Coverage for src / bioimageio / spec / model / v0_5.py: 76%

1547 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-17 16:08 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from copy import deepcopy 

8from itertools import chain 

9from math import ceil 

10from pathlib import Path, PurePosixPath 

11from tempfile import mkdtemp 

12from textwrap import dedent 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 ClassVar, 

18 Dict, 

19 Generic, 

20 List, 

21 Literal, 

22 Mapping, 

23 NamedTuple, 

24 Optional, 

25 Sequence, 

26 Set, 

27 Tuple, 

28 Type, 

29 TypeVar, 

30 Union, 

31 cast, 

32 overload, 

33) 

34 

35import numpy as np 

36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

38from loguru import logger 

39from numpy.typing import NDArray 

40from pydantic import ( 

41 AfterValidator, 

42 Discriminator, 

43 Field, 

44 RootModel, 

45 SerializationInfo, 

46 SerializerFunctionWrapHandler, 

47 StrictInt, 

48 Tag, 

49 ValidationInfo, 

50 WrapSerializer, 

51 field_validator, 

52 model_serializer, 

53 model_validator, 

54) 

55from typing_extensions import Annotated, Self, assert_never, get_args 

56 

57from .._internal.common_nodes import ( 

58 InvalidDescr, 

59 KwargsNode, 

60 Node, 

61 NodeWithExplicitlySetFields, 

62) 

63from .._internal.constants import DTYPE_LIMITS 

64from .._internal.field_warning import issue_warning, warn 

65from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

66from .._internal.io import FileDescr as FileDescr 

67from .._internal.io import ( 

68 FileSource, 

69 WithSuffix, 

70 YamlValue, 

71 extract_file_name, 

72 get_reader, 

73 wo_special_file_name, 

74) 

75from .._internal.io_basics import Sha256 as Sha256 

76from .._internal.io_packaging import ( 

77 FileDescr_, 

78 FileSource_, 

79 package_file_descr_serializer, 

80) 

81from .._internal.io_utils import load_array 

82from .._internal.node_converter import Converter 

83from .._internal.type_guards import is_dict, is_sequence 

84from .._internal.types import ( 

85 FAIR, 

86 AbsoluteTolerance, 

87 LowerCaseIdentifier, 

88 LowerCaseIdentifierAnno, 

89 MismatchedElementsPerMillion, 

90 RelativeTolerance, 

91) 

92from .._internal.types import Datetime as Datetime 

93from .._internal.types import Identifier as Identifier 

94from .._internal.types import NotEmpty as NotEmpty 

95from .._internal.types import SiUnit as SiUnit 

96from .._internal.url import HttpUrl as HttpUrl 

97from .._internal.validation_context import get_validation_context 

98from .._internal.validator_annotations import RestrictCharacters 

99from .._internal.version_type import Version as Version 

100from .._internal.warning_levels import INFO 

101from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

102from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

103from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

104from ..dataset.v0_3 import DatasetId as DatasetId 

105from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

106from ..dataset.v0_3 import Uploader as Uploader 

107from ..generic.v0_3 import ( 

108 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

109) 

110from ..generic.v0_3 import Author as Author 

111from ..generic.v0_3 import BadgeDescr as BadgeDescr 

112from ..generic.v0_3 import CiteEntry as CiteEntry 

113from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

114from ..generic.v0_3 import Doi as Doi 

115from ..generic.v0_3 import ( 

116 FileSource_documentation, 

117 GenericModelDescrBase, 

118 LinkedResourceBase, 

119 _author_conv, # pyright: ignore[reportPrivateUsage] 

120 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

121) 

122from ..generic.v0_3 import LicenseId as LicenseId 

123from ..generic.v0_3 import LinkedResource as LinkedResource 

124from ..generic.v0_3 import Maintainer as Maintainer 

125from ..generic.v0_3 import OrcidId as OrcidId 

126from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

127from ..generic.v0_3 import ResourceId as ResourceId 

128from .v0_4 import Author as _Author_v0_4 

129from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

130from .v0_4 import CallableFromDepencency as CallableFromDepencency 

131from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

132from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

133from .v0_4 import ClipDescr as _ClipDescr_v0_4 

134from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

135from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

136from .v0_4 import KnownRunMode as KnownRunMode 

137from .v0_4 import ModelDescr as _ModelDescr_v0_4 

138from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

139from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

140from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

141from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

142from .v0_4 import RunMode as RunMode 

143from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

144from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

145from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

146from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

147from .v0_4 import TensorName as _TensorName_v0_4 

148from .v0_4 import WeightsFormat as WeightsFormat 

149from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

150from .v0_4 import package_weights 

151 

152SpaceUnit = Literal[ 

153 "attometer", 

154 "angstrom", 

155 "centimeter", 

156 "decimeter", 

157 "exameter", 

158 "femtometer", 

159 "foot", 

160 "gigameter", 

161 "hectometer", 

162 "inch", 

163 "kilometer", 

164 "megameter", 

165 "meter", 

166 "micrometer", 

167 "mile", 

168 "millimeter", 

169 "nanometer", 

170 "parsec", 

171 "petameter", 

172 "picometer", 

173 "terameter", 

174 "yard", 

175 "yoctometer", 

176 "yottameter", 

177 "zeptometer", 

178 "zettameter", 

179] 

180"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

181 

182TimeUnit = Literal[ 

183 "attosecond", 

184 "centisecond", 

185 "day", 

186 "decisecond", 

187 "exasecond", 

188 "femtosecond", 

189 "gigasecond", 

190 "hectosecond", 

191 "hour", 

192 "kilosecond", 

193 "megasecond", 

194 "microsecond", 

195 "millisecond", 

196 "minute", 

197 "nanosecond", 

198 "petasecond", 

199 "picosecond", 

200 "second", 

201 "terasecond", 

202 "yoctosecond", 

203 "yottasecond", 

204 "zeptosecond", 

205 "zettasecond", 

206] 

207"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

208 

209AxisType = Literal["batch", "channel", "index", "time", "space"] 

210 

211_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

212 "b": "batch", 

213 "t": "time", 

214 "i": "index", 

215 "c": "channel", 

216 "x": "space", 

217 "y": "space", 

218 "z": "space", 

219} 

220 

221_AXIS_ID_MAP = { 

222 "b": "batch", 

223 "t": "time", 

224 "i": "index", 

225 "c": "channel", 

226} 

227 

228 

229class TensorId(LowerCaseIdentifier): 

230 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

231 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

232 ] 

233 

234 

235def _normalize_axis_id(a: str): 

236 a = str(a) 

237 normalized = _AXIS_ID_MAP.get(a, a) 

238 if a != normalized: 

239 logger.opt(depth=3).warning( 

240 "Normalized axis id from '{}' to '{}'.", a, normalized 

241 ) 

242 return normalized 

243 

244 

245class AxisId(LowerCaseIdentifier): 

246 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

247 Annotated[ 

248 LowerCaseIdentifierAnno, 

249 MaxLen(16), 

250 AfterValidator(_normalize_axis_id), 

251 ] 

252 ] 

253 

254 

255def _is_batch(a: str) -> bool: 

256 return str(a) == "batch" 

257 

258 

259def _is_not_batch(a: str) -> bool: 

260 return not _is_batch(a) 

261 

262 

263NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

264 

265PreprocessingId = Literal[ 

266 "binarize", 

267 "clip", 

268 "ensure_dtype", 

269 "fixed_zero_mean_unit_variance", 

270 "scale_linear", 

271 "scale_range", 

272 "sigmoid", 

273 "softmax", 

274] 

275PostprocessingId = Literal[ 

276 "binarize", 

277 "clip", 

278 "ensure_dtype", 

279 "fixed_zero_mean_unit_variance", 

280 "scale_linear", 

281 "scale_mean_variance", 

282 "scale_range", 

283 "sigmoid", 

284 "softmax", 

285 "zero_mean_unit_variance", 

286] 

287 

288 

289SAME_AS_TYPE = "<same as type>" 

290 

291 

292ParameterizedSize_N = int 

293""" 

294Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

295""" 

296 

297 

298class ParameterizedSize(Node): 

299 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

300 

301 - **min** and **step** are given by the model description. 

302 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

303 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

304 This allows to adjust the axis size more generically. 

305 """ 

306 

307 N: ClassVar[Type[int]] = ParameterizedSize_N 

308 """Positive integer to parameterize this axis""" 

309 

310 min: Annotated[int, Gt(0)] 

311 step: Annotated[int, Gt(0)] 

312 

313 def validate_size(self, size: int) -> int: 

314 if size < self.min: 

315 raise ValueError(f"size {size} < {self.min}") 

316 if (size - self.min) % self.step != 0: 

317 raise ValueError( 

318 f"axis of size {size} is not parameterized by `min + n*step` =" 

319 + f" `{self.min} + n*{self.step}`" 

320 ) 

321 

322 return size 

323 

324 def get_size(self, n: ParameterizedSize_N) -> int: 

325 return self.min + self.step * n 

326 

327 def get_n(self, s: int) -> ParameterizedSize_N: 

328 """return smallest n parameterizing a size greater or equal than `s`""" 

329 return ceil((s - self.min) / self.step) 

330 

331 

332class DataDependentSize(Node): 

333 min: Annotated[int, Gt(0)] = 1 

334 max: Annotated[Optional[int], Gt(1)] = None 

335 

336 @model_validator(mode="after") 

337 def _validate_max_gt_min(self): 

338 if self.max is not None and self.min >= self.max: 

339 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

340 

341 return self 

342 

343 def validate_size(self, size: int) -> int: 

344 if size < self.min: 

345 raise ValueError(f"size {size} < {self.min}") 

346 

347 if self.max is not None and size > self.max: 

348 raise ValueError(f"size {size} > {self.max}") 

349 

350 return size 

351 

352 

353class SizeReference(Node): 

354 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

355 

356 `axis.size = reference.size * reference.scale / axis.scale + offset` 

357 

358 Note: 

359 1. The axis and the referenced axis need to have the same unit (or no unit). 

360 2. Batch axes may not be referenced. 

361 3. Fractions are rounded down. 

362 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

363 `concatenable` as well with the same block order. 

364 

365 Example: 

366 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

367 Let's assume that we want to express the image height h in relation to its width w 

368 instead of only accepting input images of exactly 100*49 pixels 

369 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

370 

371 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

372 >>> h = SpaceInputAxis( 

373 ... id=AxisId("h"), 

374 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

375 ... unit="millimeter", 

376 ... scale=4, 

377 ... ) 

378 >>> print(h.size.get_size(h, w)) 

379 49 

380 

381 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

382 """ 

383 

384 tensor_id: TensorId 

385 """tensor id of the reference axis""" 

386 

387 axis_id: AxisId 

388 """axis id of the reference axis""" 

389 

390 offset: StrictInt = 0 

391 

392 def get_size( 

393 self, 

394 axis: Union[ 

395 ChannelAxis, 

396 IndexInputAxis, 

397 IndexOutputAxis, 

398 TimeInputAxis, 

399 SpaceInputAxis, 

400 TimeOutputAxis, 

401 TimeOutputAxisWithHalo, 

402 SpaceOutputAxis, 

403 SpaceOutputAxisWithHalo, 

404 ], 

405 ref_axis: Union[ 

406 ChannelAxis, 

407 IndexInputAxis, 

408 IndexOutputAxis, 

409 TimeInputAxis, 

410 SpaceInputAxis, 

411 TimeOutputAxis, 

412 TimeOutputAxisWithHalo, 

413 SpaceOutputAxis, 

414 SpaceOutputAxisWithHalo, 

415 ], 

416 n: ParameterizedSize_N = 0, 

417 ref_size: Optional[int] = None, 

418 ): 

419 """Compute the concrete size for a given axis and its reference axis. 

420 

421 Args: 

422 axis: The axis this [SizeReference][] is the size of. 

423 ref_axis: The reference axis to compute the size from. 

424 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

425 and no fixed **ref_size** is given, 

426 **n** is used to compute the size of the parameterized **ref_axis**. 

427 ref_size: Overwrite the reference size instead of deriving it from 

428 **ref_axis** 

429 (**ref_axis.scale** is still used; any given **n** is ignored). 

430 """ 

431 assert axis.size == self, ( 

432 "Given `axis.size` is not defined by this `SizeReference`" 

433 ) 

434 

435 assert ref_axis.id == self.axis_id, ( 

436 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

437 ) 

438 

439 assert axis.unit == ref_axis.unit, ( 

440 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

441 f" but {axis.unit}!={ref_axis.unit}" 

442 ) 

443 if ref_size is None: 

444 if isinstance(ref_axis.size, (int, float)): 

445 ref_size = ref_axis.size 

446 elif isinstance(ref_axis.size, ParameterizedSize): 

447 ref_size = ref_axis.size.get_size(n) 

448 elif isinstance(ref_axis.size, DataDependentSize): 

449 raise ValueError( 

450 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

451 ) 

452 elif isinstance(ref_axis.size, SizeReference): 

453 raise ValueError( 

454 "Reference axis referenced in `SizeReference` may not be sized by a" 

455 + " `SizeReference` itself." 

456 ) 

457 else: 

458 assert_never(ref_axis.size) 

459 

460 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

461 

462 @staticmethod 

463 def _get_unit( 

464 axis: Union[ 

465 ChannelAxis, 

466 IndexInputAxis, 

467 IndexOutputAxis, 

468 TimeInputAxis, 

469 SpaceInputAxis, 

470 TimeOutputAxis, 

471 TimeOutputAxisWithHalo, 

472 SpaceOutputAxis, 

473 SpaceOutputAxisWithHalo, 

474 ], 

475 ): 

476 return axis.unit 

477 

478 

479class AxisBase(NodeWithExplicitlySetFields): 

480 id: AxisId 

481 """An axis id unique across all axes of one tensor.""" 

482 

483 description: Annotated[str, MaxLen(128)] = "" 

484 """A short description of this axis beyond its type and id.""" 

485 

486 

487class WithHalo(Node): 

488 halo: Annotated[int, Ge(1)] 

489 """The halo should be cropped from the output tensor to avoid boundary effects. 

490 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

491 To document a halo that is already cropped by the model use `size.offset` instead.""" 

492 

493 size: Annotated[ 

494 SizeReference, 

495 Field( 

496 examples=[ 

497 10, 

498 SizeReference( 

499 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

500 ).model_dump(mode="json"), 

501 ] 

502 ), 

503 ] 

504 """reference to another axis with an optional offset (see [SizeReference][])""" 

505 

506 

507BATCH_AXIS_ID = AxisId("batch") 

508 

509 

510class BatchAxis(AxisBase): 

511 implemented_type: ClassVar[Literal["batch"]] = "batch" 

512 if TYPE_CHECKING: 

513 type: Literal["batch"] = "batch" 

514 else: 

515 type: Literal["batch"] 

516 

517 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

518 size: Optional[Literal[1]] = None 

519 """The batch size may be fixed to 1, 

520 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

521 

522 @property 

523 def scale(self): 

524 return 1.0 

525 

526 @property 

527 def concatenable(self): 

528 return True 

529 

530 @property 

531 def unit(self): 

532 return None 

533 

534 

535class ChannelAxis(AxisBase): 

536 implemented_type: ClassVar[Literal["channel"]] = "channel" 

537 if TYPE_CHECKING: 

538 type: Literal["channel"] = "channel" 

539 else: 

540 type: Literal["channel"] 

541 

542 id: NonBatchAxisId = AxisId("channel") 

543 

544 channel_names: NotEmpty[List[Identifier]] 

545 

546 @property 

547 def size(self) -> int: 

548 return len(self.channel_names) 

549 

550 @property 

551 def concatenable(self): 

552 return False 

553 

554 @property 

555 def scale(self) -> float: 

556 return 1.0 

557 

558 @property 

559 def unit(self): 

560 return None 

561 

562 

563class IndexAxisBase(AxisBase): 

564 implemented_type: ClassVar[Literal["index"]] = "index" 

565 if TYPE_CHECKING: 

566 type: Literal["index"] = "index" 

567 else: 

568 type: Literal["index"] 

569 

570 id: NonBatchAxisId = AxisId("index") 

571 

572 @property 

573 def scale(self) -> float: 

574 return 1.0 

575 

576 @property 

577 def unit(self): 

578 return None 

579 

580 

581class _WithInputAxisSize(Node): 

582 size: Annotated[ 

583 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

584 Field( 

585 examples=[ 

586 10, 

587 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

588 SizeReference( 

589 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

590 ).model_dump(mode="json"), 

591 ] 

592 ), 

593 ] 

594 """The size/length of this axis can be specified as 

595 - fixed integer 

596 - parameterized series of valid sizes ([ParameterizedSize][]) 

597 - reference to another axis with an optional offset ([SizeReference][]) 

598 """ 

599 

600 

601class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

602 concatenable: bool = False 

603 """If a model has a `concatenable` input axis, it can be processed blockwise, 

604 splitting a longer sample axis into blocks matching its input tensor description. 

605 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

606 input axis. 

607 """ 

608 

609 

610class IndexOutputAxis(IndexAxisBase): 

611 size: Annotated[ 

612 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

613 Field( 

614 examples=[ 

615 10, 

616 SizeReference( 

617 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

618 ).model_dump(mode="json"), 

619 ] 

620 ), 

621 ] 

622 """The size/length of this axis can be specified as 

623 - fixed integer 

624 - reference to another axis with an optional offset ([SizeReference][]) 

625 - data dependent size using [DataDependentSize][] (size is only known after model inference) 

626 """ 

627 

628 

629class TimeAxisBase(AxisBase): 

630 implemented_type: ClassVar[Literal["time"]] = "time" 

631 if TYPE_CHECKING: 

632 type: Literal["time"] = "time" 

633 else: 

634 type: Literal["time"] 

635 

636 id: NonBatchAxisId = AxisId("time") 

637 unit: Optional[TimeUnit] = None 

638 scale: Annotated[float, Gt(0)] = 1.0 

639 

640 

641class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

642 concatenable: bool = False 

643 """If a model has a `concatenable` input axis, it can be processed blockwise, 

644 splitting a longer sample axis into blocks matching its input tensor description. 

645 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

646 input axis. 

647 """ 

648 

649 

650class SpaceAxisBase(AxisBase): 

651 implemented_type: ClassVar[Literal["space"]] = "space" 

652 if TYPE_CHECKING: 

653 type: Literal["space"] = "space" 

654 else: 

655 type: Literal["space"] 

656 

657 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

658 unit: Optional[SpaceUnit] = None 

659 scale: Annotated[float, Gt(0)] = 1.0 

660 

661 

662class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

663 concatenable: bool = False 

664 """If a model has a `concatenable` input axis, it can be processed blockwise, 

665 splitting a longer sample axis into blocks matching its input tensor description. 

666 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

667 input axis. 

668 """ 

669 

670 

671INPUT_AXIS_TYPES = ( 

672 BatchAxis, 

673 ChannelAxis, 

674 IndexInputAxis, 

675 TimeInputAxis, 

676 SpaceInputAxis, 

677) 

678"""intended for isinstance comparisons in py<3.10""" 

679 

680_InputAxisUnion = Union[ 

681 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

682] 

683InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

684 

685 

686class _WithOutputAxisSize(Node): 

687 size: Annotated[ 

688 Union[Annotated[int, Gt(0)], SizeReference], 

689 Field( 

690 examples=[ 

691 10, 

692 SizeReference( 

693 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

694 ).model_dump(mode="json"), 

695 ] 

696 ), 

697 ] 

698 """The size/length of this axis can be specified as 

699 - fixed integer 

700 - reference to another axis with an optional offset (see [SizeReference][]) 

701 """ 

702 

703 

704class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

705 pass 

706 

707 

708class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

709 pass 

710 

711 

712def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

713 if isinstance(v, dict): 

714 return "with_halo" if "halo" in v else "wo_halo" 

715 else: 

716 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

717 

718 

719_TimeOutputAxisUnion = Annotated[ 

720 Union[ 

721 Annotated[TimeOutputAxis, Tag("wo_halo")], 

722 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

723 ], 

724 Discriminator(_get_halo_axis_discriminator_value), 

725] 

726 

727 

728class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

729 pass 

730 

731 

732class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

733 pass 

734 

735 

736_SpaceOutputAxisUnion = Annotated[ 

737 Union[ 

738 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

739 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

740 ], 

741 Discriminator(_get_halo_axis_discriminator_value), 

742] 

743 

744 

745_OutputAxisUnion = Union[ 

746 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

747] 

748OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

749 

750OUTPUT_AXIS_TYPES = ( 

751 BatchAxis, 

752 ChannelAxis, 

753 IndexOutputAxis, 

754 TimeOutputAxis, 

755 TimeOutputAxisWithHalo, 

756 SpaceOutputAxis, 

757 SpaceOutputAxisWithHalo, 

758) 

759"""intended for isinstance comparisons in py<3.10""" 

760 

761 

762AnyAxis = Union[InputAxis, OutputAxis] 

763 

764ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

765"""intended for isinstance comparisons in py<3.10""" 

766 

767TVs = Union[ 

768 NotEmpty[List[int]], 

769 NotEmpty[List[float]], 

770 NotEmpty[List[bool]], 

771 NotEmpty[List[str]], 

772] 

773 

774 

775NominalOrOrdinalDType = Literal[ 

776 "float32", 

777 "float64", 

778 "uint8", 

779 "int8", 

780 "uint16", 

781 "int16", 

782 "uint32", 

783 "int32", 

784 "uint64", 

785 "int64", 

786 "bool", 

787] 

788 

789 

790class NominalOrOrdinalDataDescr(Node): 

791 values: TVs 

792 """A fixed set of nominal or an ascending sequence of ordinal values. 

793 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

794 String `values` are interpreted as labels for tensor values 0, ..., N. 

795 Note: as YAML 1.2 does not natively support a "set" datatype, 

796 nominal values should be given as a sequence (aka list/array) as well. 

797 """ 

798 

799 type: Annotated[ 

800 NominalOrOrdinalDType, 

801 Field( 

802 examples=[ 

803 "float32", 

804 "uint8", 

805 "uint16", 

806 "int64", 

807 "bool", 

808 ], 

809 ), 

810 ] = "uint8" 

811 

812 @model_validator(mode="after") 

813 def _validate_values_match_type( 

814 self, 

815 ) -> Self: 

816 incompatible: List[Any] = [] 

817 for v in self.values: 

818 if self.type == "bool": 

819 if not isinstance(v, bool): 

820 incompatible.append(v) 

821 elif self.type in DTYPE_LIMITS: 

822 if ( 

823 isinstance(v, (int, float)) 

824 and ( 

825 v < DTYPE_LIMITS[self.type].min 

826 or v > DTYPE_LIMITS[self.type].max 

827 ) 

828 or (isinstance(v, str) and "uint" not in self.type) 

829 or (isinstance(v, float) and "int" in self.type) 

830 ): 

831 incompatible.append(v) 

832 else: 

833 incompatible.append(v) 

834 

835 if len(incompatible) == 5: 

836 incompatible.append("...") 

837 break 

838 

839 if incompatible: 

840 raise ValueError( 

841 f"data type '{self.type}' incompatible with values {incompatible}" 

842 ) 

843 

844 return self 

845 

846 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

847 

848 @property 

849 def range(self): 

850 if isinstance(self.values[0], str): 

851 return 0, len(self.values) - 1 

852 else: 

853 return min(self.values), max(self.values) 

854 

855 

856IntervalOrRatioDType = Literal[ 

857 "float32", 

858 "float64", 

859 "uint8", 

860 "int8", 

861 "uint16", 

862 "int16", 

863 "uint32", 

864 "int32", 

865 "uint64", 

866 "int64", 

867] 

868 

869 

870class IntervalOrRatioDataDescr(Node): 

871 type: Annotated[ # TODO: rename to dtype 

872 IntervalOrRatioDType, 

873 Field( 

874 examples=["float32", "float64", "uint8", "uint16"], 

875 ), 

876 ] = "float32" 

877 range: Tuple[Optional[float], Optional[float]] = ( 

878 None, 

879 None, 

880 ) 

881 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

882 `None` corresponds to min/max of what can be expressed by **type**.""" 

883 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

884 scale: float = 1.0 

885 """Scale for data on an interval (or ratio) scale.""" 

886 offset: Optional[float] = None 

887 """Offset for data on a ratio scale.""" 

888 

889 @model_validator(mode="before") 

890 def _replace_inf(cls, data: Any): 

891 if is_dict(data): 

892 if "range" in data and is_sequence(data["range"]): 

893 forbidden = ( 

894 "inf", 

895 "-inf", 

896 ".inf", 

897 "-.inf", 

898 float("inf"), 

899 float("-inf"), 

900 ) 

901 if any(v in forbidden for v in data["range"]): 

902 issue_warning("replaced 'inf' value", value=data["range"]) 

903 

904 data["range"] = tuple( 

905 (None if v in forbidden else v) for v in data["range"] 

906 ) 

907 

908 return data 

909 

910 

911TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

912 

913 

914class BinarizeKwargs(KwargsNode): 

915 """key word arguments for [BinarizeDescr][]""" 

916 

917 threshold: float 

918 """The fixed threshold""" 

919 

920 

921class BinarizeAlongAxisKwargs(KwargsNode): 

922 """key word arguments for [BinarizeDescr][]""" 

923 

924 threshold: NotEmpty[List[float]] 

925 """The fixed threshold values along `axis`""" 

926 

927 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

928 """The `threshold` axis""" 

929 

930 

931class BinarizeDescr(NodeWithExplicitlySetFields): 

932 """Binarize the tensor with a fixed threshold. 

933 

934 Values above [BinarizeKwargs.threshold][]/[BinarizeAlongAxisKwargs.threshold][] 

935 will be set to one, values below the threshold to zero. 

936 

937 Examples: 

938 - in YAML 

939 ```yaml 

940 postprocessing: 

941 - id: binarize 

942 kwargs: 

943 axis: 'channel' 

944 threshold: [0.25, 0.5, 0.75] 

945 ``` 

946 - in Python: 

947 >>> postprocessing = [BinarizeDescr( 

948 ... kwargs=BinarizeAlongAxisKwargs( 

949 ... axis=AxisId('channel'), 

950 ... threshold=[0.25, 0.5, 0.75], 

951 ... ) 

952 ... )] 

953 """ 

954 

955 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

956 if TYPE_CHECKING: 

957 id: Literal["binarize"] = "binarize" 

958 else: 

959 id: Literal["binarize"] 

960 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

961 

962 

963class ClipKwargs(KwargsNode): 

964 """key word arguments for [ClipDescr][]""" 

965 

966 min: Optional[float] = None 

967 """Minimum value for clipping. 

968 

969 Exclusive with [min_percentile][] 

970 """ 

971 min_percentile: Optional[Annotated[float, Interval(ge=0, lt=100)]] = None 

972 """Minimum percentile for clipping. 

973 

974 Exclusive with [min][]. 

975 

976 In range [0, 100). 

977 """ 

978 

979 max: Optional[float] = None 

980 """Maximum value for clipping. 

981 

982 Exclusive with `max_percentile`. 

983 """ 

984 max_percentile: Optional[Annotated[float, Interval(gt=1, le=100)]] = None 

985 """Maximum percentile for clipping. 

986 

987 Exclusive with `max`. 

988 

989 In range (1, 100]. 

990 """ 

991 

992 axes: Annotated[ 

993 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

994 ] = None 

995 """The subset of axes to determine percentiles jointly, 

996 

997 i.e. axes to reduce to compute min/max from `min_percentile`/`max_percentile`. 

998 For example to clip 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

999 resulting in a tensor of equal shape with clipped values per channel, specify `axes=('batch', 'x', 'y')`. 

1000 To clip samples independently, leave out the 'batch' axis. 

1001 

1002 Only valid if `min_percentile` and/or `max_percentile` are set. 

1003 

1004 Default: Compute percentiles over all axes jointly.""" 

1005 

1006 @model_validator(mode="after") 

1007 def _validate(self) -> Self: 

1008 if (self.min is not None) and (self.min_percentile is not None): 

1009 raise ValueError( 

1010 "Only one of `min` and `min_percentile` may be set, not both." 

1011 ) 

1012 if (self.max is not None) and (self.max_percentile is not None): 

1013 raise ValueError( 

1014 "Only one of `max` and `max_percentile` may be set, not both." 

1015 ) 

1016 if ( 

1017 self.min is None 

1018 and self.min_percentile is None 

1019 and self.max is None 

1020 and self.max_percentile is None 

1021 ): 

1022 raise ValueError( 

1023 "At least one of `min`, `min_percentile`, `max`, or `max_percentile` must be set." 

1024 ) 

1025 

1026 if ( 

1027 self.axes is not None 

1028 and self.min_percentile is None 

1029 and self.max_percentile is None 

1030 ): 

1031 raise ValueError( 

1032 "If `axes` is set, at least one of `min_percentile` or `max_percentile` must be set." 

1033 ) 

1034 

1035 return self 

1036 

1037 

1038class ClipDescr(NodeWithExplicitlySetFields): 

1039 """Set tensor values below min to min and above max to max. 

1040 

1041 See `ScaleRangeDescr` for examples. 

1042 """ 

1043 

1044 implemented_id: ClassVar[Literal["clip"]] = "clip" 

1045 if TYPE_CHECKING: 

1046 id: Literal["clip"] = "clip" 

1047 else: 

1048 id: Literal["clip"] 

1049 

1050 kwargs: ClipKwargs 

1051 

1052 

1053class EnsureDtypeKwargs(KwargsNode): 

1054 """key word arguments for [EnsureDtypeDescr][]""" 

1055 

1056 dtype: Literal[ 

1057 "float32", 

1058 "float64", 

1059 "uint8", 

1060 "int8", 

1061 "uint16", 

1062 "int16", 

1063 "uint32", 

1064 "int32", 

1065 "uint64", 

1066 "int64", 

1067 "bool", 

1068 ] 

1069 

1070 

1071class EnsureDtypeDescr(NodeWithExplicitlySetFields): 

1072 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

1073 

1074 This can for example be used to ensure the inner neural network model gets a 

1075 different input tensor data type than the fully described bioimage.io model does. 

1076 

1077 Examples: 

1078 The described bioimage.io model (incl. preprocessing) accepts any 

1079 float32-compatible tensor, normalizes it with percentiles and clipping and then 

1080 casts it to uint8, which is what the neural network in this example expects. 

1081 - in YAML 

1082 ```yaml 

1083 inputs: 

1084 - data: 

1085 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

1086 preprocessing: 

1087 - id: scale_range 

1088 kwargs: 

1089 axes: ['y', 'x'] 

1090 max_percentile: 99.8 

1091 min_percentile: 5.0 

1092 - id: clip 

1093 kwargs: 

1094 min: 0.0 

1095 max: 1.0 

1096 - id: ensure_dtype # the neural network of the model requires uint8 

1097 kwargs: 

1098 dtype: uint8 

1099 ``` 

1100 - in Python: 

1101 >>> preprocessing = [ 

1102 ... ScaleRangeDescr( 

1103 ... kwargs=ScaleRangeKwargs( 

1104 ... axes= (AxisId('y'), AxisId('x')), 

1105 ... max_percentile= 99.8, 

1106 ... min_percentile= 5.0, 

1107 ... ) 

1108 ... ), 

1109 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

1110 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

1111 ... ] 

1112 """ 

1113 

1114 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1115 if TYPE_CHECKING: 

1116 id: Literal["ensure_dtype"] = "ensure_dtype" 

1117 else: 

1118 id: Literal["ensure_dtype"] 

1119 

1120 kwargs: EnsureDtypeKwargs 

1121 

1122 

1123class ScaleLinearKwargs(KwargsNode): 

1124 """Key word arguments for [ScaleLinearDescr][]""" 

1125 

1126 gain: float = 1.0 

1127 """multiplicative factor""" 

1128 

1129 offset: float = 0.0 

1130 """additive term""" 

1131 

1132 @model_validator(mode="after") 

1133 def _validate(self) -> Self: 

1134 if self.gain == 1.0 and self.offset == 0.0: 

1135 raise ValueError( 

1136 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1137 + " != 0.0." 

1138 ) 

1139 

1140 return self 

1141 

1142 

1143class ScaleLinearAlongAxisKwargs(KwargsNode): 

1144 """Key word arguments for [ScaleLinearDescr][]""" 

1145 

1146 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1147 """The axis of gain and offset values.""" 

1148 

1149 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1150 """multiplicative factor""" 

1151 

1152 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1153 """additive term""" 

1154 

1155 @model_validator(mode="after") 

1156 def _validate(self) -> Self: 

1157 if isinstance(self.gain, list): 

1158 if isinstance(self.offset, list): 

1159 if len(self.gain) != len(self.offset): 

1160 raise ValueError( 

1161 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1162 ) 

1163 else: 

1164 self.offset = [float(self.offset)] * len(self.gain) 

1165 elif isinstance(self.offset, list): 

1166 self.gain = [float(self.gain)] * len(self.offset) 

1167 else: 

1168 raise ValueError( 

1169 "Do not specify an `axis` for scalar gain and offset values." 

1170 ) 

1171 

1172 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1173 raise ValueError( 

1174 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1175 + " != 0.0." 

1176 ) 

1177 

1178 return self 

1179 

1180 

1181class ScaleLinearDescr(NodeWithExplicitlySetFields): 

1182 """Fixed linear scaling. 

1183 

1184 Examples: 

1185 1. Scale with scalar gain and offset 

1186 - in YAML 

1187 ```yaml 

1188 preprocessing: 

1189 - id: scale_linear 

1190 kwargs: 

1191 gain: 2.0 

1192 offset: 3.0 

1193 ``` 

1194 - in Python: 

1195 >>> preprocessing = [ 

1196 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1197 ... ] 

1198 

1199 2. Independent scaling along an axis 

1200 - in YAML 

1201 ```yaml 

1202 preprocessing: 

1203 - id: scale_linear 

1204 kwargs: 

1205 axis: 'channel' 

1206 gain: [1.0, 2.0, 3.0] 

1207 ``` 

1208 - in Python: 

1209 >>> preprocessing = [ 

1210 ... ScaleLinearDescr( 

1211 ... kwargs=ScaleLinearAlongAxisKwargs( 

1212 ... axis=AxisId("channel"), 

1213 ... gain=[1.0, 2.0, 3.0], 

1214 ... ) 

1215 ... ) 

1216 ... ] 

1217 

1218 """ 

1219 

1220 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1221 if TYPE_CHECKING: 

1222 id: Literal["scale_linear"] = "scale_linear" 

1223 else: 

1224 id: Literal["scale_linear"] 

1225 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1226 

1227 

1228class SigmoidDescr(NodeWithExplicitlySetFields): 

1229 """The logistic sigmoid function, a.k.a. expit function. 

1230 

1231 Examples: 

1232 - in YAML 

1233 ```yaml 

1234 postprocessing: 

1235 - id: sigmoid 

1236 ``` 

1237 - in Python: 

1238 >>> postprocessing = [SigmoidDescr()] 

1239 """ 

1240 

1241 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1242 if TYPE_CHECKING: 

1243 id: Literal["sigmoid"] = "sigmoid" 

1244 else: 

1245 id: Literal["sigmoid"] 

1246 

1247 @property 

1248 def kwargs(self) -> KwargsNode: 

1249 """empty kwargs""" 

1250 return KwargsNode() 

1251 

1252 

1253class SoftmaxKwargs(KwargsNode): 

1254 """key word arguments for [SoftmaxDescr][]""" 

1255 

1256 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 

1257 """The axis to apply the softmax function along. 

1258 Note: 

1259 Defaults to 'channel' axis 

1260 (which may not exist, in which case 

1261 a different axis id has to be specified). 

1262 """ 

1263 

1264 

1265class SoftmaxDescr(NodeWithExplicitlySetFields): 

1266 """The softmax function. 

1267 

1268 Examples: 

1269 - in YAML 

1270 ```yaml 

1271 postprocessing: 

1272 - id: softmax 

1273 kwargs: 

1274 axis: channel 

1275 ``` 

1276 - in Python: 

1277 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 

1278 """ 

1279 

1280 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 

1281 if TYPE_CHECKING: 

1282 id: Literal["softmax"] = "softmax" 

1283 else: 

1284 id: Literal["softmax"] 

1285 

1286 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 

1287 

1288 

1289class FixedZeroMeanUnitVarianceKwargs(KwargsNode): 

1290 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]""" 

1291 

1292 mean: float 

1293 """The mean value to normalize with.""" 

1294 

1295 std: Annotated[float, Ge(1e-6)] 

1296 """The standard deviation value to normalize with.""" 

1297 

1298 

1299class FixedZeroMeanUnitVarianceAlongAxisKwargs(KwargsNode): 

1300 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]""" 

1301 

1302 mean: NotEmpty[List[float]] 

1303 """The mean value(s) to normalize with.""" 

1304 

1305 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1306 """The standard deviation value(s) to normalize with. 

1307 Size must match `mean` values.""" 

1308 

1309 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1310 """The axis of the mean/std values to normalize each entry along that dimension 

1311 separately.""" 

1312 

1313 @model_validator(mode="after") 

1314 def _mean_and_std_match(self) -> Self: 

1315 if len(self.mean) != len(self.std): 

1316 raise ValueError( 

1317 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1318 + " must match." 

1319 ) 

1320 

1321 return self 

1322 

1323 

1324class FixedZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

1325 """Subtract a given mean and divide by the standard deviation. 

1326 

1327 Normalize with fixed, precomputed values for 

1328 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1329 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1330 axes. 

1331 

1332 Examples: 

1333 1. scalar value for whole tensor 

1334 - in YAML 

1335 ```yaml 

1336 preprocessing: 

1337 - id: fixed_zero_mean_unit_variance 

1338 kwargs: 

1339 mean: 103.5 

1340 std: 13.7 

1341 ``` 

1342 - in Python 

1343 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1344 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1345 ... )] 

1346 

1347 2. independently along an axis 

1348 - in YAML 

1349 ```yaml 

1350 preprocessing: 

1351 - id: fixed_zero_mean_unit_variance 

1352 kwargs: 

1353 axis: channel 

1354 mean: [101.5, 102.5, 103.5] 

1355 std: [11.7, 12.7, 13.7] 

1356 ``` 

1357 - in Python 

1358 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1359 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1360 ... axis=AxisId("channel"), 

1361 ... mean=[101.5, 102.5, 103.5], 

1362 ... std=[11.7, 12.7, 13.7], 

1363 ... ) 

1364 ... )] 

1365 """ 

1366 

1367 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1368 "fixed_zero_mean_unit_variance" 

1369 ) 

1370 if TYPE_CHECKING: 

1371 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1372 else: 

1373 id: Literal["fixed_zero_mean_unit_variance"] 

1374 

1375 kwargs: Union[ 

1376 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1377 ] 

1378 

1379 

1380class ZeroMeanUnitVarianceKwargs(KwargsNode): 

1381 """key word arguments for [ZeroMeanUnitVarianceDescr][]""" 

1382 

1383 axes: Annotated[ 

1384 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1385 ] = None 

1386 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1387 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1388 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1389 To normalize each sample independently leave out the 'batch' axis. 

1390 Default: Scale all axes jointly.""" 

1391 

1392 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1393 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1394 

1395 

1396class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

1397 """Subtract mean and divide by variance. 

1398 

1399 Examples: 

1400 Subtract tensor mean and variance 

1401 - in YAML 

1402 ```yaml 

1403 preprocessing: 

1404 - id: zero_mean_unit_variance 

1405 ``` 

1406 - in Python 

1407 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1408 """ 

1409 

1410 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1411 "zero_mean_unit_variance" 

1412 ) 

1413 if TYPE_CHECKING: 

1414 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1415 else: 

1416 id: Literal["zero_mean_unit_variance"] 

1417 

1418 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1419 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 

1420 ) 

1421 

1422 

1423class ScaleRangeKwargs(KwargsNode): 

1424 """key word arguments for [ScaleRangeDescr][] 

1425 

1426 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1427 this processing step normalizes data to the [0, 1] intervall. 

1428 For other percentiles the normalized values will partially be outside the [0, 1] 

1429 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1430 normalized values to a range. 

1431 """ 

1432 

1433 axes: Annotated[ 

1434 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1435 ] = None 

1436 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1437 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1438 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1439 To normalize samples independently, leave out the "batch" axis. 

1440 Default: Scale all axes jointly.""" 

1441 

1442 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1443 """The lower percentile used to determine the value to align with zero.""" 

1444 

1445 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1446 """The upper percentile used to determine the value to align with one. 

1447 Has to be bigger than `min_percentile`. 

1448 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1449 accepting percentiles specified in the range 0.0 to 1.0.""" 

1450 

1451 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1452 """Epsilon for numeric stability. 

1453 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1454 with `v_lower,v_upper` values at the respective percentiles.""" 

1455 

1456 reference_tensor: Optional[TensorId] = None 

1457 """Tensor ID to compute the percentiles from. Default: The tensor itself. 

1458 For any tensor in `inputs` only input tensor references are allowed.""" 

1459 

1460 @field_validator("max_percentile", mode="after") 

1461 @classmethod 

1462 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1463 if (min_p := info.data["min_percentile"]) >= value: 

1464 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1465 

1466 return value 

1467 

1468 

1469class ScaleRangeDescr(NodeWithExplicitlySetFields): 

1470 """Scale with percentiles. 

1471 

1472 Examples: 

1473 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1474 - in YAML 

1475 ```yaml 

1476 preprocessing: 

1477 - id: scale_range 

1478 kwargs: 

1479 axes: ['y', 'x'] 

1480 max_percentile: 99.8 

1481 min_percentile: 5.0 

1482 ``` 

1483 - in Python 

1484 >>> preprocessing = [ 

1485 ... ScaleRangeDescr( 

1486 ... kwargs=ScaleRangeKwargs( 

1487 ... axes= (AxisId('y'), AxisId('x')), 

1488 ... max_percentile= 99.8, 

1489 ... min_percentile= 5.0, 

1490 ... ) 

1491 ... ) 

1492 ... ] 

1493 

1494 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1495 - in YAML 

1496 ```yaml 

1497 preprocessing: 

1498 - id: scale_range 

1499 kwargs: 

1500 axes: ['y', 'x'] 

1501 max_percentile: 99.8 

1502 min_percentile: 5.0 

1503 - id: scale_range 

1504 - id: clip 

1505 kwargs: 

1506 min: 0.0 

1507 max: 1.0 

1508 ``` 

1509 - in Python 

1510 >>> preprocessing = [ 

1511 ... ScaleRangeDescr( 

1512 ... kwargs=ScaleRangeKwargs( 

1513 ... axes= (AxisId('y'), AxisId('x')), 

1514 ... max_percentile= 99.8, 

1515 ... min_percentile= 5.0, 

1516 ... ) 

1517 ... ), 

1518 ... ClipDescr( 

1519 ... kwargs=ClipKwargs( 

1520 ... min=0.0, 

1521 ... max=1.0, 

1522 ... ) 

1523 ... ), 

1524 ... ] 

1525 

1526 """ 

1527 

1528 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1529 if TYPE_CHECKING: 

1530 id: Literal["scale_range"] = "scale_range" 

1531 else: 

1532 id: Literal["scale_range"] 

1533 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 

1534 

1535 

1536class ScaleMeanVarianceKwargs(KwargsNode): 

1537 """key word arguments for [ScaleMeanVarianceKwargs][]""" 

1538 

1539 reference_tensor: TensorId 

1540 """Name of tensor to match.""" 

1541 

1542 axes: Annotated[ 

1543 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1544 ] = None 

1545 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1546 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1547 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1548 To normalize samples independently, leave out the 'batch' axis. 

1549 Default: Scale all axes jointly.""" 

1550 

1551 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1552 """Epsilon for numeric stability: 

1553 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1554 

1555 

1556class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields): 

1557 """Scale a tensor's data distribution to match another tensor's mean/std. 

1558 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1559 """ 

1560 

1561 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1562 if TYPE_CHECKING: 

1563 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1564 else: 

1565 id: Literal["scale_mean_variance"] 

1566 kwargs: ScaleMeanVarianceKwargs 

1567 

1568 

1569PreprocessingDescr = Annotated[ 

1570 Union[ 

1571 BinarizeDescr, 

1572 ClipDescr, 

1573 EnsureDtypeDescr, 

1574 FixedZeroMeanUnitVarianceDescr, 

1575 ScaleLinearDescr, 

1576 ScaleRangeDescr, 

1577 SigmoidDescr, 

1578 SoftmaxDescr, 

1579 ZeroMeanUnitVarianceDescr, 

1580 ], 

1581 Discriminator("id"), 

1582] 

1583PostprocessingDescr = Annotated[ 

1584 Union[ 

1585 BinarizeDescr, 

1586 ClipDescr, 

1587 EnsureDtypeDescr, 

1588 FixedZeroMeanUnitVarianceDescr, 

1589 ScaleLinearDescr, 

1590 ScaleMeanVarianceDescr, 

1591 ScaleRangeDescr, 

1592 SigmoidDescr, 

1593 SoftmaxDescr, 

1594 ZeroMeanUnitVarianceDescr, 

1595 ], 

1596 Discriminator("id"), 

1597] 

1598 

1599IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1600 

1601 

1602class TensorDescrBase(Node, Generic[IO_AxisT]): 

1603 id: TensorId 

1604 """Tensor id. No duplicates are allowed.""" 

1605 

1606 description: Annotated[str, MaxLen(128)] = "" 

1607 """free text description""" 

1608 

1609 axes: NotEmpty[Sequence[IO_AxisT]] 

1610 """tensor axes""" 

1611 

1612 @property 

1613 def shape(self): 

1614 return tuple(a.size for a in self.axes) 

1615 

1616 @field_validator("axes", mode="after", check_fields=False) 

1617 @classmethod 

1618 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1619 batch_axes = [a for a in axes if a.type == "batch"] 

1620 if len(batch_axes) > 1: 

1621 raise ValueError( 

1622 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1623 ) 

1624 

1625 seen_ids: Set[AxisId] = set() 

1626 duplicate_axes_ids: Set[AxisId] = set() 

1627 for a in axes: 

1628 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1629 

1630 if duplicate_axes_ids: 

1631 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1632 

1633 return axes 

1634 

1635 test_tensor: FAIR[Optional[FileDescr_]] = None 

1636 """An example tensor to use for testing. 

1637 Using the model with the test input tensors is expected to yield the test output tensors. 

1638 Each test tensor has be a an ndarray in the 

1639 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1640 The file extension must be '.npy'.""" 

1641 

1642 sample_tensor: FAIR[Optional[FileDescr_]] = None 

1643 """A sample tensor to illustrate a possible input/output for the model, 

1644 The sample image primarily serves to inform a human user about an example use case 

1645 and is typically stored as .hdf5, .png or .tiff. 

1646 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1647 (numpy's `.npy` format is not supported). 

1648 The image dimensionality has to match the number of axes specified in this tensor description. 

1649 """ 

1650 

1651 @model_validator(mode="after") 

1652 def _validate_sample_tensor(self) -> Self: 

1653 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1654 return self 

1655 

1656 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1657 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType] 

1658 reader.read(), 

1659 extension=PurePosixPath(reader.original_file_name).suffix, 

1660 ) 

1661 n_dims = len(tensor.squeeze().shape) 

1662 n_dims_min = n_dims_max = len(self.axes) 

1663 

1664 for a in self.axes: 

1665 if isinstance(a, BatchAxis): 

1666 n_dims_min -= 1 

1667 elif isinstance(a.size, int): 

1668 if a.size == 1: 

1669 n_dims_min -= 1 

1670 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1671 if a.size.min == 1: 

1672 n_dims_min -= 1 

1673 elif isinstance(a.size, SizeReference): 

1674 if a.size.offset < 2: 

1675 # size reference may result in singleton axis 

1676 n_dims_min -= 1 

1677 else: 

1678 assert_never(a.size) 

1679 

1680 n_dims_min = max(0, n_dims_min) 

1681 if n_dims < n_dims_min or n_dims > n_dims_max: 

1682 raise ValueError( 

1683 f"Expected sample tensor to have {n_dims_min} to" 

1684 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1685 ) 

1686 

1687 return self 

1688 

1689 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1690 IntervalOrRatioDataDescr() 

1691 ) 

1692 """Description of the tensor's data values, optionally per channel. 

1693 If specified per channel, the data `type` needs to match across channels.""" 

1694 

1695 @property 

1696 def dtype( 

1697 self, 

1698 ) -> Literal[ 

1699 "float32", 

1700 "float64", 

1701 "uint8", 

1702 "int8", 

1703 "uint16", 

1704 "int16", 

1705 "uint32", 

1706 "int32", 

1707 "uint64", 

1708 "int64", 

1709 "bool", 

1710 ]: 

1711 """dtype as specified under `data.type` or `data[i].type`""" 

1712 if isinstance(self.data, collections.abc.Sequence): 

1713 return self.data[0].type 

1714 else: 

1715 return self.data.type 

1716 

1717 @field_validator("data", mode="after") 

1718 @classmethod 

1719 def _check_data_type_across_channels( 

1720 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1721 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1722 if not isinstance(value, list): 

1723 return value 

1724 

1725 dtypes = {t.type for t in value} 

1726 if len(dtypes) > 1: 

1727 raise ValueError( 

1728 "Tensor data descriptions per channel need to agree in their data" 

1729 + f" `type`, but found {dtypes}." 

1730 ) 

1731 

1732 return value 

1733 

1734 @model_validator(mode="after") 

1735 def _check_data_matches_channelaxis(self) -> Self: 

1736 if not isinstance(self.data, (list, tuple)): 

1737 return self 

1738 

1739 for a in self.axes: 

1740 if isinstance(a, ChannelAxis): 

1741 size = a.size 

1742 assert isinstance(size, int) 

1743 break 

1744 else: 

1745 return self 

1746 

1747 if len(self.data) != size: 

1748 raise ValueError( 

1749 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1750 + f" '{a.id}' axis has size {size}." 

1751 ) 

1752 

1753 return self 

1754 

1755 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1756 if len(array.shape) != len(self.axes): 

1757 raise ValueError( 

1758 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1759 + f" incompatible with {len(self.axes)} axes." 

1760 ) 

1761 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1762 

1763 

1764class InputTensorDescr(TensorDescrBase[InputAxis]): 

1765 id: TensorId = TensorId("input") 

1766 """Input tensor id. 

1767 No duplicates are allowed across all inputs and outputs.""" 

1768 

1769 optional: bool = False 

1770 """indicates that this tensor may be `None`""" 

1771 

1772 preprocessing: List[PreprocessingDescr] = Field( 

1773 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 

1774 ) 

1775 

1776 """Description of how this input should be preprocessed. 

1777 

1778 notes: 

1779 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1780 to ensure an input tensor's data type matches the input tensor's data description. 

1781 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1782 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1783 changing the data type. 

1784 """ 

1785 

1786 @model_validator(mode="after") 

1787 def _validate_preprocessing_kwargs(self) -> Self: 

1788 axes_ids = [a.id for a in self.axes] 

1789 for p in self.preprocessing: 

1790 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1791 if kwargs_axes is None: 

1792 continue 

1793 

1794 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1795 raise ValueError( 

1796 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1797 ) 

1798 

1799 if any(a not in axes_ids for a in kwargs_axes): 

1800 raise ValueError( 

1801 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1802 ) 

1803 

1804 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1805 dtype = self.data.type 

1806 else: 

1807 dtype = self.data[0].type 

1808 

1809 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1810 if not self.preprocessing or not isinstance( 

1811 self.preprocessing[0], EnsureDtypeDescr 

1812 ): 

1813 self.preprocessing.insert( 

1814 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1815 ) 

1816 

1817 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1818 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1819 self.preprocessing.append( 

1820 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1821 ) 

1822 

1823 return self 

1824 

1825 

1826def convert_axes( 

1827 axes: str, 

1828 *, 

1829 shape: Union[ 

1830 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1831 ], 

1832 tensor_type: Literal["input", "output"], 

1833 halo: Optional[Sequence[int]], 

1834 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1835): 

1836 ret: List[AnyAxis] = [] 

1837 for i, a in enumerate(axes): 

1838 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1839 if axis_type == "batch": 

1840 ret.append(BatchAxis()) 

1841 continue 

1842 

1843 scale = 1.0 

1844 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1845 if shape.step[i] == 0: 

1846 size = shape.min[i] 

1847 else: 

1848 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1849 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1850 ref_t = str(shape.reference_tensor) 

1851 if ref_t.count(".") == 1: 

1852 t_id, orig_a_id = ref_t.split(".") 

1853 else: 

1854 t_id = ref_t 

1855 orig_a_id = a 

1856 

1857 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1858 if not (orig_scale := shape.scale[i]): 

1859 # old way to insert a new axis dimension 

1860 size = int(2 * shape.offset[i]) 

1861 else: 

1862 scale = 1 / orig_scale 

1863 if axis_type in ("channel", "index"): 

1864 # these axes no longer have a scale 

1865 offset_from_scale = orig_scale * size_refs.get( 

1866 _TensorName_v0_4(t_id), {} 

1867 ).get(orig_a_id, 0) 

1868 else: 

1869 offset_from_scale = 0 

1870 size = SizeReference( 

1871 tensor_id=TensorId(t_id), 

1872 axis_id=AxisId(a_id), 

1873 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1874 ) 

1875 else: 

1876 size = shape[i] 

1877 

1878 if axis_type == "time": 

1879 if tensor_type == "input": 

1880 ret.append(TimeInputAxis(size=size, scale=scale)) 

1881 else: 

1882 assert not isinstance(size, ParameterizedSize) 

1883 if halo is None: 

1884 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1885 else: 

1886 assert not isinstance(size, int) 

1887 ret.append( 

1888 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1889 ) 

1890 

1891 elif axis_type == "index": 

1892 if tensor_type == "input": 

1893 ret.append(IndexInputAxis(size=size)) 

1894 else: 

1895 if isinstance(size, ParameterizedSize): 

1896 size = DataDependentSize(min=size.min) 

1897 

1898 ret.append(IndexOutputAxis(size=size)) 

1899 elif axis_type == "channel": 

1900 assert not isinstance(size, ParameterizedSize) 

1901 if isinstance(size, SizeReference): 

1902 warnings.warn( 

1903 "Conversion of channel size from an implicit output shape may be" 

1904 + " wrong" 

1905 ) 

1906 ret.append( 

1907 ChannelAxis( 

1908 channel_names=[ 

1909 Identifier(f"channel{i}") for i in range(size.offset) 

1910 ] 

1911 ) 

1912 ) 

1913 else: 

1914 ret.append( 

1915 ChannelAxis( 

1916 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1917 ) 

1918 ) 

1919 elif axis_type == "space": 

1920 if tensor_type == "input": 

1921 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

1922 else: 

1923 assert not isinstance(size, ParameterizedSize) 

1924 if halo is None or halo[i] == 0: 

1925 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

1926 elif isinstance(size, int): 

1927 raise NotImplementedError( 

1928 f"output axis with halo and fixed size (here {size}) not allowed" 

1929 ) 

1930 else: 

1931 ret.append( 

1932 SpaceOutputAxisWithHalo( 

1933 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

1934 ) 

1935 ) 

1936 

1937 return ret 

1938 

1939 

1940def _axes_letters_to_ids( 

1941 axes: Optional[str], 

1942) -> Optional[List[AxisId]]: 

1943 if axes is None: 

1944 return None 

1945 

1946 return [AxisId(a) for a in axes] 

1947 

1948 

1949def _get_complement_v04_axis( 

1950 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

1951) -> Optional[AxisId]: 

1952 if axes is None: 

1953 return None 

1954 

1955 non_complement_axes = set(axes) | {"b"} 

1956 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

1957 if len(complement_axes) > 1: 

1958 raise ValueError( 

1959 f"Expected none or a single complement axis, but axes '{axes}' " 

1960 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

1961 ) 

1962 

1963 return None if not complement_axes else AxisId(complement_axes[0]) 

1964 

1965 

1966def _convert_proc( 

1967 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

1968 tensor_axes: Sequence[str], 

1969) -> Union[PreprocessingDescr, PostprocessingDescr]: 

1970 if isinstance(p, _BinarizeDescr_v0_4): 

1971 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

1972 elif isinstance(p, _ClipDescr_v0_4): 

1973 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

1974 elif isinstance(p, _SigmoidDescr_v0_4): 

1975 return SigmoidDescr() 

1976 elif isinstance(p, _ScaleLinearDescr_v0_4): 

1977 axes = _axes_letters_to_ids(p.kwargs.axes) 

1978 if p.kwargs.axes is None: 

1979 axis = None 

1980 else: 

1981 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1982 

1983 if axis is None: 

1984 assert not isinstance(p.kwargs.gain, list) 

1985 assert not isinstance(p.kwargs.offset, list) 

1986 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

1987 else: 

1988 kwargs = ScaleLinearAlongAxisKwargs( 

1989 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

1990 ) 

1991 return ScaleLinearDescr(kwargs=kwargs) 

1992 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

1993 return ScaleMeanVarianceDescr( 

1994 kwargs=ScaleMeanVarianceKwargs( 

1995 axes=_axes_letters_to_ids(p.kwargs.axes), 

1996 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

1997 eps=p.kwargs.eps, 

1998 ) 

1999 ) 

2000 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

2001 if p.kwargs.mode == "fixed": 

2002 mean = p.kwargs.mean 

2003 std = p.kwargs.std 

2004 assert mean is not None 

2005 assert std is not None 

2006 

2007 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

2008 

2009 if axis is None: 

2010 if isinstance(mean, list): 

2011 raise ValueError("Expected single float value for mean, not <list>") 

2012 if isinstance(std, list): 

2013 raise ValueError("Expected single float value for std, not <list>") 

2014 return FixedZeroMeanUnitVarianceDescr( 

2015 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct( 

2016 mean=mean, 

2017 std=std, 

2018 ) 

2019 ) 

2020 else: 

2021 if not isinstance(mean, list): 

2022 mean = [float(mean)] 

2023 if not isinstance(std, list): 

2024 std = [float(std)] 

2025 

2026 return FixedZeroMeanUnitVarianceDescr( 

2027 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

2028 axis=axis, mean=mean, std=std 

2029 ) 

2030 ) 

2031 

2032 else: 

2033 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

2034 if p.kwargs.mode == "per_dataset": 

2035 axes = [AxisId("batch")] + axes 

2036 if not axes: 

2037 axes = None 

2038 return ZeroMeanUnitVarianceDescr( 

2039 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

2040 ) 

2041 

2042 elif isinstance(p, _ScaleRangeDescr_v0_4): 

2043 return ScaleRangeDescr( 

2044 kwargs=ScaleRangeKwargs( 

2045 axes=_axes_letters_to_ids(p.kwargs.axes), 

2046 min_percentile=p.kwargs.min_percentile, 

2047 max_percentile=p.kwargs.max_percentile, 

2048 eps=p.kwargs.eps, 

2049 ) 

2050 ) 

2051 else: 

2052 assert_never(p) 

2053 

2054 

2055class _InputTensorConv( 

2056 Converter[ 

2057 _InputTensorDescr_v0_4, 

2058 InputTensorDescr, 

2059 FileSource_, 

2060 Optional[FileSource_], 

2061 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2062 ] 

2063): 

2064 def _convert( 

2065 self, 

2066 src: _InputTensorDescr_v0_4, 

2067 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

2068 test_tensor: FileSource_, 

2069 sample_tensor: Optional[FileSource_], 

2070 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2071 ) -> "InputTensorDescr | dict[str, Any]": 

2072 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2073 src.axes, 

2074 shape=src.shape, 

2075 tensor_type="input", 

2076 halo=None, 

2077 size_refs=size_refs, 

2078 ) 

2079 prep: List[PreprocessingDescr] = [] 

2080 for p in src.preprocessing: 

2081 cp = _convert_proc(p, src.axes) 

2082 assert not isinstance(cp, ScaleMeanVarianceDescr) 

2083 prep.append(cp) 

2084 

2085 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

2086 

2087 return tgt( 

2088 axes=axes, 

2089 id=TensorId(str(src.name)), 

2090 test_tensor=FileDescr(source=test_tensor), 

2091 sample_tensor=( 

2092 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2093 ), 

2094 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

2095 preprocessing=prep, 

2096 ) 

2097 

2098 

2099_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

2100 

2101 

2102class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

2103 id: TensorId = TensorId("output") 

2104 """Output tensor id. 

2105 No duplicates are allowed across all inputs and outputs.""" 

2106 

2107 postprocessing: List[PostprocessingDescr] = Field( 

2108 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

2109 ) 

2110 """Description of how this output should be postprocessed. 

2111 

2112 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

2113 If not given this is added to cast to this tensor's `data.type`. 

2114 """ 

2115 

2116 @model_validator(mode="after") 

2117 def _validate_postprocessing_kwargs(self) -> Self: 

2118 axes_ids = [a.id for a in self.axes] 

2119 for p in self.postprocessing: 

2120 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

2121 if kwargs_axes is None: 

2122 continue 

2123 

2124 if not isinstance(kwargs_axes, collections.abc.Sequence): 

2125 raise ValueError( 

2126 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

2127 ) 

2128 

2129 if any(a not in axes_ids for a in kwargs_axes): 

2130 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

2131 

2132 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

2133 dtype = self.data.type 

2134 else: 

2135 dtype = self.data[0].type 

2136 

2137 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

2138 if not self.postprocessing or not isinstance( 

2139 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

2140 ): 

2141 self.postprocessing.append( 

2142 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2143 ) 

2144 return self 

2145 

2146 

2147class _OutputTensorConv( 

2148 Converter[ 

2149 _OutputTensorDescr_v0_4, 

2150 OutputTensorDescr, 

2151 FileSource_, 

2152 Optional[FileSource_], 

2153 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2154 ] 

2155): 

2156 def _convert( 

2157 self, 

2158 src: _OutputTensorDescr_v0_4, 

2159 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

2160 test_tensor: FileSource_, 

2161 sample_tensor: Optional[FileSource_], 

2162 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2163 ) -> "OutputTensorDescr | dict[str, Any]": 

2164 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2165 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2166 src.axes, 

2167 shape=src.shape, 

2168 tensor_type="output", 

2169 halo=src.halo, 

2170 size_refs=size_refs, 

2171 ) 

2172 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2173 if data_descr["type"] == "bool": 

2174 data_descr["values"] = [False, True] 

2175 

2176 return tgt( 

2177 axes=axes, 

2178 id=TensorId(str(src.name)), 

2179 test_tensor=FileDescr(source=test_tensor), 

2180 sample_tensor=( 

2181 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2182 ), 

2183 data=data_descr, # pyright: ignore[reportArgumentType] 

2184 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2185 ) 

2186 

2187 

2188_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2189 

2190 

2191TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2192 

2193 

2194def validate_tensors( 

2195 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 

2196 tensor_origin: Literal[ 

2197 "test_tensor" 

2198 ], # for more precise error messages, e.g. 'test_tensor' 

2199): 

2200 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 

2201 

2202 def e_msg(d: TensorDescr): 

2203 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2204 

2205 for descr, array in tensors.values(): 

2206 if array is None: 

2207 axis_sizes = {a.id: None for a in descr.axes} 

2208 else: 

2209 try: 

2210 axis_sizes = descr.get_axis_sizes_for_array(array) 

2211 except ValueError as e: 

2212 raise ValueError(f"{e_msg(descr)} {e}") 

2213 

2214 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 

2215 

2216 for descr, array in tensors.values(): 

2217 if array is None: 

2218 continue 

2219 

2220 if descr.dtype in ("float32", "float64"): 

2221 invalid_test_tensor_dtype = array.dtype.name not in ( 

2222 "float32", 

2223 "float64", 

2224 "uint8", 

2225 "int8", 

2226 "uint16", 

2227 "int16", 

2228 "uint32", 

2229 "int32", 

2230 "uint64", 

2231 "int64", 

2232 ) 

2233 else: 

2234 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2235 

2236 if invalid_test_tensor_dtype: 

2237 raise ValueError( 

2238 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 

2239 + f" match described dtype '{descr.dtype}'" 

2240 ) 

2241 

2242 if array.min() > -1e-4 and array.max() < 1e-4: 

2243 raise ValueError( 

2244 "Output values are too small for reliable testing." 

2245 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2246 ) 

2247 

2248 for a in descr.axes: 

2249 actual_size = all_tensor_axes[descr.id][a.id][1] 

2250 if actual_size is None: 

2251 continue 

2252 

2253 if a.size is None: 

2254 continue 

2255 

2256 if isinstance(a.size, int): 

2257 if actual_size != a.size: 

2258 raise ValueError( 

2259 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 

2260 + f"has incompatible size {actual_size}, expected {a.size}" 

2261 ) 

2262 elif isinstance(a.size, ParameterizedSize): 

2263 _ = a.size.validate_size(actual_size) 

2264 elif isinstance(a.size, DataDependentSize): 

2265 _ = a.size.validate_size(actual_size) 

2266 elif isinstance(a.size, SizeReference): 

2267 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2268 if ref_tensor_axes is None: 

2269 raise ValueError( 

2270 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2271 + f" reference '{a.size.tensor_id}'" 

2272 ) 

2273 

2274 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2275 if ref_axis is None or ref_size is None: 

2276 raise ValueError( 

2277 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2278 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 

2279 ) 

2280 

2281 if a.unit != ref_axis.unit: 

2282 raise ValueError( 

2283 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2284 + " axis and reference axis to have the same `unit`, but" 

2285 + f" {a.unit}!={ref_axis.unit}" 

2286 ) 

2287 

2288 if actual_size != ( 

2289 expected_size := ( 

2290 ref_size * ref_axis.scale / a.scale + a.size.offset 

2291 ) 

2292 ): 

2293 raise ValueError( 

2294 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2295 + f" {actual_size} invalid for referenced size {ref_size};" 

2296 + f" expected {expected_size}" 

2297 ) 

2298 else: 

2299 assert_never(a.size) 

2300 

2301 

2302FileDescr_dependencies = Annotated[ 

2303 FileDescr_, 

2304 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2305 Field(examples=[dict(source="environment.yaml")]), 

2306] 

2307 

2308 

2309class _ArchitectureCallableDescr(Node): 

2310 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2311 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2312 

2313 kwargs: Dict[str, YamlValue] = Field( 

2314 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2315 ) 

2316 """key word arguments for the `callable`""" 

2317 

2318 

2319class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2320 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2321 """Architecture source file""" 

2322 

2323 @model_serializer(mode="wrap", when_used="unless-none") 

2324 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2325 return package_file_descr_serializer(self, nxt, info) 

2326 

2327 

2328class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2329 import_from: str 

2330 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2331 

2332 

2333class _ArchFileConv( 

2334 Converter[ 

2335 _CallableFromFile_v0_4, 

2336 ArchitectureFromFileDescr, 

2337 Optional[Sha256], 

2338 Dict[str, Any], 

2339 ] 

2340): 

2341 def _convert( 

2342 self, 

2343 src: _CallableFromFile_v0_4, 

2344 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2345 sha256: Optional[Sha256], 

2346 kwargs: Dict[str, Any], 

2347 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2348 if src.startswith("http") and src.count(":") == 2: 

2349 http, source, callable_ = src.split(":") 

2350 source = ":".join((http, source)) 

2351 elif not src.startswith("http") and src.count(":") == 1: 

2352 source, callable_ = src.split(":") 

2353 else: 

2354 source = str(src) 

2355 callable_ = str(src) 

2356 return tgt( 

2357 callable=Identifier(callable_), 

2358 source=cast(FileSource_, source), 

2359 sha256=sha256, 

2360 kwargs=kwargs, 

2361 ) 

2362 

2363 

2364_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2365 

2366 

2367class _ArchLibConv( 

2368 Converter[ 

2369 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2370 ] 

2371): 

2372 def _convert( 

2373 self, 

2374 src: _CallableFromDepencency_v0_4, 

2375 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2376 kwargs: Dict[str, Any], 

2377 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2378 *mods, callable_ = src.split(".") 

2379 import_from = ".".join(mods) 

2380 return tgt( 

2381 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2382 ) 

2383 

2384 

2385_arch_lib_conv = _ArchLibConv( 

2386 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2387) 

2388 

2389 

2390class WeightsEntryDescrBase(FileDescr): 

2391 type: ClassVar[WeightsFormat] 

2392 weights_format_name: ClassVar[str] # human readable 

2393 

2394 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2395 """Source of the weights file.""" 

2396 

2397 authors: Optional[List[Author]] = None 

2398 """Authors 

2399 Either the person(s) that have trained this model resulting in the original weights file. 

2400 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2401 Or the person(s) who have converted the weights to this weights format. 

2402 (If this is a child weight, i.e. it has a `parent` field) 

2403 """ 

2404 

2405 parent: Annotated[ 

2406 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2407 ] = None 

2408 """The source weights these weights were converted from. 

2409 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2410 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2411 All weight entries except one (the initial set of weights resulting from training the model), 

2412 need to have this field.""" 

2413 

2414 comment: str = "" 

2415 """A comment about this weights entry, for example how these weights were created.""" 

2416 

2417 @model_validator(mode="after") 

2418 def _validate(self) -> Self: 

2419 if self.type == self.parent: 

2420 raise ValueError("Weights entry can't be it's own parent.") 

2421 

2422 return self 

2423 

2424 @model_serializer(mode="wrap", when_used="unless-none") 

2425 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2426 return package_file_descr_serializer(self, nxt, info) 

2427 

2428 

2429class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2430 type: ClassVar[WeightsFormat] = "keras_hdf5" 

2431 weights_format_name: ClassVar[str] = "Keras HDF5" 

2432 tensorflow_version: Version 

2433 """TensorFlow version used to create these weights.""" 

2434 

2435 

2436FileDescr_external_data = Annotated[ 

2437 FileDescr_, 

2438 WithSuffix(".data", case_sensitive=True), 

2439 Field(examples=[dict(source="weights.onnx.data")]), 

2440] 

2441 

2442 

2443class OnnxWeightsDescr(WeightsEntryDescrBase): 

2444 type: ClassVar[WeightsFormat] = "onnx" 

2445 weights_format_name: ClassVar[str] = "ONNX" 

2446 opset_version: Annotated[int, Ge(7)] 

2447 """ONNX opset version""" 

2448 

2449 external_data: Optional[FileDescr_external_data] = None 

2450 """Source of the external ONNX data file holding the weights. 

2451 (If present **source** holds the ONNX architecture without weights).""" 

2452 

2453 @model_validator(mode="after") 

2454 def _validate_external_data_unique_file_name(self) -> Self: 

2455 if self.external_data is not None and ( 

2456 extract_file_name(self.source) 

2457 == extract_file_name(self.external_data.source) 

2458 ): 

2459 raise ValueError( 

2460 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'" 

2461 + " must be different from ONNX `source` file name." 

2462 ) 

2463 

2464 return self 

2465 

2466 

2467class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2468 type: ClassVar[WeightsFormat] = "pytorch_state_dict" 

2469 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2470 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 

2471 pytorch_version: Version 

2472 """Version of the PyTorch library used. 

2473 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2474 """ 

2475 dependencies: Optional[FileDescr_dependencies] = None 

2476 """Custom depencies beyond pytorch described in a Conda environment file. 

2477 Allows to specify custom dependencies, see conda docs: 

2478 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2479 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2480 

2481 The conda environment file should include pytorch and any version pinning has to be compatible with 

2482 **pytorch_version**. 

2483 """ 

2484 

2485 

2486class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2487 type: ClassVar[WeightsFormat] = "tensorflow_js" 

2488 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2489 tensorflow_version: Version 

2490 """Version of the TensorFlow library used.""" 

2491 

2492 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2493 """The multi-file weights. 

2494 All required files/folders should be a zip archive.""" 

2495 

2496 

2497class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2498 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle" 

2499 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2500 tensorflow_version: Version 

2501 """Version of the TensorFlow library used.""" 

2502 

2503 dependencies: Optional[FileDescr_dependencies] = None 

2504 """Custom dependencies beyond tensorflow. 

2505 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 

2506 

2507 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2508 """The multi-file weights. 

2509 All required files/folders should be a zip archive.""" 

2510 

2511 

2512class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2513 type: ClassVar[WeightsFormat] = "torchscript" 

2514 weights_format_name: ClassVar[str] = "TorchScript" 

2515 pytorch_version: Version 

2516 """Version of the PyTorch library used.""" 

2517 

2518 

2519SpecificWeightsDescr = Union[ 

2520 KerasHdf5WeightsDescr, 

2521 OnnxWeightsDescr, 

2522 PytorchStateDictWeightsDescr, 

2523 TensorflowJsWeightsDescr, 

2524 TensorflowSavedModelBundleWeightsDescr, 

2525 TorchscriptWeightsDescr, 

2526] 

2527 

2528 

2529class WeightsDescr(Node): 

2530 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2531 onnx: Optional[OnnxWeightsDescr] = None 

2532 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2533 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2534 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2535 None 

2536 ) 

2537 torchscript: Optional[TorchscriptWeightsDescr] = None 

2538 

2539 @model_validator(mode="after") 

2540 def check_entries(self) -> Self: 

2541 entries = {wtype for wtype, entry in self if entry is not None} 

2542 

2543 if not entries: 

2544 raise ValueError("Missing weights entry") 

2545 

2546 entries_wo_parent = { 

2547 wtype 

2548 for wtype, entry in self 

2549 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2550 } 

2551 if len(entries_wo_parent) != 1: 

2552 issue_warning( 

2553 "Exactly one weights entry may not specify the `parent` field (got" 

2554 + " {value}). That entry is considered the original set of model weights." 

2555 + " Other weight formats are created through conversion of the orignal or" 

2556 + " already converted weights. They have to reference the weights format" 

2557 + " they were converted from as their `parent`.", 

2558 value=len(entries_wo_parent), 

2559 field="weights", 

2560 ) 

2561 

2562 for wtype, entry in self: 

2563 if entry is None: 

2564 continue 

2565 

2566 assert hasattr(entry, "type") 

2567 assert hasattr(entry, "parent") 

2568 assert wtype == entry.type 

2569 if ( 

2570 entry.parent is not None and entry.parent not in entries 

2571 ): # self reference checked for `parent` field 

2572 raise ValueError( 

2573 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2574 + f" formats: {entries}" 

2575 ) 

2576 

2577 return self 

2578 

2579 def __getitem__( 

2580 self, 

2581 key: Literal[ 

2582 "keras_hdf5", 

2583 "onnx", 

2584 "pytorch_state_dict", 

2585 "tensorflow_js", 

2586 "tensorflow_saved_model_bundle", 

2587 "torchscript", 

2588 ], 

2589 ): 

2590 if key == "keras_hdf5": 

2591 ret = self.keras_hdf5 

2592 elif key == "onnx": 

2593 ret = self.onnx 

2594 elif key == "pytorch_state_dict": 

2595 ret = self.pytorch_state_dict 

2596 elif key == "tensorflow_js": 

2597 ret = self.tensorflow_js 

2598 elif key == "tensorflow_saved_model_bundle": 

2599 ret = self.tensorflow_saved_model_bundle 

2600 elif key == "torchscript": 

2601 ret = self.torchscript 

2602 else: 

2603 raise KeyError(key) 

2604 

2605 if ret is None: 

2606 raise KeyError(key) 

2607 

2608 return ret 

2609 

2610 @overload 

2611 def __setitem__( 

2612 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr] 

2613 ) -> None: ... 

2614 @overload 

2615 def __setitem__( 

2616 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr] 

2617 ) -> None: ... 

2618 @overload 

2619 def __setitem__( 

2620 self, 

2621 key: Literal["pytorch_state_dict"], 

2622 value: Optional[PytorchStateDictWeightsDescr], 

2623 ) -> None: ... 

2624 @overload 

2625 def __setitem__( 

2626 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr] 

2627 ) -> None: ... 

2628 @overload 

2629 def __setitem__( 

2630 self, 

2631 key: Literal["tensorflow_saved_model_bundle"], 

2632 value: Optional[TensorflowSavedModelBundleWeightsDescr], 

2633 ) -> None: ... 

2634 @overload 

2635 def __setitem__( 

2636 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr] 

2637 ) -> None: ... 

2638 

2639 def __setitem__( 

2640 self, 

2641 key: Literal[ 

2642 "keras_hdf5", 

2643 "onnx", 

2644 "pytorch_state_dict", 

2645 "tensorflow_js", 

2646 "tensorflow_saved_model_bundle", 

2647 "torchscript", 

2648 ], 

2649 value: Optional[SpecificWeightsDescr], 

2650 ): 

2651 if key == "keras_hdf5": 

2652 if value is not None and not isinstance(value, KerasHdf5WeightsDescr): 

2653 raise TypeError( 

2654 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}" 

2655 ) 

2656 self.keras_hdf5 = value 

2657 elif key == "onnx": 

2658 if value is not None and not isinstance(value, OnnxWeightsDescr): 

2659 raise TypeError( 

2660 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}" 

2661 ) 

2662 self.onnx = value 

2663 elif key == "pytorch_state_dict": 

2664 if value is not None and not isinstance( 

2665 value, PytorchStateDictWeightsDescr 

2666 ): 

2667 raise TypeError( 

2668 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}" 

2669 ) 

2670 self.pytorch_state_dict = value 

2671 elif key == "tensorflow_js": 

2672 if value is not None and not isinstance(value, TensorflowJsWeightsDescr): 

2673 raise TypeError( 

2674 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}" 

2675 ) 

2676 self.tensorflow_js = value 

2677 elif key == "tensorflow_saved_model_bundle": 

2678 if value is not None and not isinstance( 

2679 value, TensorflowSavedModelBundleWeightsDescr 

2680 ): 

2681 raise TypeError( 

2682 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}" 

2683 ) 

2684 self.tensorflow_saved_model_bundle = value 

2685 elif key == "torchscript": 

2686 if value is not None and not isinstance(value, TorchscriptWeightsDescr): 

2687 raise TypeError( 

2688 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}" 

2689 ) 

2690 self.torchscript = value 

2691 else: 

2692 raise KeyError(key) 

2693 

2694 @property 

2695 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]: 

2696 return { 

2697 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2698 **({} if self.onnx is None else {"onnx": self.onnx}), 

2699 **( 

2700 {} 

2701 if self.pytorch_state_dict is None 

2702 else {"pytorch_state_dict": self.pytorch_state_dict} 

2703 ), 

2704 **( 

2705 {} 

2706 if self.tensorflow_js is None 

2707 else {"tensorflow_js": self.tensorflow_js} 

2708 ), 

2709 **( 

2710 {} 

2711 if self.tensorflow_saved_model_bundle is None 

2712 else { 

2713 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2714 } 

2715 ), 

2716 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2717 } 

2718 

2719 @property 

2720 def missing_formats(self) -> Set[WeightsFormat]: 

2721 return { 

2722 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2723 } 

2724 

2725 

2726class ModelId(ResourceId): 

2727 pass 

2728 

2729 

2730class LinkedModel(LinkedResourceBase): 

2731 """Reference to a bioimage.io model.""" 

2732 

2733 id: ModelId 

2734 """A valid model `id` from the bioimage.io collection.""" 

2735 

2736 

2737class _DataDepSize(NamedTuple): 

2738 min: StrictInt 

2739 max: Optional[StrictInt] 

2740 

2741 

2742class _AxisSizes(NamedTuple): 

2743 """the lenghts of all axes of model inputs and outputs""" 

2744 

2745 inputs: Dict[Tuple[TensorId, AxisId], int] 

2746 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2747 

2748 

2749class _TensorSizes(NamedTuple): 

2750 """_AxisSizes as nested dicts""" 

2751 

2752 inputs: Dict[TensorId, Dict[AxisId, int]] 

2753 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2754 

2755 

2756class ReproducibilityTolerance(Node, extra="allow"): 

2757 """Describes what small numerical differences -- if any -- may be tolerated 

2758 in the generated output when executing in different environments. 

2759 

2760 A tensor element *output* is considered mismatched to the **test_tensor** if 

2761 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

2762 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

2763 

2764 Motivation: 

2765 For testing we can request the respective deep learning frameworks to be as 

2766 reproducible as possible by setting seeds and chosing deterministic algorithms, 

2767 but differences in operating systems, available hardware and installed drivers 

2768 may still lead to numerical differences. 

2769 """ 

2770 

2771 relative_tolerance: RelativeTolerance = 1e-3 

2772 """Maximum relative tolerance of reproduced test tensor.""" 

2773 

2774 absolute_tolerance: AbsoluteTolerance = 1e-3 

2775 """Maximum absolute tolerance of reproduced test tensor.""" 

2776 

2777 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 

2778 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

2779 

2780 output_ids: Sequence[TensorId] = () 

2781 """Limits the output tensor IDs these reproducibility details apply to.""" 

2782 

2783 weights_formats: Sequence[WeightsFormat] = () 

2784 """Limits the weights formats these details apply to.""" 

2785 

2786 

2787class BiasRisksLimitations(Node, extra="allow"): 

2788 """Known biases, risks, technical limitations, and recommendations for model use.""" 

2789 

2790 known_biases: str = dedent("""\ 

2791 In general bioimage models may suffer from biases caused by: 

2792 

2793 - Imaging protocol dependencies 

2794 - Use of a specific cell type 

2795 - Species-specific training data limitations 

2796 

2797 """) 

2798 """Biases in training data or model behavior.""" 

2799 

2800 risks: str = dedent("""\ 

2801 Common risks in bioimage analysis include: 

2802 

2803 - Erroneously assuming generalization to unseen experimental conditions 

2804 - Trusting (overconfident) model outputs without validation 

2805 - Misinterpretation of results 

2806 

2807 """) 

2808 """Potential risks in the context of bioimage analysis.""" 

2809 

2810 limitations: Optional[str] = None 

2811 """Technical limitations and failure modes.""" 

2812 

2813 recommendations: str = "Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model." 

2814 """Mitigation strategies regarding `known_biases`, `risks`, and `limitations`, as well as applicable best practices. 

2815 

2816 Consider: 

2817 - How to use a validation dataset? 

2818 - How to manually validate? 

2819 - Feasibility of domain adaptation for different experimental setups? 

2820 

2821 """ 

2822 

2823 def format_md(self) -> str: 

2824 if self.limitations is None: 

2825 limitations_header = "" 

2826 else: 

2827 limitations_header = "## Limitations\n\n" 

2828 

2829 return f"""# Bias, Risks, and Limitations 

2830 

2831{self.known_biases} 

2832 

2833{self.risks} 

2834 

2835{limitations_header}{self.limitations or ""} 

2836 

2837## Recommendations 

2838 

2839{self.recommendations} 

2840 

2841""" 

2842 

2843 

2844class TrainingDetails(Node, extra="allow"): 

2845 training_preprocessing: Optional[str] = None 

2846 """Detailed image preprocessing steps during model training: 

2847 

2848 Mention: 

2849 - *Normalization methods* 

2850 - *Augmentation strategies* 

2851 - *Resizing/resampling procedures* 

2852 - *Artifact handling* 

2853 

2854 """ 

2855 

2856 training_epochs: Optional[float] = None 

2857 """Number of training epochs.""" 

2858 

2859 training_batch_size: Optional[float] = None 

2860 """Batch size used in training.""" 

2861 

2862 initial_learning_rate: Optional[float] = None 

2863 """Initial learning rate used in training.""" 

2864 

2865 learning_rate_schedule: Optional[str] = None 

2866 """Learning rate schedule used in training.""" 

2867 

2868 loss_function: Optional[str] = None 

2869 """Loss function used in training, e.g. nn.MSELoss.""" 

2870 

2871 loss_function_kwargs: Dict[str, YamlValue] = Field( 

2872 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2873 ) 

2874 """key word arguments for the `loss_function`""" 

2875 

2876 optimizer: Optional[str] = None 

2877 """optimizer, e.g. torch.optim.Adam""" 

2878 

2879 optimizer_kwargs: Dict[str, YamlValue] = Field( 

2880 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2881 ) 

2882 """key word arguments for the `optimizer`""" 

2883 

2884 regularization: Optional[str] = None 

2885 """Regularization techniques used during training, e.g. drop-out or weight decay.""" 

2886 

2887 training_duration: Optional[float] = None 

2888 """Total training duration in hours.""" 

2889 

2890 

2891class Evaluation(Node, extra="allow"): 

2892 model_id: Optional[ModelId] = None 

2893 """Model being evaluated.""" 

2894 

2895 dataset_id: DatasetId 

2896 """Dataset used for evaluation.""" 

2897 

2898 dataset_source: HttpUrl 

2899 """Source of the dataset.""" 

2900 

2901 dataset_role: Literal["train", "validation", "test", "independent", "unknown"] 

2902 """Role of the dataset used for evaluation. 

2903 

2904 - `train`: dataset was (part of) the training data 

2905 - `validation`: dataset was (part of) the validation data used during training, e.g. used for model selection or hyperparameter tuning 

2906 - `test`: dataset was (part of) the designated test data; not used during training or validation, but acquired from the same source/distribution as training data 

2907 - `independent`: dataset is entirely independent test data; not used during training or validation, and acquired from a different source/distribution than training data 

2908 - `unknown`: role of the dataset is unknown; choose this if you are not certain if (a subset) of the data was seen by the model during training. 

2909 """ 

2910 

2911 sample_count: int 

2912 """Number of evaluated samples.""" 

2913 

2914 evaluation_factors: List[Annotated[str, MaxLen(16)]] 

2915 """(Abbreviations of) each evaluation factor. 

2916 

2917 Evaluation factors are criteria along which model performance is evaluated, e.g. different image conditions 

2918 like 'low SNR', 'high cell density', or different biological conditions like 'cell type A', 'cell type B'. 

2919 An 'overall' factor may be included to summarize performance across all conditions. 

2920 """ 

2921 

2922 evaluation_factors_long: List[str] 

2923 """Descriptions (long form) of each evaluation factor.""" 

2924 

2925 metrics: List[Annotated[str, MaxLen(16)]] 

2926 """(Abbreviations of) metrics used for evaluation.""" 

2927 

2928 metrics_long: List[str] 

2929 """Description of each metric used.""" 

2930 

2931 @model_validator(mode="after") 

2932 def _validate_list_lengths(self) -> Self: 

2933 if len(self.evaluation_factors) != len(self.evaluation_factors_long): 

2934 raise ValueError( 

2935 "`evaluation_factors` and `evaluation_factors_long` must have the same length" 

2936 ) 

2937 

2938 if len(self.metrics) != len(self.metrics_long): 

2939 raise ValueError("`metrics` and `metrics_long` must have the same length") 

2940 

2941 if len(self.results) != len(self.metrics): 

2942 raise ValueError("`results` must have the same number of rows as `metrics`") 

2943 

2944 for row in self.results: 

2945 if len(row) != len(self.evaluation_factors): 

2946 raise ValueError( 

2947 "`results` must have the same number of columns (in every row) as `evaluation_factors`" 

2948 ) 

2949 

2950 return self 

2951 

2952 results: List[List[Union[str, float, int]]] 

2953 """Results for each metric (rows; outer list) and each evaluation factor (columns; inner list).""" 

2954 

2955 results_summary: Optional[str] = None 

2956 """Interpretation of results for general audience. 

2957 

2958 Consider: 

2959 - Overall model performance 

2960 - Comparison to existing methods 

2961 - Limitations and areas for improvement 

2962 

2963""" 

2964 

2965 def format_md(self): 

2966 results_header = ["Metric"] + self.evaluation_factors 

2967 results_table_cells = [results_header, ["---"] * len(results_header)] + [ 

2968 [metric] + [str(r) for r in row] 

2969 for metric, row in zip(self.metrics, self.results) 

2970 ] 

2971 

2972 results_table = "".join( 

2973 "| " + " | ".join(row) + " |\n" for row in results_table_cells 

2974 ) 

2975 factors = "".join( 

2976 f"\n - {ef}: {efl}" 

2977 for ef, efl in zip(self.evaluation_factors, self.evaluation_factors_long) 

2978 ) 

2979 metrics = "".join( 

2980 f"\n - {em}: {eml}" for em, eml in zip(self.metrics, self.metrics_long) 

2981 ) 

2982 

2983 return f"""## Testing Data, Factors & Metrics 

2984 

2985Evaluation of {self.model_id or "this"} model on the {self.dataset_id} dataset (dataset role: {self.dataset_role}). 

2986 

2987### Testing Data 

2988 

2989- **Source:** [{self.dataset_id}]({self.dataset_source}) 

2990- **Size:** {self.sample_count} evaluated samples 

2991 

2992### Factors 

2993{factors} 

2994 

2995### Metrics 

2996{metrics} 

2997 

2998## Results 

2999 

3000### Quantitative Results 

3001 

3002{results_table} 

3003 

3004### Summary 

3005 

3006{self.results_summary or "missing"} 

3007 

3008""" 

3009 

3010 

3011class EnvironmentalImpact(Node, extra="allow"): 

3012 """Environmental considerations for model training and deployment. 

3013 

3014 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). 

3015 """ 

3016 

3017 hardware_type: Optional[str] = None 

3018 """GPU/CPU specifications""" 

3019 

3020 hours_used: Optional[float] = None 

3021 """Total compute hours""" 

3022 

3023 cloud_provider: Optional[str] = None 

3024 """If applicable""" 

3025 

3026 compute_region: Optional[str] = None 

3027 """Geographic location""" 

3028 

3029 co2_emitted: Optional[float] = None 

3030 """kg CO2 equivalent 

3031 

3032 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). 

3033 """ 

3034 

3035 def format_md(self): 

3036 """Filled Markdown template section following [Hugging Face Model Card Template](https://huggingface.co/docs/hub/en/model-card-annotated).""" 

3037 if self == self.__class__(): 

3038 return "" 

3039 

3040 ret = "# Environmental Impact\n\n" 

3041 if self.hardware_type is not None: 

3042 ret += f"- **Hardware Type:** {self.hardware_type}\n" 

3043 if self.hours_used is not None: 

3044 ret += f"- **Hours used:** {self.hours_used}\n" 

3045 if self.cloud_provider is not None: 

3046 ret += f"- **Cloud Provider:** {self.cloud_provider}\n" 

3047 if self.compute_region is not None: 

3048 ret += f"- **Compute Region:** {self.compute_region}\n" 

3049 if self.co2_emitted is not None: 

3050 ret += f"- **Carbon Emitted:** {self.co2_emitted} kg CO2e\n" 

3051 

3052 return ret + "\n" 

3053 

3054 

3055class BioimageioConfig(Node, extra="allow"): 

3056 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

3057 """Tolerances to allow when reproducing the model's test outputs 

3058 from the model's test inputs. 

3059 Only the first entry matching tensor id and weights format is considered. 

3060 """ 

3061 

3062 funded_by: Optional[str] = None 

3063 """Funding agency, grant number if applicable""" 

3064 

3065 architecture_type: Optional[Annotated[str, MaxLen(32)]] = ( 

3066 None # TODO: add to differentiated tags 

3067 ) 

3068 """Model architecture type, e.g., 3D U-Net, ResNet, transformer""" 

3069 

3070 architecture_description: Optional[str] = None 

3071 """Text description of model architecture.""" 

3072 

3073 modality: Optional[str] = None # TODO: add to differentiated tags 

3074 """Input modality, e.g., fluorescence microscopy, electron microscopy""" 

3075 

3076 target_structure: List[str] = Field( # TODO: add to differentiated tags 

3077 default_factory=cast(Callable[[], List[str]], list) 

3078 ) 

3079 """Biological structure(s) the model is designed to analyze, e.g., nuclei, mitochondria, cells""" 

3080 

3081 task: Optional[str] = None # TODO: add to differentiated tags 

3082 """Bioimage-specific task type, e.g., segmentation, classification, detection, denoising""" 

3083 

3084 new_version: Optional[ModelId] = None 

3085 """A new version of this model exists with a different model id.""" 

3086 

3087 out_of_scope_use: Optional[str] = None 

3088 """Describe how the model may be misused in bioimage analysis contexts and what users should **not** do with the model.""" 

3089 

3090 bias_risks_limitations: BiasRisksLimitations = Field( 

3091 default_factory=BiasRisksLimitations.model_construct 

3092 ) 

3093 """Description of known bias, risks, and technical limitations for in-scope model use.""" 

3094 

3095 model_parameter_count: Optional[int] = None 

3096 """Total number of model parameters.""" 

3097 

3098 training: TrainingDetails = Field(default_factory=TrainingDetails.model_construct) 

3099 """Details on how the model was trained.""" 

3100 

3101 inference_time: Optional[str] = None 

3102 """Average inference time per image/tile. Specify hardware and image size. Multiple examples can be given.""" 

3103 

3104 memory_requirements_inference: Optional[str] = None 

3105 """GPU memory needed for inference. Multiple examples with different image size can be given.""" 

3106 

3107 memory_requirements_training: Optional[str] = None 

3108 """GPU memory needed for training. Multiple examples with different image/batch sizes can be given.""" 

3109 

3110 evaluations: List[Evaluation] = Field( 

3111 default_factory=cast(Callable[[], List[Evaluation]], list) 

3112 ) 

3113 """Quantitative model evaluations. 

3114 

3115 Note: 

3116 At the moment we recommend to include only a single test dataset 

3117 (with evaluation factors that may mark subsets of the dataset) 

3118 to avoid confusion and make the presentation of results cleaner. 

3119 """ 

3120 

3121 environmental_impact: EnvironmentalImpact = Field( 

3122 default_factory=EnvironmentalImpact.model_construct 

3123 ) 

3124 """Environmental considerations for model training and deployment""" 

3125 

3126 

3127class Config(Node, extra="allow"): 

3128 bioimageio: BioimageioConfig = Field( 

3129 default_factory=BioimageioConfig.model_construct 

3130 ) 

3131 

3132 

3133class ModelDescr(GenericModelDescrBase): 

3134 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

3135 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

3136 """ 

3137 

3138 implemented_format_version: ClassVar[Literal["0.5.7"]] = "0.5.7" 

3139 if TYPE_CHECKING: 

3140 format_version: Literal["0.5.7"] = "0.5.7" 

3141 else: 

3142 format_version: Literal["0.5.7"] 

3143 """Version of the bioimage.io model description specification used. 

3144 When creating a new model always use the latest micro/patch version described here. 

3145 The `format_version` is important for any consumer software to understand how to parse the fields. 

3146 """ 

3147 

3148 implemented_type: ClassVar[Literal["model"]] = "model" 

3149 if TYPE_CHECKING: 

3150 type: Literal["model"] = "model" 

3151 else: 

3152 type: Literal["model"] 

3153 """Specialized resource type 'model'""" 

3154 

3155 id: Optional[ModelId] = None 

3156 """bioimage.io-wide unique resource identifier 

3157 assigned by bioimage.io; version **un**specific.""" 

3158 

3159 authors: FAIR[List[Author]] = Field( 

3160 default_factory=cast(Callable[[], List[Author]], list) 

3161 ) 

3162 """The authors are the creators of the model RDF and the primary points of contact.""" 

3163 

3164 documentation: FAIR[Optional[FileSource_documentation]] = None 

3165 """URL or relative path to a markdown file with additional documentation. 

3166 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

3167 The documentation should include a '#[#] Validation' (sub)section 

3168 with details on how to quantitatively validate the model on unseen data.""" 

3169 

3170 @field_validator("documentation", mode="after") 

3171 @classmethod 

3172 def _validate_documentation( 

3173 cls, value: Optional[FileSource_documentation] 

3174 ) -> Optional[FileSource_documentation]: 

3175 if not get_validation_context().perform_io_checks or value is None: 

3176 return value 

3177 

3178 doc_reader = get_reader(value) 

3179 doc_content = doc_reader.read().decode(encoding="utf-8") 

3180 if not re.search("#.*[vV]alidation", doc_content): 

3181 issue_warning( 

3182 "No '# Validation' (sub)section found in {value}.", 

3183 value=value, 

3184 field="documentation", 

3185 ) 

3186 

3187 return value 

3188 

3189 inputs: NotEmpty[Sequence[InputTensorDescr]] 

3190 """Describes the input tensors expected by this model.""" 

3191 

3192 @field_validator("inputs", mode="after") 

3193 @classmethod 

3194 def _validate_input_axes( 

3195 cls, inputs: Sequence[InputTensorDescr] 

3196 ) -> Sequence[InputTensorDescr]: 

3197 input_size_refs = cls._get_axes_with_independent_size(inputs) 

3198 

3199 for i, ipt in enumerate(inputs): 

3200 valid_independent_refs: Dict[ 

3201 Tuple[TensorId, AxisId], 

3202 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3203 ] = { 

3204 **{ 

3205 (ipt.id, a.id): (ipt, a, a.size) 

3206 for a in ipt.axes 

3207 if not isinstance(a, BatchAxis) 

3208 and isinstance(a.size, (int, ParameterizedSize)) 

3209 }, 

3210 **input_size_refs, 

3211 } 

3212 for a, ax in enumerate(ipt.axes): 

3213 cls._validate_axis( 

3214 "inputs", 

3215 i=i, 

3216 tensor_id=ipt.id, 

3217 a=a, 

3218 axis=ax, 

3219 valid_independent_refs=valid_independent_refs, 

3220 ) 

3221 return inputs 

3222 

3223 @staticmethod 

3224 def _validate_axis( 

3225 field_name: str, 

3226 i: int, 

3227 tensor_id: TensorId, 

3228 a: int, 

3229 axis: AnyAxis, 

3230 valid_independent_refs: Dict[ 

3231 Tuple[TensorId, AxisId], 

3232 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3233 ], 

3234 ): 

3235 if isinstance(axis, BatchAxis) or isinstance( 

3236 axis.size, (int, ParameterizedSize, DataDependentSize) 

3237 ): 

3238 return 

3239 elif not isinstance(axis.size, SizeReference): 

3240 assert_never(axis.size) 

3241 

3242 # validate axis.size SizeReference 

3243 ref = (axis.size.tensor_id, axis.size.axis_id) 

3244 if ref not in valid_independent_refs: 

3245 raise ValueError( 

3246 "Invalid tensor axis reference at" 

3247 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

3248 ) 

3249 if ref == (tensor_id, axis.id): 

3250 raise ValueError( 

3251 "Self-referencing not allowed for" 

3252 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

3253 ) 

3254 if axis.type == "channel": 

3255 if valid_independent_refs[ref][1].type != "channel": 

3256 raise ValueError( 

3257 "A channel axis' size may only reference another fixed size" 

3258 + " channel axis." 

3259 ) 

3260 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

3261 ref_size = valid_independent_refs[ref][2] 

3262 assert isinstance(ref_size, int), ( 

3263 "channel axis ref (another channel axis) has to specify fixed" 

3264 + " size" 

3265 ) 

3266 generated_channel_names = [ 

3267 Identifier(axis.channel_names.format(i=i)) 

3268 for i in range(1, ref_size + 1) 

3269 ] 

3270 axis.channel_names = generated_channel_names 

3271 

3272 if (ax_unit := getattr(axis, "unit", None)) != ( 

3273 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

3274 ): 

3275 raise ValueError( 

3276 "The units of an axis and its reference axis need to match, but" 

3277 + f" '{ax_unit}' != '{ref_unit}'." 

3278 ) 

3279 ref_axis = valid_independent_refs[ref][1] 

3280 if isinstance(ref_axis, BatchAxis): 

3281 raise ValueError( 

3282 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

3283 + " (a batch axis is not allowed as reference)." 

3284 ) 

3285 

3286 if isinstance(axis, WithHalo): 

3287 min_size = axis.size.get_size(axis, ref_axis, n=0) 

3288 if (min_size - 2 * axis.halo) < 1: 

3289 raise ValueError( 

3290 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

3291 + f" {axis.halo}." 

3292 ) 

3293 

3294 ref_halo = axis.halo * axis.scale / ref_axis.scale 

3295 if ref_halo != int(ref_halo): 

3296 raise ValueError( 

3297 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} =" 

3298 + f" {tensor_id}.{axis.id}.halo {axis.halo}" 

3299 + f" * {tensor_id}.{axis.id}.scale {axis.scale}" 

3300 + f" / {'.'.join(ref)}.scale {ref_axis.scale})." 

3301 ) 

3302 

3303 @model_validator(mode="after") 

3304 def _validate_test_tensors(self) -> Self: 

3305 if not get_validation_context().perform_io_checks: 

3306 return self 

3307 

3308 test_output_arrays = [ 

3309 None if descr.test_tensor is None else load_array(descr.test_tensor) 

3310 for descr in self.outputs 

3311 ] 

3312 test_input_arrays = [ 

3313 None if descr.test_tensor is None else load_array(descr.test_tensor) 

3314 for descr in self.inputs 

3315 ] 

3316 

3317 tensors = { 

3318 descr.id: (descr, array) 

3319 for descr, array in zip( 

3320 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 

3321 ) 

3322 } 

3323 validate_tensors(tensors, tensor_origin="test_tensor") 

3324 

3325 output_arrays = { 

3326 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 

3327 } 

3328 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

3329 if not rep_tol.absolute_tolerance: 

3330 continue 

3331 

3332 if rep_tol.output_ids: 

3333 out_arrays = { 

3334 oid: a 

3335 for oid, a in output_arrays.items() 

3336 if oid in rep_tol.output_ids 

3337 } 

3338 else: 

3339 out_arrays = output_arrays 

3340 

3341 for out_id, array in out_arrays.items(): 

3342 if array is None: 

3343 continue 

3344 

3345 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

3346 raise ValueError( 

3347 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

3348 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

3349 + f" (1% of the maximum value of the test tensor '{out_id}')" 

3350 ) 

3351 

3352 return self 

3353 

3354 @model_validator(mode="after") 

3355 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

3356 ipt_refs = {t.id for t in self.inputs} 

3357 out_refs = {t.id for t in self.outputs} 

3358 for ipt in self.inputs: 

3359 for p in ipt.preprocessing: 

3360 ref = p.kwargs.get("reference_tensor") 

3361 if ref is None: 

3362 continue 

3363 if ref not in ipt_refs: 

3364 raise ValueError( 

3365 f"`reference_tensor` '{ref}' not found. Valid input tensor" 

3366 + f" references are: {ipt_refs}." 

3367 ) 

3368 

3369 for out in self.outputs: 

3370 for p in out.postprocessing: 

3371 ref = p.kwargs.get("reference_tensor") 

3372 if ref is None: 

3373 continue 

3374 

3375 if ref not in ipt_refs and ref not in out_refs: 

3376 raise ValueError( 

3377 f"`reference_tensor` '{ref}' not found. Valid tensor references" 

3378 + f" are: {ipt_refs | out_refs}." 

3379 ) 

3380 

3381 return self 

3382 

3383 # TODO: use validate funcs in validate_test_tensors 

3384 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 

3385 

3386 name: Annotated[ 

3387 str, 

3388 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 

3389 MinLen(5), 

3390 MaxLen(128), 

3391 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

3392 ] 

3393 """A human-readable name of this model. 

3394 It should be no longer than 64 characters 

3395 and may only contain letter, number, underscore, minus, parentheses and spaces. 

3396 We recommend to chose a name that refers to the model's task and image modality. 

3397 """ 

3398 

3399 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

3400 """Describes the output tensors.""" 

3401 

3402 @field_validator("outputs", mode="after") 

3403 @classmethod 

3404 def _validate_tensor_ids( 

3405 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

3406 ) -> Sequence[OutputTensorDescr]: 

3407 tensor_ids = [ 

3408 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

3409 ] 

3410 duplicate_tensor_ids: List[str] = [] 

3411 seen: Set[str] = set() 

3412 for t in tensor_ids: 

3413 if t in seen: 

3414 duplicate_tensor_ids.append(t) 

3415 

3416 seen.add(t) 

3417 

3418 if duplicate_tensor_ids: 

3419 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

3420 

3421 return outputs 

3422 

3423 @staticmethod 

3424 def _get_axes_with_parameterized_size( 

3425 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3426 ): 

3427 return { 

3428 f"{t.id}.{a.id}": (t, a, a.size) 

3429 for t in io 

3430 for a in t.axes 

3431 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

3432 } 

3433 

3434 @staticmethod 

3435 def _get_axes_with_independent_size( 

3436 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3437 ): 

3438 return { 

3439 (t.id, a.id): (t, a, a.size) 

3440 for t in io 

3441 for a in t.axes 

3442 if not isinstance(a, BatchAxis) 

3443 and isinstance(a.size, (int, ParameterizedSize)) 

3444 } 

3445 

3446 @field_validator("outputs", mode="after") 

3447 @classmethod 

3448 def _validate_output_axes( 

3449 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

3450 ) -> List[OutputTensorDescr]: 

3451 input_size_refs = cls._get_axes_with_independent_size( 

3452 info.data.get("inputs", []) 

3453 ) 

3454 output_size_refs = cls._get_axes_with_independent_size(outputs) 

3455 

3456 for i, out in enumerate(outputs): 

3457 valid_independent_refs: Dict[ 

3458 Tuple[TensorId, AxisId], 

3459 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3460 ] = { 

3461 **{ 

3462 (out.id, a.id): (out, a, a.size) 

3463 for a in out.axes 

3464 if not isinstance(a, BatchAxis) 

3465 and isinstance(a.size, (int, ParameterizedSize)) 

3466 }, 

3467 **input_size_refs, 

3468 **output_size_refs, 

3469 } 

3470 for a, ax in enumerate(out.axes): 

3471 cls._validate_axis( 

3472 "outputs", 

3473 i, 

3474 out.id, 

3475 a, 

3476 ax, 

3477 valid_independent_refs=valid_independent_refs, 

3478 ) 

3479 

3480 return outputs 

3481 

3482 packaged_by: List[Author] = Field( 

3483 default_factory=cast(Callable[[], List[Author]], list) 

3484 ) 

3485 """The persons that have packaged and uploaded this model. 

3486 Only required if those persons differ from the `authors`.""" 

3487 

3488 parent: Optional[LinkedModel] = None 

3489 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

3490 

3491 @model_validator(mode="after") 

3492 def _validate_parent_is_not_self(self) -> Self: 

3493 if self.parent is not None and self.parent.id == self.id: 

3494 raise ValueError("A model description may not reference itself as parent.") 

3495 

3496 return self 

3497 

3498 run_mode: Annotated[ 

3499 Optional[RunMode], 

3500 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

3501 ] = None 

3502 """Custom run mode for this model: for more complex prediction procedures like test time 

3503 data augmentation that currently cannot be expressed in the specification. 

3504 No standard run modes are defined yet.""" 

3505 

3506 timestamp: Datetime = Field(default_factory=Datetime.now) 

3507 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

3508 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

3509 (In Python a datetime object is valid, too).""" 

3510 

3511 training_data: Annotated[ 

3512 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

3513 Field(union_mode="left_to_right"), 

3514 ] = None 

3515 """The dataset used to train this model""" 

3516 

3517 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

3518 """The weights for this model. 

3519 Weights can be given for different formats, but should otherwise be equivalent. 

3520 The available weight formats determine which consumers can use this model.""" 

3521 

3522 config: Config = Field(default_factory=Config.model_construct) 

3523 

3524 @model_validator(mode="after") 

3525 def _add_default_cover(self) -> Self: 

3526 if not get_validation_context().perform_io_checks or self.covers: 

3527 return self 

3528 

3529 try: 

3530 generated_covers = generate_covers( 

3531 [ 

3532 (t, load_array(t.test_tensor)) 

3533 for t in self.inputs 

3534 if t.test_tensor is not None 

3535 ], 

3536 [ 

3537 (t, load_array(t.test_tensor)) 

3538 for t in self.outputs 

3539 if t.test_tensor is not None 

3540 ], 

3541 ) 

3542 except Exception as e: 

3543 issue_warning( 

3544 "Failed to generate cover image(s): {e}", 

3545 value=self.covers, 

3546 msg_context=dict(e=e), 

3547 field="covers", 

3548 ) 

3549 else: 

3550 self.covers.extend(generated_covers) 

3551 

3552 return self 

3553 

3554 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

3555 return self._get_test_arrays(self.inputs) 

3556 

3557 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

3558 return self._get_test_arrays(self.outputs) 

3559 

3560 @staticmethod 

3561 def _get_test_arrays( 

3562 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3563 ): 

3564 ts: List[FileDescr] = [] 

3565 for d in io_descr: 

3566 if d.test_tensor is None: 

3567 raise ValueError( 

3568 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 

3569 ) 

3570 ts.append(d.test_tensor) 

3571 

3572 data = [load_array(t) for t in ts] 

3573 assert all(isinstance(d, np.ndarray) for d in data) 

3574 return data 

3575 

3576 @staticmethod 

3577 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

3578 batch_size = 1 

3579 tensor_with_batchsize: Optional[TensorId] = None 

3580 for tid in tensor_sizes: 

3581 for aid, s in tensor_sizes[tid].items(): 

3582 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

3583 continue 

3584 

3585 if batch_size != 1: 

3586 assert tensor_with_batchsize is not None 

3587 raise ValueError( 

3588 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

3589 ) 

3590 

3591 batch_size = s 

3592 tensor_with_batchsize = tid 

3593 

3594 return batch_size 

3595 

3596 def get_output_tensor_sizes( 

3597 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

3598 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

3599 """Returns the tensor output sizes for given **input_sizes**. 

3600 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

3601 Otherwise it might be larger than the actual (valid) output""" 

3602 batch_size = self.get_batch_size(input_sizes) 

3603 ns = self.get_ns(input_sizes) 

3604 

3605 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

3606 return tensor_sizes.outputs 

3607 

3608 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

3609 """get parameter `n` for each parameterized axis 

3610 such that the valid input size is >= the given input size""" 

3611 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

3612 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

3613 for tid in input_sizes: 

3614 for aid, s in input_sizes[tid].items(): 

3615 size_descr = axes[tid][aid].size 

3616 if isinstance(size_descr, ParameterizedSize): 

3617 ret[(tid, aid)] = size_descr.get_n(s) 

3618 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

3619 pass 

3620 else: 

3621 assert_never(size_descr) 

3622 

3623 return ret 

3624 

3625 def get_tensor_sizes( 

3626 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

3627 ) -> _TensorSizes: 

3628 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

3629 return _TensorSizes( 

3630 { 

3631 t: { 

3632 aa: axis_sizes.inputs[(tt, aa)] 

3633 for tt, aa in axis_sizes.inputs 

3634 if tt == t 

3635 } 

3636 for t in {tt for tt, _ in axis_sizes.inputs} 

3637 }, 

3638 { 

3639 t: { 

3640 aa: axis_sizes.outputs[(tt, aa)] 

3641 for tt, aa in axis_sizes.outputs 

3642 if tt == t 

3643 } 

3644 for t in {tt for tt, _ in axis_sizes.outputs} 

3645 }, 

3646 ) 

3647 

3648 def get_axis_sizes( 

3649 self, 

3650 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3651 batch_size: Optional[int] = None, 

3652 *, 

3653 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

3654 ) -> _AxisSizes: 

3655 """Determine input and output block shape for scale factors **ns** 

3656 of parameterized input sizes. 

3657 

3658 Args: 

3659 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

3660 that is parameterized as `size = min + n * step`. 

3661 batch_size: The desired size of the batch dimension. 

3662 If given **batch_size** overwrites any batch size present in 

3663 **max_input_shape**. Default 1. 

3664 max_input_shape: Limits the derived block shapes. 

3665 Each axis for which the input size, parameterized by `n`, is larger 

3666 than **max_input_shape** is set to the minimal value `n_min` for which 

3667 this is still true. 

3668 Use this for small input samples or large values of **ns**. 

3669 Or simply whenever you know the full input shape. 

3670 

3671 Returns: 

3672 Resolved axis sizes for model inputs and outputs. 

3673 """ 

3674 max_input_shape = max_input_shape or {} 

3675 if batch_size is None: 

3676 for (_t_id, a_id), s in max_input_shape.items(): 

3677 if a_id == BATCH_AXIS_ID: 

3678 batch_size = s 

3679 break 

3680 else: 

3681 batch_size = 1 

3682 

3683 all_axes = { 

3684 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

3685 } 

3686 

3687 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

3688 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

3689 

3690 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

3691 if isinstance(a, BatchAxis): 

3692 if (t_descr.id, a.id) in ns: 

3693 logger.warning( 

3694 "Ignoring unexpected size increment factor (n) for batch axis" 

3695 + " of tensor '{}'.", 

3696 t_descr.id, 

3697 ) 

3698 return batch_size 

3699 elif isinstance(a.size, int): 

3700 if (t_descr.id, a.id) in ns: 

3701 logger.warning( 

3702 "Ignoring unexpected size increment factor (n) for fixed size" 

3703 + " axis '{}' of tensor '{}'.", 

3704 a.id, 

3705 t_descr.id, 

3706 ) 

3707 return a.size 

3708 elif isinstance(a.size, ParameterizedSize): 

3709 if (t_descr.id, a.id) not in ns: 

3710 raise ValueError( 

3711 "Size increment factor (n) missing for parametrized axis" 

3712 + f" '{a.id}' of tensor '{t_descr.id}'." 

3713 ) 

3714 n = ns[(t_descr.id, a.id)] 

3715 s_max = max_input_shape.get((t_descr.id, a.id)) 

3716 if s_max is not None: 

3717 n = min(n, a.size.get_n(s_max)) 

3718 

3719 return a.size.get_size(n) 

3720 

3721 elif isinstance(a.size, SizeReference): 

3722 if (t_descr.id, a.id) in ns: 

3723 logger.warning( 

3724 "Ignoring unexpected size increment factor (n) for axis '{}'" 

3725 + " of tensor '{}' with size reference.", 

3726 a.id, 

3727 t_descr.id, 

3728 ) 

3729 assert not isinstance(a, BatchAxis) 

3730 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

3731 assert not isinstance(ref_axis, BatchAxis) 

3732 ref_key = (a.size.tensor_id, a.size.axis_id) 

3733 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

3734 assert ref_size is not None, ref_key 

3735 assert not isinstance(ref_size, _DataDepSize), ref_key 

3736 return a.size.get_size( 

3737 axis=a, 

3738 ref_axis=ref_axis, 

3739 ref_size=ref_size, 

3740 ) 

3741 elif isinstance(a.size, DataDependentSize): 

3742 if (t_descr.id, a.id) in ns: 

3743 logger.warning( 

3744 "Ignoring unexpected increment factor (n) for data dependent" 

3745 + " size axis '{}' of tensor '{}'.", 

3746 a.id, 

3747 t_descr.id, 

3748 ) 

3749 return _DataDepSize(a.size.min, a.size.max) 

3750 else: 

3751 assert_never(a.size) 

3752 

3753 # first resolve all , but the `SizeReference` input sizes 

3754 for t_descr in self.inputs: 

3755 for a in t_descr.axes: 

3756 if not isinstance(a.size, SizeReference): 

3757 s = get_axis_size(a) 

3758 assert not isinstance(s, _DataDepSize) 

3759 inputs[t_descr.id, a.id] = s 

3760 

3761 # resolve all other input axis sizes 

3762 for t_descr in self.inputs: 

3763 for a in t_descr.axes: 

3764 if isinstance(a.size, SizeReference): 

3765 s = get_axis_size(a) 

3766 assert not isinstance(s, _DataDepSize) 

3767 inputs[t_descr.id, a.id] = s 

3768 

3769 # resolve all output axis sizes 

3770 for t_descr in self.outputs: 

3771 for a in t_descr.axes: 

3772 assert not isinstance(a.size, ParameterizedSize) 

3773 s = get_axis_size(a) 

3774 outputs[t_descr.id, a.id] = s 

3775 

3776 return _AxisSizes(inputs=inputs, outputs=outputs) 

3777 

3778 @model_validator(mode="before") 

3779 @classmethod 

3780 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

3781 cls.convert_from_old_format_wo_validation(data) 

3782 return data 

3783 

3784 @classmethod 

3785 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

3786 """Convert metadata following an older format version to this classes' format 

3787 without validating the result. 

3788 """ 

3789 if ( 

3790 data.get("type") == "model" 

3791 and isinstance(fv := data.get("format_version"), str) 

3792 and fv.count(".") == 2 

3793 ): 

3794 fv_parts = fv.split(".") 

3795 if any(not p.isdigit() for p in fv_parts): 

3796 return 

3797 

3798 fv_tuple = tuple(map(int, fv_parts)) 

3799 

3800 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

3801 if fv_tuple[:2] in ((0, 3), (0, 4)): 

3802 m04 = _ModelDescr_v0_4.load(data) 

3803 if isinstance(m04, InvalidDescr): 

3804 try: 

3805 updated = _model_conv.convert_as_dict( 

3806 m04 # pyright: ignore[reportArgumentType] 

3807 ) 

3808 except Exception as e: 

3809 logger.error( 

3810 "Failed to convert from invalid model 0.4 description." 

3811 + f"\nerror: {e}" 

3812 + "\nProceeding with model 0.5 validation without conversion." 

3813 ) 

3814 updated = None 

3815 else: 

3816 updated = _model_conv.convert_as_dict(m04) 

3817 

3818 if updated is not None: 

3819 data.clear() 

3820 data.update(updated) 

3821 

3822 elif fv_tuple[:2] == (0, 5): 

3823 # bump patch version 

3824 data["format_version"] = cls.implemented_format_version 

3825 

3826 

3827class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

3828 def _convert( 

3829 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

3830 ) -> "ModelDescr | dict[str, Any]": 

3831 name = "".join( 

3832 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

3833 for c in src.name 

3834 ) 

3835 

3836 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

3837 conv = ( 

3838 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

3839 ) 

3840 return None if auths is None else [conv(a) for a in auths] 

3841 

3842 if TYPE_CHECKING: 

3843 arch_file_conv = _arch_file_conv.convert 

3844 arch_lib_conv = _arch_lib_conv.convert 

3845 else: 

3846 arch_file_conv = _arch_file_conv.convert_as_dict 

3847 arch_lib_conv = _arch_lib_conv.convert_as_dict 

3848 

3849 input_size_refs = { 

3850 ipt.name: { 

3851 a: s 

3852 for a, s in zip( 

3853 ipt.axes, 

3854 ( 

3855 ipt.shape.min 

3856 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

3857 else ipt.shape 

3858 ), 

3859 ) 

3860 } 

3861 for ipt in src.inputs 

3862 if ipt.shape 

3863 } 

3864 output_size_refs = { 

3865 **{ 

3866 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

3867 for out in src.outputs 

3868 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

3869 }, 

3870 **input_size_refs, 

3871 } 

3872 

3873 return tgt( 

3874 attachments=( 

3875 [] 

3876 if src.attachments is None 

3877 else [FileDescr(source=f) for f in src.attachments.files] 

3878 ), 

3879 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType] 

3880 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType] 

3881 config=src.config, # pyright: ignore[reportArgumentType] 

3882 covers=src.covers, 

3883 description=src.description, 

3884 documentation=src.documentation, 

3885 format_version="0.5.7", 

3886 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3887 icon=src.icon, 

3888 id=None if src.id is None else ModelId(src.id), 

3889 id_emoji=src.id_emoji, 

3890 license=src.license, # type: ignore 

3891 links=src.links, 

3892 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType] 

3893 name=name, 

3894 tags=src.tags, 

3895 type=src.type, 

3896 uploader=src.uploader, 

3897 version=src.version, 

3898 inputs=[ # pyright: ignore[reportArgumentType] 

3899 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

3900 for ipt, tt, st in zip( 

3901 src.inputs, 

3902 src.test_inputs, 

3903 src.sample_inputs or [None] * len(src.test_inputs), 

3904 ) 

3905 ], 

3906 outputs=[ # pyright: ignore[reportArgumentType] 

3907 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

3908 for out, tt, st in zip( 

3909 src.outputs, 

3910 src.test_outputs, 

3911 src.sample_outputs or [None] * len(src.test_outputs), 

3912 ) 

3913 ], 

3914 parent=( 

3915 None 

3916 if src.parent is None 

3917 else LinkedModel( 

3918 id=ModelId( 

3919 str(src.parent.id) 

3920 + ( 

3921 "" 

3922 if src.parent.version_number is None 

3923 else f"/{src.parent.version_number}" 

3924 ) 

3925 ) 

3926 ) 

3927 ), 

3928 training_data=( 

3929 None 

3930 if src.training_data is None 

3931 else ( 

3932 LinkedDataset( 

3933 id=DatasetId( 

3934 str(src.training_data.id) 

3935 + ( 

3936 "" 

3937 if src.training_data.version_number is None 

3938 else f"/{src.training_data.version_number}" 

3939 ) 

3940 ) 

3941 ) 

3942 if isinstance(src.training_data, LinkedDataset02) 

3943 else src.training_data 

3944 ) 

3945 ), 

3946 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType] 

3947 run_mode=src.run_mode, 

3948 timestamp=src.timestamp, 

3949 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

3950 keras_hdf5=(w := src.weights.keras_hdf5) 

3951 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

3952 authors=conv_authors(w.authors), 

3953 source=w.source, 

3954 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3955 parent=w.parent, 

3956 ), 

3957 onnx=(w := src.weights.onnx) 

3958 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

3959 source=w.source, 

3960 authors=conv_authors(w.authors), 

3961 parent=w.parent, 

3962 opset_version=w.opset_version or 15, 

3963 ), 

3964 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

3965 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

3966 source=w.source, 

3967 authors=conv_authors(w.authors), 

3968 parent=w.parent, 

3969 architecture=( 

3970 arch_file_conv( 

3971 w.architecture, 

3972 w.architecture_sha256, 

3973 w.kwargs, 

3974 ) 

3975 if isinstance(w.architecture, _CallableFromFile_v0_4) 

3976 else arch_lib_conv(w.architecture, w.kwargs) 

3977 ), 

3978 pytorch_version=w.pytorch_version or Version("1.10"), 

3979 dependencies=( 

3980 None 

3981 if w.dependencies is None 

3982 else (FileDescr if TYPE_CHECKING else dict)( 

3983 source=cast( 

3984 FileSource, 

3985 str(deps := w.dependencies)[ 

3986 ( 

3987 len("conda:") 

3988 if str(deps).startswith("conda:") 

3989 else 0 

3990 ) : 

3991 ], 

3992 ) 

3993 ) 

3994 ), 

3995 ), 

3996 tensorflow_js=(w := src.weights.tensorflow_js) 

3997 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

3998 source=w.source, 

3999 authors=conv_authors(w.authors), 

4000 parent=w.parent, 

4001 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4002 ), 

4003 tensorflow_saved_model_bundle=( 

4004 w := src.weights.tensorflow_saved_model_bundle 

4005 ) 

4006 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

4007 authors=conv_authors(w.authors), 

4008 parent=w.parent, 

4009 source=w.source, 

4010 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4011 dependencies=( 

4012 None 

4013 if w.dependencies is None 

4014 else (FileDescr if TYPE_CHECKING else dict)( 

4015 source=cast( 

4016 FileSource, 

4017 ( 

4018 str(w.dependencies)[len("conda:") :] 

4019 if str(w.dependencies).startswith("conda:") 

4020 else str(w.dependencies) 

4021 ), 

4022 ) 

4023 ) 

4024 ), 

4025 ), 

4026 torchscript=(w := src.weights.torchscript) 

4027 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

4028 source=w.source, 

4029 authors=conv_authors(w.authors), 

4030 parent=w.parent, 

4031 pytorch_version=w.pytorch_version or Version("1.10"), 

4032 ), 

4033 ), 

4034 ) 

4035 

4036 

4037_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

4038 

4039 

4040# create better cover images for 3d data and non-image outputs 

4041def generate_covers( 

4042 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

4043 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

4044) -> List[Path]: 

4045 def squeeze( 

4046 data: NDArray[Any], axes: Sequence[AnyAxis] 

4047 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

4048 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

4049 if data.ndim != len(axes): 

4050 raise ValueError( 

4051 f"tensor shape {data.shape} does not match described axes" 

4052 + f" {[a.id for a in axes]}" 

4053 ) 

4054 

4055 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

4056 return data.squeeze(), axes 

4057 

4058 def normalize( 

4059 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

4060 ) -> NDArray[np.float32]: 

4061 data = data.astype("float32") 

4062 data -= data.min(axis=axis, keepdims=True) 

4063 data /= data.max(axis=axis, keepdims=True) + eps 

4064 return data 

4065 

4066 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

4067 original_shape = data.shape 

4068 original_axes = list(axes) 

4069 data, axes = squeeze(data, axes) 

4070 

4071 # take slice fom any batch or index axis if needed 

4072 # and convert the first channel axis and take a slice from any additional channel axes 

4073 slices: Tuple[slice, ...] = () 

4074 ndim = data.ndim 

4075 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

4076 has_c_axis = False 

4077 for i, a in enumerate(axes): 

4078 s = data.shape[i] 

4079 assert s > 1 

4080 if ( 

4081 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

4082 and ndim > ndim_need 

4083 ): 

4084 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4085 ndim -= 1 

4086 elif isinstance(a, ChannelAxis): 

4087 if has_c_axis: 

4088 # second channel axis 

4089 data = data[slices + (slice(0, 1),)] 

4090 ndim -= 1 

4091 else: 

4092 has_c_axis = True 

4093 if s == 2: 

4094 # visualize two channels with cyan and magenta 

4095 data = np.concatenate( 

4096 [ 

4097 data[slices + (slice(1, 2),)], 

4098 data[slices + (slice(0, 1),)], 

4099 ( 

4100 data[slices + (slice(0, 1),)] 

4101 + data[slices + (slice(1, 2),)] 

4102 ) 

4103 / 2, # TODO: take maximum instead? 

4104 ], 

4105 axis=i, 

4106 ) 

4107 elif data.shape[i] == 3: 

4108 pass # visualize 3 channels as RGB 

4109 else: 

4110 # visualize first 3 channels as RGB 

4111 data = data[slices + (slice(3),)] 

4112 

4113 assert data.shape[i] == 3 

4114 

4115 slices += (slice(None),) 

4116 

4117 data, axes = squeeze(data, axes) 

4118 assert len(axes) == ndim 

4119 # take slice from z axis if needed 

4120 slices = () 

4121 if ndim > ndim_need: 

4122 for i, a in enumerate(axes): 

4123 s = data.shape[i] 

4124 if a.id == AxisId("z"): 

4125 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4126 data, axes = squeeze(data, axes) 

4127 ndim -= 1 

4128 break 

4129 

4130 slices += (slice(None),) 

4131 

4132 # take slice from any space or time axis 

4133 slices = () 

4134 

4135 for i, a in enumerate(axes): 

4136 if ndim <= ndim_need: 

4137 break 

4138 

4139 s = data.shape[i] 

4140 assert s > 1 

4141 if isinstance( 

4142 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

4143 ): 

4144 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4145 ndim -= 1 

4146 

4147 slices += (slice(None),) 

4148 

4149 del slices 

4150 data, axes = squeeze(data, axes) 

4151 assert len(axes) == ndim 

4152 

4153 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2): 

4154 raise ValueError( 

4155 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}." 

4156 ) 

4157 

4158 if not has_c_axis: 

4159 assert ndim == 2 

4160 data = np.repeat(data[:, :, None], 3, axis=2) 

4161 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

4162 ndim += 1 

4163 

4164 assert ndim == 3 

4165 

4166 # transpose axis order such that longest axis comes first... 

4167 axis_order: List[int] = list(np.argsort(list(data.shape))) 

4168 axis_order.reverse() 

4169 # ... and channel axis is last 

4170 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

4171 axis_order.append(axis_order.pop(c)) 

4172 axes = [axes[ao] for ao in axis_order] 

4173 data = data.transpose(axis_order) 

4174 

4175 # h, w = data.shape[:2] 

4176 # if h / w in (1.0 or 2.0): 

4177 # pass 

4178 # elif h / w < 2: 

4179 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

4180 

4181 norm_along = ( 

4182 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

4183 ) 

4184 # normalize the data and map to 8 bit 

4185 data = normalize(data, norm_along) 

4186 data = (data * 255).astype("uint8") 

4187 

4188 return data 

4189 

4190 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

4191 assert im0.dtype == im1.dtype == np.uint8 

4192 assert im0.shape == im1.shape 

4193 assert im0.ndim == 3 

4194 N, M, C = im0.shape 

4195 assert C == 3 

4196 out = np.ones((N, M, C), dtype="uint8") 

4197 for c in range(C): 

4198 outc = np.tril(im0[..., c]) 

4199 mask = outc == 0 

4200 outc[mask] = np.triu(im1[..., c])[mask] 

4201 out[..., c] = outc 

4202 

4203 return out 

4204 

4205 if not inputs: 

4206 raise ValueError("Missing test input tensor for cover generation.") 

4207 

4208 if not outputs: 

4209 raise ValueError("Missing test output tensor for cover generation.") 

4210 

4211 ipt_descr, ipt = inputs[0] 

4212 out_descr, out = outputs[0] 

4213 

4214 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

4215 out_img = to_2d_image(out, out_descr.axes) 

4216 

4217 cover_folder = Path(mkdtemp()) 

4218 if ipt_img.shape == out_img.shape: 

4219 covers = [cover_folder / "cover.png"] 

4220 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

4221 else: 

4222 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

4223 imwrite(covers[0], ipt_img) 

4224 imwrite(covers[1], out_img) 

4225 

4226 return covers