Coverage for src/bioimageio/spec/model/v0_5.py: 71%

1660 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-15 15:08 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from copy import deepcopy 

8from functools import partial 

9from itertools import chain 

10from math import ceil 

11from pathlib import Path, PurePosixPath 

12from tempfile import mkdtemp 

13from textwrap import dedent 

14from typing import ( 

15 TYPE_CHECKING, 

16 Any, 

17 Callable, 

18 ClassVar, 

19 Dict, 

20 Generic, 

21 List, 

22 Literal, 

23 Mapping, 

24 NamedTuple, 

25 Optional, 

26 Sequence, 

27 Set, 

28 Tuple, 

29 Type, 

30 TypeVar, 

31 Union, 

32 cast, 

33 overload, 

34) 

35 

36import numpy as np 

37from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

38from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

39from loguru import logger 

40from numpy.typing import NDArray 

41from pydantic import ( 

42 AfterValidator, 

43 Discriminator, 

44 Field, 

45 RootModel, 

46 SerializationInfo, 

47 SerializerFunctionWrapHandler, 

48 StrictInt, 

49 Tag, 

50 ValidationInfo, 

51 WrapSerializer, 

52 field_validator, 

53 model_serializer, 

54 model_validator, 

55) 

56from typing_extensions import Annotated, Self, TypeAlias, assert_never, get_args 

57 

58from .._internal.common_nodes import ( 

59 InvalidDescr, 

60 KwargsNode, 

61 Node, 

62 NodeWithExplicitlySetFields, 

63) 

64from .._internal.constants import DTYPE_LIMITS 

65from .._internal.field_warning import issue_warning, warn 

66from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

67from .._internal.io import FileDescr as FileDescr 

68from .._internal.io import ( 

69 FileSource, 

70 WithSuffix, 

71 YamlValue, 

72 extract_file_name, 

73 get_reader, 

74 wo_special_file_name, 

75) 

76from .._internal.io_basics import Sha256 as Sha256 

77from .._internal.io_packaging import ( 

78 FileDescr_package, 

79 package_file_descr_serializer, 

80) 

81from .._internal.io_utils import load_array 

82from .._internal.node_converter import Converter 

83from .._internal.type_guards import is_dict, is_sequence 

84from .._internal.types import ( 

85 FAIR, 

86 AbsoluteTolerance, 

87 LowerCaseIdentifier, 

88 LowerCaseIdentifierAnno, 

89 MismatchedElementsPerMillion, 

90 RelativeTolerance, 

91) 

92from .._internal.types import Datetime as Datetime 

93from .._internal.types import Identifier as Identifier 

94from .._internal.types import NotEmpty as NotEmpty 

95from .._internal.types import SiUnit as SiUnit 

96from .._internal.url import HttpUrl as HttpUrl 

97from .._internal.utils import try_all_raise_last 

98from .._internal.validation_context import get_validation_context 

99from .._internal.validator_annotations import RestrictCharacters 

100from .._internal.version_type import Version as Version 

101from .._internal.warning_levels import INFO 

102from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

103from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

104from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

105from ..dataset.v0_3 import DatasetId as DatasetId 

106from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

107from ..dataset.v0_3 import Uploader as Uploader 

108from ..generic._v0_3_converter import convert_plain_covers_and_docs_and_icon 

109from ..generic.v0_3 import ( 

110 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

111) 

112from ..generic.v0_3 import Author as Author 

113from ..generic.v0_3 import BadgeDescr as BadgeDescr 

114from ..generic.v0_3 import CiteEntry as CiteEntry 

115from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

116from ..generic.v0_3 import Doi as Doi 

117from ..generic.v0_3 import ( 

118 FileDescr_documentation, 

119 GenericModelDescrBase, 

120 LinkedResourceBase, 

121 _author_conv, # pyright: ignore[reportPrivateUsage] 

122 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

123) 

124from ..generic.v0_3 import LicenseId as LicenseId 

125from ..generic.v0_3 import LinkedResource as LinkedResource 

126from ..generic.v0_3 import Maintainer as Maintainer 

127from ..generic.v0_3 import OrcidId as OrcidId 

128from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

129from ..generic.v0_3 import ResourceId as ResourceId 

130from .v0_4 import Author as _Author_v0_4 

131from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

132from .v0_4 import CallableFromDepencency as CallableFromDepencency 

133from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

134from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

135from .v0_4 import ClipDescr as _ClipDescr_v0_4 

136from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

137from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

138from .v0_4 import KnownRunMode as KnownRunMode 

139from .v0_4 import ModelDescr as _ModelDescr_v0_4 

140from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

141from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

142from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

143from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

144from .v0_4 import RunMode as RunMode 

145from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

146from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

147from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

148from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

149from .v0_4 import TensorName as _TensorName_v0_4 

150from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

151from .v0_4 import package_weights 

152 

153SpaceUnit = Literal[ 

154 "attometer", 

155 "angstrom", 

156 "centimeter", 

157 "decimeter", 

158 "exameter", 

159 "femtometer", 

160 "foot", 

161 "gigameter", 

162 "hectometer", 

163 "inch", 

164 "kilometer", 

165 "megameter", 

166 "meter", 

167 "micrometer", 

168 "mile", 

169 "millimeter", 

170 "nanometer", 

171 "parsec", 

172 "petameter", 

173 "picometer", 

174 "terameter", 

175 "yard", 

176 "yoctometer", 

177 "yottameter", 

178 "zeptometer", 

179 "zettameter", 

180] 

181"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

182 

183TimeUnit = Literal[ 

184 "attosecond", 

185 "centisecond", 

186 "day", 

187 "decisecond", 

188 "exasecond", 

189 "femtosecond", 

190 "gigasecond", 

191 "hectosecond", 

192 "hour", 

193 "kilosecond", 

194 "megasecond", 

195 "microsecond", 

196 "millisecond", 

197 "minute", 

198 "nanosecond", 

199 "petasecond", 

200 "picosecond", 

201 "second", 

202 "terasecond", 

203 "yoctosecond", 

204 "yottasecond", 

205 "zeptosecond", 

206 "zettasecond", 

207] 

208"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

209 

210AxisType = Literal["batch", "channel", "index", "time", "space"] 

211 

212_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

213 "b": "batch", 

214 "t": "time", 

215 "i": "index", 

216 "c": "channel", 

217 "x": "space", 

218 "y": "space", 

219 "z": "space", 

220} 

221 

222_AXIS_ID_MAP = { 

223 "b": "batch", 

224 "t": "time", 

225 "i": "index", 

226 "c": "channel", 

227} 

228 

229WeightsFormat = Literal[ 

230 "keras_hdf5", 

231 "keras_v3", 

232 "onnx", 

233 "pytorch_state_dict", 

234 "tensorflow_js", 

235 "tensorflow_saved_model_bundle", 

236 "torchscript", 

237] 

238 

239 

240class TensorId(LowerCaseIdentifier): 

241 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

242 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

243 ] 

244 

245 

246def _normalize_axis_id(a: str): 

247 a = str(a) 

248 normalized = _AXIS_ID_MAP.get(a, a) 

249 if a != normalized: 

250 logger.opt(depth=3).warning( 

251 "Normalized axis id from '{}' to '{}'.", a, normalized 

252 ) 

253 return normalized 

254 

255 

256class AxisId(LowerCaseIdentifier): 

257 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

258 Annotated[ 

259 LowerCaseIdentifierAnno, 

260 MaxLen(16), 

261 AfterValidator(_normalize_axis_id), 

262 ] 

263 ] 

264 

265 

266def _is_batch(a: str) -> bool: 

267 return str(a) == "batch" 

268 

269 

270def _is_not_batch(a: str) -> bool: 

271 return not _is_batch(a) 

272 

273 

274NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

275 

276PreprocessingId = Literal[ 

277 "binarize", 

278 "clip", 

279 "ensure_dtype", 

280 "fixed_zero_mean_unit_variance", 

281 "scale_linear", 

282 "scale_range", 

283 "sigmoid", 

284 "softmax", 

285] 

286PostprocessingId = Literal[ 

287 "binarize", 

288 "clip", 

289 "custom", 

290 "ensure_dtype", 

291 "fixed_zero_mean_unit_variance", 

292 "scale_linear", 

293 "scale_mean_variance", 

294 "scale_range", 

295 "sigmoid", 

296 "softmax", 

297 "zero_mean_unit_variance", 

298] 

299 

300 

301SAME_AS_TYPE = "<same as type>" 

302 

303 

304ParameterizedSize_N: TypeAlias = int 

305""" 

306Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

307""" 

308 

309 

310class ParameterizedSize(Node): 

311 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

312 

313 - **min** and **step** are given by the model description. 

314 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

315 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

316 This allows to adjust the axis size more generically. 

317 """ 

318 

319 N: ClassVar[Type[int]] = ParameterizedSize_N 

320 """Positive integer to parameterize this axis""" 

321 

322 min: Annotated[int, Gt(0)] 

323 step: Annotated[int, Gt(0)] 

324 

325 def validate_size(self, size: int, msg_prefix: str = "") -> int: 

326 if size < self.min: 

327 raise ValueError( 

328 f"{msg_prefix}size {size} < {self.min} (minimum axis size)" 

329 ) 

330 if (size - self.min) % self.step != 0: 

331 raise ValueError( 

332 f"{msg_prefix}size {size} is not parameterized by `min + n*step` =" 

333 + f" `{self.min} + n*{self.step}`" 

334 ) 

335 

336 return size 

337 

338 def get_size(self, n: ParameterizedSize_N) -> int: 

339 return self.min + self.step * n 

340 

341 def get_n(self, s: int) -> ParameterizedSize_N: 

342 """return smallest n parameterizing a size greater or equal than `s`""" 

343 return ceil((s - self.min) / self.step) 

344 

345 

346class DataDependentSize(Node): 

347 min: Annotated[int, Gt(0)] = 1 

348 max: Annotated[Optional[int], Gt(1)] = None 

349 

350 @model_validator(mode="after") 

351 def _validate_max_gt_min(self): 

352 if self.max is not None and self.min >= self.max: 

353 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

354 

355 return self 

356 

357 def validate_size(self, size: int, msg_prefix: str = "") -> int: 

358 if size < self.min: 

359 raise ValueError(f"{msg_prefix}size {size} < {self.min}") 

360 

361 if self.max is not None and size > self.max: 

362 raise ValueError(f"{msg_prefix}size {size} > {self.max}") 

363 

364 return size 

365 

366 

367class SizeReference(Node): 

368 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

369 

370 `axis.size = reference.size * reference.scale / axis.scale + offset` 

371 

372 Note: 

373 1. The axis and the referenced axis need to have the same unit (or no unit). 

374 2. Batch axes may not be referenced. 

375 3. Fractions are rounded down. 

376 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

377 `concatenable` as well with the same block order. 

378 

379 Example: 

380 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

381 Let's assume that we want to express the image height h in relation to its width w 

382 instead of only accepting input images of exactly 100*49 pixels 

383 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

384 

385 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

386 >>> h = SpaceInputAxis( 

387 ... id=AxisId("h"), 

388 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

389 ... unit="millimeter", 

390 ... scale=4, 

391 ... ) 

392 >>> print(h.size.get_size(h, w)) 

393 49 

394 

395 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

396 """ 

397 

398 tensor_id: TensorId 

399 """tensor id of the reference axis""" 

400 

401 axis_id: AxisId 

402 """axis id of the reference axis""" 

403 

404 offset: StrictInt = 0 

405 

406 def get_size( 

407 self, 

408 axis: Union[ 

409 ChannelAxis, 

410 IndexInputAxis, 

411 IndexOutputAxis, 

412 TimeInputAxis, 

413 SpaceInputAxis, 

414 TimeOutputAxis, 

415 TimeOutputAxisWithHalo, 

416 SpaceOutputAxis, 

417 SpaceOutputAxisWithHalo, 

418 ], 

419 ref_axis: Union[ 

420 ChannelAxis, 

421 IndexInputAxis, 

422 IndexOutputAxis, 

423 TimeInputAxis, 

424 SpaceInputAxis, 

425 TimeOutputAxis, 

426 TimeOutputAxisWithHalo, 

427 SpaceOutputAxis, 

428 SpaceOutputAxisWithHalo, 

429 ], 

430 n: ParameterizedSize_N = 0, 

431 ref_size: Optional[int] = None, 

432 ): 

433 """Compute the concrete size for a given axis and its reference axis. 

434 

435 Args: 

436 axis: The axis this [SizeReference][] is the size of. 

437 ref_axis: The reference axis to compute the size from. 

438 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

439 and no fixed **ref_size** is given, 

440 **n** is used to compute the size of the parameterized **ref_axis**. 

441 ref_size: Overwrite the reference size instead of deriving it from 

442 **ref_axis** 

443 (**ref_axis.scale** is still used; any given **n** is ignored). 

444 """ 

445 assert axis.size == self, ( 

446 "Given `axis.size` is not defined by this `SizeReference`" 

447 ) 

448 

449 assert ref_axis.id == self.axis_id, ( 

450 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

451 ) 

452 

453 assert axis.unit == ref_axis.unit, ( 

454 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

455 f" but {axis.unit}!={ref_axis.unit}" 

456 ) 

457 if ref_size is None: 

458 if isinstance(ref_axis.size, (int, float)): 

459 ref_size = ref_axis.size 

460 elif isinstance(ref_axis.size, ParameterizedSize): 

461 ref_size = ref_axis.size.get_size(n) 

462 elif isinstance(ref_axis.size, DataDependentSize): 

463 raise ValueError( 

464 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

465 ) 

466 elif isinstance(ref_axis.size, SizeReference): 

467 raise ValueError( 

468 "Reference axis referenced in `SizeReference` may not be sized by a" 

469 + " `SizeReference` itself." 

470 ) 

471 else: 

472 assert_never(ref_axis.size) 

473 

474 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

475 

476 @staticmethod 

477 def _get_unit( 

478 axis: Union[ 

479 ChannelAxis, 

480 IndexInputAxis, 

481 IndexOutputAxis, 

482 TimeInputAxis, 

483 SpaceInputAxis, 

484 TimeOutputAxis, 

485 TimeOutputAxisWithHalo, 

486 SpaceOutputAxis, 

487 SpaceOutputAxisWithHalo, 

488 ], 

489 ): 

490 return axis.unit 

491 

492 

493class AxisBase(NodeWithExplicitlySetFields): 

494 id: AxisId 

495 """An axis id unique across all axes of one tensor.""" 

496 

497 description: Annotated[str, MaxLen(128)] = "" 

498 """A short description of this axis beyond its type and id.""" 

499 

500 

501class WithHalo(Node): 

502 halo: Annotated[int, Ge(1)] 

503 """The halo should be cropped from the output tensor to avoid boundary effects. 

504 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

505 To document a halo that is already cropped by the model use `size.offset` instead.""" 

506 

507 size: Annotated[ 

508 SizeReference, 

509 Field( 

510 examples=[ 

511 10, 

512 SizeReference( 

513 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

514 ).model_dump(mode="json"), 

515 ] 

516 ), 

517 ] 

518 """reference to another axis with an optional offset (see [SizeReference][])""" 

519 

520 

521BATCH_AXIS_ID = AxisId("batch") 

522 

523 

524class BatchAxis(AxisBase): 

525 implemented_type: ClassVar[Literal["batch"]] = "batch" 

526 if TYPE_CHECKING: 

527 type: Literal["batch"] = "batch" 

528 else: 

529 type: Literal["batch"] 

530 

531 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

532 size: Optional[Literal[1]] = None 

533 """The batch size may be fixed to 1, 

534 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

535 

536 @property 

537 def scale(self): 

538 return 1.0 

539 

540 @property 

541 def concatenable(self): 

542 return True 

543 

544 @property 

545 def unit(self): 

546 return None 

547 

548 

549class ChannelAxis(AxisBase): 

550 implemented_type: ClassVar[Literal["channel"]] = "channel" 

551 if TYPE_CHECKING: 

552 type: Literal["channel"] = "channel" 

553 else: 

554 type: Literal["channel"] 

555 

556 id: NonBatchAxisId = AxisId("channel") 

557 

558 channel_names: NotEmpty[List[str]] 

559 

560 @property 

561 def size(self) -> int: 

562 return len(self.channel_names) 

563 

564 @property 

565 def concatenable(self): 

566 return False 

567 

568 @property 

569 def scale(self) -> float: 

570 return 1.0 

571 

572 @property 

573 def unit(self): 

574 return None 

575 

576 

577class _WithInputAxisSize(Node): 

578 size: Annotated[ 

579 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

580 Field( 

581 examples=[ 

582 10, 

583 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

584 SizeReference( 

585 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

586 ).model_dump(mode="json"), 

587 ] 

588 ), 

589 ] 

590 """The size/length of this axis can be specified as 

591 - fixed integer 

592 - parameterized series of valid sizes ([ParameterizedSize][]) 

593 - reference to another axis with an optional offset ([SizeReference][]) 

594 """ 

595 

596 

597class IndexAxisBase(AxisBase): 

598 implemented_type: ClassVar[Literal["index"]] = "index" 

599 if TYPE_CHECKING: 

600 type: Literal["index"] = "index" 

601 else: 

602 type: Literal["index"] 

603 

604 id: NonBatchAxisId = AxisId("index") 

605 

606 @property 

607 def scale(self) -> float: 

608 return 1.0 

609 

610 @property 

611 def unit(self): 

612 return None 

613 

614 

615class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

616 concatenable: bool = False 

617 """If a model has a `concatenable` input axis, it can be processed blockwise, 

618 splitting a longer sample axis into blocks matching its input tensor description. 

619 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

620 input axis. 

621 """ 

622 

623 

624class IndexOutputAxis(IndexAxisBase): 

625 size: Annotated[ 

626 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

627 Field( 

628 examples=[ 

629 10, 

630 SizeReference( 

631 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

632 ).model_dump(mode="json"), 

633 ] 

634 ), 

635 ] 

636 """The size/length of this axis can be specified as 

637 - fixed integer 

638 - reference to another axis with an optional offset ([SizeReference][]) 

639 - data dependent size using [DataDependentSize][] (size is only known after model inference) 

640 """ 

641 

642 

643class TimeAxisBase(AxisBase): 

644 implemented_type: ClassVar[Literal["time"]] = "time" 

645 if TYPE_CHECKING: 

646 type: Literal["time"] = "time" 

647 else: 

648 type: Literal["time"] 

649 

650 id: NonBatchAxisId = AxisId("time") 

651 unit: Optional[TimeUnit] = None 

652 scale: Annotated[float, Gt(0)] = 1.0 

653 

654 

655class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

656 concatenable: bool = False 

657 """If a model has a `concatenable` input axis, it can be processed blockwise, 

658 splitting a longer sample axis into blocks matching its input tensor description. 

659 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

660 input axis. 

661 """ 

662 

663 

664class SpaceAxisBase(AxisBase): 

665 implemented_type: ClassVar[Literal["space"]] = "space" 

666 if TYPE_CHECKING: 

667 type: Literal["space"] = "space" 

668 else: 

669 type: Literal["space"] 

670 

671 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

672 unit: Optional[SpaceUnit] = None 

673 scale: Annotated[float, Gt(0)] = 1.0 

674 

675 

676class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

677 concatenable: bool = False 

678 """If a model has a `concatenable` input axis, it can be processed blockwise, 

679 splitting a longer sample axis into blocks matching its input tensor description. 

680 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

681 input axis. 

682 """ 

683 

684 

685INPUT_AXIS_TYPES = ( 

686 BatchAxis, 

687 ChannelAxis, 

688 IndexInputAxis, 

689 TimeInputAxis, 

690 SpaceInputAxis, 

691) 

692"""intended for isinstance comparisons in py<3.10""" 

693 

694_InputAxisUnion = Union[ 

695 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

696] 

697InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

698 

699 

700class _WithOutputAxisSize(Node): 

701 size: Annotated[ 

702 Union[Annotated[int, Gt(0)], SizeReference], 

703 Field( 

704 examples=[ 

705 10, 

706 SizeReference( 

707 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

708 ).model_dump(mode="json"), 

709 ] 

710 ), 

711 ] 

712 """The size/length of this axis can be specified as 

713 - fixed integer 

714 - reference to another axis with an optional offset (see [SizeReference][]) 

715 """ 

716 

717 

718class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

719 pass 

720 

721 

722class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

723 pass 

724 

725 

726def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

727 if isinstance(v, dict): 

728 return "with_halo" if "halo" in v else "wo_halo" 

729 else: 

730 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

731 

732 

733_TimeOutputAxisUnion = Annotated[ 

734 Union[ 

735 Annotated[TimeOutputAxis, Tag("wo_halo")], 

736 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

737 ], 

738 Discriminator(_get_halo_axis_discriminator_value), 

739] 

740 

741 

742class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

743 pass 

744 

745 

746class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

747 pass 

748 

749 

750_SpaceOutputAxisUnion = Annotated[ 

751 Union[ 

752 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

753 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

754 ], 

755 Discriminator(_get_halo_axis_discriminator_value), 

756] 

757 

758 

759_OutputAxisUnion = Union[ 

760 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

761] 

762OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

763 

764OUTPUT_AXIS_TYPES = ( 

765 BatchAxis, 

766 ChannelAxis, 

767 IndexOutputAxis, 

768 TimeOutputAxis, 

769 TimeOutputAxisWithHalo, 

770 SpaceOutputAxis, 

771 SpaceOutputAxisWithHalo, 

772) 

773"""intended for isinstance comparisons in py<3.10""" 

774 

775 

776AnyAxis = Union[InputAxis, OutputAxis] 

777 

778ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

779"""intended for isinstance comparisons in py<3.10""" 

780 

781TVs = Union[ 

782 NotEmpty[List[int]], 

783 NotEmpty[List[float]], 

784 NotEmpty[List[bool]], 

785 NotEmpty[List[str]], 

786] 

787 

788 

789NominalOrOrdinalDType = Literal[ 

790 "float32", 

791 "float64", 

792 "uint8", 

793 "int8", 

794 "uint16", 

795 "int16", 

796 "uint32", 

797 "int32", 

798 "uint64", 

799 "int64", 

800 "bool", 

801] 

802 

803 

804class NominalOrOrdinalDataDescr(Node): 

805 values: TVs 

806 """A fixed set of nominal or an ascending sequence of ordinal values. 

807 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

808 String `values` are interpreted as labels for tensor values 0, ..., N. 

809 Note: as YAML 1.2 does not natively support a "set" datatype, 

810 nominal values should be given as a sequence (aka list/array) as well. 

811 """ 

812 

813 type: Annotated[ 

814 NominalOrOrdinalDType, 

815 Field( 

816 examples=[ 

817 "float32", 

818 "uint8", 

819 "uint16", 

820 "int64", 

821 "bool", 

822 ], 

823 ), 

824 ] = "uint8" 

825 

826 @model_validator(mode="after") 

827 def _validate_values_match_type( 

828 self, 

829 ) -> Self: 

830 incompatible: List[Any] = [] 

831 for v in self.values: 

832 if self.type == "bool": 

833 if not isinstance(v, bool): 

834 incompatible.append(v) 

835 elif self.type in DTYPE_LIMITS: 

836 if ( 

837 isinstance(v, (int, float)) 

838 and ( 

839 v < DTYPE_LIMITS[self.type].min 

840 or v > DTYPE_LIMITS[self.type].max 

841 ) 

842 or (isinstance(v, str) and "uint" not in self.type) 

843 or (isinstance(v, float) and "int" in self.type) 

844 ): 

845 incompatible.append(v) 

846 else: 

847 incompatible.append(v) 

848 

849 if len(incompatible) == 5: 

850 incompatible.append("...") 

851 break 

852 

853 if incompatible: 

854 raise ValueError( 

855 f"data type '{self.type}' incompatible with values {incompatible}" 

856 ) 

857 

858 return self 

859 

860 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

861 

862 @property 

863 def range(self): 

864 if isinstance(self.values[0], str): 

865 return 0, len(self.values) - 1 

866 else: 

867 return min(self.values), max(self.values) 

868 

869 

870IntervalOrRatioDType = Literal[ 

871 "float32", 

872 "float64", 

873 "uint8", 

874 "int8", 

875 "uint16", 

876 "int16", 

877 "uint32", 

878 "int32", 

879 "uint64", 

880 "int64", 

881] 

882 

883 

884class IntervalOrRatioDataDescr(Node): 

885 type: Annotated[ # TODO: rename to dtype 

886 IntervalOrRatioDType, 

887 Field( 

888 examples=["float32", "float64", "uint8", "uint16"], 

889 ), 

890 ] = "float32" 

891 range: Tuple[Optional[float], Optional[float]] = ( 

892 None, 

893 None, 

894 ) 

895 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

896 `None` corresponds to min/max of what can be expressed by **type**.""" 

897 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

898 scale: float = 1.0 

899 """Scale for data on an interval (or ratio) scale.""" 

900 offset: Optional[float] = None 

901 """Offset for data on a ratio scale.""" 

902 

903 @model_validator(mode="before") 

904 def _replace_inf(cls, data: Any): 

905 if is_dict(data): 

906 if "range" in data and is_sequence(data["range"]): 

907 forbidden = ( 

908 "inf", 

909 "-inf", 

910 ".inf", 

911 "-.inf", 

912 float("inf"), 

913 float("-inf"), 

914 ) 

915 if any(v in forbidden for v in data["range"]): 

916 issue_warning("replaced 'inf' value", value=data["range"]) 

917 

918 data["range"] = tuple( 

919 (None if v in forbidden else v) for v in data["range"] 

920 ) 

921 

922 return data 

923 

924 

925TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

926 

927 

928class BinarizeKwargs(KwargsNode): 

929 """key word arguments for [BinarizeDescr][]""" 

930 

931 threshold: float 

932 """The fixed threshold""" 

933 

934 

935class BinarizeAlongAxisKwargs(KwargsNode): 

936 """key word arguments for [BinarizeDescr][]""" 

937 

938 threshold: NotEmpty[List[float]] 

939 """The fixed threshold values along `axis`""" 

940 

941 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

942 """The `threshold` axis""" 

943 

944 

945class BinarizeDescr(NodeWithExplicitlySetFields): 

946 """Binarize the tensor with a fixed threshold. 

947 

948 Values above [BinarizeKwargs.threshold][]/[BinarizeAlongAxisKwargs.threshold][] 

949 will be set to one, values below the threshold to zero. 

950 

951 Examples: 

952 - in YAML 

953 ```yaml 

954 postprocessing: 

955 - id: binarize 

956 kwargs: 

957 axis: 'channel' 

958 threshold: [0.25, 0.5, 0.75] 

959 ``` 

960 - in Python: 

961 

962 >>> postprocessing = [BinarizeDescr( 

963 ... kwargs=BinarizeAlongAxisKwargs( 

964 ... axis=AxisId('channel'), 

965 ... threshold=[0.25, 0.5, 0.75], 

966 ... ) 

967 ... )] 

968 """ 

969 

970 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

971 if TYPE_CHECKING: 

972 id: Literal["binarize"] = "binarize" 

973 else: 

974 id: Literal["binarize"] 

975 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

976 

977 

978class ClipKwargs(KwargsNode): 

979 """key word arguments for [ClipDescr][]""" 

980 

981 min: Optional[float] = None 

982 """Minimum value for clipping. 

983 

984 Exclusive with [min_percentile][] 

985 """ 

986 min_percentile: Optional[Annotated[float, Interval(ge=0, lt=100)]] = None 

987 """Minimum percentile for clipping. 

988 

989 Exclusive with [min][]. 

990 

991 In range [0, 100). 

992 """ 

993 

994 max: Optional[float] = None 

995 """Maximum value for clipping. 

996 

997 Exclusive with `max_percentile`. 

998 """ 

999 max_percentile: Optional[Annotated[float, Interval(gt=1, le=100)]] = None 

1000 """Maximum percentile for clipping. 

1001 

1002 Exclusive with `max`. 

1003 

1004 In range (1, 100]. 

1005 """ 

1006 

1007 axes: Annotated[ 

1008 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1009 ] = None 

1010 """The subset of axes to determine percentiles jointly, 

1011 

1012 i.e. axes to reduce to compute min/max from `min_percentile`/`max_percentile`. 

1013 For example to clip 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1014 resulting in a tensor of equal shape with clipped values per channel, specify `axes=('batch', 'x', 'y')`. 

1015 To clip samples independently, leave out the 'batch' axis. 

1016 

1017 Only valid if `min_percentile` and/or `max_percentile` are set. 

1018 

1019 Default: Compute percentiles over all axes jointly.""" 

1020 

1021 @model_validator(mode="after") 

1022 def _validate(self) -> Self: 

1023 if (self.min is not None) and (self.min_percentile is not None): 

1024 raise ValueError( 

1025 "Only one of `min` and `min_percentile` may be set, not both." 

1026 ) 

1027 if (self.max is not None) and (self.max_percentile is not None): 

1028 raise ValueError( 

1029 "Only one of `max` and `max_percentile` may be set, not both." 

1030 ) 

1031 if ( 

1032 self.min is None 

1033 and self.min_percentile is None 

1034 and self.max is None 

1035 and self.max_percentile is None 

1036 ): 

1037 raise ValueError( 

1038 "At least one of `min`, `min_percentile`, `max`, or `max_percentile` must be set." 

1039 ) 

1040 

1041 if ( 

1042 self.axes is not None 

1043 and self.min_percentile is None 

1044 and self.max_percentile is None 

1045 ): 

1046 raise ValueError( 

1047 "If `axes` is set, at least one of `min_percentile` or `max_percentile` must be set." 

1048 ) 

1049 

1050 return self 

1051 

1052 

1053class ClipDescr(NodeWithExplicitlySetFields): 

1054 """Set tensor values below min to min and above max to max. 

1055 

1056 See `ScaleRangeDescr` for examples. 

1057 """ 

1058 

1059 implemented_id: ClassVar[Literal["clip"]] = "clip" 

1060 if TYPE_CHECKING: 

1061 id: Literal["clip"] = "clip" 

1062 else: 

1063 id: Literal["clip"] 

1064 

1065 kwargs: ClipKwargs 

1066 

1067 

1068class EnsureDtypeKwargs(KwargsNode): 

1069 """key word arguments for [EnsureDtypeDescr][]""" 

1070 

1071 dtype: Literal[ 

1072 "float32", 

1073 "float64", 

1074 "uint8", 

1075 "int8", 

1076 "uint16", 

1077 "int16", 

1078 "uint32", 

1079 "int32", 

1080 "uint64", 

1081 "int64", 

1082 "bool", 

1083 ] 

1084 

1085 

1086class EnsureDtypeDescr(NodeWithExplicitlySetFields): 

1087 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

1088 

1089 This can for example be used to ensure the inner neural network model gets a 

1090 different input tensor data type than the fully described bioimage.io model does. 

1091 

1092 Examples: 

1093 The described bioimage.io model (incl. preprocessing) accepts any 

1094 float32-compatible tensor, normalizes it with percentiles and clipping and then 

1095 casts it to uint8, which is what the neural network in this example expects. 

1096 - in YAML 

1097 ```yaml 

1098 inputs: 

1099 - data: 

1100 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

1101 preprocessing: 

1102 - id: scale_range 

1103 kwargs: 

1104 axes: ['y', 'x'] 

1105 max_percentile: 99.8 

1106 min_percentile: 5.0 

1107 - id: clip 

1108 kwargs: 

1109 min: 0.0 

1110 max: 1.0 

1111 - id: ensure_dtype # the neural network of the model requires uint8 

1112 kwargs: 

1113 dtype: uint8 

1114 ``` 

1115 - in Python: 

1116 >>> preprocessing = [ 

1117 ... ScaleRangeDescr( 

1118 ... kwargs=ScaleRangeKwargs( 

1119 ... axes= (AxisId('y'), AxisId('x')), 

1120 ... max_percentile= 99.8, 

1121 ... min_percentile= 5.0, 

1122 ... ) 

1123 ... ), 

1124 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

1125 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

1126 ... ] 

1127 """ 

1128 

1129 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1130 if TYPE_CHECKING: 

1131 id: Literal["ensure_dtype"] = "ensure_dtype" 

1132 else: 

1133 id: Literal["ensure_dtype"] 

1134 

1135 kwargs: EnsureDtypeKwargs 

1136 

1137 

1138class ScaleLinearKwargs(KwargsNode): 

1139 """Key word arguments for [ScaleLinearDescr][]""" 

1140 

1141 gain: float = 1.0 

1142 """multiplicative factor""" 

1143 

1144 offset: float = 0.0 

1145 """additive term""" 

1146 

1147 @model_validator(mode="after") 

1148 def _validate(self) -> Self: 

1149 if self.gain == 1.0 and self.offset == 0.0: 

1150 raise ValueError( 

1151 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1152 + " != 0.0." 

1153 ) 

1154 

1155 return self 

1156 

1157 

1158class ScaleLinearAlongAxisKwargs(KwargsNode): 

1159 """Key word arguments for [ScaleLinearDescr][]""" 

1160 

1161 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1162 """The axis of gain and offset values.""" 

1163 

1164 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1165 """multiplicative factor""" 

1166 

1167 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1168 """additive term""" 

1169 

1170 @model_validator(mode="after") 

1171 def _validate(self) -> Self: 

1172 if isinstance(self.gain, list): 

1173 if isinstance(self.offset, list): 

1174 if len(self.gain) != len(self.offset): 

1175 raise ValueError( 

1176 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1177 ) 

1178 else: 

1179 self.offset = [float(self.offset)] * len(self.gain) 

1180 elif isinstance(self.offset, list): 

1181 self.gain = [float(self.gain)] * len(self.offset) 

1182 else: 

1183 raise ValueError( 

1184 "Do not specify an `axis` for scalar gain and offset values." 

1185 ) 

1186 

1187 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1188 raise ValueError( 

1189 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1190 + " != 0.0." 

1191 ) 

1192 

1193 return self 

1194 

1195 

1196class ScaleLinearDescr(NodeWithExplicitlySetFields): 

1197 """Fixed linear scaling. 

1198 

1199 Examples: 

1200 1. Scale with scalar gain and offset 

1201 - in YAML 

1202 ```yaml 

1203 preprocessing: 

1204 - id: scale_linear 

1205 kwargs: 

1206 gain: 2.0 

1207 offset: 3.0 

1208 ``` 

1209 - in Python: 

1210 

1211 >>> preprocessing = [ 

1212 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1213 ... ] 

1214 

1215 2. Independent scaling along an axis 

1216 - in YAML 

1217 ```yaml 

1218 preprocessing: 

1219 - id: scale_linear 

1220 kwargs: 

1221 axis: 'channel' 

1222 gain: [1.0, 2.0, 3.0] 

1223 ``` 

1224 - in Python: 

1225 

1226 >>> preprocessing = [ 

1227 ... ScaleLinearDescr( 

1228 ... kwargs=ScaleLinearAlongAxisKwargs( 

1229 ... axis=AxisId("channel"), 

1230 ... gain=[1.0, 2.0, 3.0], 

1231 ... ) 

1232 ... ) 

1233 ... ] 

1234 

1235 """ 

1236 

1237 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1238 if TYPE_CHECKING: 

1239 id: Literal["scale_linear"] = "scale_linear" 

1240 else: 

1241 id: Literal["scale_linear"] 

1242 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1243 

1244 

1245class SigmoidDescr(NodeWithExplicitlySetFields): 

1246 """The logistic sigmoid function, a.k.a. expit function. 

1247 

1248 Examples: 

1249 - in YAML 

1250 ```yaml 

1251 postprocessing: 

1252 - id: sigmoid 

1253 ``` 

1254 - in Python: 

1255 

1256 >>> postprocessing = [SigmoidDescr()] 

1257 """ 

1258 

1259 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1260 if TYPE_CHECKING: 

1261 id: Literal["sigmoid"] = "sigmoid" 

1262 else: 

1263 id: Literal["sigmoid"] 

1264 

1265 @property 

1266 def kwargs(self) -> KwargsNode: 

1267 """empty kwargs""" 

1268 return KwargsNode() 

1269 

1270 

1271class SoftmaxKwargs(KwargsNode): 

1272 """key word arguments for [SoftmaxDescr][]""" 

1273 

1274 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 

1275 """The axis to apply the softmax function along. 

1276 Note: 

1277 Defaults to 'channel' axis 

1278 (which may not exist, in which case 

1279 a different axis id has to be specified). 

1280 """ 

1281 

1282 

1283class SoftmaxDescr(NodeWithExplicitlySetFields): 

1284 """The softmax function. 

1285 

1286 Examples: 

1287 - in YAML 

1288 ```yaml 

1289 postprocessing: 

1290 - id: softmax 

1291 kwargs: 

1292 axis: channel 

1293 ``` 

1294 - in Python: 

1295 

1296 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 

1297 """ 

1298 

1299 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 

1300 if TYPE_CHECKING: 

1301 id: Literal["softmax"] = "softmax" 

1302 else: 

1303 id: Literal["softmax"] 

1304 

1305 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 

1306 

1307 

1308class _StardistPostprocessingKwargsBase(KwargsNode): 

1309 """key word arguments for [StardistPostprocessingDescr][]""" 

1310 

1311 prob_threshold: float 

1312 """The probability threshold for object candidate selection.""" 

1313 

1314 nms_threshold: float 

1315 """The IoU threshold for non-maximum suppression.""" 

1316 

1317 n_rays: int 

1318 """Number of radial lines (rays) cast from the center of an object to its boundary.""" 

1319 

1320 

1321class StardistPostprocessingKwargs2D(_StardistPostprocessingKwargsBase): 

1322 grid: Tuple[int, int] 

1323 """Grid size of network predictions.""" 

1324 

1325 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int]]] 

1326 """Border region in which object probability is set to zero.""" 

1327 

1328 

1329class StardistPostprocessingKwargs3D(_StardistPostprocessingKwargsBase): 

1330 grid: Tuple[int, int, int] 

1331 """Grid size of network predictions.""" 

1332 

1333 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]] 

1334 """Border region in which object probability is set to zero.""" 

1335 

1336 anisotropy: Tuple[float, float, float] 

1337 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis.""" 

1338 

1339 overlap_label: Optional[int] = None 

1340 """Optional label to apply to any area of overlapping predicted objects.""" 

1341 

1342 

1343class StardistPostprocessingDescr(NodeWithExplicitlySetFields): 

1344 """Stardist postprocessing including non-maximum suppression and converting polygon representations to instance labels 

1345 

1346 as described in: 

1347 - Uwe Schmidt, Martin Weigert, Coleman Broaddus, and Gene Myers. 

1348 [*Cell Detection with Star-convex Polygons*](https://arxiv.org/abs/1806.03535). 

1349 International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. 

1350 - Martin Weigert, Uwe Schmidt, Robert Haase, Ko Sugawara, and Gene Myers. 

1351 [*Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy*](http://openaccess.thecvf.com/content_WACV_2020/papers/Weigert_Star-convex_Polyhedra_for_3D_Object_Detection_and_Segmentation_in_Microscopy_WACV_2020_paper.pdf). 

1352 The IEEE Winter Conference on Applications of Computer Vision (WACV), Snowmass Village, Colorado, March 2020. 

1353 

1354 Note: Only available if the `stardist` package is installed. 

1355 """ 

1356 

1357 implemented_id: ClassVar[Literal["stardist_postprocessing"]] = ( 

1358 "stardist_postprocessing" 

1359 ) 

1360 if TYPE_CHECKING: 

1361 id: Literal["stardist_postprocessing"] = "stardist_postprocessing" 

1362 else: 

1363 id: Literal["stardist_postprocessing"] 

1364 

1365 kwargs: Union[StardistPostprocessingKwargs2D, StardistPostprocessingKwargs3D] 

1366 

1367 

1368class CellposeFlowDynamicsKwargs(KwargsNode): 

1369 """key word arguments for [CellposeFlowDynamicsDescr][]""" 

1370 

1371 cellprob_threshold: float 

1372 flow_threshold: float 

1373 do_3D: bool 

1374 min_size: int = 15 

1375 """Minimum size of objects to keep, in pixels. Default is 15, which is the default in Cellpose. Set to 0 to disable filtering by size.""" 

1376 output_dtype: Literal["uint16", "uint32"] = "uint16" 

1377 

1378 

1379class CellposeFlowDynamicsDescr(NodeWithExplicitlySetFields): 

1380 """Cellpose flow dynamics postprocessing as described in: 

1381 - Carsen Stringer and Marius Pachitariu. [*Cellpose: a generalist algorithm for cellular segmentation*](https://www.nature.com/articles/s41592-020-01018-x). Nature Methods, 2021. 

1382 

1383 Note: Only available if the `cellpose` package is installed. 

1384 """ 

1385 

1386 implemented_id: ClassVar[Literal["cellpose_flow_dynamics"]] = ( 

1387 "cellpose_flow_dynamics" 

1388 ) 

1389 if TYPE_CHECKING: 

1390 id: Literal["cellpose_flow_dynamics"] = "cellpose_flow_dynamics" 

1391 else: 

1392 id: Literal["cellpose_flow_dynamics"] 

1393 

1394 kwargs: CellposeFlowDynamicsKwargs 

1395 

1396 

1397class CustomProcessingDescr(NodeWithExplicitlySetFields, FileDescr): 

1398 """Custom (post)processing op — source file shipped inline with the model. 

1399 

1400 Supports (post)processing that cannot be expressed by the built-in named 

1401 operations (watershed, connected components, etc.) 

1402 using a simple Python callable interface. 

1403 

1404 The op is implemented in a ``.py`` file packaged alongside the model weights. 

1405 Two styles are supported: 

1406 

1407 *Callable class* — kwargs go to ``__init__``, tensors arrive in ``__call__``: 

1408 

1409 .. code-block:: python 

1410 

1411 # my_postprocess.py 

1412 import numpy as np 

1413 

1414 class my_postprocess: 

1415 def __init__(self, threshold: float = 0.5) -> None: 

1416 self.threshold = threshold 

1417 def __call__(self, *arrays: np.ndarray) -> np.ndarray: 

1418 # arrays = model output tensors in rdf.yaml declaration order 

1419 return (arrays[0] > self.threshold).astype(np.uint8) 

1420 

1421 *Factory function* — alternative closure style, identical runtime behaviour: 

1422 

1423 .. code-block:: python 

1424 

1425 # my_postprocess.py 

1426 import numpy as np 

1427 

1428 def my_postprocess(threshold: float = 0.5): 

1429 def run(*arrays: np.ndarray) -> np.ndarray: 

1430 return (arrays[0] > threshold).astype(np.uint8) 

1431 return run 

1432 

1433 Reference it in ``rdf.yaml`` with the source file included in the package: 

1434 

1435 .. code-block:: yaml 

1436 

1437 postprocessing: 

1438 - id: custom 

1439 callable: my_postprocess # class or function name in source 

1440 source: my_postprocess.py # packaged alongside weights 

1441 sha256: <hash> # sha256 of the source file 

1442 kwargs: # forwarded to __init__ / factory 

1443 threshold: 0.5 

1444 

1445 **Security:** source files are SHA-256 verified before execution. 

1446 Execution requires explicit opt-in in bioimageio.core and curator 

1447 review before Zoo publication. 

1448 """ 

1449 

1450 implemented_id: ClassVar[Literal["custom"]] = "custom" 

1451 if TYPE_CHECKING: 

1452 id: Literal["custom"] = "custom" 

1453 else: 

1454 id: Literal["custom"] 

1455 

1456 callable: Annotated[ 

1457 str, 

1458 Field(examples=["my_postprocess_factory", "MyPostprocessClass"]), 

1459 ] 

1460 """Name of the callable class or factory function defined in ``source``. 

1461 

1462 At runtime: ``op = callable(**kwargs)``, then ``result = op(*output_tensors)`` 

1463 per image. Both a class with ``__call__`` and a factory function returning 

1464 a callable satisfy this protocol.""" 

1465 

1466 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

1467 """Python source file (included when packaging the model).""" 

1468 

1469 kwargs: Dict[str, YamlValue] = Field( 

1470 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

1471 ) 

1472 """Keyword arguments forwarded to the callable (``__init__`` or factory).""" 

1473 

1474 @model_serializer(mode="wrap", when_used="unless-none") 

1475 def _serialize( 

1476 self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo 

1477 ) -> Dict[str, YamlValue]: 

1478 return package_file_descr_serializer(self, nxt, info) 

1479 

1480 

1481class FixedZeroMeanUnitVarianceKwargs(KwargsNode): 

1482 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]""" 

1483 

1484 mean: float 

1485 """The mean value to normalize with.""" 

1486 

1487 std: Annotated[float, Ge(1e-6)] 

1488 """The standard deviation value to normalize with.""" 

1489 

1490 

1491class FixedZeroMeanUnitVarianceAlongAxisKwargs(KwargsNode): 

1492 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]""" 

1493 

1494 mean: NotEmpty[List[float]] 

1495 """The mean value(s) to normalize with.""" 

1496 

1497 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1498 """The standard deviation value(s) to normalize with. 

1499 Size must match `mean` values.""" 

1500 

1501 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1502 """The axis of the mean/std values to normalize each entry along that dimension 

1503 separately.""" 

1504 

1505 @model_validator(mode="after") 

1506 def _mean_and_std_match(self) -> Self: 

1507 if len(self.mean) != len(self.std): 

1508 raise ValueError( 

1509 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1510 + " must match." 

1511 ) 

1512 

1513 return self 

1514 

1515 

1516class FixedZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

1517 """Subtract a given mean and divide by the standard deviation. 

1518 

1519 Normalize with fixed, precomputed values for 

1520 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1521 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1522 axes. 

1523 

1524 Examples: 

1525 1. scalar value for whole tensor 

1526 - in YAML 

1527 ```yaml 

1528 preprocessing: 

1529 - id: fixed_zero_mean_unit_variance 

1530 kwargs: 

1531 mean: 103.5 

1532 std: 13.7 

1533 ``` 

1534 - in Python 

1535 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1536 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1537 ... )] 

1538 

1539 2. independently along an axis 

1540 - in YAML 

1541 ```yaml 

1542 preprocessing: 

1543 - id: fixed_zero_mean_unit_variance 

1544 kwargs: 

1545 axis: channel 

1546 mean: [101.5, 102.5, 103.5] 

1547 std: [11.7, 12.7, 13.7] 

1548 ``` 

1549 - in Python 

1550 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1551 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1552 ... axis=AxisId("channel"), 

1553 ... mean=[101.5, 102.5, 103.5], 

1554 ... std=[11.7, 12.7, 13.7], 

1555 ... ) 

1556 ... )] 

1557 """ 

1558 

1559 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1560 "fixed_zero_mean_unit_variance" 

1561 ) 

1562 if TYPE_CHECKING: 

1563 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1564 else: 

1565 id: Literal["fixed_zero_mean_unit_variance"] 

1566 

1567 kwargs: Union[ 

1568 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1569 ] 

1570 

1571 

1572class ZeroMeanUnitVarianceKwargs(KwargsNode): 

1573 """key word arguments for [ZeroMeanUnitVarianceDescr][]""" 

1574 

1575 axes: Annotated[ 

1576 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1577 ] = None 

1578 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1579 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1580 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1581 To normalize each sample independently leave out the 'batch' axis. 

1582 Default: Scale all axes jointly.""" 

1583 

1584 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1585 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1586 

1587 

1588class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

1589 """Subtract mean and divide by variance. 

1590 

1591 Examples: 

1592 Subtract tensor mean and variance 

1593 - in YAML 

1594 ```yaml 

1595 preprocessing: 

1596 - id: zero_mean_unit_variance 

1597 ``` 

1598 - in Python 

1599 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1600 """ 

1601 

1602 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1603 "zero_mean_unit_variance" 

1604 ) 

1605 if TYPE_CHECKING: 

1606 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1607 else: 

1608 id: Literal["zero_mean_unit_variance"] 

1609 

1610 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1611 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 

1612 ) 

1613 

1614 

1615class ScaleRangeKwargs(KwargsNode): 

1616 """key word arguments for [ScaleRangeDescr][] 

1617 

1618 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1619 this processing step normalizes data to the [0, 1] intervall. 

1620 For other percentiles the normalized values will partially be outside the [0, 1] 

1621 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1622 normalized values to a range. 

1623 """ 

1624 

1625 axes: Annotated[ 

1626 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1627 ] = None 

1628 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1629 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1630 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1631 To normalize samples independently, leave out the "batch" axis. 

1632 Default: Scale all axes jointly.""" 

1633 

1634 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1635 """The lower percentile used to determine the value to align with zero.""" 

1636 

1637 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1638 """The upper percentile used to determine the value to align with one. 

1639 Has to be bigger than `min_percentile`. 

1640 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1641 accepting percentiles specified in the range 0.0 to 1.0.""" 

1642 

1643 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1644 """Epsilon for numeric stability. 

1645 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1646 with `v_lower,v_upper` values at the respective percentiles.""" 

1647 

1648 reference_tensor: Optional[TensorId] = None 

1649 """ID of the unprocessed input tensor to compute the percentiles from. 

1650 Default: The tensor itself. 

1651 """ 

1652 

1653 @field_validator("max_percentile", mode="after") 

1654 @classmethod 

1655 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1656 if (min_p := info.data["min_percentile"]) >= value: 

1657 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1658 

1659 return value 

1660 

1661 

1662class ScaleRangeDescr(NodeWithExplicitlySetFields): 

1663 """Scale with percentiles. 

1664 

1665 Examples: 

1666 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1667 - in YAML 

1668 ```yaml 

1669 preprocessing: 

1670 - id: scale_range 

1671 kwargs: 

1672 axes: ['y', 'x'] 

1673 max_percentile: 99.8 

1674 min_percentile: 5.0 

1675 ``` 

1676 - in Python 

1677 

1678 >>> preprocessing = [ 

1679 ... ScaleRangeDescr( 

1680 ... kwargs=ScaleRangeKwargs( 

1681 ... axes= (AxisId('y'), AxisId('x')), 

1682 ... max_percentile= 99.8, 

1683 ... min_percentile= 5.0, 

1684 ... ) 

1685 ... ) 

1686 ... ] 

1687 

1688 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1689 - in YAML 

1690 ```yaml 

1691 preprocessing: 

1692 - id: scale_range 

1693 kwargs: 

1694 axes: ['y', 'x'] 

1695 max_percentile: 99.8 

1696 min_percentile: 5.0 

1697 - id: clip 

1698 kwargs: 

1699 min: 0.0 

1700 max: 1.0 

1701 ``` 

1702 - in Python 

1703 

1704 >>> preprocessing = [ 

1705 ... ScaleRangeDescr( 

1706 ... kwargs=ScaleRangeKwargs( 

1707 ... axes= (AxisId('y'), AxisId('x')), 

1708 ... max_percentile= 99.8, 

1709 ... min_percentile= 5.0, 

1710 ... ) 

1711 ... ), 

1712 ... ClipDescr( 

1713 ... kwargs=ClipKwargs( 

1714 ... min=0.0, 

1715 ... max=1.0, 

1716 ... ) 

1717 ... ), 

1718 ... ] 

1719 

1720 """ 

1721 

1722 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1723 if TYPE_CHECKING: 

1724 id: Literal["scale_range"] = "scale_range" 

1725 else: 

1726 id: Literal["scale_range"] 

1727 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 

1728 

1729 

1730class ScaleMeanVarianceKwargs(KwargsNode): 

1731 """key word arguments for [ScaleMeanVarianceKwargs][]""" 

1732 

1733 reference_tensor: TensorId 

1734 """ID of unprocessed input tensor to match.""" 

1735 

1736 axes: Annotated[ 

1737 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1738 ] = None 

1739 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1740 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1741 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1742 To normalize samples independently, leave out the 'batch' axis. 

1743 Default: Scale all axes jointly.""" 

1744 

1745 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1746 """Epsilon for numeric stability: 

1747 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1748 

1749 

1750class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields): 

1751 """Scale a tensor's data distribution to match another tensor's mean/std. 

1752 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1753 """ 

1754 

1755 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1756 if TYPE_CHECKING: 

1757 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1758 else: 

1759 id: Literal["scale_mean_variance"] 

1760 kwargs: ScaleMeanVarianceKwargs 

1761 

1762 

1763PreprocessingDescr = Annotated[ 

1764 Union[ 

1765 BinarizeDescr, 

1766 ClipDescr, 

1767 EnsureDtypeDescr, 

1768 FixedZeroMeanUnitVarianceDescr, 

1769 ScaleLinearDescr, 

1770 ScaleRangeDescr, 

1771 SigmoidDescr, 

1772 SoftmaxDescr, 

1773 ZeroMeanUnitVarianceDescr, 

1774 ], 

1775 Discriminator("id"), 

1776] 

1777PostprocessingDescr = Annotated[ 

1778 Union[ 

1779 BinarizeDescr, 

1780 CellposeFlowDynamicsDescr, 

1781 ClipDescr, 

1782 CustomProcessingDescr, 

1783 EnsureDtypeDescr, 

1784 FixedZeroMeanUnitVarianceDescr, 

1785 ScaleLinearDescr, 

1786 ScaleMeanVarianceDescr, 

1787 ScaleRangeDescr, 

1788 SigmoidDescr, 

1789 SoftmaxDescr, 

1790 StardistPostprocessingDescr, 

1791 ZeroMeanUnitVarianceDescr, 

1792 ], 

1793 Discriminator("id"), 

1794] 

1795 

1796IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1797 

1798 

1799class TensorDescrBase(Node, Generic[IO_AxisT]): 

1800 id: TensorId 

1801 """Tensor id. No duplicates are allowed.""" 

1802 

1803 description: Annotated[str, MaxLen(128)] = "" 

1804 """free text description""" 

1805 

1806 axes: NotEmpty[Sequence[IO_AxisT]] 

1807 """tensor axes""" 

1808 

1809 @property 

1810 def shape(self): 

1811 return tuple(a.size for a in self.axes) 

1812 

1813 @field_validator("axes", mode="after", check_fields=False) 

1814 @classmethod 

1815 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1816 batch_axes = [a for a in axes if a.type == "batch"] 

1817 if len(batch_axes) > 1: 

1818 raise ValueError( 

1819 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1820 ) 

1821 

1822 seen_ids: Set[AxisId] = set() 

1823 duplicate_axes_ids: Set[AxisId] = set() 

1824 for a in axes: 

1825 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1826 

1827 if duplicate_axes_ids: 

1828 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1829 

1830 return axes 

1831 

1832 test_tensor: FAIR[Optional[FileDescr_package]] = None 

1833 """An example tensor to use for testing. 

1834 Using the model with the test input tensors is expected to yield the test output tensors. 

1835 Each test tensor has be a an ndarray in the 

1836 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1837 The file extension must be '.npy'.""" 

1838 

1839 sample_tensor: FAIR[Optional[FileDescr_package]] = None 

1840 """A sample tensor to illustrate a possible input/output for the model, 

1841 The sample image primarily serves to inform a human user about an example use case 

1842 and is typically stored as .hdf5, .png or .tiff. 

1843 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1844 (numpy's `.npy` format is not supported). 

1845 The image dimensionality has to match the number of axes specified in this tensor description. 

1846 """ 

1847 

1848 @model_validator(mode="after") 

1849 def _validate_sample_tensor(self) -> Self: 

1850 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1851 return self 

1852 

1853 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1854 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType] 

1855 reader.read(), 

1856 extension=PurePosixPath(reader.original_file_name).suffix, 

1857 ) 

1858 n_dims = len(tensor.squeeze().shape) 

1859 n_dims_min = n_dims_max = len(self.axes) 

1860 

1861 for a in self.axes: 

1862 if isinstance(a, BatchAxis): 

1863 n_dims_min -= 1 

1864 elif isinstance(a.size, int): 

1865 if a.size == 1: 

1866 n_dims_min -= 1 

1867 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1868 if a.size.min == 1: 

1869 n_dims_min -= 1 

1870 elif isinstance(a.size, SizeReference): 

1871 if a.size.offset < 2: 

1872 # size reference may result in singleton axis 

1873 n_dims_min -= 1 

1874 else: 

1875 assert_never(a.size) 

1876 

1877 n_dims_min = max(0, n_dims_min) 

1878 if n_dims < n_dims_min or n_dims > n_dims_max: 

1879 raise ValueError( 

1880 f"Expected sample tensor to have {n_dims_min} to" 

1881 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1882 ) 

1883 

1884 return self 

1885 

1886 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1887 IntervalOrRatioDataDescr() 

1888 ) 

1889 """Description of the tensor's data values, optionally per channel. 

1890 If specified per channel, the data `type` needs to match across channels.""" 

1891 

1892 @property 

1893 def dtype( 

1894 self, 

1895 ) -> Literal[ 

1896 "float32", 

1897 "float64", 

1898 "uint8", 

1899 "int8", 

1900 "uint16", 

1901 "int16", 

1902 "uint32", 

1903 "int32", 

1904 "uint64", 

1905 "int64", 

1906 "bool", 

1907 ]: 

1908 """dtype as specified under `data.type` or `data[i].type`""" 

1909 if isinstance(self.data, collections.abc.Sequence): 

1910 return self.data[0].type 

1911 else: 

1912 return self.data.type 

1913 

1914 @field_validator("data", mode="after") 

1915 @classmethod 

1916 def _check_data_type_across_channels( 

1917 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1918 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1919 if not isinstance(value, list): 

1920 return value 

1921 

1922 dtypes = {t.type for t in value} 

1923 if len(dtypes) > 1: 

1924 raise ValueError( 

1925 "Tensor data descriptions per channel need to agree in their data" 

1926 + f" `type`, but found {dtypes}." 

1927 ) 

1928 

1929 return value 

1930 

1931 @model_validator(mode="after") 

1932 def _check_data_matches_channelaxis(self) -> Self: 

1933 if not isinstance(self.data, (list, tuple)): 

1934 return self 

1935 

1936 for a in self.axes: 

1937 if isinstance(a, ChannelAxis): 

1938 size = a.size 

1939 assert isinstance(size, int) 

1940 break 

1941 else: 

1942 return self 

1943 

1944 if len(self.data) != size: 

1945 raise ValueError( 

1946 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1947 + f" '{a.id}' axis has size {size}." 

1948 ) 

1949 

1950 return self 

1951 

1952 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1953 if len(array.shape) != len(self.axes): 

1954 raise ValueError( 

1955 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1956 + f" incompatible with {len(self.axes)} axes." 

1957 ) 

1958 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1959 

1960 

1961class ConstantPadding(Node): 

1962 mode: Literal["constant"] = "constant" 

1963 value: Union[int, float] = 0 

1964 

1965 

1966class EdgePadding(Node): 

1967 mode: Literal["edge"] = "edge" 

1968 

1969 

1970class ReflectPadding(Node): 

1971 mode: Literal["reflect"] = "reflect" 

1972 

1973 

1974class SymmetricPadding(Node): 

1975 mode: Literal["symmetric"] = "symmetric" 

1976 

1977 

1978Padding = Union[ConstantPadding, EdgePadding, ReflectPadding, SymmetricPadding] 

1979 

1980 

1981class InputTensorDescr(TensorDescrBase[InputAxis]): 

1982 id: TensorId = TensorId("input") 

1983 """Input tensor id. 

1984 No duplicates are allowed across all inputs and outputs.""" 

1985 

1986 optional: bool = False 

1987 """indicates that this tensor may be `None`""" 

1988 

1989 pad: Optional[Padding] = None 

1990 """Explicitly specify how to pad this input tensor. 

1991 

1992 Use `axes[i].pad` to specify padding width. 

1993 

1994 Note: 

1995 Non-blockwise sample prediction only applies padding for axes with a `pad` specification. 

1996 """ 

1997 

1998 preprocessing: List[PreprocessingDescr] = Field( 

1999 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 

2000 ) 

2001 """Description of how this input should be preprocessed. 

2002 

2003 notes: 

2004 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

2005 to ensure an input tensor's data type matches the input tensor's data description. 

2006 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

2007 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

2008 changing the data type. 

2009 """ 

2010 

2011 @model_validator(mode="after") 

2012 def _validate_preprocessing_kwargs(self) -> Self: 

2013 axes_ids = [a.id for a in self.axes] 

2014 for p in self.preprocessing: 

2015 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

2016 if kwargs_axes is None: 

2017 continue 

2018 

2019 if not isinstance(kwargs_axes, collections.abc.Sequence): 

2020 raise ValueError( 

2021 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

2022 ) 

2023 

2024 if any(a not in axes_ids for a in kwargs_axes): 

2025 raise ValueError( 

2026 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

2027 ) 

2028 

2029 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

2030 dtype = self.data.type 

2031 else: 

2032 dtype = self.data[0].type 

2033 

2034 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

2035 if not self.preprocessing or not isinstance( 

2036 self.preprocessing[0], EnsureDtypeDescr 

2037 ): 

2038 self.preprocessing.insert( 

2039 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2040 ) 

2041 

2042 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

2043 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

2044 self.preprocessing.append( 

2045 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2046 ) 

2047 

2048 return self 

2049 

2050 

2051def convert_axes( 

2052 axes: str, 

2053 *, 

2054 shape: Union[ 

2055 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

2056 ], 

2057 tensor_type: Literal["input", "output"], 

2058 halo: Optional[Sequence[int]], 

2059 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2060): 

2061 ret: List[AnyAxis] = [] 

2062 for i, a in enumerate(axes): 

2063 axis_type = _AXIS_TYPE_MAP.get(a, a) 

2064 if axis_type == "batch": 

2065 ret.append(BatchAxis()) 

2066 continue 

2067 

2068 scale = 1.0 

2069 if isinstance(shape, _ParameterizedInputShape_v0_4): 

2070 if shape.step[i] == 0: 

2071 size = shape.min[i] 

2072 else: 

2073 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

2074 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

2075 ref_t = str(shape.reference_tensor) 

2076 if ref_t.count(".") == 1: 

2077 t_id, orig_a_id = ref_t.split(".") 

2078 else: 

2079 t_id = ref_t 

2080 orig_a_id = a 

2081 

2082 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

2083 if not (orig_scale := shape.scale[i]): 

2084 # old way to insert a new axis dimension 

2085 size = int(2 * shape.offset[i]) 

2086 else: 

2087 scale = 1 / orig_scale 

2088 if axis_type in ("channel", "index"): 

2089 # these axes no longer have a scale 

2090 offset_from_scale = orig_scale * size_refs.get( 

2091 _TensorName_v0_4(t_id), {} 

2092 ).get(orig_a_id, 0) 

2093 else: 

2094 offset_from_scale = 0 

2095 size = SizeReference( 

2096 tensor_id=TensorId(t_id), 

2097 axis_id=AxisId(a_id), 

2098 offset=int(offset_from_scale + 2 * shape.offset[i]), 

2099 ) 

2100 else: 

2101 size = shape[i] 

2102 

2103 if axis_type == "time": 

2104 if tensor_type == "input": 

2105 ret.append(TimeInputAxis(size=size, scale=scale)) 

2106 else: 

2107 assert not isinstance(size, ParameterizedSize) 

2108 if halo is None: 

2109 ret.append(TimeOutputAxis(size=size, scale=scale)) 

2110 else: 

2111 assert not isinstance(size, int) 

2112 ret.append( 

2113 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

2114 ) 

2115 

2116 elif axis_type == "index": 

2117 if tensor_type == "input": 

2118 ret.append(IndexInputAxis(size=size)) 

2119 else: 

2120 if isinstance(size, ParameterizedSize): 

2121 size = DataDependentSize(min=size.min) 

2122 

2123 ret.append(IndexOutputAxis(size=size)) 

2124 elif axis_type == "channel": 

2125 assert not isinstance(size, ParameterizedSize) 

2126 if isinstance(size, SizeReference): 

2127 warnings.warn( 

2128 "Conversion of channel size from an implicit output shape may be" 

2129 + " wrong" 

2130 ) 

2131 ret.append( 

2132 ChannelAxis( 

2133 channel_names=[f"channel{i}" for i in range(size.offset)] 

2134 ) 

2135 ) 

2136 else: 

2137 ret.append( 

2138 ChannelAxis(channel_names=[f"channel{i}" for i in range(size)]) 

2139 ) 

2140 elif axis_type == "space": 

2141 if tensor_type == "input": 

2142 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

2143 else: 

2144 assert not isinstance(size, ParameterizedSize) 

2145 if halo is None or halo[i] == 0: 

2146 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

2147 elif isinstance(size, int): 

2148 raise NotImplementedError( 

2149 f"output axis with halo and fixed size (here {size}) not allowed" 

2150 ) 

2151 else: 

2152 ret.append( 

2153 SpaceOutputAxisWithHalo( 

2154 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

2155 ) 

2156 ) 

2157 

2158 return ret 

2159 

2160 

2161def _axes_letters_to_ids( 

2162 axes: Optional[str], 

2163) -> Optional[List[AxisId]]: 

2164 if axes is None: 

2165 return None 

2166 

2167 return [AxisId(a) for a in axes] 

2168 

2169 

2170def _get_complement_v04_axis( 

2171 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

2172) -> Optional[AxisId]: 

2173 if axes is None: 

2174 return None 

2175 

2176 non_complement_axes = set(axes) | {"b"} 

2177 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

2178 if len(complement_axes) > 1: 

2179 raise ValueError( 

2180 f"Expected none or a single complement axis, but axes '{axes}' " 

2181 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

2182 ) 

2183 

2184 return None if not complement_axes else AxisId(complement_axes[0]) 

2185 

2186 

2187def _convert_proc( 

2188 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

2189 tensor_axes: Sequence[str], 

2190) -> Union[PreprocessingDescr, PostprocessingDescr]: 

2191 if isinstance(p, _BinarizeDescr_v0_4): 

2192 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

2193 elif isinstance(p, _ClipDescr_v0_4): 

2194 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

2195 elif isinstance(p, _SigmoidDescr_v0_4): 

2196 return SigmoidDescr() 

2197 elif isinstance(p, _ScaleLinearDescr_v0_4): 

2198 axes = _axes_letters_to_ids(p.kwargs.axes) 

2199 if p.kwargs.axes is None: 

2200 axis = None 

2201 else: 

2202 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

2203 

2204 if axis is None: 

2205 assert not isinstance(p.kwargs.gain, list) 

2206 assert not isinstance(p.kwargs.offset, list) 

2207 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

2208 else: 

2209 kwargs = ScaleLinearAlongAxisKwargs( 

2210 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

2211 ) 

2212 return ScaleLinearDescr(kwargs=kwargs) 

2213 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

2214 return ScaleMeanVarianceDescr( 

2215 kwargs=ScaleMeanVarianceKwargs( 

2216 axes=_axes_letters_to_ids(p.kwargs.axes), 

2217 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

2218 eps=p.kwargs.eps, 

2219 ) 

2220 ) 

2221 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

2222 if p.kwargs.mode == "fixed": 

2223 mean = p.kwargs.mean 

2224 std = p.kwargs.std 

2225 assert mean is not None 

2226 assert std is not None 

2227 

2228 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

2229 

2230 if axis is None: 

2231 if isinstance(mean, list): 

2232 raise ValueError("Expected single float value for mean, not <list>") 

2233 if isinstance(std, list): 

2234 raise ValueError("Expected single float value for std, not <list>") 

2235 return FixedZeroMeanUnitVarianceDescr( 

2236 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct( 

2237 mean=mean, 

2238 std=std, 

2239 ) 

2240 ) 

2241 else: 

2242 if not isinstance(mean, list): 

2243 mean = [float(mean)] 

2244 if not isinstance(std, list): 

2245 std = [float(std)] 

2246 

2247 return FixedZeroMeanUnitVarianceDescr( 

2248 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

2249 axis=axis, mean=mean, std=std 

2250 ) 

2251 ) 

2252 

2253 else: 

2254 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

2255 if p.kwargs.mode == "per_dataset": 

2256 axes = [AxisId("batch")] + axes 

2257 if not axes: 

2258 axes = None 

2259 return ZeroMeanUnitVarianceDescr( 

2260 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

2261 ) 

2262 

2263 elif isinstance(p, _ScaleRangeDescr_v0_4): 

2264 return ScaleRangeDescr( 

2265 kwargs=ScaleRangeKwargs( 

2266 axes=_axes_letters_to_ids(p.kwargs.axes), 

2267 min_percentile=p.kwargs.min_percentile, 

2268 max_percentile=p.kwargs.max_percentile, 

2269 eps=p.kwargs.eps, 

2270 ) 

2271 ) 

2272 else: 

2273 assert_never(p) 

2274 

2275 

2276class _InputTensorConv( 

2277 Converter[ 

2278 _InputTensorDescr_v0_4, 

2279 InputTensorDescr, 

2280 FileSource, 

2281 Optional[FileSource], 

2282 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2283 ] 

2284): 

2285 def _convert( 

2286 self, 

2287 src: _InputTensorDescr_v0_4, 

2288 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

2289 test_tensor: FileSource, 

2290 sample_tensor: Optional[FileSource], 

2291 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2292 ) -> "InputTensorDescr | dict[str, Any]": 

2293 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2294 src.axes, 

2295 shape=src.shape, 

2296 tensor_type="input", 

2297 halo=None, 

2298 size_refs=size_refs, 

2299 ) 

2300 prep: List[PreprocessingDescr] = [] 

2301 for p in src.preprocessing: 

2302 cp = _convert_proc(p, src.axes) 

2303 assert not isinstance( 

2304 cp, 

2305 ( 

2306 CellposeFlowDynamicsDescr, 

2307 CustomProcessingDescr, 

2308 ScaleMeanVarianceDescr, 

2309 StardistPostprocessingDescr, 

2310 ), 

2311 ) 

2312 prep.append(cp) 

2313 

2314 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

2315 

2316 return tgt( 

2317 axes=axes, 

2318 id=TensorId(str(src.name)), 

2319 test_tensor=FileDescr(source=test_tensor), 

2320 sample_tensor=( 

2321 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2322 ), 

2323 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

2324 preprocessing=prep, 

2325 ) 

2326 

2327 

2328_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

2329 

2330 

2331class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

2332 id: TensorId = TensorId("output") 

2333 """Output tensor id. 

2334 No duplicates are allowed across all inputs and outputs.""" 

2335 

2336 postprocessing: List[PostprocessingDescr] = Field( 

2337 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

2338 ) 

2339 """Description of how this output should be postprocessed. 

2340 

2341 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

2342 If not given this is added to cast to this tensor's `data.type`. 

2343 """ 

2344 

2345 @model_validator(mode="after") 

2346 def _validate_postprocessing_kwargs(self) -> Self: 

2347 axes_ids = [a.id for a in self.axes] 

2348 for p in self.postprocessing: 

2349 kwargs_axes = p.kwargs.get("axes") 

2350 if kwargs_axes is None: 

2351 continue 

2352 

2353 if not isinstance(kwargs_axes, collections.abc.Sequence): 

2354 raise ValueError( 

2355 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

2356 ) 

2357 

2358 kwargs_axes_seq: Sequence[Any] = cast(Sequence[Any], kwargs_axes) 

2359 if any(a not in axes_ids for a in kwargs_axes_seq): 

2360 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

2361 

2362 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

2363 dtype = self.data.type 

2364 else: 

2365 dtype = self.data[0].type 

2366 

2367 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

2368 if not self.postprocessing or not isinstance( 

2369 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

2370 ): 

2371 self.postprocessing.append( 

2372 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2373 ) 

2374 return self 

2375 

2376 

2377class _OutputTensorConv( 

2378 Converter[ 

2379 _OutputTensorDescr_v0_4, 

2380 OutputTensorDescr, 

2381 FileSource, 

2382 Optional[FileSource], 

2383 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2384 ] 

2385): 

2386 def _convert( 

2387 self, 

2388 src: _OutputTensorDescr_v0_4, 

2389 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

2390 test_tensor: FileSource, 

2391 sample_tensor: Optional[FileSource], 

2392 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2393 ) -> "OutputTensorDescr | dict[str, Any]": 

2394 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2395 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2396 src.axes, 

2397 shape=src.shape, 

2398 tensor_type="output", 

2399 halo=src.halo, 

2400 size_refs=size_refs, 

2401 ) 

2402 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2403 if data_descr["type"] == "bool": 

2404 data_descr["values"] = [False, True] 

2405 

2406 return tgt( 

2407 axes=axes, 

2408 id=TensorId(str(src.name)), 

2409 test_tensor=FileDescr(source=test_tensor), 

2410 sample_tensor=( 

2411 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2412 ), 

2413 data=data_descr, # pyright: ignore[reportArgumentType] 

2414 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2415 ) 

2416 

2417 

2418_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2419 

2420 

2421TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2422 

2423 

2424def get_halos( 

2425 tensors: Mapping[TensorId, TensorDescr], 

2426 /, 

2427) -> Dict[TensorId, Dict[AxisId, Tuple[int, int]]]: 

2428 """Get all input and output halos from tensor descriptions. 

2429 

2430 Note: 

2431 - Input halos are to be padded 

2432 - Output halos are to be cropped 

2433 """ 

2434 halos: Dict[TensorId, Dict[AxisId, Tuple[int, int]]] = {} 

2435 for descr in tensors.values(): 

2436 if isinstance(descr, InputTensorDescr): 

2437 continue 

2438 for axis in descr.axes: 

2439 if not isinstance(axis, WithHalo): 

2440 continue 

2441 

2442 ref_scale = next( 

2443 a 

2444 for a in tensors[axis.size.tensor_id].axes 

2445 if a.id == axis.size.axis_id 

2446 ).scale 

2447 

2448 # set output halo (to be cropped) 

2449 halos.setdefault(descr.id, {})[axis.id] = (axis.halo, axis.halo) 

2450 # set input halo (to be padded) 

2451 pad_width = int(axis.halo / axis.scale * ref_scale) 

2452 halos.setdefault(axis.size.tensor_id, {})[axis.size.axis_id] = ( 

2453 pad_width, 

2454 pad_width, 

2455 ) 

2456 

2457 return halos 

2458 

2459 

2460def validate_tensors( 

2461 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 

2462 tensor_origin: Literal[ 

2463 "source", "test_tensor" 

2464 ] = "source", # for more precise error messages 

2465 *, 

2466 pad_inputs: Union[bool, Literal["allow"]] = True, 

2467 crop_outputs: Union[bool, Literal["allow"]] = True, 

2468): 

2469 """Validate all inputs (and optionally output tensors) against their tensor descriptions. 

2470 

2471 Args: 

2472 tensors: Mapping of tensor id to a tuple of tensor description and optional numpy array. 

2473 tensor_origin: String to use in error messages to indicate the origin of the tensors being validated. 

2474 pad_inputs: Wether to apply/allow padding of inputs before shape comparison 

2475 crop_outputs: Wether to apply/allow cropping of outputs before shape comparison. 

2476 """ 

2477 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 

2478 

2479 def e_msg_location(d: TensorDescr): 

2480 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2481 

2482 for descr, array in tensors.values(): 

2483 if array is None: 

2484 axis_sizes = {a.id: None for a in descr.axes} 

2485 else: 

2486 try: 

2487 axis_sizes = descr.get_axis_sizes_for_array(array) 

2488 except ValueError as e: 

2489 raise ValueError(f"{e_msg_location(descr)} {e}") 

2490 

2491 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 

2492 

2493 # get halos to be padded/cropped to validate against halo-adjusted sizes 

2494 io_halos = get_halos({k: v[0] for k, v in tensors.items()}) 

2495 

2496 for descr, array in tensors.values(): 

2497 if array is None: 

2498 continue 

2499 

2500 if descr.dtype in ("float32", "float64"): 

2501 invalid_test_tensor_dtype = array.dtype.name not in ( 

2502 "float32", 

2503 "float64", 

2504 "uint8", 

2505 "int8", 

2506 "uint16", 

2507 "int16", 

2508 "uint32", 

2509 "int32", 

2510 "uint64", 

2511 "int64", 

2512 ) 

2513 else: 

2514 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2515 

2516 if invalid_test_tensor_dtype: 

2517 raise ValueError( 

2518 f"{tensor_origin} data type '{array.dtype.name}' does not" 

2519 + f" match described {e_msg_location(descr)}.dtype '{descr.dtype}'" 

2520 ) 

2521 

2522 if array.min() > -1e-4 and array.max() < 1e-4: 

2523 raise ValueError( 

2524 "Output values are too small for reliable testing." 

2525 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2526 ) 

2527 

2528 for a in descr.axes: 

2529 actual_size = all_tensor_axes[descr.id][a.id][1] 

2530 

2531 if actual_size is None: 

2532 continue 

2533 

2534 if a.size is None: 

2535 continue 

2536 

2537 # add padding width to actual tensor size 

2538 total_axis_halo = sum(io_halos.get(descr.id, {}).get(a.id, (0, 0))) 

2539 if isinstance(descr, InputTensorDescr): 

2540 # pad input halos 

2541 actual_size_with_halo = actual_size + total_axis_halo 

2542 if pad_inputs is True: 

2543 check_sizes = {actual_size_with_halo} 

2544 size_hint = " (after padding input halo)" 

2545 elif pad_inputs == "allow": 

2546 check_sizes = {actual_size, actual_size_with_halo} 

2547 size_hint = " (with or without padding input halo)" 

2548 elif pad_inputs is False: 

2549 check_sizes = {actual_size} 

2550 size_hint = "" 

2551 else: 

2552 assert_never(pad_inputs) 

2553 

2554 elif isinstance(descr, OutputTensorDescr): 

2555 # crop output halos 

2556 actual_size_with_halo = max(0, actual_size - total_axis_halo) 

2557 if crop_outputs is True: 

2558 check_sizes = {actual_size_with_halo} 

2559 size_hint = " (after cropping output halo)" 

2560 elif crop_outputs == "allow": 

2561 check_sizes = {actual_size, actual_size_with_halo} 

2562 size_hint = " (with or without cropping output halo)" 

2563 elif crop_outputs is False: 

2564 check_sizes = {actual_size} 

2565 size_hint = "" 

2566 else: 

2567 assert_never(crop_outputs) 

2568 else: 

2569 assert_never(descr) 

2570 

2571 del actual_size # make sure we explicitly use unchanged or halo-adjusted size from here on 

2572 

2573 if isinstance(a.size, int): 

2574 if a.size not in check_sizes: 

2575 raise ValueError( 

2576 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis " 

2577 + f"has incompatible size {check_sizes}{size_hint}, expected {a.size}" 

2578 ) 

2579 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

2580 _ = try_all_raise_last( 

2581 (partial(a.size.validate_size, s) for s in check_sizes), 

2582 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis ", 

2583 ) 

2584 elif isinstance(a.size, SizeReference): 

2585 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2586 if ref_tensor_axes is None: 

2587 raise ValueError( 

2588 f"{e_msg_location(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2589 + f" reference '{a.size.tensor_id}', available: {list(all_tensor_axes)}" 

2590 ) 

2591 

2592 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2593 if ref_axis is None or ref_size is None: 

2594 raise ValueError( 

2595 f"{e_msg_location(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2596 + f" reference '{a.size.tensor_id}.{a.size.axis_id}, available: {list(ref_tensor_axes)}" 

2597 ) 

2598 

2599 if a.unit != ref_axis.unit: 

2600 raise ValueError( 

2601 f"{e_msg_location(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2602 + " axis and reference axis to have the same `unit`, but" 

2603 + f" {a.unit}!={ref_axis.unit}" 

2604 ) 

2605 

2606 if ( 

2607 expected_size := ( 

2608 ref_size * ref_axis.scale / a.scale + a.size.offset 

2609 ) 

2610 ) not in check_sizes: 

2611 raise ValueError( 

2612 f"{e_msg_location(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2613 + f" {check_sizes} invalid for referenced size {ref_size};" 

2614 + f" expected {expected_size}" 

2615 ) 

2616 else: 

2617 assert_never(a.size) 

2618 

2619 

2620FileDescr_dependencies = Annotated[ 

2621 FileDescr_package, 

2622 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2623 Field(examples=[dict(source="environment.yaml")]), 

2624] 

2625 

2626 

2627class _ArchitectureCallableDescr(Node): 

2628 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2629 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2630 

2631 kwargs: Dict[str, YamlValue] = Field( 

2632 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2633 ) 

2634 """key word arguments for the `callable`""" 

2635 

2636 

2637class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2638 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2639 """Architecture source file""" 

2640 

2641 @model_serializer(mode="wrap", when_used="unless-none") 

2642 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2643 return package_file_descr_serializer(self, nxt, info) 

2644 

2645 

2646class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2647 import_from: str 

2648 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2649 

2650 

2651class _ArchFileConv( 

2652 Converter[ 

2653 _CallableFromFile_v0_4, 

2654 ArchitectureFromFileDescr, 

2655 Optional[Sha256], 

2656 Dict[str, Any], 

2657 ] 

2658): 

2659 def _convert( 

2660 self, 

2661 src: _CallableFromFile_v0_4, 

2662 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2663 sha256: Optional[Sha256], 

2664 kwargs: Dict[str, Any], 

2665 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2666 if src.startswith("http") and src.count(":") == 2: 

2667 http, source, callable_ = src.split(":") 

2668 source = ":".join((http, source)) 

2669 elif not src.startswith("http") and src.count(":") == 1: 

2670 source, callable_ = src.split(":") 

2671 else: 

2672 source = str(src) 

2673 callable_ = str(src) 

2674 return tgt( 

2675 callable=Identifier(callable_), 

2676 source=cast(FileSource, source), 

2677 sha256=sha256, 

2678 kwargs=kwargs, 

2679 ) 

2680 

2681 

2682_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2683 

2684 

2685class _ArchLibConv( 

2686 Converter[ 

2687 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2688 ] 

2689): 

2690 def _convert( 

2691 self, 

2692 src: _CallableFromDepencency_v0_4, 

2693 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2694 kwargs: Dict[str, Any], 

2695 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2696 *mods, callable_ = src.split(".") 

2697 import_from = ".".join(mods) 

2698 return tgt( 

2699 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2700 ) 

2701 

2702 

2703_arch_lib_conv = _ArchLibConv( 

2704 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2705) 

2706 

2707 

2708class WeightsEntryDescrBase(FileDescr): 

2709 type: ClassVar[WeightsFormat] 

2710 weights_format_name: ClassVar[str] # human readable 

2711 

2712 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2713 """Source of the weights file.""" 

2714 

2715 authors: Optional[List[Author]] = None 

2716 """Authors 

2717 Either the person(s) that have trained this model resulting in the original weights file. 

2718 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2719 Or the person(s) who have converted the weights to this weights format. 

2720 (If this is a child weight, i.e. it has a `parent` field) 

2721 """ 

2722 

2723 parent: Annotated[ 

2724 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2725 ] = None 

2726 """The source weights these weights were converted from. 

2727 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2728 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2729 All weight entries except one (the initial set of weights resulting from training the model), 

2730 need to have this field.""" 

2731 

2732 comment: str = "" 

2733 """A comment about this weights entry, for example how these weights were created.""" 

2734 

2735 @model_validator(mode="after") 

2736 def _validate(self) -> Self: 

2737 if self.type == self.parent: 

2738 raise ValueError("Weights entry can't be it's own parent.") 

2739 

2740 return self 

2741 

2742 @model_serializer(mode="wrap", when_used="unless-none") 

2743 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2744 return package_file_descr_serializer(self, nxt, info) 

2745 

2746 

2747class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2748 type: ClassVar[WeightsFormat] = "keras_hdf5" 

2749 weights_format_name: ClassVar[str] = "Keras HDF5" 

2750 tensorflow_version: Version 

2751 """TensorFlow version used to create these weights.""" 

2752 

2753 

2754class KerasV3WeightsDescr(WeightsEntryDescrBase): 

2755 type: ClassVar[WeightsFormat] = "keras_v3" 

2756 weights_format_name: ClassVar[str] = "Keras v3" 

2757 keras_version: Annotated[Version, Ge(Version(3))] 

2758 """Keras version used to create these weights.""" 

2759 backend: Tuple[Literal["tensorflow", "jax", "torch"], Version] 

2760 """Keras backend used to create these weights.""" 

2761 source: Annotated[ 

2762 FileSource, 

2763 AfterValidator(wo_special_file_name), 

2764 WithSuffix(".keras", case_sensitive=True), 

2765 ] 

2766 """Source of the .keras weights file.""" 

2767 

2768 

2769FileDescr_external_data = Annotated[ 

2770 FileDescr_package, 

2771 WithSuffix(".data", case_sensitive=True), 

2772 Field(examples=[dict(source="weights.onnx.data")]), 

2773] 

2774 

2775 

2776class OnnxWeightsDescr(WeightsEntryDescrBase): 

2777 type: ClassVar[WeightsFormat] = "onnx" 

2778 weights_format_name: ClassVar[str] = "ONNX" 

2779 opset_version: Annotated[int, Ge(7)] 

2780 """ONNX opset version""" 

2781 

2782 external_data: Optional[FileDescr_external_data] = None 

2783 """Source of the external ONNX data file holding the weights. 

2784 (If present **source** holds the ONNX architecture without weights).""" 

2785 

2786 @model_validator(mode="after") 

2787 def _validate_external_data_unique_file_name(self) -> Self: 

2788 if self.external_data is not None and ( 

2789 extract_file_name(self.source) 

2790 == extract_file_name(self.external_data.source) 

2791 ): 

2792 raise ValueError( 

2793 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'" 

2794 + " must be different from ONNX `source` file name." 

2795 ) 

2796 

2797 return self 

2798 

2799 

2800class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2801 type: ClassVar[WeightsFormat] = "pytorch_state_dict" 

2802 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2803 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 

2804 pytorch_version: Version 

2805 """Version of the PyTorch library used. 

2806 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2807 """ 

2808 dependencies: Optional[FileDescr_dependencies] = None 

2809 """Custom depencies beyond pytorch described in a Conda environment file. 

2810 Allows to specify custom dependencies, see conda docs: 

2811 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2812 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2813 

2814 The conda environment file should include pytorch and any version pinning has to be compatible with 

2815 **pytorch_version**. 

2816 """ 

2817 strict: bool = True 

2818 """Whether to allow missing or unexpected keys or to be strict about the architecture matching the state dict weights.""" 

2819 

2820 

2821class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2822 type: ClassVar[WeightsFormat] = "tensorflow_js" 

2823 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2824 tensorflow_version: Version 

2825 """Version of the TensorFlow library used.""" 

2826 

2827 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2828 """The multi-file weights. 

2829 All required files/folders should be a zip archive.""" 

2830 

2831 

2832class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2833 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle" 

2834 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2835 tensorflow_version: Version 

2836 """Version of the TensorFlow library used.""" 

2837 

2838 dependencies: Optional[FileDescr_dependencies] = None 

2839 """Custom dependencies beyond tensorflow. 

2840 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 

2841 

2842 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2843 """The multi-file weights. 

2844 All required files/folders should be a zip archive.""" 

2845 

2846 

2847class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2848 type: ClassVar[WeightsFormat] = "torchscript" 

2849 weights_format_name: ClassVar[str] = "TorchScript" 

2850 pytorch_version: Version 

2851 """Version of the PyTorch library used.""" 

2852 

2853 

2854SpecificWeightsDescr = Union[ 

2855 KerasHdf5WeightsDescr, 

2856 KerasV3WeightsDescr, 

2857 OnnxWeightsDescr, 

2858 PytorchStateDictWeightsDescr, 

2859 TensorflowJsWeightsDescr, 

2860 TensorflowSavedModelBundleWeightsDescr, 

2861 TorchscriptWeightsDescr, 

2862] 

2863 

2864 

2865class WeightsDescr(Node): 

2866 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2867 keras_v3: Optional[KerasV3WeightsDescr] = None 

2868 onnx: Optional[OnnxWeightsDescr] = None 

2869 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2870 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2871 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2872 None 

2873 ) 

2874 torchscript: Optional[TorchscriptWeightsDescr] = None 

2875 

2876 @model_validator(mode="after") 

2877 def check_entries(self) -> Self: 

2878 entries = {wtype for wtype, entry in self if entry is not None} 

2879 

2880 if not entries: 

2881 raise ValueError("Missing weights entry") 

2882 

2883 entries_wo_parent = { 

2884 wtype 

2885 for wtype, entry in self 

2886 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2887 } 

2888 if len(entries_wo_parent) != 1: 

2889 issue_warning( 

2890 "Exactly one weights entry may not specify the `parent` field (got" 

2891 + " {value}). That entry is considered the original set of model weights." 

2892 + " Other weight formats are created through conversion of the orignal or" 

2893 + " already converted weights. They have to reference the weights format" 

2894 + " they were converted from as their `parent`.", 

2895 value=len(entries_wo_parent), 

2896 field="weights", 

2897 ) 

2898 

2899 for wtype, entry in self: 

2900 if entry is None: 

2901 continue 

2902 

2903 assert hasattr(entry, "type") 

2904 assert hasattr(entry, "parent") 

2905 assert wtype == entry.type 

2906 if ( 

2907 entry.parent is not None and entry.parent not in entries 

2908 ): # self reference checked for `parent` field 

2909 raise ValueError( 

2910 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2911 + f" formats: {entries}" 

2912 ) 

2913 

2914 return self 

2915 

2916 def __getitem__( 

2917 self, 

2918 key: WeightsFormat, 

2919 ): 

2920 if key == "keras_hdf5": 

2921 ret = self.keras_hdf5 

2922 elif key == "keras_v3": 

2923 ret = self.keras_v3 

2924 elif key == "onnx": 

2925 ret = self.onnx 

2926 elif key == "pytorch_state_dict": 

2927 ret = self.pytorch_state_dict 

2928 elif key == "tensorflow_js": 

2929 ret = self.tensorflow_js 

2930 elif key == "tensorflow_saved_model_bundle": 

2931 ret = self.tensorflow_saved_model_bundle 

2932 elif key == "torchscript": 

2933 ret = self.torchscript 

2934 else: 

2935 raise KeyError(key) 

2936 

2937 if ret is None: 

2938 raise KeyError(key) 

2939 

2940 return ret 

2941 

2942 @overload 

2943 def __setitem__( 

2944 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr] 

2945 ) -> None: ... 

2946 @overload 

2947 def __setitem__( 

2948 self, key: Literal["keras_v3"], value: Optional[KerasV3WeightsDescr] 

2949 ) -> None: ... 

2950 @overload 

2951 def __setitem__( 

2952 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr] 

2953 ) -> None: ... 

2954 @overload 

2955 def __setitem__( 

2956 self, 

2957 key: Literal["pytorch_state_dict"], 

2958 value: Optional[PytorchStateDictWeightsDescr], 

2959 ) -> None: ... 

2960 @overload 

2961 def __setitem__( 

2962 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr] 

2963 ) -> None: ... 

2964 @overload 

2965 def __setitem__( 

2966 self, 

2967 key: Literal["tensorflow_saved_model_bundle"], 

2968 value: Optional[TensorflowSavedModelBundleWeightsDescr], 

2969 ) -> None: ... 

2970 @overload 

2971 def __setitem__( 

2972 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr] 

2973 ) -> None: ... 

2974 

2975 def __setitem__( 

2976 self, 

2977 key: WeightsFormat, 

2978 value: Optional[SpecificWeightsDescr], 

2979 ): 

2980 if key == "keras_hdf5": 

2981 if value is not None and not isinstance(value, KerasHdf5WeightsDescr): 

2982 raise TypeError( 

2983 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}" 

2984 ) 

2985 self.keras_hdf5 = value 

2986 elif key == "keras_v3": 

2987 if value is not None and not isinstance(value, KerasV3WeightsDescr): 

2988 raise TypeError( 

2989 f"Expected KerasV3WeightsDescr or None for key 'keras_v3', got {type(value)}" 

2990 ) 

2991 self.keras_v3 = value 

2992 elif key == "onnx": 

2993 if value is not None and not isinstance(value, OnnxWeightsDescr): 

2994 raise TypeError( 

2995 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}" 

2996 ) 

2997 self.onnx = value 

2998 elif key == "pytorch_state_dict": 

2999 if value is not None and not isinstance( 

3000 value, PytorchStateDictWeightsDescr 

3001 ): 

3002 raise TypeError( 

3003 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}" 

3004 ) 

3005 self.pytorch_state_dict = value 

3006 elif key == "tensorflow_js": 

3007 if value is not None and not isinstance(value, TensorflowJsWeightsDescr): 

3008 raise TypeError( 

3009 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}" 

3010 ) 

3011 self.tensorflow_js = value 

3012 elif key == "tensorflow_saved_model_bundle": 

3013 if value is not None and not isinstance( 

3014 value, TensorflowSavedModelBundleWeightsDescr 

3015 ): 

3016 raise TypeError( 

3017 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}" 

3018 ) 

3019 self.tensorflow_saved_model_bundle = value 

3020 elif key == "torchscript": 

3021 if value is not None and not isinstance(value, TorchscriptWeightsDescr): 

3022 raise TypeError( 

3023 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}" 

3024 ) 

3025 self.torchscript = value 

3026 else: 

3027 raise KeyError(key) 

3028 

3029 @property 

3030 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]: 

3031 return { 

3032 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

3033 **({} if self.keras_v3 is None else {"keras_v3": self.keras_v3}), 

3034 **({} if self.onnx is None else {"onnx": self.onnx}), 

3035 **( 

3036 {} 

3037 if self.pytorch_state_dict is None 

3038 else {"pytorch_state_dict": self.pytorch_state_dict} 

3039 ), 

3040 **( 

3041 {} 

3042 if self.tensorflow_js is None 

3043 else {"tensorflow_js": self.tensorflow_js} 

3044 ), 

3045 **( 

3046 {} 

3047 if self.tensorflow_saved_model_bundle is None 

3048 else { 

3049 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

3050 } 

3051 ), 

3052 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

3053 } 

3054 

3055 @property 

3056 def missing_formats(self) -> Set[WeightsFormat]: 

3057 return { 

3058 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

3059 } 

3060 

3061 

3062class ModelId(ResourceId): 

3063 pass 

3064 

3065 

3066class LinkedModel(LinkedResourceBase): 

3067 """Reference to a bioimage.io model.""" 

3068 

3069 id: ModelId 

3070 """A valid model `id` from the bioimage.io collection.""" 

3071 

3072 

3073class _DataDepSize(NamedTuple): 

3074 min: StrictInt 

3075 max: Optional[StrictInt] 

3076 

3077 

3078class _AxisSizes(NamedTuple): 

3079 """the lenghts of all axes of model inputs and outputs""" 

3080 

3081 inputs: Dict[Tuple[TensorId, AxisId], int] 

3082 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

3083 

3084 

3085class _TensorSizes(NamedTuple): 

3086 """_AxisSizes as nested dicts""" 

3087 

3088 inputs: Dict[TensorId, Dict[AxisId, int]] 

3089 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

3090 

3091 

3092class ReproducibilityTolerance(Node, extra="allow"): 

3093 """Describes what small numerical differences -- if any -- may be tolerated 

3094 in the generated output when executing in different environments. 

3095 

3096 A tensor element *output* is considered mismatched to the **test_tensor** if 

3097 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

3098 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

3099 

3100 Motivation: 

3101 For testing we can request the respective deep learning frameworks to be as 

3102 reproducible as possible by setting seeds and chosing deterministic algorithms, 

3103 but differences in operating systems, available hardware and installed drivers 

3104 may still lead to numerical differences. 

3105 """ 

3106 

3107 relative_tolerance: RelativeTolerance = 1e-3 

3108 """Maximum relative tolerance of reproduced test tensor.""" 

3109 

3110 absolute_tolerance: AbsoluteTolerance = 1e-3 

3111 """Maximum absolute tolerance of reproduced test tensor.""" 

3112 

3113 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 

3114 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

3115 

3116 output_ids: Sequence[TensorId] = () 

3117 """Limits the output tensor IDs these reproducibility details apply to.""" 

3118 

3119 weights_formats: Sequence[WeightsFormat] = () 

3120 """Limits the weights formats these details apply to.""" 

3121 

3122 

3123class BiasRisksLimitations(Node, extra="allow"): 

3124 """Known biases, risks, technical limitations, and recommendations for model use.""" 

3125 

3126 known_biases: str = dedent("""\ 

3127 In general bioimage models may suffer from biases caused by: 

3128 

3129 - Imaging protocol dependencies 

3130 - Use of a specific cell type 

3131 - Species-specific training data limitations 

3132 

3133 """) 

3134 """Biases in training data or model behavior.""" 

3135 

3136 risks: str = dedent("""\ 

3137 Common risks in bioimage analysis include: 

3138 

3139 - Erroneously assuming generalization to unseen experimental conditions 

3140 - Trusting (overconfident) model outputs without validation 

3141 - Misinterpretation of results 

3142 

3143 """) 

3144 """Potential risks in the context of bioimage analysis.""" 

3145 

3146 limitations: Optional[str] = None 

3147 """Technical limitations and failure modes.""" 

3148 

3149 recommendations: str = "Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model." 

3150 """Mitigation strategies regarding `known_biases`, `risks`, and `limitations`, as well as applicable best practices. 

3151 

3152 Consider: 

3153 - How to use a validation dataset? 

3154 - How to manually validate? 

3155 - Feasibility of domain adaptation for different experimental setups? 

3156 

3157 """ 

3158 

3159 def format_md(self) -> str: 

3160 if self.limitations is None: 

3161 limitations_header = "" 

3162 else: 

3163 limitations_header = "## Limitations\n\n" 

3164 

3165 return f"""# Bias, Risks, and Limitations 

3166 

3167{self.known_biases} 

3168 

3169{self.risks} 

3170 

3171{limitations_header}{self.limitations or ""} 

3172 

3173## Recommendations 

3174 

3175{self.recommendations} 

3176 

3177""" 

3178 

3179 

3180class TrainingDetails(Node, extra="allow"): 

3181 training_preprocessing: Optional[str] = None 

3182 """Detailed image preprocessing steps during model training: 

3183 

3184 Mention: 

3185 - *Normalization methods* 

3186 - *Augmentation strategies* 

3187 - *Resizing/resampling procedures* 

3188 - *Artifact handling* 

3189 

3190 """ 

3191 

3192 training_epochs: Optional[float] = None 

3193 """Number of training epochs.""" 

3194 

3195 training_batch_size: Optional[float] = None 

3196 """Batch size used in training.""" 

3197 

3198 initial_learning_rate: Optional[float] = None 

3199 """Initial learning rate used in training.""" 

3200 

3201 learning_rate_schedule: Optional[str] = None 

3202 """Learning rate schedule used in training.""" 

3203 

3204 loss_function: Optional[str] = None 

3205 """Loss function used in training, e.g. nn.MSELoss.""" 

3206 

3207 loss_function_kwargs: Dict[str, YamlValue] = Field( 

3208 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

3209 ) 

3210 """key word arguments for the `loss_function`""" 

3211 

3212 optimizer: Optional[str] = None 

3213 """optimizer, e.g. torch.optim.Adam""" 

3214 

3215 optimizer_kwargs: Dict[str, YamlValue] = Field( 

3216 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

3217 ) 

3218 """key word arguments for the `optimizer`""" 

3219 

3220 regularization: Optional[str] = None 

3221 """Regularization techniques used during training, e.g. drop-out or weight decay.""" 

3222 

3223 training_duration: Optional[float] = None 

3224 """Total training duration in hours.""" 

3225 

3226 

3227class Evaluation(Node, extra="allow"): 

3228 model_id: Optional[ModelId] = None 

3229 """Model being evaluated.""" 

3230 

3231 dataset_id: DatasetId 

3232 """Dataset used for evaluation.""" 

3233 

3234 dataset_source: HttpUrl 

3235 """Source of the dataset.""" 

3236 

3237 dataset_role: Literal["train", "validation", "test", "independent", "unknown"] 

3238 """Role of the dataset used for evaluation. 

3239 

3240 - `train`: dataset was (part of) the training data 

3241 - `validation`: dataset was (part of) the validation data used during training, e.g. used for model selection or hyperparameter tuning 

3242 - `test`: dataset was (part of) the designated test data; not used during training or validation, but acquired from the same source/distribution as training data 

3243 - `independent`: dataset is entirely independent test data; not used during training or validation, and acquired from a different source/distribution than training data 

3244 - `unknown`: role of the dataset is unknown; choose this if you are not certain if (a subset) of the data was seen by the model during training. 

3245 """ 

3246 

3247 sample_count: int 

3248 """Number of evaluated samples.""" 

3249 

3250 evaluation_factors: List[Annotated[str, MaxLen(16)]] 

3251 """(Abbreviations of) each evaluation factor. 

3252 

3253 Evaluation factors are criteria along which model performance is evaluated, e.g. different image conditions 

3254 like 'low SNR', 'high cell density', or different biological conditions like 'cell type A', 'cell type B'. 

3255 An 'overall' factor may be included to summarize performance across all conditions. 

3256 """ 

3257 

3258 evaluation_factors_long: List[str] 

3259 """Descriptions (long form) of each evaluation factor.""" 

3260 

3261 metrics: List[Annotated[str, MaxLen(16)]] 

3262 """(Abbreviations of) metrics used for evaluation.""" 

3263 

3264 metrics_long: List[str] 

3265 """Description of each metric used.""" 

3266 

3267 @model_validator(mode="after") 

3268 def _validate_list_lengths(self) -> Self: 

3269 if len(self.evaluation_factors) != len(self.evaluation_factors_long): 

3270 raise ValueError( 

3271 "`evaluation_factors` and `evaluation_factors_long` must have the same length" 

3272 ) 

3273 

3274 if len(self.metrics) != len(self.metrics_long): 

3275 raise ValueError("`metrics` and `metrics_long` must have the same length") 

3276 

3277 if len(self.results) != len(self.metrics): 

3278 raise ValueError("`results` must have the same number of rows as `metrics`") 

3279 

3280 for row in self.results: 

3281 if len(row) != len(self.evaluation_factors): 

3282 raise ValueError( 

3283 "`results` must have the same number of columns (in every row) as `evaluation_factors`" 

3284 ) 

3285 

3286 return self 

3287 

3288 results: List[List[Union[str, float, int]]] 

3289 """Results for each metric (rows; outer list) and each evaluation factor (columns; inner list).""" 

3290 

3291 results_summary: Optional[str] = None 

3292 """Interpretation of results for general audience. 

3293 

3294 Consider: 

3295 - Overall model performance 

3296 - Comparison to existing methods 

3297 - Limitations and areas for improvement 

3298 

3299""" 

3300 

3301 def format_md(self): 

3302 results_header = ["Metric"] + self.evaluation_factors 

3303 results_table_cells = [results_header, ["---"] * len(results_header)] + [ 

3304 [metric] + [str(r) for r in row] 

3305 for metric, row in zip(self.metrics, self.results) 

3306 ] 

3307 

3308 results_table = "".join( 

3309 "| " + " | ".join(row) + " |\n" for row in results_table_cells 

3310 ) 

3311 factors = "".join( 

3312 f"\n - {ef}: {efl}" 

3313 for ef, efl in zip(self.evaluation_factors, self.evaluation_factors_long) 

3314 ) 

3315 metrics = "".join( 

3316 f"\n - {em}: {eml}" for em, eml in zip(self.metrics, self.metrics_long) 

3317 ) 

3318 

3319 return f"""## Testing Data, Factors & Metrics 

3320 

3321Evaluation of {self.model_id or "this"} model on the {self.dataset_id} dataset (dataset role: {self.dataset_role}). 

3322 

3323### Testing Data 

3324 

3325- **Source:** [{self.dataset_id}]({self.dataset_source}) 

3326- **Size:** {self.sample_count} evaluated samples 

3327 

3328### Factors 

3329{factors} 

3330 

3331### Metrics 

3332{metrics} 

3333 

3334## Results 

3335 

3336### Quantitative Results 

3337 

3338{results_table} 

3339 

3340### Summary 

3341 

3342{self.results_summary or "missing"} 

3343 

3344""" 

3345 

3346 

3347class EnvironmentalImpact(Node, extra="allow"): 

3348 """Environmental considerations for model training and deployment. 

3349 

3350 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). 

3351 """ 

3352 

3353 hardware_type: Optional[str] = None 

3354 """GPU/CPU specifications""" 

3355 

3356 hours_used: Optional[float] = None 

3357 """Total compute hours""" 

3358 

3359 cloud_provider: Optional[str] = None 

3360 """If applicable""" 

3361 

3362 compute_region: Optional[str] = None 

3363 """Geographic location""" 

3364 

3365 co2_emitted: Optional[float] = None 

3366 """kg CO2 equivalent 

3367 

3368 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). 

3369 """ 

3370 

3371 def format_md(self): 

3372 """Filled Markdown template section following [Hugging Face Model Card Template](https://huggingface.co/docs/hub/en/model-card-annotated).""" 

3373 if self == self.__class__(): 

3374 return "" 

3375 

3376 ret = "# Environmental Impact\n\n" 

3377 if self.hardware_type is not None: 

3378 ret += f"- **Hardware Type:** {self.hardware_type}\n" 

3379 if self.hours_used is not None: 

3380 ret += f"- **Hours used:** {self.hours_used}\n" 

3381 if self.cloud_provider is not None: 

3382 ret += f"- **Cloud Provider:** {self.cloud_provider}\n" 

3383 if self.compute_region is not None: 

3384 ret += f"- **Compute Region:** {self.compute_region}\n" 

3385 if self.co2_emitted is not None: 

3386 ret += f"- **Carbon Emitted:** {self.co2_emitted} kg CO2e\n" 

3387 

3388 return ret + "\n" 

3389 

3390 

3391class BioimageioConfig(Node, extra="allow"): 

3392 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

3393 """Tolerances to allow when reproducing the model's test outputs 

3394 from the model's test inputs. 

3395 Only the first entry matching tensor id and weights format is considered. 

3396 """ 

3397 

3398 funded_by: Optional[str] = None 

3399 """Funding agency, grant number if applicable""" 

3400 

3401 architecture_type: Optional[Annotated[str, MaxLen(32)]] = ( 

3402 None # TODO: add to differentiated tags 

3403 ) 

3404 """Model architecture type, e.g., 3D U-Net, ResNet, transformer""" 

3405 

3406 architecture_description: Optional[str] = None 

3407 """Text description of model architecture.""" 

3408 

3409 modality: Optional[str] = None # TODO: add to differentiated tags 

3410 """Input modality, e.g., fluorescence microscopy, electron microscopy""" 

3411 

3412 target_structure: List[str] = Field( # TODO: add to differentiated tags 

3413 default_factory=cast(Callable[[], List[str]], list) 

3414 ) 

3415 """Biological structure(s) the model is designed to analyze, e.g., nuclei, mitochondria, cells""" 

3416 

3417 task: Optional[str] = None # TODO: add to differentiated tags 

3418 """Bioimage-specific task type, e.g., segmentation, classification, detection, denoising""" 

3419 

3420 new_version: Optional[ModelId] = None 

3421 """A new version of this model exists with a different model id.""" 

3422 

3423 out_of_scope_use: Optional[str] = None 

3424 """Describe how the model may be misused in bioimage analysis contexts and what users should **not** do with the model.""" 

3425 

3426 bias_risks_limitations: BiasRisksLimitations = Field( 

3427 default_factory=BiasRisksLimitations.model_construct 

3428 ) 

3429 """Description of known bias, risks, and technical limitations for in-scope model use.""" 

3430 

3431 model_parameter_count: Optional[int] = None 

3432 """Total number of model parameters.""" 

3433 

3434 training: TrainingDetails = Field(default_factory=TrainingDetails.model_construct) 

3435 """Details on how the model was trained.""" 

3436 

3437 inference_time: Optional[str] = None 

3438 """Average inference time per image/tile. Specify hardware and image size. Multiple examples can be given.""" 

3439 

3440 memory_requirements_inference: Optional[str] = None 

3441 """GPU memory needed for inference. Multiple examples with different image size can be given.""" 

3442 

3443 memory_requirements_training: Optional[str] = None 

3444 """GPU memory needed for training. Multiple examples with different image/batch sizes can be given.""" 

3445 

3446 evaluations: List[Evaluation] = Field( 

3447 default_factory=cast(Callable[[], List[Evaluation]], list) 

3448 ) 

3449 """Quantitative model evaluations. 

3450 

3451 Note: 

3452 At the moment we recommend to include only a single test dataset 

3453 (with evaluation factors that may mark subsets of the dataset) 

3454 to avoid confusion and make the presentation of results cleaner. 

3455 """ 

3456 

3457 environmental_impact: EnvironmentalImpact = Field( 

3458 default_factory=EnvironmentalImpact.model_construct 

3459 ) 

3460 """Environmental considerations for model training and deployment""" 

3461 

3462 

3463class Config(Node, extra="allow"): 

3464 bioimageio: BioimageioConfig = Field( 

3465 default_factory=BioimageioConfig.model_construct 

3466 ) 

3467 stardist: YamlValue = None 

3468 

3469 

3470class ModelDescr(GenericModelDescrBase): 

3471 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

3472 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

3473 """ 

3474 

3475 implemented_format_version: ClassVar[Literal["0.5.11"]] = "0.5.11" 

3476 if TYPE_CHECKING: 

3477 format_version: Literal["0.5.11"] = "0.5.11" 

3478 else: 

3479 format_version: Literal["0.5.11"] 

3480 """Version of the bioimage.io model description specification used. 

3481 When creating a new model always use the latest micro/patch version described here. 

3482 The `format_version` is important for any consumer software to understand how to parse the fields. 

3483 """ 

3484 

3485 implemented_type: ClassVar[Literal["model"]] = "model" 

3486 if TYPE_CHECKING: 

3487 type: Literal["model"] = "model" 

3488 else: 

3489 type: Literal["model"] 

3490 """Specialized resource type 'model'""" 

3491 

3492 id: Optional[ModelId] = None 

3493 """bioimage.io-wide unique resource identifier 

3494 assigned by bioimage.io; version **un**specific.""" 

3495 

3496 authors: FAIR[List[Author]] = Field( 

3497 default_factory=cast(Callable[[], List[Author]], list) 

3498 ) 

3499 """The authors are the creators of the model RDF and the primary points of contact.""" 

3500 

3501 documentation: FAIR[Optional[FileDescr_documentation]] = None 

3502 """Additional model documentation. 

3503 The recommended documentation source file name is `README.md`. An `.md` suffix is mandatory. 

3504 The documentation should include a '#[#] Validation' (sub)section 

3505 with details on how to quantitatively validate the model on unseen data.""" 

3506 

3507 @field_validator("documentation", mode="after") 

3508 @classmethod 

3509 def _validate_documentation(cls, value: Optional[FileDescr]) -> Optional[FileDescr]: 

3510 if not get_validation_context().perform_io_checks or value is None: 

3511 return value 

3512 

3513 doc_reader = get_reader(value) 

3514 doc_content = doc_reader.read().decode(encoding="utf-8") 

3515 if not re.search("#.*[vV]alidation", doc_content): 

3516 issue_warning( 

3517 "No '# Validation' (sub)section found in {value}.", 

3518 value=value, 

3519 field="documentation", 

3520 ) 

3521 

3522 return value 

3523 

3524 inputs: NotEmpty[Sequence[InputTensorDescr]] 

3525 """Describes the input tensors expected by this model.""" 

3526 

3527 @field_validator("inputs", mode="after") 

3528 @classmethod 

3529 def _validate_input_axes( 

3530 cls, inputs: Sequence[InputTensorDescr] 

3531 ) -> Sequence[InputTensorDescr]: 

3532 input_size_refs = cls._get_axes_with_independent_size(inputs) 

3533 

3534 for i, ipt in enumerate(inputs): 

3535 valid_independent_refs: Dict[ 

3536 Tuple[TensorId, AxisId], 

3537 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3538 ] = { 

3539 **{ 

3540 (ipt.id, a.id): (ipt, a, a.size) 

3541 for a in ipt.axes 

3542 if not isinstance(a, BatchAxis) 

3543 and isinstance(a.size, (int, ParameterizedSize)) 

3544 }, 

3545 **input_size_refs, 

3546 } 

3547 for a, ax in enumerate(ipt.axes): 

3548 cls._validate_axis( 

3549 "inputs", 

3550 i=i, 

3551 tensor_id=ipt.id, 

3552 a=a, 

3553 axis=ax, 

3554 valid_independent_refs=valid_independent_refs, 

3555 ) 

3556 return inputs 

3557 

3558 @staticmethod 

3559 def _validate_axis( 

3560 field_name: str, 

3561 i: int, 

3562 tensor_id: TensorId, 

3563 a: int, 

3564 axis: AnyAxis, 

3565 valid_independent_refs: Dict[ 

3566 Tuple[TensorId, AxisId], 

3567 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3568 ], 

3569 ): 

3570 if isinstance(axis, BatchAxis) or isinstance( 

3571 axis.size, (int, ParameterizedSize, DataDependentSize) 

3572 ): 

3573 return 

3574 elif not isinstance(axis.size, SizeReference): 

3575 assert_never(axis.size) 

3576 

3577 # validate axis.size SizeReference 

3578 ref = (axis.size.tensor_id, axis.size.axis_id) 

3579 if ref not in valid_independent_refs: 

3580 raise ValueError( 

3581 "Invalid tensor axis reference at" 

3582 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

3583 ) 

3584 if ref == (tensor_id, axis.id): 

3585 raise ValueError( 

3586 "Self-referencing not allowed for" 

3587 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

3588 ) 

3589 if axis.type == "channel": 

3590 if valid_independent_refs[ref][1].type != "channel": 

3591 raise ValueError( 

3592 "A channel axis' size may only reference another fixed size" 

3593 + " channel axis." 

3594 ) 

3595 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

3596 ref_size = valid_independent_refs[ref][2] 

3597 assert isinstance(ref_size, int), ( 

3598 "channel axis ref (another channel axis) has to specify fixed" 

3599 + " size" 

3600 ) 

3601 generated_channel_names = [ 

3602 axis.channel_names.format(i=i) for i in range(1, ref_size + 1) 

3603 ] 

3604 axis.channel_names = generated_channel_names 

3605 

3606 if (ax_unit := getattr(axis, "unit", None)) != ( 

3607 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

3608 ): 

3609 raise ValueError( 

3610 "The units of an axis and its reference axis need to match, but" 

3611 + f" '{ax_unit}' != '{ref_unit}'." 

3612 ) 

3613 ref_axis = valid_independent_refs[ref][1] 

3614 if isinstance(ref_axis, BatchAxis): 

3615 raise ValueError( 

3616 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

3617 + " (a batch axis is not allowed as reference)." 

3618 ) 

3619 

3620 if isinstance(axis, WithHalo): 

3621 min_size = axis.size.get_size(axis, ref_axis, n=0) 

3622 if (min_size - 2 * axis.halo) < 1: 

3623 raise ValueError( 

3624 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

3625 + f" {axis.halo}." 

3626 ) 

3627 

3628 ref_halo = axis.halo * axis.scale / ref_axis.scale 

3629 if ref_halo != int(ref_halo): 

3630 raise ValueError( 

3631 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} =" 

3632 + f" {tensor_id}.{axis.id}.halo {axis.halo}" 

3633 + f" * {tensor_id}.{axis.id}.scale {axis.scale}" 

3634 + f" / {'.'.join(ref)}.scale {ref_axis.scale})." 

3635 ) 

3636 

3637 def validate_input_tensors( 

3638 self, 

3639 sources: Union[ 

3640 Sequence[NDArray[Any]], Mapping[TensorId, Optional[NDArray[Any]]] 

3641 ], 

3642 *, 

3643 pad_inputs: Union[bool, Literal["allow"]] = True, 

3644 crop_outputs: Union[bool, Literal["allow"]] = True, 

3645 ) -> Mapping[TensorId, Optional[NDArray[Any]]]: 

3646 """Check if the given input tensors match the model's input tensor descriptions. 

3647 This includes checks of tensor shapes and dtypes, but not of the actual values. 

3648 """ 

3649 if not isinstance(sources, collections.abc.Mapping): 

3650 sources = {descr.id: tensor for descr, tensor in zip(self.inputs, sources)} 

3651 

3652 tensors = { 

3653 **{descr.id: (descr, sources.get(descr.id)) for descr in self.inputs}, 

3654 **{ # outputs are required for halo 

3655 descr.id: (descr, None) for descr in self.outputs 

3656 }, 

3657 } 

3658 validate_tensors(tensors, pad_inputs=pad_inputs, crop_outputs=crop_outputs) 

3659 

3660 return sources 

3661 

3662 @model_validator(mode="after") 

3663 def _validate_test_tensors(self) -> Self: 

3664 if not get_validation_context().perform_io_checks: 

3665 return self 

3666 

3667 test_inputs = { 

3668 descr.id: ( 

3669 descr, 

3670 None if descr.test_tensor is None else load_array(descr.test_tensor), 

3671 ) 

3672 for descr in self.inputs 

3673 } 

3674 test_outputs = { 

3675 descr.id: ( 

3676 descr, 

3677 None if descr.test_tensor is None else load_array(descr.test_tensor), 

3678 ) 

3679 for descr in self.outputs 

3680 } 

3681 

3682 validate_tensors( 

3683 {**test_inputs, **test_outputs}, 

3684 tensor_origin="test_tensor", 

3685 pad_inputs="allow", 

3686 crop_outputs="allow", 

3687 ) 

3688 

3689 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

3690 if not rep_tol.absolute_tolerance: 

3691 continue 

3692 

3693 if rep_tol.output_ids: 

3694 out_arrays = { 

3695 k: v[1] for k, v in test_outputs.items() if k in rep_tol.output_ids 

3696 } 

3697 else: 

3698 out_arrays = {k: v[1] for k, v in test_outputs.items()} 

3699 

3700 for out_id, array in out_arrays.items(): 

3701 if array is None: 

3702 continue 

3703 

3704 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

3705 raise ValueError( 

3706 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

3707 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

3708 + f" (1% of the maximum value of the test tensor '{out_id}')" 

3709 ) 

3710 

3711 return self 

3712 

3713 @model_validator(mode="after") 

3714 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

3715 ipt_refs = {t.id for t in self.inputs} 

3716 missing_refs = [ 

3717 k["reference_tensor"] 

3718 for k in [p.kwargs for ipt in self.inputs for p in ipt.preprocessing] 

3719 + [p.kwargs for out in self.outputs for p in out.postprocessing] 

3720 if "reference_tensor" in k 

3721 and k["reference_tensor"] is not None 

3722 and k["reference_tensor"] not in ipt_refs 

3723 ] 

3724 

3725 if missing_refs: 

3726 raise ValueError( 

3727 f"`reference_tensor`s {missing_refs} not found. Valid input tensor" 

3728 + f" references are: {ipt_refs}." 

3729 ) 

3730 

3731 return self 

3732 

3733 name: Annotated[ 

3734 str, 

3735 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 

3736 MinLen(5), 

3737 MaxLen(128), 

3738 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

3739 ] 

3740 """A human-readable name of this model. 

3741 It should be no longer than 64 characters 

3742 and may only contain letter, number, underscore, minus, parentheses and spaces. 

3743 We recommend to chose a name that refers to the model's task and image modality. 

3744 """ 

3745 

3746 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

3747 """Describes the output tensors.""" 

3748 

3749 @field_validator("outputs", mode="after") 

3750 @classmethod 

3751 def _validate_tensor_ids( 

3752 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

3753 ) -> Sequence[OutputTensorDescr]: 

3754 tensor_ids = [ 

3755 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

3756 ] 

3757 duplicate_tensor_ids: List[str] = [] 

3758 seen: Set[str] = set() 

3759 for t in tensor_ids: 

3760 if t in seen: 

3761 duplicate_tensor_ids.append(t) 

3762 

3763 seen.add(t) 

3764 

3765 if duplicate_tensor_ids: 

3766 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

3767 

3768 return outputs 

3769 

3770 @staticmethod 

3771 def _get_axes_with_parameterized_size( 

3772 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3773 ): 

3774 return { 

3775 f"{t.id}.{a.id}": (t, a, a.size) 

3776 for t in io 

3777 for a in t.axes 

3778 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

3779 } 

3780 

3781 @staticmethod 

3782 def _get_axes_with_independent_size( 

3783 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3784 ): 

3785 return { 

3786 (t.id, a.id): (t, a, a.size) 

3787 for t in io 

3788 for a in t.axes 

3789 if not isinstance(a, BatchAxis) 

3790 and isinstance(a.size, (int, ParameterizedSize)) 

3791 } 

3792 

3793 @field_validator("outputs", mode="after") 

3794 @classmethod 

3795 def _validate_output_axes( 

3796 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

3797 ) -> List[OutputTensorDescr]: 

3798 input_size_refs = cls._get_axes_with_independent_size( 

3799 info.data.get("inputs", []) 

3800 ) 

3801 output_size_refs = cls._get_axes_with_independent_size(outputs) 

3802 

3803 for i, out in enumerate(outputs): 

3804 valid_independent_refs: Dict[ 

3805 Tuple[TensorId, AxisId], 

3806 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3807 ] = { 

3808 **{ 

3809 (out.id, a.id): (out, a, a.size) 

3810 for a in out.axes 

3811 if not isinstance(a, BatchAxis) 

3812 and isinstance(a.size, (int, ParameterizedSize)) 

3813 }, 

3814 **input_size_refs, 

3815 **output_size_refs, 

3816 } 

3817 for a, ax in enumerate(out.axes): 

3818 cls._validate_axis( 

3819 "outputs", 

3820 i, 

3821 out.id, 

3822 a, 

3823 ax, 

3824 valid_independent_refs=valid_independent_refs, 

3825 ) 

3826 

3827 return outputs 

3828 

3829 packaged_by: List[Author] = Field( 

3830 default_factory=cast(Callable[[], List[Author]], list) 

3831 ) 

3832 """The persons that have packaged and uploaded this model. 

3833 Only required if those persons differ from the `authors`.""" 

3834 

3835 parent: Optional[LinkedModel] = None 

3836 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

3837 

3838 @model_validator(mode="after") 

3839 def _validate_parent_is_not_self(self) -> Self: 

3840 if self.parent is not None and self.parent.id == self.id: 

3841 raise ValueError("A model description may not reference itself as parent.") 

3842 

3843 return self 

3844 

3845 run_mode: Annotated[ 

3846 Optional[RunMode], 

3847 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

3848 ] = None 

3849 """Custom run mode for this model: for more complex prediction procedures like test time 

3850 data augmentation that currently cannot be expressed in the specification. 

3851 No standard run modes are defined yet.""" 

3852 

3853 timestamp: Datetime = Field(default_factory=Datetime.now) 

3854 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

3855 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

3856 (In Python a datetime object is valid, too).""" 

3857 

3858 training_data: Annotated[ 

3859 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

3860 Field(union_mode="left_to_right"), 

3861 ] = None 

3862 """The dataset used to train this model""" 

3863 

3864 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

3865 """The weights for this model. 

3866 Weights can be given for different formats, but should otherwise be equivalent. 

3867 The available weight formats determine which consumers can use this model.""" 

3868 

3869 config: Config = Field(default_factory=Config.model_construct) 

3870 

3871 @model_validator(mode="after") 

3872 def _add_default_cover(self) -> Self: 

3873 if not get_validation_context().perform_io_checks or self.covers: 

3874 return self 

3875 

3876 try: 

3877 generated_covers = generate_covers( 

3878 [ 

3879 (t, load_array(t.test_tensor)) 

3880 for t in self.inputs 

3881 if t.test_tensor is not None 

3882 ], 

3883 [ 

3884 (t, load_array(t.test_tensor)) 

3885 for t in self.outputs 

3886 if t.test_tensor is not None 

3887 ], 

3888 ) 

3889 except Exception as e: 

3890 issue_warning( 

3891 "Failed to generate cover image(s): {e}", 

3892 value=self.covers, 

3893 msg_context=dict(e=e), 

3894 field="covers", 

3895 ) 

3896 else: 

3897 self.covers.extend(generated_covers) 

3898 

3899 return self 

3900 

3901 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

3902 return self._get_test_arrays(self.inputs) 

3903 

3904 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

3905 return self._get_test_arrays(self.outputs) 

3906 

3907 @staticmethod 

3908 def _get_test_arrays( 

3909 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3910 ): 

3911 ts: List[FileDescr] = [] 

3912 for d in io_descr: 

3913 if d.test_tensor is None: 

3914 raise ValueError( 

3915 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 

3916 ) 

3917 ts.append(d.test_tensor) 

3918 

3919 data = [load_array(t) for t in ts] 

3920 assert all(isinstance(d, np.ndarray) for d in data) 

3921 return data 

3922 

3923 @staticmethod 

3924 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

3925 batch_size = 1 

3926 tensor_with_batchsize: Optional[TensorId] = None 

3927 for tid in tensor_sizes: 

3928 for aid, s in tensor_sizes[tid].items(): 

3929 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

3930 continue 

3931 

3932 if batch_size != 1: 

3933 assert tensor_with_batchsize is not None 

3934 raise ValueError( 

3935 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

3936 ) 

3937 

3938 batch_size = s 

3939 tensor_with_batchsize = tid 

3940 

3941 return batch_size 

3942 

3943 def get_output_tensor_sizes( 

3944 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

3945 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

3946 """Returns the tensor output sizes for given **input_sizes**. 

3947 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

3948 Otherwise it might be larger than the actual (valid) output""" 

3949 batch_size = self.get_batch_size(input_sizes) 

3950 ns = self.get_ns(input_sizes) 

3951 

3952 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

3953 return tensor_sizes.outputs 

3954 

3955 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

3956 """get parameter `n` for each parameterized axis 

3957 such that the valid input size is >= the given input size""" 

3958 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

3959 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

3960 for tid in input_sizes: 

3961 for aid, s in input_sizes[tid].items(): 

3962 size_descr = axes[tid][aid].size 

3963 if isinstance(size_descr, ParameterizedSize): 

3964 ret[(tid, aid)] = size_descr.get_n(s) 

3965 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

3966 pass 

3967 else: 

3968 assert_never(size_descr) 

3969 

3970 return ret 

3971 

3972 def get_tensor_sizes( 

3973 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

3974 ) -> _TensorSizes: 

3975 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

3976 return _TensorSizes( 

3977 { 

3978 t: { 

3979 aa: axis_sizes.inputs[(tt, aa)] 

3980 for tt, aa in axis_sizes.inputs 

3981 if tt == t 

3982 } 

3983 for t in {tt for tt, _ in axis_sizes.inputs} 

3984 }, 

3985 { 

3986 t: { 

3987 aa: axis_sizes.outputs[(tt, aa)] 

3988 for tt, aa in axis_sizes.outputs 

3989 if tt == t 

3990 } 

3991 for t in {tt for tt, _ in axis_sizes.outputs} 

3992 }, 

3993 ) 

3994 

3995 def get_axis_sizes( 

3996 self, 

3997 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3998 batch_size: Optional[int] = None, 

3999 *, 

4000 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

4001 ) -> _AxisSizes: 

4002 """Determine input and output block shape for scale factors **ns** 

4003 of parameterized input sizes. 

4004 

4005 Args: 

4006 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

4007 that is parameterized as `size = min + n * step`. 

4008 batch_size: The desired size of the batch dimension. 

4009 If given **batch_size** overwrites any batch size present in 

4010 **max_input_shape**. Default 1. 

4011 max_input_shape: Limits the derived block shapes. 

4012 Each axis for which the input size, parameterized by `n`, is larger 

4013 than **max_input_shape** is set to the minimal value `n_min` for which 

4014 this is still true. 

4015 Use this for small input samples or large values of **ns**. 

4016 Or simply whenever you know the full input shape. 

4017 

4018 Returns: 

4019 Resolved axis sizes for model inputs and outputs. 

4020 """ 

4021 max_input_shape = max_input_shape or {} 

4022 if batch_size is None: 

4023 for (_t_id, a_id), s in max_input_shape.items(): 

4024 if a_id == BATCH_AXIS_ID: 

4025 batch_size = s 

4026 break 

4027 else: 

4028 batch_size = 1 

4029 

4030 all_axes = { 

4031 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

4032 } 

4033 

4034 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

4035 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

4036 

4037 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

4038 if isinstance(a, BatchAxis): 

4039 if (t_descr.id, a.id) in ns: 

4040 logger.warning( 

4041 "Ignoring unexpected size increment factor (n) for batch axis" 

4042 + " of tensor '{}'.", 

4043 t_descr.id, 

4044 ) 

4045 return batch_size 

4046 elif isinstance(a.size, int): 

4047 if (t_descr.id, a.id) in ns: 

4048 logger.warning( 

4049 "Ignoring unexpected size increment factor (n) for fixed size" 

4050 + " axis '{}' of tensor '{}'.", 

4051 a.id, 

4052 t_descr.id, 

4053 ) 

4054 return a.size 

4055 elif isinstance(a.size, ParameterizedSize): 

4056 if (t_descr.id, a.id) not in ns: 

4057 raise ValueError( 

4058 "Size increment factor (n) missing for parametrized axis" 

4059 + f" '{a.id}' of tensor '{t_descr.id}'." 

4060 ) 

4061 n = ns[(t_descr.id, a.id)] 

4062 s_max = max_input_shape.get((t_descr.id, a.id)) 

4063 if s_max is not None: 

4064 n = min(n, a.size.get_n(s_max)) 

4065 

4066 return a.size.get_size(n) 

4067 

4068 elif isinstance(a.size, SizeReference): 

4069 if (t_descr.id, a.id) in ns: 

4070 logger.warning( 

4071 "Ignoring unexpected size increment factor (n) for axis '{}'" 

4072 + " of tensor '{}' with size reference.", 

4073 a.id, 

4074 t_descr.id, 

4075 ) 

4076 assert not isinstance(a, BatchAxis) 

4077 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

4078 assert not isinstance(ref_axis, BatchAxis) 

4079 ref_key = (a.size.tensor_id, a.size.axis_id) 

4080 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

4081 assert ref_size is not None, ref_key 

4082 assert not isinstance(ref_size, _DataDepSize), ref_key 

4083 return a.size.get_size( 

4084 axis=a, 

4085 ref_axis=ref_axis, 

4086 ref_size=ref_size, 

4087 ) 

4088 elif isinstance(a.size, DataDependentSize): 

4089 if (t_descr.id, a.id) in ns: 

4090 logger.warning( 

4091 "Ignoring unexpected increment factor (n) for data dependent" 

4092 + " size axis '{}' of tensor '{}'.", 

4093 a.id, 

4094 t_descr.id, 

4095 ) 

4096 return _DataDepSize(a.size.min, a.size.max) 

4097 else: 

4098 assert_never(a.size) 

4099 

4100 # first resolve all , but the `SizeReference` input sizes 

4101 for t_descr in self.inputs: 

4102 for a in t_descr.axes: 

4103 if not isinstance(a.size, SizeReference): 

4104 s = get_axis_size(a) 

4105 assert not isinstance(s, _DataDepSize) 

4106 inputs[t_descr.id, a.id] = s 

4107 

4108 # resolve all other input axis sizes 

4109 for t_descr in self.inputs: 

4110 for a in t_descr.axes: 

4111 if isinstance(a.size, SizeReference): 

4112 s = get_axis_size(a) 

4113 assert not isinstance(s, _DataDepSize) 

4114 inputs[t_descr.id, a.id] = s 

4115 

4116 # resolve all output axis sizes 

4117 for t_descr in self.outputs: 

4118 for a in t_descr.axes: 

4119 assert not isinstance(a.size, ParameterizedSize) 

4120 s = get_axis_size(a) 

4121 outputs[t_descr.id, a.id] = s 

4122 

4123 return _AxisSizes(inputs=inputs, outputs=outputs) 

4124 

4125 @model_validator(mode="before") 

4126 @classmethod 

4127 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

4128 cls.convert_from_old_format_wo_validation(data) 

4129 return data 

4130 

4131 @classmethod 

4132 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

4133 """Convert metadata following an older format version to this classes' format 

4134 without validating the result. 

4135 """ 

4136 if ( 

4137 data.get("type") == "model" 

4138 and isinstance(fv := data.get("format_version"), str) 

4139 and fv.count(".") == 2 

4140 ): 

4141 fv_parts = fv.split(".") 

4142 if any(not p.isdigit() for p in fv_parts): 

4143 return 

4144 

4145 fv_tuple = tuple(map(int, fv_parts)) 

4146 

4147 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

4148 if fv_tuple[:2] in ((0, 3), (0, 4)): 

4149 m04 = _ModelDescr_v0_4.load(data) 

4150 if isinstance(m04, InvalidDescr): 

4151 try: 

4152 updated = _model_conv.convert_as_dict( 

4153 m04 # pyright: ignore[reportArgumentType] 

4154 ) 

4155 except Exception as e: 

4156 logger.error( 

4157 "Failed to convert from invalid model 0.4 description." 

4158 + f"\nerror: {e}" 

4159 + "\nProceeding with model 0.5 validation without conversion." 

4160 ) 

4161 updated = None 

4162 else: 

4163 updated = _model_conv.convert_as_dict(m04) 

4164 

4165 if updated is not None: 

4166 data.clear() 

4167 data.update(updated) 

4168 

4169 elif fv_tuple[:2] == (0, 5): 

4170 # bump patch version 

4171 data["format_version"] = cls.implemented_format_version 

4172 

4173 if fv_tuple[:2] in ((0, 3), (0, 4)) or ( 

4174 fv_tuple[:2] == (0, 5) and fv_tuple[2] < 11 

4175 ): 

4176 convert_plain_covers_and_docs_and_icon(data) 

4177 

4178 

4179class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

4180 def _convert( 

4181 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

4182 ) -> "ModelDescr | dict[str, Any]": 

4183 name = "".join( 

4184 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

4185 for c in src.name 

4186 ) 

4187 

4188 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

4189 conv = ( 

4190 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

4191 ) 

4192 return None if auths is None else [conv(a) for a in auths] 

4193 

4194 if TYPE_CHECKING: 

4195 arch_file_conv = _arch_file_conv.convert 

4196 arch_lib_conv = _arch_lib_conv.convert 

4197 else: 

4198 arch_file_conv = _arch_file_conv.convert_as_dict 

4199 arch_lib_conv = _arch_lib_conv.convert_as_dict 

4200 

4201 input_size_refs = { 

4202 ipt.name: { 

4203 a: s 

4204 for a, s in zip( 

4205 ipt.axes, 

4206 ( 

4207 ipt.shape.min 

4208 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

4209 else ipt.shape 

4210 ), 

4211 ) 

4212 } 

4213 for ipt in src.inputs 

4214 if ipt.shape 

4215 } 

4216 output_size_refs = { 

4217 **{ 

4218 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

4219 for out in src.outputs 

4220 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

4221 }, 

4222 **input_size_refs, 

4223 } 

4224 

4225 return tgt( 

4226 attachments=( 

4227 [] 

4228 if src.attachments is None 

4229 else [FileDescr(source=f) for f in src.attachments.files] 

4230 ), 

4231 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType] 

4232 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType] 

4233 config=src.config, # pyright: ignore[reportArgumentType] 

4234 covers=[{"source": c} for c in src.covers], # pyright: ignore[reportArgumentType] 

4235 description=src.description, 

4236 documentation={"source": src.documentation} if src.documentation else None, # pyright: ignore[reportArgumentType] 

4237 format_version="0.5.11", 

4238 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

4239 icon={"source": src.icon} if src.icon else None, # pyright: ignore[reportArgumentType] 

4240 id=None if src.id is None else ModelId(src.id), 

4241 id_emoji=src.id_emoji, 

4242 license=src.license, # type: ignore 

4243 links=src.links, 

4244 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType] 

4245 name=name, 

4246 tags=src.tags, 

4247 type=src.type, 

4248 uploader=src.uploader, 

4249 version=src.version, 

4250 inputs=[ # pyright: ignore[reportArgumentType] 

4251 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

4252 for ipt, tt, st in zip( 

4253 src.inputs, 

4254 src.test_inputs, 

4255 src.sample_inputs or [None] * len(src.test_inputs), 

4256 ) 

4257 ], 

4258 outputs=[ # pyright: ignore[reportArgumentType] 

4259 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

4260 for out, tt, st in zip( 

4261 src.outputs, 

4262 src.test_outputs, 

4263 src.sample_outputs or [None] * len(src.test_outputs), 

4264 ) 

4265 ], 

4266 parent=( 

4267 None 

4268 if src.parent is None 

4269 else LinkedModel( 

4270 id=ModelId( 

4271 str(src.parent.id) 

4272 + ( 

4273 "" 

4274 if src.parent.version_number is None 

4275 else f"/{src.parent.version_number}" 

4276 ) 

4277 ) 

4278 ) 

4279 ), 

4280 training_data=( 

4281 None 

4282 if src.training_data is None 

4283 else ( 

4284 LinkedDataset( 

4285 id=DatasetId( 

4286 str(src.training_data.id) 

4287 + ( 

4288 "" 

4289 if src.training_data.version_number is None 

4290 else f"/{src.training_data.version_number}" 

4291 ) 

4292 ) 

4293 ) 

4294 if isinstance(src.training_data, LinkedDataset02) 

4295 else src.training_data 

4296 ) 

4297 ), 

4298 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType] 

4299 run_mode=src.run_mode, 

4300 timestamp=src.timestamp, 

4301 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

4302 keras_hdf5=(w := src.weights.keras_hdf5) 

4303 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

4304 authors=conv_authors(w.authors), 

4305 source=w.source, 

4306 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4307 parent=w.parent, 

4308 ), 

4309 onnx=(w := src.weights.onnx) 

4310 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

4311 source=w.source, 

4312 authors=conv_authors(w.authors), 

4313 parent=w.parent, 

4314 opset_version=w.opset_version or 15, 

4315 ), 

4316 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

4317 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

4318 source=w.source, 

4319 authors=conv_authors(w.authors), 

4320 parent=w.parent, 

4321 architecture=( 

4322 arch_file_conv( 

4323 w.architecture, 

4324 w.architecture_sha256, 

4325 w.kwargs, 

4326 ) 

4327 if isinstance(w.architecture, _CallableFromFile_v0_4) 

4328 else arch_lib_conv(w.architecture, w.kwargs) 

4329 ), 

4330 pytorch_version=w.pytorch_version or Version("1.10"), 

4331 dependencies=( 

4332 None 

4333 if w.dependencies is None 

4334 else (FileDescr if TYPE_CHECKING else dict)( 

4335 source=cast( 

4336 FileSource, 

4337 str(deps := w.dependencies)[ 

4338 ( 

4339 len("conda:") 

4340 if str(deps).startswith("conda:") 

4341 else 0 

4342 ) : 

4343 ], 

4344 ) 

4345 ) 

4346 ), 

4347 ), 

4348 tensorflow_js=(w := src.weights.tensorflow_js) 

4349 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

4350 source=w.source, 

4351 authors=conv_authors(w.authors), 

4352 parent=w.parent, 

4353 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4354 ), 

4355 tensorflow_saved_model_bundle=( 

4356 w := src.weights.tensorflow_saved_model_bundle 

4357 ) 

4358 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

4359 authors=conv_authors(w.authors), 

4360 parent=w.parent, 

4361 source=w.source, 

4362 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4363 dependencies=( 

4364 None 

4365 if w.dependencies is None 

4366 else (FileDescr if TYPE_CHECKING else dict)( 

4367 source=cast( 

4368 FileSource, 

4369 ( 

4370 str(w.dependencies)[len("conda:") :] 

4371 if str(w.dependencies).startswith("conda:") 

4372 else str(w.dependencies) 

4373 ), 

4374 ) 

4375 ) 

4376 ), 

4377 ), 

4378 torchscript=(w := src.weights.torchscript) 

4379 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

4380 source=w.source, 

4381 authors=conv_authors(w.authors), 

4382 parent=w.parent, 

4383 pytorch_version=w.pytorch_version or Version("1.10"), 

4384 ), 

4385 ), 

4386 ) 

4387 

4388 

4389_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

4390 

4391 

4392# create better cover images for 3d data and non-image outputs 

4393def generate_covers( 

4394 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

4395 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

4396) -> List[FileDescr]: 

4397 def squeeze( 

4398 data: NDArray[Any], axes: Sequence[AnyAxis] 

4399 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

4400 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

4401 if data.ndim != len(axes): 

4402 raise ValueError( 

4403 f"tensor shape {data.shape} does not match described axes" 

4404 + f" {[a.id for a in axes]}" 

4405 ) 

4406 

4407 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

4408 return data.squeeze(), axes 

4409 

4410 def normalize( 

4411 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

4412 ) -> NDArray[np.float32]: 

4413 data = data.astype("float32") 

4414 data -= data.min(axis=axis, keepdims=True) 

4415 data /= data.max(axis=axis, keepdims=True) + eps 

4416 return data 

4417 

4418 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

4419 original_shape = data.shape 

4420 original_axes = list(axes) 

4421 data, axes = squeeze(data, axes) 

4422 

4423 # take slice fom any batch or index axis if needed 

4424 # and convert the first channel axis and take a slice from any additional channel axes 

4425 slices: Tuple[slice, ...] = () 

4426 ndim = data.ndim 

4427 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

4428 has_c_axis = False 

4429 for i, a in enumerate(axes): 

4430 s = data.shape[i] 

4431 assert s > 1 

4432 if ( 

4433 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

4434 and ndim > ndim_need 

4435 ): 

4436 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4437 ndim -= 1 

4438 elif isinstance(a, ChannelAxis): 

4439 if has_c_axis: 

4440 # second channel axis 

4441 data = data[slices + (slice(0, 1),)] 

4442 ndim -= 1 

4443 else: 

4444 has_c_axis = True 

4445 if s == 2: 

4446 # visualize two channels with cyan and magenta 

4447 data = np.concatenate( 

4448 [ 

4449 data[slices + (slice(1, 2),)], 

4450 data[slices + (slice(0, 1),)], 

4451 ( 

4452 data[slices + (slice(0, 1),)] 

4453 + data[slices + (slice(1, 2),)] 

4454 ) 

4455 / 2, # TODO: take maximum instead? 

4456 ], 

4457 axis=i, 

4458 ) 

4459 elif data.shape[i] == 3: 

4460 pass # visualize 3 channels as RGB 

4461 else: 

4462 # visualize first 3 channels as RGB 

4463 data = data[slices + (slice(3),)] 

4464 

4465 assert data.shape[i] == 3 

4466 

4467 slices += (slice(None),) 

4468 

4469 data, axes = squeeze(data, axes) 

4470 assert len(axes) == ndim 

4471 # take slice from z axis if needed 

4472 slices = () 

4473 if ndim > ndim_need: 

4474 for i, a in enumerate(axes): 

4475 s = data.shape[i] 

4476 if a.id == AxisId("z"): 

4477 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4478 data, axes = squeeze(data, axes) 

4479 ndim -= 1 

4480 break 

4481 

4482 slices += (slice(None),) 

4483 

4484 # take slice from any space or time axis 

4485 slices = () 

4486 

4487 for i, a in enumerate(axes): 

4488 if ndim <= ndim_need: 

4489 break 

4490 

4491 s = data.shape[i] 

4492 assert s > 1 

4493 if isinstance( 

4494 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

4495 ): 

4496 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4497 ndim -= 1 

4498 

4499 slices += (slice(None),) 

4500 

4501 del slices 

4502 data, axes = squeeze(data, axes) 

4503 assert len(axes) == ndim 

4504 

4505 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2): 

4506 raise ValueError( 

4507 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}." 

4508 ) 

4509 

4510 if not has_c_axis: 

4511 assert ndim == 2 

4512 data = np.repeat(data[:, :, None], 3, axis=2) 

4513 axes.append(ChannelAxis(channel_names=list("RGB"))) 

4514 ndim += 1 

4515 

4516 assert ndim == 3 

4517 

4518 # transpose axis order such that longest axis comes first... 

4519 axis_order: List[int] = list(np.argsort(list(data.shape))) 

4520 axis_order.reverse() 

4521 # ... and channel axis is last 

4522 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

4523 axis_order.append(axis_order.pop(c)) 

4524 axes = [axes[ao] for ao in axis_order] 

4525 data = data.transpose(axis_order) 

4526 

4527 # h, w = data.shape[:2] 

4528 # if h / w in (1.0 or 2.0): 

4529 # pass 

4530 # elif h / w < 2: 

4531 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

4532 

4533 norm_along = ( 

4534 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

4535 ) 

4536 # normalize the data and map to 8 bit 

4537 data = normalize(data, norm_along) 

4538 data = (data * 255).astype("uint8") 

4539 

4540 return data 

4541 

4542 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

4543 assert im0.dtype == im1.dtype == np.uint8 

4544 assert im0.shape == im1.shape 

4545 assert im0.ndim == 3 

4546 N, M, C = im0.shape 

4547 assert C == 3 

4548 out = np.ones((N, M, C), dtype="uint8") 

4549 for c in range(C): 

4550 outc = np.tril(im0[..., c]) 

4551 mask = outc == 0 

4552 outc[mask] = np.triu(im1[..., c])[mask] 

4553 out[..., c] = outc 

4554 

4555 return out 

4556 

4557 if not inputs: 

4558 raise ValueError("Missing test input tensor for cover generation.") 

4559 

4560 if not outputs: 

4561 raise ValueError("Missing test output tensor for cover generation.") 

4562 

4563 ipt_descr, ipt = inputs[0] 

4564 out_descr, out = outputs[0] 

4565 

4566 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

4567 out_img = to_2d_image(out, out_descr.axes) 

4568 

4569 cover_folder = Path(mkdtemp()) 

4570 if ipt_img.shape == out_img.shape: 

4571 covers = [cover_folder / "cover.png"] 

4572 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

4573 else: 

4574 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

4575 imwrite(covers[0], ipt_img) 

4576 imwrite(covers[1], out_img) 

4577 

4578 return [FileDescr(source=c) for c in covers]