Coverage for src / bioimageio / spec / model / v0_5.py: 74%

1391 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-08 13:04 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from abc import ABC 

8from copy import deepcopy 

9from itertools import chain 

10from math import ceil 

11from pathlib import Path, PurePosixPath 

12from tempfile import mkdtemp 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 ClassVar, 

18 Dict, 

19 Generic, 

20 List, 

21 Literal, 

22 Mapping, 

23 NamedTuple, 

24 Optional, 

25 Sequence, 

26 Set, 

27 Tuple, 

28 Type, 

29 TypeVar, 

30 Union, 

31 cast, 

32 overload, 

33) 

34 

35import numpy as np 

36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

38from loguru import logger 

39from numpy.typing import NDArray 

40from pydantic import ( 

41 AfterValidator, 

42 Discriminator, 

43 Field, 

44 RootModel, 

45 SerializationInfo, 

46 SerializerFunctionWrapHandler, 

47 StrictInt, 

48 Tag, 

49 ValidationInfo, 

50 WrapSerializer, 

51 field_validator, 

52 model_serializer, 

53 model_validator, 

54) 

55from typing_extensions import Annotated, Self, assert_never, get_args 

56 

57from .._internal.common_nodes import ( 

58 InvalidDescr, 

59 Node, 

60 NodeWithExplicitlySetFields, 

61) 

62from .._internal.constants import DTYPE_LIMITS 

63from .._internal.field_warning import issue_warning, warn 

64from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

65from .._internal.io import FileDescr as FileDescr 

66from .._internal.io import ( 

67 FileSource, 

68 WithSuffix, 

69 YamlValue, 

70 extract_file_name, 

71 get_reader, 

72 wo_special_file_name, 

73) 

74from .._internal.io_basics import Sha256 as Sha256 

75from .._internal.io_packaging import ( 

76 FileDescr_, 

77 FileSource_, 

78 package_file_descr_serializer, 

79) 

80from .._internal.io_utils import load_array 

81from .._internal.node_converter import Converter 

82from .._internal.type_guards import is_dict, is_sequence 

83from .._internal.types import ( 

84 FAIR, 

85 AbsoluteTolerance, 

86 LowerCaseIdentifier, 

87 LowerCaseIdentifierAnno, 

88 MismatchedElementsPerMillion, 

89 RelativeTolerance, 

90) 

91from .._internal.types import Datetime as Datetime 

92from .._internal.types import Identifier as Identifier 

93from .._internal.types import NotEmpty as NotEmpty 

94from .._internal.types import SiUnit as SiUnit 

95from .._internal.url import HttpUrl as HttpUrl 

96from .._internal.validation_context import get_validation_context 

97from .._internal.validator_annotations import RestrictCharacters 

98from .._internal.version_type import Version as Version 

99from .._internal.warning_levels import INFO 

100from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

101from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

102from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

103from ..dataset.v0_3 import DatasetId as DatasetId 

104from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

105from ..dataset.v0_3 import Uploader as Uploader 

106from ..generic.v0_3 import ( 

107 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

108) 

109from ..generic.v0_3 import Author as Author 

110from ..generic.v0_3 import BadgeDescr as BadgeDescr 

111from ..generic.v0_3 import CiteEntry as CiteEntry 

112from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

113from ..generic.v0_3 import Doi as Doi 

114from ..generic.v0_3 import ( 

115 FileSource_documentation, 

116 GenericModelDescrBase, 

117 LinkedResourceBase, 

118 _author_conv, # pyright: ignore[reportPrivateUsage] 

119 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

120) 

121from ..generic.v0_3 import LicenseId as LicenseId 

122from ..generic.v0_3 import LinkedResource as LinkedResource 

123from ..generic.v0_3 import Maintainer as Maintainer 

124from ..generic.v0_3 import OrcidId as OrcidId 

125from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

126from ..generic.v0_3 import ResourceId as ResourceId 

127from .v0_4 import Author as _Author_v0_4 

128from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

129from .v0_4 import CallableFromDepencency as CallableFromDepencency 

130from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

131from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

132from .v0_4 import ClipDescr as _ClipDescr_v0_4 

133from .v0_4 import ClipKwargs as ClipKwargs 

134from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

135from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

136from .v0_4 import KnownRunMode as KnownRunMode 

137from .v0_4 import ModelDescr as _ModelDescr_v0_4 

138from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

139from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

140from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

141from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

142from .v0_4 import ProcessingKwargs as ProcessingKwargs 

143from .v0_4 import RunMode as RunMode 

144from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

145from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

146from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

147from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

148from .v0_4 import TensorName as _TensorName_v0_4 

149from .v0_4 import WeightsFormat as WeightsFormat 

150from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

151from .v0_4 import package_weights 

152 

153SpaceUnit = Literal[ 

154 "attometer", 

155 "angstrom", 

156 "centimeter", 

157 "decimeter", 

158 "exameter", 

159 "femtometer", 

160 "foot", 

161 "gigameter", 

162 "hectometer", 

163 "inch", 

164 "kilometer", 

165 "megameter", 

166 "meter", 

167 "micrometer", 

168 "mile", 

169 "millimeter", 

170 "nanometer", 

171 "parsec", 

172 "petameter", 

173 "picometer", 

174 "terameter", 

175 "yard", 

176 "yoctometer", 

177 "yottameter", 

178 "zeptometer", 

179 "zettameter", 

180] 

181"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

182 

183TimeUnit = Literal[ 

184 "attosecond", 

185 "centisecond", 

186 "day", 

187 "decisecond", 

188 "exasecond", 

189 "femtosecond", 

190 "gigasecond", 

191 "hectosecond", 

192 "hour", 

193 "kilosecond", 

194 "megasecond", 

195 "microsecond", 

196 "millisecond", 

197 "minute", 

198 "nanosecond", 

199 "petasecond", 

200 "picosecond", 

201 "second", 

202 "terasecond", 

203 "yoctosecond", 

204 "yottasecond", 

205 "zeptosecond", 

206 "zettasecond", 

207] 

208"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

209 

210AxisType = Literal["batch", "channel", "index", "time", "space"] 

211 

212_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

213 "b": "batch", 

214 "t": "time", 

215 "i": "index", 

216 "c": "channel", 

217 "x": "space", 

218 "y": "space", 

219 "z": "space", 

220} 

221 

222_AXIS_ID_MAP = { 

223 "b": "batch", 

224 "t": "time", 

225 "i": "index", 

226 "c": "channel", 

227} 

228 

229 

230class TensorId(LowerCaseIdentifier): 

231 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

232 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

233 ] 

234 

235 

236def _normalize_axis_id(a: str): 

237 a = str(a) 

238 normalized = _AXIS_ID_MAP.get(a, a) 

239 if a != normalized: 

240 logger.opt(depth=3).warning( 

241 "Normalized axis id from '{}' to '{}'.", a, normalized 

242 ) 

243 return normalized 

244 

245 

246class AxisId(LowerCaseIdentifier): 

247 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

248 Annotated[ 

249 LowerCaseIdentifierAnno, 

250 MaxLen(16), 

251 AfterValidator(_normalize_axis_id), 

252 ] 

253 ] 

254 

255 

256def _is_batch(a: str) -> bool: 

257 return str(a) == "batch" 

258 

259 

260def _is_not_batch(a: str) -> bool: 

261 return not _is_batch(a) 

262 

263 

264NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

265 

266PreprocessingId = Literal[ 

267 "binarize", 

268 "clip", 

269 "ensure_dtype", 

270 "fixed_zero_mean_unit_variance", 

271 "scale_linear", 

272 "scale_range", 

273 "sigmoid", 

274 "softmax", 

275] 

276PostprocessingId = Literal[ 

277 "binarize", 

278 "clip", 

279 "ensure_dtype", 

280 "fixed_zero_mean_unit_variance", 

281 "scale_linear", 

282 "scale_mean_variance", 

283 "scale_range", 

284 "sigmoid", 

285 "softmax", 

286 "zero_mean_unit_variance", 

287] 

288 

289 

290SAME_AS_TYPE = "<same as type>" 

291 

292 

293ParameterizedSize_N = int 

294""" 

295Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

296""" 

297 

298 

299class ParameterizedSize(Node): 

300 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

301 

302 - **min** and **step** are given by the model description. 

303 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

304 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

305 This allows to adjust the axis size more generically. 

306 """ 

307 

308 N: ClassVar[Type[int]] = ParameterizedSize_N 

309 """Positive integer to parameterize this axis""" 

310 

311 min: Annotated[int, Gt(0)] 

312 step: Annotated[int, Gt(0)] 

313 

314 def validate_size(self, size: int) -> int: 

315 if size < self.min: 

316 raise ValueError(f"size {size} < {self.min}") 

317 if (size - self.min) % self.step != 0: 

318 raise ValueError( 

319 f"axis of size {size} is not parameterized by `min + n*step` =" 

320 + f" `{self.min} + n*{self.step}`" 

321 ) 

322 

323 return size 

324 

325 def get_size(self, n: ParameterizedSize_N) -> int: 

326 return self.min + self.step * n 

327 

328 def get_n(self, s: int) -> ParameterizedSize_N: 

329 """return smallest n parameterizing a size greater or equal than `s`""" 

330 return ceil((s - self.min) / self.step) 

331 

332 

333class DataDependentSize(Node): 

334 min: Annotated[int, Gt(0)] = 1 

335 max: Annotated[Optional[int], Gt(1)] = None 

336 

337 @model_validator(mode="after") 

338 def _validate_max_gt_min(self): 

339 if self.max is not None and self.min >= self.max: 

340 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

341 

342 return self 

343 

344 def validate_size(self, size: int) -> int: 

345 if size < self.min: 

346 raise ValueError(f"size {size} < {self.min}") 

347 

348 if self.max is not None and size > self.max: 

349 raise ValueError(f"size {size} > {self.max}") 

350 

351 return size 

352 

353 

354class SizeReference(Node): 

355 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

356 

357 `axis.size = reference.size * reference.scale / axis.scale + offset` 

358 

359 Note: 

360 1. The axis and the referenced axis need to have the same unit (or no unit). 

361 2. Batch axes may not be referenced. 

362 3. Fractions are rounded down. 

363 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

364 `concatenable` as well with the same block order. 

365 

366 Example: 

367 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

368 Let's assume that we want to express the image height h in relation to its width w 

369 instead of only accepting input images of exactly 100*49 pixels 

370 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

371 

372 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

373 >>> h = SpaceInputAxis( 

374 ... id=AxisId("h"), 

375 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

376 ... unit="millimeter", 

377 ... scale=4, 

378 ... ) 

379 >>> print(h.size.get_size(h, w)) 

380 49 

381 

382 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

383 """ 

384 

385 tensor_id: TensorId 

386 """tensor id of the reference axis""" 

387 

388 axis_id: AxisId 

389 """axis id of the reference axis""" 

390 

391 offset: StrictInt = 0 

392 

393 def get_size( 

394 self, 

395 axis: Union[ 

396 ChannelAxis, 

397 IndexInputAxis, 

398 IndexOutputAxis, 

399 TimeInputAxis, 

400 SpaceInputAxis, 

401 TimeOutputAxis, 

402 TimeOutputAxisWithHalo, 

403 SpaceOutputAxis, 

404 SpaceOutputAxisWithHalo, 

405 ], 

406 ref_axis: Union[ 

407 ChannelAxis, 

408 IndexInputAxis, 

409 IndexOutputAxis, 

410 TimeInputAxis, 

411 SpaceInputAxis, 

412 TimeOutputAxis, 

413 TimeOutputAxisWithHalo, 

414 SpaceOutputAxis, 

415 SpaceOutputAxisWithHalo, 

416 ], 

417 n: ParameterizedSize_N = 0, 

418 ref_size: Optional[int] = None, 

419 ): 

420 """Compute the concrete size for a given axis and its reference axis. 

421 

422 Args: 

423 axis: The axis this `SizeReference` is the size of. 

424 ref_axis: The reference axis to compute the size from. 

425 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

426 and no fixed **ref_size** is given, 

427 **n** is used to compute the size of the parameterized **ref_axis**. 

428 ref_size: Overwrite the reference size instead of deriving it from 

429 **ref_axis** 

430 (**ref_axis.scale** is still used; any given **n** is ignored). 

431 """ 

432 assert axis.size == self, ( 

433 "Given `axis.size` is not defined by this `SizeReference`" 

434 ) 

435 

436 assert ref_axis.id == self.axis_id, ( 

437 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

438 ) 

439 

440 assert axis.unit == ref_axis.unit, ( 

441 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

442 f" but {axis.unit}!={ref_axis.unit}" 

443 ) 

444 if ref_size is None: 

445 if isinstance(ref_axis.size, (int, float)): 

446 ref_size = ref_axis.size 

447 elif isinstance(ref_axis.size, ParameterizedSize): 

448 ref_size = ref_axis.size.get_size(n) 

449 elif isinstance(ref_axis.size, DataDependentSize): 

450 raise ValueError( 

451 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

452 ) 

453 elif isinstance(ref_axis.size, SizeReference): 

454 raise ValueError( 

455 "Reference axis referenced in `SizeReference` may not be sized by a" 

456 + " `SizeReference` itself." 

457 ) 

458 else: 

459 assert_never(ref_axis.size) 

460 

461 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

462 

463 @staticmethod 

464 def _get_unit( 

465 axis: Union[ 

466 ChannelAxis, 

467 IndexInputAxis, 

468 IndexOutputAxis, 

469 TimeInputAxis, 

470 SpaceInputAxis, 

471 TimeOutputAxis, 

472 TimeOutputAxisWithHalo, 

473 SpaceOutputAxis, 

474 SpaceOutputAxisWithHalo, 

475 ], 

476 ): 

477 return axis.unit 

478 

479 

480class AxisBase(NodeWithExplicitlySetFields): 

481 id: AxisId 

482 """An axis id unique across all axes of one tensor.""" 

483 

484 description: Annotated[str, MaxLen(128)] = "" 

485 """A short description of this axis beyond its type and id.""" 

486 

487 

488class WithHalo(Node): 

489 halo: Annotated[int, Ge(1)] 

490 """The halo should be cropped from the output tensor to avoid boundary effects. 

491 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

492 To document a halo that is already cropped by the model use `size.offset` instead.""" 

493 

494 size: Annotated[ 

495 SizeReference, 

496 Field( 

497 examples=[ 

498 10, 

499 SizeReference( 

500 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

501 ).model_dump(mode="json"), 

502 ] 

503 ), 

504 ] 

505 """reference to another axis with an optional offset (see `SizeReference`)""" 

506 

507 

508BATCH_AXIS_ID = AxisId("batch") 

509 

510 

511class BatchAxis(AxisBase): 

512 implemented_type: ClassVar[Literal["batch"]] = "batch" 

513 if TYPE_CHECKING: 

514 type: Literal["batch"] = "batch" 

515 else: 

516 type: Literal["batch"] 

517 

518 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

519 size: Optional[Literal[1]] = None 

520 """The batch size may be fixed to 1, 

521 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

522 

523 @property 

524 def scale(self): 

525 return 1.0 

526 

527 @property 

528 def concatenable(self): 

529 return True 

530 

531 @property 

532 def unit(self): 

533 return None 

534 

535 

536class ChannelAxis(AxisBase): 

537 implemented_type: ClassVar[Literal["channel"]] = "channel" 

538 if TYPE_CHECKING: 

539 type: Literal["channel"] = "channel" 

540 else: 

541 type: Literal["channel"] 

542 

543 id: NonBatchAxisId = AxisId("channel") 

544 

545 channel_names: NotEmpty[List[Identifier]] 

546 

547 @property 

548 def size(self) -> int: 

549 return len(self.channel_names) 

550 

551 @property 

552 def concatenable(self): 

553 return False 

554 

555 @property 

556 def scale(self) -> float: 

557 return 1.0 

558 

559 @property 

560 def unit(self): 

561 return None 

562 

563 

564class IndexAxisBase(AxisBase): 

565 implemented_type: ClassVar[Literal["index"]] = "index" 

566 if TYPE_CHECKING: 

567 type: Literal["index"] = "index" 

568 else: 

569 type: Literal["index"] 

570 

571 id: NonBatchAxisId = AxisId("index") 

572 

573 @property 

574 def scale(self) -> float: 

575 return 1.0 

576 

577 @property 

578 def unit(self): 

579 return None 

580 

581 

582class _WithInputAxisSize(Node): 

583 size: Annotated[ 

584 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

585 Field( 

586 examples=[ 

587 10, 

588 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

589 SizeReference( 

590 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

591 ).model_dump(mode="json"), 

592 ] 

593 ), 

594 ] 

595 """The size/length of this axis can be specified as 

596 - fixed integer 

597 - parameterized series of valid sizes (`ParameterizedSize`) 

598 - reference to another axis with an optional offset (`SizeReference`) 

599 """ 

600 

601 

602class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

603 concatenable: bool = False 

604 """If a model has a `concatenable` input axis, it can be processed blockwise, 

605 splitting a longer sample axis into blocks matching its input tensor description. 

606 Output axes are concatenable if they have a `SizeReference` to a concatenable 

607 input axis. 

608 """ 

609 

610 

611class IndexOutputAxis(IndexAxisBase): 

612 size: Annotated[ 

613 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

614 Field( 

615 examples=[ 

616 10, 

617 SizeReference( 

618 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

619 ).model_dump(mode="json"), 

620 ] 

621 ), 

622 ] 

623 """The size/length of this axis can be specified as 

624 - fixed integer 

625 - reference to another axis with an optional offset (`SizeReference`) 

626 - data dependent size using `DataDependentSize` (size is only known after model inference) 

627 """ 

628 

629 

630class TimeAxisBase(AxisBase): 

631 implemented_type: ClassVar[Literal["time"]] = "time" 

632 if TYPE_CHECKING: 

633 type: Literal["time"] = "time" 

634 else: 

635 type: Literal["time"] 

636 

637 id: NonBatchAxisId = AxisId("time") 

638 unit: Optional[TimeUnit] = None 

639 scale: Annotated[float, Gt(0)] = 1.0 

640 

641 

642class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

643 concatenable: bool = False 

644 """If a model has a `concatenable` input axis, it can be processed blockwise, 

645 splitting a longer sample axis into blocks matching its input tensor description. 

646 Output axes are concatenable if they have a `SizeReference` to a concatenable 

647 input axis. 

648 """ 

649 

650 

651class SpaceAxisBase(AxisBase): 

652 implemented_type: ClassVar[Literal["space"]] = "space" 

653 if TYPE_CHECKING: 

654 type: Literal["space"] = "space" 

655 else: 

656 type: Literal["space"] 

657 

658 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

659 unit: Optional[SpaceUnit] = None 

660 scale: Annotated[float, Gt(0)] = 1.0 

661 

662 

663class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

664 concatenable: bool = False 

665 """If a model has a `concatenable` input axis, it can be processed blockwise, 

666 splitting a longer sample axis into blocks matching its input tensor description. 

667 Output axes are concatenable if they have a `SizeReference` to a concatenable 

668 input axis. 

669 """ 

670 

671 

672INPUT_AXIS_TYPES = ( 

673 BatchAxis, 

674 ChannelAxis, 

675 IndexInputAxis, 

676 TimeInputAxis, 

677 SpaceInputAxis, 

678) 

679"""intended for isinstance comparisons in py<3.10""" 

680 

681_InputAxisUnion = Union[ 

682 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

683] 

684InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

685 

686 

687class _WithOutputAxisSize(Node): 

688 size: Annotated[ 

689 Union[Annotated[int, Gt(0)], SizeReference], 

690 Field( 

691 examples=[ 

692 10, 

693 SizeReference( 

694 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

695 ).model_dump(mode="json"), 

696 ] 

697 ), 

698 ] 

699 """The size/length of this axis can be specified as 

700 - fixed integer 

701 - reference to another axis with an optional offset (see `SizeReference`) 

702 """ 

703 

704 

705class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

706 pass 

707 

708 

709class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

710 pass 

711 

712 

713def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

714 if isinstance(v, dict): 

715 return "with_halo" if "halo" in v else "wo_halo" 

716 else: 

717 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

718 

719 

720_TimeOutputAxisUnion = Annotated[ 

721 Union[ 

722 Annotated[TimeOutputAxis, Tag("wo_halo")], 

723 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

724 ], 

725 Discriminator(_get_halo_axis_discriminator_value), 

726] 

727 

728 

729class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

730 pass 

731 

732 

733class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

734 pass 

735 

736 

737_SpaceOutputAxisUnion = Annotated[ 

738 Union[ 

739 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

740 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

741 ], 

742 Discriminator(_get_halo_axis_discriminator_value), 

743] 

744 

745 

746_OutputAxisUnion = Union[ 

747 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

748] 

749OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

750 

751OUTPUT_AXIS_TYPES = ( 

752 BatchAxis, 

753 ChannelAxis, 

754 IndexOutputAxis, 

755 TimeOutputAxis, 

756 TimeOutputAxisWithHalo, 

757 SpaceOutputAxis, 

758 SpaceOutputAxisWithHalo, 

759) 

760"""intended for isinstance comparisons in py<3.10""" 

761 

762 

763AnyAxis = Union[InputAxis, OutputAxis] 

764 

765ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

766"""intended for isinstance comparisons in py<3.10""" 

767 

768TVs = Union[ 

769 NotEmpty[List[int]], 

770 NotEmpty[List[float]], 

771 NotEmpty[List[bool]], 

772 NotEmpty[List[str]], 

773] 

774 

775 

776NominalOrOrdinalDType = Literal[ 

777 "float32", 

778 "float64", 

779 "uint8", 

780 "int8", 

781 "uint16", 

782 "int16", 

783 "uint32", 

784 "int32", 

785 "uint64", 

786 "int64", 

787 "bool", 

788] 

789 

790 

791class NominalOrOrdinalDataDescr(Node): 

792 values: TVs 

793 """A fixed set of nominal or an ascending sequence of ordinal values. 

794 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

795 String `values` are interpreted as labels for tensor values 0, ..., N. 

796 Note: as YAML 1.2 does not natively support a "set" datatype, 

797 nominal values should be given as a sequence (aka list/array) as well. 

798 """ 

799 

800 type: Annotated[ 

801 NominalOrOrdinalDType, 

802 Field( 

803 examples=[ 

804 "float32", 

805 "uint8", 

806 "uint16", 

807 "int64", 

808 "bool", 

809 ], 

810 ), 

811 ] = "uint8" 

812 

813 @model_validator(mode="after") 

814 def _validate_values_match_type( 

815 self, 

816 ) -> Self: 

817 incompatible: List[Any] = [] 

818 for v in self.values: 

819 if self.type == "bool": 

820 if not isinstance(v, bool): 

821 incompatible.append(v) 

822 elif self.type in DTYPE_LIMITS: 

823 if ( 

824 isinstance(v, (int, float)) 

825 and ( 

826 v < DTYPE_LIMITS[self.type].min 

827 or v > DTYPE_LIMITS[self.type].max 

828 ) 

829 or (isinstance(v, str) and "uint" not in self.type) 

830 or (isinstance(v, float) and "int" in self.type) 

831 ): 

832 incompatible.append(v) 

833 else: 

834 incompatible.append(v) 

835 

836 if len(incompatible) == 5: 

837 incompatible.append("...") 

838 break 

839 

840 if incompatible: 

841 raise ValueError( 

842 f"data type '{self.type}' incompatible with values {incompatible}" 

843 ) 

844 

845 return self 

846 

847 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

848 

849 @property 

850 def range(self): 

851 if isinstance(self.values[0], str): 

852 return 0, len(self.values) - 1 

853 else: 

854 return min(self.values), max(self.values) 

855 

856 

857IntervalOrRatioDType = Literal[ 

858 "float32", 

859 "float64", 

860 "uint8", 

861 "int8", 

862 "uint16", 

863 "int16", 

864 "uint32", 

865 "int32", 

866 "uint64", 

867 "int64", 

868] 

869 

870 

871class IntervalOrRatioDataDescr(Node): 

872 type: Annotated[ # TODO: rename to dtype 

873 IntervalOrRatioDType, 

874 Field( 

875 examples=["float32", "float64", "uint8", "uint16"], 

876 ), 

877 ] = "float32" 

878 range: Tuple[Optional[float], Optional[float]] = ( 

879 None, 

880 None, 

881 ) 

882 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

883 `None` corresponds to min/max of what can be expressed by **type**.""" 

884 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

885 scale: float = 1.0 

886 """Scale for data on an interval (or ratio) scale.""" 

887 offset: Optional[float] = None 

888 """Offset for data on a ratio scale.""" 

889 

890 @model_validator(mode="before") 

891 def _replace_inf(cls, data: Any): 

892 if is_dict(data): 

893 if "range" in data and is_sequence(data["range"]): 

894 forbidden = ( 

895 "inf", 

896 "-inf", 

897 ".inf", 

898 "-.inf", 

899 float("inf"), 

900 float("-inf"), 

901 ) 

902 if any(v in forbidden for v in data["range"]): 

903 issue_warning("replaced 'inf' value", value=data["range"]) 

904 

905 data["range"] = tuple( 

906 (None if v in forbidden else v) for v in data["range"] 

907 ) 

908 

909 return data 

910 

911 

912TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

913 

914 

915class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 

916 """processing base class""" 

917 

918 

919class BinarizeKwargs(ProcessingKwargs): 

920 """key word arguments for `BinarizeDescr`""" 

921 

922 threshold: float 

923 """The fixed threshold""" 

924 

925 

926class BinarizeAlongAxisKwargs(ProcessingKwargs): 

927 """key word arguments for `BinarizeDescr`""" 

928 

929 threshold: NotEmpty[List[float]] 

930 """The fixed threshold values along `axis`""" 

931 

932 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

933 """The `threshold` axis""" 

934 

935 

936class BinarizeDescr(ProcessingDescrBase): 

937 """Binarize the tensor with a fixed threshold. 

938 

939 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 

940 will be set to one, values below the threshold to zero. 

941 

942 Examples: 

943 - in YAML 

944 ```yaml 

945 postprocessing: 

946 - id: binarize 

947 kwargs: 

948 axis: 'channel' 

949 threshold: [0.25, 0.5, 0.75] 

950 ``` 

951 - in Python: 

952 >>> postprocessing = [BinarizeDescr( 

953 ... kwargs=BinarizeAlongAxisKwargs( 

954 ... axis=AxisId('channel'), 

955 ... threshold=[0.25, 0.5, 0.75], 

956 ... ) 

957 ... )] 

958 """ 

959 

960 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

961 if TYPE_CHECKING: 

962 id: Literal["binarize"] = "binarize" 

963 else: 

964 id: Literal["binarize"] 

965 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

966 

967 

968class ClipDescr(ProcessingDescrBase): 

969 """Set tensor values below min to min and above max to max. 

970 

971 See `ScaleRangeDescr` for examples. 

972 """ 

973 

974 implemented_id: ClassVar[Literal["clip"]] = "clip" 

975 if TYPE_CHECKING: 

976 id: Literal["clip"] = "clip" 

977 else: 

978 id: Literal["clip"] 

979 

980 kwargs: ClipKwargs 

981 

982 

983class EnsureDtypeKwargs(ProcessingKwargs): 

984 """key word arguments for `EnsureDtypeDescr`""" 

985 

986 dtype: Literal[ 

987 "float32", 

988 "float64", 

989 "uint8", 

990 "int8", 

991 "uint16", 

992 "int16", 

993 "uint32", 

994 "int32", 

995 "uint64", 

996 "int64", 

997 "bool", 

998 ] 

999 

1000 

1001class EnsureDtypeDescr(ProcessingDescrBase): 

1002 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

1003 

1004 This can for example be used to ensure the inner neural network model gets a 

1005 different input tensor data type than the fully described bioimage.io model does. 

1006 

1007 Examples: 

1008 The described bioimage.io model (incl. preprocessing) accepts any 

1009 float32-compatible tensor, normalizes it with percentiles and clipping and then 

1010 casts it to uint8, which is what the neural network in this example expects. 

1011 - in YAML 

1012 ```yaml 

1013 inputs: 

1014 - data: 

1015 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

1016 preprocessing: 

1017 - id: scale_range 

1018 kwargs: 

1019 axes: ['y', 'x'] 

1020 max_percentile: 99.8 

1021 min_percentile: 5.0 

1022 - id: clip 

1023 kwargs: 

1024 min: 0.0 

1025 max: 1.0 

1026 - id: ensure_dtype # the neural network of the model requires uint8 

1027 kwargs: 

1028 dtype: uint8 

1029 ``` 

1030 - in Python: 

1031 >>> preprocessing = [ 

1032 ... ScaleRangeDescr( 

1033 ... kwargs=ScaleRangeKwargs( 

1034 ... axes= (AxisId('y'), AxisId('x')), 

1035 ... max_percentile= 99.8, 

1036 ... min_percentile= 5.0, 

1037 ... ) 

1038 ... ), 

1039 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

1040 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

1041 ... ] 

1042 """ 

1043 

1044 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1045 if TYPE_CHECKING: 

1046 id: Literal["ensure_dtype"] = "ensure_dtype" 

1047 else: 

1048 id: Literal["ensure_dtype"] 

1049 

1050 kwargs: EnsureDtypeKwargs 

1051 

1052 

1053class ScaleLinearKwargs(ProcessingKwargs): 

1054 """Key word arguments for `ScaleLinearDescr`""" 

1055 

1056 gain: float = 1.0 

1057 """multiplicative factor""" 

1058 

1059 offset: float = 0.0 

1060 """additive term""" 

1061 

1062 @model_validator(mode="after") 

1063 def _validate(self) -> Self: 

1064 if self.gain == 1.0 and self.offset == 0.0: 

1065 raise ValueError( 

1066 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1067 + " != 0.0." 

1068 ) 

1069 

1070 return self 

1071 

1072 

1073class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 

1074 """Key word arguments for `ScaleLinearDescr`""" 

1075 

1076 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1077 """The axis of gain and offset values.""" 

1078 

1079 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1080 """multiplicative factor""" 

1081 

1082 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1083 """additive term""" 

1084 

1085 @model_validator(mode="after") 

1086 def _validate(self) -> Self: 

1087 if isinstance(self.gain, list): 

1088 if isinstance(self.offset, list): 

1089 if len(self.gain) != len(self.offset): 

1090 raise ValueError( 

1091 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1092 ) 

1093 else: 

1094 self.offset = [float(self.offset)] * len(self.gain) 

1095 elif isinstance(self.offset, list): 

1096 self.gain = [float(self.gain)] * len(self.offset) 

1097 else: 

1098 raise ValueError( 

1099 "Do not specify an `axis` for scalar gain and offset values." 

1100 ) 

1101 

1102 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1103 raise ValueError( 

1104 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1105 + " != 0.0." 

1106 ) 

1107 

1108 return self 

1109 

1110 

1111class ScaleLinearDescr(ProcessingDescrBase): 

1112 """Fixed linear scaling. 

1113 

1114 Examples: 

1115 1. Scale with scalar gain and offset 

1116 - in YAML 

1117 ```yaml 

1118 preprocessing: 

1119 - id: scale_linear 

1120 kwargs: 

1121 gain: 2.0 

1122 offset: 3.0 

1123 ``` 

1124 - in Python: 

1125 >>> preprocessing = [ 

1126 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1127 ... ] 

1128 

1129 2. Independent scaling along an axis 

1130 - in YAML 

1131 ```yaml 

1132 preprocessing: 

1133 - id: scale_linear 

1134 kwargs: 

1135 axis: 'channel' 

1136 gain: [1.0, 2.0, 3.0] 

1137 ``` 

1138 - in Python: 

1139 >>> preprocessing = [ 

1140 ... ScaleLinearDescr( 

1141 ... kwargs=ScaleLinearAlongAxisKwargs( 

1142 ... axis=AxisId("channel"), 

1143 ... gain=[1.0, 2.0, 3.0], 

1144 ... ) 

1145 ... ) 

1146 ... ] 

1147 

1148 """ 

1149 

1150 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1151 if TYPE_CHECKING: 

1152 id: Literal["scale_linear"] = "scale_linear" 

1153 else: 

1154 id: Literal["scale_linear"] 

1155 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1156 

1157 

1158class SigmoidDescr(ProcessingDescrBase): 

1159 """The logistic sigmoid function, a.k.a. expit function. 

1160 

1161 Examples: 

1162 - in YAML 

1163 ```yaml 

1164 postprocessing: 

1165 - id: sigmoid 

1166 ``` 

1167 - in Python: 

1168 >>> postprocessing = [SigmoidDescr()] 

1169 """ 

1170 

1171 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1172 if TYPE_CHECKING: 

1173 id: Literal["sigmoid"] = "sigmoid" 

1174 else: 

1175 id: Literal["sigmoid"] 

1176 

1177 @property 

1178 def kwargs(self) -> ProcessingKwargs: 

1179 """empty kwargs""" 

1180 return ProcessingKwargs() 

1181 

1182 

1183class SoftmaxKwargs(ProcessingKwargs): 

1184 """key word arguments for `SoftmaxDescr`""" 

1185 

1186 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 

1187 """The axis to apply the softmax function along. 

1188 Note: 

1189 Defaults to 'channel' axis 

1190 (which may not exist, in which case 

1191 a different axis id has to be specified). 

1192 """ 

1193 

1194 

1195class SoftmaxDescr(ProcessingDescrBase): 

1196 """The softmax function. 

1197 

1198 Examples: 

1199 - in YAML 

1200 ```yaml 

1201 postprocessing: 

1202 - id: softmax 

1203 kwargs: 

1204 axis: channel 

1205 ``` 

1206 - in Python: 

1207 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 

1208 """ 

1209 

1210 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 

1211 if TYPE_CHECKING: 

1212 id: Literal["softmax"] = "softmax" 

1213 else: 

1214 id: Literal["softmax"] 

1215 

1216 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 

1217 

1218 

1219class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1220 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1221 

1222 mean: float 

1223 """The mean value to normalize with.""" 

1224 

1225 std: Annotated[float, Ge(1e-6)] 

1226 """The standard deviation value to normalize with.""" 

1227 

1228 

1229class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 

1230 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1231 

1232 mean: NotEmpty[List[float]] 

1233 """The mean value(s) to normalize with.""" 

1234 

1235 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1236 """The standard deviation value(s) to normalize with. 

1237 Size must match `mean` values.""" 

1238 

1239 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1240 """The axis of the mean/std values to normalize each entry along that dimension 

1241 separately.""" 

1242 

1243 @model_validator(mode="after") 

1244 def _mean_and_std_match(self) -> Self: 

1245 if len(self.mean) != len(self.std): 

1246 raise ValueError( 

1247 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1248 + " must match." 

1249 ) 

1250 

1251 return self 

1252 

1253 

1254class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1255 """Subtract a given mean and divide by the standard deviation. 

1256 

1257 Normalize with fixed, precomputed values for 

1258 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1259 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1260 axes. 

1261 

1262 Examples: 

1263 1. scalar value for whole tensor 

1264 - in YAML 

1265 ```yaml 

1266 preprocessing: 

1267 - id: fixed_zero_mean_unit_variance 

1268 kwargs: 

1269 mean: 103.5 

1270 std: 13.7 

1271 ``` 

1272 - in Python 

1273 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1274 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1275 ... )] 

1276 

1277 2. independently along an axis 

1278 - in YAML 

1279 ```yaml 

1280 preprocessing: 

1281 - id: fixed_zero_mean_unit_variance 

1282 kwargs: 

1283 axis: channel 

1284 mean: [101.5, 102.5, 103.5] 

1285 std: [11.7, 12.7, 13.7] 

1286 ``` 

1287 - in Python 

1288 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1289 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1290 ... axis=AxisId("channel"), 

1291 ... mean=[101.5, 102.5, 103.5], 

1292 ... std=[11.7, 12.7, 13.7], 

1293 ... ) 

1294 ... )] 

1295 """ 

1296 

1297 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1298 "fixed_zero_mean_unit_variance" 

1299 ) 

1300 if TYPE_CHECKING: 

1301 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1302 else: 

1303 id: Literal["fixed_zero_mean_unit_variance"] 

1304 

1305 kwargs: Union[ 

1306 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1307 ] 

1308 

1309 

1310class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1311 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

1312 

1313 axes: Annotated[ 

1314 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1315 ] = None 

1316 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1317 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1318 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1319 To normalize each sample independently leave out the 'batch' axis. 

1320 Default: Scale all axes jointly.""" 

1321 

1322 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1323 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1324 

1325 

1326class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1327 """Subtract mean and divide by variance. 

1328 

1329 Examples: 

1330 Subtract tensor mean and variance 

1331 - in YAML 

1332 ```yaml 

1333 preprocessing: 

1334 - id: zero_mean_unit_variance 

1335 ``` 

1336 - in Python 

1337 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1338 """ 

1339 

1340 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1341 "zero_mean_unit_variance" 

1342 ) 

1343 if TYPE_CHECKING: 

1344 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1345 else: 

1346 id: Literal["zero_mean_unit_variance"] 

1347 

1348 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1349 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 

1350 ) 

1351 

1352 

1353class ScaleRangeKwargs(ProcessingKwargs): 

1354 """key word arguments for `ScaleRangeDescr` 

1355 

1356 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1357 this processing step normalizes data to the [0, 1] intervall. 

1358 For other percentiles the normalized values will partially be outside the [0, 1] 

1359 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1360 normalized values to a range. 

1361 """ 

1362 

1363 axes: Annotated[ 

1364 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1365 ] = None 

1366 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1367 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1368 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1369 To normalize samples independently, leave out the "batch" axis. 

1370 Default: Scale all axes jointly.""" 

1371 

1372 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1373 """The lower percentile used to determine the value to align with zero.""" 

1374 

1375 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1376 """The upper percentile used to determine the value to align with one. 

1377 Has to be bigger than `min_percentile`. 

1378 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1379 accepting percentiles specified in the range 0.0 to 1.0.""" 

1380 

1381 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1382 """Epsilon for numeric stability. 

1383 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1384 with `v_lower,v_upper` values at the respective percentiles.""" 

1385 

1386 reference_tensor: Optional[TensorId] = None 

1387 """Tensor ID to compute the percentiles from. Default: The tensor itself. 

1388 For any tensor in `inputs` only input tensor references are allowed.""" 

1389 

1390 @field_validator("max_percentile", mode="after") 

1391 @classmethod 

1392 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1393 if (min_p := info.data["min_percentile"]) >= value: 

1394 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1395 

1396 return value 

1397 

1398 

1399class ScaleRangeDescr(ProcessingDescrBase): 

1400 """Scale with percentiles. 

1401 

1402 Examples: 

1403 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1404 - in YAML 

1405 ```yaml 

1406 preprocessing: 

1407 - id: scale_range 

1408 kwargs: 

1409 axes: ['y', 'x'] 

1410 max_percentile: 99.8 

1411 min_percentile: 5.0 

1412 ``` 

1413 - in Python 

1414 >>> preprocessing = [ 

1415 ... ScaleRangeDescr( 

1416 ... kwargs=ScaleRangeKwargs( 

1417 ... axes= (AxisId('y'), AxisId('x')), 

1418 ... max_percentile= 99.8, 

1419 ... min_percentile= 5.0, 

1420 ... ) 

1421 ... ) 

1422 ... ] 

1423 

1424 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1425 - in YAML 

1426 ```yaml 

1427 preprocessing: 

1428 - id: scale_range 

1429 kwargs: 

1430 axes: ['y', 'x'] 

1431 max_percentile: 99.8 

1432 min_percentile: 5.0 

1433 - id: scale_range 

1434 - id: clip 

1435 kwargs: 

1436 min: 0.0 

1437 max: 1.0 

1438 ``` 

1439 - in Python 

1440 >>> preprocessing = [ 

1441 ... ScaleRangeDescr( 

1442 ... kwargs=ScaleRangeKwargs( 

1443 ... axes= (AxisId('y'), AxisId('x')), 

1444 ... max_percentile= 99.8, 

1445 ... min_percentile= 5.0, 

1446 ... ) 

1447 ... ), 

1448 ... ClipDescr( 

1449 ... kwargs=ClipKwargs( 

1450 ... min=0.0, 

1451 ... max=1.0, 

1452 ... ) 

1453 ... ), 

1454 ... ] 

1455 

1456 """ 

1457 

1458 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1459 if TYPE_CHECKING: 

1460 id: Literal["scale_range"] = "scale_range" 

1461 else: 

1462 id: Literal["scale_range"] 

1463 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 

1464 

1465 

1466class ScaleMeanVarianceKwargs(ProcessingKwargs): 

1467 """key word arguments for `ScaleMeanVarianceKwargs`""" 

1468 

1469 reference_tensor: TensorId 

1470 """Name of tensor to match.""" 

1471 

1472 axes: Annotated[ 

1473 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1474 ] = None 

1475 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1476 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1477 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1478 To normalize samples independently, leave out the 'batch' axis. 

1479 Default: Scale all axes jointly.""" 

1480 

1481 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1482 """Epsilon for numeric stability: 

1483 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1484 

1485 

1486class ScaleMeanVarianceDescr(ProcessingDescrBase): 

1487 """Scale a tensor's data distribution to match another tensor's mean/std. 

1488 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1489 """ 

1490 

1491 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1492 if TYPE_CHECKING: 

1493 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1494 else: 

1495 id: Literal["scale_mean_variance"] 

1496 kwargs: ScaleMeanVarianceKwargs 

1497 

1498 

1499PreprocessingDescr = Annotated[ 

1500 Union[ 

1501 BinarizeDescr, 

1502 ClipDescr, 

1503 EnsureDtypeDescr, 

1504 FixedZeroMeanUnitVarianceDescr, 

1505 ScaleLinearDescr, 

1506 ScaleRangeDescr, 

1507 SigmoidDescr, 

1508 SoftmaxDescr, 

1509 ZeroMeanUnitVarianceDescr, 

1510 ], 

1511 Discriminator("id"), 

1512] 

1513PostprocessingDescr = Annotated[ 

1514 Union[ 

1515 BinarizeDescr, 

1516 ClipDescr, 

1517 EnsureDtypeDescr, 

1518 FixedZeroMeanUnitVarianceDescr, 

1519 ScaleLinearDescr, 

1520 ScaleMeanVarianceDescr, 

1521 ScaleRangeDescr, 

1522 SigmoidDescr, 

1523 SoftmaxDescr, 

1524 ZeroMeanUnitVarianceDescr, 

1525 ], 

1526 Discriminator("id"), 

1527] 

1528 

1529IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1530 

1531 

1532class TensorDescrBase(Node, Generic[IO_AxisT]): 

1533 id: TensorId 

1534 """Tensor id. No duplicates are allowed.""" 

1535 

1536 description: Annotated[str, MaxLen(128)] = "" 

1537 """free text description""" 

1538 

1539 axes: NotEmpty[Sequence[IO_AxisT]] 

1540 """tensor axes""" 

1541 

1542 @property 

1543 def shape(self): 

1544 return tuple(a.size for a in self.axes) 

1545 

1546 @field_validator("axes", mode="after", check_fields=False) 

1547 @classmethod 

1548 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1549 batch_axes = [a for a in axes if a.type == "batch"] 

1550 if len(batch_axes) > 1: 

1551 raise ValueError( 

1552 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1553 ) 

1554 

1555 seen_ids: Set[AxisId] = set() 

1556 duplicate_axes_ids: Set[AxisId] = set() 

1557 for a in axes: 

1558 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1559 

1560 if duplicate_axes_ids: 

1561 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1562 

1563 return axes 

1564 

1565 test_tensor: FAIR[Optional[FileDescr_]] = None 

1566 """An example tensor to use for testing. 

1567 Using the model with the test input tensors is expected to yield the test output tensors. 

1568 Each test tensor has be a an ndarray in the 

1569 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1570 The file extension must be '.npy'.""" 

1571 

1572 sample_tensor: FAIR[Optional[FileDescr_]] = None 

1573 """A sample tensor to illustrate a possible input/output for the model, 

1574 The sample image primarily serves to inform a human user about an example use case 

1575 and is typically stored as .hdf5, .png or .tiff. 

1576 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1577 (numpy's `.npy` format is not supported). 

1578 The image dimensionality has to match the number of axes specified in this tensor description. 

1579 """ 

1580 

1581 @model_validator(mode="after") 

1582 def _validate_sample_tensor(self) -> Self: 

1583 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1584 return self 

1585 

1586 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1587 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType] 

1588 reader.read(), 

1589 extension=PurePosixPath(reader.original_file_name).suffix, 

1590 ) 

1591 n_dims = len(tensor.squeeze().shape) 

1592 n_dims_min = n_dims_max = len(self.axes) 

1593 

1594 for a in self.axes: 

1595 if isinstance(a, BatchAxis): 

1596 n_dims_min -= 1 

1597 elif isinstance(a.size, int): 

1598 if a.size == 1: 

1599 n_dims_min -= 1 

1600 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1601 if a.size.min == 1: 

1602 n_dims_min -= 1 

1603 elif isinstance(a.size, SizeReference): 

1604 if a.size.offset < 2: 

1605 # size reference may result in singleton axis 

1606 n_dims_min -= 1 

1607 else: 

1608 assert_never(a.size) 

1609 

1610 n_dims_min = max(0, n_dims_min) 

1611 if n_dims < n_dims_min or n_dims > n_dims_max: 

1612 raise ValueError( 

1613 f"Expected sample tensor to have {n_dims_min} to" 

1614 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1615 ) 

1616 

1617 return self 

1618 

1619 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1620 IntervalOrRatioDataDescr() 

1621 ) 

1622 """Description of the tensor's data values, optionally per channel. 

1623 If specified per channel, the data `type` needs to match across channels.""" 

1624 

1625 @property 

1626 def dtype( 

1627 self, 

1628 ) -> Literal[ 

1629 "float32", 

1630 "float64", 

1631 "uint8", 

1632 "int8", 

1633 "uint16", 

1634 "int16", 

1635 "uint32", 

1636 "int32", 

1637 "uint64", 

1638 "int64", 

1639 "bool", 

1640 ]: 

1641 """dtype as specified under `data.type` or `data[i].type`""" 

1642 if isinstance(self.data, collections.abc.Sequence): 

1643 return self.data[0].type 

1644 else: 

1645 return self.data.type 

1646 

1647 @field_validator("data", mode="after") 

1648 @classmethod 

1649 def _check_data_type_across_channels( 

1650 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1651 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1652 if not isinstance(value, list): 

1653 return value 

1654 

1655 dtypes = {t.type for t in value} 

1656 if len(dtypes) > 1: 

1657 raise ValueError( 

1658 "Tensor data descriptions per channel need to agree in their data" 

1659 + f" `type`, but found {dtypes}." 

1660 ) 

1661 

1662 return value 

1663 

1664 @model_validator(mode="after") 

1665 def _check_data_matches_channelaxis(self) -> Self: 

1666 if not isinstance(self.data, (list, tuple)): 

1667 return self 

1668 

1669 for a in self.axes: 

1670 if isinstance(a, ChannelAxis): 

1671 size = a.size 

1672 assert isinstance(size, int) 

1673 break 

1674 else: 

1675 return self 

1676 

1677 if len(self.data) != size: 

1678 raise ValueError( 

1679 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1680 + f" '{a.id}' axis has size {size}." 

1681 ) 

1682 

1683 return self 

1684 

1685 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1686 if len(array.shape) != len(self.axes): 

1687 raise ValueError( 

1688 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1689 + f" incompatible with {len(self.axes)} axes." 

1690 ) 

1691 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1692 

1693 

1694class InputTensorDescr(TensorDescrBase[InputAxis]): 

1695 id: TensorId = TensorId("input") 

1696 """Input tensor id. 

1697 No duplicates are allowed across all inputs and outputs.""" 

1698 

1699 optional: bool = False 

1700 """indicates that this tensor may be `None`""" 

1701 

1702 preprocessing: List[PreprocessingDescr] = Field( 

1703 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 

1704 ) 

1705 

1706 """Description of how this input should be preprocessed. 

1707 

1708 notes: 

1709 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1710 to ensure an input tensor's data type matches the input tensor's data description. 

1711 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1712 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1713 changing the data type. 

1714 """ 

1715 

1716 @model_validator(mode="after") 

1717 def _validate_preprocessing_kwargs(self) -> Self: 

1718 axes_ids = [a.id for a in self.axes] 

1719 for p in self.preprocessing: 

1720 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1721 if kwargs_axes is None: 

1722 continue 

1723 

1724 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1725 raise ValueError( 

1726 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1727 ) 

1728 

1729 if any(a not in axes_ids for a in kwargs_axes): 

1730 raise ValueError( 

1731 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1732 ) 

1733 

1734 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1735 dtype = self.data.type 

1736 else: 

1737 dtype = self.data[0].type 

1738 

1739 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1740 if not self.preprocessing or not isinstance( 

1741 self.preprocessing[0], EnsureDtypeDescr 

1742 ): 

1743 self.preprocessing.insert( 

1744 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1745 ) 

1746 

1747 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1748 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1749 self.preprocessing.append( 

1750 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1751 ) 

1752 

1753 return self 

1754 

1755 

1756def convert_axes( 

1757 axes: str, 

1758 *, 

1759 shape: Union[ 

1760 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1761 ], 

1762 tensor_type: Literal["input", "output"], 

1763 halo: Optional[Sequence[int]], 

1764 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1765): 

1766 ret: List[AnyAxis] = [] 

1767 for i, a in enumerate(axes): 

1768 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1769 if axis_type == "batch": 

1770 ret.append(BatchAxis()) 

1771 continue 

1772 

1773 scale = 1.0 

1774 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1775 if shape.step[i] == 0: 

1776 size = shape.min[i] 

1777 else: 

1778 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1779 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1780 ref_t = str(shape.reference_tensor) 

1781 if ref_t.count(".") == 1: 

1782 t_id, orig_a_id = ref_t.split(".") 

1783 else: 

1784 t_id = ref_t 

1785 orig_a_id = a 

1786 

1787 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1788 if not (orig_scale := shape.scale[i]): 

1789 # old way to insert a new axis dimension 

1790 size = int(2 * shape.offset[i]) 

1791 else: 

1792 scale = 1 / orig_scale 

1793 if axis_type in ("channel", "index"): 

1794 # these axes no longer have a scale 

1795 offset_from_scale = orig_scale * size_refs.get( 

1796 _TensorName_v0_4(t_id), {} 

1797 ).get(orig_a_id, 0) 

1798 else: 

1799 offset_from_scale = 0 

1800 size = SizeReference( 

1801 tensor_id=TensorId(t_id), 

1802 axis_id=AxisId(a_id), 

1803 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1804 ) 

1805 else: 

1806 size = shape[i] 

1807 

1808 if axis_type == "time": 

1809 if tensor_type == "input": 

1810 ret.append(TimeInputAxis(size=size, scale=scale)) 

1811 else: 

1812 assert not isinstance(size, ParameterizedSize) 

1813 if halo is None: 

1814 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1815 else: 

1816 assert not isinstance(size, int) 

1817 ret.append( 

1818 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1819 ) 

1820 

1821 elif axis_type == "index": 

1822 if tensor_type == "input": 

1823 ret.append(IndexInputAxis(size=size)) 

1824 else: 

1825 if isinstance(size, ParameterizedSize): 

1826 size = DataDependentSize(min=size.min) 

1827 

1828 ret.append(IndexOutputAxis(size=size)) 

1829 elif axis_type == "channel": 

1830 assert not isinstance(size, ParameterizedSize) 

1831 if isinstance(size, SizeReference): 

1832 warnings.warn( 

1833 "Conversion of channel size from an implicit output shape may be" 

1834 + " wrong" 

1835 ) 

1836 ret.append( 

1837 ChannelAxis( 

1838 channel_names=[ 

1839 Identifier(f"channel{i}") for i in range(size.offset) 

1840 ] 

1841 ) 

1842 ) 

1843 else: 

1844 ret.append( 

1845 ChannelAxis( 

1846 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1847 ) 

1848 ) 

1849 elif axis_type == "space": 

1850 if tensor_type == "input": 

1851 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

1852 else: 

1853 assert not isinstance(size, ParameterizedSize) 

1854 if halo is None or halo[i] == 0: 

1855 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

1856 elif isinstance(size, int): 

1857 raise NotImplementedError( 

1858 f"output axis with halo and fixed size (here {size}) not allowed" 

1859 ) 

1860 else: 

1861 ret.append( 

1862 SpaceOutputAxisWithHalo( 

1863 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

1864 ) 

1865 ) 

1866 

1867 return ret 

1868 

1869 

1870def _axes_letters_to_ids( 

1871 axes: Optional[str], 

1872) -> Optional[List[AxisId]]: 

1873 if axes is None: 

1874 return None 

1875 

1876 return [AxisId(a) for a in axes] 

1877 

1878 

1879def _get_complement_v04_axis( 

1880 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

1881) -> Optional[AxisId]: 

1882 if axes is None: 

1883 return None 

1884 

1885 non_complement_axes = set(axes) | {"b"} 

1886 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

1887 if len(complement_axes) > 1: 

1888 raise ValueError( 

1889 f"Expected none or a single complement axis, but axes '{axes}' " 

1890 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

1891 ) 

1892 

1893 return None if not complement_axes else AxisId(complement_axes[0]) 

1894 

1895 

1896def _convert_proc( 

1897 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

1898 tensor_axes: Sequence[str], 

1899) -> Union[PreprocessingDescr, PostprocessingDescr]: 

1900 if isinstance(p, _BinarizeDescr_v0_4): 

1901 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

1902 elif isinstance(p, _ClipDescr_v0_4): 

1903 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

1904 elif isinstance(p, _SigmoidDescr_v0_4): 

1905 return SigmoidDescr() 

1906 elif isinstance(p, _ScaleLinearDescr_v0_4): 

1907 axes = _axes_letters_to_ids(p.kwargs.axes) 

1908 if p.kwargs.axes is None: 

1909 axis = None 

1910 else: 

1911 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1912 

1913 if axis is None: 

1914 assert not isinstance(p.kwargs.gain, list) 

1915 assert not isinstance(p.kwargs.offset, list) 

1916 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

1917 else: 

1918 kwargs = ScaleLinearAlongAxisKwargs( 

1919 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

1920 ) 

1921 return ScaleLinearDescr(kwargs=kwargs) 

1922 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

1923 return ScaleMeanVarianceDescr( 

1924 kwargs=ScaleMeanVarianceKwargs( 

1925 axes=_axes_letters_to_ids(p.kwargs.axes), 

1926 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

1927 eps=p.kwargs.eps, 

1928 ) 

1929 ) 

1930 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

1931 if p.kwargs.mode == "fixed": 

1932 mean = p.kwargs.mean 

1933 std = p.kwargs.std 

1934 assert mean is not None 

1935 assert std is not None 

1936 

1937 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1938 

1939 if axis is None: 

1940 if isinstance(mean, list): 

1941 raise ValueError("Expected single float value for mean, not <list>") 

1942 if isinstance(std, list): 

1943 raise ValueError("Expected single float value for std, not <list>") 

1944 return FixedZeroMeanUnitVarianceDescr( 

1945 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct( 

1946 mean=mean, 

1947 std=std, 

1948 ) 

1949 ) 

1950 else: 

1951 if not isinstance(mean, list): 

1952 mean = [float(mean)] 

1953 if not isinstance(std, list): 

1954 std = [float(std)] 

1955 

1956 return FixedZeroMeanUnitVarianceDescr( 

1957 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1958 axis=axis, mean=mean, std=std 

1959 ) 

1960 ) 

1961 

1962 else: 

1963 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

1964 if p.kwargs.mode == "per_dataset": 

1965 axes = [AxisId("batch")] + axes 

1966 if not axes: 

1967 axes = None 

1968 return ZeroMeanUnitVarianceDescr( 

1969 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

1970 ) 

1971 

1972 elif isinstance(p, _ScaleRangeDescr_v0_4): 

1973 return ScaleRangeDescr( 

1974 kwargs=ScaleRangeKwargs( 

1975 axes=_axes_letters_to_ids(p.kwargs.axes), 

1976 min_percentile=p.kwargs.min_percentile, 

1977 max_percentile=p.kwargs.max_percentile, 

1978 eps=p.kwargs.eps, 

1979 ) 

1980 ) 

1981 else: 

1982 assert_never(p) 

1983 

1984 

1985class _InputTensorConv( 

1986 Converter[ 

1987 _InputTensorDescr_v0_4, 

1988 InputTensorDescr, 

1989 FileSource_, 

1990 Optional[FileSource_], 

1991 Mapping[_TensorName_v0_4, Mapping[str, int]], 

1992 ] 

1993): 

1994 def _convert( 

1995 self, 

1996 src: _InputTensorDescr_v0_4, 

1997 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

1998 test_tensor: FileSource_, 

1999 sample_tensor: Optional[FileSource_], 

2000 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2001 ) -> "InputTensorDescr | dict[str, Any]": 

2002 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2003 src.axes, 

2004 shape=src.shape, 

2005 tensor_type="input", 

2006 halo=None, 

2007 size_refs=size_refs, 

2008 ) 

2009 prep: List[PreprocessingDescr] = [] 

2010 for p in src.preprocessing: 

2011 cp = _convert_proc(p, src.axes) 

2012 assert not isinstance(cp, ScaleMeanVarianceDescr) 

2013 prep.append(cp) 

2014 

2015 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

2016 

2017 return tgt( 

2018 axes=axes, 

2019 id=TensorId(str(src.name)), 

2020 test_tensor=FileDescr(source=test_tensor), 

2021 sample_tensor=( 

2022 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2023 ), 

2024 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

2025 preprocessing=prep, 

2026 ) 

2027 

2028 

2029_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

2030 

2031 

2032class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

2033 id: TensorId = TensorId("output") 

2034 """Output tensor id. 

2035 No duplicates are allowed across all inputs and outputs.""" 

2036 

2037 postprocessing: List[PostprocessingDescr] = Field( 

2038 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

2039 ) 

2040 """Description of how this output should be postprocessed. 

2041 

2042 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

2043 If not given this is added to cast to this tensor's `data.type`. 

2044 """ 

2045 

2046 @model_validator(mode="after") 

2047 def _validate_postprocessing_kwargs(self) -> Self: 

2048 axes_ids = [a.id for a in self.axes] 

2049 for p in self.postprocessing: 

2050 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

2051 if kwargs_axes is None: 

2052 continue 

2053 

2054 if not isinstance(kwargs_axes, collections.abc.Sequence): 

2055 raise ValueError( 

2056 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

2057 ) 

2058 

2059 if any(a not in axes_ids for a in kwargs_axes): 

2060 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

2061 

2062 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

2063 dtype = self.data.type 

2064 else: 

2065 dtype = self.data[0].type 

2066 

2067 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

2068 if not self.postprocessing or not isinstance( 

2069 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

2070 ): 

2071 self.postprocessing.append( 

2072 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2073 ) 

2074 return self 

2075 

2076 

2077class _OutputTensorConv( 

2078 Converter[ 

2079 _OutputTensorDescr_v0_4, 

2080 OutputTensorDescr, 

2081 FileSource_, 

2082 Optional[FileSource_], 

2083 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2084 ] 

2085): 

2086 def _convert( 

2087 self, 

2088 src: _OutputTensorDescr_v0_4, 

2089 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

2090 test_tensor: FileSource_, 

2091 sample_tensor: Optional[FileSource_], 

2092 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2093 ) -> "OutputTensorDescr | dict[str, Any]": 

2094 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2095 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2096 src.axes, 

2097 shape=src.shape, 

2098 tensor_type="output", 

2099 halo=src.halo, 

2100 size_refs=size_refs, 

2101 ) 

2102 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2103 if data_descr["type"] == "bool": 

2104 data_descr["values"] = [False, True] 

2105 

2106 return tgt( 

2107 axes=axes, 

2108 id=TensorId(str(src.name)), 

2109 test_tensor=FileDescr(source=test_tensor), 

2110 sample_tensor=( 

2111 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2112 ), 

2113 data=data_descr, # pyright: ignore[reportArgumentType] 

2114 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2115 ) 

2116 

2117 

2118_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2119 

2120 

2121TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2122 

2123 

2124def validate_tensors( 

2125 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 

2126 tensor_origin: Literal[ 

2127 "test_tensor" 

2128 ], # for more precise error messages, e.g. 'test_tensor' 

2129): 

2130 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 

2131 

2132 def e_msg(d: TensorDescr): 

2133 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2134 

2135 for descr, array in tensors.values(): 

2136 if array is None: 

2137 axis_sizes = {a.id: None for a in descr.axes} 

2138 else: 

2139 try: 

2140 axis_sizes = descr.get_axis_sizes_for_array(array) 

2141 except ValueError as e: 

2142 raise ValueError(f"{e_msg(descr)} {e}") 

2143 

2144 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 

2145 

2146 for descr, array in tensors.values(): 

2147 if array is None: 

2148 continue 

2149 

2150 if descr.dtype in ("float32", "float64"): 

2151 invalid_test_tensor_dtype = array.dtype.name not in ( 

2152 "float32", 

2153 "float64", 

2154 "uint8", 

2155 "int8", 

2156 "uint16", 

2157 "int16", 

2158 "uint32", 

2159 "int32", 

2160 "uint64", 

2161 "int64", 

2162 ) 

2163 else: 

2164 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2165 

2166 if invalid_test_tensor_dtype: 

2167 raise ValueError( 

2168 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 

2169 + f" match described dtype '{descr.dtype}'" 

2170 ) 

2171 

2172 if array.min() > -1e-4 and array.max() < 1e-4: 

2173 raise ValueError( 

2174 "Output values are too small for reliable testing." 

2175 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2176 ) 

2177 

2178 for a in descr.axes: 

2179 actual_size = all_tensor_axes[descr.id][a.id][1] 

2180 if actual_size is None: 

2181 continue 

2182 

2183 if a.size is None: 

2184 continue 

2185 

2186 if isinstance(a.size, int): 

2187 if actual_size != a.size: 

2188 raise ValueError( 

2189 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 

2190 + f"has incompatible size {actual_size}, expected {a.size}" 

2191 ) 

2192 elif isinstance(a.size, ParameterizedSize): 

2193 _ = a.size.validate_size(actual_size) 

2194 elif isinstance(a.size, DataDependentSize): 

2195 _ = a.size.validate_size(actual_size) 

2196 elif isinstance(a.size, SizeReference): 

2197 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2198 if ref_tensor_axes is None: 

2199 raise ValueError( 

2200 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2201 + f" reference '{a.size.tensor_id}'" 

2202 ) 

2203 

2204 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2205 if ref_axis is None or ref_size is None: 

2206 raise ValueError( 

2207 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2208 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 

2209 ) 

2210 

2211 if a.unit != ref_axis.unit: 

2212 raise ValueError( 

2213 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2214 + " axis and reference axis to have the same `unit`, but" 

2215 + f" {a.unit}!={ref_axis.unit}" 

2216 ) 

2217 

2218 if actual_size != ( 

2219 expected_size := ( 

2220 ref_size * ref_axis.scale / a.scale + a.size.offset 

2221 ) 

2222 ): 

2223 raise ValueError( 

2224 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2225 + f" {actual_size} invalid for referenced size {ref_size};" 

2226 + f" expected {expected_size}" 

2227 ) 

2228 else: 

2229 assert_never(a.size) 

2230 

2231 

2232FileDescr_dependencies = Annotated[ 

2233 FileDescr_, 

2234 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2235 Field(examples=[dict(source="environment.yaml")]), 

2236] 

2237 

2238 

2239class _ArchitectureCallableDescr(Node): 

2240 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2241 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2242 

2243 kwargs: Dict[str, YamlValue] = Field( 

2244 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2245 ) 

2246 """key word arguments for the `callable`""" 

2247 

2248 

2249class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2250 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2251 """Architecture source file""" 

2252 

2253 @model_serializer(mode="wrap", when_used="unless-none") 

2254 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2255 return package_file_descr_serializer(self, nxt, info) 

2256 

2257 

2258class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2259 import_from: str 

2260 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2261 

2262 

2263class _ArchFileConv( 

2264 Converter[ 

2265 _CallableFromFile_v0_4, 

2266 ArchitectureFromFileDescr, 

2267 Optional[Sha256], 

2268 Dict[str, Any], 

2269 ] 

2270): 

2271 def _convert( 

2272 self, 

2273 src: _CallableFromFile_v0_4, 

2274 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2275 sha256: Optional[Sha256], 

2276 kwargs: Dict[str, Any], 

2277 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2278 if src.startswith("http") and src.count(":") == 2: 

2279 http, source, callable_ = src.split(":") 

2280 source = ":".join((http, source)) 

2281 elif not src.startswith("http") and src.count(":") == 1: 

2282 source, callable_ = src.split(":") 

2283 else: 

2284 source = str(src) 

2285 callable_ = str(src) 

2286 return tgt( 

2287 callable=Identifier(callable_), 

2288 source=cast(FileSource_, source), 

2289 sha256=sha256, 

2290 kwargs=kwargs, 

2291 ) 

2292 

2293 

2294_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2295 

2296 

2297class _ArchLibConv( 

2298 Converter[ 

2299 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2300 ] 

2301): 

2302 def _convert( 

2303 self, 

2304 src: _CallableFromDepencency_v0_4, 

2305 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2306 kwargs: Dict[str, Any], 

2307 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2308 *mods, callable_ = src.split(".") 

2309 import_from = ".".join(mods) 

2310 return tgt( 

2311 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2312 ) 

2313 

2314 

2315_arch_lib_conv = _ArchLibConv( 

2316 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2317) 

2318 

2319 

2320class WeightsEntryDescrBase(FileDescr): 

2321 type: ClassVar[WeightsFormat] 

2322 weights_format_name: ClassVar[str] # human readable 

2323 

2324 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2325 """Source of the weights file.""" 

2326 

2327 authors: Optional[List[Author]] = None 

2328 """Authors 

2329 Either the person(s) that have trained this model resulting in the original weights file. 

2330 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2331 Or the person(s) who have converted the weights to this weights format. 

2332 (If this is a child weight, i.e. it has a `parent` field) 

2333 """ 

2334 

2335 parent: Annotated[ 

2336 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2337 ] = None 

2338 """The source weights these weights were converted from. 

2339 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2340 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2341 All weight entries except one (the initial set of weights resulting from training the model), 

2342 need to have this field.""" 

2343 

2344 comment: str = "" 

2345 """A comment about this weights entry, for example how these weights were created.""" 

2346 

2347 @model_validator(mode="after") 

2348 def _validate(self) -> Self: 

2349 if self.type == self.parent: 

2350 raise ValueError("Weights entry can't be it's own parent.") 

2351 

2352 return self 

2353 

2354 @model_serializer(mode="wrap", when_used="unless-none") 

2355 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2356 return package_file_descr_serializer(self, nxt, info) 

2357 

2358 

2359class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2360 type: ClassVar[WeightsFormat] = "keras_hdf5" 

2361 weights_format_name: ClassVar[str] = "Keras HDF5" 

2362 tensorflow_version: Version 

2363 """TensorFlow version used to create these weights.""" 

2364 

2365 

2366FileDescr_external_data = Annotated[ 

2367 FileDescr_, 

2368 WithSuffix(".data", case_sensitive=True), 

2369 Field(examples=[dict(source="weights.onnx.data")]), 

2370] 

2371 

2372 

2373class OnnxWeightsDescr(WeightsEntryDescrBase): 

2374 type: ClassVar[WeightsFormat] = "onnx" 

2375 weights_format_name: ClassVar[str] = "ONNX" 

2376 opset_version: Annotated[int, Ge(7)] 

2377 """ONNX opset version""" 

2378 

2379 external_data: Optional[FileDescr_external_data] = None 

2380 """Source of the external ONNX data file holding the weights. 

2381 (If present **source** holds the ONNX architecture without weights).""" 

2382 

2383 @model_validator(mode="after") 

2384 def _validate_external_data_unique_file_name(self) -> Self: 

2385 if self.external_data is not None and ( 

2386 extract_file_name(self.source) 

2387 == extract_file_name(self.external_data.source) 

2388 ): 

2389 raise ValueError( 

2390 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'" 

2391 + " must be different from ONNX `source` file name." 

2392 ) 

2393 

2394 return self 

2395 

2396 

2397class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2398 type: ClassVar[WeightsFormat] = "pytorch_state_dict" 

2399 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2400 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 

2401 pytorch_version: Version 

2402 """Version of the PyTorch library used. 

2403 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2404 """ 

2405 dependencies: Optional[FileDescr_dependencies] = None 

2406 """Custom depencies beyond pytorch described in a Conda environment file. 

2407 Allows to specify custom dependencies, see conda docs: 

2408 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2409 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2410 

2411 The conda environment file should include pytorch and any version pinning has to be compatible with 

2412 **pytorch_version**. 

2413 """ 

2414 

2415 

2416class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2417 type: ClassVar[WeightsFormat] = "tensorflow_js" 

2418 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2419 tensorflow_version: Version 

2420 """Version of the TensorFlow library used.""" 

2421 

2422 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2423 """The multi-file weights. 

2424 All required files/folders should be a zip archive.""" 

2425 

2426 

2427class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2428 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle" 

2429 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2430 tensorflow_version: Version 

2431 """Version of the TensorFlow library used.""" 

2432 

2433 dependencies: Optional[FileDescr_dependencies] = None 

2434 """Custom dependencies beyond tensorflow. 

2435 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 

2436 

2437 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2438 """The multi-file weights. 

2439 All required files/folders should be a zip archive.""" 

2440 

2441 

2442class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2443 type: ClassVar[WeightsFormat] = "torchscript" 

2444 weights_format_name: ClassVar[str] = "TorchScript" 

2445 pytorch_version: Version 

2446 """Version of the PyTorch library used.""" 

2447 

2448 

2449SpecificWeightsDescr = Union[ 

2450 KerasHdf5WeightsDescr, 

2451 OnnxWeightsDescr, 

2452 PytorchStateDictWeightsDescr, 

2453 TensorflowJsWeightsDescr, 

2454 TensorflowSavedModelBundleWeightsDescr, 

2455 TorchscriptWeightsDescr, 

2456] 

2457 

2458 

2459class WeightsDescr(Node): 

2460 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2461 onnx: Optional[OnnxWeightsDescr] = None 

2462 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2463 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2464 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2465 None 

2466 ) 

2467 torchscript: Optional[TorchscriptWeightsDescr] = None 

2468 

2469 @model_validator(mode="after") 

2470 def check_entries(self) -> Self: 

2471 entries = {wtype for wtype, entry in self if entry is not None} 

2472 

2473 if not entries: 

2474 raise ValueError("Missing weights entry") 

2475 

2476 entries_wo_parent = { 

2477 wtype 

2478 for wtype, entry in self 

2479 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2480 } 

2481 if len(entries_wo_parent) != 1: 

2482 issue_warning( 

2483 "Exactly one weights entry may not specify the `parent` field (got" 

2484 + " {value}). That entry is considered the original set of model weights." 

2485 + " Other weight formats are created through conversion of the orignal or" 

2486 + " already converted weights. They have to reference the weights format" 

2487 + " they were converted from as their `parent`.", 

2488 value=len(entries_wo_parent), 

2489 field="weights", 

2490 ) 

2491 

2492 for wtype, entry in self: 

2493 if entry is None: 

2494 continue 

2495 

2496 assert hasattr(entry, "type") 

2497 assert hasattr(entry, "parent") 

2498 assert wtype == entry.type 

2499 if ( 

2500 entry.parent is not None and entry.parent not in entries 

2501 ): # self reference checked for `parent` field 

2502 raise ValueError( 

2503 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2504 + f" formats: {entries}" 

2505 ) 

2506 

2507 return self 

2508 

2509 def __getitem__( 

2510 self, 

2511 key: Literal[ 

2512 "keras_hdf5", 

2513 "onnx", 

2514 "pytorch_state_dict", 

2515 "tensorflow_js", 

2516 "tensorflow_saved_model_bundle", 

2517 "torchscript", 

2518 ], 

2519 ): 

2520 if key == "keras_hdf5": 

2521 ret = self.keras_hdf5 

2522 elif key == "onnx": 

2523 ret = self.onnx 

2524 elif key == "pytorch_state_dict": 

2525 ret = self.pytorch_state_dict 

2526 elif key == "tensorflow_js": 

2527 ret = self.tensorflow_js 

2528 elif key == "tensorflow_saved_model_bundle": 

2529 ret = self.tensorflow_saved_model_bundle 

2530 elif key == "torchscript": 

2531 ret = self.torchscript 

2532 else: 

2533 raise KeyError(key) 

2534 

2535 if ret is None: 

2536 raise KeyError(key) 

2537 

2538 return ret 

2539 

2540 @overload 

2541 def __setitem__( 

2542 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr] 

2543 ) -> None: ... 

2544 @overload 

2545 def __setitem__( 

2546 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr] 

2547 ) -> None: ... 

2548 @overload 

2549 def __setitem__( 

2550 self, 

2551 key: Literal["pytorch_state_dict"], 

2552 value: Optional[PytorchStateDictWeightsDescr], 

2553 ) -> None: ... 

2554 @overload 

2555 def __setitem__( 

2556 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr] 

2557 ) -> None: ... 

2558 @overload 

2559 def __setitem__( 

2560 self, 

2561 key: Literal["tensorflow_saved_model_bundle"], 

2562 value: Optional[TensorflowSavedModelBundleWeightsDescr], 

2563 ) -> None: ... 

2564 @overload 

2565 def __setitem__( 

2566 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr] 

2567 ) -> None: ... 

2568 

2569 def __setitem__( 

2570 self, 

2571 key: Literal[ 

2572 "keras_hdf5", 

2573 "onnx", 

2574 "pytorch_state_dict", 

2575 "tensorflow_js", 

2576 "tensorflow_saved_model_bundle", 

2577 "torchscript", 

2578 ], 

2579 value: Optional[SpecificWeightsDescr], 

2580 ): 

2581 if key == "keras_hdf5": 

2582 if value is not None and not isinstance(value, KerasHdf5WeightsDescr): 

2583 raise TypeError( 

2584 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}" 

2585 ) 

2586 self.keras_hdf5 = value 

2587 elif key == "onnx": 

2588 if value is not None and not isinstance(value, OnnxWeightsDescr): 

2589 raise TypeError( 

2590 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}" 

2591 ) 

2592 self.onnx = value 

2593 elif key == "pytorch_state_dict": 

2594 if value is not None and not isinstance( 

2595 value, PytorchStateDictWeightsDescr 

2596 ): 

2597 raise TypeError( 

2598 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}" 

2599 ) 

2600 self.pytorch_state_dict = value 

2601 elif key == "tensorflow_js": 

2602 if value is not None and not isinstance(value, TensorflowJsWeightsDescr): 

2603 raise TypeError( 

2604 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}" 

2605 ) 

2606 self.tensorflow_js = value 

2607 elif key == "tensorflow_saved_model_bundle": 

2608 if value is not None and not isinstance( 

2609 value, TensorflowSavedModelBundleWeightsDescr 

2610 ): 

2611 raise TypeError( 

2612 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}" 

2613 ) 

2614 self.tensorflow_saved_model_bundle = value 

2615 elif key == "torchscript": 

2616 if value is not None and not isinstance(value, TorchscriptWeightsDescr): 

2617 raise TypeError( 

2618 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}" 

2619 ) 

2620 self.torchscript = value 

2621 else: 

2622 raise KeyError(key) 

2623 

2624 @property 

2625 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]: 

2626 return { 

2627 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2628 **({} if self.onnx is None else {"onnx": self.onnx}), 

2629 **( 

2630 {} 

2631 if self.pytorch_state_dict is None 

2632 else {"pytorch_state_dict": self.pytorch_state_dict} 

2633 ), 

2634 **( 

2635 {} 

2636 if self.tensorflow_js is None 

2637 else {"tensorflow_js": self.tensorflow_js} 

2638 ), 

2639 **( 

2640 {} 

2641 if self.tensorflow_saved_model_bundle is None 

2642 else { 

2643 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2644 } 

2645 ), 

2646 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2647 } 

2648 

2649 @property 

2650 def missing_formats(self) -> Set[WeightsFormat]: 

2651 return { 

2652 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2653 } 

2654 

2655 

2656class ModelId(ResourceId): 

2657 pass 

2658 

2659 

2660class LinkedModel(LinkedResourceBase): 

2661 """Reference to a bioimage.io model.""" 

2662 

2663 id: ModelId 

2664 """A valid model `id` from the bioimage.io collection.""" 

2665 

2666 

2667class _DataDepSize(NamedTuple): 

2668 min: StrictInt 

2669 max: Optional[StrictInt] 

2670 

2671 

2672class _AxisSizes(NamedTuple): 

2673 """the lenghts of all axes of model inputs and outputs""" 

2674 

2675 inputs: Dict[Tuple[TensorId, AxisId], int] 

2676 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2677 

2678 

2679class _TensorSizes(NamedTuple): 

2680 """_AxisSizes as nested dicts""" 

2681 

2682 inputs: Dict[TensorId, Dict[AxisId, int]] 

2683 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2684 

2685 

2686class ReproducibilityTolerance(Node, extra="allow"): 

2687 """Describes what small numerical differences -- if any -- may be tolerated 

2688 in the generated output when executing in different environments. 

2689 

2690 A tensor element *output* is considered mismatched to the **test_tensor** if 

2691 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

2692 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

2693 

2694 Motivation: 

2695 For testing we can request the respective deep learning frameworks to be as 

2696 reproducible as possible by setting seeds and chosing deterministic algorithms, 

2697 but differences in operating systems, available hardware and installed drivers 

2698 may still lead to numerical differences. 

2699 """ 

2700 

2701 relative_tolerance: RelativeTolerance = 1e-3 

2702 """Maximum relative tolerance of reproduced test tensor.""" 

2703 

2704 absolute_tolerance: AbsoluteTolerance = 1e-4 

2705 """Maximum absolute tolerance of reproduced test tensor.""" 

2706 

2707 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 

2708 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

2709 

2710 output_ids: Sequence[TensorId] = () 

2711 """Limits the output tensor IDs these reproducibility details apply to.""" 

2712 

2713 weights_formats: Sequence[WeightsFormat] = () 

2714 """Limits the weights formats these details apply to.""" 

2715 

2716 

2717class BioimageioConfig(Node, extra="allow"): 

2718 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

2719 """Tolerances to allow when reproducing the model's test outputs 

2720 from the model's test inputs. 

2721 Only the first entry matching tensor id and weights format is considered. 

2722 """ 

2723 

2724 

2725class Config(Node, extra="allow"): 

2726 bioimageio: BioimageioConfig = Field( 

2727 default_factory=BioimageioConfig.model_construct 

2728 ) 

2729 

2730 

2731class ModelDescr(GenericModelDescrBase): 

2732 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

2733 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

2734 """ 

2735 

2736 implemented_format_version: ClassVar[Literal["0.5.6"]] = "0.5.6" 

2737 if TYPE_CHECKING: 

2738 format_version: Literal["0.5.6"] = "0.5.6" 

2739 else: 

2740 format_version: Literal["0.5.6"] 

2741 """Version of the bioimage.io model description specification used. 

2742 When creating a new model always use the latest micro/patch version described here. 

2743 The `format_version` is important for any consumer software to understand how to parse the fields. 

2744 """ 

2745 

2746 implemented_type: ClassVar[Literal["model"]] = "model" 

2747 if TYPE_CHECKING: 

2748 type: Literal["model"] = "model" 

2749 else: 

2750 type: Literal["model"] 

2751 """Specialized resource type 'model'""" 

2752 

2753 id: Optional[ModelId] = None 

2754 """bioimage.io-wide unique resource identifier 

2755 assigned by bioimage.io; version **un**specific.""" 

2756 

2757 authors: FAIR[List[Author]] = Field( 

2758 default_factory=cast(Callable[[], List[Author]], list) 

2759 ) 

2760 """The authors are the creators of the model RDF and the primary points of contact.""" 

2761 

2762 documentation: FAIR[Optional[FileSource_documentation]] = None 

2763 """URL or relative path to a markdown file with additional documentation. 

2764 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

2765 The documentation should include a '#[#] Validation' (sub)section 

2766 with details on how to quantitatively validate the model on unseen data.""" 

2767 

2768 @field_validator("documentation", mode="after") 

2769 @classmethod 

2770 def _validate_documentation( 

2771 cls, value: Optional[FileSource_documentation] 

2772 ) -> Optional[FileSource_documentation]: 

2773 if not get_validation_context().perform_io_checks or value is None: 

2774 return value 

2775 

2776 doc_reader = get_reader(value) 

2777 doc_content = doc_reader.read().decode(encoding="utf-8") 

2778 if not re.search("#.*[vV]alidation", doc_content): 

2779 issue_warning( 

2780 "No '# Validation' (sub)section found in {value}.", 

2781 value=value, 

2782 field="documentation", 

2783 ) 

2784 

2785 return value 

2786 

2787 inputs: NotEmpty[Sequence[InputTensorDescr]] 

2788 """Describes the input tensors expected by this model.""" 

2789 

2790 @field_validator("inputs", mode="after") 

2791 @classmethod 

2792 def _validate_input_axes( 

2793 cls, inputs: Sequence[InputTensorDescr] 

2794 ) -> Sequence[InputTensorDescr]: 

2795 input_size_refs = cls._get_axes_with_independent_size(inputs) 

2796 

2797 for i, ipt in enumerate(inputs): 

2798 valid_independent_refs: Dict[ 

2799 Tuple[TensorId, AxisId], 

2800 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2801 ] = { 

2802 **{ 

2803 (ipt.id, a.id): (ipt, a, a.size) 

2804 for a in ipt.axes 

2805 if not isinstance(a, BatchAxis) 

2806 and isinstance(a.size, (int, ParameterizedSize)) 

2807 }, 

2808 **input_size_refs, 

2809 } 

2810 for a, ax in enumerate(ipt.axes): 

2811 cls._validate_axis( 

2812 "inputs", 

2813 i=i, 

2814 tensor_id=ipt.id, 

2815 a=a, 

2816 axis=ax, 

2817 valid_independent_refs=valid_independent_refs, 

2818 ) 

2819 return inputs 

2820 

2821 @staticmethod 

2822 def _validate_axis( 

2823 field_name: str, 

2824 i: int, 

2825 tensor_id: TensorId, 

2826 a: int, 

2827 axis: AnyAxis, 

2828 valid_independent_refs: Dict[ 

2829 Tuple[TensorId, AxisId], 

2830 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2831 ], 

2832 ): 

2833 if isinstance(axis, BatchAxis) or isinstance( 

2834 axis.size, (int, ParameterizedSize, DataDependentSize) 

2835 ): 

2836 return 

2837 elif not isinstance(axis.size, SizeReference): 

2838 assert_never(axis.size) 

2839 

2840 # validate axis.size SizeReference 

2841 ref = (axis.size.tensor_id, axis.size.axis_id) 

2842 if ref not in valid_independent_refs: 

2843 raise ValueError( 

2844 "Invalid tensor axis reference at" 

2845 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

2846 ) 

2847 if ref == (tensor_id, axis.id): 

2848 raise ValueError( 

2849 "Self-referencing not allowed for" 

2850 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

2851 ) 

2852 if axis.type == "channel": 

2853 if valid_independent_refs[ref][1].type != "channel": 

2854 raise ValueError( 

2855 "A channel axis' size may only reference another fixed size" 

2856 + " channel axis." 

2857 ) 

2858 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

2859 ref_size = valid_independent_refs[ref][2] 

2860 assert isinstance(ref_size, int), ( 

2861 "channel axis ref (another channel axis) has to specify fixed" 

2862 + " size" 

2863 ) 

2864 generated_channel_names = [ 

2865 Identifier(axis.channel_names.format(i=i)) 

2866 for i in range(1, ref_size + 1) 

2867 ] 

2868 axis.channel_names = generated_channel_names 

2869 

2870 if (ax_unit := getattr(axis, "unit", None)) != ( 

2871 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

2872 ): 

2873 raise ValueError( 

2874 "The units of an axis and its reference axis need to match, but" 

2875 + f" '{ax_unit}' != '{ref_unit}'." 

2876 ) 

2877 ref_axis = valid_independent_refs[ref][1] 

2878 if isinstance(ref_axis, BatchAxis): 

2879 raise ValueError( 

2880 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

2881 + " (a batch axis is not allowed as reference)." 

2882 ) 

2883 

2884 if isinstance(axis, WithHalo): 

2885 min_size = axis.size.get_size(axis, ref_axis, n=0) 

2886 if (min_size - 2 * axis.halo) < 1: 

2887 raise ValueError( 

2888 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

2889 + f" {axis.halo}." 

2890 ) 

2891 

2892 ref_halo = axis.halo * axis.scale / ref_axis.scale 

2893 if ref_halo != int(ref_halo): 

2894 raise ValueError( 

2895 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} =" 

2896 + f" {tensor_id}.{axis.id}.halo {axis.halo}" 

2897 + f" * {tensor_id}.{axis.id}.scale {axis.scale}" 

2898 + f" / {'.'.join(ref)}.scale {ref_axis.scale})." 

2899 ) 

2900 

2901 @model_validator(mode="after") 

2902 def _validate_test_tensors(self) -> Self: 

2903 if not get_validation_context().perform_io_checks: 

2904 return self 

2905 

2906 test_output_arrays = [ 

2907 None if descr.test_tensor is None else load_array(descr.test_tensor) 

2908 for descr in self.outputs 

2909 ] 

2910 test_input_arrays = [ 

2911 None if descr.test_tensor is None else load_array(descr.test_tensor) 

2912 for descr in self.inputs 

2913 ] 

2914 

2915 tensors = { 

2916 descr.id: (descr, array) 

2917 for descr, array in zip( 

2918 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 

2919 ) 

2920 } 

2921 validate_tensors(tensors, tensor_origin="test_tensor") 

2922 

2923 output_arrays = { 

2924 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 

2925 } 

2926 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

2927 if not rep_tol.absolute_tolerance: 

2928 continue 

2929 

2930 if rep_tol.output_ids: 

2931 out_arrays = { 

2932 oid: a 

2933 for oid, a in output_arrays.items() 

2934 if oid in rep_tol.output_ids 

2935 } 

2936 else: 

2937 out_arrays = output_arrays 

2938 

2939 for out_id, array in out_arrays.items(): 

2940 if array is None: 

2941 continue 

2942 

2943 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

2944 raise ValueError( 

2945 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

2946 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

2947 + f" (1% of the maximum value of the test tensor '{out_id}')" 

2948 ) 

2949 

2950 return self 

2951 

2952 @model_validator(mode="after") 

2953 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

2954 ipt_refs = {t.id for t in self.inputs} 

2955 out_refs = {t.id for t in self.outputs} 

2956 for ipt in self.inputs: 

2957 for p in ipt.preprocessing: 

2958 ref = p.kwargs.get("reference_tensor") 

2959 if ref is None: 

2960 continue 

2961 if ref not in ipt_refs: 

2962 raise ValueError( 

2963 f"`reference_tensor` '{ref}' not found. Valid input tensor" 

2964 + f" references are: {ipt_refs}." 

2965 ) 

2966 

2967 for out in self.outputs: 

2968 for p in out.postprocessing: 

2969 ref = p.kwargs.get("reference_tensor") 

2970 if ref is None: 

2971 continue 

2972 

2973 if ref not in ipt_refs and ref not in out_refs: 

2974 raise ValueError( 

2975 f"`reference_tensor` '{ref}' not found. Valid tensor references" 

2976 + f" are: {ipt_refs | out_refs}." 

2977 ) 

2978 

2979 return self 

2980 

2981 # TODO: use validate funcs in validate_test_tensors 

2982 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 

2983 

2984 name: Annotated[ 

2985 str, 

2986 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 

2987 MinLen(5), 

2988 MaxLen(128), 

2989 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

2990 ] 

2991 """A human-readable name of this model. 

2992 It should be no longer than 64 characters 

2993 and may only contain letter, number, underscore, minus, parentheses and spaces. 

2994 We recommend to chose a name that refers to the model's task and image modality. 

2995 """ 

2996 

2997 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

2998 """Describes the output tensors.""" 

2999 

3000 @field_validator("outputs", mode="after") 

3001 @classmethod 

3002 def _validate_tensor_ids( 

3003 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

3004 ) -> Sequence[OutputTensorDescr]: 

3005 tensor_ids = [ 

3006 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

3007 ] 

3008 duplicate_tensor_ids: List[str] = [] 

3009 seen: Set[str] = set() 

3010 for t in tensor_ids: 

3011 if t in seen: 

3012 duplicate_tensor_ids.append(t) 

3013 

3014 seen.add(t) 

3015 

3016 if duplicate_tensor_ids: 

3017 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

3018 

3019 return outputs 

3020 

3021 @staticmethod 

3022 def _get_axes_with_parameterized_size( 

3023 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3024 ): 

3025 return { 

3026 f"{t.id}.{a.id}": (t, a, a.size) 

3027 for t in io 

3028 for a in t.axes 

3029 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

3030 } 

3031 

3032 @staticmethod 

3033 def _get_axes_with_independent_size( 

3034 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3035 ): 

3036 return { 

3037 (t.id, a.id): (t, a, a.size) 

3038 for t in io 

3039 for a in t.axes 

3040 if not isinstance(a, BatchAxis) 

3041 and isinstance(a.size, (int, ParameterizedSize)) 

3042 } 

3043 

3044 @field_validator("outputs", mode="after") 

3045 @classmethod 

3046 def _validate_output_axes( 

3047 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

3048 ) -> List[OutputTensorDescr]: 

3049 input_size_refs = cls._get_axes_with_independent_size( 

3050 info.data.get("inputs", []) 

3051 ) 

3052 output_size_refs = cls._get_axes_with_independent_size(outputs) 

3053 

3054 for i, out in enumerate(outputs): 

3055 valid_independent_refs: Dict[ 

3056 Tuple[TensorId, AxisId], 

3057 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3058 ] = { 

3059 **{ 

3060 (out.id, a.id): (out, a, a.size) 

3061 for a in out.axes 

3062 if not isinstance(a, BatchAxis) 

3063 and isinstance(a.size, (int, ParameterizedSize)) 

3064 }, 

3065 **input_size_refs, 

3066 **output_size_refs, 

3067 } 

3068 for a, ax in enumerate(out.axes): 

3069 cls._validate_axis( 

3070 "outputs", 

3071 i, 

3072 out.id, 

3073 a, 

3074 ax, 

3075 valid_independent_refs=valid_independent_refs, 

3076 ) 

3077 

3078 return outputs 

3079 

3080 packaged_by: List[Author] = Field( 

3081 default_factory=cast(Callable[[], List[Author]], list) 

3082 ) 

3083 """The persons that have packaged and uploaded this model. 

3084 Only required if those persons differ from the `authors`.""" 

3085 

3086 parent: Optional[LinkedModel] = None 

3087 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

3088 

3089 @model_validator(mode="after") 

3090 def _validate_parent_is_not_self(self) -> Self: 

3091 if self.parent is not None and self.parent.id == self.id: 

3092 raise ValueError("A model description may not reference itself as parent.") 

3093 

3094 return self 

3095 

3096 run_mode: Annotated[ 

3097 Optional[RunMode], 

3098 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

3099 ] = None 

3100 """Custom run mode for this model: for more complex prediction procedures like test time 

3101 data augmentation that currently cannot be expressed in the specification. 

3102 No standard run modes are defined yet.""" 

3103 

3104 timestamp: Datetime = Field(default_factory=Datetime.now) 

3105 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

3106 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

3107 (In Python a datetime object is valid, too).""" 

3108 

3109 training_data: Annotated[ 

3110 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

3111 Field(union_mode="left_to_right"), 

3112 ] = None 

3113 """The dataset used to train this model""" 

3114 

3115 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

3116 """The weights for this model. 

3117 Weights can be given for different formats, but should otherwise be equivalent. 

3118 The available weight formats determine which consumers can use this model.""" 

3119 

3120 config: Config = Field(default_factory=Config.model_construct) 

3121 

3122 @model_validator(mode="after") 

3123 def _add_default_cover(self) -> Self: 

3124 if not get_validation_context().perform_io_checks or self.covers: 

3125 return self 

3126 

3127 try: 

3128 generated_covers = generate_covers( 

3129 [ 

3130 (t, load_array(t.test_tensor)) 

3131 for t in self.inputs 

3132 if t.test_tensor is not None 

3133 ], 

3134 [ 

3135 (t, load_array(t.test_tensor)) 

3136 for t in self.outputs 

3137 if t.test_tensor is not None 

3138 ], 

3139 ) 

3140 except Exception as e: 

3141 issue_warning( 

3142 "Failed to generate cover image(s): {e}", 

3143 value=self.covers, 

3144 msg_context=dict(e=e), 

3145 field="covers", 

3146 ) 

3147 else: 

3148 self.covers.extend(generated_covers) 

3149 

3150 return self 

3151 

3152 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

3153 return self._get_test_arrays(self.inputs) 

3154 

3155 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

3156 return self._get_test_arrays(self.outputs) 

3157 

3158 @staticmethod 

3159 def _get_test_arrays( 

3160 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3161 ): 

3162 ts: List[FileDescr] = [] 

3163 for d in io_descr: 

3164 if d.test_tensor is None: 

3165 raise ValueError( 

3166 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 

3167 ) 

3168 ts.append(d.test_tensor) 

3169 

3170 data = [load_array(t) for t in ts] 

3171 assert all(isinstance(d, np.ndarray) for d in data) 

3172 return data 

3173 

3174 @staticmethod 

3175 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

3176 batch_size = 1 

3177 tensor_with_batchsize: Optional[TensorId] = None 

3178 for tid in tensor_sizes: 

3179 for aid, s in tensor_sizes[tid].items(): 

3180 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

3181 continue 

3182 

3183 if batch_size != 1: 

3184 assert tensor_with_batchsize is not None 

3185 raise ValueError( 

3186 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

3187 ) 

3188 

3189 batch_size = s 

3190 tensor_with_batchsize = tid 

3191 

3192 return batch_size 

3193 

3194 def get_output_tensor_sizes( 

3195 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

3196 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

3197 """Returns the tensor output sizes for given **input_sizes**. 

3198 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

3199 Otherwise it might be larger than the actual (valid) output""" 

3200 batch_size = self.get_batch_size(input_sizes) 

3201 ns = self.get_ns(input_sizes) 

3202 

3203 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

3204 return tensor_sizes.outputs 

3205 

3206 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

3207 """get parameter `n` for each parameterized axis 

3208 such that the valid input size is >= the given input size""" 

3209 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

3210 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

3211 for tid in input_sizes: 

3212 for aid, s in input_sizes[tid].items(): 

3213 size_descr = axes[tid][aid].size 

3214 if isinstance(size_descr, ParameterizedSize): 

3215 ret[(tid, aid)] = size_descr.get_n(s) 

3216 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

3217 pass 

3218 else: 

3219 assert_never(size_descr) 

3220 

3221 return ret 

3222 

3223 def get_tensor_sizes( 

3224 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

3225 ) -> _TensorSizes: 

3226 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

3227 return _TensorSizes( 

3228 { 

3229 t: { 

3230 aa: axis_sizes.inputs[(tt, aa)] 

3231 for tt, aa in axis_sizes.inputs 

3232 if tt == t 

3233 } 

3234 for t in {tt for tt, _ in axis_sizes.inputs} 

3235 }, 

3236 { 

3237 t: { 

3238 aa: axis_sizes.outputs[(tt, aa)] 

3239 for tt, aa in axis_sizes.outputs 

3240 if tt == t 

3241 } 

3242 for t in {tt for tt, _ in axis_sizes.outputs} 

3243 }, 

3244 ) 

3245 

3246 def get_axis_sizes( 

3247 self, 

3248 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3249 batch_size: Optional[int] = None, 

3250 *, 

3251 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

3252 ) -> _AxisSizes: 

3253 """Determine input and output block shape for scale factors **ns** 

3254 of parameterized input sizes. 

3255 

3256 Args: 

3257 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

3258 that is parameterized as `size = min + n * step`. 

3259 batch_size: The desired size of the batch dimension. 

3260 If given **batch_size** overwrites any batch size present in 

3261 **max_input_shape**. Default 1. 

3262 max_input_shape: Limits the derived block shapes. 

3263 Each axis for which the input size, parameterized by `n`, is larger 

3264 than **max_input_shape** is set to the minimal value `n_min` for which 

3265 this is still true. 

3266 Use this for small input samples or large values of **ns**. 

3267 Or simply whenever you know the full input shape. 

3268 

3269 Returns: 

3270 Resolved axis sizes for model inputs and outputs. 

3271 """ 

3272 max_input_shape = max_input_shape or {} 

3273 if batch_size is None: 

3274 for (_t_id, a_id), s in max_input_shape.items(): 

3275 if a_id == BATCH_AXIS_ID: 

3276 batch_size = s 

3277 break 

3278 else: 

3279 batch_size = 1 

3280 

3281 all_axes = { 

3282 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

3283 } 

3284 

3285 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

3286 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

3287 

3288 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

3289 if isinstance(a, BatchAxis): 

3290 if (t_descr.id, a.id) in ns: 

3291 logger.warning( 

3292 "Ignoring unexpected size increment factor (n) for batch axis" 

3293 + " of tensor '{}'.", 

3294 t_descr.id, 

3295 ) 

3296 return batch_size 

3297 elif isinstance(a.size, int): 

3298 if (t_descr.id, a.id) in ns: 

3299 logger.warning( 

3300 "Ignoring unexpected size increment factor (n) for fixed size" 

3301 + " axis '{}' of tensor '{}'.", 

3302 a.id, 

3303 t_descr.id, 

3304 ) 

3305 return a.size 

3306 elif isinstance(a.size, ParameterizedSize): 

3307 if (t_descr.id, a.id) not in ns: 

3308 raise ValueError( 

3309 "Size increment factor (n) missing for parametrized axis" 

3310 + f" '{a.id}' of tensor '{t_descr.id}'." 

3311 ) 

3312 n = ns[(t_descr.id, a.id)] 

3313 s_max = max_input_shape.get((t_descr.id, a.id)) 

3314 if s_max is not None: 

3315 n = min(n, a.size.get_n(s_max)) 

3316 

3317 return a.size.get_size(n) 

3318 

3319 elif isinstance(a.size, SizeReference): 

3320 if (t_descr.id, a.id) in ns: 

3321 logger.warning( 

3322 "Ignoring unexpected size increment factor (n) for axis '{}'" 

3323 + " of tensor '{}' with size reference.", 

3324 a.id, 

3325 t_descr.id, 

3326 ) 

3327 assert not isinstance(a, BatchAxis) 

3328 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

3329 assert not isinstance(ref_axis, BatchAxis) 

3330 ref_key = (a.size.tensor_id, a.size.axis_id) 

3331 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

3332 assert ref_size is not None, ref_key 

3333 assert not isinstance(ref_size, _DataDepSize), ref_key 

3334 return a.size.get_size( 

3335 axis=a, 

3336 ref_axis=ref_axis, 

3337 ref_size=ref_size, 

3338 ) 

3339 elif isinstance(a.size, DataDependentSize): 

3340 if (t_descr.id, a.id) in ns: 

3341 logger.warning( 

3342 "Ignoring unexpected increment factor (n) for data dependent" 

3343 + " size axis '{}' of tensor '{}'.", 

3344 a.id, 

3345 t_descr.id, 

3346 ) 

3347 return _DataDepSize(a.size.min, a.size.max) 

3348 else: 

3349 assert_never(a.size) 

3350 

3351 # first resolve all , but the `SizeReference` input sizes 

3352 for t_descr in self.inputs: 

3353 for a in t_descr.axes: 

3354 if not isinstance(a.size, SizeReference): 

3355 s = get_axis_size(a) 

3356 assert not isinstance(s, _DataDepSize) 

3357 inputs[t_descr.id, a.id] = s 

3358 

3359 # resolve all other input axis sizes 

3360 for t_descr in self.inputs: 

3361 for a in t_descr.axes: 

3362 if isinstance(a.size, SizeReference): 

3363 s = get_axis_size(a) 

3364 assert not isinstance(s, _DataDepSize) 

3365 inputs[t_descr.id, a.id] = s 

3366 

3367 # resolve all output axis sizes 

3368 for t_descr in self.outputs: 

3369 for a in t_descr.axes: 

3370 assert not isinstance(a.size, ParameterizedSize) 

3371 s = get_axis_size(a) 

3372 outputs[t_descr.id, a.id] = s 

3373 

3374 return _AxisSizes(inputs=inputs, outputs=outputs) 

3375 

3376 @model_validator(mode="before") 

3377 @classmethod 

3378 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

3379 cls.convert_from_old_format_wo_validation(data) 

3380 return data 

3381 

3382 @classmethod 

3383 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

3384 """Convert metadata following an older format version to this classes' format 

3385 without validating the result. 

3386 """ 

3387 if ( 

3388 data.get("type") == "model" 

3389 and isinstance(fv := data.get("format_version"), str) 

3390 and fv.count(".") == 2 

3391 ): 

3392 fv_parts = fv.split(".") 

3393 if any(not p.isdigit() for p in fv_parts): 

3394 return 

3395 

3396 fv_tuple = tuple(map(int, fv_parts)) 

3397 

3398 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

3399 if fv_tuple[:2] in ((0, 3), (0, 4)): 

3400 m04 = _ModelDescr_v0_4.load(data) 

3401 if isinstance(m04, InvalidDescr): 

3402 try: 

3403 updated = _model_conv.convert_as_dict( 

3404 m04 # pyright: ignore[reportArgumentType] 

3405 ) 

3406 except Exception as e: 

3407 logger.error( 

3408 "Failed to convert from invalid model 0.4 description." 

3409 + f"\nerror: {e}" 

3410 + "\nProceeding with model 0.5 validation without conversion." 

3411 ) 

3412 updated = None 

3413 else: 

3414 updated = _model_conv.convert_as_dict(m04) 

3415 

3416 if updated is not None: 

3417 data.clear() 

3418 data.update(updated) 

3419 

3420 elif fv_tuple[:2] == (0, 5): 

3421 # bump patch version 

3422 data["format_version"] = cls.implemented_format_version 

3423 

3424 

3425class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

3426 def _convert( 

3427 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

3428 ) -> "ModelDescr | dict[str, Any]": 

3429 name = "".join( 

3430 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

3431 for c in src.name 

3432 ) 

3433 

3434 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

3435 conv = ( 

3436 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

3437 ) 

3438 return None if auths is None else [conv(a) for a in auths] 

3439 

3440 if TYPE_CHECKING: 

3441 arch_file_conv = _arch_file_conv.convert 

3442 arch_lib_conv = _arch_lib_conv.convert 

3443 else: 

3444 arch_file_conv = _arch_file_conv.convert_as_dict 

3445 arch_lib_conv = _arch_lib_conv.convert_as_dict 

3446 

3447 input_size_refs = { 

3448 ipt.name: { 

3449 a: s 

3450 for a, s in zip( 

3451 ipt.axes, 

3452 ( 

3453 ipt.shape.min 

3454 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

3455 else ipt.shape 

3456 ), 

3457 ) 

3458 } 

3459 for ipt in src.inputs 

3460 if ipt.shape 

3461 } 

3462 output_size_refs = { 

3463 **{ 

3464 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

3465 for out in src.outputs 

3466 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

3467 }, 

3468 **input_size_refs, 

3469 } 

3470 

3471 return tgt( 

3472 attachments=( 

3473 [] 

3474 if src.attachments is None 

3475 else [FileDescr(source=f) for f in src.attachments.files] 

3476 ), 

3477 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType] 

3478 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType] 

3479 config=src.config, # pyright: ignore[reportArgumentType] 

3480 covers=src.covers, 

3481 description=src.description, 

3482 documentation=src.documentation, 

3483 format_version="0.5.6", 

3484 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3485 icon=src.icon, 

3486 id=None if src.id is None else ModelId(src.id), 

3487 id_emoji=src.id_emoji, 

3488 license=src.license, # type: ignore 

3489 links=src.links, 

3490 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType] 

3491 name=name, 

3492 tags=src.tags, 

3493 type=src.type, 

3494 uploader=src.uploader, 

3495 version=src.version, 

3496 inputs=[ # pyright: ignore[reportArgumentType] 

3497 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

3498 for ipt, tt, st in zip( 

3499 src.inputs, 

3500 src.test_inputs, 

3501 src.sample_inputs or [None] * len(src.test_inputs), 

3502 ) 

3503 ], 

3504 outputs=[ # pyright: ignore[reportArgumentType] 

3505 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

3506 for out, tt, st in zip( 

3507 src.outputs, 

3508 src.test_outputs, 

3509 src.sample_outputs or [None] * len(src.test_outputs), 

3510 ) 

3511 ], 

3512 parent=( 

3513 None 

3514 if src.parent is None 

3515 else LinkedModel( 

3516 id=ModelId( 

3517 str(src.parent.id) 

3518 + ( 

3519 "" 

3520 if src.parent.version_number is None 

3521 else f"/{src.parent.version_number}" 

3522 ) 

3523 ) 

3524 ) 

3525 ), 

3526 training_data=( 

3527 None 

3528 if src.training_data is None 

3529 else ( 

3530 LinkedDataset( 

3531 id=DatasetId( 

3532 str(src.training_data.id) 

3533 + ( 

3534 "" 

3535 if src.training_data.version_number is None 

3536 else f"/{src.training_data.version_number}" 

3537 ) 

3538 ) 

3539 ) 

3540 if isinstance(src.training_data, LinkedDataset02) 

3541 else src.training_data 

3542 ) 

3543 ), 

3544 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType] 

3545 run_mode=src.run_mode, 

3546 timestamp=src.timestamp, 

3547 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

3548 keras_hdf5=(w := src.weights.keras_hdf5) 

3549 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

3550 authors=conv_authors(w.authors), 

3551 source=w.source, 

3552 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3553 parent=w.parent, 

3554 ), 

3555 onnx=(w := src.weights.onnx) 

3556 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

3557 source=w.source, 

3558 authors=conv_authors(w.authors), 

3559 parent=w.parent, 

3560 opset_version=w.opset_version or 15, 

3561 ), 

3562 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

3563 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

3564 source=w.source, 

3565 authors=conv_authors(w.authors), 

3566 parent=w.parent, 

3567 architecture=( 

3568 arch_file_conv( 

3569 w.architecture, 

3570 w.architecture_sha256, 

3571 w.kwargs, 

3572 ) 

3573 if isinstance(w.architecture, _CallableFromFile_v0_4) 

3574 else arch_lib_conv(w.architecture, w.kwargs) 

3575 ), 

3576 pytorch_version=w.pytorch_version or Version("1.10"), 

3577 dependencies=( 

3578 None 

3579 if w.dependencies is None 

3580 else (FileDescr if TYPE_CHECKING else dict)( 

3581 source=cast( 

3582 FileSource, 

3583 str(deps := w.dependencies)[ 

3584 ( 

3585 len("conda:") 

3586 if str(deps).startswith("conda:") 

3587 else 0 

3588 ) : 

3589 ], 

3590 ) 

3591 ) 

3592 ), 

3593 ), 

3594 tensorflow_js=(w := src.weights.tensorflow_js) 

3595 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

3596 source=w.source, 

3597 authors=conv_authors(w.authors), 

3598 parent=w.parent, 

3599 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3600 ), 

3601 tensorflow_saved_model_bundle=( 

3602 w := src.weights.tensorflow_saved_model_bundle 

3603 ) 

3604 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

3605 authors=conv_authors(w.authors), 

3606 parent=w.parent, 

3607 source=w.source, 

3608 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3609 dependencies=( 

3610 None 

3611 if w.dependencies is None 

3612 else (FileDescr if TYPE_CHECKING else dict)( 

3613 source=cast( 

3614 FileSource, 

3615 ( 

3616 str(w.dependencies)[len("conda:") :] 

3617 if str(w.dependencies).startswith("conda:") 

3618 else str(w.dependencies) 

3619 ), 

3620 ) 

3621 ) 

3622 ), 

3623 ), 

3624 torchscript=(w := src.weights.torchscript) 

3625 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

3626 source=w.source, 

3627 authors=conv_authors(w.authors), 

3628 parent=w.parent, 

3629 pytorch_version=w.pytorch_version or Version("1.10"), 

3630 ), 

3631 ), 

3632 ) 

3633 

3634 

3635_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

3636 

3637 

3638# create better cover images for 3d data and non-image outputs 

3639def generate_covers( 

3640 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

3641 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

3642) -> List[Path]: 

3643 def squeeze( 

3644 data: NDArray[Any], axes: Sequence[AnyAxis] 

3645 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

3646 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

3647 if data.ndim != len(axes): 

3648 raise ValueError( 

3649 f"tensor shape {data.shape} does not match described axes" 

3650 + f" {[a.id for a in axes]}" 

3651 ) 

3652 

3653 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

3654 return data.squeeze(), axes 

3655 

3656 def normalize( 

3657 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

3658 ) -> NDArray[np.float32]: 

3659 data = data.astype("float32") 

3660 data -= data.min(axis=axis, keepdims=True) 

3661 data /= data.max(axis=axis, keepdims=True) + eps 

3662 return data 

3663 

3664 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

3665 original_shape = data.shape 

3666 original_axes = list(axes) 

3667 data, axes = squeeze(data, axes) 

3668 

3669 # take slice fom any batch or index axis if needed 

3670 # and convert the first channel axis and take a slice from any additional channel axes 

3671 slices: Tuple[slice, ...] = () 

3672 ndim = data.ndim 

3673 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

3674 has_c_axis = False 

3675 for i, a in enumerate(axes): 

3676 s = data.shape[i] 

3677 assert s > 1 

3678 if ( 

3679 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

3680 and ndim > ndim_need 

3681 ): 

3682 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3683 ndim -= 1 

3684 elif isinstance(a, ChannelAxis): 

3685 if has_c_axis: 

3686 # second channel axis 

3687 data = data[slices + (slice(0, 1),)] 

3688 ndim -= 1 

3689 else: 

3690 has_c_axis = True 

3691 if s == 2: 

3692 # visualize two channels with cyan and magenta 

3693 data = np.concatenate( 

3694 [ 

3695 data[slices + (slice(1, 2),)], 

3696 data[slices + (slice(0, 1),)], 

3697 ( 

3698 data[slices + (slice(0, 1),)] 

3699 + data[slices + (slice(1, 2),)] 

3700 ) 

3701 / 2, # TODO: take maximum instead? 

3702 ], 

3703 axis=i, 

3704 ) 

3705 elif data.shape[i] == 3: 

3706 pass # visualize 3 channels as RGB 

3707 else: 

3708 # visualize first 3 channels as RGB 

3709 data = data[slices + (slice(3),)] 

3710 

3711 assert data.shape[i] == 3 

3712 

3713 slices += (slice(None),) 

3714 

3715 data, axes = squeeze(data, axes) 

3716 assert len(axes) == ndim 

3717 # take slice from z axis if needed 

3718 slices = () 

3719 if ndim > ndim_need: 

3720 for i, a in enumerate(axes): 

3721 s = data.shape[i] 

3722 if a.id == AxisId("z"): 

3723 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3724 data, axes = squeeze(data, axes) 

3725 ndim -= 1 

3726 break 

3727 

3728 slices += (slice(None),) 

3729 

3730 # take slice from any space or time axis 

3731 slices = () 

3732 

3733 for i, a in enumerate(axes): 

3734 if ndim <= ndim_need: 

3735 break 

3736 

3737 s = data.shape[i] 

3738 assert s > 1 

3739 if isinstance( 

3740 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

3741 ): 

3742 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3743 ndim -= 1 

3744 

3745 slices += (slice(None),) 

3746 

3747 del slices 

3748 data, axes = squeeze(data, axes) 

3749 assert len(axes) == ndim 

3750 

3751 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2): 

3752 raise ValueError( 

3753 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}." 

3754 ) 

3755 

3756 if not has_c_axis: 

3757 assert ndim == 2 

3758 data = np.repeat(data[:, :, None], 3, axis=2) 

3759 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

3760 ndim += 1 

3761 

3762 assert ndim == 3 

3763 

3764 # transpose axis order such that longest axis comes first... 

3765 axis_order: List[int] = list(np.argsort(list(data.shape))) 

3766 axis_order.reverse() 

3767 # ... and channel axis is last 

3768 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

3769 axis_order.append(axis_order.pop(c)) 

3770 axes = [axes[ao] for ao in axis_order] 

3771 data = data.transpose(axis_order) 

3772 

3773 # h, w = data.shape[:2] 

3774 # if h / w in (1.0 or 2.0): 

3775 # pass 

3776 # elif h / w < 2: 

3777 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

3778 

3779 norm_along = ( 

3780 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

3781 ) 

3782 # normalize the data and map to 8 bit 

3783 data = normalize(data, norm_along) 

3784 data = (data * 255).astype("uint8") 

3785 

3786 return data 

3787 

3788 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

3789 assert im0.dtype == im1.dtype == np.uint8 

3790 assert im0.shape == im1.shape 

3791 assert im0.ndim == 3 

3792 N, M, C = im0.shape 

3793 assert C == 3 

3794 out = np.ones((N, M, C), dtype="uint8") 

3795 for c in range(C): 

3796 outc = np.tril(im0[..., c]) 

3797 mask = outc == 0 

3798 outc[mask] = np.triu(im1[..., c])[mask] 

3799 out[..., c] = outc 

3800 

3801 return out 

3802 

3803 if not inputs: 

3804 raise ValueError("Missing test input tensor for cover generation.") 

3805 

3806 if not outputs: 

3807 raise ValueError("Missing test output tensor for cover generation.") 

3808 

3809 ipt_descr, ipt = inputs[0] 

3810 out_descr, out = outputs[0] 

3811 

3812 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

3813 out_img = to_2d_image(out, out_descr.axes) 

3814 

3815 cover_folder = Path(mkdtemp()) 

3816 if ipt_img.shape == out_img.shape: 

3817 covers = [cover_folder / "cover.png"] 

3818 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

3819 else: 

3820 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

3821 imwrite(covers[0], ipt_img) 

3822 imwrite(covers[1], out_img) 

3823 

3824 return covers