Coverage for src / bioimageio / spec / model / v0_5.py: 76%

1581 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-27 14:45 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from copy import deepcopy 

8from itertools import chain 

9from math import ceil 

10from pathlib import Path, PurePosixPath 

11from tempfile import mkdtemp 

12from textwrap import dedent 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 ClassVar, 

18 Dict, 

19 Generic, 

20 List, 

21 Literal, 

22 Mapping, 

23 NamedTuple, 

24 Optional, 

25 Sequence, 

26 Set, 

27 Tuple, 

28 Type, 

29 TypeVar, 

30 Union, 

31 cast, 

32 overload, 

33) 

34 

35import numpy as np 

36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

38from loguru import logger 

39from numpy.typing import NDArray 

40from pydantic import ( 

41 AfterValidator, 

42 Discriminator, 

43 Field, 

44 RootModel, 

45 SerializationInfo, 

46 SerializerFunctionWrapHandler, 

47 StrictInt, 

48 Tag, 

49 ValidationInfo, 

50 WrapSerializer, 

51 field_validator, 

52 model_serializer, 

53 model_validator, 

54) 

55from typing_extensions import Annotated, Self, assert_never, get_args 

56 

57from .._internal.common_nodes import ( 

58 InvalidDescr, 

59 KwargsNode, 

60 Node, 

61 NodeWithExplicitlySetFields, 

62) 

63from .._internal.constants import DTYPE_LIMITS 

64from .._internal.field_warning import issue_warning, warn 

65from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

66from .._internal.io import FileDescr as FileDescr 

67from .._internal.io import ( 

68 FileSource, 

69 WithSuffix, 

70 YamlValue, 

71 extract_file_name, 

72 get_reader, 

73 wo_special_file_name, 

74) 

75from .._internal.io_basics import Sha256 as Sha256 

76from .._internal.io_packaging import ( 

77 FileDescr_, 

78 FileSource_, 

79 package_file_descr_serializer, 

80) 

81from .._internal.io_utils import load_array 

82from .._internal.node_converter import Converter 

83from .._internal.type_guards import is_dict, is_sequence 

84from .._internal.types import ( 

85 FAIR, 

86 AbsoluteTolerance, 

87 LowerCaseIdentifier, 

88 LowerCaseIdentifierAnno, 

89 MismatchedElementsPerMillion, 

90 RelativeTolerance, 

91) 

92from .._internal.types import Datetime as Datetime 

93from .._internal.types import Identifier as Identifier 

94from .._internal.types import NotEmpty as NotEmpty 

95from .._internal.types import SiUnit as SiUnit 

96from .._internal.url import HttpUrl as HttpUrl 

97from .._internal.validation_context import get_validation_context 

98from .._internal.validator_annotations import RestrictCharacters 

99from .._internal.version_type import Version as Version 

100from .._internal.warning_levels import INFO 

101from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

102from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

103from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

104from ..dataset.v0_3 import DatasetId as DatasetId 

105from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

106from ..dataset.v0_3 import Uploader as Uploader 

107from ..generic.v0_3 import ( 

108 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

109) 

110from ..generic.v0_3 import Author as Author 

111from ..generic.v0_3 import BadgeDescr as BadgeDescr 

112from ..generic.v0_3 import CiteEntry as CiteEntry 

113from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

114from ..generic.v0_3 import Doi as Doi 

115from ..generic.v0_3 import ( 

116 FileSource_documentation, 

117 GenericModelDescrBase, 

118 LinkedResourceBase, 

119 _author_conv, # pyright: ignore[reportPrivateUsage] 

120 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

121) 

122from ..generic.v0_3 import LicenseId as LicenseId 

123from ..generic.v0_3 import LinkedResource as LinkedResource 

124from ..generic.v0_3 import Maintainer as Maintainer 

125from ..generic.v0_3 import OrcidId as OrcidId 

126from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

127from ..generic.v0_3 import ResourceId as ResourceId 

128from .v0_4 import Author as _Author_v0_4 

129from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

130from .v0_4 import CallableFromDepencency as CallableFromDepencency 

131from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

132from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

133from .v0_4 import ClipDescr as _ClipDescr_v0_4 

134from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

135from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

136from .v0_4 import KnownRunMode as KnownRunMode 

137from .v0_4 import ModelDescr as _ModelDescr_v0_4 

138from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

139from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

140from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

141from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

142from .v0_4 import RunMode as RunMode 

143from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

144from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

145from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

146from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

147from .v0_4 import TensorName as _TensorName_v0_4 

148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

149from .v0_4 import package_weights 

150 

151SpaceUnit = Literal[ 

152 "attometer", 

153 "angstrom", 

154 "centimeter", 

155 "decimeter", 

156 "exameter", 

157 "femtometer", 

158 "foot", 

159 "gigameter", 

160 "hectometer", 

161 "inch", 

162 "kilometer", 

163 "megameter", 

164 "meter", 

165 "micrometer", 

166 "mile", 

167 "millimeter", 

168 "nanometer", 

169 "parsec", 

170 "petameter", 

171 "picometer", 

172 "terameter", 

173 "yard", 

174 "yoctometer", 

175 "yottameter", 

176 "zeptometer", 

177 "zettameter", 

178] 

179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

180 

181TimeUnit = Literal[ 

182 "attosecond", 

183 "centisecond", 

184 "day", 

185 "decisecond", 

186 "exasecond", 

187 "femtosecond", 

188 "gigasecond", 

189 "hectosecond", 

190 "hour", 

191 "kilosecond", 

192 "megasecond", 

193 "microsecond", 

194 "millisecond", 

195 "minute", 

196 "nanosecond", 

197 "petasecond", 

198 "picosecond", 

199 "second", 

200 "terasecond", 

201 "yoctosecond", 

202 "yottasecond", 

203 "zeptosecond", 

204 "zettasecond", 

205] 

206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

207 

208AxisType = Literal["batch", "channel", "index", "time", "space"] 

209 

210_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

211 "b": "batch", 

212 "t": "time", 

213 "i": "index", 

214 "c": "channel", 

215 "x": "space", 

216 "y": "space", 

217 "z": "space", 

218} 

219 

220_AXIS_ID_MAP = { 

221 "b": "batch", 

222 "t": "time", 

223 "i": "index", 

224 "c": "channel", 

225} 

226 

227WeightsFormat = Literal[ 

228 "keras_hdf5", 

229 "keras_v3", 

230 "onnx", 

231 "pytorch_state_dict", 

232 "tensorflow_js", 

233 "tensorflow_saved_model_bundle", 

234 "torchscript", 

235] 

236 

237 

238class TensorId(LowerCaseIdentifier): 

239 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

240 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

241 ] 

242 

243 

244def _normalize_axis_id(a: str): 

245 a = str(a) 

246 normalized = _AXIS_ID_MAP.get(a, a) 

247 if a != normalized: 

248 logger.opt(depth=3).warning( 

249 "Normalized axis id from '{}' to '{}'.", a, normalized 

250 ) 

251 return normalized 

252 

253 

254class AxisId(LowerCaseIdentifier): 

255 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

256 Annotated[ 

257 LowerCaseIdentifierAnno, 

258 MaxLen(16), 

259 AfterValidator(_normalize_axis_id), 

260 ] 

261 ] 

262 

263 

264def _is_batch(a: str) -> bool: 

265 return str(a) == "batch" 

266 

267 

268def _is_not_batch(a: str) -> bool: 

269 return not _is_batch(a) 

270 

271 

272NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

273 

274PreprocessingId = Literal[ 

275 "binarize", 

276 "clip", 

277 "ensure_dtype", 

278 "fixed_zero_mean_unit_variance", 

279 "scale_linear", 

280 "scale_range", 

281 "sigmoid", 

282 "softmax", 

283] 

284PostprocessingId = Literal[ 

285 "binarize", 

286 "clip", 

287 "ensure_dtype", 

288 "fixed_zero_mean_unit_variance", 

289 "scale_linear", 

290 "scale_mean_variance", 

291 "scale_range", 

292 "sigmoid", 

293 "softmax", 

294 "zero_mean_unit_variance", 

295] 

296 

297 

298SAME_AS_TYPE = "<same as type>" 

299 

300 

301ParameterizedSize_N = int 

302""" 

303Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

304""" 

305 

306 

307class ParameterizedSize(Node): 

308 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

309 

310 - **min** and **step** are given by the model description. 

311 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

312 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

313 This allows to adjust the axis size more generically. 

314 """ 

315 

316 N: ClassVar[Type[int]] = ParameterizedSize_N 

317 """Positive integer to parameterize this axis""" 

318 

319 min: Annotated[int, Gt(0)] 

320 step: Annotated[int, Gt(0)] 

321 

322 def validate_size(self, size: int, msg_prefix: str = "") -> int: 

323 if size < self.min: 

324 raise ValueError( 

325 f"{msg_prefix}size {size} < {self.min} (minimum axis size)" 

326 ) 

327 if (size - self.min) % self.step != 0: 

328 raise ValueError( 

329 f"{msg_prefix}size {size} is not parameterized by `min + n*step` =" 

330 + f" `{self.min} + n*{self.step}`" 

331 ) 

332 

333 return size 

334 

335 def get_size(self, n: ParameterizedSize_N) -> int: 

336 return self.min + self.step * n 

337 

338 def get_n(self, s: int) -> ParameterizedSize_N: 

339 """return smallest n parameterizing a size greater or equal than `s`""" 

340 return ceil((s - self.min) / self.step) 

341 

342 

343class DataDependentSize(Node): 

344 min: Annotated[int, Gt(0)] = 1 

345 max: Annotated[Optional[int], Gt(1)] = None 

346 

347 @model_validator(mode="after") 

348 def _validate_max_gt_min(self): 

349 if self.max is not None and self.min >= self.max: 

350 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

351 

352 return self 

353 

354 def validate_size(self, size: int, msg_prefix: str = "") -> int: 

355 if size < self.min: 

356 raise ValueError(f"{msg_prefix}size {size} < {self.min}") 

357 

358 if self.max is not None and size > self.max: 

359 raise ValueError(f"{msg_prefix}size {size} > {self.max}") 

360 

361 return size 

362 

363 

364class SizeReference(Node): 

365 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

366 

367 `axis.size = reference.size * reference.scale / axis.scale + offset` 

368 

369 Note: 

370 1. The axis and the referenced axis need to have the same unit (or no unit). 

371 2. Batch axes may not be referenced. 

372 3. Fractions are rounded down. 

373 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

374 `concatenable` as well with the same block order. 

375 

376 Example: 

377 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

378 Let's assume that we want to express the image height h in relation to its width w 

379 instead of only accepting input images of exactly 100*49 pixels 

380 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

381 

382 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

383 >>> h = SpaceInputAxis( 

384 ... id=AxisId("h"), 

385 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

386 ... unit="millimeter", 

387 ... scale=4, 

388 ... ) 

389 >>> print(h.size.get_size(h, w)) 

390 49 

391 

392 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

393 """ 

394 

395 tensor_id: TensorId 

396 """tensor id of the reference axis""" 

397 

398 axis_id: AxisId 

399 """axis id of the reference axis""" 

400 

401 offset: StrictInt = 0 

402 

403 def get_size( 

404 self, 

405 axis: Union[ 

406 ChannelAxis, 

407 IndexInputAxis, 

408 IndexOutputAxis, 

409 TimeInputAxis, 

410 SpaceInputAxis, 

411 TimeOutputAxis, 

412 TimeOutputAxisWithHalo, 

413 SpaceOutputAxis, 

414 SpaceOutputAxisWithHalo, 

415 ], 

416 ref_axis: Union[ 

417 ChannelAxis, 

418 IndexInputAxis, 

419 IndexOutputAxis, 

420 TimeInputAxis, 

421 SpaceInputAxis, 

422 TimeOutputAxis, 

423 TimeOutputAxisWithHalo, 

424 SpaceOutputAxis, 

425 SpaceOutputAxisWithHalo, 

426 ], 

427 n: ParameterizedSize_N = 0, 

428 ref_size: Optional[int] = None, 

429 ): 

430 """Compute the concrete size for a given axis and its reference axis. 

431 

432 Args: 

433 axis: The axis this [SizeReference][] is the size of. 

434 ref_axis: The reference axis to compute the size from. 

435 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

436 and no fixed **ref_size** is given, 

437 **n** is used to compute the size of the parameterized **ref_axis**. 

438 ref_size: Overwrite the reference size instead of deriving it from 

439 **ref_axis** 

440 (**ref_axis.scale** is still used; any given **n** is ignored). 

441 """ 

442 assert axis.size == self, ( 

443 "Given `axis.size` is not defined by this `SizeReference`" 

444 ) 

445 

446 assert ref_axis.id == self.axis_id, ( 

447 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

448 ) 

449 

450 assert axis.unit == ref_axis.unit, ( 

451 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

452 f" but {axis.unit}!={ref_axis.unit}" 

453 ) 

454 if ref_size is None: 

455 if isinstance(ref_axis.size, (int, float)): 

456 ref_size = ref_axis.size 

457 elif isinstance(ref_axis.size, ParameterizedSize): 

458 ref_size = ref_axis.size.get_size(n) 

459 elif isinstance(ref_axis.size, DataDependentSize): 

460 raise ValueError( 

461 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

462 ) 

463 elif isinstance(ref_axis.size, SizeReference): 

464 raise ValueError( 

465 "Reference axis referenced in `SizeReference` may not be sized by a" 

466 + " `SizeReference` itself." 

467 ) 

468 else: 

469 assert_never(ref_axis.size) 

470 

471 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

472 

473 @staticmethod 

474 def _get_unit( 

475 axis: Union[ 

476 ChannelAxis, 

477 IndexInputAxis, 

478 IndexOutputAxis, 

479 TimeInputAxis, 

480 SpaceInputAxis, 

481 TimeOutputAxis, 

482 TimeOutputAxisWithHalo, 

483 SpaceOutputAxis, 

484 SpaceOutputAxisWithHalo, 

485 ], 

486 ): 

487 return axis.unit 

488 

489 

490class AxisBase(NodeWithExplicitlySetFields): 

491 id: AxisId 

492 """An axis id unique across all axes of one tensor.""" 

493 

494 description: Annotated[str, MaxLen(128)] = "" 

495 """A short description of this axis beyond its type and id.""" 

496 

497 

498class WithHalo(Node): 

499 halo: Annotated[int, Ge(1)] 

500 """The halo should be cropped from the output tensor to avoid boundary effects. 

501 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

502 To document a halo that is already cropped by the model use `size.offset` instead.""" 

503 

504 size: Annotated[ 

505 SizeReference, 

506 Field( 

507 examples=[ 

508 10, 

509 SizeReference( 

510 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

511 ).model_dump(mode="json"), 

512 ] 

513 ), 

514 ] 

515 """reference to another axis with an optional offset (see [SizeReference][])""" 

516 

517 

518BATCH_AXIS_ID = AxisId("batch") 

519 

520 

521class BatchAxis(AxisBase): 

522 implemented_type: ClassVar[Literal["batch"]] = "batch" 

523 if TYPE_CHECKING: 

524 type: Literal["batch"] = "batch" 

525 else: 

526 type: Literal["batch"] 

527 

528 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

529 size: Optional[Literal[1]] = None 

530 """The batch size may be fixed to 1, 

531 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

532 

533 @property 

534 def scale(self): 

535 return 1.0 

536 

537 @property 

538 def concatenable(self): 

539 return True 

540 

541 @property 

542 def unit(self): 

543 return None 

544 

545 

546class ChannelAxis(AxisBase): 

547 implemented_type: ClassVar[Literal["channel"]] = "channel" 

548 if TYPE_CHECKING: 

549 type: Literal["channel"] = "channel" 

550 else: 

551 type: Literal["channel"] 

552 

553 id: NonBatchAxisId = AxisId("channel") 

554 

555 channel_names: NotEmpty[List[Identifier]] 

556 

557 @property 

558 def size(self) -> int: 

559 return len(self.channel_names) 

560 

561 @property 

562 def concatenable(self): 

563 return False 

564 

565 @property 

566 def scale(self) -> float: 

567 return 1.0 

568 

569 @property 

570 def unit(self): 

571 return None 

572 

573 

574class IndexAxisBase(AxisBase): 

575 implemented_type: ClassVar[Literal["index"]] = "index" 

576 if TYPE_CHECKING: 

577 type: Literal["index"] = "index" 

578 else: 

579 type: Literal["index"] 

580 

581 id: NonBatchAxisId = AxisId("index") 

582 

583 @property 

584 def scale(self) -> float: 

585 return 1.0 

586 

587 @property 

588 def unit(self): 

589 return None 

590 

591 

592class _WithInputAxisSize(Node): 

593 size: Annotated[ 

594 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

595 Field( 

596 examples=[ 

597 10, 

598 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

599 SizeReference( 

600 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

601 ).model_dump(mode="json"), 

602 ] 

603 ), 

604 ] 

605 """The size/length of this axis can be specified as 

606 - fixed integer 

607 - parameterized series of valid sizes ([ParameterizedSize][]) 

608 - reference to another axis with an optional offset ([SizeReference][]) 

609 """ 

610 

611 

612class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

613 concatenable: bool = False 

614 """If a model has a `concatenable` input axis, it can be processed blockwise, 

615 splitting a longer sample axis into blocks matching its input tensor description. 

616 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

617 input axis. 

618 """ 

619 

620 

621class IndexOutputAxis(IndexAxisBase): 

622 size: Annotated[ 

623 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

624 Field( 

625 examples=[ 

626 10, 

627 SizeReference( 

628 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

629 ).model_dump(mode="json"), 

630 ] 

631 ), 

632 ] 

633 """The size/length of this axis can be specified as 

634 - fixed integer 

635 - reference to another axis with an optional offset ([SizeReference][]) 

636 - data dependent size using [DataDependentSize][] (size is only known after model inference) 

637 """ 

638 

639 

640class TimeAxisBase(AxisBase): 

641 implemented_type: ClassVar[Literal["time"]] = "time" 

642 if TYPE_CHECKING: 

643 type: Literal["time"] = "time" 

644 else: 

645 type: Literal["time"] 

646 

647 id: NonBatchAxisId = AxisId("time") 

648 unit: Optional[TimeUnit] = None 

649 scale: Annotated[float, Gt(0)] = 1.0 

650 

651 

652class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

653 concatenable: bool = False 

654 """If a model has a `concatenable` input axis, it can be processed blockwise, 

655 splitting a longer sample axis into blocks matching its input tensor description. 

656 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

657 input axis. 

658 """ 

659 

660 

661class SpaceAxisBase(AxisBase): 

662 implemented_type: ClassVar[Literal["space"]] = "space" 

663 if TYPE_CHECKING: 

664 type: Literal["space"] = "space" 

665 else: 

666 type: Literal["space"] 

667 

668 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

669 unit: Optional[SpaceUnit] = None 

670 scale: Annotated[float, Gt(0)] = 1.0 

671 

672 

673class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

674 concatenable: bool = False 

675 """If a model has a `concatenable` input axis, it can be processed blockwise, 

676 splitting a longer sample axis into blocks matching its input tensor description. 

677 Output axes are concatenable if they have a [SizeReference][] to a concatenable 

678 input axis. 

679 """ 

680 

681 

682INPUT_AXIS_TYPES = ( 

683 BatchAxis, 

684 ChannelAxis, 

685 IndexInputAxis, 

686 TimeInputAxis, 

687 SpaceInputAxis, 

688) 

689"""intended for isinstance comparisons in py<3.10""" 

690 

691_InputAxisUnion = Union[ 

692 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

693] 

694InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

695 

696 

697class _WithOutputAxisSize(Node): 

698 size: Annotated[ 

699 Union[Annotated[int, Gt(0)], SizeReference], 

700 Field( 

701 examples=[ 

702 10, 

703 SizeReference( 

704 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

705 ).model_dump(mode="json"), 

706 ] 

707 ), 

708 ] 

709 """The size/length of this axis can be specified as 

710 - fixed integer 

711 - reference to another axis with an optional offset (see [SizeReference][]) 

712 """ 

713 

714 

715class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

716 pass 

717 

718 

719class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

720 pass 

721 

722 

723def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

724 if isinstance(v, dict): 

725 return "with_halo" if "halo" in v else "wo_halo" 

726 else: 

727 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

728 

729 

730_TimeOutputAxisUnion = Annotated[ 

731 Union[ 

732 Annotated[TimeOutputAxis, Tag("wo_halo")], 

733 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

734 ], 

735 Discriminator(_get_halo_axis_discriminator_value), 

736] 

737 

738 

739class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

740 pass 

741 

742 

743class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

744 pass 

745 

746 

747_SpaceOutputAxisUnion = Annotated[ 

748 Union[ 

749 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

750 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

751 ], 

752 Discriminator(_get_halo_axis_discriminator_value), 

753] 

754 

755 

756_OutputAxisUnion = Union[ 

757 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

758] 

759OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

760 

761OUTPUT_AXIS_TYPES = ( 

762 BatchAxis, 

763 ChannelAxis, 

764 IndexOutputAxis, 

765 TimeOutputAxis, 

766 TimeOutputAxisWithHalo, 

767 SpaceOutputAxis, 

768 SpaceOutputAxisWithHalo, 

769) 

770"""intended for isinstance comparisons in py<3.10""" 

771 

772 

773AnyAxis = Union[InputAxis, OutputAxis] 

774 

775ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

776"""intended for isinstance comparisons in py<3.10""" 

777 

778TVs = Union[ 

779 NotEmpty[List[int]], 

780 NotEmpty[List[float]], 

781 NotEmpty[List[bool]], 

782 NotEmpty[List[str]], 

783] 

784 

785 

786NominalOrOrdinalDType = Literal[ 

787 "float32", 

788 "float64", 

789 "uint8", 

790 "int8", 

791 "uint16", 

792 "int16", 

793 "uint32", 

794 "int32", 

795 "uint64", 

796 "int64", 

797 "bool", 

798] 

799 

800 

801class NominalOrOrdinalDataDescr(Node): 

802 values: TVs 

803 """A fixed set of nominal or an ascending sequence of ordinal values. 

804 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

805 String `values` are interpreted as labels for tensor values 0, ..., N. 

806 Note: as YAML 1.2 does not natively support a "set" datatype, 

807 nominal values should be given as a sequence (aka list/array) as well. 

808 """ 

809 

810 type: Annotated[ 

811 NominalOrOrdinalDType, 

812 Field( 

813 examples=[ 

814 "float32", 

815 "uint8", 

816 "uint16", 

817 "int64", 

818 "bool", 

819 ], 

820 ), 

821 ] = "uint8" 

822 

823 @model_validator(mode="after") 

824 def _validate_values_match_type( 

825 self, 

826 ) -> Self: 

827 incompatible: List[Any] = [] 

828 for v in self.values: 

829 if self.type == "bool": 

830 if not isinstance(v, bool): 

831 incompatible.append(v) 

832 elif self.type in DTYPE_LIMITS: 

833 if ( 

834 isinstance(v, (int, float)) 

835 and ( 

836 v < DTYPE_LIMITS[self.type].min 

837 or v > DTYPE_LIMITS[self.type].max 

838 ) 

839 or (isinstance(v, str) and "uint" not in self.type) 

840 or (isinstance(v, float) and "int" in self.type) 

841 ): 

842 incompatible.append(v) 

843 else: 

844 incompatible.append(v) 

845 

846 if len(incompatible) == 5: 

847 incompatible.append("...") 

848 break 

849 

850 if incompatible: 

851 raise ValueError( 

852 f"data type '{self.type}' incompatible with values {incompatible}" 

853 ) 

854 

855 return self 

856 

857 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

858 

859 @property 

860 def range(self): 

861 if isinstance(self.values[0], str): 

862 return 0, len(self.values) - 1 

863 else: 

864 return min(self.values), max(self.values) 

865 

866 

867IntervalOrRatioDType = Literal[ 

868 "float32", 

869 "float64", 

870 "uint8", 

871 "int8", 

872 "uint16", 

873 "int16", 

874 "uint32", 

875 "int32", 

876 "uint64", 

877 "int64", 

878] 

879 

880 

881class IntervalOrRatioDataDescr(Node): 

882 type: Annotated[ # TODO: rename to dtype 

883 IntervalOrRatioDType, 

884 Field( 

885 examples=["float32", "float64", "uint8", "uint16"], 

886 ), 

887 ] = "float32" 

888 range: Tuple[Optional[float], Optional[float]] = ( 

889 None, 

890 None, 

891 ) 

892 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

893 `None` corresponds to min/max of what can be expressed by **type**.""" 

894 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

895 scale: float = 1.0 

896 """Scale for data on an interval (or ratio) scale.""" 

897 offset: Optional[float] = None 

898 """Offset for data on a ratio scale.""" 

899 

900 @model_validator(mode="before") 

901 def _replace_inf(cls, data: Any): 

902 if is_dict(data): 

903 if "range" in data and is_sequence(data["range"]): 

904 forbidden = ( 

905 "inf", 

906 "-inf", 

907 ".inf", 

908 "-.inf", 

909 float("inf"), 

910 float("-inf"), 

911 ) 

912 if any(v in forbidden for v in data["range"]): 

913 issue_warning("replaced 'inf' value", value=data["range"]) 

914 

915 data["range"] = tuple( 

916 (None if v in forbidden else v) for v in data["range"] 

917 ) 

918 

919 return data 

920 

921 

922TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

923 

924 

925class BinarizeKwargs(KwargsNode): 

926 """key word arguments for [BinarizeDescr][]""" 

927 

928 threshold: float 

929 """The fixed threshold""" 

930 

931 

932class BinarizeAlongAxisKwargs(KwargsNode): 

933 """key word arguments for [BinarizeDescr][]""" 

934 

935 threshold: NotEmpty[List[float]] 

936 """The fixed threshold values along `axis`""" 

937 

938 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

939 """The `threshold` axis""" 

940 

941 

942class BinarizeDescr(NodeWithExplicitlySetFields): 

943 """Binarize the tensor with a fixed threshold. 

944 

945 Values above [BinarizeKwargs.threshold][]/[BinarizeAlongAxisKwargs.threshold][] 

946 will be set to one, values below the threshold to zero. 

947 

948 Examples: 

949 - in YAML 

950 ```yaml 

951 postprocessing: 

952 - id: binarize 

953 kwargs: 

954 axis: 'channel' 

955 threshold: [0.25, 0.5, 0.75] 

956 ``` 

957 - in Python: 

958 >>> postprocessing = [BinarizeDescr( 

959 ... kwargs=BinarizeAlongAxisKwargs( 

960 ... axis=AxisId('channel'), 

961 ... threshold=[0.25, 0.5, 0.75], 

962 ... ) 

963 ... )] 

964 """ 

965 

966 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

967 if TYPE_CHECKING: 

968 id: Literal["binarize"] = "binarize" 

969 else: 

970 id: Literal["binarize"] 

971 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

972 

973 

974class ClipKwargs(KwargsNode): 

975 """key word arguments for [ClipDescr][]""" 

976 

977 min: Optional[float] = None 

978 """Minimum value for clipping. 

979 

980 Exclusive with [min_percentile][] 

981 """ 

982 min_percentile: Optional[Annotated[float, Interval(ge=0, lt=100)]] = None 

983 """Minimum percentile for clipping. 

984 

985 Exclusive with [min][]. 

986 

987 In range [0, 100). 

988 """ 

989 

990 max: Optional[float] = None 

991 """Maximum value for clipping. 

992 

993 Exclusive with `max_percentile`. 

994 """ 

995 max_percentile: Optional[Annotated[float, Interval(gt=1, le=100)]] = None 

996 """Maximum percentile for clipping. 

997 

998 Exclusive with `max`. 

999 

1000 In range (1, 100]. 

1001 """ 

1002 

1003 axes: Annotated[ 

1004 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1005 ] = None 

1006 """The subset of axes to determine percentiles jointly, 

1007 

1008 i.e. axes to reduce to compute min/max from `min_percentile`/`max_percentile`. 

1009 For example to clip 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1010 resulting in a tensor of equal shape with clipped values per channel, specify `axes=('batch', 'x', 'y')`. 

1011 To clip samples independently, leave out the 'batch' axis. 

1012 

1013 Only valid if `min_percentile` and/or `max_percentile` are set. 

1014 

1015 Default: Compute percentiles over all axes jointly.""" 

1016 

1017 @model_validator(mode="after") 

1018 def _validate(self) -> Self: 

1019 if (self.min is not None) and (self.min_percentile is not None): 

1020 raise ValueError( 

1021 "Only one of `min` and `min_percentile` may be set, not both." 

1022 ) 

1023 if (self.max is not None) and (self.max_percentile is not None): 

1024 raise ValueError( 

1025 "Only one of `max` and `max_percentile` may be set, not both." 

1026 ) 

1027 if ( 

1028 self.min is None 

1029 and self.min_percentile is None 

1030 and self.max is None 

1031 and self.max_percentile is None 

1032 ): 

1033 raise ValueError( 

1034 "At least one of `min`, `min_percentile`, `max`, or `max_percentile` must be set." 

1035 ) 

1036 

1037 if ( 

1038 self.axes is not None 

1039 and self.min_percentile is None 

1040 and self.max_percentile is None 

1041 ): 

1042 raise ValueError( 

1043 "If `axes` is set, at least one of `min_percentile` or `max_percentile` must be set." 

1044 ) 

1045 

1046 return self 

1047 

1048 

1049class ClipDescr(NodeWithExplicitlySetFields): 

1050 """Set tensor values below min to min and above max to max. 

1051 

1052 See `ScaleRangeDescr` for examples. 

1053 """ 

1054 

1055 implemented_id: ClassVar[Literal["clip"]] = "clip" 

1056 if TYPE_CHECKING: 

1057 id: Literal["clip"] = "clip" 

1058 else: 

1059 id: Literal["clip"] 

1060 

1061 kwargs: ClipKwargs 

1062 

1063 

1064class EnsureDtypeKwargs(KwargsNode): 

1065 """key word arguments for [EnsureDtypeDescr][]""" 

1066 

1067 dtype: Literal[ 

1068 "float32", 

1069 "float64", 

1070 "uint8", 

1071 "int8", 

1072 "uint16", 

1073 "int16", 

1074 "uint32", 

1075 "int32", 

1076 "uint64", 

1077 "int64", 

1078 "bool", 

1079 ] 

1080 

1081 

1082class EnsureDtypeDescr(NodeWithExplicitlySetFields): 

1083 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

1084 

1085 This can for example be used to ensure the inner neural network model gets a 

1086 different input tensor data type than the fully described bioimage.io model does. 

1087 

1088 Examples: 

1089 The described bioimage.io model (incl. preprocessing) accepts any 

1090 float32-compatible tensor, normalizes it with percentiles and clipping and then 

1091 casts it to uint8, which is what the neural network in this example expects. 

1092 - in YAML 

1093 ```yaml 

1094 inputs: 

1095 - data: 

1096 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

1097 preprocessing: 

1098 - id: scale_range 

1099 kwargs: 

1100 axes: ['y', 'x'] 

1101 max_percentile: 99.8 

1102 min_percentile: 5.0 

1103 - id: clip 

1104 kwargs: 

1105 min: 0.0 

1106 max: 1.0 

1107 - id: ensure_dtype # the neural network of the model requires uint8 

1108 kwargs: 

1109 dtype: uint8 

1110 ``` 

1111 - in Python: 

1112 >>> preprocessing = [ 

1113 ... ScaleRangeDescr( 

1114 ... kwargs=ScaleRangeKwargs( 

1115 ... axes= (AxisId('y'), AxisId('x')), 

1116 ... max_percentile= 99.8, 

1117 ... min_percentile= 5.0, 

1118 ... ) 

1119 ... ), 

1120 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

1121 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

1122 ... ] 

1123 """ 

1124 

1125 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1126 if TYPE_CHECKING: 

1127 id: Literal["ensure_dtype"] = "ensure_dtype" 

1128 else: 

1129 id: Literal["ensure_dtype"] 

1130 

1131 kwargs: EnsureDtypeKwargs 

1132 

1133 

1134class ScaleLinearKwargs(KwargsNode): 

1135 """Key word arguments for [ScaleLinearDescr][]""" 

1136 

1137 gain: float = 1.0 

1138 """multiplicative factor""" 

1139 

1140 offset: float = 0.0 

1141 """additive term""" 

1142 

1143 @model_validator(mode="after") 

1144 def _validate(self) -> Self: 

1145 if self.gain == 1.0 and self.offset == 0.0: 

1146 raise ValueError( 

1147 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1148 + " != 0.0." 

1149 ) 

1150 

1151 return self 

1152 

1153 

1154class ScaleLinearAlongAxisKwargs(KwargsNode): 

1155 """Key word arguments for [ScaleLinearDescr][]""" 

1156 

1157 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1158 """The axis of gain and offset values.""" 

1159 

1160 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1161 """multiplicative factor""" 

1162 

1163 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1164 """additive term""" 

1165 

1166 @model_validator(mode="after") 

1167 def _validate(self) -> Self: 

1168 if isinstance(self.gain, list): 

1169 if isinstance(self.offset, list): 

1170 if len(self.gain) != len(self.offset): 

1171 raise ValueError( 

1172 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1173 ) 

1174 else: 

1175 self.offset = [float(self.offset)] * len(self.gain) 

1176 elif isinstance(self.offset, list): 

1177 self.gain = [float(self.gain)] * len(self.offset) 

1178 else: 

1179 raise ValueError( 

1180 "Do not specify an `axis` for scalar gain and offset values." 

1181 ) 

1182 

1183 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1184 raise ValueError( 

1185 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1186 + " != 0.0." 

1187 ) 

1188 

1189 return self 

1190 

1191 

1192class ScaleLinearDescr(NodeWithExplicitlySetFields): 

1193 """Fixed linear scaling. 

1194 

1195 Examples: 

1196 1. Scale with scalar gain and offset 

1197 - in YAML 

1198 ```yaml 

1199 preprocessing: 

1200 - id: scale_linear 

1201 kwargs: 

1202 gain: 2.0 

1203 offset: 3.0 

1204 ``` 

1205 - in Python: 

1206 >>> preprocessing = [ 

1207 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1208 ... ] 

1209 

1210 2. Independent scaling along an axis 

1211 - in YAML 

1212 ```yaml 

1213 preprocessing: 

1214 - id: scale_linear 

1215 kwargs: 

1216 axis: 'channel' 

1217 gain: [1.0, 2.0, 3.0] 

1218 ``` 

1219 - in Python: 

1220 >>> preprocessing = [ 

1221 ... ScaleLinearDescr( 

1222 ... kwargs=ScaleLinearAlongAxisKwargs( 

1223 ... axis=AxisId("channel"), 

1224 ... gain=[1.0, 2.0, 3.0], 

1225 ... ) 

1226 ... ) 

1227 ... ] 

1228 

1229 """ 

1230 

1231 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1232 if TYPE_CHECKING: 

1233 id: Literal["scale_linear"] = "scale_linear" 

1234 else: 

1235 id: Literal["scale_linear"] 

1236 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1237 

1238 

1239class SigmoidDescr(NodeWithExplicitlySetFields): 

1240 """The logistic sigmoid function, a.k.a. expit function. 

1241 

1242 Examples: 

1243 - in YAML 

1244 ```yaml 

1245 postprocessing: 

1246 - id: sigmoid 

1247 ``` 

1248 - in Python: 

1249 >>> postprocessing = [SigmoidDescr()] 

1250 """ 

1251 

1252 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1253 if TYPE_CHECKING: 

1254 id: Literal["sigmoid"] = "sigmoid" 

1255 else: 

1256 id: Literal["sigmoid"] 

1257 

1258 @property 

1259 def kwargs(self) -> KwargsNode: 

1260 """empty kwargs""" 

1261 return KwargsNode() 

1262 

1263 

1264class SoftmaxKwargs(KwargsNode): 

1265 """key word arguments for [SoftmaxDescr][]""" 

1266 

1267 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 

1268 """The axis to apply the softmax function along. 

1269 Note: 

1270 Defaults to 'channel' axis 

1271 (which may not exist, in which case 

1272 a different axis id has to be specified). 

1273 """ 

1274 

1275 

1276class SoftmaxDescr(NodeWithExplicitlySetFields): 

1277 """The softmax function. 

1278 

1279 Examples: 

1280 - in YAML 

1281 ```yaml 

1282 postprocessing: 

1283 - id: softmax 

1284 kwargs: 

1285 axis: channel 

1286 ``` 

1287 - in Python: 

1288 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 

1289 """ 

1290 

1291 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 

1292 if TYPE_CHECKING: 

1293 id: Literal["softmax"] = "softmax" 

1294 else: 

1295 id: Literal["softmax"] 

1296 

1297 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 

1298 

1299 

1300class _StardistPostprocessingKwargsBase(KwargsNode): 

1301 """key word arguments for [StardistPostprocessingDescr][]""" 

1302 

1303 prob_threshold: float 

1304 """The probability threshold for object candidate selection.""" 

1305 

1306 nms_threshold: float 

1307 """The IoU threshold for non-maximum suppression.""" 

1308 

1309 

1310class StardistPostprocessingKwargs2D(_StardistPostprocessingKwargsBase): 

1311 grid: Tuple[int, int] 

1312 """Grid size of network predictions.""" 

1313 

1314 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int]]] 

1315 """Border region in which object probability is set to zero.""" 

1316 

1317 

1318class StardistPostprocessingKwargs3D(_StardistPostprocessingKwargsBase): 

1319 grid: Tuple[int, int, int] 

1320 """Grid size of network predictions.""" 

1321 

1322 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]] 

1323 """Border region in which object probability is set to zero.""" 

1324 

1325 n_rays: int 

1326 """Number of rays for 3D star-convex polyhedra.""" 

1327 

1328 anisotropy: Tuple[float, float, float] 

1329 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis.""" 

1330 

1331 overlap_label: Optional[int] = None 

1332 """Optional label to apply to any area of overlapping predicted objects.""" 

1333 

1334 

1335class StardistPostprocessingDescr(NodeWithExplicitlySetFields): 

1336 """Stardist postprocessing including non-maximum suppression and converting polygon representations to instance labels 

1337 

1338 as described in: 

1339 - Uwe Schmidt, Martin Weigert, Coleman Broaddus, and Gene Myers. 

1340 [*Cell Detection with Star-convex Polygons*](https://arxiv.org/abs/1806.03535). 

1341 International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. 

1342 - Martin Weigert, Uwe Schmidt, Robert Haase, Ko Sugawara, and Gene Myers. 

1343 [*Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy*](http://openaccess.thecvf.com/content_WACV_2020/papers/Weigert_Star-convex_Polyhedra_for_3D_Object_Detection_and_Segmentation_in_Microscopy_WACV_2020_paper.pdf). 

1344 The IEEE Winter Conference on Applications of Computer Vision (WACV), Snowmass Village, Colorado, March 2020. 

1345 

1346 Note: Only available if the `stardist` package is installed. 

1347 """ 

1348 

1349 implemented_id: ClassVar[Literal["stardist_postprocessing"]] = ( 

1350 "stardist_postprocessing" 

1351 ) 

1352 if TYPE_CHECKING: 

1353 id: Literal["stardist_postprocessing"] = "stardist_postprocessing" 

1354 else: 

1355 id: Literal["stardist_postprocessing"] 

1356 

1357 kwargs: Union[StardistPostprocessingKwargs2D, StardistPostprocessingKwargs3D] 

1358 

1359 

1360class FixedZeroMeanUnitVarianceKwargs(KwargsNode): 

1361 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]""" 

1362 

1363 mean: float 

1364 """The mean value to normalize with.""" 

1365 

1366 std: Annotated[float, Ge(1e-6)] 

1367 """The standard deviation value to normalize with.""" 

1368 

1369 

1370class FixedZeroMeanUnitVarianceAlongAxisKwargs(KwargsNode): 

1371 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]""" 

1372 

1373 mean: NotEmpty[List[float]] 

1374 """The mean value(s) to normalize with.""" 

1375 

1376 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1377 """The standard deviation value(s) to normalize with. 

1378 Size must match `mean` values.""" 

1379 

1380 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1381 """The axis of the mean/std values to normalize each entry along that dimension 

1382 separately.""" 

1383 

1384 @model_validator(mode="after") 

1385 def _mean_and_std_match(self) -> Self: 

1386 if len(self.mean) != len(self.std): 

1387 raise ValueError( 

1388 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1389 + " must match." 

1390 ) 

1391 

1392 return self 

1393 

1394 

1395class FixedZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

1396 """Subtract a given mean and divide by the standard deviation. 

1397 

1398 Normalize with fixed, precomputed values for 

1399 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1400 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1401 axes. 

1402 

1403 Examples: 

1404 1. scalar value for whole tensor 

1405 - in YAML 

1406 ```yaml 

1407 preprocessing: 

1408 - id: fixed_zero_mean_unit_variance 

1409 kwargs: 

1410 mean: 103.5 

1411 std: 13.7 

1412 ``` 

1413 - in Python 

1414 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1415 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1416 ... )] 

1417 

1418 2. independently along an axis 

1419 - in YAML 

1420 ```yaml 

1421 preprocessing: 

1422 - id: fixed_zero_mean_unit_variance 

1423 kwargs: 

1424 axis: channel 

1425 mean: [101.5, 102.5, 103.5] 

1426 std: [11.7, 12.7, 13.7] 

1427 ``` 

1428 - in Python 

1429 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1430 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1431 ... axis=AxisId("channel"), 

1432 ... mean=[101.5, 102.5, 103.5], 

1433 ... std=[11.7, 12.7, 13.7], 

1434 ... ) 

1435 ... )] 

1436 """ 

1437 

1438 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1439 "fixed_zero_mean_unit_variance" 

1440 ) 

1441 if TYPE_CHECKING: 

1442 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1443 else: 

1444 id: Literal["fixed_zero_mean_unit_variance"] 

1445 

1446 kwargs: Union[ 

1447 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1448 ] 

1449 

1450 

1451class ZeroMeanUnitVarianceKwargs(KwargsNode): 

1452 """key word arguments for [ZeroMeanUnitVarianceDescr][]""" 

1453 

1454 axes: Annotated[ 

1455 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1456 ] = None 

1457 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1458 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1459 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1460 To normalize each sample independently leave out the 'batch' axis. 

1461 Default: Scale all axes jointly.""" 

1462 

1463 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1464 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1465 

1466 

1467class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

1468 """Subtract mean and divide by variance. 

1469 

1470 Examples: 

1471 Subtract tensor mean and variance 

1472 - in YAML 

1473 ```yaml 

1474 preprocessing: 

1475 - id: zero_mean_unit_variance 

1476 ``` 

1477 - in Python 

1478 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1479 """ 

1480 

1481 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1482 "zero_mean_unit_variance" 

1483 ) 

1484 if TYPE_CHECKING: 

1485 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1486 else: 

1487 id: Literal["zero_mean_unit_variance"] 

1488 

1489 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1490 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 

1491 ) 

1492 

1493 

1494class ScaleRangeKwargs(KwargsNode): 

1495 """key word arguments for [ScaleRangeDescr][] 

1496 

1497 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1498 this processing step normalizes data to the [0, 1] intervall. 

1499 For other percentiles the normalized values will partially be outside the [0, 1] 

1500 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1501 normalized values to a range. 

1502 """ 

1503 

1504 axes: Annotated[ 

1505 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1506 ] = None 

1507 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1508 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1509 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1510 To normalize samples independently, leave out the "batch" axis. 

1511 Default: Scale all axes jointly.""" 

1512 

1513 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1514 """The lower percentile used to determine the value to align with zero.""" 

1515 

1516 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1517 """The upper percentile used to determine the value to align with one. 

1518 Has to be bigger than `min_percentile`. 

1519 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1520 accepting percentiles specified in the range 0.0 to 1.0.""" 

1521 

1522 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1523 """Epsilon for numeric stability. 

1524 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1525 with `v_lower,v_upper` values at the respective percentiles.""" 

1526 

1527 reference_tensor: Optional[TensorId] = None 

1528 """ID of the unprocessed input tensor to compute the percentiles from. 

1529 Default: The tensor itself. 

1530 """ 

1531 

1532 @field_validator("max_percentile", mode="after") 

1533 @classmethod 

1534 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1535 if (min_p := info.data["min_percentile"]) >= value: 

1536 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1537 

1538 return value 

1539 

1540 

1541class ScaleRangeDescr(NodeWithExplicitlySetFields): 

1542 """Scale with percentiles. 

1543 

1544 Examples: 

1545 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1546 - in YAML 

1547 ```yaml 

1548 preprocessing: 

1549 - id: scale_range 

1550 kwargs: 

1551 axes: ['y', 'x'] 

1552 max_percentile: 99.8 

1553 min_percentile: 5.0 

1554 ``` 

1555 - in Python 

1556 >>> preprocessing = [ 

1557 ... ScaleRangeDescr( 

1558 ... kwargs=ScaleRangeKwargs( 

1559 ... axes= (AxisId('y'), AxisId('x')), 

1560 ... max_percentile= 99.8, 

1561 ... min_percentile= 5.0, 

1562 ... ) 

1563 ... ) 

1564 ... ] 

1565 

1566 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1567 - in YAML 

1568 ```yaml 

1569 preprocessing: 

1570 - id: scale_range 

1571 kwargs: 

1572 axes: ['y', 'x'] 

1573 max_percentile: 99.8 

1574 min_percentile: 5.0 

1575 - id: scale_range 

1576 - id: clip 

1577 kwargs: 

1578 min: 0.0 

1579 max: 1.0 

1580 ``` 

1581 - in Python 

1582 >>> preprocessing = [ 

1583 ... ScaleRangeDescr( 

1584 ... kwargs=ScaleRangeKwargs( 

1585 ... axes= (AxisId('y'), AxisId('x')), 

1586 ... max_percentile= 99.8, 

1587 ... min_percentile= 5.0, 

1588 ... ) 

1589 ... ), 

1590 ... ClipDescr( 

1591 ... kwargs=ClipKwargs( 

1592 ... min=0.0, 

1593 ... max=1.0, 

1594 ... ) 

1595 ... ), 

1596 ... ] 

1597 

1598 """ 

1599 

1600 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1601 if TYPE_CHECKING: 

1602 id: Literal["scale_range"] = "scale_range" 

1603 else: 

1604 id: Literal["scale_range"] 

1605 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 

1606 

1607 

1608class ScaleMeanVarianceKwargs(KwargsNode): 

1609 """key word arguments for [ScaleMeanVarianceKwargs][]""" 

1610 

1611 reference_tensor: TensorId 

1612 """ID of unprocessed input tensor to match.""" 

1613 

1614 axes: Annotated[ 

1615 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1616 ] = None 

1617 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1618 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1619 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1620 To normalize samples independently, leave out the 'batch' axis. 

1621 Default: Scale all axes jointly.""" 

1622 

1623 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1624 """Epsilon for numeric stability: 

1625 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1626 

1627 

1628class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields): 

1629 """Scale a tensor's data distribution to match another tensor's mean/std. 

1630 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1631 """ 

1632 

1633 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1634 if TYPE_CHECKING: 

1635 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1636 else: 

1637 id: Literal["scale_mean_variance"] 

1638 kwargs: ScaleMeanVarianceKwargs 

1639 

1640 

1641PreprocessingDescr = Annotated[ 

1642 Union[ 

1643 BinarizeDescr, 

1644 ClipDescr, 

1645 EnsureDtypeDescr, 

1646 FixedZeroMeanUnitVarianceDescr, 

1647 ScaleLinearDescr, 

1648 ScaleRangeDescr, 

1649 SigmoidDescr, 

1650 SoftmaxDescr, 

1651 ZeroMeanUnitVarianceDescr, 

1652 ], 

1653 Discriminator("id"), 

1654] 

1655PostprocessingDescr = Annotated[ 

1656 Union[ 

1657 BinarizeDescr, 

1658 ClipDescr, 

1659 EnsureDtypeDescr, 

1660 FixedZeroMeanUnitVarianceDescr, 

1661 ScaleLinearDescr, 

1662 ScaleMeanVarianceDescr, 

1663 ScaleRangeDescr, 

1664 SigmoidDescr, 

1665 SoftmaxDescr, 

1666 StardistPostprocessingDescr, 

1667 ZeroMeanUnitVarianceDescr, 

1668 ], 

1669 Discriminator("id"), 

1670] 

1671 

1672IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1673 

1674 

1675class TensorDescrBase(Node, Generic[IO_AxisT]): 

1676 id: TensorId 

1677 """Tensor id. No duplicates are allowed.""" 

1678 

1679 description: Annotated[str, MaxLen(128)] = "" 

1680 """free text description""" 

1681 

1682 axes: NotEmpty[Sequence[IO_AxisT]] 

1683 """tensor axes""" 

1684 

1685 @property 

1686 def shape(self): 

1687 return tuple(a.size for a in self.axes) 

1688 

1689 @field_validator("axes", mode="after", check_fields=False) 

1690 @classmethod 

1691 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1692 batch_axes = [a for a in axes if a.type == "batch"] 

1693 if len(batch_axes) > 1: 

1694 raise ValueError( 

1695 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1696 ) 

1697 

1698 seen_ids: Set[AxisId] = set() 

1699 duplicate_axes_ids: Set[AxisId] = set() 

1700 for a in axes: 

1701 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1702 

1703 if duplicate_axes_ids: 

1704 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1705 

1706 return axes 

1707 

1708 test_tensor: FAIR[Optional[FileDescr_]] = None 

1709 """An example tensor to use for testing. 

1710 Using the model with the test input tensors is expected to yield the test output tensors. 

1711 Each test tensor has be a an ndarray in the 

1712 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1713 The file extension must be '.npy'.""" 

1714 

1715 sample_tensor: FAIR[Optional[FileDescr_]] = None 

1716 """A sample tensor to illustrate a possible input/output for the model, 

1717 The sample image primarily serves to inform a human user about an example use case 

1718 and is typically stored as .hdf5, .png or .tiff. 

1719 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1720 (numpy's `.npy` format is not supported). 

1721 The image dimensionality has to match the number of axes specified in this tensor description. 

1722 """ 

1723 

1724 @model_validator(mode="after") 

1725 def _validate_sample_tensor(self) -> Self: 

1726 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1727 return self 

1728 

1729 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1730 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType] 

1731 reader.read(), 

1732 extension=PurePosixPath(reader.original_file_name).suffix, 

1733 ) 

1734 n_dims = len(tensor.squeeze().shape) 

1735 n_dims_min = n_dims_max = len(self.axes) 

1736 

1737 for a in self.axes: 

1738 if isinstance(a, BatchAxis): 

1739 n_dims_min -= 1 

1740 elif isinstance(a.size, int): 

1741 if a.size == 1: 

1742 n_dims_min -= 1 

1743 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1744 if a.size.min == 1: 

1745 n_dims_min -= 1 

1746 elif isinstance(a.size, SizeReference): 

1747 if a.size.offset < 2: 

1748 # size reference may result in singleton axis 

1749 n_dims_min -= 1 

1750 else: 

1751 assert_never(a.size) 

1752 

1753 n_dims_min = max(0, n_dims_min) 

1754 if n_dims < n_dims_min or n_dims > n_dims_max: 

1755 raise ValueError( 

1756 f"Expected sample tensor to have {n_dims_min} to" 

1757 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1758 ) 

1759 

1760 return self 

1761 

1762 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1763 IntervalOrRatioDataDescr() 

1764 ) 

1765 """Description of the tensor's data values, optionally per channel. 

1766 If specified per channel, the data `type` needs to match across channels.""" 

1767 

1768 @property 

1769 def dtype( 

1770 self, 

1771 ) -> Literal[ 

1772 "float32", 

1773 "float64", 

1774 "uint8", 

1775 "int8", 

1776 "uint16", 

1777 "int16", 

1778 "uint32", 

1779 "int32", 

1780 "uint64", 

1781 "int64", 

1782 "bool", 

1783 ]: 

1784 """dtype as specified under `data.type` or `data[i].type`""" 

1785 if isinstance(self.data, collections.abc.Sequence): 

1786 return self.data[0].type 

1787 else: 

1788 return self.data.type 

1789 

1790 @field_validator("data", mode="after") 

1791 @classmethod 

1792 def _check_data_type_across_channels( 

1793 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1794 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1795 if not isinstance(value, list): 

1796 return value 

1797 

1798 dtypes = {t.type for t in value} 

1799 if len(dtypes) > 1: 

1800 raise ValueError( 

1801 "Tensor data descriptions per channel need to agree in their data" 

1802 + f" `type`, but found {dtypes}." 

1803 ) 

1804 

1805 return value 

1806 

1807 @model_validator(mode="after") 

1808 def _check_data_matches_channelaxis(self) -> Self: 

1809 if not isinstance(self.data, (list, tuple)): 

1810 return self 

1811 

1812 for a in self.axes: 

1813 if isinstance(a, ChannelAxis): 

1814 size = a.size 

1815 assert isinstance(size, int) 

1816 break 

1817 else: 

1818 return self 

1819 

1820 if len(self.data) != size: 

1821 raise ValueError( 

1822 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1823 + f" '{a.id}' axis has size {size}." 

1824 ) 

1825 

1826 return self 

1827 

1828 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1829 if len(array.shape) != len(self.axes): 

1830 raise ValueError( 

1831 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1832 + f" incompatible with {len(self.axes)} axes." 

1833 ) 

1834 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1835 

1836 

1837class InputTensorDescr(TensorDescrBase[InputAxis]): 

1838 id: TensorId = TensorId("input") 

1839 """Input tensor id. 

1840 No duplicates are allowed across all inputs and outputs.""" 

1841 

1842 optional: bool = False 

1843 """indicates that this tensor may be `None`""" 

1844 

1845 preprocessing: List[PreprocessingDescr] = Field( 

1846 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 

1847 ) 

1848 

1849 """Description of how this input should be preprocessed. 

1850 

1851 notes: 

1852 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1853 to ensure an input tensor's data type matches the input tensor's data description. 

1854 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1855 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1856 changing the data type. 

1857 """ 

1858 

1859 @model_validator(mode="after") 

1860 def _validate_preprocessing_kwargs(self) -> Self: 

1861 axes_ids = [a.id for a in self.axes] 

1862 for p in self.preprocessing: 

1863 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1864 if kwargs_axes is None: 

1865 continue 

1866 

1867 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1868 raise ValueError( 

1869 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1870 ) 

1871 

1872 if any(a not in axes_ids for a in kwargs_axes): 

1873 raise ValueError( 

1874 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1875 ) 

1876 

1877 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1878 dtype = self.data.type 

1879 else: 

1880 dtype = self.data[0].type 

1881 

1882 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1883 if not self.preprocessing or not isinstance( 

1884 self.preprocessing[0], EnsureDtypeDescr 

1885 ): 

1886 self.preprocessing.insert( 

1887 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1888 ) 

1889 

1890 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1891 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1892 self.preprocessing.append( 

1893 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1894 ) 

1895 

1896 return self 

1897 

1898 

1899def convert_axes( 

1900 axes: str, 

1901 *, 

1902 shape: Union[ 

1903 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1904 ], 

1905 tensor_type: Literal["input", "output"], 

1906 halo: Optional[Sequence[int]], 

1907 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1908): 

1909 ret: List[AnyAxis] = [] 

1910 for i, a in enumerate(axes): 

1911 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1912 if axis_type == "batch": 

1913 ret.append(BatchAxis()) 

1914 continue 

1915 

1916 scale = 1.0 

1917 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1918 if shape.step[i] == 0: 

1919 size = shape.min[i] 

1920 else: 

1921 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1922 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1923 ref_t = str(shape.reference_tensor) 

1924 if ref_t.count(".") == 1: 

1925 t_id, orig_a_id = ref_t.split(".") 

1926 else: 

1927 t_id = ref_t 

1928 orig_a_id = a 

1929 

1930 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1931 if not (orig_scale := shape.scale[i]): 

1932 # old way to insert a new axis dimension 

1933 size = int(2 * shape.offset[i]) 

1934 else: 

1935 scale = 1 / orig_scale 

1936 if axis_type in ("channel", "index"): 

1937 # these axes no longer have a scale 

1938 offset_from_scale = orig_scale * size_refs.get( 

1939 _TensorName_v0_4(t_id), {} 

1940 ).get(orig_a_id, 0) 

1941 else: 

1942 offset_from_scale = 0 

1943 size = SizeReference( 

1944 tensor_id=TensorId(t_id), 

1945 axis_id=AxisId(a_id), 

1946 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1947 ) 

1948 else: 

1949 size = shape[i] 

1950 

1951 if axis_type == "time": 

1952 if tensor_type == "input": 

1953 ret.append(TimeInputAxis(size=size, scale=scale)) 

1954 else: 

1955 assert not isinstance(size, ParameterizedSize) 

1956 if halo is None: 

1957 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1958 else: 

1959 assert not isinstance(size, int) 

1960 ret.append( 

1961 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1962 ) 

1963 

1964 elif axis_type == "index": 

1965 if tensor_type == "input": 

1966 ret.append(IndexInputAxis(size=size)) 

1967 else: 

1968 if isinstance(size, ParameterizedSize): 

1969 size = DataDependentSize(min=size.min) 

1970 

1971 ret.append(IndexOutputAxis(size=size)) 

1972 elif axis_type == "channel": 

1973 assert not isinstance(size, ParameterizedSize) 

1974 if isinstance(size, SizeReference): 

1975 warnings.warn( 

1976 "Conversion of channel size from an implicit output shape may be" 

1977 + " wrong" 

1978 ) 

1979 ret.append( 

1980 ChannelAxis( 

1981 channel_names=[ 

1982 Identifier(f"channel{i}") for i in range(size.offset) 

1983 ] 

1984 ) 

1985 ) 

1986 else: 

1987 ret.append( 

1988 ChannelAxis( 

1989 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1990 ) 

1991 ) 

1992 elif axis_type == "space": 

1993 if tensor_type == "input": 

1994 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

1995 else: 

1996 assert not isinstance(size, ParameterizedSize) 

1997 if halo is None or halo[i] == 0: 

1998 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

1999 elif isinstance(size, int): 

2000 raise NotImplementedError( 

2001 f"output axis with halo and fixed size (here {size}) not allowed" 

2002 ) 

2003 else: 

2004 ret.append( 

2005 SpaceOutputAxisWithHalo( 

2006 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

2007 ) 

2008 ) 

2009 

2010 return ret 

2011 

2012 

2013def _axes_letters_to_ids( 

2014 axes: Optional[str], 

2015) -> Optional[List[AxisId]]: 

2016 if axes is None: 

2017 return None 

2018 

2019 return [AxisId(a) for a in axes] 

2020 

2021 

2022def _get_complement_v04_axis( 

2023 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

2024) -> Optional[AxisId]: 

2025 if axes is None: 

2026 return None 

2027 

2028 non_complement_axes = set(axes) | {"b"} 

2029 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

2030 if len(complement_axes) > 1: 

2031 raise ValueError( 

2032 f"Expected none or a single complement axis, but axes '{axes}' " 

2033 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

2034 ) 

2035 

2036 return None if not complement_axes else AxisId(complement_axes[0]) 

2037 

2038 

2039def _convert_proc( 

2040 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

2041 tensor_axes: Sequence[str], 

2042) -> Union[PreprocessingDescr, PostprocessingDescr]: 

2043 if isinstance(p, _BinarizeDescr_v0_4): 

2044 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

2045 elif isinstance(p, _ClipDescr_v0_4): 

2046 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

2047 elif isinstance(p, _SigmoidDescr_v0_4): 

2048 return SigmoidDescr() 

2049 elif isinstance(p, _ScaleLinearDescr_v0_4): 

2050 axes = _axes_letters_to_ids(p.kwargs.axes) 

2051 if p.kwargs.axes is None: 

2052 axis = None 

2053 else: 

2054 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

2055 

2056 if axis is None: 

2057 assert not isinstance(p.kwargs.gain, list) 

2058 assert not isinstance(p.kwargs.offset, list) 

2059 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

2060 else: 

2061 kwargs = ScaleLinearAlongAxisKwargs( 

2062 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

2063 ) 

2064 return ScaleLinearDescr(kwargs=kwargs) 

2065 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

2066 return ScaleMeanVarianceDescr( 

2067 kwargs=ScaleMeanVarianceKwargs( 

2068 axes=_axes_letters_to_ids(p.kwargs.axes), 

2069 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

2070 eps=p.kwargs.eps, 

2071 ) 

2072 ) 

2073 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

2074 if p.kwargs.mode == "fixed": 

2075 mean = p.kwargs.mean 

2076 std = p.kwargs.std 

2077 assert mean is not None 

2078 assert std is not None 

2079 

2080 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

2081 

2082 if axis is None: 

2083 if isinstance(mean, list): 

2084 raise ValueError("Expected single float value for mean, not <list>") 

2085 if isinstance(std, list): 

2086 raise ValueError("Expected single float value for std, not <list>") 

2087 return FixedZeroMeanUnitVarianceDescr( 

2088 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct( 

2089 mean=mean, 

2090 std=std, 

2091 ) 

2092 ) 

2093 else: 

2094 if not isinstance(mean, list): 

2095 mean = [float(mean)] 

2096 if not isinstance(std, list): 

2097 std = [float(std)] 

2098 

2099 return FixedZeroMeanUnitVarianceDescr( 

2100 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

2101 axis=axis, mean=mean, std=std 

2102 ) 

2103 ) 

2104 

2105 else: 

2106 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

2107 if p.kwargs.mode == "per_dataset": 

2108 axes = [AxisId("batch")] + axes 

2109 if not axes: 

2110 axes = None 

2111 return ZeroMeanUnitVarianceDescr( 

2112 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

2113 ) 

2114 

2115 elif isinstance(p, _ScaleRangeDescr_v0_4): 

2116 return ScaleRangeDescr( 

2117 kwargs=ScaleRangeKwargs( 

2118 axes=_axes_letters_to_ids(p.kwargs.axes), 

2119 min_percentile=p.kwargs.min_percentile, 

2120 max_percentile=p.kwargs.max_percentile, 

2121 eps=p.kwargs.eps, 

2122 ) 

2123 ) 

2124 else: 

2125 assert_never(p) 

2126 

2127 

2128class _InputTensorConv( 

2129 Converter[ 

2130 _InputTensorDescr_v0_4, 

2131 InputTensorDescr, 

2132 FileSource_, 

2133 Optional[FileSource_], 

2134 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2135 ] 

2136): 

2137 def _convert( 

2138 self, 

2139 src: _InputTensorDescr_v0_4, 

2140 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

2141 test_tensor: FileSource_, 

2142 sample_tensor: Optional[FileSource_], 

2143 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2144 ) -> "InputTensorDescr | dict[str, Any]": 

2145 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2146 src.axes, 

2147 shape=src.shape, 

2148 tensor_type="input", 

2149 halo=None, 

2150 size_refs=size_refs, 

2151 ) 

2152 prep: List[PreprocessingDescr] = [] 

2153 for p in src.preprocessing: 

2154 cp = _convert_proc(p, src.axes) 

2155 assert not isinstance( 

2156 cp, (ScaleMeanVarianceDescr, StardistPostprocessingDescr) 

2157 ) 

2158 prep.append(cp) 

2159 

2160 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

2161 

2162 return tgt( 

2163 axes=axes, 

2164 id=TensorId(str(src.name)), 

2165 test_tensor=FileDescr(source=test_tensor), 

2166 sample_tensor=( 

2167 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2168 ), 

2169 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

2170 preprocessing=prep, 

2171 ) 

2172 

2173 

2174_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

2175 

2176 

2177class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

2178 id: TensorId = TensorId("output") 

2179 """Output tensor id. 

2180 No duplicates are allowed across all inputs and outputs.""" 

2181 

2182 postprocessing: List[PostprocessingDescr] = Field( 

2183 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

2184 ) 

2185 """Description of how this output should be postprocessed. 

2186 

2187 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

2188 If not given this is added to cast to this tensor's `data.type`. 

2189 """ 

2190 

2191 @model_validator(mode="after") 

2192 def _validate_postprocessing_kwargs(self) -> Self: 

2193 axes_ids = [a.id for a in self.axes] 

2194 for p in self.postprocessing: 

2195 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

2196 if kwargs_axes is None: 

2197 continue 

2198 

2199 if not isinstance(kwargs_axes, collections.abc.Sequence): 

2200 raise ValueError( 

2201 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

2202 ) 

2203 

2204 if any(a not in axes_ids for a in kwargs_axes): 

2205 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

2206 

2207 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

2208 dtype = self.data.type 

2209 else: 

2210 dtype = self.data[0].type 

2211 

2212 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

2213 if not self.postprocessing or not isinstance( 

2214 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

2215 ): 

2216 self.postprocessing.append( 

2217 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2218 ) 

2219 return self 

2220 

2221 

2222class _OutputTensorConv( 

2223 Converter[ 

2224 _OutputTensorDescr_v0_4, 

2225 OutputTensorDescr, 

2226 FileSource_, 

2227 Optional[FileSource_], 

2228 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2229 ] 

2230): 

2231 def _convert( 

2232 self, 

2233 src: _OutputTensorDescr_v0_4, 

2234 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

2235 test_tensor: FileSource_, 

2236 sample_tensor: Optional[FileSource_], 

2237 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2238 ) -> "OutputTensorDescr | dict[str, Any]": 

2239 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2240 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2241 src.axes, 

2242 shape=src.shape, 

2243 tensor_type="output", 

2244 halo=src.halo, 

2245 size_refs=size_refs, 

2246 ) 

2247 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2248 if data_descr["type"] == "bool": 

2249 data_descr["values"] = [False, True] 

2250 

2251 return tgt( 

2252 axes=axes, 

2253 id=TensorId(str(src.name)), 

2254 test_tensor=FileDescr(source=test_tensor), 

2255 sample_tensor=( 

2256 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2257 ), 

2258 data=data_descr, # pyright: ignore[reportArgumentType] 

2259 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2260 ) 

2261 

2262 

2263_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2264 

2265 

2266TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2267 

2268 

2269def validate_tensors( 

2270 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 

2271 tensor_origin: Literal[ 

2272 "source", "test_tensor" 

2273 ] = "source", # for more precise error messages 

2274): 

2275 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 

2276 

2277 def e_msg_location(d: TensorDescr): 

2278 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2279 

2280 for descr, array in tensors.values(): 

2281 if array is None: 

2282 axis_sizes = {a.id: None for a in descr.axes} 

2283 else: 

2284 try: 

2285 axis_sizes = descr.get_axis_sizes_for_array(array) 

2286 except ValueError as e: 

2287 raise ValueError(f"{e_msg_location(descr)} {e}") 

2288 

2289 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 

2290 

2291 for descr, array in tensors.values(): 

2292 if array is None: 

2293 continue 

2294 

2295 if descr.dtype in ("float32", "float64"): 

2296 invalid_test_tensor_dtype = array.dtype.name not in ( 

2297 "float32", 

2298 "float64", 

2299 "uint8", 

2300 "int8", 

2301 "uint16", 

2302 "int16", 

2303 "uint32", 

2304 "int32", 

2305 "uint64", 

2306 "int64", 

2307 ) 

2308 else: 

2309 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2310 

2311 if invalid_test_tensor_dtype: 

2312 raise ValueError( 

2313 f"{tensor_origin} data type '{array.dtype.name}' does not" 

2314 + f" match described {e_msg_location(descr)}.dtype '{descr.dtype}'" 

2315 ) 

2316 

2317 if array.min() > -1e-4 and array.max() < 1e-4: 

2318 raise ValueError( 

2319 "Output values are too small for reliable testing." 

2320 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2321 ) 

2322 

2323 for a in descr.axes: 

2324 actual_size = all_tensor_axes[descr.id][a.id][1] 

2325 if actual_size is None: 

2326 continue 

2327 

2328 if a.size is None: 

2329 continue 

2330 

2331 if isinstance(a.size, int): 

2332 if actual_size != a.size: 

2333 raise ValueError( 

2334 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis " 

2335 + f"has incompatible size {actual_size}, expected {a.size}" 

2336 ) 

2337 elif isinstance(a.size, ParameterizedSize): 

2338 _ = a.size.validate_size( 

2339 actual_size, 

2340 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis ", 

2341 ) 

2342 elif isinstance(a.size, DataDependentSize): 

2343 _ = a.size.validate_size( 

2344 actual_size, 

2345 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis ", 

2346 ) 

2347 elif isinstance(a.size, SizeReference): 

2348 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2349 if ref_tensor_axes is None: 

2350 raise ValueError( 

2351 f"{e_msg_location(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2352 + f" reference '{a.size.tensor_id}', available: {list(all_tensor_axes)}" 

2353 ) 

2354 

2355 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2356 if ref_axis is None or ref_size is None: 

2357 raise ValueError( 

2358 f"{e_msg_location(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2359 + f" reference '{a.size.tensor_id}.{a.size.axis_id}, available: {list(ref_tensor_axes)}" 

2360 ) 

2361 

2362 if a.unit != ref_axis.unit: 

2363 raise ValueError( 

2364 f"{e_msg_location(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2365 + " axis and reference axis to have the same `unit`, but" 

2366 + f" {a.unit}!={ref_axis.unit}" 

2367 ) 

2368 

2369 if actual_size != ( 

2370 expected_size := ( 

2371 ref_size * ref_axis.scale / a.scale + a.size.offset 

2372 ) 

2373 ): 

2374 raise ValueError( 

2375 f"{e_msg_location(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2376 + f" {actual_size} invalid for referenced size {ref_size};" 

2377 + f" expected {expected_size}" 

2378 ) 

2379 else: 

2380 assert_never(a.size) 

2381 

2382 

2383FileDescr_dependencies = Annotated[ 

2384 FileDescr_, 

2385 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2386 Field(examples=[dict(source="environment.yaml")]), 

2387] 

2388 

2389 

2390class _ArchitectureCallableDescr(Node): 

2391 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2392 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2393 

2394 kwargs: Dict[str, YamlValue] = Field( 

2395 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2396 ) 

2397 """key word arguments for the `callable`""" 

2398 

2399 

2400class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2401 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2402 """Architecture source file""" 

2403 

2404 @model_serializer(mode="wrap", when_used="unless-none") 

2405 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2406 return package_file_descr_serializer(self, nxt, info) 

2407 

2408 

2409class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2410 import_from: str 

2411 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2412 

2413 

2414class _ArchFileConv( 

2415 Converter[ 

2416 _CallableFromFile_v0_4, 

2417 ArchitectureFromFileDescr, 

2418 Optional[Sha256], 

2419 Dict[str, Any], 

2420 ] 

2421): 

2422 def _convert( 

2423 self, 

2424 src: _CallableFromFile_v0_4, 

2425 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2426 sha256: Optional[Sha256], 

2427 kwargs: Dict[str, Any], 

2428 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2429 if src.startswith("http") and src.count(":") == 2: 

2430 http, source, callable_ = src.split(":") 

2431 source = ":".join((http, source)) 

2432 elif not src.startswith("http") and src.count(":") == 1: 

2433 source, callable_ = src.split(":") 

2434 else: 

2435 source = str(src) 

2436 callable_ = str(src) 

2437 return tgt( 

2438 callable=Identifier(callable_), 

2439 source=cast(FileSource_, source), 

2440 sha256=sha256, 

2441 kwargs=kwargs, 

2442 ) 

2443 

2444 

2445_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2446 

2447 

2448class _ArchLibConv( 

2449 Converter[ 

2450 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2451 ] 

2452): 

2453 def _convert( 

2454 self, 

2455 src: _CallableFromDepencency_v0_4, 

2456 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2457 kwargs: Dict[str, Any], 

2458 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2459 *mods, callable_ = src.split(".") 

2460 import_from = ".".join(mods) 

2461 return tgt( 

2462 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2463 ) 

2464 

2465 

2466_arch_lib_conv = _ArchLibConv( 

2467 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2468) 

2469 

2470 

2471class WeightsEntryDescrBase(FileDescr): 

2472 type: ClassVar[WeightsFormat] 

2473 weights_format_name: ClassVar[str] # human readable 

2474 

2475 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2476 """Source of the weights file.""" 

2477 

2478 authors: Optional[List[Author]] = None 

2479 """Authors 

2480 Either the person(s) that have trained this model resulting in the original weights file. 

2481 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2482 Or the person(s) who have converted the weights to this weights format. 

2483 (If this is a child weight, i.e. it has a `parent` field) 

2484 """ 

2485 

2486 parent: Annotated[ 

2487 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2488 ] = None 

2489 """The source weights these weights were converted from. 

2490 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2491 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2492 All weight entries except one (the initial set of weights resulting from training the model), 

2493 need to have this field.""" 

2494 

2495 comment: str = "" 

2496 """A comment about this weights entry, for example how these weights were created.""" 

2497 

2498 @model_validator(mode="after") 

2499 def _validate(self) -> Self: 

2500 if self.type == self.parent: 

2501 raise ValueError("Weights entry can't be it's own parent.") 

2502 

2503 return self 

2504 

2505 @model_serializer(mode="wrap", when_used="unless-none") 

2506 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2507 return package_file_descr_serializer(self, nxt, info) 

2508 

2509 

2510class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2511 type: ClassVar[WeightsFormat] = "keras_hdf5" 

2512 weights_format_name: ClassVar[str] = "Keras HDF5" 

2513 tensorflow_version: Version 

2514 """TensorFlow version used to create these weights.""" 

2515 

2516 

2517class KerasV3WeightsDescr(WeightsEntryDescrBase): 

2518 type: ClassVar[WeightsFormat] = "keras_v3" 

2519 weights_format_name: ClassVar[str] = "Keras v3" 

2520 keras_version: Annotated[Version, Ge(Version(3))] 

2521 """Keras version used to create these weights.""" 

2522 backend: Tuple[Literal["tensorflow", "jax", "torch"], Version] 

2523 """Keras backend used to create these weights.""" 

2524 source: Annotated[ 

2525 FileSource, 

2526 AfterValidator(wo_special_file_name), 

2527 WithSuffix(".keras", case_sensitive=True), 

2528 ] 

2529 """Source of the .keras weights file.""" 

2530 

2531 

2532FileDescr_external_data = Annotated[ 

2533 FileDescr_, 

2534 WithSuffix(".data", case_sensitive=True), 

2535 Field(examples=[dict(source="weights.onnx.data")]), 

2536] 

2537 

2538 

2539class OnnxWeightsDescr(WeightsEntryDescrBase): 

2540 type: ClassVar[WeightsFormat] = "onnx" 

2541 weights_format_name: ClassVar[str] = "ONNX" 

2542 opset_version: Annotated[int, Ge(7)] 

2543 """ONNX opset version""" 

2544 

2545 external_data: Optional[FileDescr_external_data] = None 

2546 """Source of the external ONNX data file holding the weights. 

2547 (If present **source** holds the ONNX architecture without weights).""" 

2548 

2549 @model_validator(mode="after") 

2550 def _validate_external_data_unique_file_name(self) -> Self: 

2551 if self.external_data is not None and ( 

2552 extract_file_name(self.source) 

2553 == extract_file_name(self.external_data.source) 

2554 ): 

2555 raise ValueError( 

2556 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'" 

2557 + " must be different from ONNX `source` file name." 

2558 ) 

2559 

2560 return self 

2561 

2562 

2563class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2564 type: ClassVar[WeightsFormat] = "pytorch_state_dict" 

2565 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2566 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 

2567 pytorch_version: Version 

2568 """Version of the PyTorch library used. 

2569 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2570 """ 

2571 dependencies: Optional[FileDescr_dependencies] = None 

2572 """Custom depencies beyond pytorch described in a Conda environment file. 

2573 Allows to specify custom dependencies, see conda docs: 

2574 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2575 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2576 

2577 The conda environment file should include pytorch and any version pinning has to be compatible with 

2578 **pytorch_version**. 

2579 """ 

2580 

2581 

2582class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2583 type: ClassVar[WeightsFormat] = "tensorflow_js" 

2584 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2585 tensorflow_version: Version 

2586 """Version of the TensorFlow library used.""" 

2587 

2588 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2589 """The multi-file weights. 

2590 All required files/folders should be a zip archive.""" 

2591 

2592 

2593class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2594 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle" 

2595 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2596 tensorflow_version: Version 

2597 """Version of the TensorFlow library used.""" 

2598 

2599 dependencies: Optional[FileDescr_dependencies] = None 

2600 """Custom dependencies beyond tensorflow. 

2601 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 

2602 

2603 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2604 """The multi-file weights. 

2605 All required files/folders should be a zip archive.""" 

2606 

2607 

2608class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2609 type: ClassVar[WeightsFormat] = "torchscript" 

2610 weights_format_name: ClassVar[str] = "TorchScript" 

2611 pytorch_version: Version 

2612 """Version of the PyTorch library used.""" 

2613 

2614 

2615SpecificWeightsDescr = Union[ 

2616 KerasHdf5WeightsDescr, 

2617 KerasV3WeightsDescr, 

2618 OnnxWeightsDescr, 

2619 PytorchStateDictWeightsDescr, 

2620 TensorflowJsWeightsDescr, 

2621 TensorflowSavedModelBundleWeightsDescr, 

2622 TorchscriptWeightsDescr, 

2623] 

2624 

2625 

2626class WeightsDescr(Node): 

2627 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2628 keras_v3: Optional[KerasV3WeightsDescr] = None 

2629 onnx: Optional[OnnxWeightsDescr] = None 

2630 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2631 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2632 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2633 None 

2634 ) 

2635 torchscript: Optional[TorchscriptWeightsDescr] = None 

2636 

2637 @model_validator(mode="after") 

2638 def check_entries(self) -> Self: 

2639 entries = {wtype for wtype, entry in self if entry is not None} 

2640 

2641 if not entries: 

2642 raise ValueError("Missing weights entry") 

2643 

2644 entries_wo_parent = { 

2645 wtype 

2646 for wtype, entry in self 

2647 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2648 } 

2649 if len(entries_wo_parent) != 1: 

2650 issue_warning( 

2651 "Exactly one weights entry may not specify the `parent` field (got" 

2652 + " {value}). That entry is considered the original set of model weights." 

2653 + " Other weight formats are created through conversion of the orignal or" 

2654 + " already converted weights. They have to reference the weights format" 

2655 + " they were converted from as their `parent`.", 

2656 value=len(entries_wo_parent), 

2657 field="weights", 

2658 ) 

2659 

2660 for wtype, entry in self: 

2661 if entry is None: 

2662 continue 

2663 

2664 assert hasattr(entry, "type") 

2665 assert hasattr(entry, "parent") 

2666 assert wtype == entry.type 

2667 if ( 

2668 entry.parent is not None and entry.parent not in entries 

2669 ): # self reference checked for `parent` field 

2670 raise ValueError( 

2671 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2672 + f" formats: {entries}" 

2673 ) 

2674 

2675 return self 

2676 

2677 def __getitem__( 

2678 self, 

2679 key: WeightsFormat, 

2680 ): 

2681 if key == "keras_hdf5": 

2682 ret = self.keras_hdf5 

2683 elif key == "keras_v3": 

2684 ret = self.keras_v3 

2685 elif key == "onnx": 

2686 ret = self.onnx 

2687 elif key == "pytorch_state_dict": 

2688 ret = self.pytorch_state_dict 

2689 elif key == "tensorflow_js": 

2690 ret = self.tensorflow_js 

2691 elif key == "tensorflow_saved_model_bundle": 

2692 ret = self.tensorflow_saved_model_bundle 

2693 elif key == "torchscript": 

2694 ret = self.torchscript 

2695 else: 

2696 raise KeyError(key) 

2697 

2698 if ret is None: 

2699 raise KeyError(key) 

2700 

2701 return ret 

2702 

2703 @overload 

2704 def __setitem__( 

2705 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr] 

2706 ) -> None: ... 

2707 @overload 

2708 def __setitem__( 

2709 self, key: Literal["keras_v3"], value: Optional[KerasV3WeightsDescr] 

2710 ) -> None: ... 

2711 @overload 

2712 def __setitem__( 

2713 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr] 

2714 ) -> None: ... 

2715 @overload 

2716 def __setitem__( 

2717 self, 

2718 key: Literal["pytorch_state_dict"], 

2719 value: Optional[PytorchStateDictWeightsDescr], 

2720 ) -> None: ... 

2721 @overload 

2722 def __setitem__( 

2723 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr] 

2724 ) -> None: ... 

2725 @overload 

2726 def __setitem__( 

2727 self, 

2728 key: Literal["tensorflow_saved_model_bundle"], 

2729 value: Optional[TensorflowSavedModelBundleWeightsDescr], 

2730 ) -> None: ... 

2731 @overload 

2732 def __setitem__( 

2733 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr] 

2734 ) -> None: ... 

2735 

2736 def __setitem__( 

2737 self, 

2738 key: WeightsFormat, 

2739 value: Optional[SpecificWeightsDescr], 

2740 ): 

2741 if key == "keras_hdf5": 

2742 if value is not None and not isinstance(value, KerasHdf5WeightsDescr): 

2743 raise TypeError( 

2744 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}" 

2745 ) 

2746 self.keras_hdf5 = value 

2747 elif key == "keras_v3": 

2748 if value is not None and not isinstance(value, KerasV3WeightsDescr): 

2749 raise TypeError( 

2750 f"Expected KerasV3WeightsDescr or None for key 'keras_v3', got {type(value)}" 

2751 ) 

2752 self.keras_v3 = value 

2753 elif key == "onnx": 

2754 if value is not None and not isinstance(value, OnnxWeightsDescr): 

2755 raise TypeError( 

2756 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}" 

2757 ) 

2758 self.onnx = value 

2759 elif key == "pytorch_state_dict": 

2760 if value is not None and not isinstance( 

2761 value, PytorchStateDictWeightsDescr 

2762 ): 

2763 raise TypeError( 

2764 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}" 

2765 ) 

2766 self.pytorch_state_dict = value 

2767 elif key == "tensorflow_js": 

2768 if value is not None and not isinstance(value, TensorflowJsWeightsDescr): 

2769 raise TypeError( 

2770 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}" 

2771 ) 

2772 self.tensorflow_js = value 

2773 elif key == "tensorflow_saved_model_bundle": 

2774 if value is not None and not isinstance( 

2775 value, TensorflowSavedModelBundleWeightsDescr 

2776 ): 

2777 raise TypeError( 

2778 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}" 

2779 ) 

2780 self.tensorflow_saved_model_bundle = value 

2781 elif key == "torchscript": 

2782 if value is not None and not isinstance(value, TorchscriptWeightsDescr): 

2783 raise TypeError( 

2784 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}" 

2785 ) 

2786 self.torchscript = value 

2787 else: 

2788 raise KeyError(key) 

2789 

2790 @property 

2791 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]: 

2792 return { 

2793 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2794 **({} if self.keras_v3 is None else {"keras_v3": self.keras_v3}), 

2795 **({} if self.onnx is None else {"onnx": self.onnx}), 

2796 **( 

2797 {} 

2798 if self.pytorch_state_dict is None 

2799 else {"pytorch_state_dict": self.pytorch_state_dict} 

2800 ), 

2801 **( 

2802 {} 

2803 if self.tensorflow_js is None 

2804 else {"tensorflow_js": self.tensorflow_js} 

2805 ), 

2806 **( 

2807 {} 

2808 if self.tensorflow_saved_model_bundle is None 

2809 else { 

2810 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2811 } 

2812 ), 

2813 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2814 } 

2815 

2816 @property 

2817 def missing_formats(self) -> Set[WeightsFormat]: 

2818 return { 

2819 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2820 } 

2821 

2822 

2823class ModelId(ResourceId): 

2824 pass 

2825 

2826 

2827class LinkedModel(LinkedResourceBase): 

2828 """Reference to a bioimage.io model.""" 

2829 

2830 id: ModelId 

2831 """A valid model `id` from the bioimage.io collection.""" 

2832 

2833 

2834class _DataDepSize(NamedTuple): 

2835 min: StrictInt 

2836 max: Optional[StrictInt] 

2837 

2838 

2839class _AxisSizes(NamedTuple): 

2840 """the lenghts of all axes of model inputs and outputs""" 

2841 

2842 inputs: Dict[Tuple[TensorId, AxisId], int] 

2843 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2844 

2845 

2846class _TensorSizes(NamedTuple): 

2847 """_AxisSizes as nested dicts""" 

2848 

2849 inputs: Dict[TensorId, Dict[AxisId, int]] 

2850 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2851 

2852 

2853class ReproducibilityTolerance(Node, extra="allow"): 

2854 """Describes what small numerical differences -- if any -- may be tolerated 

2855 in the generated output when executing in different environments. 

2856 

2857 A tensor element *output* is considered mismatched to the **test_tensor** if 

2858 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

2859 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

2860 

2861 Motivation: 

2862 For testing we can request the respective deep learning frameworks to be as 

2863 reproducible as possible by setting seeds and chosing deterministic algorithms, 

2864 but differences in operating systems, available hardware and installed drivers 

2865 may still lead to numerical differences. 

2866 """ 

2867 

2868 relative_tolerance: RelativeTolerance = 1e-3 

2869 """Maximum relative tolerance of reproduced test tensor.""" 

2870 

2871 absolute_tolerance: AbsoluteTolerance = 1e-3 

2872 """Maximum absolute tolerance of reproduced test tensor.""" 

2873 

2874 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 

2875 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

2876 

2877 output_ids: Sequence[TensorId] = () 

2878 """Limits the output tensor IDs these reproducibility details apply to.""" 

2879 

2880 weights_formats: Sequence[WeightsFormat] = () 

2881 """Limits the weights formats these details apply to.""" 

2882 

2883 

2884class BiasRisksLimitations(Node, extra="allow"): 

2885 """Known biases, risks, technical limitations, and recommendations for model use.""" 

2886 

2887 known_biases: str = dedent("""\ 

2888 In general bioimage models may suffer from biases caused by: 

2889 

2890 - Imaging protocol dependencies 

2891 - Use of a specific cell type 

2892 - Species-specific training data limitations 

2893 

2894 """) 

2895 """Biases in training data or model behavior.""" 

2896 

2897 risks: str = dedent("""\ 

2898 Common risks in bioimage analysis include: 

2899 

2900 - Erroneously assuming generalization to unseen experimental conditions 

2901 - Trusting (overconfident) model outputs without validation 

2902 - Misinterpretation of results 

2903 

2904 """) 

2905 """Potential risks in the context of bioimage analysis.""" 

2906 

2907 limitations: Optional[str] = None 

2908 """Technical limitations and failure modes.""" 

2909 

2910 recommendations: str = "Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model." 

2911 """Mitigation strategies regarding `known_biases`, `risks`, and `limitations`, as well as applicable best practices. 

2912 

2913 Consider: 

2914 - How to use a validation dataset? 

2915 - How to manually validate? 

2916 - Feasibility of domain adaptation for different experimental setups? 

2917 

2918 """ 

2919 

2920 def format_md(self) -> str: 

2921 if self.limitations is None: 

2922 limitations_header = "" 

2923 else: 

2924 limitations_header = "## Limitations\n\n" 

2925 

2926 return f"""# Bias, Risks, and Limitations 

2927 

2928{self.known_biases} 

2929 

2930{self.risks} 

2931 

2932{limitations_header}{self.limitations or ""} 

2933 

2934## Recommendations 

2935 

2936{self.recommendations} 

2937 

2938""" 

2939 

2940 

2941class TrainingDetails(Node, extra="allow"): 

2942 training_preprocessing: Optional[str] = None 

2943 """Detailed image preprocessing steps during model training: 

2944 

2945 Mention: 

2946 - *Normalization methods* 

2947 - *Augmentation strategies* 

2948 - *Resizing/resampling procedures* 

2949 - *Artifact handling* 

2950 

2951 """ 

2952 

2953 training_epochs: Optional[float] = None 

2954 """Number of training epochs.""" 

2955 

2956 training_batch_size: Optional[float] = None 

2957 """Batch size used in training.""" 

2958 

2959 initial_learning_rate: Optional[float] = None 

2960 """Initial learning rate used in training.""" 

2961 

2962 learning_rate_schedule: Optional[str] = None 

2963 """Learning rate schedule used in training.""" 

2964 

2965 loss_function: Optional[str] = None 

2966 """Loss function used in training, e.g. nn.MSELoss.""" 

2967 

2968 loss_function_kwargs: Dict[str, YamlValue] = Field( 

2969 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2970 ) 

2971 """key word arguments for the `loss_function`""" 

2972 

2973 optimizer: Optional[str] = None 

2974 """optimizer, e.g. torch.optim.Adam""" 

2975 

2976 optimizer_kwargs: Dict[str, YamlValue] = Field( 

2977 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2978 ) 

2979 """key word arguments for the `optimizer`""" 

2980 

2981 regularization: Optional[str] = None 

2982 """Regularization techniques used during training, e.g. drop-out or weight decay.""" 

2983 

2984 training_duration: Optional[float] = None 

2985 """Total training duration in hours.""" 

2986 

2987 

2988class Evaluation(Node, extra="allow"): 

2989 model_id: Optional[ModelId] = None 

2990 """Model being evaluated.""" 

2991 

2992 dataset_id: DatasetId 

2993 """Dataset used for evaluation.""" 

2994 

2995 dataset_source: HttpUrl 

2996 """Source of the dataset.""" 

2997 

2998 dataset_role: Literal["train", "validation", "test", "independent", "unknown"] 

2999 """Role of the dataset used for evaluation. 

3000 

3001 - `train`: dataset was (part of) the training data 

3002 - `validation`: dataset was (part of) the validation data used during training, e.g. used for model selection or hyperparameter tuning 

3003 - `test`: dataset was (part of) the designated test data; not used during training or validation, but acquired from the same source/distribution as training data 

3004 - `independent`: dataset is entirely independent test data; not used during training or validation, and acquired from a different source/distribution than training data 

3005 - `unknown`: role of the dataset is unknown; choose this if you are not certain if (a subset) of the data was seen by the model during training. 

3006 """ 

3007 

3008 sample_count: int 

3009 """Number of evaluated samples.""" 

3010 

3011 evaluation_factors: List[Annotated[str, MaxLen(16)]] 

3012 """(Abbreviations of) each evaluation factor. 

3013 

3014 Evaluation factors are criteria along which model performance is evaluated, e.g. different image conditions 

3015 like 'low SNR', 'high cell density', or different biological conditions like 'cell type A', 'cell type B'. 

3016 An 'overall' factor may be included to summarize performance across all conditions. 

3017 """ 

3018 

3019 evaluation_factors_long: List[str] 

3020 """Descriptions (long form) of each evaluation factor.""" 

3021 

3022 metrics: List[Annotated[str, MaxLen(16)]] 

3023 """(Abbreviations of) metrics used for evaluation.""" 

3024 

3025 metrics_long: List[str] 

3026 """Description of each metric used.""" 

3027 

3028 @model_validator(mode="after") 

3029 def _validate_list_lengths(self) -> Self: 

3030 if len(self.evaluation_factors) != len(self.evaluation_factors_long): 

3031 raise ValueError( 

3032 "`evaluation_factors` and `evaluation_factors_long` must have the same length" 

3033 ) 

3034 

3035 if len(self.metrics) != len(self.metrics_long): 

3036 raise ValueError("`metrics` and `metrics_long` must have the same length") 

3037 

3038 if len(self.results) != len(self.metrics): 

3039 raise ValueError("`results` must have the same number of rows as `metrics`") 

3040 

3041 for row in self.results: 

3042 if len(row) != len(self.evaluation_factors): 

3043 raise ValueError( 

3044 "`results` must have the same number of columns (in every row) as `evaluation_factors`" 

3045 ) 

3046 

3047 return self 

3048 

3049 results: List[List[Union[str, float, int]]] 

3050 """Results for each metric (rows; outer list) and each evaluation factor (columns; inner list).""" 

3051 

3052 results_summary: Optional[str] = None 

3053 """Interpretation of results for general audience. 

3054 

3055 Consider: 

3056 - Overall model performance 

3057 - Comparison to existing methods 

3058 - Limitations and areas for improvement 

3059 

3060""" 

3061 

3062 def format_md(self): 

3063 results_header = ["Metric"] + self.evaluation_factors 

3064 results_table_cells = [results_header, ["---"] * len(results_header)] + [ 

3065 [metric] + [str(r) for r in row] 

3066 for metric, row in zip(self.metrics, self.results) 

3067 ] 

3068 

3069 results_table = "".join( 

3070 "| " + " | ".join(row) + " |\n" for row in results_table_cells 

3071 ) 

3072 factors = "".join( 

3073 f"\n - {ef}: {efl}" 

3074 for ef, efl in zip(self.evaluation_factors, self.evaluation_factors_long) 

3075 ) 

3076 metrics = "".join( 

3077 f"\n - {em}: {eml}" for em, eml in zip(self.metrics, self.metrics_long) 

3078 ) 

3079 

3080 return f"""## Testing Data, Factors & Metrics 

3081 

3082Evaluation of {self.model_id or "this"} model on the {self.dataset_id} dataset (dataset role: {self.dataset_role}). 

3083 

3084### Testing Data 

3085 

3086- **Source:** [{self.dataset_id}]({self.dataset_source}) 

3087- **Size:** {self.sample_count} evaluated samples 

3088 

3089### Factors 

3090{factors} 

3091 

3092### Metrics 

3093{metrics} 

3094 

3095## Results 

3096 

3097### Quantitative Results 

3098 

3099{results_table} 

3100 

3101### Summary 

3102 

3103{self.results_summary or "missing"} 

3104 

3105""" 

3106 

3107 

3108class EnvironmentalImpact(Node, extra="allow"): 

3109 """Environmental considerations for model training and deployment. 

3110 

3111 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). 

3112 """ 

3113 

3114 hardware_type: Optional[str] = None 

3115 """GPU/CPU specifications""" 

3116 

3117 hours_used: Optional[float] = None 

3118 """Total compute hours""" 

3119 

3120 cloud_provider: Optional[str] = None 

3121 """If applicable""" 

3122 

3123 compute_region: Optional[str] = None 

3124 """Geographic location""" 

3125 

3126 co2_emitted: Optional[float] = None 

3127 """kg CO2 equivalent 

3128 

3129 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). 

3130 """ 

3131 

3132 def format_md(self): 

3133 """Filled Markdown template section following [Hugging Face Model Card Template](https://huggingface.co/docs/hub/en/model-card-annotated).""" 

3134 if self == self.__class__(): 

3135 return "" 

3136 

3137 ret = "# Environmental Impact\n\n" 

3138 if self.hardware_type is not None: 

3139 ret += f"- **Hardware Type:** {self.hardware_type}\n" 

3140 if self.hours_used is not None: 

3141 ret += f"- **Hours used:** {self.hours_used}\n" 

3142 if self.cloud_provider is not None: 

3143 ret += f"- **Cloud Provider:** {self.cloud_provider}\n" 

3144 if self.compute_region is not None: 

3145 ret += f"- **Compute Region:** {self.compute_region}\n" 

3146 if self.co2_emitted is not None: 

3147 ret += f"- **Carbon Emitted:** {self.co2_emitted} kg CO2e\n" 

3148 

3149 return ret + "\n" 

3150 

3151 

3152class BioimageioConfig(Node, extra="allow"): 

3153 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

3154 """Tolerances to allow when reproducing the model's test outputs 

3155 from the model's test inputs. 

3156 Only the first entry matching tensor id and weights format is considered. 

3157 """ 

3158 

3159 funded_by: Optional[str] = None 

3160 """Funding agency, grant number if applicable""" 

3161 

3162 architecture_type: Optional[Annotated[str, MaxLen(32)]] = ( 

3163 None # TODO: add to differentiated tags 

3164 ) 

3165 """Model architecture type, e.g., 3D U-Net, ResNet, transformer""" 

3166 

3167 architecture_description: Optional[str] = None 

3168 """Text description of model architecture.""" 

3169 

3170 modality: Optional[str] = None # TODO: add to differentiated tags 

3171 """Input modality, e.g., fluorescence microscopy, electron microscopy""" 

3172 

3173 target_structure: List[str] = Field( # TODO: add to differentiated tags 

3174 default_factory=cast(Callable[[], List[str]], list) 

3175 ) 

3176 """Biological structure(s) the model is designed to analyze, e.g., nuclei, mitochondria, cells""" 

3177 

3178 task: Optional[str] = None # TODO: add to differentiated tags 

3179 """Bioimage-specific task type, e.g., segmentation, classification, detection, denoising""" 

3180 

3181 new_version: Optional[ModelId] = None 

3182 """A new version of this model exists with a different model id.""" 

3183 

3184 out_of_scope_use: Optional[str] = None 

3185 """Describe how the model may be misused in bioimage analysis contexts and what users should **not** do with the model.""" 

3186 

3187 bias_risks_limitations: BiasRisksLimitations = Field( 

3188 default_factory=BiasRisksLimitations.model_construct 

3189 ) 

3190 """Description of known bias, risks, and technical limitations for in-scope model use.""" 

3191 

3192 model_parameter_count: Optional[int] = None 

3193 """Total number of model parameters.""" 

3194 

3195 training: TrainingDetails = Field(default_factory=TrainingDetails.model_construct) 

3196 """Details on how the model was trained.""" 

3197 

3198 inference_time: Optional[str] = None 

3199 """Average inference time per image/tile. Specify hardware and image size. Multiple examples can be given.""" 

3200 

3201 memory_requirements_inference: Optional[str] = None 

3202 """GPU memory needed for inference. Multiple examples with different image size can be given.""" 

3203 

3204 memory_requirements_training: Optional[str] = None 

3205 """GPU memory needed for training. Multiple examples with different image/batch sizes can be given.""" 

3206 

3207 evaluations: List[Evaluation] = Field( 

3208 default_factory=cast(Callable[[], List[Evaluation]], list) 

3209 ) 

3210 """Quantitative model evaluations. 

3211 

3212 Note: 

3213 At the moment we recommend to include only a single test dataset 

3214 (with evaluation factors that may mark subsets of the dataset) 

3215 to avoid confusion and make the presentation of results cleaner. 

3216 """ 

3217 

3218 environmental_impact: EnvironmentalImpact = Field( 

3219 default_factory=EnvironmentalImpact.model_construct 

3220 ) 

3221 """Environmental considerations for model training and deployment""" 

3222 

3223 

3224class Config(Node, extra="allow"): 

3225 bioimageio: BioimageioConfig = Field( 

3226 default_factory=BioimageioConfig.model_construct 

3227 ) 

3228 stardist: YamlValue = None 

3229 

3230 

3231class ModelDescr(GenericModelDescrBase): 

3232 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

3233 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

3234 """ 

3235 

3236 implemented_format_version: ClassVar[Literal["0.5.9"]] = "0.5.9" 

3237 if TYPE_CHECKING: 

3238 format_version: Literal["0.5.9"] = "0.5.9" 

3239 else: 

3240 format_version: Literal["0.5.9"] 

3241 """Version of the bioimage.io model description specification used. 

3242 When creating a new model always use the latest micro/patch version described here. 

3243 The `format_version` is important for any consumer software to understand how to parse the fields. 

3244 """ 

3245 

3246 implemented_type: ClassVar[Literal["model"]] = "model" 

3247 if TYPE_CHECKING: 

3248 type: Literal["model"] = "model" 

3249 else: 

3250 type: Literal["model"] 

3251 """Specialized resource type 'model'""" 

3252 

3253 id: Optional[ModelId] = None 

3254 """bioimage.io-wide unique resource identifier 

3255 assigned by bioimage.io; version **un**specific.""" 

3256 

3257 authors: FAIR[List[Author]] = Field( 

3258 default_factory=cast(Callable[[], List[Author]], list) 

3259 ) 

3260 """The authors are the creators of the model RDF and the primary points of contact.""" 

3261 

3262 documentation: FAIR[Optional[FileSource_documentation]] = None 

3263 """URL or relative path to a markdown file with additional documentation. 

3264 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

3265 The documentation should include a '#[#] Validation' (sub)section 

3266 with details on how to quantitatively validate the model on unseen data.""" 

3267 

3268 @field_validator("documentation", mode="after") 

3269 @classmethod 

3270 def _validate_documentation( 

3271 cls, value: Optional[FileSource_documentation] 

3272 ) -> Optional[FileSource_documentation]: 

3273 if not get_validation_context().perform_io_checks or value is None: 

3274 return value 

3275 

3276 doc_reader = get_reader(value) 

3277 doc_content = doc_reader.read().decode(encoding="utf-8") 

3278 if not re.search("#.*[vV]alidation", doc_content): 

3279 issue_warning( 

3280 "No '# Validation' (sub)section found in {value}.", 

3281 value=value, 

3282 field="documentation", 

3283 ) 

3284 

3285 return value 

3286 

3287 inputs: NotEmpty[Sequence[InputTensorDescr]] 

3288 """Describes the input tensors expected by this model.""" 

3289 

3290 @field_validator("inputs", mode="after") 

3291 @classmethod 

3292 def _validate_input_axes( 

3293 cls, inputs: Sequence[InputTensorDescr] 

3294 ) -> Sequence[InputTensorDescr]: 

3295 input_size_refs = cls._get_axes_with_independent_size(inputs) 

3296 

3297 for i, ipt in enumerate(inputs): 

3298 valid_independent_refs: Dict[ 

3299 Tuple[TensorId, AxisId], 

3300 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3301 ] = { 

3302 **{ 

3303 (ipt.id, a.id): (ipt, a, a.size) 

3304 for a in ipt.axes 

3305 if not isinstance(a, BatchAxis) 

3306 and isinstance(a.size, (int, ParameterizedSize)) 

3307 }, 

3308 **input_size_refs, 

3309 } 

3310 for a, ax in enumerate(ipt.axes): 

3311 cls._validate_axis( 

3312 "inputs", 

3313 i=i, 

3314 tensor_id=ipt.id, 

3315 a=a, 

3316 axis=ax, 

3317 valid_independent_refs=valid_independent_refs, 

3318 ) 

3319 return inputs 

3320 

3321 @staticmethod 

3322 def _validate_axis( 

3323 field_name: str, 

3324 i: int, 

3325 tensor_id: TensorId, 

3326 a: int, 

3327 axis: AnyAxis, 

3328 valid_independent_refs: Dict[ 

3329 Tuple[TensorId, AxisId], 

3330 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3331 ], 

3332 ): 

3333 if isinstance(axis, BatchAxis) or isinstance( 

3334 axis.size, (int, ParameterizedSize, DataDependentSize) 

3335 ): 

3336 return 

3337 elif not isinstance(axis.size, SizeReference): 

3338 assert_never(axis.size) 

3339 

3340 # validate axis.size SizeReference 

3341 ref = (axis.size.tensor_id, axis.size.axis_id) 

3342 if ref not in valid_independent_refs: 

3343 raise ValueError( 

3344 "Invalid tensor axis reference at" 

3345 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

3346 ) 

3347 if ref == (tensor_id, axis.id): 

3348 raise ValueError( 

3349 "Self-referencing not allowed for" 

3350 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

3351 ) 

3352 if axis.type == "channel": 

3353 if valid_independent_refs[ref][1].type != "channel": 

3354 raise ValueError( 

3355 "A channel axis' size may only reference another fixed size" 

3356 + " channel axis." 

3357 ) 

3358 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

3359 ref_size = valid_independent_refs[ref][2] 

3360 assert isinstance(ref_size, int), ( 

3361 "channel axis ref (another channel axis) has to specify fixed" 

3362 + " size" 

3363 ) 

3364 generated_channel_names = [ 

3365 Identifier(axis.channel_names.format(i=i)) 

3366 for i in range(1, ref_size + 1) 

3367 ] 

3368 axis.channel_names = generated_channel_names 

3369 

3370 if (ax_unit := getattr(axis, "unit", None)) != ( 

3371 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

3372 ): 

3373 raise ValueError( 

3374 "The units of an axis and its reference axis need to match, but" 

3375 + f" '{ax_unit}' != '{ref_unit}'." 

3376 ) 

3377 ref_axis = valid_independent_refs[ref][1] 

3378 if isinstance(ref_axis, BatchAxis): 

3379 raise ValueError( 

3380 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

3381 + " (a batch axis is not allowed as reference)." 

3382 ) 

3383 

3384 if isinstance(axis, WithHalo): 

3385 min_size = axis.size.get_size(axis, ref_axis, n=0) 

3386 if (min_size - 2 * axis.halo) < 1: 

3387 raise ValueError( 

3388 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

3389 + f" {axis.halo}." 

3390 ) 

3391 

3392 ref_halo = axis.halo * axis.scale / ref_axis.scale 

3393 if ref_halo != int(ref_halo): 

3394 raise ValueError( 

3395 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} =" 

3396 + f" {tensor_id}.{axis.id}.halo {axis.halo}" 

3397 + f" * {tensor_id}.{axis.id}.scale {axis.scale}" 

3398 + f" / {'.'.join(ref)}.scale {ref_axis.scale})." 

3399 ) 

3400 

3401 def validate_input_tensors( 

3402 self, 

3403 sources: Union[ 

3404 Sequence[NDArray[Any]], Mapping[TensorId, Optional[NDArray[Any]]] 

3405 ], 

3406 ) -> Mapping[TensorId, Optional[NDArray[Any]]]: 

3407 """Check if the given input tensors match the model's input tensor descriptions. 

3408 This includes checks of tensor shapes and dtypes, but not of the actual values. 

3409 """ 

3410 if not isinstance(sources, collections.abc.Mapping): 

3411 sources = {descr.id: tensor for descr, tensor in zip(self.inputs, sources)} 

3412 

3413 tensors = {descr.id: (descr, sources.get(descr.id)) for descr in self.inputs} 

3414 validate_tensors(tensors) 

3415 

3416 return sources 

3417 

3418 @model_validator(mode="after") 

3419 def _validate_test_tensors(self) -> Self: 

3420 if not get_validation_context().perform_io_checks: 

3421 return self 

3422 

3423 test_inputs = { 

3424 descr.id: ( 

3425 descr, 

3426 None if descr.test_tensor is None else load_array(descr.test_tensor), 

3427 ) 

3428 for descr in self.inputs 

3429 } 

3430 test_outputs = { 

3431 descr.id: ( 

3432 descr, 

3433 None if descr.test_tensor is None else load_array(descr.test_tensor), 

3434 ) 

3435 for descr in self.outputs 

3436 } 

3437 

3438 validate_tensors({**test_inputs, **test_outputs}, tensor_origin="test_tensor") 

3439 

3440 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

3441 if not rep_tol.absolute_tolerance: 

3442 continue 

3443 

3444 if rep_tol.output_ids: 

3445 out_arrays = { 

3446 k: v[1] for k, v in test_outputs.items() if k in rep_tol.output_ids 

3447 } 

3448 else: 

3449 out_arrays = {k: v[1] for k, v in test_outputs.items()} 

3450 

3451 for out_id, array in out_arrays.items(): 

3452 if array is None: 

3453 continue 

3454 

3455 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

3456 raise ValueError( 

3457 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

3458 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

3459 + f" (1% of the maximum value of the test tensor '{out_id}')" 

3460 ) 

3461 

3462 return self 

3463 

3464 @model_validator(mode="after") 

3465 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

3466 ipt_refs = {t.id for t in self.inputs} 

3467 missing_refs = [ 

3468 k["reference_tensor"] 

3469 for k in [p.kwargs for ipt in self.inputs for p in ipt.preprocessing] 

3470 + [p.kwargs for out in self.outputs for p in out.postprocessing] 

3471 if "reference_tensor" in k 

3472 and k["reference_tensor"] is not None 

3473 and k["reference_tensor"] not in ipt_refs 

3474 ] 

3475 

3476 if missing_refs: 

3477 raise ValueError( 

3478 f"`reference_tensor`s {missing_refs} not found. Valid input tensor" 

3479 + f" references are: {ipt_refs}." 

3480 ) 

3481 

3482 return self 

3483 

3484 name: Annotated[ 

3485 str, 

3486 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 

3487 MinLen(5), 

3488 MaxLen(128), 

3489 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

3490 ] 

3491 """A human-readable name of this model. 

3492 It should be no longer than 64 characters 

3493 and may only contain letter, number, underscore, minus, parentheses and spaces. 

3494 We recommend to chose a name that refers to the model's task and image modality. 

3495 """ 

3496 

3497 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

3498 """Describes the output tensors.""" 

3499 

3500 @field_validator("outputs", mode="after") 

3501 @classmethod 

3502 def _validate_tensor_ids( 

3503 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

3504 ) -> Sequence[OutputTensorDescr]: 

3505 tensor_ids = [ 

3506 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

3507 ] 

3508 duplicate_tensor_ids: List[str] = [] 

3509 seen: Set[str] = set() 

3510 for t in tensor_ids: 

3511 if t in seen: 

3512 duplicate_tensor_ids.append(t) 

3513 

3514 seen.add(t) 

3515 

3516 if duplicate_tensor_ids: 

3517 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

3518 

3519 return outputs 

3520 

3521 @staticmethod 

3522 def _get_axes_with_parameterized_size( 

3523 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3524 ): 

3525 return { 

3526 f"{t.id}.{a.id}": (t, a, a.size) 

3527 for t in io 

3528 for a in t.axes 

3529 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

3530 } 

3531 

3532 @staticmethod 

3533 def _get_axes_with_independent_size( 

3534 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3535 ): 

3536 return { 

3537 (t.id, a.id): (t, a, a.size) 

3538 for t in io 

3539 for a in t.axes 

3540 if not isinstance(a, BatchAxis) 

3541 and isinstance(a.size, (int, ParameterizedSize)) 

3542 } 

3543 

3544 @field_validator("outputs", mode="after") 

3545 @classmethod 

3546 def _validate_output_axes( 

3547 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

3548 ) -> List[OutputTensorDescr]: 

3549 input_size_refs = cls._get_axes_with_independent_size( 

3550 info.data.get("inputs", []) 

3551 ) 

3552 output_size_refs = cls._get_axes_with_independent_size(outputs) 

3553 

3554 for i, out in enumerate(outputs): 

3555 valid_independent_refs: Dict[ 

3556 Tuple[TensorId, AxisId], 

3557 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

3558 ] = { 

3559 **{ 

3560 (out.id, a.id): (out, a, a.size) 

3561 for a in out.axes 

3562 if not isinstance(a, BatchAxis) 

3563 and isinstance(a.size, (int, ParameterizedSize)) 

3564 }, 

3565 **input_size_refs, 

3566 **output_size_refs, 

3567 } 

3568 for a, ax in enumerate(out.axes): 

3569 cls._validate_axis( 

3570 "outputs", 

3571 i, 

3572 out.id, 

3573 a, 

3574 ax, 

3575 valid_independent_refs=valid_independent_refs, 

3576 ) 

3577 

3578 return outputs 

3579 

3580 packaged_by: List[Author] = Field( 

3581 default_factory=cast(Callable[[], List[Author]], list) 

3582 ) 

3583 """The persons that have packaged and uploaded this model. 

3584 Only required if those persons differ from the `authors`.""" 

3585 

3586 parent: Optional[LinkedModel] = None 

3587 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

3588 

3589 @model_validator(mode="after") 

3590 def _validate_parent_is_not_self(self) -> Self: 

3591 if self.parent is not None and self.parent.id == self.id: 

3592 raise ValueError("A model description may not reference itself as parent.") 

3593 

3594 return self 

3595 

3596 run_mode: Annotated[ 

3597 Optional[RunMode], 

3598 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

3599 ] = None 

3600 """Custom run mode for this model: for more complex prediction procedures like test time 

3601 data augmentation that currently cannot be expressed in the specification. 

3602 No standard run modes are defined yet.""" 

3603 

3604 timestamp: Datetime = Field(default_factory=Datetime.now) 

3605 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

3606 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

3607 (In Python a datetime object is valid, too).""" 

3608 

3609 training_data: Annotated[ 

3610 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

3611 Field(union_mode="left_to_right"), 

3612 ] = None 

3613 """The dataset used to train this model""" 

3614 

3615 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

3616 """The weights for this model. 

3617 Weights can be given for different formats, but should otherwise be equivalent. 

3618 The available weight formats determine which consumers can use this model.""" 

3619 

3620 config: Config = Field(default_factory=Config.model_construct) 

3621 

3622 @model_validator(mode="after") 

3623 def _add_default_cover(self) -> Self: 

3624 if not get_validation_context().perform_io_checks or self.covers: 

3625 return self 

3626 

3627 try: 

3628 generated_covers = generate_covers( 

3629 [ 

3630 (t, load_array(t.test_tensor)) 

3631 for t in self.inputs 

3632 if t.test_tensor is not None 

3633 ], 

3634 [ 

3635 (t, load_array(t.test_tensor)) 

3636 for t in self.outputs 

3637 if t.test_tensor is not None 

3638 ], 

3639 ) 

3640 except Exception as e: 

3641 issue_warning( 

3642 "Failed to generate cover image(s): {e}", 

3643 value=self.covers, 

3644 msg_context=dict(e=e), 

3645 field="covers", 

3646 ) 

3647 else: 

3648 self.covers.extend(generated_covers) 

3649 

3650 return self 

3651 

3652 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

3653 return self._get_test_arrays(self.inputs) 

3654 

3655 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

3656 return self._get_test_arrays(self.outputs) 

3657 

3658 @staticmethod 

3659 def _get_test_arrays( 

3660 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3661 ): 

3662 ts: List[FileDescr] = [] 

3663 for d in io_descr: 

3664 if d.test_tensor is None: 

3665 raise ValueError( 

3666 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 

3667 ) 

3668 ts.append(d.test_tensor) 

3669 

3670 data = [load_array(t) for t in ts] 

3671 assert all(isinstance(d, np.ndarray) for d in data) 

3672 return data 

3673 

3674 @staticmethod 

3675 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

3676 batch_size = 1 

3677 tensor_with_batchsize: Optional[TensorId] = None 

3678 for tid in tensor_sizes: 

3679 for aid, s in tensor_sizes[tid].items(): 

3680 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

3681 continue 

3682 

3683 if batch_size != 1: 

3684 assert tensor_with_batchsize is not None 

3685 raise ValueError( 

3686 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

3687 ) 

3688 

3689 batch_size = s 

3690 tensor_with_batchsize = tid 

3691 

3692 return batch_size 

3693 

3694 def get_output_tensor_sizes( 

3695 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

3696 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

3697 """Returns the tensor output sizes for given **input_sizes**. 

3698 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

3699 Otherwise it might be larger than the actual (valid) output""" 

3700 batch_size = self.get_batch_size(input_sizes) 

3701 ns = self.get_ns(input_sizes) 

3702 

3703 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

3704 return tensor_sizes.outputs 

3705 

3706 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

3707 """get parameter `n` for each parameterized axis 

3708 such that the valid input size is >= the given input size""" 

3709 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

3710 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

3711 for tid in input_sizes: 

3712 for aid, s in input_sizes[tid].items(): 

3713 size_descr = axes[tid][aid].size 

3714 if isinstance(size_descr, ParameterizedSize): 

3715 ret[(tid, aid)] = size_descr.get_n(s) 

3716 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

3717 pass 

3718 else: 

3719 assert_never(size_descr) 

3720 

3721 return ret 

3722 

3723 def get_tensor_sizes( 

3724 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

3725 ) -> _TensorSizes: 

3726 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

3727 return _TensorSizes( 

3728 { 

3729 t: { 

3730 aa: axis_sizes.inputs[(tt, aa)] 

3731 for tt, aa in axis_sizes.inputs 

3732 if tt == t 

3733 } 

3734 for t in {tt for tt, _ in axis_sizes.inputs} 

3735 }, 

3736 { 

3737 t: { 

3738 aa: axis_sizes.outputs[(tt, aa)] 

3739 for tt, aa in axis_sizes.outputs 

3740 if tt == t 

3741 } 

3742 for t in {tt for tt, _ in axis_sizes.outputs} 

3743 }, 

3744 ) 

3745 

3746 def get_axis_sizes( 

3747 self, 

3748 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3749 batch_size: Optional[int] = None, 

3750 *, 

3751 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

3752 ) -> _AxisSizes: 

3753 """Determine input and output block shape for scale factors **ns** 

3754 of parameterized input sizes. 

3755 

3756 Args: 

3757 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

3758 that is parameterized as `size = min + n * step`. 

3759 batch_size: The desired size of the batch dimension. 

3760 If given **batch_size** overwrites any batch size present in 

3761 **max_input_shape**. Default 1. 

3762 max_input_shape: Limits the derived block shapes. 

3763 Each axis for which the input size, parameterized by `n`, is larger 

3764 than **max_input_shape** is set to the minimal value `n_min` for which 

3765 this is still true. 

3766 Use this for small input samples or large values of **ns**. 

3767 Or simply whenever you know the full input shape. 

3768 

3769 Returns: 

3770 Resolved axis sizes for model inputs and outputs. 

3771 """ 

3772 max_input_shape = max_input_shape or {} 

3773 if batch_size is None: 

3774 for (_t_id, a_id), s in max_input_shape.items(): 

3775 if a_id == BATCH_AXIS_ID: 

3776 batch_size = s 

3777 break 

3778 else: 

3779 batch_size = 1 

3780 

3781 all_axes = { 

3782 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

3783 } 

3784 

3785 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

3786 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

3787 

3788 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

3789 if isinstance(a, BatchAxis): 

3790 if (t_descr.id, a.id) in ns: 

3791 logger.warning( 

3792 "Ignoring unexpected size increment factor (n) for batch axis" 

3793 + " of tensor '{}'.", 

3794 t_descr.id, 

3795 ) 

3796 return batch_size 

3797 elif isinstance(a.size, int): 

3798 if (t_descr.id, a.id) in ns: 

3799 logger.warning( 

3800 "Ignoring unexpected size increment factor (n) for fixed size" 

3801 + " axis '{}' of tensor '{}'.", 

3802 a.id, 

3803 t_descr.id, 

3804 ) 

3805 return a.size 

3806 elif isinstance(a.size, ParameterizedSize): 

3807 if (t_descr.id, a.id) not in ns: 

3808 raise ValueError( 

3809 "Size increment factor (n) missing for parametrized axis" 

3810 + f" '{a.id}' of tensor '{t_descr.id}'." 

3811 ) 

3812 n = ns[(t_descr.id, a.id)] 

3813 s_max = max_input_shape.get((t_descr.id, a.id)) 

3814 if s_max is not None: 

3815 n = min(n, a.size.get_n(s_max)) 

3816 

3817 return a.size.get_size(n) 

3818 

3819 elif isinstance(a.size, SizeReference): 

3820 if (t_descr.id, a.id) in ns: 

3821 logger.warning( 

3822 "Ignoring unexpected size increment factor (n) for axis '{}'" 

3823 + " of tensor '{}' with size reference.", 

3824 a.id, 

3825 t_descr.id, 

3826 ) 

3827 assert not isinstance(a, BatchAxis) 

3828 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

3829 assert not isinstance(ref_axis, BatchAxis) 

3830 ref_key = (a.size.tensor_id, a.size.axis_id) 

3831 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

3832 assert ref_size is not None, ref_key 

3833 assert not isinstance(ref_size, _DataDepSize), ref_key 

3834 return a.size.get_size( 

3835 axis=a, 

3836 ref_axis=ref_axis, 

3837 ref_size=ref_size, 

3838 ) 

3839 elif isinstance(a.size, DataDependentSize): 

3840 if (t_descr.id, a.id) in ns: 

3841 logger.warning( 

3842 "Ignoring unexpected increment factor (n) for data dependent" 

3843 + " size axis '{}' of tensor '{}'.", 

3844 a.id, 

3845 t_descr.id, 

3846 ) 

3847 return _DataDepSize(a.size.min, a.size.max) 

3848 else: 

3849 assert_never(a.size) 

3850 

3851 # first resolve all , but the `SizeReference` input sizes 

3852 for t_descr in self.inputs: 

3853 for a in t_descr.axes: 

3854 if not isinstance(a.size, SizeReference): 

3855 s = get_axis_size(a) 

3856 assert not isinstance(s, _DataDepSize) 

3857 inputs[t_descr.id, a.id] = s 

3858 

3859 # resolve all other input axis sizes 

3860 for t_descr in self.inputs: 

3861 for a in t_descr.axes: 

3862 if isinstance(a.size, SizeReference): 

3863 s = get_axis_size(a) 

3864 assert not isinstance(s, _DataDepSize) 

3865 inputs[t_descr.id, a.id] = s 

3866 

3867 # resolve all output axis sizes 

3868 for t_descr in self.outputs: 

3869 for a in t_descr.axes: 

3870 assert not isinstance(a.size, ParameterizedSize) 

3871 s = get_axis_size(a) 

3872 outputs[t_descr.id, a.id] = s 

3873 

3874 return _AxisSizes(inputs=inputs, outputs=outputs) 

3875 

3876 @model_validator(mode="before") 

3877 @classmethod 

3878 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

3879 cls.convert_from_old_format_wo_validation(data) 

3880 return data 

3881 

3882 @classmethod 

3883 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

3884 """Convert metadata following an older format version to this classes' format 

3885 without validating the result. 

3886 """ 

3887 if ( 

3888 data.get("type") == "model" 

3889 and isinstance(fv := data.get("format_version"), str) 

3890 and fv.count(".") == 2 

3891 ): 

3892 fv_parts = fv.split(".") 

3893 if any(not p.isdigit() for p in fv_parts): 

3894 return 

3895 

3896 fv_tuple = tuple(map(int, fv_parts)) 

3897 

3898 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

3899 if fv_tuple[:2] in ((0, 3), (0, 4)): 

3900 m04 = _ModelDescr_v0_4.load(data) 

3901 if isinstance(m04, InvalidDescr): 

3902 try: 

3903 updated = _model_conv.convert_as_dict( 

3904 m04 # pyright: ignore[reportArgumentType] 

3905 ) 

3906 except Exception as e: 

3907 logger.error( 

3908 "Failed to convert from invalid model 0.4 description." 

3909 + f"\nerror: {e}" 

3910 + "\nProceeding with model 0.5 validation without conversion." 

3911 ) 

3912 updated = None 

3913 else: 

3914 updated = _model_conv.convert_as_dict(m04) 

3915 

3916 if updated is not None: 

3917 data.clear() 

3918 data.update(updated) 

3919 

3920 elif fv_tuple[:2] == (0, 5): 

3921 # bump patch version 

3922 data["format_version"] = cls.implemented_format_version 

3923 

3924 

3925class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

3926 def _convert( 

3927 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

3928 ) -> "ModelDescr | dict[str, Any]": 

3929 name = "".join( 

3930 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

3931 for c in src.name 

3932 ) 

3933 

3934 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

3935 conv = ( 

3936 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

3937 ) 

3938 return None if auths is None else [conv(a) for a in auths] 

3939 

3940 if TYPE_CHECKING: 

3941 arch_file_conv = _arch_file_conv.convert 

3942 arch_lib_conv = _arch_lib_conv.convert 

3943 else: 

3944 arch_file_conv = _arch_file_conv.convert_as_dict 

3945 arch_lib_conv = _arch_lib_conv.convert_as_dict 

3946 

3947 input_size_refs = { 

3948 ipt.name: { 

3949 a: s 

3950 for a, s in zip( 

3951 ipt.axes, 

3952 ( 

3953 ipt.shape.min 

3954 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

3955 else ipt.shape 

3956 ), 

3957 ) 

3958 } 

3959 for ipt in src.inputs 

3960 if ipt.shape 

3961 } 

3962 output_size_refs = { 

3963 **{ 

3964 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

3965 for out in src.outputs 

3966 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

3967 }, 

3968 **input_size_refs, 

3969 } 

3970 

3971 return tgt( 

3972 attachments=( 

3973 [] 

3974 if src.attachments is None 

3975 else [FileDescr(source=f) for f in src.attachments.files] 

3976 ), 

3977 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType] 

3978 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType] 

3979 config=src.config, # pyright: ignore[reportArgumentType] 

3980 covers=src.covers, 

3981 description=src.description, 

3982 documentation=src.documentation, 

3983 format_version="0.5.9", 

3984 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3985 icon=src.icon, 

3986 id=None if src.id is None else ModelId(src.id), 

3987 id_emoji=src.id_emoji, 

3988 license=src.license, # type: ignore 

3989 links=src.links, 

3990 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType] 

3991 name=name, 

3992 tags=src.tags, 

3993 type=src.type, 

3994 uploader=src.uploader, 

3995 version=src.version, 

3996 inputs=[ # pyright: ignore[reportArgumentType] 

3997 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

3998 for ipt, tt, st in zip( 

3999 src.inputs, 

4000 src.test_inputs, 

4001 src.sample_inputs or [None] * len(src.test_inputs), 

4002 ) 

4003 ], 

4004 outputs=[ # pyright: ignore[reportArgumentType] 

4005 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

4006 for out, tt, st in zip( 

4007 src.outputs, 

4008 src.test_outputs, 

4009 src.sample_outputs or [None] * len(src.test_outputs), 

4010 ) 

4011 ], 

4012 parent=( 

4013 None 

4014 if src.parent is None 

4015 else LinkedModel( 

4016 id=ModelId( 

4017 str(src.parent.id) 

4018 + ( 

4019 "" 

4020 if src.parent.version_number is None 

4021 else f"/{src.parent.version_number}" 

4022 ) 

4023 ) 

4024 ) 

4025 ), 

4026 training_data=( 

4027 None 

4028 if src.training_data is None 

4029 else ( 

4030 LinkedDataset( 

4031 id=DatasetId( 

4032 str(src.training_data.id) 

4033 + ( 

4034 "" 

4035 if src.training_data.version_number is None 

4036 else f"/{src.training_data.version_number}" 

4037 ) 

4038 ) 

4039 ) 

4040 if isinstance(src.training_data, LinkedDataset02) 

4041 else src.training_data 

4042 ) 

4043 ), 

4044 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType] 

4045 run_mode=src.run_mode, 

4046 timestamp=src.timestamp, 

4047 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

4048 keras_hdf5=(w := src.weights.keras_hdf5) 

4049 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

4050 authors=conv_authors(w.authors), 

4051 source=w.source, 

4052 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4053 parent=w.parent, 

4054 ), 

4055 onnx=(w := src.weights.onnx) 

4056 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

4057 source=w.source, 

4058 authors=conv_authors(w.authors), 

4059 parent=w.parent, 

4060 opset_version=w.opset_version or 15, 

4061 ), 

4062 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

4063 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

4064 source=w.source, 

4065 authors=conv_authors(w.authors), 

4066 parent=w.parent, 

4067 architecture=( 

4068 arch_file_conv( 

4069 w.architecture, 

4070 w.architecture_sha256, 

4071 w.kwargs, 

4072 ) 

4073 if isinstance(w.architecture, _CallableFromFile_v0_4) 

4074 else arch_lib_conv(w.architecture, w.kwargs) 

4075 ), 

4076 pytorch_version=w.pytorch_version or Version("1.10"), 

4077 dependencies=( 

4078 None 

4079 if w.dependencies is None 

4080 else (FileDescr if TYPE_CHECKING else dict)( 

4081 source=cast( 

4082 FileSource, 

4083 str(deps := w.dependencies)[ 

4084 ( 

4085 len("conda:") 

4086 if str(deps).startswith("conda:") 

4087 else 0 

4088 ) : 

4089 ], 

4090 ) 

4091 ) 

4092 ), 

4093 ), 

4094 tensorflow_js=(w := src.weights.tensorflow_js) 

4095 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

4096 source=w.source, 

4097 authors=conv_authors(w.authors), 

4098 parent=w.parent, 

4099 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4100 ), 

4101 tensorflow_saved_model_bundle=( 

4102 w := src.weights.tensorflow_saved_model_bundle 

4103 ) 

4104 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

4105 authors=conv_authors(w.authors), 

4106 parent=w.parent, 

4107 source=w.source, 

4108 tensorflow_version=w.tensorflow_version or Version("1.15"), 

4109 dependencies=( 

4110 None 

4111 if w.dependencies is None 

4112 else (FileDescr if TYPE_CHECKING else dict)( 

4113 source=cast( 

4114 FileSource, 

4115 ( 

4116 str(w.dependencies)[len("conda:") :] 

4117 if str(w.dependencies).startswith("conda:") 

4118 else str(w.dependencies) 

4119 ), 

4120 ) 

4121 ) 

4122 ), 

4123 ), 

4124 torchscript=(w := src.weights.torchscript) 

4125 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

4126 source=w.source, 

4127 authors=conv_authors(w.authors), 

4128 parent=w.parent, 

4129 pytorch_version=w.pytorch_version or Version("1.10"), 

4130 ), 

4131 ), 

4132 ) 

4133 

4134 

4135_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

4136 

4137 

4138# create better cover images for 3d data and non-image outputs 

4139def generate_covers( 

4140 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

4141 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

4142) -> List[Path]: 

4143 def squeeze( 

4144 data: NDArray[Any], axes: Sequence[AnyAxis] 

4145 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

4146 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

4147 if data.ndim != len(axes): 

4148 raise ValueError( 

4149 f"tensor shape {data.shape} does not match described axes" 

4150 + f" {[a.id for a in axes]}" 

4151 ) 

4152 

4153 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

4154 return data.squeeze(), axes 

4155 

4156 def normalize( 

4157 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

4158 ) -> NDArray[np.float32]: 

4159 data = data.astype("float32") 

4160 data -= data.min(axis=axis, keepdims=True) 

4161 data /= data.max(axis=axis, keepdims=True) + eps 

4162 return data 

4163 

4164 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

4165 original_shape = data.shape 

4166 original_axes = list(axes) 

4167 data, axes = squeeze(data, axes) 

4168 

4169 # take slice fom any batch or index axis if needed 

4170 # and convert the first channel axis and take a slice from any additional channel axes 

4171 slices: Tuple[slice, ...] = () 

4172 ndim = data.ndim 

4173 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

4174 has_c_axis = False 

4175 for i, a in enumerate(axes): 

4176 s = data.shape[i] 

4177 assert s > 1 

4178 if ( 

4179 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

4180 and ndim > ndim_need 

4181 ): 

4182 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4183 ndim -= 1 

4184 elif isinstance(a, ChannelAxis): 

4185 if has_c_axis: 

4186 # second channel axis 

4187 data = data[slices + (slice(0, 1),)] 

4188 ndim -= 1 

4189 else: 

4190 has_c_axis = True 

4191 if s == 2: 

4192 # visualize two channels with cyan and magenta 

4193 data = np.concatenate( 

4194 [ 

4195 data[slices + (slice(1, 2),)], 

4196 data[slices + (slice(0, 1),)], 

4197 ( 

4198 data[slices + (slice(0, 1),)] 

4199 + data[slices + (slice(1, 2),)] 

4200 ) 

4201 / 2, # TODO: take maximum instead? 

4202 ], 

4203 axis=i, 

4204 ) 

4205 elif data.shape[i] == 3: 

4206 pass # visualize 3 channels as RGB 

4207 else: 

4208 # visualize first 3 channels as RGB 

4209 data = data[slices + (slice(3),)] 

4210 

4211 assert data.shape[i] == 3 

4212 

4213 slices += (slice(None),) 

4214 

4215 data, axes = squeeze(data, axes) 

4216 assert len(axes) == ndim 

4217 # take slice from z axis if needed 

4218 slices = () 

4219 if ndim > ndim_need: 

4220 for i, a in enumerate(axes): 

4221 s = data.shape[i] 

4222 if a.id == AxisId("z"): 

4223 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4224 data, axes = squeeze(data, axes) 

4225 ndim -= 1 

4226 break 

4227 

4228 slices += (slice(None),) 

4229 

4230 # take slice from any space or time axis 

4231 slices = () 

4232 

4233 for i, a in enumerate(axes): 

4234 if ndim <= ndim_need: 

4235 break 

4236 

4237 s = data.shape[i] 

4238 assert s > 1 

4239 if isinstance( 

4240 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

4241 ): 

4242 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

4243 ndim -= 1 

4244 

4245 slices += (slice(None),) 

4246 

4247 del slices 

4248 data, axes = squeeze(data, axes) 

4249 assert len(axes) == ndim 

4250 

4251 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2): 

4252 raise ValueError( 

4253 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}." 

4254 ) 

4255 

4256 if not has_c_axis: 

4257 assert ndim == 2 

4258 data = np.repeat(data[:, :, None], 3, axis=2) 

4259 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

4260 ndim += 1 

4261 

4262 assert ndim == 3 

4263 

4264 # transpose axis order such that longest axis comes first... 

4265 axis_order: List[int] = list(np.argsort(list(data.shape))) 

4266 axis_order.reverse() 

4267 # ... and channel axis is last 

4268 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

4269 axis_order.append(axis_order.pop(c)) 

4270 axes = [axes[ao] for ao in axis_order] 

4271 data = data.transpose(axis_order) 

4272 

4273 # h, w = data.shape[:2] 

4274 # if h / w in (1.0 or 2.0): 

4275 # pass 

4276 # elif h / w < 2: 

4277 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

4278 

4279 norm_along = ( 

4280 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

4281 ) 

4282 # normalize the data and map to 8 bit 

4283 data = normalize(data, norm_along) 

4284 data = (data * 255).astype("uint8") 

4285 

4286 return data 

4287 

4288 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

4289 assert im0.dtype == im1.dtype == np.uint8 

4290 assert im0.shape == im1.shape 

4291 assert im0.ndim == 3 

4292 N, M, C = im0.shape 

4293 assert C == 3 

4294 out = np.ones((N, M, C), dtype="uint8") 

4295 for c in range(C): 

4296 outc = np.tril(im0[..., c]) 

4297 mask = outc == 0 

4298 outc[mask] = np.triu(im1[..., c])[mask] 

4299 out[..., c] = outc 

4300 

4301 return out 

4302 

4303 if not inputs: 

4304 raise ValueError("Missing test input tensor for cover generation.") 

4305 

4306 if not outputs: 

4307 raise ValueError("Missing test output tensor for cover generation.") 

4308 

4309 ipt_descr, ipt = inputs[0] 

4310 out_descr, out = outputs[0] 

4311 

4312 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

4313 out_img = to_2d_image(out, out_descr.axes) 

4314 

4315 cover_folder = Path(mkdtemp()) 

4316 if ipt_img.shape == out_img.shape: 

4317 covers = [cover_folder / "cover.png"] 

4318 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

4319 else: 

4320 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

4321 imwrite(covers[0], ipt_img) 

4322 imwrite(covers[1], out_img) 

4323 

4324 return covers