Coverage for src / bioimageio / spec / model / v0_5.py: 74%

1409 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-08 13:52 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from copy import deepcopy 

8from itertools import chain 

9from math import ceil 

10from pathlib import Path, PurePosixPath 

11from tempfile import mkdtemp 

12from typing import ( 

13 TYPE_CHECKING, 

14 Any, 

15 Callable, 

16 ClassVar, 

17 Dict, 

18 Generic, 

19 List, 

20 Literal, 

21 Mapping, 

22 NamedTuple, 

23 Optional, 

24 Sequence, 

25 Set, 

26 Tuple, 

27 Type, 

28 TypeVar, 

29 Union, 

30 cast, 

31 overload, 

32) 

33 

34import numpy as np 

35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

37from loguru import logger 

38from numpy.typing import NDArray 

39from pydantic import ( 

40 AfterValidator, 

41 Discriminator, 

42 Field, 

43 RootModel, 

44 SerializationInfo, 

45 SerializerFunctionWrapHandler, 

46 StrictInt, 

47 Tag, 

48 ValidationInfo, 

49 WrapSerializer, 

50 field_validator, 

51 model_serializer, 

52 model_validator, 

53) 

54from typing_extensions import Annotated, Self, assert_never, get_args 

55 

56from .._internal.common_nodes import ( 

57 InvalidDescr, 

58 KwargsNode, 

59 Node, 

60 NodeWithExplicitlySetFields, 

61) 

62from .._internal.constants import DTYPE_LIMITS 

63from .._internal.field_warning import issue_warning, warn 

64from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

65from .._internal.io import FileDescr as FileDescr 

66from .._internal.io import ( 

67 FileSource, 

68 WithSuffix, 

69 YamlValue, 

70 extract_file_name, 

71 get_reader, 

72 wo_special_file_name, 

73) 

74from .._internal.io_basics import Sha256 as Sha256 

75from .._internal.io_packaging import ( 

76 FileDescr_, 

77 FileSource_, 

78 package_file_descr_serializer, 

79) 

80from .._internal.io_utils import load_array 

81from .._internal.node_converter import Converter 

82from .._internal.type_guards import is_dict, is_sequence 

83from .._internal.types import ( 

84 FAIR, 

85 AbsoluteTolerance, 

86 LowerCaseIdentifier, 

87 LowerCaseIdentifierAnno, 

88 MismatchedElementsPerMillion, 

89 RelativeTolerance, 

90) 

91from .._internal.types import Datetime as Datetime 

92from .._internal.types import Identifier as Identifier 

93from .._internal.types import NotEmpty as NotEmpty 

94from .._internal.types import SiUnit as SiUnit 

95from .._internal.url import HttpUrl as HttpUrl 

96from .._internal.validation_context import get_validation_context 

97from .._internal.validator_annotations import RestrictCharacters 

98from .._internal.version_type import Version as Version 

99from .._internal.warning_levels import INFO 

100from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

101from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

102from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

103from ..dataset.v0_3 import DatasetId as DatasetId 

104from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

105from ..dataset.v0_3 import Uploader as Uploader 

106from ..generic.v0_3 import ( 

107 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

108) 

109from ..generic.v0_3 import Author as Author 

110from ..generic.v0_3 import BadgeDescr as BadgeDescr 

111from ..generic.v0_3 import CiteEntry as CiteEntry 

112from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

113from ..generic.v0_3 import Doi as Doi 

114from ..generic.v0_3 import ( 

115 FileSource_documentation, 

116 GenericModelDescrBase, 

117 LinkedResourceBase, 

118 _author_conv, # pyright: ignore[reportPrivateUsage] 

119 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

120) 

121from ..generic.v0_3 import LicenseId as LicenseId 

122from ..generic.v0_3 import LinkedResource as LinkedResource 

123from ..generic.v0_3 import Maintainer as Maintainer 

124from ..generic.v0_3 import OrcidId as OrcidId 

125from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

126from ..generic.v0_3 import ResourceId as ResourceId 

127from .v0_4 import Author as _Author_v0_4 

128from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

129from .v0_4 import CallableFromDepencency as CallableFromDepencency 

130from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

131from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

132from .v0_4 import ClipDescr as _ClipDescr_v0_4 

133from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

134from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

135from .v0_4 import KnownRunMode as KnownRunMode 

136from .v0_4 import ModelDescr as _ModelDescr_v0_4 

137from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

138from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

139from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

140from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

141from .v0_4 import RunMode as RunMode 

142from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

143from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

144from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

145from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

146from .v0_4 import TensorName as _TensorName_v0_4 

147from .v0_4 import WeightsFormat as WeightsFormat 

148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

149from .v0_4 import package_weights 

150 

151SpaceUnit = Literal[ 

152 "attometer", 

153 "angstrom", 

154 "centimeter", 

155 "decimeter", 

156 "exameter", 

157 "femtometer", 

158 "foot", 

159 "gigameter", 

160 "hectometer", 

161 "inch", 

162 "kilometer", 

163 "megameter", 

164 "meter", 

165 "micrometer", 

166 "mile", 

167 "millimeter", 

168 "nanometer", 

169 "parsec", 

170 "petameter", 

171 "picometer", 

172 "terameter", 

173 "yard", 

174 "yoctometer", 

175 "yottameter", 

176 "zeptometer", 

177 "zettameter", 

178] 

179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

180 

181TimeUnit = Literal[ 

182 "attosecond", 

183 "centisecond", 

184 "day", 

185 "decisecond", 

186 "exasecond", 

187 "femtosecond", 

188 "gigasecond", 

189 "hectosecond", 

190 "hour", 

191 "kilosecond", 

192 "megasecond", 

193 "microsecond", 

194 "millisecond", 

195 "minute", 

196 "nanosecond", 

197 "petasecond", 

198 "picosecond", 

199 "second", 

200 "terasecond", 

201 "yoctosecond", 

202 "yottasecond", 

203 "zeptosecond", 

204 "zettasecond", 

205] 

206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

207 

208AxisType = Literal["batch", "channel", "index", "time", "space"] 

209 

210_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

211 "b": "batch", 

212 "t": "time", 

213 "i": "index", 

214 "c": "channel", 

215 "x": "space", 

216 "y": "space", 

217 "z": "space", 

218} 

219 

220_AXIS_ID_MAP = { 

221 "b": "batch", 

222 "t": "time", 

223 "i": "index", 

224 "c": "channel", 

225} 

226 

227 

228class TensorId(LowerCaseIdentifier): 

229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

230 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

231 ] 

232 

233 

234def _normalize_axis_id(a: str): 

235 a = str(a) 

236 normalized = _AXIS_ID_MAP.get(a, a) 

237 if a != normalized: 

238 logger.opt(depth=3).warning( 

239 "Normalized axis id from '{}' to '{}'.", a, normalized 

240 ) 

241 return normalized 

242 

243 

244class AxisId(LowerCaseIdentifier): 

245 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

246 Annotated[ 

247 LowerCaseIdentifierAnno, 

248 MaxLen(16), 

249 AfterValidator(_normalize_axis_id), 

250 ] 

251 ] 

252 

253 

254def _is_batch(a: str) -> bool: 

255 return str(a) == "batch" 

256 

257 

258def _is_not_batch(a: str) -> bool: 

259 return not _is_batch(a) 

260 

261 

262NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

263 

264PreprocessingId = Literal[ 

265 "binarize", 

266 "clip", 

267 "ensure_dtype", 

268 "fixed_zero_mean_unit_variance", 

269 "scale_linear", 

270 "scale_range", 

271 "sigmoid", 

272 "softmax", 

273] 

274PostprocessingId = Literal[ 

275 "binarize", 

276 "clip", 

277 "ensure_dtype", 

278 "fixed_zero_mean_unit_variance", 

279 "scale_linear", 

280 "scale_mean_variance", 

281 "scale_range", 

282 "sigmoid", 

283 "softmax", 

284 "zero_mean_unit_variance", 

285] 

286 

287 

288SAME_AS_TYPE = "<same as type>" 

289 

290 

291ParameterizedSize_N = int 

292""" 

293Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

294""" 

295 

296 

297class ParameterizedSize(Node): 

298 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

299 

300 - **min** and **step** are given by the model description. 

301 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

302 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

303 This allows to adjust the axis size more generically. 

304 """ 

305 

306 N: ClassVar[Type[int]] = ParameterizedSize_N 

307 """Positive integer to parameterize this axis""" 

308 

309 min: Annotated[int, Gt(0)] 

310 step: Annotated[int, Gt(0)] 

311 

312 def validate_size(self, size: int) -> int: 

313 if size < self.min: 

314 raise ValueError(f"size {size} < {self.min}") 

315 if (size - self.min) % self.step != 0: 

316 raise ValueError( 

317 f"axis of size {size} is not parameterized by `min + n*step` =" 

318 + f" `{self.min} + n*{self.step}`" 

319 ) 

320 

321 return size 

322 

323 def get_size(self, n: ParameterizedSize_N) -> int: 

324 return self.min + self.step * n 

325 

326 def get_n(self, s: int) -> ParameterizedSize_N: 

327 """return smallest n parameterizing a size greater or equal than `s`""" 

328 return ceil((s - self.min) / self.step) 

329 

330 

331class DataDependentSize(Node): 

332 min: Annotated[int, Gt(0)] = 1 

333 max: Annotated[Optional[int], Gt(1)] = None 

334 

335 @model_validator(mode="after") 

336 def _validate_max_gt_min(self): 

337 if self.max is not None and self.min >= self.max: 

338 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

339 

340 return self 

341 

342 def validate_size(self, size: int) -> int: 

343 if size < self.min: 

344 raise ValueError(f"size {size} < {self.min}") 

345 

346 if self.max is not None and size > self.max: 

347 raise ValueError(f"size {size} > {self.max}") 

348 

349 return size 

350 

351 

352class SizeReference(Node): 

353 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

354 

355 `axis.size = reference.size * reference.scale / axis.scale + offset` 

356 

357 Note: 

358 1. The axis and the referenced axis need to have the same unit (or no unit). 

359 2. Batch axes may not be referenced. 

360 3. Fractions are rounded down. 

361 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

362 `concatenable` as well with the same block order. 

363 

364 Example: 

365 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

366 Let's assume that we want to express the image height h in relation to its width w 

367 instead of only accepting input images of exactly 100*49 pixels 

368 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

369 

370 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

371 >>> h = SpaceInputAxis( 

372 ... id=AxisId("h"), 

373 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

374 ... unit="millimeter", 

375 ... scale=4, 

376 ... ) 

377 >>> print(h.size.get_size(h, w)) 

378 49 

379 

380 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

381 """ 

382 

383 tensor_id: TensorId 

384 """tensor id of the reference axis""" 

385 

386 axis_id: AxisId 

387 """axis id of the reference axis""" 

388 

389 offset: StrictInt = 0 

390 

391 def get_size( 

392 self, 

393 axis: Union[ 

394 ChannelAxis, 

395 IndexInputAxis, 

396 IndexOutputAxis, 

397 TimeInputAxis, 

398 SpaceInputAxis, 

399 TimeOutputAxis, 

400 TimeOutputAxisWithHalo, 

401 SpaceOutputAxis, 

402 SpaceOutputAxisWithHalo, 

403 ], 

404 ref_axis: Union[ 

405 ChannelAxis, 

406 IndexInputAxis, 

407 IndexOutputAxis, 

408 TimeInputAxis, 

409 SpaceInputAxis, 

410 TimeOutputAxis, 

411 TimeOutputAxisWithHalo, 

412 SpaceOutputAxis, 

413 SpaceOutputAxisWithHalo, 

414 ], 

415 n: ParameterizedSize_N = 0, 

416 ref_size: Optional[int] = None, 

417 ): 

418 """Compute the concrete size for a given axis and its reference axis. 

419 

420 Args: 

421 axis: The axis this [SizeReference][] is the size of. 

422 ref_axis: The reference axis to compute the size from. 

423 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

424 and no fixed **ref_size** is given, 

425 **n** is used to compute the size of the parameterized **ref_axis**. 

426 ref_size: Overwrite the reference size instead of deriving it from 

427 **ref_axis** 

428 (**ref_axis.scale** is still used; any given **n** is ignored). 

429 """ 

430 assert axis.size == self, ( 

431 "Given `axis.size` is not defined by this `SizeReference`" 

432 ) 

433 

434 assert ref_axis.id == self.axis_id, ( 

435 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

436 ) 

437 

438 assert axis.unit == ref_axis.unit, ( 

439 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

440 f" but {axis.unit}!={ref_axis.unit}" 

441 ) 

442 if ref_size is None: 

443 if isinstance(ref_axis.size, (int, float)): 

444 ref_size = ref_axis.size 

445 elif isinstance(ref_axis.size, ParameterizedSize): 

446 ref_size = ref_axis.size.get_size(n) 

447 elif isinstance(ref_axis.size, DataDependentSize): 

448 raise ValueError( 

449 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

450 ) 

451 elif isinstance(ref_axis.size, SizeReference): 

452 raise ValueError( 

453 "Reference axis referenced in `SizeReference` may not be sized by a" 

454 + " `SizeReference` itself." 

455 ) 

456 else: 

457 assert_never(ref_axis.size) 

458 

459 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

460 

461 @staticmethod 

462 def _get_unit( 

463 axis: Union[ 

464 ChannelAxis, 

465 IndexInputAxis, 

466 IndexOutputAxis, 

467 TimeInputAxis, 

468 SpaceInputAxis, 

469 TimeOutputAxis, 

470 TimeOutputAxisWithHalo, 

471 SpaceOutputAxis, 

472 SpaceOutputAxisWithHalo, 

473 ], 

474 ): 

475 return axis.unit 

476 

477 

478class AxisBase(NodeWithExplicitlySetFields): 

479 id: AxisId 

480 """An axis id unique across all axes of one tensor.""" 

481 

482 description: Annotated[str, MaxLen(128)] = "" 

483 """A short description of this axis beyond its type and id.""" 

484 

485 

486class WithHalo(Node): 

487 halo: Annotated[int, Ge(1)] 

488 """The halo should be cropped from the output tensor to avoid boundary effects. 

489 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

490 To document a halo that is already cropped by the model use `size.offset` instead.""" 

491 

492 size: Annotated[ 

493 SizeReference, 

494 Field( 

495 examples=[ 

496 10, 

497 SizeReference( 

498 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

499 ).model_dump(mode="json"), 

500 ] 

501 ), 

502 ] 

503 """reference to another axis with an optional offset (see [SizeReference][])""" 

504 

505 

506BATCH_AXIS_ID = AxisId("batch") 

507 

508 

509class BatchAxis(AxisBase): 

510 implemented_type: ClassVar[Literal["batch"]] = "batch" 

511 if TYPE_CHECKING: 

512 type: Literal["batch"] = "batch" 

513 else: 

514 type: Literal["batch"] 

515 

516 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

517 size: Optional[Literal[1]] = None 

518 """The batch size may be fixed to 1, 

519 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

520 

521 @property 

522 def scale(self): 

523 return 1.0 

524 

525 @property 

526 def concatenable(self): 

527 return True 

528 

529 @property 

530 def unit(self): 

531 return None 

532 

533 

534class ChannelAxis(AxisBase): 

535 implemented_type: ClassVar[Literal["channel"]] = "channel" 

536 if TYPE_CHECKING: 

537 type: Literal["channel"] = "channel" 

538 else: 

539 type: Literal["channel"] 

540 

541 id: NonBatchAxisId = AxisId("channel") 

542 

543 channel_names: NotEmpty[List[Identifier]] 

544 

545 @property 

546 def size(self) -> int: 

547 return len(self.channel_names) 

548 

549 @property 

550 def concatenable(self): 

551 return False 

552 

553 @property 

554 def scale(self) -> float: 

555 return 1.0 

556 

557 @property 

558 def unit(self): 

559 return None 

560 

561 

562class IndexAxisBase(AxisBase): 

563 implemented_type: ClassVar[Literal["index"]] = "index" 

564 if TYPE_CHECKING: 

565 type: Literal["index"] = "index" 

566 else: 

567 type: Literal["index"] 

568 

569 id: NonBatchAxisId = AxisId("index") 

570 

571 @property 

572 def scale(self) -> float: 

573 return 1.0 

574 

575 @property 

576 def unit(self): 

577 return None 

578 

579 

580class _WithInputAxisSize(Node): 

581 size: Annotated[ 

582 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

583 Field( 

584 examples=[ 

585 10, 

586 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

587 SizeReference( 

588 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

589 ).model_dump(mode="json"), 

590 ] 

591 ), 

592 ] 

593 """The size/length of this axis can be specified as 

594 - fixed integer 

595 - parameterized series of valid sizes ([ParameterizedSize][]) 

596 - reference to another axis with an optional offset ([SizeReference][]) 

597 """ 

598 

599 

600class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

601 concatenable: bool = False 

602 """If a model has a `concatenable` input axis, it can be processed blockwise, 

603 splitting a longer sample axis into blocks matching its input tensor description. 

604 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

605 input axis. 

606 """ 

607 

608 

609class IndexOutputAxis(IndexAxisBase): 

610 size: Annotated[ 

611 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

612 Field( 

613 examples=[ 

614 10, 

615 SizeReference( 

616 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

617 ).model_dump(mode="json"), 

618 ] 

619 ), 

620 ] 

621 """The size/length of this axis can be specified as 

622 - fixed integer 

623 - reference to another axis with an optional offset ([SizeReference][]) 

624 - data dependent size using [DataDependentSize][] (size is only known after model inference) 

625 """ 

626 

627 

628class TimeAxisBase(AxisBase): 

629 implemented_type: ClassVar[Literal["time"]] = "time" 

630 if TYPE_CHECKING: 

631 type: Literal["time"] = "time" 

632 else: 

633 type: Literal["time"] 

634 

635 id: NonBatchAxisId = AxisId("time") 

636 unit: Optional[TimeUnit] = None 

637 scale: Annotated[float, Gt(0)] = 1.0 

638 

639 

640class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

641 concatenable: bool = False 

642 """If a model has a `concatenable` input axis, it can be processed blockwise, 

643 splitting a longer sample axis into blocks matching its input tensor description. 

644 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

645 input axis. 

646 """ 

647 

648 

649class SpaceAxisBase(AxisBase): 

650 implemented_type: ClassVar[Literal["space"]] = "space" 

651 if TYPE_CHECKING: 

652 type: Literal["space"] = "space" 

653 else: 

654 type: Literal["space"] 

655 

656 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

657 unit: Optional[SpaceUnit] = None 

658 scale: Annotated[float, Gt(0)] = 1.0 

659 

660 

661class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

662 concatenable: bool = False 

663 """If a model has a `concatenable` input axis, it can be processed blockwise, 

664 splitting a longer sample axis into blocks matching its input tensor description. 

665 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

666 input axis. 

667 """ 

668 

669 

670INPUT_AXIS_TYPES = ( 

671 BatchAxis, 

672 ChannelAxis, 

673 IndexInputAxis, 

674 TimeInputAxis, 

675 SpaceInputAxis, 

676) 

677"""intended for isinstance comparisons in py<3.10""" 

678 

679_InputAxisUnion = Union[ 

680 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

681] 

682InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

683 

684 

685class _WithOutputAxisSize(Node): 

686 size: Annotated[ 

687 Union[Annotated[int, Gt(0)], SizeReference], 

688 Field( 

689 examples=[ 

690 10, 

691 SizeReference( 

692 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

693 ).model_dump(mode="json"), 

694 ] 

695 ), 

696 ] 

697 """The size/length of this axis can be specified as 

698 - fixed integer 

699 - reference to another axis with an optional offset (see [SizeReference][]) 

700 """ 

701 

702 

703class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

704 pass 

705 

706 

707class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

708 pass 

709 

710 

711def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

712 if isinstance(v, dict): 

713 return "with_halo" if "halo" in v else "wo_halo" 

714 else: 

715 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

716 

717 

718_TimeOutputAxisUnion = Annotated[ 

719 Union[ 

720 Annotated[TimeOutputAxis, Tag("wo_halo")], 

721 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

722 ], 

723 Discriminator(_get_halo_axis_discriminator_value), 

724] 

725 

726 

727class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

728 pass 

729 

730 

731class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

732 pass 

733 

734 

735_SpaceOutputAxisUnion = Annotated[ 

736 Union[ 

737 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

738 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

739 ], 

740 Discriminator(_get_halo_axis_discriminator_value), 

741] 

742 

743 

744_OutputAxisUnion = Union[ 

745 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

746] 

747OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

748 

749OUTPUT_AXIS_TYPES = ( 

750 BatchAxis, 

751 ChannelAxis, 

752 IndexOutputAxis, 

753 TimeOutputAxis, 

754 TimeOutputAxisWithHalo, 

755 SpaceOutputAxis, 

756 SpaceOutputAxisWithHalo, 

757) 

758"""intended for isinstance comparisons in py<3.10""" 

759 

760 

761AnyAxis = Union[InputAxis, OutputAxis] 

762 

763ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

764"""intended for isinstance comparisons in py<3.10""" 

765 

766TVs = Union[ 

767 NotEmpty[List[int]], 

768 NotEmpty[List[float]], 

769 NotEmpty[List[bool]], 

770 NotEmpty[List[str]], 

771] 

772 

773 

774NominalOrOrdinalDType = Literal[ 

775 "float32", 

776 "float64", 

777 "uint8", 

778 "int8", 

779 "uint16", 

780 "int16", 

781 "uint32", 

782 "int32", 

783 "uint64", 

784 "int64", 

785 "bool", 

786] 

787 

788 

789class NominalOrOrdinalDataDescr(Node): 

790 values: TVs 

791 """A fixed set of nominal or an ascending sequence of ordinal values. 

792 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

793 String `values` are interpreted as labels for tensor values 0, ..., N. 

794 Note: as YAML 1.2 does not natively support a "set" datatype, 

795 nominal values should be given as a sequence (aka list/array) as well. 

796 """ 

797 

798 type: Annotated[ 

799 NominalOrOrdinalDType, 

800 Field( 

801 examples=[ 

802 "float32", 

803 "uint8", 

804 "uint16", 

805 "int64", 

806 "bool", 

807 ], 

808 ), 

809 ] = "uint8" 

810 

811 @model_validator(mode="after") 

812 def _validate_values_match_type( 

813 self, 

814 ) -> Self: 

815 incompatible: List[Any] = [] 

816 for v in self.values: 

817 if self.type == "bool": 

818 if not isinstance(v, bool): 

819 incompatible.append(v) 

820 elif self.type in DTYPE_LIMITS: 

821 if ( 

822 isinstance(v, (int, float)) 

823 and ( 

824 v < DTYPE_LIMITS[self.type].min 

825 or v > DTYPE_LIMITS[self.type].max 

826 ) 

827 or (isinstance(v, str) and "uint" not in self.type) 

828 or (isinstance(v, float) and "int" in self.type) 

829 ): 

830 incompatible.append(v) 

831 else: 

832 incompatible.append(v) 

833 

834 if len(incompatible) == 5: 

835 incompatible.append("...") 

836 break 

837 

838 if incompatible: 

839 raise ValueError( 

840 f"data type '{self.type}' incompatible with values {incompatible}" 

841 ) 

842 

843 return self 

844 

845 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

846 

847 @property 

848 def range(self): 

849 if isinstance(self.values[0], str): 

850 return 0, len(self.values) - 1 

851 else: 

852 return min(self.values), max(self.values) 

853 

854 

855IntervalOrRatioDType = Literal[ 

856 "float32", 

857 "float64", 

858 "uint8", 

859 "int8", 

860 "uint16", 

861 "int16", 

862 "uint32", 

863 "int32", 

864 "uint64", 

865 "int64", 

866] 

867 

868 

869class IntervalOrRatioDataDescr(Node): 

870 type: Annotated[ # TODO: rename to dtype 

871 IntervalOrRatioDType, 

872 Field( 

873 examples=["float32", "float64", "uint8", "uint16"], 

874 ), 

875 ] = "float32" 

876 range: Tuple[Optional[float], Optional[float]] = ( 

877 None, 

878 None, 

879 ) 

880 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

881 `None` corresponds to min/max of what can be expressed by **type**.""" 

882 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

883 scale: float = 1.0 

884 """Scale for data on an interval (or ratio) scale.""" 

885 offset: Optional[float] = None 

886 """Offset for data on a ratio scale.""" 

887 

888 @model_validator(mode="before") 

889 def _replace_inf(cls, data: Any): 

890 if is_dict(data): 

891 if "range" in data and is_sequence(data["range"]): 

892 forbidden = ( 

893 "inf", 

894 "-inf", 

895 ".inf", 

896 "-.inf", 

897 float("inf"), 

898 float("-inf"), 

899 ) 

900 if any(v in forbidden for v in data["range"]): 

901 issue_warning("replaced 'inf' value", value=data["range"]) 

902 

903 data["range"] = tuple( 

904 (None if v in forbidden else v) for v in data["range"] 

905 ) 

906 

907 return data 

908 

909 

910TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

911 

912 

913class BinarizeKwargs(KwargsNode): 

914 """key word arguments for [BinarizeDescr][]""" 

915 

916 threshold: float 

917 """The fixed threshold""" 

918 

919 

920class BinarizeAlongAxisKwargs(KwargsNode): 

921 """key word arguments for [BinarizeDescr][]""" 

922 

923 threshold: NotEmpty[List[float]] 

924 """The fixed threshold values along `axis`""" 

925 

926 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

927 """The `threshold` axis""" 

928 

929 

930class BinarizeDescr(NodeWithExplicitlySetFields): 

931 """Binarize the tensor with a fixed threshold. 

932 

933 Values above [BinarizeKwargs.threshold][]/[BinarizeAlongAxisKwargs.threshold][] 

934 will be set to one, values below the threshold to zero. 

935 

936 Examples: 

937 - in YAML 

938 ```yaml 

939 postprocessing: 

940 - id: binarize 

941 kwargs: 

942 axis: 'channel' 

943 threshold: [0.25, 0.5, 0.75] 

944 ``` 

945 - in Python: 

946 >>> postprocessing = [BinarizeDescr( 

947 ... kwargs=BinarizeAlongAxisKwargs( 

948 ... axis=AxisId('channel'), 

949 ... threshold=[0.25, 0.5, 0.75], 

950 ... ) 

951 ... )] 

952 """ 

953 

954 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

955 if TYPE_CHECKING: 

956 id: Literal["binarize"] = "binarize" 

957 else: 

958 id: Literal["binarize"] 

959 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

960 

961 

962class ClipKwargs(KwargsNode): 

963 """key word arguments for [ClipDescr][]""" 

964 

965 min: Optional[float] = None 

966 """Minimum value for clipping. 

967 

968 Exclusive with [min_percentile][] 

969 """ 

970 min_percentile: Optional[Annotated[float, Interval(ge=0, lt=100)]] = None 

971 """Minimum percentile for clipping. 

972 

973 Exclusive with [min][]. 

974 

975 In range [0, 100). 

976 """ 

977 

978 max: Optional[float] = None 

979 """Maximum value for clipping. 

980 

981 Exclusive with `max_percentile`. 

982 """ 

983 max_percentile: Optional[Annotated[float, Interval(gt=1, le=100)]] = None 

984 """Maximum percentile for clipping. 

985 

986 Exclusive with `max`. 

987 

988 In range (1, 100]. 

989 """ 

990 

991 axes: Annotated[ 

992 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

993 ] = None 

994 """The subset of axes to determine percentiles jointly, 

995 

996 i.e. axes to reduce to compute min/max from `min_percentile`/`max_percentile`. 

997 For example to clip 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

998 resulting in a tensor of equal shape with clipped values per channel, specify `axes=('batch', 'x', 'y')`. 

999 To clip samples independently, leave out the 'batch' axis. 

1000 

1001 Only valid if `min_percentile` and/or `max_percentile` are set. 

1002 

1003 Default: Compute percentiles over all axes jointly.""" 

1004 

1005 @model_validator(mode="after") 

1006 def _validate(self) -> Self: 

1007 if (self.min is not None) and (self.min_percentile is not None): 

1008 raise ValueError( 

1009 "Only one of `min` and `min_percentile` may be set, not both." 

1010 ) 

1011 if (self.max is not None) and (self.max_percentile is not None): 

1012 raise ValueError( 

1013 "Only one of `max` and `max_percentile` may be set, not both." 

1014 ) 

1015 if ( 

1016 self.min is None 

1017 and self.min_percentile is None 

1018 and self.max is None 

1019 and self.max_percentile is None 

1020 ): 

1021 raise ValueError( 

1022 "At least one of `min`, `min_percentile`, `max`, or `max_percentile` must be set." 

1023 ) 

1024 

1025 if ( 

1026 self.axes is not None 

1027 and self.min_percentile is None 

1028 and self.max_percentile is None 

1029 ): 

1030 raise ValueError( 

1031 "If `axes` is set, at least one of `min_percentile` or `max_percentile` must be set." 

1032 ) 

1033 

1034 return self 

1035 

1036 

1037class ClipDescr(NodeWithExplicitlySetFields): 

1038 """Set tensor values below min to min and above max to max. 

1039 

1040 See `ScaleRangeDescr` for examples. 

1041 """ 

1042 

1043 implemented_id: ClassVar[Literal["clip"]] = "clip" 

1044 if TYPE_CHECKING: 

1045 id: Literal["clip"] = "clip" 

1046 else: 

1047 id: Literal["clip"] 

1048 

1049 kwargs: ClipKwargs 

1050 

1051 

1052class EnsureDtypeKwargs(KwargsNode): 

1053 """key word arguments for [EnsureDtypeDescr][]""" 

1054 

1055 dtype: Literal[ 

1056 "float32", 

1057 "float64", 

1058 "uint8", 

1059 "int8", 

1060 "uint16", 

1061 "int16", 

1062 "uint32", 

1063 "int32", 

1064 "uint64", 

1065 "int64", 

1066 "bool", 

1067 ] 

1068 

1069 

1070class EnsureDtypeDescr(NodeWithExplicitlySetFields): 

1071 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

1072 

1073 This can for example be used to ensure the inner neural network model gets a 

1074 different input tensor data type than the fully described bioimage.io model does. 

1075 

1076 Examples: 

1077 The described bioimage.io model (incl. preprocessing) accepts any 

1078 float32-compatible tensor, normalizes it with percentiles and clipping and then 

1079 casts it to uint8, which is what the neural network in this example expects. 

1080 - in YAML 

1081 ```yaml 

1082 inputs: 

1083 - data: 

1084 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

1085 preprocessing: 

1086 - id: scale_range 

1087 kwargs: 

1088 axes: ['y', 'x'] 

1089 max_percentile: 99.8 

1090 min_percentile: 5.0 

1091 - id: clip 

1092 kwargs: 

1093 min: 0.0 

1094 max: 1.0 

1095 - id: ensure_dtype # the neural network of the model requires uint8 

1096 kwargs: 

1097 dtype: uint8 

1098 ``` 

1099 - in Python: 

1100 >>> preprocessing = [ 

1101 ... ScaleRangeDescr( 

1102 ... kwargs=ScaleRangeKwargs( 

1103 ... axes= (AxisId('y'), AxisId('x')), 

1104 ... max_percentile= 99.8, 

1105 ... min_percentile= 5.0, 

1106 ... ) 

1107 ... ), 

1108 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

1109 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

1110 ... ] 

1111 """ 

1112 

1113 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1114 if TYPE_CHECKING: 

1115 id: Literal["ensure_dtype"] = "ensure_dtype" 

1116 else: 

1117 id: Literal["ensure_dtype"] 

1118 

1119 kwargs: EnsureDtypeKwargs 

1120 

1121 

1122class ScaleLinearKwargs(KwargsNode): 

1123 """Key word arguments for [ScaleLinearDescr][]""" 

1124 

1125 gain: float = 1.0 

1126 """multiplicative factor""" 

1127 

1128 offset: float = 0.0 

1129 """additive term""" 

1130 

1131 @model_validator(mode="after") 

1132 def _validate(self) -> Self: 

1133 if self.gain == 1.0 and self.offset == 0.0: 

1134 raise ValueError( 

1135 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1136 + " != 0.0." 

1137 ) 

1138 

1139 return self 

1140 

1141 

1142class ScaleLinearAlongAxisKwargs(KwargsNode): 

1143 """Key word arguments for [ScaleLinearDescr][]""" 

1144 

1145 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1146 """The axis of gain and offset values.""" 

1147 

1148 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1149 """multiplicative factor""" 

1150 

1151 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1152 """additive term""" 

1153 

1154 @model_validator(mode="after") 

1155 def _validate(self) -> Self: 

1156 if isinstance(self.gain, list): 

1157 if isinstance(self.offset, list): 

1158 if len(self.gain) != len(self.offset): 

1159 raise ValueError( 

1160 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1161 ) 

1162 else: 

1163 self.offset = [float(self.offset)] * len(self.gain) 

1164 elif isinstance(self.offset, list): 

1165 self.gain = [float(self.gain)] * len(self.offset) 

1166 else: 

1167 raise ValueError( 

1168 "Do not specify an `axis` for scalar gain and offset values." 

1169 ) 

1170 

1171 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1172 raise ValueError( 

1173 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1174 + " != 0.0." 

1175 ) 

1176 

1177 return self 

1178 

1179 

1180class ScaleLinearDescr(NodeWithExplicitlySetFields): 

1181 """Fixed linear scaling. 

1182 

1183 Examples: 

1184 1. Scale with scalar gain and offset 

1185 - in YAML 

1186 ```yaml 

1187 preprocessing: 

1188 - id: scale_linear 

1189 kwargs: 

1190 gain: 2.0 

1191 offset: 3.0 

1192 ``` 

1193 - in Python: 

1194 >>> preprocessing = [ 

1195 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1196 ... ] 

1197 

1198 2. Independent scaling along an axis 

1199 - in YAML 

1200 ```yaml 

1201 preprocessing: 

1202 - id: scale_linear 

1203 kwargs: 

1204 axis: 'channel' 

1205 gain: [1.0, 2.0, 3.0] 

1206 ``` 

1207 - in Python: 

1208 >>> preprocessing = [ 

1209 ... ScaleLinearDescr( 

1210 ... kwargs=ScaleLinearAlongAxisKwargs( 

1211 ... axis=AxisId("channel"), 

1212 ... gain=[1.0, 2.0, 3.0], 

1213 ... ) 

1214 ... ) 

1215 ... ] 

1216 

1217 """ 

1218 

1219 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1220 if TYPE_CHECKING: 

1221 id: Literal["scale_linear"] = "scale_linear" 

1222 else: 

1223 id: Literal["scale_linear"] 

1224 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1225 

1226 

1227class SigmoidDescr(NodeWithExplicitlySetFields): 

1228 """The logistic sigmoid function, a.k.a. expit function. 

1229 

1230 Examples: 

1231 - in YAML 

1232 ```yaml 

1233 postprocessing: 

1234 - id: sigmoid 

1235 ``` 

1236 - in Python: 

1237 >>> postprocessing = [SigmoidDescr()] 

1238 """ 

1239 

1240 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1241 if TYPE_CHECKING: 

1242 id: Literal["sigmoid"] = "sigmoid" 

1243 else: 

1244 id: Literal["sigmoid"] 

1245 

1246 @property 

1247 def kwargs(self) -> KwargsNode: 

1248 """empty kwargs""" 

1249 return KwargsNode() 

1250 

1251 

1252class SoftmaxKwargs(KwargsNode): 

1253 """key word arguments for [SoftmaxDescr][]""" 

1254 

1255 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 

1256 """The axis to apply the softmax function along. 

1257 Note: 

1258 Defaults to 'channel' axis 

1259 (which may not exist, in which case 

1260 a different axis id has to be specified). 

1261 """ 

1262 

1263 

1264class SoftmaxDescr(NodeWithExplicitlySetFields): 

1265 """The softmax function. 

1266 

1267 Examples: 

1268 - in YAML 

1269 ```yaml 

1270 postprocessing: 

1271 - id: softmax 

1272 kwargs: 

1273 axis: channel 

1274 ``` 

1275 - in Python: 

1276 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 

1277 """ 

1278 

1279 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 

1280 if TYPE_CHECKING: 

1281 id: Literal["softmax"] = "softmax" 

1282 else: 

1283 id: Literal["softmax"] 

1284 

1285 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 

1286 

1287 

1288class FixedZeroMeanUnitVarianceKwargs(KwargsNode): 

1289 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]""" 

1290 

1291 mean: float 

1292 """The mean value to normalize with.""" 

1293 

1294 std: Annotated[float, Ge(1e-6)] 

1295 """The standard deviation value to normalize with.""" 

1296 

1297 

1298class FixedZeroMeanUnitVarianceAlongAxisKwargs(KwargsNode): 

1299 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]""" 

1300 

1301 mean: NotEmpty[List[float]] 

1302 """The mean value(s) to normalize with.""" 

1303 

1304 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1305 """The standard deviation value(s) to normalize with. 

1306 Size must match `mean` values.""" 

1307 

1308 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1309 """The axis of the mean/std values to normalize each entry along that dimension 

1310 separately.""" 

1311 

1312 @model_validator(mode="after") 

1313 def _mean_and_std_match(self) -> Self: 

1314 if len(self.mean) != len(self.std): 

1315 raise ValueError( 

1316 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1317 + " must match." 

1318 ) 

1319 

1320 return self 

1321 

1322 

1323class FixedZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

1324 """Subtract a given mean and divide by the standard deviation. 

1325 

1326 Normalize with fixed, precomputed values for 

1327 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1328 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1329 axes. 

1330 

1331 Examples: 

1332 1. scalar value for whole tensor 

1333 - in YAML 

1334 ```yaml 

1335 preprocessing: 

1336 - id: fixed_zero_mean_unit_variance 

1337 kwargs: 

1338 mean: 103.5 

1339 std: 13.7 

1340 ``` 

1341 - in Python 

1342 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1343 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1344 ... )] 

1345 

1346 2. independently along an axis 

1347 - in YAML 

1348 ```yaml 

1349 preprocessing: 

1350 - id: fixed_zero_mean_unit_variance 

1351 kwargs: 

1352 axis: channel 

1353 mean: [101.5, 102.5, 103.5] 

1354 std: [11.7, 12.7, 13.7] 

1355 ``` 

1356 - in Python 

1357 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1358 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1359 ... axis=AxisId("channel"), 

1360 ... mean=[101.5, 102.5, 103.5], 

1361 ... std=[11.7, 12.7, 13.7], 

1362 ... ) 

1363 ... )] 

1364 """ 

1365 

1366 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1367 "fixed_zero_mean_unit_variance" 

1368 ) 

1369 if TYPE_CHECKING: 

1370 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1371 else: 

1372 id: Literal["fixed_zero_mean_unit_variance"] 

1373 

1374 kwargs: Union[ 

1375 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1376 ] 

1377 

1378 

1379class ZeroMeanUnitVarianceKwargs(KwargsNode): 

1380 """key word arguments for [ZeroMeanUnitVarianceDescr][]""" 

1381 

1382 axes: Annotated[ 

1383 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1384 ] = None 

1385 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1386 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1387 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1388 To normalize each sample independently leave out the 'batch' axis. 

1389 Default: Scale all axes jointly.""" 

1390 

1391 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1392 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1393 

1394 

1395class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

1396 """Subtract mean and divide by variance. 

1397 

1398 Examples: 

1399 Subtract tensor mean and variance 

1400 - in YAML 

1401 ```yaml 

1402 preprocessing: 

1403 - id: zero_mean_unit_variance 

1404 ``` 

1405 - in Python 

1406 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1407 """ 

1408 

1409 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1410 "zero_mean_unit_variance" 

1411 ) 

1412 if TYPE_CHECKING: 

1413 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1414 else: 

1415 id: Literal["zero_mean_unit_variance"] 

1416 

1417 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1418 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 

1419 ) 

1420 

1421 

1422class ScaleRangeKwargs(KwargsNode): 

1423 """key word arguments for [ScaleRangeDescr][] 

1424 

1425 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1426 this processing step normalizes data to the [0, 1] intervall. 

1427 For other percentiles the normalized values will partially be outside the [0, 1] 

1428 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1429 normalized values to a range. 

1430 """ 

1431 

1432 axes: Annotated[ 

1433 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1434 ] = None 

1435 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1436 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1437 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1438 To normalize samples independently, leave out the "batch" axis. 

1439 Default: Scale all axes jointly.""" 

1440 

1441 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1442 """The lower percentile used to determine the value to align with zero.""" 

1443 

1444 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1445 """The upper percentile used to determine the value to align with one. 

1446 Has to be bigger than `min_percentile`. 

1447 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1448 accepting percentiles specified in the range 0.0 to 1.0.""" 

1449 

1450 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1451 """Epsilon for numeric stability. 

1452 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1453 with `v_lower,v_upper` values at the respective percentiles.""" 

1454 

1455 reference_tensor: Optional[TensorId] = None 

1456 """Tensor ID to compute the percentiles from. Default: The tensor itself. 

1457 For any tensor in `inputs` only input tensor references are allowed.""" 

1458 

1459 @field_validator("max_percentile", mode="after") 

1460 @classmethod 

1461 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1462 if (min_p := info.data["min_percentile"]) >= value: 

1463 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1464 

1465 return value 

1466 

1467 

1468class ScaleRangeDescr(NodeWithExplicitlySetFields): 

1469 """Scale with percentiles. 

1470 

1471 Examples: 

1472 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1473 - in YAML 

1474 ```yaml 

1475 preprocessing: 

1476 - id: scale_range 

1477 kwargs: 

1478 axes: ['y', 'x'] 

1479 max_percentile: 99.8 

1480 min_percentile: 5.0 

1481 ``` 

1482 - in Python 

1483 >>> preprocessing = [ 

1484 ... ScaleRangeDescr( 

1485 ... kwargs=ScaleRangeKwargs( 

1486 ... axes= (AxisId('y'), AxisId('x')), 

1487 ... max_percentile= 99.8, 

1488 ... min_percentile= 5.0, 

1489 ... ) 

1490 ... ) 

1491 ... ] 

1492 

1493 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1494 - in YAML 

1495 ```yaml 

1496 preprocessing: 

1497 - id: scale_range 

1498 kwargs: 

1499 axes: ['y', 'x'] 

1500 max_percentile: 99.8 

1501 min_percentile: 5.0 

1502 - id: scale_range 

1503 - id: clip 

1504 kwargs: 

1505 min: 0.0 

1506 max: 1.0 

1507 ``` 

1508 - in Python 

1509 >>> preprocessing = [ 

1510 ... ScaleRangeDescr( 

1511 ... kwargs=ScaleRangeKwargs( 

1512 ... axes= (AxisId('y'), AxisId('x')), 

1513 ... max_percentile= 99.8, 

1514 ... min_percentile= 5.0, 

1515 ... ) 

1516 ... ), 

1517 ... ClipDescr( 

1518 ... kwargs=ClipKwargs( 

1519 ... min=0.0, 

1520 ... max=1.0, 

1521 ... ) 

1522 ... ), 

1523 ... ] 

1524 

1525 """ 

1526 

1527 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1528 if TYPE_CHECKING: 

1529 id: Literal["scale_range"] = "scale_range" 

1530 else: 

1531 id: Literal["scale_range"] 

1532 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 

1533 

1534 

1535class ScaleMeanVarianceKwargs(KwargsNode): 

1536 """key word arguments for [ScaleMeanVarianceKwargs][]""" 

1537 

1538 reference_tensor: TensorId 

1539 """Name of tensor to match.""" 

1540 

1541 axes: Annotated[ 

1542 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1543 ] = None 

1544 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1545 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1546 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1547 To normalize samples independently, leave out the 'batch' axis. 

1548 Default: Scale all axes jointly.""" 

1549 

1550 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1551 """Epsilon for numeric stability: 

1552 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1553 

1554 

1555class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields): 

1556 """Scale a tensor's data distribution to match another tensor's mean/std. 

1557 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1558 """ 

1559 

1560 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1561 if TYPE_CHECKING: 

1562 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1563 else: 

1564 id: Literal["scale_mean_variance"] 

1565 kwargs: ScaleMeanVarianceKwargs 

1566 

1567 

1568PreprocessingDescr = Annotated[ 

1569 Union[ 

1570 BinarizeDescr, 

1571 ClipDescr, 

1572 EnsureDtypeDescr, 

1573 FixedZeroMeanUnitVarianceDescr, 

1574 ScaleLinearDescr, 

1575 ScaleRangeDescr, 

1576 SigmoidDescr, 

1577 SoftmaxDescr, 

1578 ZeroMeanUnitVarianceDescr, 

1579 ], 

1580 Discriminator("id"), 

1581] 

1582PostprocessingDescr = Annotated[ 

1583 Union[ 

1584 BinarizeDescr, 

1585 ClipDescr, 

1586 EnsureDtypeDescr, 

1587 FixedZeroMeanUnitVarianceDescr, 

1588 ScaleLinearDescr, 

1589 ScaleMeanVarianceDescr, 

1590 ScaleRangeDescr, 

1591 SigmoidDescr, 

1592 SoftmaxDescr, 

1593 ZeroMeanUnitVarianceDescr, 

1594 ], 

1595 Discriminator("id"), 

1596] 

1597 

1598IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1599 

1600 

1601class TensorDescrBase(Node, Generic[IO_AxisT]): 

1602 id: TensorId 

1603 """Tensor id. No duplicates are allowed.""" 

1604 

1605 description: Annotated[str, MaxLen(128)] = "" 

1606 """free text description""" 

1607 

1608 axes: NotEmpty[Sequence[IO_AxisT]] 

1609 """tensor axes""" 

1610 

1611 @property 

1612 def shape(self): 

1613 return tuple(a.size for a in self.axes) 

1614 

1615 @field_validator("axes", mode="after", check_fields=False) 

1616 @classmethod 

1617 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1618 batch_axes = [a for a in axes if a.type == "batch"] 

1619 if len(batch_axes) > 1: 

1620 raise ValueError( 

1621 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1622 ) 

1623 

1624 seen_ids: Set[AxisId] = set() 

1625 duplicate_axes_ids: Set[AxisId] = set() 

1626 for a in axes: 

1627 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1628 

1629 if duplicate_axes_ids: 

1630 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1631 

1632 return axes 

1633 

1634 test_tensor: FAIR[Optional[FileDescr_]] = None 

1635 """An example tensor to use for testing. 

1636 Using the model with the test input tensors is expected to yield the test output tensors. 

1637 Each test tensor has be a an ndarray in the 

1638 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1639 The file extension must be '.npy'.""" 

1640 

1641 sample_tensor: FAIR[Optional[FileDescr_]] = None 

1642 """A sample tensor to illustrate a possible input/output for the model, 

1643 The sample image primarily serves to inform a human user about an example use case 

1644 and is typically stored as .hdf5, .png or .tiff. 

1645 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1646 (numpy's `.npy` format is not supported). 

1647 The image dimensionality has to match the number of axes specified in this tensor description. 

1648 """ 

1649 

1650 @model_validator(mode="after") 

1651 def _validate_sample_tensor(self) -> Self: 

1652 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1653 return self 

1654 

1655 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1656 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType] 

1657 reader.read(), 

1658 extension=PurePosixPath(reader.original_file_name).suffix, 

1659 ) 

1660 n_dims = len(tensor.squeeze().shape) 

1661 n_dims_min = n_dims_max = len(self.axes) 

1662 

1663 for a in self.axes: 

1664 if isinstance(a, BatchAxis): 

1665 n_dims_min -= 1 

1666 elif isinstance(a.size, int): 

1667 if a.size == 1: 

1668 n_dims_min -= 1 

1669 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1670 if a.size.min == 1: 

1671 n_dims_min -= 1 

1672 elif isinstance(a.size, SizeReference): 

1673 if a.size.offset < 2: 

1674 # size reference may result in singleton axis 

1675 n_dims_min -= 1 

1676 else: 

1677 assert_never(a.size) 

1678 

1679 n_dims_min = max(0, n_dims_min) 

1680 if n_dims < n_dims_min or n_dims > n_dims_max: 

1681 raise ValueError( 

1682 f"Expected sample tensor to have {n_dims_min} to" 

1683 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1684 ) 

1685 

1686 return self 

1687 

1688 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1689 IntervalOrRatioDataDescr() 

1690 ) 

1691 """Description of the tensor's data values, optionally per channel. 

1692 If specified per channel, the data `type` needs to match across channels.""" 

1693 

1694 @property 

1695 def dtype( 

1696 self, 

1697 ) -> Literal[ 

1698 "float32", 

1699 "float64", 

1700 "uint8", 

1701 "int8", 

1702 "uint16", 

1703 "int16", 

1704 "uint32", 

1705 "int32", 

1706 "uint64", 

1707 "int64", 

1708 "bool", 

1709 ]: 

1710 """dtype as specified under `data.type` or `data[i].type`""" 

1711 if isinstance(self.data, collections.abc.Sequence): 

1712 return self.data[0].type 

1713 else: 

1714 return self.data.type 

1715 

1716 @field_validator("data", mode="after") 

1717 @classmethod 

1718 def _check_data_type_across_channels( 

1719 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1720 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1721 if not isinstance(value, list): 

1722 return value 

1723 

1724 dtypes = {t.type for t in value} 

1725 if len(dtypes) > 1: 

1726 raise ValueError( 

1727 "Tensor data descriptions per channel need to agree in their data" 

1728 + f" `type`, but found {dtypes}." 

1729 ) 

1730 

1731 return value 

1732 

1733 @model_validator(mode="after") 

1734 def _check_data_matches_channelaxis(self) -> Self: 

1735 if not isinstance(self.data, (list, tuple)): 

1736 return self 

1737 

1738 for a in self.axes: 

1739 if isinstance(a, ChannelAxis): 

1740 size = a.size 

1741 assert isinstance(size, int) 

1742 break 

1743 else: 

1744 return self 

1745 

1746 if len(self.data) != size: 

1747 raise ValueError( 

1748 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1749 + f" '{a.id}' axis has size {size}." 

1750 ) 

1751 

1752 return self 

1753 

1754 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1755 if len(array.shape) != len(self.axes): 

1756 raise ValueError( 

1757 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1758 + f" incompatible with {len(self.axes)} axes." 

1759 ) 

1760 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1761 

1762 

1763class InputTensorDescr(TensorDescrBase[InputAxis]): 

1764 id: TensorId = TensorId("input") 

1765 """Input tensor id. 

1766 No duplicates are allowed across all inputs and outputs.""" 

1767 

1768 optional: bool = False 

1769 """indicates that this tensor may be `None`""" 

1770 

1771 preprocessing: List[PreprocessingDescr] = Field( 

1772 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 

1773 ) 

1774 

1775 """Description of how this input should be preprocessed. 

1776 

1777 notes: 

1778 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1779 to ensure an input tensor's data type matches the input tensor's data description. 

1780 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1781 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1782 changing the data type. 

1783 """ 

1784 

1785 @model_validator(mode="after") 

1786 def _validate_preprocessing_kwargs(self) -> Self: 

1787 axes_ids = [a.id for a in self.axes] 

1788 for p in self.preprocessing: 

1789 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1790 if kwargs_axes is None: 

1791 continue 

1792 

1793 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1794 raise ValueError( 

1795 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1796 ) 

1797 

1798 if any(a not in axes_ids for a in kwargs_axes): 

1799 raise ValueError( 

1800 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1801 ) 

1802 

1803 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1804 dtype = self.data.type 

1805 else: 

1806 dtype = self.data[0].type 

1807 

1808 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1809 if not self.preprocessing or not isinstance( 

1810 self.preprocessing[0], EnsureDtypeDescr 

1811 ): 

1812 self.preprocessing.insert( 

1813 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1814 ) 

1815 

1816 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1817 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1818 self.preprocessing.append( 

1819 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1820 ) 

1821 

1822 return self 

1823 

1824 

1825def convert_axes( 

1826 axes: str, 

1827 *, 

1828 shape: Union[ 

1829 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1830 ], 

1831 tensor_type: Literal["input", "output"], 

1832 halo: Optional[Sequence[int]], 

1833 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1834): 

1835 ret: List[AnyAxis] = [] 

1836 for i, a in enumerate(axes): 

1837 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1838 if axis_type == "batch": 

1839 ret.append(BatchAxis()) 

1840 continue 

1841 

1842 scale = 1.0 

1843 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1844 if shape.step[i] == 0: 

1845 size = shape.min[i] 

1846 else: 

1847 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1848 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1849 ref_t = str(shape.reference_tensor) 

1850 if ref_t.count(".") == 1: 

1851 t_id, orig_a_id = ref_t.split(".") 

1852 else: 

1853 t_id = ref_t 

1854 orig_a_id = a 

1855 

1856 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1857 if not (orig_scale := shape.scale[i]): 

1858 # old way to insert a new axis dimension 

1859 size = int(2 * shape.offset[i]) 

1860 else: 

1861 scale = 1 / orig_scale 

1862 if axis_type in ("channel", "index"): 

1863 # these axes no longer have a scale 

1864 offset_from_scale = orig_scale * size_refs.get( 

1865 _TensorName_v0_4(t_id), {} 

1866 ).get(orig_a_id, 0) 

1867 else: 

1868 offset_from_scale = 0 

1869 size = SizeReference( 

1870 tensor_id=TensorId(t_id), 

1871 axis_id=AxisId(a_id), 

1872 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1873 ) 

1874 else: 

1875 size = shape[i] 

1876 

1877 if axis_type == "time": 

1878 if tensor_type == "input": 

1879 ret.append(TimeInputAxis(size=size, scale=scale)) 

1880 else: 

1881 assert not isinstance(size, ParameterizedSize) 

1882 if halo is None: 

1883 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1884 else: 

1885 assert not isinstance(size, int) 

1886 ret.append( 

1887 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1888 ) 

1889 

1890 elif axis_type == "index": 

1891 if tensor_type == "input": 

1892 ret.append(IndexInputAxis(size=size)) 

1893 else: 

1894 if isinstance(size, ParameterizedSize): 

1895 size = DataDependentSize(min=size.min) 

1896 

1897 ret.append(IndexOutputAxis(size=size)) 

1898 elif axis_type == "channel": 

1899 assert not isinstance(size, ParameterizedSize) 

1900 if isinstance(size, SizeReference): 

1901 warnings.warn( 

1902 "Conversion of channel size from an implicit output shape may be" 

1903 + " wrong" 

1904 ) 

1905 ret.append( 

1906 ChannelAxis( 

1907 channel_names=[ 

1908 Identifier(f"channel{i}") for i in range(size.offset) 

1909 ] 

1910 ) 

1911 ) 

1912 else: 

1913 ret.append( 

1914 ChannelAxis( 

1915 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1916 ) 

1917 ) 

1918 elif axis_type == "space": 

1919 if tensor_type == "input": 

1920 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

1921 else: 

1922 assert not isinstance(size, ParameterizedSize) 

1923 if halo is None or halo[i] == 0: 

1924 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

1925 elif isinstance(size, int): 

1926 raise NotImplementedError( 

1927 f"output axis with halo and fixed size (here {size}) not allowed" 

1928 ) 

1929 else: 

1930 ret.append( 

1931 SpaceOutputAxisWithHalo( 

1932 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

1933 ) 

1934 ) 

1935 

1936 return ret 

1937 

1938 

1939def _axes_letters_to_ids( 

1940 axes: Optional[str], 

1941) -> Optional[List[AxisId]]: 

1942 if axes is None: 

1943 return None 

1944 

1945 return [AxisId(a) for a in axes] 

1946 

1947 

1948def _get_complement_v04_axis( 

1949 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

1950) -> Optional[AxisId]: 

1951 if axes is None: 

1952 return None 

1953 

1954 non_complement_axes = set(axes) | {"b"} 

1955 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

1956 if len(complement_axes) > 1: 

1957 raise ValueError( 

1958 f"Expected none or a single complement axis, but axes '{axes}' " 

1959 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

1960 ) 

1961 

1962 return None if not complement_axes else AxisId(complement_axes[0]) 

1963 

1964 

1965def _convert_proc( 

1966 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

1967 tensor_axes: Sequence[str], 

1968) -> Union[PreprocessingDescr, PostprocessingDescr]: 

1969 if isinstance(p, _BinarizeDescr_v0_4): 

1970 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

1971 elif isinstance(p, _ClipDescr_v0_4): 

1972 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

1973 elif isinstance(p, _SigmoidDescr_v0_4): 

1974 return SigmoidDescr() 

1975 elif isinstance(p, _ScaleLinearDescr_v0_4): 

1976 axes = _axes_letters_to_ids(p.kwargs.axes) 

1977 if p.kwargs.axes is None: 

1978 axis = None 

1979 else: 

1980 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1981 

1982 if axis is None: 

1983 assert not isinstance(p.kwargs.gain, list) 

1984 assert not isinstance(p.kwargs.offset, list) 

1985 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

1986 else: 

1987 kwargs = ScaleLinearAlongAxisKwargs( 

1988 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

1989 ) 

1990 return ScaleLinearDescr(kwargs=kwargs) 

1991 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

1992 return ScaleMeanVarianceDescr( 

1993 kwargs=ScaleMeanVarianceKwargs( 

1994 axes=_axes_letters_to_ids(p.kwargs.axes), 

1995 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

1996 eps=p.kwargs.eps, 

1997 ) 

1998 ) 

1999 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

2000 if p.kwargs.mode == "fixed": 

2001 mean = p.kwargs.mean 

2002 std = p.kwargs.std 

2003 assert mean is not None 

2004 assert std is not None 

2005 

2006 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

2007 

2008 if axis is None: 

2009 if isinstance(mean, list): 

2010 raise ValueError("Expected single float value for mean, not <list>") 

2011 if isinstance(std, list): 

2012 raise ValueError("Expected single float value for std, not <list>") 

2013 return FixedZeroMeanUnitVarianceDescr( 

2014 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct( 

2015 mean=mean, 

2016 std=std, 

2017 ) 

2018 ) 

2019 else: 

2020 if not isinstance(mean, list): 

2021 mean = [float(mean)] 

2022 if not isinstance(std, list): 

2023 std = [float(std)] 

2024 

2025 return FixedZeroMeanUnitVarianceDescr( 

2026 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

2027 axis=axis, mean=mean, std=std 

2028 ) 

2029 ) 

2030 

2031 else: 

2032 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

2033 if p.kwargs.mode == "per_dataset": 

2034 axes = [AxisId("batch")] + axes 

2035 if not axes: 

2036 axes = None 

2037 return ZeroMeanUnitVarianceDescr( 

2038 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

2039 ) 

2040 

2041 elif isinstance(p, _ScaleRangeDescr_v0_4): 

2042 return ScaleRangeDescr( 

2043 kwargs=ScaleRangeKwargs( 

2044 axes=_axes_letters_to_ids(p.kwargs.axes), 

2045 min_percentile=p.kwargs.min_percentile, 

2046 max_percentile=p.kwargs.max_percentile, 

2047 eps=p.kwargs.eps, 

2048 ) 

2049 ) 

2050 else: 

2051 assert_never(p) 

2052 

2053 

2054class _InputTensorConv( 

2055 Converter[ 

2056 _InputTensorDescr_v0_4, 

2057 InputTensorDescr, 

2058 FileSource_, 

2059 Optional[FileSource_], 

2060 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2061 ] 

2062): 

2063 def _convert( 

2064 self, 

2065 src: _InputTensorDescr_v0_4, 

2066 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

2067 test_tensor: FileSource_, 

2068 sample_tensor: Optional[FileSource_], 

2069 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2070 ) -> "InputTensorDescr | dict[str, Any]": 

2071 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2072 src.axes, 

2073 shape=src.shape, 

2074 tensor_type="input", 

2075 halo=None, 

2076 size_refs=size_refs, 

2077 ) 

2078 prep: List[PreprocessingDescr] = [] 

2079 for p in src.preprocessing: 

2080 cp = _convert_proc(p, src.axes) 

2081 assert not isinstance(cp, ScaleMeanVarianceDescr) 

2082 prep.append(cp) 

2083 

2084 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

2085 

2086 return tgt( 

2087 axes=axes, 

2088 id=TensorId(str(src.name)), 

2089 test_tensor=FileDescr(source=test_tensor), 

2090 sample_tensor=( 

2091 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2092 ), 

2093 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

2094 preprocessing=prep, 

2095 ) 

2096 

2097 

2098_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

2099 

2100 

2101class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

2102 id: TensorId = TensorId("output") 

2103 """Output tensor id. 

2104 No duplicates are allowed across all inputs and outputs.""" 

2105 

2106 postprocessing: List[PostprocessingDescr] = Field( 

2107 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

2108 ) 

2109 """Description of how this output should be postprocessed. 

2110 

2111 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

2112 If not given this is added to cast to this tensor's `data.type`. 

2113 """ 

2114 

2115 @model_validator(mode="after") 

2116 def _validate_postprocessing_kwargs(self) -> Self: 

2117 axes_ids = [a.id for a in self.axes] 

2118 for p in self.postprocessing: 

2119 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

2120 if kwargs_axes is None: 

2121 continue 

2122 

2123 if not isinstance(kwargs_axes, collections.abc.Sequence): 

2124 raise ValueError( 

2125 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

2126 ) 

2127 

2128 if any(a not in axes_ids for a in kwargs_axes): 

2129 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

2130 

2131 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

2132 dtype = self.data.type 

2133 else: 

2134 dtype = self.data[0].type 

2135 

2136 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

2137 if not self.postprocessing or not isinstance( 

2138 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

2139 ): 

2140 self.postprocessing.append( 

2141 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2142 ) 

2143 return self 

2144 

2145 

2146class _OutputTensorConv( 

2147 Converter[ 

2148 _OutputTensorDescr_v0_4, 

2149 OutputTensorDescr, 

2150 FileSource_, 

2151 Optional[FileSource_], 

2152 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2153 ] 

2154): 

2155 def _convert( 

2156 self, 

2157 src: _OutputTensorDescr_v0_4, 

2158 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

2159 test_tensor: FileSource_, 

2160 sample_tensor: Optional[FileSource_], 

2161 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2162 ) -> "OutputTensorDescr | dict[str, Any]": 

2163 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2164 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2165 src.axes, 

2166 shape=src.shape, 

2167 tensor_type="output", 

2168 halo=src.halo, 

2169 size_refs=size_refs, 

2170 ) 

2171 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2172 if data_descr["type"] == "bool": 

2173 data_descr["values"] = [False, True] 

2174 

2175 return tgt( 

2176 axes=axes, 

2177 id=TensorId(str(src.name)), 

2178 test_tensor=FileDescr(source=test_tensor), 

2179 sample_tensor=( 

2180 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2181 ), 

2182 data=data_descr, # pyright: ignore[reportArgumentType] 

2183 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2184 ) 

2185 

2186 

2187_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2188 

2189 

2190TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2191 

2192 

2193def validate_tensors( 

2194 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 

2195 tensor_origin: Literal[ 

2196 "test_tensor" 

2197 ], # for more precise error messages, e.g. 'test_tensor' 

2198): 

2199 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 

2200 

2201 def e_msg(d: TensorDescr): 

2202 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2203 

2204 for descr, array in tensors.values(): 

2205 if array is None: 

2206 axis_sizes = {a.id: None for a in descr.axes} 

2207 else: 

2208 try: 

2209 axis_sizes = descr.get_axis_sizes_for_array(array) 

2210 except ValueError as e: 

2211 raise ValueError(f"{e_msg(descr)} {e}") 

2212 

2213 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 

2214 

2215 for descr, array in tensors.values(): 

2216 if array is None: 

2217 continue 

2218 

2219 if descr.dtype in ("float32", "float64"): 

2220 invalid_test_tensor_dtype = array.dtype.name not in ( 

2221 "float32", 

2222 "float64", 

2223 "uint8", 

2224 "int8", 

2225 "uint16", 

2226 "int16", 

2227 "uint32", 

2228 "int32", 

2229 "uint64", 

2230 "int64", 

2231 ) 

2232 else: 

2233 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2234 

2235 if invalid_test_tensor_dtype: 

2236 raise ValueError( 

2237 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 

2238 + f" match described dtype '{descr.dtype}'" 

2239 ) 

2240 

2241 if array.min() > -1e-4 and array.max() < 1e-4: 

2242 raise ValueError( 

2243 "Output values are too small for reliable testing." 

2244 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2245 ) 

2246 

2247 for a in descr.axes: 

2248 actual_size = all_tensor_axes[descr.id][a.id][1] 

2249 if actual_size is None: 

2250 continue 

2251 

2252 if a.size is None: 

2253 continue 

2254 

2255 if isinstance(a.size, int): 

2256 if actual_size != a.size: 

2257 raise ValueError( 

2258 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 

2259 + f"has incompatible size {actual_size}, expected {a.size}" 

2260 ) 

2261 elif isinstance(a.size, ParameterizedSize): 

2262 _ = a.size.validate_size(actual_size) 

2263 elif isinstance(a.size, DataDependentSize): 

2264 _ = a.size.validate_size(actual_size) 

2265 elif isinstance(a.size, SizeReference): 

2266 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2267 if ref_tensor_axes is None: 

2268 raise ValueError( 

2269 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2270 + f" reference '{a.size.tensor_id}'" 

2271 ) 

2272 

2273 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2274 if ref_axis is None or ref_size is None: 

2275 raise ValueError( 

2276 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2277 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 

2278 ) 

2279 

2280 if a.unit != ref_axis.unit: 

2281 raise ValueError( 

2282 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2283 + " axis and reference axis to have the same `unit`, but" 

2284 + f" {a.unit}!={ref_axis.unit}" 

2285 ) 

2286 

2287 if actual_size != ( 

2288 expected_size := ( 

2289 ref_size * ref_axis.scale / a.scale + a.size.offset 

2290 ) 

2291 ): 

2292 raise ValueError( 

2293 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2294 + f" {actual_size} invalid for referenced size {ref_size};" 

2295 + f" expected {expected_size}" 

2296 ) 

2297 else: 

2298 assert_never(a.size) 

2299 

2300 

2301FileDescr_dependencies = Annotated[ 

2302 FileDescr_, 

2303 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2304 Field(examples=[dict(source="environment.yaml")]), 

2305] 

2306 

2307 

2308class _ArchitectureCallableDescr(Node): 

2309 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2310 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2311 

2312 kwargs: Dict[str, YamlValue] = Field( 

2313 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2314 ) 

2315 """key word arguments for the `callable`""" 

2316 

2317 

2318class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2319 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2320 """Architecture source file""" 

2321 

2322 @model_serializer(mode="wrap", when_used="unless-none") 

2323 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2324 return package_file_descr_serializer(self, nxt, info) 

2325 

2326 

2327class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2328 import_from: str 

2329 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2330 

2331 

2332class _ArchFileConv( 

2333 Converter[ 

2334 _CallableFromFile_v0_4, 

2335 ArchitectureFromFileDescr, 

2336 Optional[Sha256], 

2337 Dict[str, Any], 

2338 ] 

2339): 

2340 def _convert( 

2341 self, 

2342 src: _CallableFromFile_v0_4, 

2343 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2344 sha256: Optional[Sha256], 

2345 kwargs: Dict[str, Any], 

2346 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2347 if src.startswith("http") and src.count(":") == 2: 

2348 http, source, callable_ = src.split(":") 

2349 source = ":".join((http, source)) 

2350 elif not src.startswith("http") and src.count(":") == 1: 

2351 source, callable_ = src.split(":") 

2352 else: 

2353 source = str(src) 

2354 callable_ = str(src) 

2355 return tgt( 

2356 callable=Identifier(callable_), 

2357 source=cast(FileSource_, source), 

2358 sha256=sha256, 

2359 kwargs=kwargs, 

2360 ) 

2361 

2362 

2363_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2364 

2365 

2366class _ArchLibConv( 

2367 Converter[ 

2368 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2369 ] 

2370): 

2371 def _convert( 

2372 self, 

2373 src: _CallableFromDepencency_v0_4, 

2374 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2375 kwargs: Dict[str, Any], 

2376 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2377 *mods, callable_ = src.split(".") 

2378 import_from = ".".join(mods) 

2379 return tgt( 

2380 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2381 ) 

2382 

2383 

2384_arch_lib_conv = _ArchLibConv( 

2385 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2386) 

2387 

2388 

2389class WeightsEntryDescrBase(FileDescr): 

2390 type: ClassVar[WeightsFormat] 

2391 weights_format_name: ClassVar[str] # human readable 

2392 

2393 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2394 """Source of the weights file.""" 

2395 

2396 authors: Optional[List[Author]] = None 

2397 """Authors 

2398 Either the person(s) that have trained this model resulting in the original weights file. 

2399 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2400 Or the person(s) who have converted the weights to this weights format. 

2401 (If this is a child weight, i.e. it has a `parent` field) 

2402 """ 

2403 

2404 parent: Annotated[ 

2405 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2406 ] = None 

2407 """The source weights these weights were converted from. 

2408 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2409 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2410 All weight entries except one (the initial set of weights resulting from training the model), 

2411 need to have this field.""" 

2412 

2413 comment: str = "" 

2414 """A comment about this weights entry, for example how these weights were created.""" 

2415 

2416 @model_validator(mode="after") 

2417 def _validate(self) -> Self: 

2418 if self.type == self.parent: 

2419 raise ValueError("Weights entry can't be it's own parent.") 

2420 

2421 return self 

2422 

2423 @model_serializer(mode="wrap", when_used="unless-none") 

2424 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2425 return package_file_descr_serializer(self, nxt, info) 

2426 

2427 

2428class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2429 type: ClassVar[WeightsFormat] = "keras_hdf5" 

2430 weights_format_name: ClassVar[str] = "Keras HDF5" 

2431 tensorflow_version: Version 

2432 """TensorFlow version used to create these weights.""" 

2433 

2434 

2435FileDescr_external_data = Annotated[ 

2436 FileDescr_, 

2437 WithSuffix(".data", case_sensitive=True), 

2438 Field(examples=[dict(source="weights.onnx.data")]), 

2439] 

2440 

2441 

2442class OnnxWeightsDescr(WeightsEntryDescrBase): 

2443 type: ClassVar[WeightsFormat] = "onnx" 

2444 weights_format_name: ClassVar[str] = "ONNX" 

2445 opset_version: Annotated[int, Ge(7)] 

2446 """ONNX opset version""" 

2447 

2448 external_data: Optional[FileDescr_external_data] = None 

2449 """Source of the external ONNX data file holding the weights. 

2450 (If present **source** holds the ONNX architecture without weights).""" 

2451 

2452 @model_validator(mode="after") 

2453 def _validate_external_data_unique_file_name(self) -> Self: 

2454 if self.external_data is not None and ( 

2455 extract_file_name(self.source) 

2456 == extract_file_name(self.external_data.source) 

2457 ): 

2458 raise ValueError( 

2459 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'" 

2460 + " must be different from ONNX `source` file name." 

2461 ) 

2462 

2463 return self 

2464 

2465 

2466class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2467 type: ClassVar[WeightsFormat] = "pytorch_state_dict" 

2468 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2469 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 

2470 pytorch_version: Version 

2471 """Version of the PyTorch library used. 

2472 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2473 """ 

2474 dependencies: Optional[FileDescr_dependencies] = None 

2475 """Custom depencies beyond pytorch described in a Conda environment file. 

2476 Allows to specify custom dependencies, see conda docs: 

2477 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2478 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2479 

2480 The conda environment file should include pytorch and any version pinning has to be compatible with 

2481 **pytorch_version**. 

2482 """ 

2483 

2484 

2485class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2486 type: ClassVar[WeightsFormat] = "tensorflow_js" 

2487 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2488 tensorflow_version: Version 

2489 """Version of the TensorFlow library used.""" 

2490 

2491 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2492 """The multi-file weights. 

2493 All required files/folders should be a zip archive.""" 

2494 

2495 

2496class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2497 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle" 

2498 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2499 tensorflow_version: Version 

2500 """Version of the TensorFlow library used.""" 

2501 

2502 dependencies: Optional[FileDescr_dependencies] = None 

2503 """Custom dependencies beyond tensorflow. 

2504 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 

2505 

2506 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2507 """The multi-file weights. 

2508 All required files/folders should be a zip archive.""" 

2509 

2510 

2511class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2512 type: ClassVar[WeightsFormat] = "torchscript" 

2513 weights_format_name: ClassVar[str] = "TorchScript" 

2514 pytorch_version: Version 

2515 """Version of the PyTorch library used.""" 

2516 

2517 

2518SpecificWeightsDescr = Union[ 

2519 KerasHdf5WeightsDescr, 

2520 OnnxWeightsDescr, 

2521 PytorchStateDictWeightsDescr, 

2522 TensorflowJsWeightsDescr, 

2523 TensorflowSavedModelBundleWeightsDescr, 

2524 TorchscriptWeightsDescr, 

2525] 

2526 

2527 

2528class WeightsDescr(Node): 

2529 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2530 onnx: Optional[OnnxWeightsDescr] = None 

2531 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2532 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2533 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2534 None 

2535 ) 

2536 torchscript: Optional[TorchscriptWeightsDescr] = None 

2537 

2538 @model_validator(mode="after") 

2539 def check_entries(self) -> Self: 

2540 entries = {wtype for wtype, entry in self if entry is not None} 

2541 

2542 if not entries: 

2543 raise ValueError("Missing weights entry") 

2544 

2545 entries_wo_parent = { 

2546 wtype 

2547 for wtype, entry in self 

2548 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2549 } 

2550 if len(entries_wo_parent) != 1: 

2551 issue_warning( 

2552 "Exactly one weights entry may not specify the `parent` field (got" 

2553 + " {value}). That entry is considered the original set of model weights." 

2554 + " Other weight formats are created through conversion of the orignal or" 

2555 + " already converted weights. They have to reference the weights format" 

2556 + " they were converted from as their `parent`.", 

2557 value=len(entries_wo_parent), 

2558 field="weights", 

2559 ) 

2560 

2561 for wtype, entry in self: 

2562 if entry is None: 

2563 continue 

2564 

2565 assert hasattr(entry, "type") 

2566 assert hasattr(entry, "parent") 

2567 assert wtype == entry.type 

2568 if ( 

2569 entry.parent is not None and entry.parent not in entries 

2570 ): # self reference checked for `parent` field 

2571 raise ValueError( 

2572 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2573 + f" formats: {entries}" 

2574 ) 

2575 

2576 return self 

2577 

2578 def __getitem__( 

2579 self, 

2580 key: Literal[ 

2581 "keras_hdf5", 

2582 "onnx", 

2583 "pytorch_state_dict", 

2584 "tensorflow_js", 

2585 "tensorflow_saved_model_bundle", 

2586 "torchscript", 

2587 ], 

2588 ): 

2589 if key == "keras_hdf5": 

2590 ret = self.keras_hdf5 

2591 elif key == "onnx": 

2592 ret = self.onnx 

2593 elif key == "pytorch_state_dict": 

2594 ret = self.pytorch_state_dict 

2595 elif key == "tensorflow_js": 

2596 ret = self.tensorflow_js 

2597 elif key == "tensorflow_saved_model_bundle": 

2598 ret = self.tensorflow_saved_model_bundle 

2599 elif key == "torchscript": 

2600 ret = self.torchscript 

2601 else: 

2602 raise KeyError(key) 

2603 

2604 if ret is None: 

2605 raise KeyError(key) 

2606 

2607 return ret 

2608 

2609 @overload 

2610 def __setitem__( 

2611 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr] 

2612 ) -> None: ... 

2613 @overload 

2614 def __setitem__( 

2615 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr] 

2616 ) -> None: ... 

2617 @overload 

2618 def __setitem__( 

2619 self, 

2620 key: Literal["pytorch_state_dict"], 

2621 value: Optional[PytorchStateDictWeightsDescr], 

2622 ) -> None: ... 

2623 @overload 

2624 def __setitem__( 

2625 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr] 

2626 ) -> None: ... 

2627 @overload 

2628 def __setitem__( 

2629 self, 

2630 key: Literal["tensorflow_saved_model_bundle"], 

2631 value: Optional[TensorflowSavedModelBundleWeightsDescr], 

2632 ) -> None: ... 

2633 @overload 

2634 def __setitem__( 

2635 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr] 

2636 ) -> None: ... 

2637 

2638 def __setitem__( 

2639 self, 

2640 key: Literal[ 

2641 "keras_hdf5", 

2642 "onnx", 

2643 "pytorch_state_dict", 

2644 "tensorflow_js", 

2645 "tensorflow_saved_model_bundle", 

2646 "torchscript", 

2647 ], 

2648 value: Optional[SpecificWeightsDescr], 

2649 ): 

2650 if key == "keras_hdf5": 

2651 if value is not None and not isinstance(value, KerasHdf5WeightsDescr): 

2652 raise TypeError( 

2653 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}" 

2654 ) 

2655 self.keras_hdf5 = value 

2656 elif key == "onnx": 

2657 if value is not None and not isinstance(value, OnnxWeightsDescr): 

2658 raise TypeError( 

2659 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}" 

2660 ) 

2661 self.onnx = value 

2662 elif key == "pytorch_state_dict": 

2663 if value is not None and not isinstance( 

2664 value, PytorchStateDictWeightsDescr 

2665 ): 

2666 raise TypeError( 

2667 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}" 

2668 ) 

2669 self.pytorch_state_dict = value 

2670 elif key == "tensorflow_js": 

2671 if value is not None and not isinstance(value, TensorflowJsWeightsDescr): 

2672 raise TypeError( 

2673 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}" 

2674 ) 

2675 self.tensorflow_js = value 

2676 elif key == "tensorflow_saved_model_bundle": 

2677 if value is not None and not isinstance( 

2678 value, TensorflowSavedModelBundleWeightsDescr 

2679 ): 

2680 raise TypeError( 

2681 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}" 

2682 ) 

2683 self.tensorflow_saved_model_bundle = value 

2684 elif key == "torchscript": 

2685 if value is not None and not isinstance(value, TorchscriptWeightsDescr): 

2686 raise TypeError( 

2687 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}" 

2688 ) 

2689 self.torchscript = value 

2690 else: 

2691 raise KeyError(key) 

2692 

2693 @property 

2694 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]: 

2695 return { 

2696 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2697 **({} if self.onnx is None else {"onnx": self.onnx}), 

2698 **( 

2699 {} 

2700 if self.pytorch_state_dict is None 

2701 else {"pytorch_state_dict": self.pytorch_state_dict} 

2702 ), 

2703 **( 

2704 {} 

2705 if self.tensorflow_js is None 

2706 else {"tensorflow_js": self.tensorflow_js} 

2707 ), 

2708 **( 

2709 {} 

2710 if self.tensorflow_saved_model_bundle is None 

2711 else { 

2712 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2713 } 

2714 ), 

2715 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2716 } 

2717 

2718 @property 

2719 def missing_formats(self) -> Set[WeightsFormat]: 

2720 return { 

2721 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2722 } 

2723 

2724 

2725class ModelId(ResourceId): 

2726 pass 

2727 

2728 

2729class LinkedModel(LinkedResourceBase): 

2730 """Reference to a bioimage.io model.""" 

2731 

2732 id: ModelId 

2733 """A valid model `id` from the bioimage.io collection.""" 

2734 

2735 

2736class _DataDepSize(NamedTuple): 

2737 min: StrictInt 

2738 max: Optional[StrictInt] 

2739 

2740 

2741class _AxisSizes(NamedTuple): 

2742 """the lenghts of all axes of model inputs and outputs""" 

2743 

2744 inputs: Dict[Tuple[TensorId, AxisId], int] 

2745 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2746 

2747 

2748class _TensorSizes(NamedTuple): 

2749 """_AxisSizes as nested dicts""" 

2750 

2751 inputs: Dict[TensorId, Dict[AxisId, int]] 

2752 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2753 

2754 

2755class ReproducibilityTolerance(Node, extra="allow"): 

2756 """Describes what small numerical differences -- if any -- may be tolerated 

2757 in the generated output when executing in different environments. 

2758 

2759 A tensor element *output* is considered mismatched to the **test_tensor** if 

2760 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

2761 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

2762 

2763 Motivation: 

2764 For testing we can request the respective deep learning frameworks to be as 

2765 reproducible as possible by setting seeds and chosing deterministic algorithms, 

2766 but differences in operating systems, available hardware and installed drivers 

2767 may still lead to numerical differences. 

2768 """ 

2769 

2770 relative_tolerance: RelativeTolerance = 1e-3 

2771 """Maximum relative tolerance of reproduced test tensor.""" 

2772 

2773 absolute_tolerance: AbsoluteTolerance = 1e-3 

2774 """Maximum absolute tolerance of reproduced test tensor.""" 

2775 

2776 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 

2777 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

2778 

2779 output_ids: Sequence[TensorId] = () 

2780 """Limits the output tensor IDs these reproducibility details apply to.""" 

2781 

2782 weights_formats: Sequence[WeightsFormat] = () 

2783 """Limits the weights formats these details apply to.""" 

2784 

2785 

2786class BioimageioConfig(Node, extra="allow"): 

2787 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

2788 """Tolerances to allow when reproducing the model's test outputs 

2789 from the model's test inputs. 

2790 Only the first entry matching tensor id and weights format is considered. 

2791 """ 

2792 

2793 

2794class Config(Node, extra="allow"): 

2795 bioimageio: BioimageioConfig = Field( 

2796 default_factory=BioimageioConfig.model_construct 

2797 ) 

2798 

2799 

2800class ModelDescr(GenericModelDescrBase): 

2801 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

2802 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

2803 """ 

2804 

2805 implemented_format_version: ClassVar[Literal["0.5.7"]] = "0.5.7" 

2806 if TYPE_CHECKING: 

2807 format_version: Literal["0.5.7"] = "0.5.7" 

2808 else: 

2809 format_version: Literal["0.5.7"] 

2810 """Version of the bioimage.io model description specification used. 

2811 When creating a new model always use the latest micro/patch version described here. 

2812 The `format_version` is important for any consumer software to understand how to parse the fields. 

2813 """ 

2814 

2815 implemented_type: ClassVar[Literal["model"]] = "model" 

2816 if TYPE_CHECKING: 

2817 type: Literal["model"] = "model" 

2818 else: 

2819 type: Literal["model"] 

2820 """Specialized resource type 'model'""" 

2821 

2822 id: Optional[ModelId] = None 

2823 """bioimage.io-wide unique resource identifier 

2824 assigned by bioimage.io; version **un**specific.""" 

2825 

2826 authors: FAIR[List[Author]] = Field( 

2827 default_factory=cast(Callable[[], List[Author]], list) 

2828 ) 

2829 """The authors are the creators of the model RDF and the primary points of contact.""" 

2830 

2831 documentation: FAIR[Optional[FileSource_documentation]] = None 

2832 """URL or relative path to a markdown file with additional documentation. 

2833 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

2834 The documentation should include a '#[#] Validation' (sub)section 

2835 with details on how to quantitatively validate the model on unseen data.""" 

2836 

2837 @field_validator("documentation", mode="after") 

2838 @classmethod 

2839 def _validate_documentation( 

2840 cls, value: Optional[FileSource_documentation] 

2841 ) -> Optional[FileSource_documentation]: 

2842 if not get_validation_context().perform_io_checks or value is None: 

2843 return value 

2844 

2845 doc_reader = get_reader(value) 

2846 doc_content = doc_reader.read().decode(encoding="utf-8") 

2847 if not re.search("#.*[vV]alidation", doc_content): 

2848 issue_warning( 

2849 "No '# Validation' (sub)section found in {value}.", 

2850 value=value, 

2851 field="documentation", 

2852 ) 

2853 

2854 return value 

2855 

2856 inputs: NotEmpty[Sequence[InputTensorDescr]] 

2857 """Describes the input tensors expected by this model.""" 

2858 

2859 @field_validator("inputs", mode="after") 

2860 @classmethod 

2861 def _validate_input_axes( 

2862 cls, inputs: Sequence[InputTensorDescr] 

2863 ) -> Sequence[InputTensorDescr]: 

2864 input_size_refs = cls._get_axes_with_independent_size(inputs) 

2865 

2866 for i, ipt in enumerate(inputs): 

2867 valid_independent_refs: Dict[ 

2868 Tuple[TensorId, AxisId], 

2869 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2870 ] = { 

2871 **{ 

2872 (ipt.id, a.id): (ipt, a, a.size) 

2873 for a in ipt.axes 

2874 if not isinstance(a, BatchAxis) 

2875 and isinstance(a.size, (int, ParameterizedSize)) 

2876 }, 

2877 **input_size_refs, 

2878 } 

2879 for a, ax in enumerate(ipt.axes): 

2880 cls._validate_axis( 

2881 "inputs", 

2882 i=i, 

2883 tensor_id=ipt.id, 

2884 a=a, 

2885 axis=ax, 

2886 valid_independent_refs=valid_independent_refs, 

2887 ) 

2888 return inputs 

2889 

2890 @staticmethod 

2891 def _validate_axis( 

2892 field_name: str, 

2893 i: int, 

2894 tensor_id: TensorId, 

2895 a: int, 

2896 axis: AnyAxis, 

2897 valid_independent_refs: Dict[ 

2898 Tuple[TensorId, AxisId], 

2899 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2900 ], 

2901 ): 

2902 if isinstance(axis, BatchAxis) or isinstance( 

2903 axis.size, (int, ParameterizedSize, DataDependentSize) 

2904 ): 

2905 return 

2906 elif not isinstance(axis.size, SizeReference): 

2907 assert_never(axis.size) 

2908 

2909 # validate axis.size SizeReference 

2910 ref = (axis.size.tensor_id, axis.size.axis_id) 

2911 if ref not in valid_independent_refs: 

2912 raise ValueError( 

2913 "Invalid tensor axis reference at" 

2914 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

2915 ) 

2916 if ref == (tensor_id, axis.id): 

2917 raise ValueError( 

2918 "Self-referencing not allowed for" 

2919 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

2920 ) 

2921 if axis.type == "channel": 

2922 if valid_independent_refs[ref][1].type != "channel": 

2923 raise ValueError( 

2924 "A channel axis' size may only reference another fixed size" 

2925 + " channel axis." 

2926 ) 

2927 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

2928 ref_size = valid_independent_refs[ref][2] 

2929 assert isinstance(ref_size, int), ( 

2930 "channel axis ref (another channel axis) has to specify fixed" 

2931 + " size" 

2932 ) 

2933 generated_channel_names = [ 

2934 Identifier(axis.channel_names.format(i=i)) 

2935 for i in range(1, ref_size + 1) 

2936 ] 

2937 axis.channel_names = generated_channel_names 

2938 

2939 if (ax_unit := getattr(axis, "unit", None)) != ( 

2940 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

2941 ): 

2942 raise ValueError( 

2943 "The units of an axis and its reference axis need to match, but" 

2944 + f" '{ax_unit}' != '{ref_unit}'." 

2945 ) 

2946 ref_axis = valid_independent_refs[ref][1] 

2947 if isinstance(ref_axis, BatchAxis): 

2948 raise ValueError( 

2949 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

2950 + " (a batch axis is not allowed as reference)." 

2951 ) 

2952 

2953 if isinstance(axis, WithHalo): 

2954 min_size = axis.size.get_size(axis, ref_axis, n=0) 

2955 if (min_size - 2 * axis.halo) < 1: 

2956 raise ValueError( 

2957 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

2958 + f" {axis.halo}." 

2959 ) 

2960 

2961 ref_halo = axis.halo * axis.scale / ref_axis.scale 

2962 if ref_halo != int(ref_halo): 

2963 raise ValueError( 

2964 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} =" 

2965 + f" {tensor_id}.{axis.id}.halo {axis.halo}" 

2966 + f" * {tensor_id}.{axis.id}.scale {axis.scale}" 

2967 + f" / {'.'.join(ref)}.scale {ref_axis.scale})." 

2968 ) 

2969 

2970 @model_validator(mode="after") 

2971 def _validate_test_tensors(self) -> Self: 

2972 if not get_validation_context().perform_io_checks: 

2973 return self 

2974 

2975 test_output_arrays = [ 

2976 None if descr.test_tensor is None else load_array(descr.test_tensor) 

2977 for descr in self.outputs 

2978 ] 

2979 test_input_arrays = [ 

2980 None if descr.test_tensor is None else load_array(descr.test_tensor) 

2981 for descr in self.inputs 

2982 ] 

2983 

2984 tensors = { 

2985 descr.id: (descr, array) 

2986 for descr, array in zip( 

2987 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 

2988 ) 

2989 } 

2990 validate_tensors(tensors, tensor_origin="test_tensor") 

2991 

2992 output_arrays = { 

2993 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 

2994 } 

2995 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

2996 if not rep_tol.absolute_tolerance: 

2997 continue 

2998 

2999 if rep_tol.output_ids: 

3000 out_arrays = { 

3001 oid: a 

3002 for oid, a in output_arrays.items() 

3003 if oid in rep_tol.output_ids 

3004 } 

3005 else: 

3006 out_arrays = output_arrays 

3007 

3008 for out_id, array in out_arrays.items(): 

3009 if array is None: 

3010 continue 

3011 

3012 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

3013 raise ValueError( 

3014 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

3015 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

3016 + f" (1% of the maximum value of the test tensor '{out_id}')" 

3017 ) 

3018 

3019 return self 

3020 

3021 @model_validator(mode="after") 

3022 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

3023 ipt_refs = {t.id for t in self.inputs} 

3024 out_refs = {t.id for t in self.outputs} 

3025 for ipt in self.inputs: 

3026 for p in ipt.preprocessing: 

3027 ref = p.kwargs.get("reference_tensor") 

3028 if ref is None: 

3029 continue 

3030 if ref not in ipt_refs: 

3031 raise ValueError( 

3032 f"`reference_tensor` '{ref}' not found. Valid input tensor" 

3033 + f" references are: {ipt_refs}." 

3034 ) 

3035 

3036 for out in self.outputs: 

3037 for p in out.postprocessing: 

3038 ref = p.kwargs.get("reference_tensor") 

3039 if ref is None: 

3040 continue 

3041 

3042 if ref not in ipt_refs and ref not in out_refs: 

3043 raise ValueError( 

3044 f"`reference_tensor` '{ref}' not found. Valid tensor references" 

3045 + f" are: {ipt_refs | out_refs}." 

3046 ) 

3047 

3048 return self 

3049 

3050 # TODO: use validate funcs in validate_test_tensors 

3051 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 

3052 

3053 name: Annotated[ 

3054 str, 

3055 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 

3056 MinLen(5), 

3057 MaxLen(128), 

3058 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

3059 ] 

3060 """A human-readable name of this model. 

3061 It should be no longer than 64 characters 

3062 and may only contain letter, number, underscore, minus, parentheses and spaces. 

3063 We recommend to chose a name that refers to the model's task and image modality. 

3064 """ 

3065 

3066 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

3067 """Describes the output tensors.""" 

3068 

3069 @field_validator("outputs", mode="after") 

3070 @classmethod 

3071 def _validate_tensor_ids( 

3072 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

3073 ) -> Sequence[OutputTensorDescr]: 

3074 tensor_ids = [ 

3075 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

3076 ] 

3077 duplicate_tensor_ids: List[str] = [] 

3078 seen: Set[str] = set() 

3079 for t in tensor_ids: 

3080 if t in seen: 

3081 duplicate_tensor_ids.append(t) 

3082 

3083 seen.add(t) 

3084 

3085 if duplicate_tensor_ids: 

3086 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

3087 

3088 return outputs 

3089 

3090 @staticmethod 

3091 def _get_axes_with_parameterized_size( 

3092 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3093 ): 

3094 return { 

3095 f"{t.id}.{a.id}": (t, a, a.size) 

3096 for t in io 

3097 for a in t.axes 

3098 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

3099 } 

3100 

3101 @staticmethod 

3102 def _get_axes_with_independent_size( 

3103 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3104 ): 

3105 return { 

3106 (t.id, a.id): (t, a, a.size) 

3107 for t in io 

3108 for a in t.axes 

3109 if not isinstance(a, BatchAxis) 

3110 and isinstance(a.size, (int, ParameterizedSize)) 

3111 } 

3112 

3113 @field_validator("outputs", mode="after") 

3114 @classmethod 

3115 def _validate_output_axes( 

3116 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

3117 ) -> List[OutputTensorDescr]: 

3118 input_size_refs = cls._get_axes_with_independent_size( 

3119 info.data.get("inputs", []) 

3120 ) 

3121 output_size_refs = cls._get_axes_with_independent_size(outputs) 

3122 

3123 for i, out in enumerate(outputs): 

3124 valid_independent_refs: Dict[ 

3125 Tuple[TensorId, AxisId], 

3126 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3127 ] = { 

3128 **{ 

3129 (out.id, a.id): (out, a, a.size) 

3130 for a in out.axes 

3131 if not isinstance(a, BatchAxis) 

3132 and isinstance(a.size, (int, ParameterizedSize)) 

3133 }, 

3134 **input_size_refs, 

3135 **output_size_refs, 

3136 } 

3137 for a, ax in enumerate(out.axes): 

3138 cls._validate_axis( 

3139 "outputs", 

3140 i, 

3141 out.id, 

3142 a, 

3143 ax, 

3144 valid_independent_refs=valid_independent_refs, 

3145 ) 

3146 

3147 return outputs 

3148 

3149 packaged_by: List[Author] = Field( 

3150 default_factory=cast(Callable[[], List[Author]], list) 

3151 ) 

3152 """The persons that have packaged and uploaded this model. 

3153 Only required if those persons differ from the `authors`.""" 

3154 

3155 parent: Optional[LinkedModel] = None 

3156 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

3157 

3158 @model_validator(mode="after") 

3159 def _validate_parent_is_not_self(self) -> Self: 

3160 if self.parent is not None and self.parent.id == self.id: 

3161 raise ValueError("A model description may not reference itself as parent.") 

3162 

3163 return self 

3164 

3165 run_mode: Annotated[ 

3166 Optional[RunMode], 

3167 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

3168 ] = None 

3169 """Custom run mode for this model: for more complex prediction procedures like test time 

3170 data augmentation that currently cannot be expressed in the specification. 

3171 No standard run modes are defined yet.""" 

3172 

3173 timestamp: Datetime = Field(default_factory=Datetime.now) 

3174 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

3175 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

3176 (In Python a datetime object is valid, too).""" 

3177 

3178 training_data: Annotated[ 

3179 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

3180 Field(union_mode="left_to_right"), 

3181 ] = None 

3182 """The dataset used to train this model""" 

3183 

3184 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

3185 """The weights for this model. 

3186 Weights can be given for different formats, but should otherwise be equivalent. 

3187 The available weight formats determine which consumers can use this model.""" 

3188 

3189 config: Config = Field(default_factory=Config.model_construct) 

3190 

3191 @model_validator(mode="after") 

3192 def _add_default_cover(self) -> Self: 

3193 if not get_validation_context().perform_io_checks or self.covers: 

3194 return self 

3195 

3196 try: 

3197 generated_covers = generate_covers( 

3198 [ 

3199 (t, load_array(t.test_tensor)) 

3200 for t in self.inputs 

3201 if t.test_tensor is not None 

3202 ], 

3203 [ 

3204 (t, load_array(t.test_tensor)) 

3205 for t in self.outputs 

3206 if t.test_tensor is not None 

3207 ], 

3208 ) 

3209 except Exception as e: 

3210 issue_warning( 

3211 "Failed to generate cover image(s): {e}", 

3212 value=self.covers, 

3213 msg_context=dict(e=e), 

3214 field="covers", 

3215 ) 

3216 else: 

3217 self.covers.extend(generated_covers) 

3218 

3219 return self 

3220 

3221 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

3222 return self._get_test_arrays(self.inputs) 

3223 

3224 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

3225 return self._get_test_arrays(self.outputs) 

3226 

3227 @staticmethod 

3228 def _get_test_arrays( 

3229 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3230 ): 

3231 ts: List[FileDescr] = [] 

3232 for d in io_descr: 

3233 if d.test_tensor is None: 

3234 raise ValueError( 

3235 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 

3236 ) 

3237 ts.append(d.test_tensor) 

3238 

3239 data = [load_array(t) for t in ts] 

3240 assert all(isinstance(d, np.ndarray) for d in data) 

3241 return data 

3242 

3243 @staticmethod 

3244 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

3245 batch_size = 1 

3246 tensor_with_batchsize: Optional[TensorId] = None 

3247 for tid in tensor_sizes: 

3248 for aid, s in tensor_sizes[tid].items(): 

3249 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

3250 continue 

3251 

3252 if batch_size != 1: 

3253 assert tensor_with_batchsize is not None 

3254 raise ValueError( 

3255 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

3256 ) 

3257 

3258 batch_size = s 

3259 tensor_with_batchsize = tid 

3260 

3261 return batch_size 

3262 

3263 def get_output_tensor_sizes( 

3264 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

3265 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

3266 """Returns the tensor output sizes for given **input_sizes**. 

3267 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

3268 Otherwise it might be larger than the actual (valid) output""" 

3269 batch_size = self.get_batch_size(input_sizes) 

3270 ns = self.get_ns(input_sizes) 

3271 

3272 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

3273 return tensor_sizes.outputs 

3274 

3275 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

3276 """get parameter `n` for each parameterized axis 

3277 such that the valid input size is >= the given input size""" 

3278 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

3279 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

3280 for tid in input_sizes: 

3281 for aid, s in input_sizes[tid].items(): 

3282 size_descr = axes[tid][aid].size 

3283 if isinstance(size_descr, ParameterizedSize): 

3284 ret[(tid, aid)] = size_descr.get_n(s) 

3285 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

3286 pass 

3287 else: 

3288 assert_never(size_descr) 

3289 

3290 return ret 

3291 

3292 def get_tensor_sizes( 

3293 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

3294 ) -> _TensorSizes: 

3295 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

3296 return _TensorSizes( 

3297 { 

3298 t: { 

3299 aa: axis_sizes.inputs[(tt, aa)] 

3300 for tt, aa in axis_sizes.inputs 

3301 if tt == t 

3302 } 

3303 for t in {tt for tt, _ in axis_sizes.inputs} 

3304 }, 

3305 { 

3306 t: { 

3307 aa: axis_sizes.outputs[(tt, aa)] 

3308 for tt, aa in axis_sizes.outputs 

3309 if tt == t 

3310 } 

3311 for t in {tt for tt, _ in axis_sizes.outputs} 

3312 }, 

3313 ) 

3314 

3315 def get_axis_sizes( 

3316 self, 

3317 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3318 batch_size: Optional[int] = None, 

3319 *, 

3320 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

3321 ) -> _AxisSizes: 

3322 """Determine input and output block shape for scale factors **ns** 

3323 of parameterized input sizes. 

3324 

3325 Args: 

3326 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

3327 that is parameterized as `size = min + n * step`. 

3328 batch_size: The desired size of the batch dimension. 

3329 If given **batch_size** overwrites any batch size present in 

3330 **max_input_shape**. Default 1. 

3331 max_input_shape: Limits the derived block shapes. 

3332 Each axis for which the input size, parameterized by `n`, is larger 

3333 than **max_input_shape** is set to the minimal value `n_min` for which 

3334 this is still true. 

3335 Use this for small input samples or large values of **ns**. 

3336 Or simply whenever you know the full input shape. 

3337 

3338 Returns: 

3339 Resolved axis sizes for model inputs and outputs. 

3340 """ 

3341 max_input_shape = max_input_shape or {} 

3342 if batch_size is None: 

3343 for (_t_id, a_id), s in max_input_shape.items(): 

3344 if a_id == BATCH_AXIS_ID: 

3345 batch_size = s 

3346 break 

3347 else: 

3348 batch_size = 1 

3349 

3350 all_axes = { 

3351 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

3352 } 

3353 

3354 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

3355 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

3356 

3357 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

3358 if isinstance(a, BatchAxis): 

3359 if (t_descr.id, a.id) in ns: 

3360 logger.warning( 

3361 "Ignoring unexpected size increment factor (n) for batch axis" 

3362 + " of tensor '{}'.", 

3363 t_descr.id, 

3364 ) 

3365 return batch_size 

3366 elif isinstance(a.size, int): 

3367 if (t_descr.id, a.id) in ns: 

3368 logger.warning( 

3369 "Ignoring unexpected size increment factor (n) for fixed size" 

3370 + " axis '{}' of tensor '{}'.", 

3371 a.id, 

3372 t_descr.id, 

3373 ) 

3374 return a.size 

3375 elif isinstance(a.size, ParameterizedSize): 

3376 if (t_descr.id, a.id) not in ns: 

3377 raise ValueError( 

3378 "Size increment factor (n) missing for parametrized axis" 

3379 + f" '{a.id}' of tensor '{t_descr.id}'." 

3380 ) 

3381 n = ns[(t_descr.id, a.id)] 

3382 s_max = max_input_shape.get((t_descr.id, a.id)) 

3383 if s_max is not None: 

3384 n = min(n, a.size.get_n(s_max)) 

3385 

3386 return a.size.get_size(n) 

3387 

3388 elif isinstance(a.size, SizeReference): 

3389 if (t_descr.id, a.id) in ns: 

3390 logger.warning( 

3391 "Ignoring unexpected size increment factor (n) for axis '{}'" 

3392 + " of tensor '{}' with size reference.", 

3393 a.id, 

3394 t_descr.id, 

3395 ) 

3396 assert not isinstance(a, BatchAxis) 

3397 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

3398 assert not isinstance(ref_axis, BatchAxis) 

3399 ref_key = (a.size.tensor_id, a.size.axis_id) 

3400 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

3401 assert ref_size is not None, ref_key 

3402 assert not isinstance(ref_size, _DataDepSize), ref_key 

3403 return a.size.get_size( 

3404 axis=a, 

3405 ref_axis=ref_axis, 

3406 ref_size=ref_size, 

3407 ) 

3408 elif isinstance(a.size, DataDependentSize): 

3409 if (t_descr.id, a.id) in ns: 

3410 logger.warning( 

3411 "Ignoring unexpected increment factor (n) for data dependent" 

3412 + " size axis '{}' of tensor '{}'.", 

3413 a.id, 

3414 t_descr.id, 

3415 ) 

3416 return _DataDepSize(a.size.min, a.size.max) 

3417 else: 

3418 assert_never(a.size) 

3419 

3420 # first resolve all , but the `SizeReference` input sizes 

3421 for t_descr in self.inputs: 

3422 for a in t_descr.axes: 

3423 if not isinstance(a.size, SizeReference): 

3424 s = get_axis_size(a) 

3425 assert not isinstance(s, _DataDepSize) 

3426 inputs[t_descr.id, a.id] = s 

3427 

3428 # resolve all other input axis sizes 

3429 for t_descr in self.inputs: 

3430 for a in t_descr.axes: 

3431 if isinstance(a.size, SizeReference): 

3432 s = get_axis_size(a) 

3433 assert not isinstance(s, _DataDepSize) 

3434 inputs[t_descr.id, a.id] = s 

3435 

3436 # resolve all output axis sizes 

3437 for t_descr in self.outputs: 

3438 for a in t_descr.axes: 

3439 assert not isinstance(a.size, ParameterizedSize) 

3440 s = get_axis_size(a) 

3441 outputs[t_descr.id, a.id] = s 

3442 

3443 return _AxisSizes(inputs=inputs, outputs=outputs) 

3444 

3445 @model_validator(mode="before") 

3446 @classmethod 

3447 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

3448 cls.convert_from_old_format_wo_validation(data) 

3449 return data 

3450 

3451 @classmethod 

3452 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

3453 """Convert metadata following an older format version to this classes' format 

3454 without validating the result. 

3455 """ 

3456 if ( 

3457 data.get("type") == "model" 

3458 and isinstance(fv := data.get("format_version"), str) 

3459 and fv.count(".") == 2 

3460 ): 

3461 fv_parts = fv.split(".") 

3462 if any(not p.isdigit() for p in fv_parts): 

3463 return 

3464 

3465 fv_tuple = tuple(map(int, fv_parts)) 

3466 

3467 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

3468 if fv_tuple[:2] in ((0, 3), (0, 4)): 

3469 m04 = _ModelDescr_v0_4.load(data) 

3470 if isinstance(m04, InvalidDescr): 

3471 try: 

3472 updated = _model_conv.convert_as_dict( 

3473 m04 # pyright: ignore[reportArgumentType] 

3474 ) 

3475 except Exception as e: 

3476 logger.error( 

3477 "Failed to convert from invalid model 0.4 description." 

3478 + f"\nerror: {e}" 

3479 + "\nProceeding with model 0.5 validation without conversion." 

3480 ) 

3481 updated = None 

3482 else: 

3483 updated = _model_conv.convert_as_dict(m04) 

3484 

3485 if updated is not None: 

3486 data.clear() 

3487 data.update(updated) 

3488 

3489 elif fv_tuple[:2] == (0, 5): 

3490 # bump patch version 

3491 data["format_version"] = cls.implemented_format_version 

3492 

3493 

3494class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

3495 def _convert( 

3496 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

3497 ) -> "ModelDescr | dict[str, Any]": 

3498 name = "".join( 

3499 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

3500 for c in src.name 

3501 ) 

3502 

3503 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

3504 conv = ( 

3505 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

3506 ) 

3507 return None if auths is None else [conv(a) for a in auths] 

3508 

3509 if TYPE_CHECKING: 

3510 arch_file_conv = _arch_file_conv.convert 

3511 arch_lib_conv = _arch_lib_conv.convert 

3512 else: 

3513 arch_file_conv = _arch_file_conv.convert_as_dict 

3514 arch_lib_conv = _arch_lib_conv.convert_as_dict 

3515 

3516 input_size_refs = { 

3517 ipt.name: { 

3518 a: s 

3519 for a, s in zip( 

3520 ipt.axes, 

3521 ( 

3522 ipt.shape.min 

3523 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

3524 else ipt.shape 

3525 ), 

3526 ) 

3527 } 

3528 for ipt in src.inputs 

3529 if ipt.shape 

3530 } 

3531 output_size_refs = { 

3532 **{ 

3533 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

3534 for out in src.outputs 

3535 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

3536 }, 

3537 **input_size_refs, 

3538 } 

3539 

3540 return tgt( 

3541 attachments=( 

3542 [] 

3543 if src.attachments is None 

3544 else [FileDescr(source=f) for f in src.attachments.files] 

3545 ), 

3546 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType] 

3547 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType] 

3548 config=src.config, # pyright: ignore[reportArgumentType] 

3549 covers=src.covers, 

3550 description=src.description, 

3551 documentation=src.documentation, 

3552 format_version="0.5.7", 

3553 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3554 icon=src.icon, 

3555 id=None if src.id is None else ModelId(src.id), 

3556 id_emoji=src.id_emoji, 

3557 license=src.license, # type: ignore 

3558 links=src.links, 

3559 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType] 

3560 name=name, 

3561 tags=src.tags, 

3562 type=src.type, 

3563 uploader=src.uploader, 

3564 version=src.version, 

3565 inputs=[ # pyright: ignore[reportArgumentType] 

3566 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

3567 for ipt, tt, st in zip( 

3568 src.inputs, 

3569 src.test_inputs, 

3570 src.sample_inputs or [None] * len(src.test_inputs), 

3571 ) 

3572 ], 

3573 outputs=[ # pyright: ignore[reportArgumentType] 

3574 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

3575 for out, tt, st in zip( 

3576 src.outputs, 

3577 src.test_outputs, 

3578 src.sample_outputs or [None] * len(src.test_outputs), 

3579 ) 

3580 ], 

3581 parent=( 

3582 None 

3583 if src.parent is None 

3584 else LinkedModel( 

3585 id=ModelId( 

3586 str(src.parent.id) 

3587 + ( 

3588 "" 

3589 if src.parent.version_number is None 

3590 else f"/{src.parent.version_number}" 

3591 ) 

3592 ) 

3593 ) 

3594 ), 

3595 training_data=( 

3596 None 

3597 if src.training_data is None 

3598 else ( 

3599 LinkedDataset( 

3600 id=DatasetId( 

3601 str(src.training_data.id) 

3602 + ( 

3603 "" 

3604 if src.training_data.version_number is None 

3605 else f"/{src.training_data.version_number}" 

3606 ) 

3607 ) 

3608 ) 

3609 if isinstance(src.training_data, LinkedDataset02) 

3610 else src.training_data 

3611 ) 

3612 ), 

3613 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType] 

3614 run_mode=src.run_mode, 

3615 timestamp=src.timestamp, 

3616 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

3617 keras_hdf5=(w := src.weights.keras_hdf5) 

3618 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

3619 authors=conv_authors(w.authors), 

3620 source=w.source, 

3621 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3622 parent=w.parent, 

3623 ), 

3624 onnx=(w := src.weights.onnx) 

3625 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

3626 source=w.source, 

3627 authors=conv_authors(w.authors), 

3628 parent=w.parent, 

3629 opset_version=w.opset_version or 15, 

3630 ), 

3631 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

3632 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

3633 source=w.source, 

3634 authors=conv_authors(w.authors), 

3635 parent=w.parent, 

3636 architecture=( 

3637 arch_file_conv( 

3638 w.architecture, 

3639 w.architecture_sha256, 

3640 w.kwargs, 

3641 ) 

3642 if isinstance(w.architecture, _CallableFromFile_v0_4) 

3643 else arch_lib_conv(w.architecture, w.kwargs) 

3644 ), 

3645 pytorch_version=w.pytorch_version or Version("1.10"), 

3646 dependencies=( 

3647 None 

3648 if w.dependencies is None 

3649 else (FileDescr if TYPE_CHECKING else dict)( 

3650 source=cast( 

3651 FileSource, 

3652 str(deps := w.dependencies)[ 

3653 ( 

3654 len("conda:") 

3655 if str(deps).startswith("conda:") 

3656 else 0 

3657 ) : 

3658 ], 

3659 ) 

3660 ) 

3661 ), 

3662 ), 

3663 tensorflow_js=(w := src.weights.tensorflow_js) 

3664 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

3665 source=w.source, 

3666 authors=conv_authors(w.authors), 

3667 parent=w.parent, 

3668 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3669 ), 

3670 tensorflow_saved_model_bundle=( 

3671 w := src.weights.tensorflow_saved_model_bundle 

3672 ) 

3673 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

3674 authors=conv_authors(w.authors), 

3675 parent=w.parent, 

3676 source=w.source, 

3677 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3678 dependencies=( 

3679 None 

3680 if w.dependencies is None 

3681 else (FileDescr if TYPE_CHECKING else dict)( 

3682 source=cast( 

3683 FileSource, 

3684 ( 

3685 str(w.dependencies)[len("conda:") :] 

3686 if str(w.dependencies).startswith("conda:") 

3687 else str(w.dependencies) 

3688 ), 

3689 ) 

3690 ) 

3691 ), 

3692 ), 

3693 torchscript=(w := src.weights.torchscript) 

3694 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

3695 source=w.source, 

3696 authors=conv_authors(w.authors), 

3697 parent=w.parent, 

3698 pytorch_version=w.pytorch_version or Version("1.10"), 

3699 ), 

3700 ), 

3701 ) 

3702 

3703 

3704_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

3705 

3706 

3707# create better cover images for 3d data and non-image outputs 

3708def generate_covers( 

3709 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

3710 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

3711) -> List[Path]: 

3712 def squeeze( 

3713 data: NDArray[Any], axes: Sequence[AnyAxis] 

3714 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

3715 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

3716 if data.ndim != len(axes): 

3717 raise ValueError( 

3718 f"tensor shape {data.shape} does not match described axes" 

3719 + f" {[a.id for a in axes]}" 

3720 ) 

3721 

3722 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

3723 return data.squeeze(), axes 

3724 

3725 def normalize( 

3726 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

3727 ) -> NDArray[np.float32]: 

3728 data = data.astype("float32") 

3729 data -= data.min(axis=axis, keepdims=True) 

3730 data /= data.max(axis=axis, keepdims=True) + eps 

3731 return data 

3732 

3733 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

3734 original_shape = data.shape 

3735 original_axes = list(axes) 

3736 data, axes = squeeze(data, axes) 

3737 

3738 # take slice fom any batch or index axis if needed 

3739 # and convert the first channel axis and take a slice from any additional channel axes 

3740 slices: Tuple[slice, ...] = () 

3741 ndim = data.ndim 

3742 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

3743 has_c_axis = False 

3744 for i, a in enumerate(axes): 

3745 s = data.shape[i] 

3746 assert s > 1 

3747 if ( 

3748 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

3749 and ndim > ndim_need 

3750 ): 

3751 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3752 ndim -= 1 

3753 elif isinstance(a, ChannelAxis): 

3754 if has_c_axis: 

3755 # second channel axis 

3756 data = data[slices + (slice(0, 1),)] 

3757 ndim -= 1 

3758 else: 

3759 has_c_axis = True 

3760 if s == 2: 

3761 # visualize two channels with cyan and magenta 

3762 data = np.concatenate( 

3763 [ 

3764 data[slices + (slice(1, 2),)], 

3765 data[slices + (slice(0, 1),)], 

3766 ( 

3767 data[slices + (slice(0, 1),)] 

3768 + data[slices + (slice(1, 2),)] 

3769 ) 

3770 / 2, # TODO: take maximum instead? 

3771 ], 

3772 axis=i, 

3773 ) 

3774 elif data.shape[i] == 3: 

3775 pass # visualize 3 channels as RGB 

3776 else: 

3777 # visualize first 3 channels as RGB 

3778 data = data[slices + (slice(3),)] 

3779 

3780 assert data.shape[i] == 3 

3781 

3782 slices += (slice(None),) 

3783 

3784 data, axes = squeeze(data, axes) 

3785 assert len(axes) == ndim 

3786 # take slice from z axis if needed 

3787 slices = () 

3788 if ndim > ndim_need: 

3789 for i, a in enumerate(axes): 

3790 s = data.shape[i] 

3791 if a.id == AxisId("z"): 

3792 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3793 data, axes = squeeze(data, axes) 

3794 ndim -= 1 

3795 break 

3796 

3797 slices += (slice(None),) 

3798 

3799 # take slice from any space or time axis 

3800 slices = () 

3801 

3802 for i, a in enumerate(axes): 

3803 if ndim <= ndim_need: 

3804 break 

3805 

3806 s = data.shape[i] 

3807 assert s > 1 

3808 if isinstance( 

3809 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

3810 ): 

3811 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3812 ndim -= 1 

3813 

3814 slices += (slice(None),) 

3815 

3816 del slices 

3817 data, axes = squeeze(data, axes) 

3818 assert len(axes) == ndim 

3819 

3820 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2): 

3821 raise ValueError( 

3822 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}." 

3823 ) 

3824 

3825 if not has_c_axis: 

3826 assert ndim == 2 

3827 data = np.repeat(data[:, :, None], 3, axis=2) 

3828 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

3829 ndim += 1 

3830 

3831 assert ndim == 3 

3832 

3833 # transpose axis order such that longest axis comes first... 

3834 axis_order: List[int] = list(np.argsort(list(data.shape))) 

3835 axis_order.reverse() 

3836 # ... and channel axis is last 

3837 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

3838 axis_order.append(axis_order.pop(c)) 

3839 axes = [axes[ao] for ao in axis_order] 

3840 data = data.transpose(axis_order) 

3841 

3842 # h, w = data.shape[:2] 

3843 # if h / w in (1.0 or 2.0): 

3844 # pass 

3845 # elif h / w < 2: 

3846 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

3847 

3848 norm_along = ( 

3849 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

3850 ) 

3851 # normalize the data and map to 8 bit 

3852 data = normalize(data, norm_along) 

3853 data = (data * 255).astype("uint8") 

3854 

3855 return data 

3856 

3857 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

3858 assert im0.dtype == im1.dtype == np.uint8 

3859 assert im0.shape == im1.shape 

3860 assert im0.ndim == 3 

3861 N, M, C = im0.shape 

3862 assert C == 3 

3863 out = np.ones((N, M, C), dtype="uint8") 

3864 for c in range(C): 

3865 outc = np.tril(im0[..., c]) 

3866 mask = outc == 0 

3867 outc[mask] = np.triu(im1[..., c])[mask] 

3868 out[..., c] = outc 

3869 

3870 return out 

3871 

3872 if not inputs: 

3873 raise ValueError("Missing test input tensor for cover generation.") 

3874 

3875 if not outputs: 

3876 raise ValueError("Missing test output tensor for cover generation.") 

3877 

3878 ipt_descr, ipt = inputs[0] 

3879 out_descr, out = outputs[0] 

3880 

3881 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

3882 out_img = to_2d_image(out, out_descr.axes) 

3883 

3884 cover_folder = Path(mkdtemp()) 

3885 if ipt_img.shape == out_img.shape: 

3886 covers = [cover_folder / "cover.png"] 

3887 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

3888 else: 

3889 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

3890 imwrite(covers[0], ipt_img) 

3891 imwrite(covers[1], out_img) 

3892 

3893 return covers