Coverage for src / bioimageio / spec / model / v0_5.py: 76%

1581 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-31 13:09 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from copy import deepcopy 

8from itertools import chain 

9from math import ceil 

10from pathlib import Path, PurePosixPath 

11from tempfile import mkdtemp 

12from textwrap import dedent 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 ClassVar, 

18 Dict, 

19 Generic, 

20 List, 

21 Literal, 

22 Mapping, 

23 NamedTuple, 

24 Optional, 

25 Sequence, 

26 Set, 

27 Tuple, 

28 Type, 

29 TypeVar, 

30 Union, 

31 cast, 

32 overload, 

33) 

34 

35import numpy as np 

36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

38from loguru import logger 

39from numpy.typing import NDArray 

40from pydantic import ( 

41 AfterValidator, 

42 Discriminator, 

43 Field, 

44 RootModel, 

45 SerializationInfo, 

46 SerializerFunctionWrapHandler, 

47 StrictInt, 

48 Tag, 

49 ValidationInfo, 

50 WrapSerializer, 

51 field_validator, 

52 model_serializer, 

53 model_validator, 

54) 

55from typing_extensions import Annotated, Self, assert_never, get_args 

56 

57from .._internal.common_nodes import ( 

58 InvalidDescr, 

59 KwargsNode, 

60 Node, 

61 NodeWithExplicitlySetFields, 

62) 

63from .._internal.constants import DTYPE_LIMITS 

64from .._internal.field_warning import issue_warning, warn 

65from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

66from .._internal.io import FileDescr as FileDescr 

67from .._internal.io import ( 

68 FileSource, 

69 WithSuffix, 

70 YamlValue, 

71 extract_file_name, 

72 get_reader, 

73 wo_special_file_name, 

74) 

75from .._internal.io_basics import Sha256 as Sha256 

76from .._internal.io_packaging import ( 

77 FileDescr_, 

78 FileSource_, 

79 package_file_descr_serializer, 

80) 

81from .._internal.io_utils import load_array 

82from .._internal.node_converter import Converter 

83from .._internal.type_guards import is_dict, is_sequence 

84from .._internal.types import ( 

85 FAIR, 

86 AbsoluteTolerance, 

87 LowerCaseIdentifier, 

88 LowerCaseIdentifierAnno, 

89 MismatchedElementsPerMillion, 

90 RelativeTolerance, 

91) 

92from .._internal.types import Datetime as Datetime 

93from .._internal.types import Identifier as Identifier 

94from .._internal.types import NotEmpty as NotEmpty 

95from .._internal.types import SiUnit as SiUnit 

96from .._internal.url import HttpUrl as HttpUrl 

97from .._internal.validation_context import get_validation_context 

98from .._internal.validator_annotations import RestrictCharacters 

99from .._internal.version_type import Version as Version 

100from .._internal.warning_levels import INFO 

101from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

102from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

103from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

104from ..dataset.v0_3 import DatasetId as DatasetId 

105from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

106from ..dataset.v0_3 import Uploader as Uploader 

107from ..generic.v0_3 import ( 

108 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

109) 

110from ..generic.v0_3 import Author as Author 

111from ..generic.v0_3 import BadgeDescr as BadgeDescr 

112from ..generic.v0_3 import CiteEntry as CiteEntry 

113from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

114from ..generic.v0_3 import Doi as Doi 

115from ..generic.v0_3 import ( 

116 FileSource_documentation, 

117 GenericModelDescrBase, 

118 LinkedResourceBase, 

119 _author_conv, # pyright: ignore[reportPrivateUsage] 

120 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

121) 

122from ..generic.v0_3 import LicenseId as LicenseId 

123from ..generic.v0_3 import LinkedResource as LinkedResource 

124from ..generic.v0_3 import Maintainer as Maintainer 

125from ..generic.v0_3 import OrcidId as OrcidId 

126from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

127from ..generic.v0_3 import ResourceId as ResourceId 

128from .v0_4 import Author as _Author_v0_4 

129from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

130from .v0_4 import CallableFromDepencency as CallableFromDepencency 

131from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

132from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

133from .v0_4 import ClipDescr as _ClipDescr_v0_4 

134from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

135from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

136from .v0_4 import KnownRunMode as KnownRunMode 

137from .v0_4 import ModelDescr as _ModelDescr_v0_4 

138from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

139from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

140from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

141from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

142from .v0_4 import RunMode as RunMode 

143from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

144from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

145from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

146from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

147from .v0_4 import TensorName as _TensorName_v0_4 

148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

149from .v0_4 import package_weights 

150 

151SpaceUnit = Literal[ 

152 "attometer", 

153 "angstrom", 

154 "centimeter", 

155 "decimeter", 

156 "exameter", 

157 "femtometer", 

158 "foot", 

159 "gigameter", 

160 "hectometer", 

161 "inch", 

162 "kilometer", 

163 "megameter", 

164 "meter", 

165 "micrometer", 

166 "mile", 

167 "millimeter", 

168 "nanometer", 

169 "parsec", 

170 "petameter", 

171 "picometer", 

172 "terameter", 

173 "yard", 

174 "yoctometer", 

175 "yottameter", 

176 "zeptometer", 

177 "zettameter", 

178] 

179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

180 

181TimeUnit = Literal[ 

182 "attosecond", 

183 "centisecond", 

184 "day", 

185 "decisecond", 

186 "exasecond", 

187 "femtosecond", 

188 "gigasecond", 

189 "hectosecond", 

190 "hour", 

191 "kilosecond", 

192 "megasecond", 

193 "microsecond", 

194 "millisecond", 

195 "minute", 

196 "nanosecond", 

197 "petasecond", 

198 "picosecond", 

199 "second", 

200 "terasecond", 

201 "yoctosecond", 

202 "yottasecond", 

203 "zeptosecond", 

204 "zettasecond", 

205] 

206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

207 

208AxisType = Literal["batch", "channel", "index", "time", "space"] 

209 

210_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

211 "b": "batch", 

212 "t": "time", 

213 "i": "index", 

214 "c": "channel", 

215 "x": "space", 

216 "y": "space", 

217 "z": "space", 

218} 

219 

220_AXIS_ID_MAP = { 

221 "b": "batch", 

222 "t": "time", 

223 "i": "index", 

224 "c": "channel", 

225} 

226 

227WeightsFormat = Literal[ 

228 "keras_hdf5", 

229 "keras_v3", 

230 "onnx", 

231 "pytorch_state_dict", 

232 "tensorflow_js", 

233 "tensorflow_saved_model_bundle", 

234 "torchscript", 

235] 

236 

237 

238class TensorId(LowerCaseIdentifier): 

239 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

240 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

241 ] 

242 

243 

244def _normalize_axis_id(a: str): 

245 a = str(a) 

246 normalized = _AXIS_ID_MAP.get(a, a) 

247 if a != normalized: 

248 logger.opt(depth=3).warning( 

249 "Normalized axis id from '{}' to '{}'.", a, normalized 

250 ) 

251 return normalized 

252 

253 

254class AxisId(LowerCaseIdentifier): 

255 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

256 Annotated[ 

257 LowerCaseIdentifierAnno, 

258 MaxLen(16), 

259 AfterValidator(_normalize_axis_id), 

260 ] 

261 ] 

262 

263 

264def _is_batch(a: str) -> bool: 

265 return str(a) == "batch" 

266 

267 

268def _is_not_batch(a: str) -> bool: 

269 return not _is_batch(a) 

270 

271 

272NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

273 

274PreprocessingId = Literal[ 

275 "binarize", 

276 "clip", 

277 "ensure_dtype", 

278 "fixed_zero_mean_unit_variance", 

279 "scale_linear", 

280 "scale_range", 

281 "sigmoid", 

282 "softmax", 

283] 

284PostprocessingId = Literal[ 

285 "binarize", 

286 "clip", 

287 "ensure_dtype", 

288 "fixed_zero_mean_unit_variance", 

289 "scale_linear", 

290 "scale_mean_variance", 

291 "scale_range", 

292 "sigmoid", 

293 "softmax", 

294 "zero_mean_unit_variance", 

295] 

296 

297 

298SAME_AS_TYPE = "<same as type>" 

299 

300 

301ParameterizedSize_N = int 

302""" 

303Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

304""" 

305 

306 

307class ParameterizedSize(Node): 

308 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

309 

310 - **min** and **step** are given by the model description. 

311 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

312 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

313 This allows to adjust the axis size more generically. 

314 """ 

315 

316 N: ClassVar[Type[int]] = ParameterizedSize_N 

317 """Positive integer to parameterize this axis""" 

318 

319 min: Annotated[int, Gt(0)] 

320 step: Annotated[int, Gt(0)] 

321 

322 def validate_size(self, size: int, msg_prefix: str = "") -> int: 

323 if size < self.min: 

324 raise ValueError( 

325 f"{msg_prefix}size {size} < {self.min} (minimum axis size)" 

326 ) 

327 if (size - self.min) % self.step != 0: 

328 raise ValueError( 

329 f"{msg_prefix}size {size} is not parameterized by `min + n*step` =" 

330 + f" `{self.min} + n*{self.step}`" 

331 ) 

332 

333 return size 

334 

335 def get_size(self, n: ParameterizedSize_N) -> int: 

336 return self.min + self.step * n 

337 

338 def get_n(self, s: int) -> ParameterizedSize_N: 

339 """return smallest n parameterizing a size greater or equal than `s`""" 

340 return ceil((s - self.min) / self.step) 

341 

342 

343class DataDependentSize(Node): 

344 min: Annotated[int, Gt(0)] = 1 

345 max: Annotated[Optional[int], Gt(1)] = None 

346 

347 @model_validator(mode="after") 

348 def _validate_max_gt_min(self): 

349 if self.max is not None and self.min >= self.max: 

350 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

351 

352 return self 

353 

354 def validate_size(self, size: int, msg_prefix: str = "") -> int: 

355 if size < self.min: 

356 raise ValueError(f"{msg_prefix}size {size} < {self.min}") 

357 

358 if self.max is not None and size > self.max: 

359 raise ValueError(f"{msg_prefix}size {size} > {self.max}") 

360 

361 return size 

362 

363 

364class SizeReference(Node): 

365 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

366 

367 `axis.size = reference.size * reference.scale / axis.scale + offset` 

368 

369 Note: 

370 1. The axis and the referenced axis need to have the same unit (or no unit). 

371 2. Batch axes may not be referenced. 

372 3. Fractions are rounded down. 

373 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

374 `concatenable` as well with the same block order. 

375 

376 Example: 

377 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

378 Let's assume that we want to express the image height h in relation to its width w 

379 instead of only accepting input images of exactly 100*49 pixels 

380 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

381 

382 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

383 >>> h = SpaceInputAxis( 

384 ... id=AxisId("h"), 

385 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

386 ... unit="millimeter", 

387 ... scale=4, 

388 ... ) 

389 >>> print(h.size.get_size(h, w)) 

390 49 

391 

392 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

393 """ 

394 

395 tensor_id: TensorId 

396 """tensor id of the reference axis""" 

397 

398 axis_id: AxisId 

399 """axis id of the reference axis""" 

400 

401 offset: StrictInt = 0 

402 

403 def get_size( 

404 self, 

405 axis: Union[ 

406 ChannelAxis, 

407 IndexInputAxis, 

408 IndexOutputAxis, 

409 TimeInputAxis, 

410 SpaceInputAxis, 

411 TimeOutputAxis, 

412 TimeOutputAxisWithHalo, 

413 SpaceOutputAxis, 

414 SpaceOutputAxisWithHalo, 

415 ], 

416 ref_axis: Union[ 

417 ChannelAxis, 

418 IndexInputAxis, 

419 IndexOutputAxis, 

420 TimeInputAxis, 

421 SpaceInputAxis, 

422 TimeOutputAxis, 

423 TimeOutputAxisWithHalo, 

424 SpaceOutputAxis, 

425 SpaceOutputAxisWithHalo, 

426 ], 

427 n: ParameterizedSize_N = 0, 

428 ref_size: Optional[int] = None, 

429 ): 

430 """Compute the concrete size for a given axis and its reference axis. 

431 

432 Args: 

433 axis: The axis this [SizeReference][] is the size of. 

434 ref_axis: The reference axis to compute the size from. 

435 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

436 and no fixed **ref_size** is given, 

437 **n** is used to compute the size of the parameterized **ref_axis**. 

438 ref_size: Overwrite the reference size instead of deriving it from 

439 **ref_axis** 

440 (**ref_axis.scale** is still used; any given **n** is ignored). 

441 """ 

442 assert axis.size == self, ( 

443 "Given `axis.size` is not defined by this `SizeReference`" 

444 ) 

445 

446 assert ref_axis.id == self.axis_id, ( 

447 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

448 ) 

449 

450 assert axis.unit == ref_axis.unit, ( 

451 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

452 f" but {axis.unit}!={ref_axis.unit}" 

453 ) 

454 if ref_size is None: 

455 if isinstance(ref_axis.size, (int, float)): 

456 ref_size = ref_axis.size 

457 elif isinstance(ref_axis.size, ParameterizedSize): 

458 ref_size = ref_axis.size.get_size(n) 

459 elif isinstance(ref_axis.size, DataDependentSize): 

460 raise ValueError( 

461 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

462 ) 

463 elif isinstance(ref_axis.size, SizeReference): 

464 raise ValueError( 

465 "Reference axis referenced in `SizeReference` may not be sized by a" 

466 + " `SizeReference` itself." 

467 ) 

468 else: 

469 assert_never(ref_axis.size) 

470 

471 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

472 

473 @staticmethod 

474 def _get_unit( 

475 axis: Union[ 

476 ChannelAxis, 

477 IndexInputAxis, 

478 IndexOutputAxis, 

479 TimeInputAxis, 

480 SpaceInputAxis, 

481 TimeOutputAxis, 

482 TimeOutputAxisWithHalo, 

483 SpaceOutputAxis, 

484 SpaceOutputAxisWithHalo, 

485 ], 

486 ): 

487 return axis.unit 

488 

489 

490class AxisBase(NodeWithExplicitlySetFields): 

491 id: AxisId 

492 """An axis id unique across all axes of one tensor.""" 

493 

494 description: Annotated[str, MaxLen(128)] = "" 

495 """A short description of this axis beyond its type and id.""" 

496 

497 

498class WithHalo(Node): 

499 halo: Annotated[int, Ge(1)] 

500 """The halo should be cropped from the output tensor to avoid boundary effects. 

501 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

502 To document a halo that is already cropped by the model use `size.offset` instead.""" 

503 

504 size: Annotated[ 

505 SizeReference, 

506 Field( 

507 examples=[ 

508 10, 

509 SizeReference( 

510 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

511 ).model_dump(mode="json"), 

512 ] 

513 ), 

514 ] 

515 """reference to another axis with an optional offset (see [SizeReference][])""" 

516 

517 

518BATCH_AXIS_ID = AxisId("batch") 

519 

520 

521class BatchAxis(AxisBase): 

522 implemented_type: ClassVar[Literal["batch"]] = "batch" 

523 if TYPE_CHECKING: 

524 type: Literal["batch"] = "batch" 

525 else: 

526 type: Literal["batch"] 

527 

528 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

529 size: Optional[Literal[1]] = None 

530 """The batch size may be fixed to 1, 

531 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

532 

533 @property 

534 def scale(self): 

535 return 1.0 

536 

537 @property 

538 def concatenable(self): 

539 return True 

540 

541 @property 

542 def unit(self): 

543 return None 

544 

545 

546class ChannelAxis(AxisBase): 

547 implemented_type: ClassVar[Literal["channel"]] = "channel" 

548 if TYPE_CHECKING: 

549 type: Literal["channel"] = "channel" 

550 else: 

551 type: Literal["channel"] 

552 

553 id: NonBatchAxisId = AxisId("channel") 

554 

555 channel_names: NotEmpty[List[Identifier]] 

556 

557 @property 

558 def size(self) -> int: 

559 return len(self.channel_names) 

560 

561 @property 

562 def concatenable(self): 

563 return False 

564 

565 @property 

566 def scale(self) -> float: 

567 return 1.0 

568 

569 @property 

570 def unit(self): 

571 return None 

572 

573 

574class IndexAxisBase(AxisBase): 

575 implemented_type: ClassVar[Literal["index"]] = "index" 

576 if TYPE_CHECKING: 

577 type: Literal["index"] = "index" 

578 else: 

579 type: Literal["index"] 

580 

581 id: NonBatchAxisId = AxisId("index") 

582 

583 @property 

584 def scale(self) -> float: 

585 return 1.0 

586 

587 @property 

588 def unit(self): 

589 return None 

590 

591 

592class _WithInputAxisSize(Node): 

593 size: Annotated[ 

594 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

595 Field( 

596 examples=[ 

597 10, 

598 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

599 SizeReference( 

600 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

601 ).model_dump(mode="json"), 

602 ] 

603 ), 

604 ] 

605 """The size/length of this axis can be specified as 

606 - fixed integer 

607 - parameterized series of valid sizes ([ParameterizedSize][]) 

608 - reference to another axis with an optional offset ([SizeReference][]) 

609 """ 

610 

611 

612class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

613 concatenable: bool = False 

614 """If a model has a `concatenable` input axis, it can be processed blockwise, 

615 splitting a longer sample axis into blocks matching its input tensor description. 

616 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

617 input axis. 

618 """ 

619 

620 

621class IndexOutputAxis(IndexAxisBase): 

622 size: Annotated[ 

623 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

624 Field( 

625 examples=[ 

626 10, 

627 SizeReference( 

628 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

629 ).model_dump(mode="json"), 

630 ] 

631 ), 

632 ] 

633 """The size/length of this axis can be specified as 

634 - fixed integer 

635 - reference to another axis with an optional offset ([SizeReference][]) 

636 - data dependent size using [DataDependentSize][] (size is only known after model inference) 

637 """ 

638 

639 

640class TimeAxisBase(AxisBase): 

641 implemented_type: ClassVar[Literal["time"]] = "time" 

642 if TYPE_CHECKING: 

643 type: Literal["time"] = "time" 

644 else: 

645 type: Literal["time"] 

646 

647 id: NonBatchAxisId = AxisId("time") 

648 unit: Optional[TimeUnit] = None 

649 scale: Annotated[float, Gt(0)] = 1.0 

650 

651 

652class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

653 concatenable: bool = False 

654 """If a model has a `concatenable` input axis, it can be processed blockwise, 

655 splitting a longer sample axis into blocks matching its input tensor description. 

656 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

657 input axis. 

658 """ 

659 

660 

661class SpaceAxisBase(AxisBase): 

662 implemented_type: ClassVar[Literal["space"]] = "space" 

663 if TYPE_CHECKING: 

664 type: Literal["space"] = "space" 

665 else: 

666 type: Literal["space"] 

667 

668 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

669 unit: Optional[SpaceUnit] = None 

670 scale: Annotated[float, Gt(0)] = 1.0 

671 

672 

673class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

674 concatenable: bool = False 

675 """If a model has a `concatenable` input axis, it can be processed blockwise, 

676 splitting a longer sample axis into blocks matching its input tensor description. 

677 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

678 input axis. 

679 """ 

680 

681 

682INPUT_AXIS_TYPES = ( 

683 BatchAxis, 

684 ChannelAxis, 

685 IndexInputAxis, 

686 TimeInputAxis, 

687 SpaceInputAxis, 

688) 

689"""intended for isinstance comparisons in py<3.10""" 

690 

691_InputAxisUnion = Union[ 

692 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

693] 

694InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

695 

696 

697class _WithOutputAxisSize(Node): 

698 size: Annotated[ 

699 Union[Annotated[int, Gt(0)], SizeReference], 

700 Field( 

701 examples=[ 

702 10, 

703 SizeReference( 

704 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

705 ).model_dump(mode="json"), 

706 ] 

707 ), 

708 ] 

709 """The size/length of this axis can be specified as 

710 - fixed integer 

711 - reference to another axis with an optional offset (see [SizeReference][]) 

712 """ 

713 

714 

715class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

716 pass 

717 

718 

719class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

720 pass 

721 

722 

723def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

724 if isinstance(v, dict): 

725 return "with_halo" if "halo" in v else "wo_halo" 

726 else: 

727 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

728 

729 

730_TimeOutputAxisUnion = Annotated[ 

731 Union[ 

732 Annotated[TimeOutputAxis, Tag("wo_halo")], 

733 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

734 ], 

735 Discriminator(_get_halo_axis_discriminator_value), 

736] 

737 

738 

739class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

740 pass 

741 

742 

743class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

744 pass 

745 

746 

747_SpaceOutputAxisUnion = Annotated[ 

748 Union[ 

749 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

750 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

751 ], 

752 Discriminator(_get_halo_axis_discriminator_value), 

753] 

754 

755 

756_OutputAxisUnion = Union[ 

757 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

758] 

759OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

760 

761OUTPUT_AXIS_TYPES = ( 

762 BatchAxis, 

763 ChannelAxis, 

764 IndexOutputAxis, 

765 TimeOutputAxis, 

766 TimeOutputAxisWithHalo, 

767 SpaceOutputAxis, 

768 SpaceOutputAxisWithHalo, 

769) 

770"""intended for isinstance comparisons in py<3.10""" 

771 

772 

773AnyAxis = Union[InputAxis, OutputAxis] 

774 

775ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

776"""intended for isinstance comparisons in py<3.10""" 

777 

778TVs = Union[ 

779 NotEmpty[List[int]], 

780 NotEmpty[List[float]], 

781 NotEmpty[List[bool]], 

782 NotEmpty[List[str]], 

783] 

784 

785 

786NominalOrOrdinalDType = Literal[ 

787 "float32", 

788 "float64", 

789 "uint8", 

790 "int8", 

791 "uint16", 

792 "int16", 

793 "uint32", 

794 "int32", 

795 "uint64", 

796 "int64", 

797 "bool", 

798] 

799 

800 

801class NominalOrOrdinalDataDescr(Node): 

802 values: TVs 

803 """A fixed set of nominal or an ascending sequence of ordinal values. 

804 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

805 String `values` are interpreted as labels for tensor values 0, ..., N. 

806 Note: as YAML 1.2 does not natively support a "set" datatype, 

807 nominal values should be given as a sequence (aka list/array) as well. 

808 """ 

809 

810 type: Annotated[ 

811 NominalOrOrdinalDType, 

812 Field( 

813 examples=[ 

814 "float32", 

815 "uint8", 

816 "uint16", 

817 "int64", 

818 "bool", 

819 ], 

820 ), 

821 ] = "uint8" 

822 

823 @model_validator(mode="after") 

824 def _validate_values_match_type( 

825 self, 

826 ) -> Self: 

827 incompatible: List[Any] = [] 

828 for v in self.values: 

829 if self.type == "bool": 

830 if not isinstance(v, bool): 

831 incompatible.append(v) 

832 elif self.type in DTYPE_LIMITS: 

833 if ( 

834 isinstance(v, (int, float)) 

835 and ( 

836 v < DTYPE_LIMITS[self.type].min 

837 or v > DTYPE_LIMITS[self.type].max 

838 ) 

839 or (isinstance(v, str) and "uint" not in self.type) 

840 or (isinstance(v, float) and "int" in self.type) 

841 ): 

842 incompatible.append(v) 

843 else: 

844 incompatible.append(v) 

845 

846 if len(incompatible) == 5: 

847 incompatible.append("...") 

848 break 

849 

850 if incompatible: 

851 raise ValueError( 

852 f"data type '{self.type}' incompatible with values {incompatible}" 

853 ) 

854 

855 return self 

856 

857 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

858 

859 @property 

860 def range(self): 

861 if isinstance(self.values[0], str): 

862 return 0, len(self.values) - 1 

863 else: 

864 return min(self.values), max(self.values) 

865 

866 

867IntervalOrRatioDType = Literal[ 

868 "float32", 

869 "float64", 

870 "uint8", 

871 "int8", 

872 "uint16", 

873 "int16", 

874 "uint32", 

875 "int32", 

876 "uint64", 

877 "int64", 

878] 

879 

880 

881class IntervalOrRatioDataDescr(Node): 

882 type: Annotated[ # TODO: rename to dtype 

883 IntervalOrRatioDType, 

884 Field( 

885 examples=["float32", "float64", "uint8", "uint16"], 

886 ), 

887 ] = "float32" 

888 range: Tuple[Optional[float], Optional[float]] = ( 

889 None, 

890 None, 

891 ) 

892 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

893 `None` corresponds to min/max of what can be expressed by **type**.""" 

894 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

895 scale: float = 1.0 

896 """Scale for data on an interval (or ratio) scale.""" 

897 offset: Optional[float] = None 

898 """Offset for data on a ratio scale.""" 

899 

900 @model_validator(mode="before") 

901 def _replace_inf(cls, data: Any): 

902 if is_dict(data): 

903 if "range" in data and is_sequence(data["range"]): 

904 forbidden = ( 

905 "inf", 

906 "-inf", 

907 ".inf", 

908 "-.inf", 

909 float("inf"), 

910 float("-inf"), 

911 ) 

912 if any(v in forbidden for v in data["range"]): 

913 issue_warning("replaced 'inf' value", value=data["range"]) 

914 

915 data["range"] = tuple( 

916 (None if v in forbidden else v) for v in data["range"] 

917 ) 

918 

919 return data 

920 

921 

922TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

923 

924 

925class BinarizeKwargs(KwargsNode): 

926 """key word arguments for [BinarizeDescr][]""" 

927 

928 threshold: float 

929 """The fixed threshold""" 

930 

931 

932class BinarizeAlongAxisKwargs(KwargsNode): 

933 """key word arguments for [BinarizeDescr][]""" 

934 

935 threshold: NotEmpty[List[float]] 

936 """The fixed threshold values along `axis`""" 

937 

938 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

939 """The `threshold` axis""" 

940 

941 

942class BinarizeDescr(NodeWithExplicitlySetFields): 

943 """Binarize the tensor with a fixed threshold. 

944 

945 Values above [BinarizeKwargs.threshold][]/[BinarizeAlongAxisKwargs.threshold][] 

946 will be set to one, values below the threshold to zero. 

947 

948 Examples: 

949 - in YAML 

950 ```yaml 

951 postprocessing: 

952 - id: binarize 

953 kwargs: 

954 axis: 'channel' 

955 threshold: [0.25, 0.5, 0.75] 

956 ``` 

957 - in Python: 

958 

959 >>> postprocessing = [BinarizeDescr( 

960 ... kwargs=BinarizeAlongAxisKwargs( 

961 ... axis=AxisId('channel'), 

962 ... threshold=[0.25, 0.5, 0.75], 

963 ... ) 

964 ... )] 

965 """ 

966 

967 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

968 if TYPE_CHECKING: 

969 id: Literal["binarize"] = "binarize" 

970 else: 

971 id: Literal["binarize"] 

972 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

973 

974 

975class ClipKwargs(KwargsNode): 

976 """key word arguments for [ClipDescr][]""" 

977 

978 min: Optional[float] = None 

979 """Minimum value for clipping. 

980 

981 Exclusive with [min_percentile][] 

982 """ 

983 min_percentile: Optional[Annotated[float, Interval(ge=0, lt=100)]] = None 

984 """Minimum percentile for clipping. 

985 

986 Exclusive with [min][]. 

987 

988 In range [0, 100). 

989 """ 

990 

991 max: Optional[float] = None 

992 """Maximum value for clipping. 

993 

994 Exclusive with `max_percentile`. 

995 """ 

996 max_percentile: Optional[Annotated[float, Interval(gt=1, le=100)]] = None 

997 """Maximum percentile for clipping. 

998 

999 Exclusive with `max`. 

1000 

1001 In range (1, 100]. 

1002 """ 

1003 

1004 axes: Annotated[ 

1005 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1006 ] = None 

1007 """The subset of axes to determine percentiles jointly, 

1008 

1009 i.e. axes to reduce to compute min/max from `min_percentile`/`max_percentile`. 

1010 For example to clip 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1011 resulting in a tensor of equal shape with clipped values per channel, specify `axes=('batch', 'x', 'y')`. 

1012 To clip samples independently, leave out the 'batch' axis. 

1013 

1014 Only valid if `min_percentile` and/or `max_percentile` are set. 

1015 

1016 Default: Compute percentiles over all axes jointly.""" 

1017 

1018 @model_validator(mode="after") 

1019 def _validate(self) -> Self: 

1020 if (self.min is not None) and (self.min_percentile is not None): 

1021 raise ValueError( 

1022 "Only one of `min` and `min_percentile` may be set, not both." 

1023 ) 

1024 if (self.max is not None) and (self.max_percentile is not None): 

1025 raise ValueError( 

1026 "Only one of `max` and `max_percentile` may be set, not both." 

1027 ) 

1028 if ( 

1029 self.min is None 

1030 and self.min_percentile is None 

1031 and self.max is None 

1032 and self.max_percentile is None 

1033 ): 

1034 raise ValueError( 

1035 "At least one of `min`, `min_percentile`, `max`, or `max_percentile` must be set." 

1036 ) 

1037 

1038 if ( 

1039 self.axes is not None 

1040 and self.min_percentile is None 

1041 and self.max_percentile is None 

1042 ): 

1043 raise ValueError( 

1044 "If `axes` is set, at least one of `min_percentile` or `max_percentile` must be set." 

1045 ) 

1046 

1047 return self 

1048 

1049 

1050class ClipDescr(NodeWithExplicitlySetFields): 

1051 """Set tensor values below min to min and above max to max. 

1052 

1053 See `ScaleRangeDescr` for examples. 

1054 """ 

1055 

1056 implemented_id: ClassVar[Literal["clip"]] = "clip" 

1057 if TYPE_CHECKING: 

1058 id: Literal["clip"] = "clip" 

1059 else: 

1060 id: Literal["clip"] 

1061 

1062 kwargs: ClipKwargs 

1063 

1064 

1065class EnsureDtypeKwargs(KwargsNode): 

1066 """key word arguments for [EnsureDtypeDescr][]""" 

1067 

1068 dtype: Literal[ 

1069 "float32", 

1070 "float64", 

1071 "uint8", 

1072 "int8", 

1073 "uint16", 

1074 "int16", 

1075 "uint32", 

1076 "int32", 

1077 "uint64", 

1078 "int64", 

1079 "bool", 

1080 ] 

1081 

1082 

1083class EnsureDtypeDescr(NodeWithExplicitlySetFields): 

1084 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

1085 

1086 This can for example be used to ensure the inner neural network model gets a 

1087 different input tensor data type than the fully described bioimage.io model does. 

1088 

1089 Examples: 

1090 The described bioimage.io model (incl. preprocessing) accepts any 

1091 float32-compatible tensor, normalizes it with percentiles and clipping and then 

1092 casts it to uint8, which is what the neural network in this example expects. 

1093 - in YAML 

1094 ```yaml 

1095 inputs: 

1096 - data: 

1097 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

1098 preprocessing: 

1099 - id: scale_range 

1100 kwargs: 

1101 axes: ['y', 'x'] 

1102 max_percentile: 99.8 

1103 min_percentile: 5.0 

1104 - id: clip 

1105 kwargs: 

1106 min: 0.0 

1107 max: 1.0 

1108 - id: ensure_dtype # the neural network of the model requires uint8 

1109 kwargs: 

1110 dtype: uint8 

1111 ``` 

1112 - in Python: 

1113 >>> preprocessing = [ 

1114 ... ScaleRangeDescr( 

1115 ... kwargs=ScaleRangeKwargs( 

1116 ... axes= (AxisId('y'), AxisId('x')), 

1117 ... max_percentile= 99.8, 

1118 ... min_percentile= 5.0, 

1119 ... ) 

1120 ... ), 

1121 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

1122 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

1123 ... ] 

1124 """ 

1125 

1126 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1127 if TYPE_CHECKING: 

1128 id: Literal["ensure_dtype"] = "ensure_dtype" 

1129 else: 

1130 id: Literal["ensure_dtype"] 

1131 

1132 kwargs: EnsureDtypeKwargs 

1133 

1134 

1135class ScaleLinearKwargs(KwargsNode): 

1136 """Key word arguments for [ScaleLinearDescr][]""" 

1137 

1138 gain: float = 1.0 

1139 """multiplicative factor""" 

1140 

1141 offset: float = 0.0 

1142 """additive term""" 

1143 

1144 @model_validator(mode="after") 

1145 def _validate(self) -> Self: 

1146 if self.gain == 1.0 and self.offset == 0.0: 

1147 raise ValueError( 

1148 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1149 + " != 0.0." 

1150 ) 

1151 

1152 return self 

1153 

1154 

1155class ScaleLinearAlongAxisKwargs(KwargsNode): 

1156 """Key word arguments for [ScaleLinearDescr][]""" 

1157 

1158 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1159 """The axis of gain and offset values.""" 

1160 

1161 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1162 """multiplicative factor""" 

1163 

1164 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1165 """additive term""" 

1166 

1167 @model_validator(mode="after") 

1168 def _validate(self) -> Self: 

1169 if isinstance(self.gain, list): 

1170 if isinstance(self.offset, list): 

1171 if len(self.gain) != len(self.offset): 

1172 raise ValueError( 

1173 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1174 ) 

1175 else: 

1176 self.offset = [float(self.offset)] * len(self.gain) 

1177 elif isinstance(self.offset, list): 

1178 self.gain = [float(self.gain)] * len(self.offset) 

1179 else: 

1180 raise ValueError( 

1181 "Do not specify an `axis` for scalar gain and offset values." 

1182 ) 

1183 

1184 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1185 raise ValueError( 

1186 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1187 + " != 0.0." 

1188 ) 

1189 

1190 return self 

1191 

1192 

1193class ScaleLinearDescr(NodeWithExplicitlySetFields): 

1194 """Fixed linear scaling. 

1195 

1196 Examples: 

1197 1. Scale with scalar gain and offset 

1198 - in YAML 

1199 ```yaml 

1200 preprocessing: 

1201 - id: scale_linear 

1202 kwargs: 

1203 gain: 2.0 

1204 offset: 3.0 

1205 ``` 

1206 - in Python: 

1207 

1208 >>> preprocessing = [ 

1209 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1210 ... ] 

1211 

1212 2. Independent scaling along an axis 

1213 - in YAML 

1214 ```yaml 

1215 preprocessing: 

1216 - id: scale_linear 

1217 kwargs: 

1218 axis: 'channel' 

1219 gain: [1.0, 2.0, 3.0] 

1220 ``` 

1221 - in Python: 

1222 

1223 >>> preprocessing = [ 

1224 ... ScaleLinearDescr( 

1225 ... kwargs=ScaleLinearAlongAxisKwargs( 

1226 ... axis=AxisId("channel"), 

1227 ... gain=[1.0, 2.0, 3.0], 

1228 ... ) 

1229 ... ) 

1230 ... ] 

1231 

1232 """ 

1233 

1234 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1235 if TYPE_CHECKING: 

1236 id: Literal["scale_linear"] = "scale_linear" 

1237 else: 

1238 id: Literal["scale_linear"] 

1239 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1240 

1241 

1242class SigmoidDescr(NodeWithExplicitlySetFields): 

1243 """The logistic sigmoid function, a.k.a. expit function. 

1244 

1245 Examples: 

1246 - in YAML 

1247 ```yaml 

1248 postprocessing: 

1249 - id: sigmoid 

1250 ``` 

1251 - in Python: 

1252 

1253 >>> postprocessing = [SigmoidDescr()] 

1254 """ 

1255 

1256 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1257 if TYPE_CHECKING: 

1258 id: Literal["sigmoid"] = "sigmoid" 

1259 else: 

1260 id: Literal["sigmoid"] 

1261 

1262 @property 

1263 def kwargs(self) -> KwargsNode: 

1264 """empty kwargs""" 

1265 return KwargsNode() 

1266 

1267 

1268class SoftmaxKwargs(KwargsNode): 

1269 """key word arguments for [SoftmaxDescr][]""" 

1270 

1271 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 

1272 """The axis to apply the softmax function along. 

1273 Note: 

1274 Defaults to 'channel' axis 

1275 (which may not exist, in which case 

1276 a different axis id has to be specified). 

1277 """ 

1278 

1279 

1280class SoftmaxDescr(NodeWithExplicitlySetFields): 

1281 """The softmax function. 

1282 

1283 Examples: 

1284 - in YAML 

1285 ```yaml 

1286 postprocessing: 

1287 - id: softmax 

1288 kwargs: 

1289 axis: channel 

1290 ``` 

1291 - in Python: 

1292 

1293 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 

1294 """ 

1295 

1296 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 

1297 if TYPE_CHECKING: 

1298 id: Literal["softmax"] = "softmax" 

1299 else: 

1300 id: Literal["softmax"] 

1301 

1302 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 

1303 

1304 

1305class _StardistPostprocessingKwargsBase(KwargsNode): 

1306 """key word arguments for [StardistPostprocessingDescr][]""" 

1307 

1308 prob_threshold: float 

1309 """The probability threshold for object candidate selection.""" 

1310 

1311 nms_threshold: float 

1312 """The IoU threshold for non-maximum suppression.""" 

1313 

1314 

1315class StardistPostprocessingKwargs2D(_StardistPostprocessingKwargsBase): 

1316 grid: Tuple[int, int] 

1317 """Grid size of network predictions.""" 

1318 

1319 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int]]] 

1320 """Border region in which object probability is set to zero.""" 

1321 

1322 

1323class StardistPostprocessingKwargs3D(_StardistPostprocessingKwargsBase): 

1324 grid: Tuple[int, int, int] 

1325 """Grid size of network predictions.""" 

1326 

1327 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]] 

1328 """Border region in which object probability is set to zero.""" 

1329 

1330 n_rays: int 

1331 """Number of rays for 3D star-convex polyhedra.""" 

1332 

1333 anisotropy: Tuple[float, float, float] 

1334 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis.""" 

1335 

1336 overlap_label: Optional[int] = None 

1337 """Optional label to apply to any area of overlapping predicted objects.""" 

1338 

1339 

1340class StardistPostprocessingDescr(NodeWithExplicitlySetFields): 

1341 """Stardist postprocessing including non-maximum suppression and converting polygon representations to instance labels 

1342 

1343 as described in: 

1344 - Uwe Schmidt, Martin Weigert, Coleman Broaddus, and Gene Myers. 

1345 [*Cell Detection with Star-convex Polygons*](https://arxiv.org/abs/1806.03535). 

1346 International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. 

1347 - Martin Weigert, Uwe Schmidt, Robert Haase, Ko Sugawara, and Gene Myers. 

1348 [*Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy*](http://openaccess.thecvf.com/content_WACV_2020/papers/Weigert_Star-convex_Polyhedra_for_3D_Object_Detection_and_Segmentation_in_Microscopy_WACV_2020_paper.pdf). 

1349 The IEEE Winter Conference on Applications of Computer Vision (WACV), Snowmass Village, Colorado, March 2020. 

1350 

1351 Note: Only available if the `stardist` package is installed. 

1352 """ 

1353 

1354 implemented_id: ClassVar[Literal["stardist_postprocessing"]] = ( 

1355 "stardist_postprocessing" 

1356 ) 

1357 if TYPE_CHECKING: 

1358 id: Literal["stardist_postprocessing"] = "stardist_postprocessing" 

1359 else: 

1360 id: Literal["stardist_postprocessing"] 

1361 

1362 kwargs: Union[StardistPostprocessingKwargs2D, StardistPostprocessingKwargs3D] 

1363 

1364 

1365class FixedZeroMeanUnitVarianceKwargs(KwargsNode): 

1366 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]""" 

1367 

1368 mean: float 

1369 """The mean value to normalize with.""" 

1370 

1371 std: Annotated[float, Ge(1e-6)] 

1372 """The standard deviation value to normalize with.""" 

1373 

1374 

1375class FixedZeroMeanUnitVarianceAlongAxisKwargs(KwargsNode): 

1376 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]""" 

1377 

1378 mean: NotEmpty[List[float]] 

1379 """The mean value(s) to normalize with.""" 

1380 

1381 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1382 """The standard deviation value(s) to normalize with. 

1383 Size must match `mean` values.""" 

1384 

1385 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1386 """The axis of the mean/std values to normalize each entry along that dimension 

1387 separately.""" 

1388 

1389 @model_validator(mode="after") 

1390 def _mean_and_std_match(self) -> Self: 

1391 if len(self.mean) != len(self.std): 

1392 raise ValueError( 

1393 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1394 + " must match." 

1395 ) 

1396 

1397 return self 

1398 

1399 

1400class FixedZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

1401 """Subtract a given mean and divide by the standard deviation. 

1402 

1403 Normalize with fixed, precomputed values for 

1404 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1405 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1406 axes. 

1407 

1408 Examples: 

1409 1. scalar value for whole tensor 

1410 - in YAML 

1411 ```yaml 

1412 preprocessing: 

1413 - id: fixed_zero_mean_unit_variance 

1414 kwargs: 

1415 mean: 103.5 

1416 std: 13.7 

1417 ``` 

1418 - in Python 

1419 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1420 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1421 ... )] 

1422 

1423 2. independently along an axis 

1424 - in YAML 

1425 ```yaml 

1426 preprocessing: 

1427 - id: fixed_zero_mean_unit_variance 

1428 kwargs: 

1429 axis: channel 

1430 mean: [101.5, 102.5, 103.5] 

1431 std: [11.7, 12.7, 13.7] 

1432 ``` 

1433 - in Python 

1434 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1435 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1436 ... axis=AxisId("channel"), 

1437 ... mean=[101.5, 102.5, 103.5], 

1438 ... std=[11.7, 12.7, 13.7], 

1439 ... ) 

1440 ... )] 

1441 """ 

1442 

1443 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1444 "fixed_zero_mean_unit_variance" 

1445 ) 

1446 if TYPE_CHECKING: 

1447 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1448 else: 

1449 id: Literal["fixed_zero_mean_unit_variance"] 

1450 

1451 kwargs: Union[ 

1452 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1453 ] 

1454 

1455 

1456class ZeroMeanUnitVarianceKwargs(KwargsNode): 

1457 """key word arguments for [ZeroMeanUnitVarianceDescr][]""" 

1458 

1459 axes: Annotated[ 

1460 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1461 ] = None 

1462 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1463 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1464 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1465 To normalize each sample independently leave out the 'batch' axis. 

1466 Default: Scale all axes jointly.""" 

1467 

1468 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1469 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1470 

1471 

1472class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

1473 """Subtract mean and divide by variance. 

1474 

1475 Examples: 

1476 Subtract tensor mean and variance 

1477 - in YAML 

1478 ```yaml 

1479 preprocessing: 

1480 - id: zero_mean_unit_variance 

1481 ``` 

1482 - in Python 

1483 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1484 """ 

1485 

1486 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1487 "zero_mean_unit_variance" 

1488 ) 

1489 if TYPE_CHECKING: 

1490 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1491 else: 

1492 id: Literal["zero_mean_unit_variance"] 

1493 

1494 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1495 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 

1496 ) 

1497 

1498 

1499class ScaleRangeKwargs(KwargsNode): 

1500 """key word arguments for [ScaleRangeDescr][] 

1501 

1502 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1503 this processing step normalizes data to the [0, 1] intervall. 

1504 For other percentiles the normalized values will partially be outside the [0, 1] 

1505 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1506 normalized values to a range. 

1507 """ 

1508 

1509 axes: Annotated[ 

1510 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1511 ] = None 

1512 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1513 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1514 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1515 To normalize samples independently, leave out the "batch" axis. 

1516 Default: Scale all axes jointly.""" 

1517 

1518 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1519 """The lower percentile used to determine the value to align with zero.""" 

1520 

1521 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1522 """The upper percentile used to determine the value to align with one. 

1523 Has to be bigger than `min_percentile`. 

1524 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1525 accepting percentiles specified in the range 0.0 to 1.0.""" 

1526 

1527 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1528 """Epsilon for numeric stability. 

1529 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1530 with `v_lower,v_upper` values at the respective percentiles.""" 

1531 

1532 reference_tensor: Optional[TensorId] = None 

1533 """ID of the unprocessed input tensor to compute the percentiles from. 

1534 Default: The tensor itself. 

1535 """ 

1536 

1537 @field_validator("max_percentile", mode="after") 

1538 @classmethod 

1539 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1540 if (min_p := info.data["min_percentile"]) >= value: 

1541 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1542 

1543 return value 

1544 

1545 

1546class ScaleRangeDescr(NodeWithExplicitlySetFields): 

1547 """Scale with percentiles. 

1548 

1549 Examples: 

1550 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1551 - in YAML 

1552 ```yaml 

1553 preprocessing: 

1554 - id: scale_range 

1555 kwargs: 

1556 axes: ['y', 'x'] 

1557 max_percentile: 99.8 

1558 min_percentile: 5.0 

1559 ``` 

1560 - in Python 

1561 

1562 >>> preprocessing = [ 

1563 ... ScaleRangeDescr( 

1564 ... kwargs=ScaleRangeKwargs( 

1565 ... axes= (AxisId('y'), AxisId('x')), 

1566 ... max_percentile= 99.8, 

1567 ... min_percentile= 5.0, 

1568 ... ) 

1569 ... ) 

1570 ... ] 

1571 

1572 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1573 - in YAML 

1574 ```yaml 

1575 preprocessing: 

1576 - id: scale_range 

1577 kwargs: 

1578 axes: ['y', 'x'] 

1579 max_percentile: 99.8 

1580 min_percentile: 5.0 

1581 - id: clip 

1582 kwargs: 

1583 min: 0.0 

1584 max: 1.0 

1585 ``` 

1586 - in Python 

1587 

1588 >>> preprocessing = [ 

1589 ... ScaleRangeDescr( 

1590 ... kwargs=ScaleRangeKwargs( 

1591 ... axes= (AxisId('y'), AxisId('x')), 

1592 ... max_percentile= 99.8, 

1593 ... min_percentile= 5.0, 

1594 ... ) 

1595 ... ), 

1596 ... ClipDescr( 

1597 ... kwargs=ClipKwargs( 

1598 ... min=0.0, 

1599 ... max=1.0, 

1600 ... ) 

1601 ... ), 

1602 ... ] 

1603 

1604 """ 

1605 

1606 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1607 if TYPE_CHECKING: 

1608 id: Literal["scale_range"] = "scale_range" 

1609 else: 

1610 id: Literal["scale_range"] 

1611 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 

1612 

1613 

1614class ScaleMeanVarianceKwargs(KwargsNode): 

1615 """key word arguments for [ScaleMeanVarianceKwargs][]""" 

1616 

1617 reference_tensor: TensorId 

1618 """ID of unprocessed input tensor to match.""" 

1619 

1620 axes: Annotated[ 

1621 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1622 ] = None 

1623 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1624 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1625 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1626 To normalize samples independently, leave out the 'batch' axis. 

1627 Default: Scale all axes jointly.""" 

1628 

1629 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1630 """Epsilon for numeric stability: 

1631 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1632 

1633 

1634class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields): 

1635 """Scale a tensor's data distribution to match another tensor's mean/std. 

1636 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1637 """ 

1638 

1639 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1640 if TYPE_CHECKING: 

1641 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1642 else: 

1643 id: Literal["scale_mean_variance"] 

1644 kwargs: ScaleMeanVarianceKwargs 

1645 

1646 

1647PreprocessingDescr = Annotated[ 

1648 Union[ 

1649 BinarizeDescr, 

1650 ClipDescr, 

1651 EnsureDtypeDescr, 

1652 FixedZeroMeanUnitVarianceDescr, 

1653 ScaleLinearDescr, 

1654 ScaleRangeDescr, 

1655 SigmoidDescr, 

1656 SoftmaxDescr, 

1657 ZeroMeanUnitVarianceDescr, 

1658 ], 

1659 Discriminator("id"), 

1660] 

1661PostprocessingDescr = Annotated[ 

1662 Union[ 

1663 BinarizeDescr, 

1664 ClipDescr, 

1665 EnsureDtypeDescr, 

1666 FixedZeroMeanUnitVarianceDescr, 

1667 ScaleLinearDescr, 

1668 ScaleMeanVarianceDescr, 

1669 ScaleRangeDescr, 

1670 SigmoidDescr, 

1671 SoftmaxDescr, 

1672 StardistPostprocessingDescr, 

1673 ZeroMeanUnitVarianceDescr, 

1674 ], 

1675 Discriminator("id"), 

1676] 

1677 

1678IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1679 

1680 

1681class TensorDescrBase(Node, Generic[IO_AxisT]): 

1682 id: TensorId 

1683 """Tensor id. No duplicates are allowed.""" 

1684 

1685 description: Annotated[str, MaxLen(128)] = "" 

1686 """free text description""" 

1687 

1688 axes: NotEmpty[Sequence[IO_AxisT]] 

1689 """tensor axes""" 

1690 

1691 @property 

1692 def shape(self): 

1693 return tuple(a.size for a in self.axes) 

1694 

1695 @field_validator("axes", mode="after", check_fields=False) 

1696 @classmethod 

1697 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1698 batch_axes = [a for a in axes if a.type == "batch"] 

1699 if len(batch_axes) > 1: 

1700 raise ValueError( 

1701 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1702 ) 

1703 

1704 seen_ids: Set[AxisId] = set() 

1705 duplicate_axes_ids: Set[AxisId] = set() 

1706 for a in axes: 

1707 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1708 

1709 if duplicate_axes_ids: 

1710 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1711 

1712 return axes 

1713 

1714 test_tensor: FAIR[Optional[FileDescr_]] = None 

1715 """An example tensor to use for testing. 

1716 Using the model with the test input tensors is expected to yield the test output tensors. 

1717 Each test tensor has be a an ndarray in the 

1718 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1719 The file extension must be '.npy'.""" 

1720 

1721 sample_tensor: FAIR[Optional[FileDescr_]] = None 

1722 """A sample tensor to illustrate a possible input/output for the model, 

1723 The sample image primarily serves to inform a human user about an example use case 

1724 and is typically stored as .hdf5, .png or .tiff. 

1725 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1726 (numpy's `.npy` format is not supported). 

1727 The image dimensionality has to match the number of axes specified in this tensor description. 

1728 """ 

1729 

1730 @model_validator(mode="after") 

1731 def _validate_sample_tensor(self) -> Self: 

1732 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1733 return self 

1734 

1735 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1736 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType] 

1737 reader.read(), 

1738 extension=PurePosixPath(reader.original_file_name).suffix, 

1739 ) 

1740 n_dims = len(tensor.squeeze().shape) 

1741 n_dims_min = n_dims_max = len(self.axes) 

1742 

1743 for a in self.axes: 

1744 if isinstance(a, BatchAxis): 

1745 n_dims_min -= 1 

1746 elif isinstance(a.size, int): 

1747 if a.size == 1: 

1748 n_dims_min -= 1 

1749 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1750 if a.size.min == 1: 

1751 n_dims_min -= 1 

1752 elif isinstance(a.size, SizeReference): 

1753 if a.size.offset < 2: 

1754 # size reference may result in singleton axis 

1755 n_dims_min -= 1 

1756 else: 

1757 assert_never(a.size) 

1758 

1759 n_dims_min = max(0, n_dims_min) 

1760 if n_dims < n_dims_min or n_dims > n_dims_max: 

1761 raise ValueError( 

1762 f"Expected sample tensor to have {n_dims_min} to" 

1763 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1764 ) 

1765 

1766 return self 

1767 

1768 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1769 IntervalOrRatioDataDescr() 

1770 ) 

1771 """Description of the tensor's data values, optionally per channel. 

1772 If specified per channel, the data `type` needs to match across channels.""" 

1773 

1774 @property 

1775 def dtype( 

1776 self, 

1777 ) -> Literal[ 

1778 "float32", 

1779 "float64", 

1780 "uint8", 

1781 "int8", 

1782 "uint16", 

1783 "int16", 

1784 "uint32", 

1785 "int32", 

1786 "uint64", 

1787 "int64", 

1788 "bool", 

1789 ]: 

1790 """dtype as specified under `data.type` or `data[i].type`""" 

1791 if isinstance(self.data, collections.abc.Sequence): 

1792 return self.data[0].type 

1793 else: 

1794 return self.data.type 

1795 

1796 @field_validator("data", mode="after") 

1797 @classmethod 

1798 def _check_data_type_across_channels( 

1799 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1800 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1801 if not isinstance(value, list): 

1802 return value 

1803 

1804 dtypes = {t.type for t in value} 

1805 if len(dtypes) > 1: 

1806 raise ValueError( 

1807 "Tensor data descriptions per channel need to agree in their data" 

1808 + f" `type`, but found {dtypes}." 

1809 ) 

1810 

1811 return value 

1812 

1813 @model_validator(mode="after") 

1814 def _check_data_matches_channelaxis(self) -> Self: 

1815 if not isinstance(self.data, (list, tuple)): 

1816 return self 

1817 

1818 for a in self.axes: 

1819 if isinstance(a, ChannelAxis): 

1820 size = a.size 

1821 assert isinstance(size, int) 

1822 break 

1823 else: 

1824 return self 

1825 

1826 if len(self.data) != size: 

1827 raise ValueError( 

1828 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1829 + f" '{a.id}' axis has size {size}." 

1830 ) 

1831 

1832 return self 

1833 

1834 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1835 if len(array.shape) != len(self.axes): 

1836 raise ValueError( 

1837 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1838 + f" incompatible with {len(self.axes)} axes." 

1839 ) 

1840 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1841 

1842 

1843class InputTensorDescr(TensorDescrBase[InputAxis]): 

1844 id: TensorId = TensorId("input") 

1845 """Input tensor id. 

1846 No duplicates are allowed across all inputs and outputs.""" 

1847 

1848 optional: bool = False 

1849 """indicates that this tensor may be `None`""" 

1850 

1851 preprocessing: List[PreprocessingDescr] = Field( 

1852 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 

1853 ) 

1854 

1855 """Description of how this input should be preprocessed. 

1856 

1857 notes: 

1858 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1859 to ensure an input tensor's data type matches the input tensor's data description. 

1860 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1861 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1862 changing the data type. 

1863 """ 

1864 

1865 @model_validator(mode="after") 

1866 def _validate_preprocessing_kwargs(self) -> Self: 

1867 axes_ids = [a.id for a in self.axes] 

1868 for p in self.preprocessing: 

1869 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1870 if kwargs_axes is None: 

1871 continue 

1872 

1873 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1874 raise ValueError( 

1875 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1876 ) 

1877 

1878 if any(a not in axes_ids for a in kwargs_axes): 

1879 raise ValueError( 

1880 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1881 ) 

1882 

1883 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1884 dtype = self.data.type 

1885 else: 

1886 dtype = self.data[0].type 

1887 

1888 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1889 if not self.preprocessing or not isinstance( 

1890 self.preprocessing[0], EnsureDtypeDescr 

1891 ): 

1892 self.preprocessing.insert( 

1893 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1894 ) 

1895 

1896 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1897 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1898 self.preprocessing.append( 

1899 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1900 ) 

1901 

1902 return self 

1903 

1904 

1905def convert_axes( 

1906 axes: str, 

1907 *, 

1908 shape: Union[ 

1909 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1910 ], 

1911 tensor_type: Literal["input", "output"], 

1912 halo: Optional[Sequence[int]], 

1913 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1914): 

1915 ret: List[AnyAxis] = [] 

1916 for i, a in enumerate(axes): 

1917 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1918 if axis_type == "batch": 

1919 ret.append(BatchAxis()) 

1920 continue 

1921 

1922 scale = 1.0 

1923 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1924 if shape.step[i] == 0: 

1925 size = shape.min[i] 

1926 else: 

1927 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1928 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1929 ref_t = str(shape.reference_tensor) 

1930 if ref_t.count(".") == 1: 

1931 t_id, orig_a_id = ref_t.split(".") 

1932 else: 

1933 t_id = ref_t 

1934 orig_a_id = a 

1935 

1936 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1937 if not (orig_scale := shape.scale[i]): 

1938 # old way to insert a new axis dimension 

1939 size = int(2 * shape.offset[i]) 

1940 else: 

1941 scale = 1 / orig_scale 

1942 if axis_type in ("channel", "index"): 

1943 # these axes no longer have a scale 

1944 offset_from_scale = orig_scale * size_refs.get( 

1945 _TensorName_v0_4(t_id), {} 

1946 ).get(orig_a_id, 0) 

1947 else: 

1948 offset_from_scale = 0 

1949 size = SizeReference( 

1950 tensor_id=TensorId(t_id), 

1951 axis_id=AxisId(a_id), 

1952 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1953 ) 

1954 else: 

1955 size = shape[i] 

1956 

1957 if axis_type == "time": 

1958 if tensor_type == "input": 

1959 ret.append(TimeInputAxis(size=size, scale=scale)) 

1960 else: 

1961 assert not isinstance(size, ParameterizedSize) 

1962 if halo is None: 

1963 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1964 else: 

1965 assert not isinstance(size, int) 

1966 ret.append( 

1967 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1968 ) 

1969 

1970 elif axis_type == "index": 

1971 if tensor_type == "input": 

1972 ret.append(IndexInputAxis(size=size)) 

1973 else: 

1974 if isinstance(size, ParameterizedSize): 

1975 size = DataDependentSize(min=size.min) 

1976 

1977 ret.append(IndexOutputAxis(size=size)) 

1978 elif axis_type == "channel": 

1979 assert not isinstance(size, ParameterizedSize) 

1980 if isinstance(size, SizeReference): 

1981 warnings.warn( 

1982 "Conversion of channel size from an implicit output shape may be" 

1983 + " wrong" 

1984 ) 

1985 ret.append( 

1986 ChannelAxis( 

1987 channel_names=[ 

1988 Identifier(f"channel{i}") for i in range(size.offset) 

1989 ] 

1990 ) 

1991 ) 

1992 else: 

1993 ret.append( 

1994 ChannelAxis( 

1995 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1996 ) 

1997 ) 

1998 elif axis_type == "space": 

1999 if tensor_type == "input": 

2000 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

2001 else: 

2002 assert not isinstance(size, ParameterizedSize) 

2003 if halo is None or halo[i] == 0: 

2004 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

2005 elif isinstance(size, int): 

2006 raise NotImplementedError( 

2007 f"output axis with halo and fixed size (here {size}) not allowed" 

2008 ) 

2009 else: 

2010 ret.append( 

2011 SpaceOutputAxisWithHalo( 

2012 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

2013 ) 

2014 ) 

2015 

2016 return ret 

2017 

2018 

2019def _axes_letters_to_ids( 

2020 axes: Optional[str], 

2021) -> Optional[List[AxisId]]: 

2022 if axes is None: 

2023 return None 

2024 

2025 return [AxisId(a) for a in axes] 

2026 

2027 

2028def _get_complement_v04_axis( 

2029 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

2030) -> Optional[AxisId]: 

2031 if axes is None: 

2032 return None 

2033 

2034 non_complement_axes = set(axes) | {"b"} 

2035 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

2036 if len(complement_axes) > 1: 

2037 raise ValueError( 

2038 f"Expected none or a single complement axis, but axes '{axes}' " 

2039 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

2040 ) 

2041 

2042 return None if not complement_axes else AxisId(complement_axes[0]) 

2043 

2044 

2045def _convert_proc( 

2046 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

2047 tensor_axes: Sequence[str], 

2048) -> Union[PreprocessingDescr, PostprocessingDescr]: 

2049 if isinstance(p, _BinarizeDescr_v0_4): 

2050 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

2051 elif isinstance(p, _ClipDescr_v0_4): 

2052 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

2053 elif isinstance(p, _SigmoidDescr_v0_4): 

2054 return SigmoidDescr() 

2055 elif isinstance(p, _ScaleLinearDescr_v0_4): 

2056 axes = _axes_letters_to_ids(p.kwargs.axes) 

2057 if p.kwargs.axes is None: 

2058 axis = None 

2059 else: 

2060 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

2061 

2062 if axis is None: 

2063 assert not isinstance(p.kwargs.gain, list) 

2064 assert not isinstance(p.kwargs.offset, list) 

2065 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

2066 else: 

2067 kwargs = ScaleLinearAlongAxisKwargs( 

2068 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

2069 ) 

2070 return ScaleLinearDescr(kwargs=kwargs) 

2071 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

2072 return ScaleMeanVarianceDescr( 

2073 kwargs=ScaleMeanVarianceKwargs( 

2074 axes=_axes_letters_to_ids(p.kwargs.axes), 

2075 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

2076 eps=p.kwargs.eps, 

2077 ) 

2078 ) 

2079 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

2080 if p.kwargs.mode == "fixed": 

2081 mean = p.kwargs.mean 

2082 std = p.kwargs.std 

2083 assert mean is not None 

2084 assert std is not None 

2085 

2086 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

2087 

2088 if axis is None: 

2089 if isinstance(mean, list): 

2090 raise ValueError("Expected single float value for mean, not <list>") 

2091 if isinstance(std, list): 

2092 raise ValueError("Expected single float value for std, not <list>") 

2093 return FixedZeroMeanUnitVarianceDescr( 

2094 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct( 

2095 mean=mean, 

2096 std=std, 

2097 ) 

2098 ) 

2099 else: 

2100 if not isinstance(mean, list): 

2101 mean = [float(mean)] 

2102 if not isinstance(std, list): 

2103 std = [float(std)] 

2104 

2105 return FixedZeroMeanUnitVarianceDescr( 

2106 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

2107 axis=axis, mean=mean, std=std 

2108 ) 

2109 ) 

2110 

2111 else: 

2112 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

2113 if p.kwargs.mode == "per_dataset": 

2114 axes = [AxisId("batch")] + axes 

2115 if not axes: 

2116 axes = None 

2117 return ZeroMeanUnitVarianceDescr( 

2118 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

2119 ) 

2120 

2121 elif isinstance(p, _ScaleRangeDescr_v0_4): 

2122 return ScaleRangeDescr( 

2123 kwargs=ScaleRangeKwargs( 

2124 axes=_axes_letters_to_ids(p.kwargs.axes), 

2125 min_percentile=p.kwargs.min_percentile, 

2126 max_percentile=p.kwargs.max_percentile, 

2127 eps=p.kwargs.eps, 

2128 ) 

2129 ) 

2130 else: 

2131 assert_never(p) 

2132 

2133 

2134class _InputTensorConv( 

2135 Converter[ 

2136 _InputTensorDescr_v0_4, 

2137 InputTensorDescr, 

2138 FileSource_, 

2139 Optional[FileSource_], 

2140 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2141 ] 

2142): 

2143 def _convert( 

2144 self, 

2145 src: _InputTensorDescr_v0_4, 

2146 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

2147 test_tensor: FileSource_, 

2148 sample_tensor: Optional[FileSource_], 

2149 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2150 ) -> "InputTensorDescr | dict[str, Any]": 

2151 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2152 src.axes, 

2153 shape=src.shape, 

2154 tensor_type="input", 

2155 halo=None, 

2156 size_refs=size_refs, 

2157 ) 

2158 prep: List[PreprocessingDescr] = [] 

2159 for p in src.preprocessing: 

2160 cp = _convert_proc(p, src.axes) 

2161 assert not isinstance( 

2162 cp, (ScaleMeanVarianceDescr, StardistPostprocessingDescr) 

2163 ) 

2164 prep.append(cp) 

2165 

2166 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

2167 

2168 return tgt( 

2169 axes=axes, 

2170 id=TensorId(str(src.name)), 

2171 test_tensor=FileDescr(source=test_tensor), 

2172 sample_tensor=( 

2173 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2174 ), 

2175 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

2176 preprocessing=prep, 

2177 ) 

2178 

2179 

2180_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

2181 

2182 

2183class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

2184 id: TensorId = TensorId("output") 

2185 """Output tensor id. 

2186 No duplicates are allowed across all inputs and outputs.""" 

2187 

2188 postprocessing: List[PostprocessingDescr] = Field( 

2189 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

2190 ) 

2191 """Description of how this output should be postprocessed. 

2192 

2193 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

2194 If not given this is added to cast to this tensor's `data.type`. 

2195 """ 

2196 

2197 @model_validator(mode="after") 

2198 def _validate_postprocessing_kwargs(self) -> Self: 

2199 axes_ids = [a.id for a in self.axes] 

2200 for p in self.postprocessing: 

2201 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

2202 if kwargs_axes is None: 

2203 continue 

2204 

2205 if not isinstance(kwargs_axes, collections.abc.Sequence): 

2206 raise ValueError( 

2207 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

2208 ) 

2209 

2210 if any(a not in axes_ids for a in kwargs_axes): 

2211 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

2212 

2213 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

2214 dtype = self.data.type 

2215 else: 

2216 dtype = self.data[0].type 

2217 

2218 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

2219 if not self.postprocessing or not isinstance( 

2220 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

2221 ): 

2222 self.postprocessing.append( 

2223 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2224 ) 

2225 return self 

2226 

2227 

2228class _OutputTensorConv( 

2229 Converter[ 

2230 _OutputTensorDescr_v0_4, 

2231 OutputTensorDescr, 

2232 FileSource_, 

2233 Optional[FileSource_], 

2234 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2235 ] 

2236): 

2237 def _convert( 

2238 self, 

2239 src: _OutputTensorDescr_v0_4, 

2240 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

2241 test_tensor: FileSource_, 

2242 sample_tensor: Optional[FileSource_], 

2243 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2244 ) -> "OutputTensorDescr | dict[str, Any]": 

2245 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2246 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2247 src.axes, 

2248 shape=src.shape, 

2249 tensor_type="output", 

2250 halo=src.halo, 

2251 size_refs=size_refs, 

2252 ) 

2253 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2254 if data_descr["type"] == "bool": 

2255 data_descr["values"] = [False, True] 

2256 

2257 return tgt( 

2258 axes=axes, 

2259 id=TensorId(str(src.name)), 

2260 test_tensor=FileDescr(source=test_tensor), 

2261 sample_tensor=( 

2262 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2263 ), 

2264 data=data_descr, # pyright: ignore[reportArgumentType] 

2265 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2266 ) 

2267 

2268 

2269_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2270 

2271 

2272TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2273 

2274 

2275def validate_tensors( 

2276 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 

2277 tensor_origin: Literal[ 

2278 "source", "test_tensor" 

2279 ] = "source", # for more precise error messages 

2280): 

2281 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 

2282 

2283 def e_msg_location(d: TensorDescr): 

2284 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2285 

2286 for descr, array in tensors.values(): 

2287 if array is None: 

2288 axis_sizes = {a.id: None for a in descr.axes} 

2289 else: 

2290 try: 

2291 axis_sizes = descr.get_axis_sizes_for_array(array) 

2292 except ValueError as e: 

2293 raise ValueError(f"{e_msg_location(descr)} {e}") 

2294 

2295 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 

2296 

2297 for descr, array in tensors.values(): 

2298 if array is None: 

2299 continue 

2300 

2301 if descr.dtype in ("float32", "float64"): 

2302 invalid_test_tensor_dtype = array.dtype.name not in ( 

2303 "float32", 

2304 "float64", 

2305 "uint8", 

2306 "int8", 

2307 "uint16", 

2308 "int16", 

2309 "uint32", 

2310 "int32", 

2311 "uint64", 

2312 "int64", 

2313 ) 

2314 else: 

2315 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2316 

2317 if invalid_test_tensor_dtype: 

2318 raise ValueError( 

2319 f"{tensor_origin} data type '{array.dtype.name}' does not" 

2320 + f" match described {e_msg_location(descr)}.dtype '{descr.dtype}'" 

2321 ) 

2322 

2323 if array.min() > -1e-4 and array.max() < 1e-4: 

2324 raise ValueError( 

2325 "Output values are too small for reliable testing." 

2326 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2327 ) 

2328 

2329 for a in descr.axes: 

2330 actual_size = all_tensor_axes[descr.id][a.id][1] 

2331 if actual_size is None: 

2332 continue 

2333 

2334 if a.size is None: 

2335 continue 

2336 

2337 if isinstance(a.size, int): 

2338 if actual_size != a.size: 

2339 raise ValueError( 

2340 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis " 

2341 + f"has incompatible size {actual_size}, expected {a.size}" 

2342 ) 

2343 elif isinstance(a.size, ParameterizedSize): 

2344 _ = a.size.validate_size( 

2345 actual_size, 

2346 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis ", 

2347 ) 

2348 elif isinstance(a.size, DataDependentSize): 

2349 _ = a.size.validate_size( 

2350 actual_size, 

2351 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis ", 

2352 ) 

2353 elif isinstance(a.size, SizeReference): 

2354 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2355 if ref_tensor_axes is None: 

2356 raise ValueError( 

2357 f"{e_msg_location(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2358 + f" reference '{a.size.tensor_id}', available: {list(all_tensor_axes)}" 

2359 ) 

2360 

2361 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2362 if ref_axis is None or ref_size is None: 

2363 raise ValueError( 

2364 f"{e_msg_location(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2365 + f" reference '{a.size.tensor_id}.{a.size.axis_id}, available: {list(ref_tensor_axes)}" 

2366 ) 

2367 

2368 if a.unit != ref_axis.unit: 

2369 raise ValueError( 

2370 f"{e_msg_location(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2371 + " axis and reference axis to have the same `unit`, but" 

2372 + f" {a.unit}!={ref_axis.unit}" 

2373 ) 

2374 

2375 if actual_size != ( 

2376 expected_size := ( 

2377 ref_size * ref_axis.scale / a.scale + a.size.offset 

2378 ) 

2379 ): 

2380 raise ValueError( 

2381 f"{e_msg_location(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2382 + f" {actual_size} invalid for referenced size {ref_size};" 

2383 + f" expected {expected_size}" 

2384 ) 

2385 else: 

2386 assert_never(a.size) 

2387 

2388 

2389FileDescr_dependencies = Annotated[ 

2390 FileDescr_, 

2391 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2392 Field(examples=[dict(source="environment.yaml")]), 

2393] 

2394 

2395 

2396class _ArchitectureCallableDescr(Node): 

2397 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2398 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2399 

2400 kwargs: Dict[str, YamlValue] = Field( 

2401 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2402 ) 

2403 """key word arguments for the `callable`""" 

2404 

2405 

2406class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2407 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2408 """Architecture source file""" 

2409 

2410 @model_serializer(mode="wrap", when_used="unless-none") 

2411 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2412 return package_file_descr_serializer(self, nxt, info) 

2413 

2414 

2415class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2416 import_from: str 

2417 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2418 

2419 

2420class _ArchFileConv( 

2421 Converter[ 

2422 _CallableFromFile_v0_4, 

2423 ArchitectureFromFileDescr, 

2424 Optional[Sha256], 

2425 Dict[str, Any], 

2426 ] 

2427): 

2428 def _convert( 

2429 self, 

2430 src: _CallableFromFile_v0_4, 

2431 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2432 sha256: Optional[Sha256], 

2433 kwargs: Dict[str, Any], 

2434 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2435 if src.startswith("http") and src.count(":") == 2: 

2436 http, source, callable_ = src.split(":") 

2437 source = ":".join((http, source)) 

2438 elif not src.startswith("http") and src.count(":") == 1: 

2439 source, callable_ = src.split(":") 

2440 else: 

2441 source = str(src) 

2442 callable_ = str(src) 

2443 return tgt( 

2444 callable=Identifier(callable_), 

2445 source=cast(FileSource_, source), 

2446 sha256=sha256, 

2447 kwargs=kwargs, 

2448 ) 

2449 

2450 

2451_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2452 

2453 

2454class _ArchLibConv( 

2455 Converter[ 

2456 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2457 ] 

2458): 

2459 def _convert( 

2460 self, 

2461 src: _CallableFromDepencency_v0_4, 

2462 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2463 kwargs: Dict[str, Any], 

2464 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2465 *mods, callable_ = src.split(".") 

2466 import_from = ".".join(mods) 

2467 return tgt( 

2468 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2469 ) 

2470 

2471 

2472_arch_lib_conv = _ArchLibConv( 

2473 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2474) 

2475 

2476 

2477class WeightsEntryDescrBase(FileDescr): 

2478 type: ClassVar[WeightsFormat] 

2479 weights_format_name: ClassVar[str] # human readable 

2480 

2481 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2482 """Source of the weights file.""" 

2483 

2484 authors: Optional[List[Author]] = None 

2485 """Authors 

2486 Either the person(s) that have trained this model resulting in the original weights file. 

2487 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2488 Or the person(s) who have converted the weights to this weights format. 

2489 (If this is a child weight, i.e. it has a `parent` field) 

2490 """ 

2491 

2492 parent: Annotated[ 

2493 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2494 ] = None 

2495 """The source weights these weights were converted from. 

2496 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2497 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2498 All weight entries except one (the initial set of weights resulting from training the model), 

2499 need to have this field.""" 

2500 

2501 comment: str = "" 

2502 """A comment about this weights entry, for example how these weights were created.""" 

2503 

2504 @model_validator(mode="after") 

2505 def _validate(self) -> Self: 

2506 if self.type == self.parent: 

2507 raise ValueError("Weights entry can't be it's own parent.") 

2508 

2509 return self 

2510 

2511 @model_serializer(mode="wrap", when_used="unless-none") 

2512 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2513 return package_file_descr_serializer(self, nxt, info) 

2514 

2515 

2516class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2517 type: ClassVar[WeightsFormat] = "keras_hdf5" 

2518 weights_format_name: ClassVar[str] = "Keras HDF5" 

2519 tensorflow_version: Version 

2520 """TensorFlow version used to create these weights.""" 

2521 

2522 

2523class KerasV3WeightsDescr(WeightsEntryDescrBase): 

2524 type: ClassVar[WeightsFormat] = "keras_v3" 

2525 weights_format_name: ClassVar[str] = "Keras v3" 

2526 keras_version: Annotated[Version, Ge(Version(3))] 

2527 """Keras version used to create these weights.""" 

2528 backend: Tuple[Literal["tensorflow", "jax", "torch"], Version] 

2529 """Keras backend used to create these weights.""" 

2530 source: Annotated[ 

2531 FileSource, 

2532 AfterValidator(wo_special_file_name), 

2533 WithSuffix(".keras", case_sensitive=True), 

2534 ] 

2535 """Source of the .keras weights file.""" 

2536 

2537 

2538FileDescr_external_data = Annotated[ 

2539 FileDescr_, 

2540 WithSuffix(".data", case_sensitive=True), 

2541 Field(examples=[dict(source="weights.onnx.data")]), 

2542] 

2543 

2544 

2545class OnnxWeightsDescr(WeightsEntryDescrBase): 

2546 type: ClassVar[WeightsFormat] = "onnx" 

2547 weights_format_name: ClassVar[str] = "ONNX" 

2548 opset_version: Annotated[int, Ge(7)] 

2549 """ONNX opset version""" 

2550 

2551 external_data: Optional[FileDescr_external_data] = None 

2552 """Source of the external ONNX data file holding the weights. 

2553 (If present **source** holds the ONNX architecture without weights).""" 

2554 

2555 @model_validator(mode="after") 

2556 def _validate_external_data_unique_file_name(self) -> Self: 

2557 if self.external_data is not None and ( 

2558 extract_file_name(self.source) 

2559 == extract_file_name(self.external_data.source) 

2560 ): 

2561 raise ValueError( 

2562 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'" 

2563 + " must be different from ONNX `source` file name." 

2564 ) 

2565 

2566 return self 

2567 

2568 

2569class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2570 type: ClassVar[WeightsFormat] = "pytorch_state_dict" 

2571 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2572 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 

2573 pytorch_version: Version 

2574 """Version of the PyTorch library used. 

2575 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2576 """ 

2577 dependencies: Optional[FileDescr_dependencies] = None 

2578 """Custom depencies beyond pytorch described in a Conda environment file. 

2579 Allows to specify custom dependencies, see conda docs: 

2580 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2581 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2582 

2583 The conda environment file should include pytorch and any version pinning has to be compatible with 

2584 **pytorch_version**. 

2585 """ 

2586 

2587 

2588class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2589 type: ClassVar[WeightsFormat] = "tensorflow_js" 

2590 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2591 tensorflow_version: Version 

2592 """Version of the TensorFlow library used.""" 

2593 

2594 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2595 """The multi-file weights. 

2596 All required files/folders should be a zip archive.""" 

2597 

2598 

2599class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2600 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle" 

2601 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2602 tensorflow_version: Version 

2603 """Version of the TensorFlow library used.""" 

2604 

2605 dependencies: Optional[FileDescr_dependencies] = None 

2606 """Custom dependencies beyond tensorflow. 

2607 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 

2608 

2609 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2610 """The multi-file weights. 

2611 All required files/folders should be a zip archive.""" 

2612 

2613 

2614class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2615 type: ClassVar[WeightsFormat] = "torchscript" 

2616 weights_format_name: ClassVar[str] = "TorchScript" 

2617 pytorch_version: Version 

2618 """Version of the PyTorch library used.""" 

2619 

2620 

2621SpecificWeightsDescr = Union[ 

2622 KerasHdf5WeightsDescr, 

2623 KerasV3WeightsDescr, 

2624 OnnxWeightsDescr, 

2625 PytorchStateDictWeightsDescr, 

2626 TensorflowJsWeightsDescr, 

2627 TensorflowSavedModelBundleWeightsDescr, 

2628 TorchscriptWeightsDescr, 

2629] 

2630 

2631 

2632class WeightsDescr(Node): 

2633 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2634 keras_v3: Optional[KerasV3WeightsDescr] = None 

2635 onnx: Optional[OnnxWeightsDescr] = None 

2636 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2637 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2638 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2639 None 

2640 ) 

2641 torchscript: Optional[TorchscriptWeightsDescr] = None 

2642 

2643 @model_validator(mode="after") 

2644 def check_entries(self) -> Self: 

2645 entries = {wtype for wtype, entry in self if entry is not None} 

2646 

2647 if not entries: 

2648 raise ValueError("Missing weights entry") 

2649 

2650 entries_wo_parent = { 

2651 wtype 

2652 for wtype, entry in self 

2653 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2654 } 

2655 if len(entries_wo_parent) != 1: 

2656 issue_warning( 

2657 "Exactly one weights entry may not specify the `parent` field (got" 

2658 + " {value}). That entry is considered the original set of model weights." 

2659 + " Other weight formats are created through conversion of the orignal or" 

2660 + " already converted weights. They have to reference the weights format" 

2661 + " they were converted from as their `parent`.", 

2662 value=len(entries_wo_parent), 

2663 field="weights", 

2664 ) 

2665 

2666 for wtype, entry in self: 

2667 if entry is None: 

2668 continue 

2669 

2670 assert hasattr(entry, "type") 

2671 assert hasattr(entry, "parent") 

2672 assert wtype == entry.type 

2673 if ( 

2674 entry.parent is not None and entry.parent not in entries 

2675 ): # self reference checked for `parent` field 

2676 raise ValueError( 

2677 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2678 + f" formats: {entries}" 

2679 ) 

2680 

2681 return self 

2682 

2683 def __getitem__( 

2684 self, 

2685 key: WeightsFormat, 

2686 ): 

2687 if key == "keras_hdf5": 

2688 ret = self.keras_hdf5 

2689 elif key == "keras_v3": 

2690 ret = self.keras_v3 

2691 elif key == "onnx": 

2692 ret = self.onnx 

2693 elif key == "pytorch_state_dict": 

2694 ret = self.pytorch_state_dict 

2695 elif key == "tensorflow_js": 

2696 ret = self.tensorflow_js 

2697 elif key == "tensorflow_saved_model_bundle": 

2698 ret = self.tensorflow_saved_model_bundle 

2699 elif key == "torchscript": 

2700 ret = self.torchscript 

2701 else: 

2702 raise KeyError(key) 

2703 

2704 if ret is None: 

2705 raise KeyError(key) 

2706 

2707 return ret 

2708 

2709 @overload 

2710 def __setitem__( 

2711 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr] 

2712 ) -> None: ... 

2713 @overload 

2714 def __setitem__( 

2715 self, key: Literal["keras_v3"], value: Optional[KerasV3WeightsDescr] 

2716 ) -> None: ... 

2717 @overload 

2718 def __setitem__( 

2719 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr] 

2720 ) -> None: ... 

2721 @overload 

2722 def __setitem__( 

2723 self, 

2724 key: Literal["pytorch_state_dict"], 

2725 value: Optional[PytorchStateDictWeightsDescr], 

2726 ) -> None: ... 

2727 @overload 

2728 def __setitem__( 

2729 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr] 

2730 ) -> None: ... 

2731 @overload 

2732 def __setitem__( 

2733 self, 

2734 key: Literal["tensorflow_saved_model_bundle"], 

2735 value: Optional[TensorflowSavedModelBundleWeightsDescr], 

2736 ) -> None: ... 

2737 @overload 

2738 def __setitem__( 

2739 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr] 

2740 ) -> None: ... 

2741 

2742 def __setitem__( 

2743 self, 

2744 key: WeightsFormat, 

2745 value: Optional[SpecificWeightsDescr], 

2746 ): 

2747 if key == "keras_hdf5": 

2748 if value is not None and not isinstance(value, KerasHdf5WeightsDescr): 

2749 raise TypeError( 

2750 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}" 

2751 ) 

2752 self.keras_hdf5 = value 

2753 elif key == "keras_v3": 

2754 if value is not None and not isinstance(value, KerasV3WeightsDescr): 

2755 raise TypeError( 

2756 f"Expected KerasV3WeightsDescr or None for key 'keras_v3', got {type(value)}" 

2757 ) 

2758 self.keras_v3 = value 

2759 elif key == "onnx": 

2760 if value is not None and not isinstance(value, OnnxWeightsDescr): 

2761 raise TypeError( 

2762 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}" 

2763 ) 

2764 self.onnx = value 

2765 elif key == "pytorch_state_dict": 

2766 if value is not None and not isinstance( 

2767 value, PytorchStateDictWeightsDescr 

2768 ): 

2769 raise TypeError( 

2770 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}" 

2771 ) 

2772 self.pytorch_state_dict = value 

2773 elif key == "tensorflow_js": 

2774 if value is not None and not isinstance(value, TensorflowJsWeightsDescr): 

2775 raise TypeError( 

2776 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}" 

2777 ) 

2778 self.tensorflow_js = value 

2779 elif key == "tensorflow_saved_model_bundle": 

2780 if value is not None and not isinstance( 

2781 value, TensorflowSavedModelBundleWeightsDescr 

2782 ): 

2783 raise TypeError( 

2784 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}" 

2785 ) 

2786 self.tensorflow_saved_model_bundle = value 

2787 elif key == "torchscript": 

2788 if value is not None and not isinstance(value, TorchscriptWeightsDescr): 

2789 raise TypeError( 

2790 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}" 

2791 ) 

2792 self.torchscript = value 

2793 else: 

2794 raise KeyError(key) 

2795 

2796 @property 

2797 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]: 

2798 return { 

2799 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2800 **({} if self.keras_v3 is None else {"keras_v3": self.keras_v3}), 

2801 **({} if self.onnx is None else {"onnx": self.onnx}), 

2802 **( 

2803 {} 

2804 if self.pytorch_state_dict is None 

2805 else {"pytorch_state_dict": self.pytorch_state_dict} 

2806 ), 

2807 **( 

2808 {} 

2809 if self.tensorflow_js is None 

2810 else {"tensorflow_js": self.tensorflow_js} 

2811 ), 

2812 **( 

2813 {} 

2814 if self.tensorflow_saved_model_bundle is None 

2815 else { 

2816 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2817 } 

2818 ), 

2819 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2820 } 

2821 

2822 @property 

2823 def missing_formats(self) -> Set[WeightsFormat]: 

2824 return { 

2825 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2826 } 

2827 

2828 

2829class ModelId(ResourceId): 

2830 pass 

2831 

2832 

2833class LinkedModel(LinkedResourceBase): 

2834 """Reference to a bioimage.io model.""" 

2835 

2836 id: ModelId 

2837 """A valid model `id` from the bioimage.io collection.""" 

2838 

2839 

2840class _DataDepSize(NamedTuple): 

2841 min: StrictInt 

2842 max: Optional[StrictInt] 

2843 

2844 

2845class _AxisSizes(NamedTuple): 

2846 """the lenghts of all axes of model inputs and outputs""" 

2847 

2848 inputs: Dict[Tuple[TensorId, AxisId], int] 

2849 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2850 

2851 

2852class _TensorSizes(NamedTuple): 

2853 """_AxisSizes as nested dicts""" 

2854 

2855 inputs: Dict[TensorId, Dict[AxisId, int]] 

2856 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2857 

2858 

2859class ReproducibilityTolerance(Node, extra="allow"): 

2860 """Describes what small numerical differences -- if any -- may be tolerated 

2861 in the generated output when executing in different environments. 

2862 

2863 A tensor element *output* is considered mismatched to the **test_tensor** if 

2864 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

2865 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

2866 

2867 Motivation: 

2868 For testing we can request the respective deep learning frameworks to be as 

2869 reproducible as possible by setting seeds and chosing deterministic algorithms, 

2870 but differences in operating systems, available hardware and installed drivers 

2871 may still lead to numerical differences. 

2872 """ 

2873 

2874 relative_tolerance: RelativeTolerance = 1e-3 

2875 """Maximum relative tolerance of reproduced test tensor.""" 

2876 

2877 absolute_tolerance: AbsoluteTolerance = 1e-3 

2878 """Maximum absolute tolerance of reproduced test tensor.""" 

2879 

2880 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 

2881 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

2882 

2883 output_ids: Sequence[TensorId] = () 

2884 """Limits the output tensor IDs these reproducibility details apply to.""" 

2885 

2886 weights_formats: Sequence[WeightsFormat] = () 

2887 """Limits the weights formats these details apply to.""" 

2888 

2889 

2890class BiasRisksLimitations(Node, extra="allow"): 

2891 """Known biases, risks, technical limitations, and recommendations for model use.""" 

2892 

2893 known_biases: str = dedent("""\ 

2894 In general bioimage models may suffer from biases caused by: 

2895 

2896 - Imaging protocol dependencies 

2897 - Use of a specific cell type 

2898 - Species-specific training data limitations 

2899 

2900 """) 

2901 """Biases in training data or model behavior.""" 

2902 

2903 risks: str = dedent("""\ 

2904 Common risks in bioimage analysis include: 

2905 

2906 - Erroneously assuming generalization to unseen experimental conditions 

2907 - Trusting (overconfident) model outputs without validation 

2908 - Misinterpretation of results 

2909 

2910 """) 

2911 """Potential risks in the context of bioimage analysis.""" 

2912 

2913 limitations: Optional[str] = None 

2914 """Technical limitations and failure modes.""" 

2915 

2916 recommendations: str = "Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model." 

2917 """Mitigation strategies regarding `known_biases`, `risks`, and `limitations`, as well as applicable best practices. 

2918 

2919 Consider: 

2920 - How to use a validation dataset? 

2921 - How to manually validate? 

2922 - Feasibility of domain adaptation for different experimental setups? 

2923 

2924 """ 

2925 

2926 def format_md(self) -> str: 

2927 if self.limitations is None: 

2928 limitations_header = "" 

2929 else: 

2930 limitations_header = "## Limitations\n\n" 

2931 

2932 return f"""# Bias, Risks, and Limitations 

2933 

2934{self.known_biases} 

2935 

2936{self.risks} 

2937 

2938{limitations_header}{self.limitations or ""} 

2939 

2940## Recommendations 

2941 

2942{self.recommendations} 

2943 

2944""" 

2945 

2946 

2947class TrainingDetails(Node, extra="allow"): 

2948 training_preprocessing: Optional[str] = None 

2949 """Detailed image preprocessing steps during model training: 

2950 

2951 Mention: 

2952 - *Normalization methods* 

2953 - *Augmentation strategies* 

2954 - *Resizing/resampling procedures* 

2955 - *Artifact handling* 

2956 

2957 """ 

2958 

2959 training_epochs: Optional[float] = None 

2960 """Number of training epochs.""" 

2961 

2962 training_batch_size: Optional[float] = None 

2963 """Batch size used in training.""" 

2964 

2965 initial_learning_rate: Optional[float] = None 

2966 """Initial learning rate used in training.""" 

2967 

2968 learning_rate_schedule: Optional[str] = None 

2969 """Learning rate schedule used in training.""" 

2970 

2971 loss_function: Optional[str] = None 

2972 """Loss function used in training, e.g. nn.MSELoss.""" 

2973 

2974 loss_function_kwargs: Dict[str, YamlValue] = Field( 

2975 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2976 ) 

2977 """key word arguments for the `loss_function`""" 

2978 

2979 optimizer: Optional[str] = None 

2980 """optimizer, e.g. torch.optim.Adam""" 

2981 

2982 optimizer_kwargs: Dict[str, YamlValue] = Field( 

2983 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2984 ) 

2985 """key word arguments for the `optimizer`""" 

2986 

2987 regularization: Optional[str] = None 

2988 """Regularization techniques used during training, e.g. drop-out or weight decay.""" 

2989 

2990 training_duration: Optional[float] = None 

2991 """Total training duration in hours.""" 

2992 

2993 

2994class Evaluation(Node, extra="allow"): 

2995 model_id: Optional[ModelId] = None 

2996 """Model being evaluated.""" 

2997 

2998 dataset_id: DatasetId 

2999 """Dataset used for evaluation.""" 

3000 

3001 dataset_source: HttpUrl 

3002 """Source of the dataset.""" 

3003 

3004 dataset_role: Literal["train", "validation", "test", "independent", "unknown"] 

3005 """Role of the dataset used for evaluation. 

3006 

3007 - `train`: dataset was (part of) the training data 

3008 - `validation`: dataset was (part of) the validation data used during training, e.g. used for model selection or hyperparameter tuning 

3009 - `test`: dataset was (part of) the designated test data; not used during training or validation, but acquired from the same source/distribution as training data 

3010 - `independent`: dataset is entirely independent test data; not used during training or validation, and acquired from a different source/distribution than training data 

3011 - `unknown`: role of the dataset is unknown; choose this if you are not certain if (a subset) of the data was seen by the model during training. 

3012 """ 

3013 

3014 sample_count: int 

3015 """Number of evaluated samples.""" 

3016 

3017 evaluation_factors: List[Annotated[str, MaxLen(16)]] 

3018 """(Abbreviations of) each evaluation factor. 

3019 

3020 Evaluation factors are criteria along which model performance is evaluated, e.g. different image conditions 

3021 like 'low SNR', 'high cell density', or different biological conditions like 'cell type A', 'cell type B'. 

3022 An 'overall' factor may be included to summarize performance across all conditions. 

3023 """ 

3024 

3025 evaluation_factors_long: List[str] 

3026 """Descriptions (long form) of each evaluation factor.""" 

3027 

3028 metrics: List[Annotated[str, MaxLen(16)]] 

3029 """(Abbreviations of) metrics used for evaluation.""" 

3030 

3031 metrics_long: List[str] 

3032 """Description of each metric used.""" 

3033 

3034 @model_validator(mode="after") 

3035 def _validate_list_lengths(self) -> Self: 

3036 if len(self.evaluation_factors) != len(self.evaluation_factors_long): 

3037 raise ValueError( 

3038 "`evaluation_factors` and `evaluation_factors_long` must have the same length" 

3039 ) 

3040 

3041 if len(self.metrics) != len(self.metrics_long): 

3042 raise ValueError("`metrics` and `metrics_long` must have the same length") 

3043 

3044 if len(self.results) != len(self.metrics): 

3045 raise ValueError("`results` must have the same number of rows as `metrics`") 

3046 

3047 for row in self.results: 

3048 if len(row) != len(self.evaluation_factors): 

3049 raise ValueError( 

3050 "`results` must have the same number of columns (in every row) as `evaluation_factors`" 

3051 ) 

3052 

3053 return self 

3054 

3055 results: List[List[Union[str, float, int]]] 

3056 """Results for each metric (rows; outer list) and each evaluation factor (columns; inner list).""" 

3057 

3058 results_summary: Optional[str] = None 

3059 """Interpretation of results for general audience. 

3060 

3061 Consider: 

3062 - Overall model performance 

3063 - Comparison to existing methods 

3064 - Limitations and areas for improvement 

3065 

3066""" 

3067 

3068 def format_md(self): 

3069 results_header = ["Metric"] + self.evaluation_factors 

3070 results_table_cells = [results_header, ["---"] * len(results_header)] + [ 

3071 [metric] + [str(r) for r in row] 

3072 for metric, row in zip(self.metrics, self.results) 

3073 ] 

3074 

3075 results_table = "".join( 

3076 "| " + " | ".join(row) + " |\n" for row in results_table_cells 

3077 ) 

3078 factors = "".join( 

3079 f"\n - {ef}: {efl}" 

3080 for ef, efl in zip(self.evaluation_factors, self.evaluation_factors_long) 

3081 ) 

3082 metrics = "".join( 

3083 f"\n - {em}: {eml}" for em, eml in zip(self.metrics, self.metrics_long) 

3084 ) 

3085 

3086 return f"""## Testing Data, Factors & Metrics 

3087 

3088Evaluation of {self.model_id or "this"} model on the {self.dataset_id} dataset (dataset role: {self.dataset_role}). 

3089 

3090### Testing Data 

3091 

3092- **Source:** [{self.dataset_id}]({self.dataset_source}) 

3093- **Size:** {self.sample_count} evaluated samples 

3094 

3095### Factors 

3096{factors} 

3097 

3098### Metrics 

3099{metrics} 

3100 

3101## Results 

3102 

3103### Quantitative Results 

3104 

3105{results_table} 

3106 

3107### Summary 

3108 

3109{self.results_summary or "missing"} 

3110 

3111""" 

3112 

3113 

3114class EnvironmentalImpact(Node, extra="allow"): 

3115 """Environmental considerations for model training and deployment. 

3116 

3117 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). 

3118 """ 

3119 

3120 hardware_type: Optional[str] = None 

3121 """GPU/CPU specifications""" 

3122 

3123 hours_used: Optional[float] = None 

3124 """Total compute hours""" 

3125 

3126 cloud_provider: Optional[str] = None 

3127 """If applicable""" 

3128 

3129 compute_region: Optional[str] = None 

3130 """Geographic location""" 

3131 

3132 co2_emitted: Optional[float] = None 

3133 """kg CO2 equivalent 

3134 

3135 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). 

3136 """ 

3137 

3138 def format_md(self): 

3139 """Filled Markdown template section following [Hugging Face Model Card Template](https://huggingface.co/docs/hub/en/model-card-annotated).""" 

3140 if self == self.__class__(): 

3141 return "" 

3142 

3143 ret = "# Environmental Impact\n\n" 

3144 if self.hardware_type is not None: 

3145 ret += f"- **Hardware Type:** {self.hardware_type}\n" 

3146 if self.hours_used is not None: 

3147 ret += f"- **Hours used:** {self.hours_used}\n" 

3148 if self.cloud_provider is not None: 

3149 ret += f"- **Cloud Provider:** {self.cloud_provider}\n" 

3150 if self.compute_region is not None: 

3151 ret += f"- **Compute Region:** {self.compute_region}\n" 

3152 if self.co2_emitted is not None: 

3153 ret += f"- **Carbon Emitted:** {self.co2_emitted} kg CO2e\n" 

3154 

3155 return ret + "\n" 

3156 

3157 

3158class BioimageioConfig(Node, extra="allow"): 

3159 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

3160 """Tolerances to allow when reproducing the model's test outputs 

3161 from the model's test inputs. 

3162 Only the first entry matching tensor id and weights format is considered. 

3163 """ 

3164 

3165 funded_by: Optional[str] = None 

3166 """Funding agency, grant number if applicable""" 

3167 

3168 architecture_type: Optional[Annotated[str, MaxLen(32)]] = ( 

3169 None # TODO: add to differentiated tags 

3170 ) 

3171 """Model architecture type, e.g., 3D U-Net, ResNet, transformer""" 

3172 

3173 architecture_description: Optional[str] = None 

3174 """Text description of model architecture.""" 

3175 

3176 modality: Optional[str] = None # TODO: add to differentiated tags 

3177 """Input modality, e.g., fluorescence microscopy, electron microscopy""" 

3178 

3179 target_structure: List[str] = Field( # TODO: add to differentiated tags 

3180 default_factory=cast(Callable[[], List[str]], list) 

3181 ) 

3182 """Biological structure(s) the model is designed to analyze, e.g., nuclei, mitochondria, cells""" 

3183 

3184 task: Optional[str] = None # TODO: add to differentiated tags 

3185 """Bioimage-specific task type, e.g., segmentation, classification, detection, denoising""" 

3186 

3187 new_version: Optional[ModelId] = None 

3188 """A new version of this model exists with a different model id.""" 

3189 

3190 out_of_scope_use: Optional[str] = None 

3191 """Describe how the model may be misused in bioimage analysis contexts and what users should **not** do with the model.""" 

3192 

3193 bias_risks_limitations: BiasRisksLimitations = Field( 

3194 default_factory=BiasRisksLimitations.model_construct 

3195 ) 

3196 """Description of known bias, risks, and technical limitations for in-scope model use.""" 

3197 

3198 model_parameter_count: Optional[int] = None 

3199 """Total number of model parameters.""" 

3200 

3201 training: TrainingDetails = Field(default_factory=TrainingDetails.model_construct) 

3202 """Details on how the model was trained.""" 

3203 

3204 inference_time: Optional[str] = None 

3205 """Average inference time per image/tile. Specify hardware and image size. Multiple examples can be given.""" 

3206 

3207 memory_requirements_inference: Optional[str] = None 

3208 """GPU memory needed for inference. Multiple examples with different image size can be given.""" 

3209 

3210 memory_requirements_training: Optional[str] = None 

3211 """GPU memory needed for training. Multiple examples with different image/batch sizes can be given.""" 

3212 

3213 evaluations: List[Evaluation] = Field( 

3214 default_factory=cast(Callable[[], List[Evaluation]], list) 

3215 ) 

3216 """Quantitative model evaluations. 

3217 

3218 Note: 

3219 At the moment we recommend to include only a single test dataset 

3220 (with evaluation factors that may mark subsets of the dataset) 

3221 to avoid confusion and make the presentation of results cleaner. 

3222 """ 

3223 

3224 environmental_impact: EnvironmentalImpact = Field( 

3225 default_factory=EnvironmentalImpact.model_construct 

3226 ) 

3227 """Environmental considerations for model training and deployment""" 

3228 

3229 

3230class Config(Node, extra="allow"): 

3231 bioimageio: BioimageioConfig = Field( 

3232 default_factory=BioimageioConfig.model_construct 

3233 ) 

3234 stardist: YamlValue = None 

3235 

3236 

3237class ModelDescr(GenericModelDescrBase): 

3238 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

3239 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

3240 """ 

3241 

3242 implemented_format_version: ClassVar[Literal["0.5.9"]] = "0.5.9" 

3243 if TYPE_CHECKING: 

3244 format_version: Literal["0.5.9"] = "0.5.9" 

3245 else: 

3246 format_version: Literal["0.5.9"] 

3247 """Version of the bioimage.io model description specification used. 

3248 When creating a new model always use the latest micro/patch version described here. 

3249 The `format_version` is important for any consumer software to understand how to parse the fields. 

3250 """ 

3251 

3252 implemented_type: ClassVar[Literal["model"]] = "model" 

3253 if TYPE_CHECKING: 

3254 type: Literal["model"] = "model" 

3255 else: 

3256 type: Literal["model"] 

3257 """Specialized resource type 'model'""" 

3258 

3259 id: Optional[ModelId] = None 

3260 """bioimage.io-wide unique resource identifier 

3261 assigned by bioimage.io; version **un**specific.""" 

3262 

3263 authors: FAIR[List[Author]] = Field( 

3264 default_factory=cast(Callable[[], List[Author]], list) 

3265 ) 

3266 """The authors are the creators of the model RDF and the primary points of contact.""" 

3267 

3268 documentation: FAIR[Optional[FileSource_documentation]] = None 

3269 """URL or relative path to a markdown file with additional documentation. 

3270 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

3271 The documentation should include a '#[#] Validation' (sub)section 

3272 with details on how to quantitatively validate the model on unseen data.""" 

3273 

3274 @field_validator("documentation", mode="after") 

3275 @classmethod 

3276 def _validate_documentation( 

3277 cls, value: Optional[FileSource_documentation] 

3278 ) -> Optional[FileSource_documentation]: 

3279 if not get_validation_context().perform_io_checks or value is None: 

3280 return value 

3281 

3282 doc_reader = get_reader(value) 

3283 doc_content = doc_reader.read().decode(encoding="utf-8") 

3284 if not re.search("#.*[vV]alidation", doc_content): 

3285 issue_warning( 

3286 "No '# Validation' (sub)section found in {value}.", 

3287 value=value, 

3288 field="documentation", 

3289 ) 

3290 

3291 return value 

3292 

3293 inputs: NotEmpty[Sequence[InputTensorDescr]] 

3294 """Describes the input tensors expected by this model.""" 

3295 

3296 @field_validator("inputs", mode="after") 

3297 @classmethod 

3298 def _validate_input_axes( 

3299 cls, inputs: Sequence[InputTensorDescr] 

3300 ) -> Sequence[InputTensorDescr]: 

3301 input_size_refs = cls._get_axes_with_independent_size(inputs) 

3302 

3303 for i, ipt in enumerate(inputs): 

3304 valid_independent_refs: Dict[ 

3305 Tuple[TensorId, AxisId], 

3306 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3307 ] = { 

3308 **{ 

3309 (ipt.id, a.id): (ipt, a, a.size) 

3310 for a in ipt.axes 

3311 if not isinstance(a, BatchAxis) 

3312 and isinstance(a.size, (int, ParameterizedSize)) 

3313 }, 

3314 **input_size_refs, 

3315 } 

3316 for a, ax in enumerate(ipt.axes): 

3317 cls._validate_axis( 

3318 "inputs", 

3319 i=i, 

3320 tensor_id=ipt.id, 

3321 a=a, 

3322 axis=ax, 

3323 valid_independent_refs=valid_independent_refs, 

3324 ) 

3325 return inputs 

3326 

3327 @staticmethod 

3328 def _validate_axis( 

3329 field_name: str, 

3330 i: int, 

3331 tensor_id: TensorId, 

3332 a: int, 

3333 axis: AnyAxis, 

3334 valid_independent_refs: Dict[ 

3335 Tuple[TensorId, AxisId], 

3336 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3337 ], 

3338 ): 

3339 if isinstance(axis, BatchAxis) or isinstance( 

3340 axis.size, (int, ParameterizedSize, DataDependentSize) 

3341 ): 

3342 return 

3343 elif not isinstance(axis.size, SizeReference): 

3344 assert_never(axis.size) 

3345 

3346 # validate axis.size SizeReference 

3347 ref = (axis.size.tensor_id, axis.size.axis_id) 

3348 if ref not in valid_independent_refs: 

3349 raise ValueError( 

3350 "Invalid tensor axis reference at" 

3351 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

3352 ) 

3353 if ref == (tensor_id, axis.id): 

3354 raise ValueError( 

3355 "Self-referencing not allowed for" 

3356 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

3357 ) 

3358 if axis.type == "channel": 

3359 if valid_independent_refs[ref][1].type != "channel": 

3360 raise ValueError( 

3361 "A channel axis' size may only reference another fixed size" 

3362 + " channel axis." 

3363 ) 

3364 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

3365 ref_size = valid_independent_refs[ref][2] 

3366 assert isinstance(ref_size, int), ( 

3367 "channel axis ref (another channel axis) has to specify fixed" 

3368 + " size" 

3369 ) 

3370 generated_channel_names = [ 

3371 Identifier(axis.channel_names.format(i=i)) 

3372 for i in range(1, ref_size + 1) 

3373 ] 

3374 axis.channel_names = generated_channel_names 

3375 

3376 if (ax_unit := getattr(axis, "unit", None)) != ( 

3377 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

3378 ): 

3379 raise ValueError( 

3380 "The units of an axis and its reference axis need to match, but" 

3381 + f" '{ax_unit}' != '{ref_unit}'." 

3382 ) 

3383 ref_axis = valid_independent_refs[ref][1] 

3384 if isinstance(ref_axis, BatchAxis): 

3385 raise ValueError( 

3386 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

3387 + " (a batch axis is not allowed as reference)." 

3388 ) 

3389 

3390 if isinstance(axis, WithHalo): 

3391 min_size = axis.size.get_size(axis, ref_axis, n=0) 

3392 if (min_size - 2 * axis.halo) < 1: 

3393 raise ValueError( 

3394 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

3395 + f" {axis.halo}." 

3396 ) 

3397 

3398 ref_halo = axis.halo * axis.scale / ref_axis.scale 

3399 if ref_halo != int(ref_halo): 

3400 raise ValueError( 

3401 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} =" 

3402 + f" {tensor_id}.{axis.id}.halo {axis.halo}" 

3403 + f" * {tensor_id}.{axis.id}.scale {axis.scale}" 

3404 + f" / {'.'.join(ref)}.scale {ref_axis.scale})." 

3405 ) 

3406 

3407 def validate_input_tensors( 

3408 self, 

3409 sources: Union[ 

3410 Sequence[NDArray[Any]], Mapping[TensorId, Optional[NDArray[Any]]] 

3411 ], 

3412 ) -> Mapping[TensorId, Optional[NDArray[Any]]]: 

3413 """Check if the given input tensors match the model's input tensor descriptions. 

3414 This includes checks of tensor shapes and dtypes, but not of the actual values. 

3415 """ 

3416 if not isinstance(sources, collections.abc.Mapping): 

3417 sources = {descr.id: tensor for descr, tensor in zip(self.inputs, sources)} 

3418 

3419 tensors = {descr.id: (descr, sources.get(descr.id)) for descr in self.inputs} 

3420 validate_tensors(tensors) 

3421 

3422 return sources 

3423 

3424 @model_validator(mode="after") 

3425 def _validate_test_tensors(self) -> Self: 

3426 if not get_validation_context().perform_io_checks: 

3427 return self 

3428 

3429 test_inputs = { 

3430 descr.id: ( 

3431 descr, 

3432 None if descr.test_tensor is None else load_array(descr.test_tensor), 

3433 ) 

3434 for descr in self.inputs 

3435 } 

3436 test_outputs = { 

3437 descr.id: ( 

3438 descr, 

3439 None if descr.test_tensor is None else load_array(descr.test_tensor), 

3440 ) 

3441 for descr in self.outputs 

3442 } 

3443 

3444 validate_tensors({**test_inputs, **test_outputs}, tensor_origin="test_tensor") 

3445 

3446 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

3447 if not rep_tol.absolute_tolerance: 

3448 continue 

3449 

3450 if rep_tol.output_ids: 

3451 out_arrays = { 

3452 k: v[1] for k, v in test_outputs.items() if k in rep_tol.output_ids 

3453 } 

3454 else: 

3455 out_arrays = {k: v[1] for k, v in test_outputs.items()} 

3456 

3457 for out_id, array in out_arrays.items(): 

3458 if array is None: 

3459 continue 

3460 

3461 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

3462 raise ValueError( 

3463 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

3464 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

3465 + f" (1% of the maximum value of the test tensor '{out_id}')" 

3466 ) 

3467 

3468 return self 

3469 

3470 @model_validator(mode="after") 

3471 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

3472 ipt_refs = {t.id for t in self.inputs} 

3473 missing_refs = [ 

3474 k["reference_tensor"] 

3475 for k in [p.kwargs for ipt in self.inputs for p in ipt.preprocessing] 

3476 + [p.kwargs for out in self.outputs for p in out.postprocessing] 

3477 if "reference_tensor" in k 

3478 and k["reference_tensor"] is not None 

3479 and k["reference_tensor"] not in ipt_refs 

3480 ] 

3481 

3482 if missing_refs: 

3483 raise ValueError( 

3484 f"`reference_tensor`s {missing_refs} not found. Valid input tensor" 

3485 + f" references are: {ipt_refs}." 

3486 ) 

3487 

3488 return self 

3489 

3490 name: Annotated[ 

3491 str, 

3492 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 

3493 MinLen(5), 

3494 MaxLen(128), 

3495 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

3496 ] 

3497 """A human-readable name of this model. 

3498 It should be no longer than 64 characters 

3499 and may only contain letter, number, underscore, minus, parentheses and spaces. 

3500 We recommend to chose a name that refers to the model's task and image modality. 

3501 """ 

3502 

3503 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

3504 """Describes the output tensors.""" 

3505 

3506 @field_validator("outputs", mode="after") 

3507 @classmethod 

3508 def _validate_tensor_ids( 

3509 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

3510 ) -> Sequence[OutputTensorDescr]: 

3511 tensor_ids = [ 

3512 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

3513 ] 

3514 duplicate_tensor_ids: List[str] = [] 

3515 seen: Set[str] = set() 

3516 for t in tensor_ids: 

3517 if t in seen: 

3518 duplicate_tensor_ids.append(t) 

3519 

3520 seen.add(t) 

3521 

3522 if duplicate_tensor_ids: 

3523 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

3524 

3525 return outputs 

3526 

3527 @staticmethod 

3528 def _get_axes_with_parameterized_size( 

3529 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3530 ): 

3531 return { 

3532 f"{t.id}.{a.id}": (t, a, a.size) 

3533 for t in io 

3534 for a in t.axes 

3535 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

3536 } 

3537 

3538 @staticmethod 

3539 def _get_axes_with_independent_size( 

3540 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3541 ): 

3542 return { 

3543 (t.id, a.id): (t, a, a.size) 

3544 for t in io 

3545 for a in t.axes 

3546 if not isinstance(a, BatchAxis) 

3547 and isinstance(a.size, (int, ParameterizedSize)) 

3548 } 

3549 

3550 @field_validator("outputs", mode="after") 

3551 @classmethod 

3552 def _validate_output_axes( 

3553 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

3554 ) -> List[OutputTensorDescr]: 

3555 input_size_refs = cls._get_axes_with_independent_size( 

3556 info.data.get("inputs", []) 

3557 ) 

3558 output_size_refs = cls._get_axes_with_independent_size(outputs) 

3559 

3560 for i, out in enumerate(outputs): 

3561 valid_independent_refs: Dict[ 

3562 Tuple[TensorId, AxisId], 

3563 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3564 ] = { 

3565 **{ 

3566 (out.id, a.id): (out, a, a.size) 

3567 for a in out.axes 

3568 if not isinstance(a, BatchAxis) 

3569 and isinstance(a.size, (int, ParameterizedSize)) 

3570 }, 

3571 **input_size_refs, 

3572 **output_size_refs, 

3573 } 

3574 for a, ax in enumerate(out.axes): 

3575 cls._validate_axis( 

3576 "outputs", 

3577 i, 

3578 out.id, 

3579 a, 

3580 ax, 

3581 valid_independent_refs=valid_independent_refs, 

3582 ) 

3583 

3584 return outputs 

3585 

3586 packaged_by: List[Author] = Field( 

3587 default_factory=cast(Callable[[], List[Author]], list) 

3588 ) 

3589 """The persons that have packaged and uploaded this model. 

3590 Only required if those persons differ from the `authors`.""" 

3591 

3592 parent: Optional[LinkedModel] = None 

3593 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

3594 

3595 @model_validator(mode="after") 

3596 def _validate_parent_is_not_self(self) -> Self: 

3597 if self.parent is not None and self.parent.id == self.id: 

3598 raise ValueError("A model description may not reference itself as parent.") 

3599 

3600 return self 

3601 

3602 run_mode: Annotated[ 

3603 Optional[RunMode], 

3604 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

3605 ] = None 

3606 """Custom run mode for this model: for more complex prediction procedures like test time 

3607 data augmentation that currently cannot be expressed in the specification. 

3608 No standard run modes are defined yet.""" 

3609 

3610 timestamp: Datetime = Field(default_factory=Datetime.now) 

3611 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

3612 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

3613 (In Python a datetime object is valid, too).""" 

3614 

3615 training_data: Annotated[ 

3616 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

3617 Field(union_mode="left_to_right"), 

3618 ] = None 

3619 """The dataset used to train this model""" 

3620 

3621 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

3622 """The weights for this model. 

3623 Weights can be given for different formats, but should otherwise be equivalent. 

3624 The available weight formats determine which consumers can use this model.""" 

3625 

3626 config: Config = Field(default_factory=Config.model_construct) 

3627 

3628 @model_validator(mode="after") 

3629 def _add_default_cover(self) -> Self: 

3630 if not get_validation_context().perform_io_checks or self.covers: 

3631 return self 

3632 

3633 try: 

3634 generated_covers = generate_covers( 

3635 [ 

3636 (t, load_array(t.test_tensor)) 

3637 for t in self.inputs 

3638 if t.test_tensor is not None 

3639 ], 

3640 [ 

3641 (t, load_array(t.test_tensor)) 

3642 for t in self.outputs 

3643 if t.test_tensor is not None 

3644 ], 

3645 ) 

3646 except Exception as e: 

3647 issue_warning( 

3648 "Failed to generate cover image(s): {e}", 

3649 value=self.covers, 

3650 msg_context=dict(e=e), 

3651 field="covers", 

3652 ) 

3653 else: 

3654 self.covers.extend(generated_covers) 

3655 

3656 return self 

3657 

3658 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

3659 return self._get_test_arrays(self.inputs) 

3660 

3661 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

3662 return self._get_test_arrays(self.outputs) 

3663 

3664 @staticmethod 

3665 def _get_test_arrays( 

3666 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3667 ): 

3668 ts: List[FileDescr] = [] 

3669 for d in io_descr: 

3670 if d.test_tensor is None: 

3671 raise ValueError( 

3672 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 

3673 ) 

3674 ts.append(d.test_tensor) 

3675 

3676 data = [load_array(t) for t in ts] 

3677 assert all(isinstance(d, np.ndarray) for d in data) 

3678 return data 

3679 

3680 @staticmethod 

3681 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

3682 batch_size = 1 

3683 tensor_with_batchsize: Optional[TensorId] = None 

3684 for tid in tensor_sizes: 

3685 for aid, s in tensor_sizes[tid].items(): 

3686 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

3687 continue 

3688 

3689 if batch_size != 1: 

3690 assert tensor_with_batchsize is not None 

3691 raise ValueError( 

3692 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

3693 ) 

3694 

3695 batch_size = s 

3696 tensor_with_batchsize = tid 

3697 

3698 return batch_size 

3699 

3700 def get_output_tensor_sizes( 

3701 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

3702 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

3703 """Returns the tensor output sizes for given **input_sizes**. 

3704 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

3705 Otherwise it might be larger than the actual (valid) output""" 

3706 batch_size = self.get_batch_size(input_sizes) 

3707 ns = self.get_ns(input_sizes) 

3708 

3709 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

3710 return tensor_sizes.outputs 

3711 

3712 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

3713 """get parameter `n` for each parameterized axis 

3714 such that the valid input size is >= the given input size""" 

3715 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

3716 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

3717 for tid in input_sizes: 

3718 for aid, s in input_sizes[tid].items(): 

3719 size_descr = axes[tid][aid].size 

3720 if isinstance(size_descr, ParameterizedSize): 

3721 ret[(tid, aid)] = size_descr.get_n(s) 

3722 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

3723 pass 

3724 else: 

3725 assert_never(size_descr) 

3726 

3727 return ret 

3728 

3729 def get_tensor_sizes( 

3730 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

3731 ) -> _TensorSizes: 

3732 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

3733 return _TensorSizes( 

3734 { 

3735 t: { 

3736 aa: axis_sizes.inputs[(tt, aa)] 

3737 for tt, aa in axis_sizes.inputs 

3738 if tt == t 

3739 } 

3740 for t in {tt for tt, _ in axis_sizes.inputs} 

3741 }, 

3742 { 

3743 t: { 

3744 aa: axis_sizes.outputs[(tt, aa)] 

3745 for tt, aa in axis_sizes.outputs 

3746 if tt == t 

3747 } 

3748 for t in {tt for tt, _ in axis_sizes.outputs} 

3749 }, 

3750 ) 

3751 

3752 def get_axis_sizes( 

3753 self, 

3754 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3755 batch_size: Optional[int] = None, 

3756 *, 

3757 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

3758 ) -> _AxisSizes: 

3759 """Determine input and output block shape for scale factors **ns** 

3760 of parameterized input sizes. 

3761 

3762 Args: 

3763 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

3764 that is parameterized as `size = min + n * step`. 

3765 batch_size: The desired size of the batch dimension. 

3766 If given **batch_size** overwrites any batch size present in 

3767 **max_input_shape**. Default 1. 

3768 max_input_shape: Limits the derived block shapes. 

3769 Each axis for which the input size, parameterized by `n`, is larger 

3770 than **max_input_shape** is set to the minimal value `n_min` for which 

3771 this is still true. 

3772 Use this for small input samples or large values of **ns**. 

3773 Or simply whenever you know the full input shape. 

3774 

3775 Returns: 

3776 Resolved axis sizes for model inputs and outputs. 

3777 """ 

3778 max_input_shape = max_input_shape or {} 

3779 if batch_size is None: 

3780 for (_t_id, a_id), s in max_input_shape.items(): 

3781 if a_id == BATCH_AXIS_ID: 

3782 batch_size = s 

3783 break 

3784 else: 

3785 batch_size = 1 

3786 

3787 all_axes = { 

3788 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

3789 } 

3790 

3791 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

3792 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

3793 

3794 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

3795 if isinstance(a, BatchAxis): 

3796 if (t_descr.id, a.id) in ns: 

3797 logger.warning( 

3798 "Ignoring unexpected size increment factor (n) for batch axis" 

3799 + " of tensor '{}'.", 

3800 t_descr.id, 

3801 ) 

3802 return batch_size 

3803 elif isinstance(a.size, int): 

3804 if (t_descr.id, a.id) in ns: 

3805 logger.warning( 

3806 "Ignoring unexpected size increment factor (n) for fixed size" 

3807 + " axis '{}' of tensor '{}'.", 

3808 a.id, 

3809 t_descr.id, 

3810 ) 

3811 return a.size 

3812 elif isinstance(a.size, ParameterizedSize): 

3813 if (t_descr.id, a.id) not in ns: 

3814 raise ValueError( 

3815 "Size increment factor (n) missing for parametrized axis" 

3816 + f" '{a.id}' of tensor '{t_descr.id}'." 

3817 ) 

3818 n = ns[(t_descr.id, a.id)] 

3819 s_max = max_input_shape.get((t_descr.id, a.id)) 

3820 if s_max is not None: 

3821 n = min(n, a.size.get_n(s_max)) 

3822 

3823 return a.size.get_size(n) 

3824 

3825 elif isinstance(a.size, SizeReference): 

3826 if (t_descr.id, a.id) in ns: 

3827 logger.warning( 

3828 "Ignoring unexpected size increment factor (n) for axis '{}'" 

3829 + " of tensor '{}' with size reference.", 

3830 a.id, 

3831 t_descr.id, 

3832 ) 

3833 assert not isinstance(a, BatchAxis) 

3834 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

3835 assert not isinstance(ref_axis, BatchAxis) 

3836 ref_key = (a.size.tensor_id, a.size.axis_id) 

3837 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

3838 assert ref_size is not None, ref_key 

3839 assert not isinstance(ref_size, _DataDepSize), ref_key 

3840 return a.size.get_size( 

3841 axis=a, 

3842 ref_axis=ref_axis, 

3843 ref_size=ref_size, 

3844 ) 

3845 elif isinstance(a.size, DataDependentSize): 

3846 if (t_descr.id, a.id) in ns: 

3847 logger.warning( 

3848 "Ignoring unexpected increment factor (n) for data dependent" 

3849 + " size axis '{}' of tensor '{}'.", 

3850 a.id, 

3851 t_descr.id, 

3852 ) 

3853 return _DataDepSize(a.size.min, a.size.max) 

3854 else: 

3855 assert_never(a.size) 

3856 

3857 # first resolve all , but the `SizeReference` input sizes 

3858 for t_descr in self.inputs: 

3859 for a in t_descr.axes: 

3860 if not isinstance(a.size, SizeReference): 

3861 s = get_axis_size(a) 

3862 assert not isinstance(s, _DataDepSize) 

3863 inputs[t_descr.id, a.id] = s 

3864 

3865 # resolve all other input axis sizes 

3866 for t_descr in self.inputs: 

3867 for a in t_descr.axes: 

3868 if isinstance(a.size, SizeReference): 

3869 s = get_axis_size(a) 

3870 assert not isinstance(s, _DataDepSize) 

3871 inputs[t_descr.id, a.id] = s 

3872 

3873 # resolve all output axis sizes 

3874 for t_descr in self.outputs: 

3875 for a in t_descr.axes: 

3876 assert not isinstance(a.size, ParameterizedSize) 

3877 s = get_axis_size(a) 

3878 outputs[t_descr.id, a.id] = s 

3879 

3880 return _AxisSizes(inputs=inputs, outputs=outputs) 

3881 

3882 @model_validator(mode="before") 

3883 @classmethod 

3884 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

3885 cls.convert_from_old_format_wo_validation(data) 

3886 return data 

3887 

3888 @classmethod 

3889 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

3890 """Convert metadata following an older format version to this classes' format 

3891 without validating the result. 

3892 """ 

3893 if ( 

3894 data.get("type") == "model" 

3895 and isinstance(fv := data.get("format_version"), str) 

3896 and fv.count(".") == 2 

3897 ): 

3898 fv_parts = fv.split(".") 

3899 if any(not p.isdigit() for p in fv_parts): 

3900 return 

3901 

3902 fv_tuple = tuple(map(int, fv_parts)) 

3903 

3904 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

3905 if fv_tuple[:2] in ((0, 3), (0, 4)): 

3906 m04 = _ModelDescr_v0_4.load(data) 

3907 if isinstance(m04, InvalidDescr): 

3908 try: 

3909 updated = _model_conv.convert_as_dict( 

3910 m04 # pyright: ignore[reportArgumentType] 

3911 ) 

3912 except Exception as e: 

3913 logger.error( 

3914 "Failed to convert from invalid model 0.4 description." 

3915 + f"\nerror: {e}" 

3916 + "\nProceeding with model 0.5 validation without conversion." 

3917 ) 

3918 updated = None 

3919 else: 

3920 updated = _model_conv.convert_as_dict(m04) 

3921 

3922 if updated is not None: 

3923 data.clear() 

3924 data.update(updated) 

3925 

3926 elif fv_tuple[:2] == (0, 5): 

3927 # bump patch version 

3928 data["format_version"] = cls.implemented_format_version 

3929 

3930 

3931class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

3932 def _convert( 

3933 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

3934 ) -> "ModelDescr | dict[str, Any]": 

3935 name = "".join( 

3936 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

3937 for c in src.name 

3938 ) 

3939 

3940 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

3941 conv = ( 

3942 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

3943 ) 

3944 return None if auths is None else [conv(a) for a in auths] 

3945 

3946 if TYPE_CHECKING: 

3947 arch_file_conv = _arch_file_conv.convert 

3948 arch_lib_conv = _arch_lib_conv.convert 

3949 else: 

3950 arch_file_conv = _arch_file_conv.convert_as_dict 

3951 arch_lib_conv = _arch_lib_conv.convert_as_dict 

3952 

3953 input_size_refs = { 

3954 ipt.name: { 

3955 a: s 

3956 for a, s in zip( 

3957 ipt.axes, 

3958 ( 

3959 ipt.shape.min 

3960 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

3961 else ipt.shape 

3962 ), 

3963 ) 

3964 } 

3965 for ipt in src.inputs 

3966 if ipt.shape 

3967 } 

3968 output_size_refs = { 

3969 **{ 

3970 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

3971 for out in src.outputs 

3972 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

3973 }, 

3974 **input_size_refs, 

3975 } 

3976 

3977 return tgt( 

3978 attachments=( 

3979 [] 

3980 if src.attachments is None 

3981 else [FileDescr(source=f) for f in src.attachments.files] 

3982 ), 

3983 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType] 

3984 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType] 

3985 config=src.config, # pyright: ignore[reportArgumentType] 

3986 covers=src.covers, 

3987 description=src.description, 

3988 documentation=src.documentation, 

3989 format_version="0.5.9", 

3990 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3991 icon=src.icon, 

3992 id=None if src.id is None else ModelId(src.id), 

3993 id_emoji=src.id_emoji, 

3994 license=src.license, # type: ignore 

3995 links=src.links, 

3996 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType] 

3997 name=name, 

3998 tags=src.tags, 

3999 type=src.type, 

4000 uploader=src.uploader, 

4001 version=src.version, 

4002 inputs=[ # pyright: ignore[reportArgumentType] 

4003 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

4004 for ipt, tt, st in zip( 

4005 src.inputs, 

4006 src.test_inputs, 

4007 src.sample_inputs or [None] * len(src.test_inputs), 

4008 ) 

4009 ], 

4010 outputs=[ # pyright: ignore[reportArgumentType] 

4011 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

4012 for out, tt, st in zip( 

4013 src.outputs, 

4014 src.test_outputs, 

4015 src.sample_outputs or [None] * len(src.test_outputs), 

4016 ) 

4017 ], 

4018 parent=( 

4019 None 

4020 if src.parent is None 

4021 else LinkedModel( 

4022 id=ModelId( 

4023 str(src.parent.id) 

4024 + ( 

4025 "" 

4026 if src.parent.version_number is None 

4027 else f"/{src.parent.version_number}" 

4028 ) 

4029 ) 

4030 ) 

4031 ), 

4032 training_data=( 

4033 None 

4034 if src.training_data is None 

4035 else ( 

4036 LinkedDataset( 

4037 id=DatasetId( 

4038 str(src.training_data.id) 

4039 + ( 

4040 "" 

4041 if src.training_data.version_number is None 

4042 else f"/{src.training_data.version_number}" 

4043 ) 

4044 ) 

4045 ) 

4046 if isinstance(src.training_data, LinkedDataset02) 

4047 else src.training_data 

4048 ) 

4049 ), 

4050 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType] 

4051 run_mode=src.run_mode, 

4052 timestamp=src.timestamp, 

4053 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

4054 keras_hdf5=(w := src.weights.keras_hdf5) 

4055 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

4056 authors=conv_authors(w.authors), 

4057 source=w.source, 

4058 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4059 parent=w.parent, 

4060 ), 

4061 onnx=(w := src.weights.onnx) 

4062 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

4063 source=w.source, 

4064 authors=conv_authors(w.authors), 

4065 parent=w.parent, 

4066 opset_version=w.opset_version or 15, 

4067 ), 

4068 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

4069 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

4070 source=w.source, 

4071 authors=conv_authors(w.authors), 

4072 parent=w.parent, 

4073 architecture=( 

4074 arch_file_conv( 

4075 w.architecture, 

4076 w.architecture_sha256, 

4077 w.kwargs, 

4078 ) 

4079 if isinstance(w.architecture, _CallableFromFile_v0_4) 

4080 else arch_lib_conv(w.architecture, w.kwargs) 

4081 ), 

4082 pytorch_version=w.pytorch_version or Version("1.10"), 

4083 dependencies=( 

4084 None 

4085 if w.dependencies is None 

4086 else (FileDescr if TYPE_CHECKING else dict)( 

4087 source=cast( 

4088 FileSource, 

4089 str(deps := w.dependencies)[ 

4090 ( 

4091 len("conda:") 

4092 if str(deps).startswith("conda:") 

4093 else 0 

4094 ) : 

4095 ], 

4096 ) 

4097 ) 

4098 ), 

4099 ), 

4100 tensorflow_js=(w := src.weights.tensorflow_js) 

4101 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

4102 source=w.source, 

4103 authors=conv_authors(w.authors), 

4104 parent=w.parent, 

4105 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4106 ), 

4107 tensorflow_saved_model_bundle=( 

4108 w := src.weights.tensorflow_saved_model_bundle 

4109 ) 

4110 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

4111 authors=conv_authors(w.authors), 

4112 parent=w.parent, 

4113 source=w.source, 

4114 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4115 dependencies=( 

4116 None 

4117 if w.dependencies is None 

4118 else (FileDescr if TYPE_CHECKING else dict)( 

4119 source=cast( 

4120 FileSource, 

4121 ( 

4122 str(w.dependencies)[len("conda:") :] 

4123 if str(w.dependencies).startswith("conda:") 

4124 else str(w.dependencies) 

4125 ), 

4126 ) 

4127 ) 

4128 ), 

4129 ), 

4130 torchscript=(w := src.weights.torchscript) 

4131 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

4132 source=w.source, 

4133 authors=conv_authors(w.authors), 

4134 parent=w.parent, 

4135 pytorch_version=w.pytorch_version or Version("1.10"), 

4136 ), 

4137 ), 

4138 ) 

4139 

4140 

4141_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

4142 

4143 

4144# create better cover images for 3d data and non-image outputs 

4145def generate_covers( 

4146 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

4147 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

4148) -> List[Path]: 

4149 def squeeze( 

4150 data: NDArray[Any], axes: Sequence[AnyAxis] 

4151 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

4152 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

4153 if data.ndim != len(axes): 

4154 raise ValueError( 

4155 f"tensor shape {data.shape} does not match described axes" 

4156 + f" {[a.id for a in axes]}" 

4157 ) 

4158 

4159 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

4160 return data.squeeze(), axes 

4161 

4162 def normalize( 

4163 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

4164 ) -> NDArray[np.float32]: 

4165 data = data.astype("float32") 

4166 data -= data.min(axis=axis, keepdims=True) 

4167 data /= data.max(axis=axis, keepdims=True) + eps 

4168 return data 

4169 

4170 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

4171 original_shape = data.shape 

4172 original_axes = list(axes) 

4173 data, axes = squeeze(data, axes) 

4174 

4175 # take slice fom any batch or index axis if needed 

4176 # and convert the first channel axis and take a slice from any additional channel axes 

4177 slices: Tuple[slice, ...] = () 

4178 ndim = data.ndim 

4179 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

4180 has_c_axis = False 

4181 for i, a in enumerate(axes): 

4182 s = data.shape[i] 

4183 assert s > 1 

4184 if ( 

4185 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

4186 and ndim > ndim_need 

4187 ): 

4188 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4189 ndim -= 1 

4190 elif isinstance(a, ChannelAxis): 

4191 if has_c_axis: 

4192 # second channel axis 

4193 data = data[slices + (slice(0, 1),)] 

4194 ndim -= 1 

4195 else: 

4196 has_c_axis = True 

4197 if s == 2: 

4198 # visualize two channels with cyan and magenta 

4199 data = np.concatenate( 

4200 [ 

4201 data[slices + (slice(1, 2),)], 

4202 data[slices + (slice(0, 1),)], 

4203 ( 

4204 data[slices + (slice(0, 1),)] 

4205 + data[slices + (slice(1, 2),)] 

4206 ) 

4207 / 2, # TODO: take maximum instead? 

4208 ], 

4209 axis=i, 

4210 ) 

4211 elif data.shape[i] == 3: 

4212 pass # visualize 3 channels as RGB 

4213 else: 

4214 # visualize first 3 channels as RGB 

4215 data = data[slices + (slice(3),)] 

4216 

4217 assert data.shape[i] == 3 

4218 

4219 slices += (slice(None),) 

4220 

4221 data, axes = squeeze(data, axes) 

4222 assert len(axes) == ndim 

4223 # take slice from z axis if needed 

4224 slices = () 

4225 if ndim > ndim_need: 

4226 for i, a in enumerate(axes): 

4227 s = data.shape[i] 

4228 if a.id == AxisId("z"): 

4229 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4230 data, axes = squeeze(data, axes) 

4231 ndim -= 1 

4232 break 

4233 

4234 slices += (slice(None),) 

4235 

4236 # take slice from any space or time axis 

4237 slices = () 

4238 

4239 for i, a in enumerate(axes): 

4240 if ndim <= ndim_need: 

4241 break 

4242 

4243 s = data.shape[i] 

4244 assert s > 1 

4245 if isinstance( 

4246 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

4247 ): 

4248 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4249 ndim -= 1 

4250 

4251 slices += (slice(None),) 

4252 

4253 del slices 

4254 data, axes = squeeze(data, axes) 

4255 assert len(axes) == ndim 

4256 

4257 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2): 

4258 raise ValueError( 

4259 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}." 

4260 ) 

4261 

4262 if not has_c_axis: 

4263 assert ndim == 2 

4264 data = np.repeat(data[:, :, None], 3, axis=2) 

4265 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

4266 ndim += 1 

4267 

4268 assert ndim == 3 

4269 

4270 # transpose axis order such that longest axis comes first... 

4271 axis_order: List[int] = list(np.argsort(list(data.shape))) 

4272 axis_order.reverse() 

4273 # ... and channel axis is last 

4274 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

4275 axis_order.append(axis_order.pop(c)) 

4276 axes = [axes[ao] for ao in axis_order] 

4277 data = data.transpose(axis_order) 

4278 

4279 # h, w = data.shape[:2] 

4280 # if h / w in (1.0 or 2.0): 

4281 # pass 

4282 # elif h / w < 2: 

4283 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

4284 

4285 norm_along = ( 

4286 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

4287 ) 

4288 # normalize the data and map to 8 bit 

4289 data = normalize(data, norm_along) 

4290 data = (data * 255).astype("uint8") 

4291 

4292 return data 

4293 

4294 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

4295 assert im0.dtype == im1.dtype == np.uint8 

4296 assert im0.shape == im1.shape 

4297 assert im0.ndim == 3 

4298 N, M, C = im0.shape 

4299 assert C == 3 

4300 out = np.ones((N, M, C), dtype="uint8") 

4301 for c in range(C): 

4302 outc = np.tril(im0[..., c]) 

4303 mask = outc == 0 

4304 outc[mask] = np.triu(im1[..., c])[mask] 

4305 out[..., c] = outc 

4306 

4307 return out 

4308 

4309 if not inputs: 

4310 raise ValueError("Missing test input tensor for cover generation.") 

4311 

4312 if not outputs: 

4313 raise ValueError("Missing test output tensor for cover generation.") 

4314 

4315 ipt_descr, ipt = inputs[0] 

4316 out_descr, out = outputs[0] 

4317 

4318 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

4319 out_img = to_2d_image(out, out_descr.axes) 

4320 

4321 cover_folder = Path(mkdtemp()) 

4322 if ipt_img.shape == out_img.shape: 

4323 covers = [cover_folder / "cover.png"] 

4324 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

4325 else: 

4326 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

4327 imwrite(covers[0], ipt_img) 

4328 imwrite(covers[1], out_img) 

4329 

4330 return covers