Coverage for src / bioimageio / spec / model / v0_5.py: 76%
1581 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-31 13:09 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-31 13:09 +0000
1from __future__ import annotations
3import collections.abc
4import re
5import string
6import warnings
7from copy import deepcopy
8from itertools import chain
9from math import ceil
10from pathlib import Path, PurePosixPath
11from tempfile import mkdtemp
12from textwrap import dedent
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 ClassVar,
18 Dict,
19 Generic,
20 List,
21 Literal,
22 Mapping,
23 NamedTuple,
24 Optional,
25 Sequence,
26 Set,
27 Tuple,
28 Type,
29 TypeVar,
30 Union,
31 cast,
32 overload,
33)
35import numpy as np
36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
38from loguru import logger
39from numpy.typing import NDArray
40from pydantic import (
41 AfterValidator,
42 Discriminator,
43 Field,
44 RootModel,
45 SerializationInfo,
46 SerializerFunctionWrapHandler,
47 StrictInt,
48 Tag,
49 ValidationInfo,
50 WrapSerializer,
51 field_validator,
52 model_serializer,
53 model_validator,
54)
55from typing_extensions import Annotated, Self, assert_never, get_args
57from .._internal.common_nodes import (
58 InvalidDescr,
59 KwargsNode,
60 Node,
61 NodeWithExplicitlySetFields,
62)
63from .._internal.constants import DTYPE_LIMITS
64from .._internal.field_warning import issue_warning, warn
65from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
66from .._internal.io import FileDescr as FileDescr
67from .._internal.io import (
68 FileSource,
69 WithSuffix,
70 YamlValue,
71 extract_file_name,
72 get_reader,
73 wo_special_file_name,
74)
75from .._internal.io_basics import Sha256 as Sha256
76from .._internal.io_packaging import (
77 FileDescr_,
78 FileSource_,
79 package_file_descr_serializer,
80)
81from .._internal.io_utils import load_array
82from .._internal.node_converter import Converter
83from .._internal.type_guards import is_dict, is_sequence
84from .._internal.types import (
85 FAIR,
86 AbsoluteTolerance,
87 LowerCaseIdentifier,
88 LowerCaseIdentifierAnno,
89 MismatchedElementsPerMillion,
90 RelativeTolerance,
91)
92from .._internal.types import Datetime as Datetime
93from .._internal.types import Identifier as Identifier
94from .._internal.types import NotEmpty as NotEmpty
95from .._internal.types import SiUnit as SiUnit
96from .._internal.url import HttpUrl as HttpUrl
97from .._internal.validation_context import get_validation_context
98from .._internal.validator_annotations import RestrictCharacters
99from .._internal.version_type import Version as Version
100from .._internal.warning_levels import INFO
101from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
102from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
103from ..dataset.v0_3 import DatasetDescr as DatasetDescr
104from ..dataset.v0_3 import DatasetId as DatasetId
105from ..dataset.v0_3 import LinkedDataset as LinkedDataset
106from ..dataset.v0_3 import Uploader as Uploader
107from ..generic.v0_3 import (
108 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
109)
110from ..generic.v0_3 import Author as Author
111from ..generic.v0_3 import BadgeDescr as BadgeDescr
112from ..generic.v0_3 import CiteEntry as CiteEntry
113from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
114from ..generic.v0_3 import Doi as Doi
115from ..generic.v0_3 import (
116 FileSource_documentation,
117 GenericModelDescrBase,
118 LinkedResourceBase,
119 _author_conv, # pyright: ignore[reportPrivateUsage]
120 _maintainer_conv, # pyright: ignore[reportPrivateUsage]
121)
122from ..generic.v0_3 import LicenseId as LicenseId
123from ..generic.v0_3 import LinkedResource as LinkedResource
124from ..generic.v0_3 import Maintainer as Maintainer
125from ..generic.v0_3 import OrcidId as OrcidId
126from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
127from ..generic.v0_3 import ResourceId as ResourceId
128from .v0_4 import Author as _Author_v0_4
129from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
130from .v0_4 import CallableFromDepencency as CallableFromDepencency
131from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
132from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
133from .v0_4 import ClipDescr as _ClipDescr_v0_4
134from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
135from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
136from .v0_4 import KnownRunMode as KnownRunMode
137from .v0_4 import ModelDescr as _ModelDescr_v0_4
138from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
139from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
140from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
141from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
142from .v0_4 import RunMode as RunMode
143from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
144from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
145from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
146from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
147from .v0_4 import TensorName as _TensorName_v0_4
148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
149from .v0_4 import package_weights
151SpaceUnit = Literal[
152 "attometer",
153 "angstrom",
154 "centimeter",
155 "decimeter",
156 "exameter",
157 "femtometer",
158 "foot",
159 "gigameter",
160 "hectometer",
161 "inch",
162 "kilometer",
163 "megameter",
164 "meter",
165 "micrometer",
166 "mile",
167 "millimeter",
168 "nanometer",
169 "parsec",
170 "petameter",
171 "picometer",
172 "terameter",
173 "yard",
174 "yoctometer",
175 "yottameter",
176 "zeptometer",
177 "zettameter",
178]
179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
181TimeUnit = Literal[
182 "attosecond",
183 "centisecond",
184 "day",
185 "decisecond",
186 "exasecond",
187 "femtosecond",
188 "gigasecond",
189 "hectosecond",
190 "hour",
191 "kilosecond",
192 "megasecond",
193 "microsecond",
194 "millisecond",
195 "minute",
196 "nanosecond",
197 "petasecond",
198 "picosecond",
199 "second",
200 "terasecond",
201 "yoctosecond",
202 "yottasecond",
203 "zeptosecond",
204 "zettasecond",
205]
206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
208AxisType = Literal["batch", "channel", "index", "time", "space"]
210_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
211 "b": "batch",
212 "t": "time",
213 "i": "index",
214 "c": "channel",
215 "x": "space",
216 "y": "space",
217 "z": "space",
218}
220_AXIS_ID_MAP = {
221 "b": "batch",
222 "t": "time",
223 "i": "index",
224 "c": "channel",
225}
227WeightsFormat = Literal[
228 "keras_hdf5",
229 "keras_v3",
230 "onnx",
231 "pytorch_state_dict",
232 "tensorflow_js",
233 "tensorflow_saved_model_bundle",
234 "torchscript",
235]
238class TensorId(LowerCaseIdentifier):
239 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
240 Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
241 ]
244def _normalize_axis_id(a: str):
245 a = str(a)
246 normalized = _AXIS_ID_MAP.get(a, a)
247 if a != normalized:
248 logger.opt(depth=3).warning(
249 "Normalized axis id from '{}' to '{}'.", a, normalized
250 )
251 return normalized
254class AxisId(LowerCaseIdentifier):
255 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
256 Annotated[
257 LowerCaseIdentifierAnno,
258 MaxLen(16),
259 AfterValidator(_normalize_axis_id),
260 ]
261 ]
264def _is_batch(a: str) -> bool:
265 return str(a) == "batch"
268def _is_not_batch(a: str) -> bool:
269 return not _is_batch(a)
272NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
274PreprocessingId = Literal[
275 "binarize",
276 "clip",
277 "ensure_dtype",
278 "fixed_zero_mean_unit_variance",
279 "scale_linear",
280 "scale_range",
281 "sigmoid",
282 "softmax",
283]
284PostprocessingId = Literal[
285 "binarize",
286 "clip",
287 "ensure_dtype",
288 "fixed_zero_mean_unit_variance",
289 "scale_linear",
290 "scale_mean_variance",
291 "scale_range",
292 "sigmoid",
293 "softmax",
294 "zero_mean_unit_variance",
295]
298SAME_AS_TYPE = "<same as type>"
301ParameterizedSize_N = int
302"""
303Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
304"""
307class ParameterizedSize(Node):
308 """Describes a range of valid tensor axis sizes as `size = min + n*step`.
310 - **min** and **step** are given by the model description.
311 - All blocksize paramters n = 0,1,2,... yield a valid `size`.
312 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
313 This allows to adjust the axis size more generically.
314 """
316 N: ClassVar[Type[int]] = ParameterizedSize_N
317 """Positive integer to parameterize this axis"""
319 min: Annotated[int, Gt(0)]
320 step: Annotated[int, Gt(0)]
322 def validate_size(self, size: int, msg_prefix: str = "") -> int:
323 if size < self.min:
324 raise ValueError(
325 f"{msg_prefix}size {size} < {self.min} (minimum axis size)"
326 )
327 if (size - self.min) % self.step != 0:
328 raise ValueError(
329 f"{msg_prefix}size {size} is not parameterized by `min + n*step` ="
330 + f" `{self.min} + n*{self.step}`"
331 )
333 return size
335 def get_size(self, n: ParameterizedSize_N) -> int:
336 return self.min + self.step * n
338 def get_n(self, s: int) -> ParameterizedSize_N:
339 """return smallest n parameterizing a size greater or equal than `s`"""
340 return ceil((s - self.min) / self.step)
343class DataDependentSize(Node):
344 min: Annotated[int, Gt(0)] = 1
345 max: Annotated[Optional[int], Gt(1)] = None
347 @model_validator(mode="after")
348 def _validate_max_gt_min(self):
349 if self.max is not None and self.min >= self.max:
350 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
352 return self
354 def validate_size(self, size: int, msg_prefix: str = "") -> int:
355 if size < self.min:
356 raise ValueError(f"{msg_prefix}size {size} < {self.min}")
358 if self.max is not None and size > self.max:
359 raise ValueError(f"{msg_prefix}size {size} > {self.max}")
361 return size
364class SizeReference(Node):
365 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
367 `axis.size = reference.size * reference.scale / axis.scale + offset`
369 Note:
370 1. The axis and the referenced axis need to have the same unit (or no unit).
371 2. Batch axes may not be referenced.
372 3. Fractions are rounded down.
373 4. If the reference axis is `concatenable` the referencing axis is assumed to be
374 `concatenable` as well with the same block order.
376 Example:
377 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
378 Let's assume that we want to express the image height h in relation to its width w
379 instead of only accepting input images of exactly 100*49 pixels
380 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
382 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
383 >>> h = SpaceInputAxis(
384 ... id=AxisId("h"),
385 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
386 ... unit="millimeter",
387 ... scale=4,
388 ... )
389 >>> print(h.size.get_size(h, w))
390 49
392 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
393 """
395 tensor_id: TensorId
396 """tensor id of the reference axis"""
398 axis_id: AxisId
399 """axis id of the reference axis"""
401 offset: StrictInt = 0
403 def get_size(
404 self,
405 axis: Union[
406 ChannelAxis,
407 IndexInputAxis,
408 IndexOutputAxis,
409 TimeInputAxis,
410 SpaceInputAxis,
411 TimeOutputAxis,
412 TimeOutputAxisWithHalo,
413 SpaceOutputAxis,
414 SpaceOutputAxisWithHalo,
415 ],
416 ref_axis: Union[
417 ChannelAxis,
418 IndexInputAxis,
419 IndexOutputAxis,
420 TimeInputAxis,
421 SpaceInputAxis,
422 TimeOutputAxis,
423 TimeOutputAxisWithHalo,
424 SpaceOutputAxis,
425 SpaceOutputAxisWithHalo,
426 ],
427 n: ParameterizedSize_N = 0,
428 ref_size: Optional[int] = None,
429 ):
430 """Compute the concrete size for a given axis and its reference axis.
432 Args:
433 axis: The axis this [SizeReference][] is the size of.
434 ref_axis: The reference axis to compute the size from.
435 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
436 and no fixed **ref_size** is given,
437 **n** is used to compute the size of the parameterized **ref_axis**.
438 ref_size: Overwrite the reference size instead of deriving it from
439 **ref_axis**
440 (**ref_axis.scale** is still used; any given **n** is ignored).
441 """
442 assert axis.size == self, (
443 "Given `axis.size` is not defined by this `SizeReference`"
444 )
446 assert ref_axis.id == self.axis_id, (
447 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
448 )
450 assert axis.unit == ref_axis.unit, (
451 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
452 f" but {axis.unit}!={ref_axis.unit}"
453 )
454 if ref_size is None:
455 if isinstance(ref_axis.size, (int, float)):
456 ref_size = ref_axis.size
457 elif isinstance(ref_axis.size, ParameterizedSize):
458 ref_size = ref_axis.size.get_size(n)
459 elif isinstance(ref_axis.size, DataDependentSize):
460 raise ValueError(
461 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
462 )
463 elif isinstance(ref_axis.size, SizeReference):
464 raise ValueError(
465 "Reference axis referenced in `SizeReference` may not be sized by a"
466 + " `SizeReference` itself."
467 )
468 else:
469 assert_never(ref_axis.size)
471 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
473 @staticmethod
474 def _get_unit(
475 axis: Union[
476 ChannelAxis,
477 IndexInputAxis,
478 IndexOutputAxis,
479 TimeInputAxis,
480 SpaceInputAxis,
481 TimeOutputAxis,
482 TimeOutputAxisWithHalo,
483 SpaceOutputAxis,
484 SpaceOutputAxisWithHalo,
485 ],
486 ):
487 return axis.unit
490class AxisBase(NodeWithExplicitlySetFields):
491 id: AxisId
492 """An axis id unique across all axes of one tensor."""
494 description: Annotated[str, MaxLen(128)] = ""
495 """A short description of this axis beyond its type and id."""
498class WithHalo(Node):
499 halo: Annotated[int, Ge(1)]
500 """The halo should be cropped from the output tensor to avoid boundary effects.
501 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
502 To document a halo that is already cropped by the model use `size.offset` instead."""
504 size: Annotated[
505 SizeReference,
506 Field(
507 examples=[
508 10,
509 SizeReference(
510 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
511 ).model_dump(mode="json"),
512 ]
513 ),
514 ]
515 """reference to another axis with an optional offset (see [SizeReference][])"""
518BATCH_AXIS_ID = AxisId("batch")
521class BatchAxis(AxisBase):
522 implemented_type: ClassVar[Literal["batch"]] = "batch"
523 if TYPE_CHECKING:
524 type: Literal["batch"] = "batch"
525 else:
526 type: Literal["batch"]
528 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
529 size: Optional[Literal[1]] = None
530 """The batch size may be fixed to 1,
531 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
533 @property
534 def scale(self):
535 return 1.0
537 @property
538 def concatenable(self):
539 return True
541 @property
542 def unit(self):
543 return None
546class ChannelAxis(AxisBase):
547 implemented_type: ClassVar[Literal["channel"]] = "channel"
548 if TYPE_CHECKING:
549 type: Literal["channel"] = "channel"
550 else:
551 type: Literal["channel"]
553 id: NonBatchAxisId = AxisId("channel")
555 channel_names: NotEmpty[List[Identifier]]
557 @property
558 def size(self) -> int:
559 return len(self.channel_names)
561 @property
562 def concatenable(self):
563 return False
565 @property
566 def scale(self) -> float:
567 return 1.0
569 @property
570 def unit(self):
571 return None
574class IndexAxisBase(AxisBase):
575 implemented_type: ClassVar[Literal["index"]] = "index"
576 if TYPE_CHECKING:
577 type: Literal["index"] = "index"
578 else:
579 type: Literal["index"]
581 id: NonBatchAxisId = AxisId("index")
583 @property
584 def scale(self) -> float:
585 return 1.0
587 @property
588 def unit(self):
589 return None
592class _WithInputAxisSize(Node):
593 size: Annotated[
594 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
595 Field(
596 examples=[
597 10,
598 ParameterizedSize(min=32, step=16).model_dump(mode="json"),
599 SizeReference(
600 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
601 ).model_dump(mode="json"),
602 ]
603 ),
604 ]
605 """The size/length of this axis can be specified as
606 - fixed integer
607 - parameterized series of valid sizes ([ParameterizedSize][])
608 - reference to another axis with an optional offset ([SizeReference][])
609 """
612class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
613 concatenable: bool = False
614 """If a model has a `concatenable` input axis, it can be processed blockwise,
615 splitting a longer sample axis into blocks matching its input tensor description.
616 Output axes are concatenable if they have a [SizeReference][] to a concatenable
617 input axis.
618 """
621class IndexOutputAxis(IndexAxisBase):
622 size: Annotated[
623 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
624 Field(
625 examples=[
626 10,
627 SizeReference(
628 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
629 ).model_dump(mode="json"),
630 ]
631 ),
632 ]
633 """The size/length of this axis can be specified as
634 - fixed integer
635 - reference to another axis with an optional offset ([SizeReference][])
636 - data dependent size using [DataDependentSize][] (size is only known after model inference)
637 """
640class TimeAxisBase(AxisBase):
641 implemented_type: ClassVar[Literal["time"]] = "time"
642 if TYPE_CHECKING:
643 type: Literal["time"] = "time"
644 else:
645 type: Literal["time"]
647 id: NonBatchAxisId = AxisId("time")
648 unit: Optional[TimeUnit] = None
649 scale: Annotated[float, Gt(0)] = 1.0
652class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
653 concatenable: bool = False
654 """If a model has a `concatenable` input axis, it can be processed blockwise,
655 splitting a longer sample axis into blocks matching its input tensor description.
656 Output axes are concatenable if they have a [SizeReference][] to a concatenable
657 input axis.
658 """
661class SpaceAxisBase(AxisBase):
662 implemented_type: ClassVar[Literal["space"]] = "space"
663 if TYPE_CHECKING:
664 type: Literal["space"] = "space"
665 else:
666 type: Literal["space"]
668 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
669 unit: Optional[SpaceUnit] = None
670 scale: Annotated[float, Gt(0)] = 1.0
673class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
674 concatenable: bool = False
675 """If a model has a `concatenable` input axis, it can be processed blockwise,
676 splitting a longer sample axis into blocks matching its input tensor description.
677 Output axes are concatenable if they have a [SizeReference][] to a concatenable
678 input axis.
679 """
682INPUT_AXIS_TYPES = (
683 BatchAxis,
684 ChannelAxis,
685 IndexInputAxis,
686 TimeInputAxis,
687 SpaceInputAxis,
688)
689"""intended for isinstance comparisons in py<3.10"""
691_InputAxisUnion = Union[
692 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
693]
694InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
697class _WithOutputAxisSize(Node):
698 size: Annotated[
699 Union[Annotated[int, Gt(0)], SizeReference],
700 Field(
701 examples=[
702 10,
703 SizeReference(
704 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
705 ).model_dump(mode="json"),
706 ]
707 ),
708 ]
709 """The size/length of this axis can be specified as
710 - fixed integer
711 - reference to another axis with an optional offset (see [SizeReference][])
712 """
715class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
716 pass
719class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
720 pass
723def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
724 if isinstance(v, dict):
725 return "with_halo" if "halo" in v else "wo_halo"
726 else:
727 return "with_halo" if hasattr(v, "halo") else "wo_halo"
730_TimeOutputAxisUnion = Annotated[
731 Union[
732 Annotated[TimeOutputAxis, Tag("wo_halo")],
733 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
734 ],
735 Discriminator(_get_halo_axis_discriminator_value),
736]
739class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
740 pass
743class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
744 pass
747_SpaceOutputAxisUnion = Annotated[
748 Union[
749 Annotated[SpaceOutputAxis, Tag("wo_halo")],
750 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
751 ],
752 Discriminator(_get_halo_axis_discriminator_value),
753]
756_OutputAxisUnion = Union[
757 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
758]
759OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
761OUTPUT_AXIS_TYPES = (
762 BatchAxis,
763 ChannelAxis,
764 IndexOutputAxis,
765 TimeOutputAxis,
766 TimeOutputAxisWithHalo,
767 SpaceOutputAxis,
768 SpaceOutputAxisWithHalo,
769)
770"""intended for isinstance comparisons in py<3.10"""
773AnyAxis = Union[InputAxis, OutputAxis]
775ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
776"""intended for isinstance comparisons in py<3.10"""
778TVs = Union[
779 NotEmpty[List[int]],
780 NotEmpty[List[float]],
781 NotEmpty[List[bool]],
782 NotEmpty[List[str]],
783]
786NominalOrOrdinalDType = Literal[
787 "float32",
788 "float64",
789 "uint8",
790 "int8",
791 "uint16",
792 "int16",
793 "uint32",
794 "int32",
795 "uint64",
796 "int64",
797 "bool",
798]
801class NominalOrOrdinalDataDescr(Node):
802 values: TVs
803 """A fixed set of nominal or an ascending sequence of ordinal values.
804 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
805 String `values` are interpreted as labels for tensor values 0, ..., N.
806 Note: as YAML 1.2 does not natively support a "set" datatype,
807 nominal values should be given as a sequence (aka list/array) as well.
808 """
810 type: Annotated[
811 NominalOrOrdinalDType,
812 Field(
813 examples=[
814 "float32",
815 "uint8",
816 "uint16",
817 "int64",
818 "bool",
819 ],
820 ),
821 ] = "uint8"
823 @model_validator(mode="after")
824 def _validate_values_match_type(
825 self,
826 ) -> Self:
827 incompatible: List[Any] = []
828 for v in self.values:
829 if self.type == "bool":
830 if not isinstance(v, bool):
831 incompatible.append(v)
832 elif self.type in DTYPE_LIMITS:
833 if (
834 isinstance(v, (int, float))
835 and (
836 v < DTYPE_LIMITS[self.type].min
837 or v > DTYPE_LIMITS[self.type].max
838 )
839 or (isinstance(v, str) and "uint" not in self.type)
840 or (isinstance(v, float) and "int" in self.type)
841 ):
842 incompatible.append(v)
843 else:
844 incompatible.append(v)
846 if len(incompatible) == 5:
847 incompatible.append("...")
848 break
850 if incompatible:
851 raise ValueError(
852 f"data type '{self.type}' incompatible with values {incompatible}"
853 )
855 return self
857 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
859 @property
860 def range(self):
861 if isinstance(self.values[0], str):
862 return 0, len(self.values) - 1
863 else:
864 return min(self.values), max(self.values)
867IntervalOrRatioDType = Literal[
868 "float32",
869 "float64",
870 "uint8",
871 "int8",
872 "uint16",
873 "int16",
874 "uint32",
875 "int32",
876 "uint64",
877 "int64",
878]
881class IntervalOrRatioDataDescr(Node):
882 type: Annotated[ # TODO: rename to dtype
883 IntervalOrRatioDType,
884 Field(
885 examples=["float32", "float64", "uint8", "uint16"],
886 ),
887 ] = "float32"
888 range: Tuple[Optional[float], Optional[float]] = (
889 None,
890 None,
891 )
892 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
893 `None` corresponds to min/max of what can be expressed by **type**."""
894 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
895 scale: float = 1.0
896 """Scale for data on an interval (or ratio) scale."""
897 offset: Optional[float] = None
898 """Offset for data on a ratio scale."""
900 @model_validator(mode="before")
901 def _replace_inf(cls, data: Any):
902 if is_dict(data):
903 if "range" in data and is_sequence(data["range"]):
904 forbidden = (
905 "inf",
906 "-inf",
907 ".inf",
908 "-.inf",
909 float("inf"),
910 float("-inf"),
911 )
912 if any(v in forbidden for v in data["range"]):
913 issue_warning("replaced 'inf' value", value=data["range"])
915 data["range"] = tuple(
916 (None if v in forbidden else v) for v in data["range"]
917 )
919 return data
922TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
925class BinarizeKwargs(KwargsNode):
926 """key word arguments for [BinarizeDescr][]"""
928 threshold: float
929 """The fixed threshold"""
932class BinarizeAlongAxisKwargs(KwargsNode):
933 """key word arguments for [BinarizeDescr][]"""
935 threshold: NotEmpty[List[float]]
936 """The fixed threshold values along `axis`"""
938 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
939 """The `threshold` axis"""
942class BinarizeDescr(NodeWithExplicitlySetFields):
943 """Binarize the tensor with a fixed threshold.
945 Values above [BinarizeKwargs.threshold][]/[BinarizeAlongAxisKwargs.threshold][]
946 will be set to one, values below the threshold to zero.
948 Examples:
949 - in YAML
950 ```yaml
951 postprocessing:
952 - id: binarize
953 kwargs:
954 axis: 'channel'
955 threshold: [0.25, 0.5, 0.75]
956 ```
957 - in Python:
959 >>> postprocessing = [BinarizeDescr(
960 ... kwargs=BinarizeAlongAxisKwargs(
961 ... axis=AxisId('channel'),
962 ... threshold=[0.25, 0.5, 0.75],
963 ... )
964 ... )]
965 """
967 implemented_id: ClassVar[Literal["binarize"]] = "binarize"
968 if TYPE_CHECKING:
969 id: Literal["binarize"] = "binarize"
970 else:
971 id: Literal["binarize"]
972 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
975class ClipKwargs(KwargsNode):
976 """key word arguments for [ClipDescr][]"""
978 min: Optional[float] = None
979 """Minimum value for clipping.
981 Exclusive with [min_percentile][]
982 """
983 min_percentile: Optional[Annotated[float, Interval(ge=0, lt=100)]] = None
984 """Minimum percentile for clipping.
986 Exclusive with [min][].
988 In range [0, 100).
989 """
991 max: Optional[float] = None
992 """Maximum value for clipping.
994 Exclusive with `max_percentile`.
995 """
996 max_percentile: Optional[Annotated[float, Interval(gt=1, le=100)]] = None
997 """Maximum percentile for clipping.
999 Exclusive with `max`.
1001 In range (1, 100].
1002 """
1004 axes: Annotated[
1005 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1006 ] = None
1007 """The subset of axes to determine percentiles jointly,
1009 i.e. axes to reduce to compute min/max from `min_percentile`/`max_percentile`.
1010 For example to clip 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1011 resulting in a tensor of equal shape with clipped values per channel, specify `axes=('batch', 'x', 'y')`.
1012 To clip samples independently, leave out the 'batch' axis.
1014 Only valid if `min_percentile` and/or `max_percentile` are set.
1016 Default: Compute percentiles over all axes jointly."""
1018 @model_validator(mode="after")
1019 def _validate(self) -> Self:
1020 if (self.min is not None) and (self.min_percentile is not None):
1021 raise ValueError(
1022 "Only one of `min` and `min_percentile` may be set, not both."
1023 )
1024 if (self.max is not None) and (self.max_percentile is not None):
1025 raise ValueError(
1026 "Only one of `max` and `max_percentile` may be set, not both."
1027 )
1028 if (
1029 self.min is None
1030 and self.min_percentile is None
1031 and self.max is None
1032 and self.max_percentile is None
1033 ):
1034 raise ValueError(
1035 "At least one of `min`, `min_percentile`, `max`, or `max_percentile` must be set."
1036 )
1038 if (
1039 self.axes is not None
1040 and self.min_percentile is None
1041 and self.max_percentile is None
1042 ):
1043 raise ValueError(
1044 "If `axes` is set, at least one of `min_percentile` or `max_percentile` must be set."
1045 )
1047 return self
1050class ClipDescr(NodeWithExplicitlySetFields):
1051 """Set tensor values below min to min and above max to max.
1053 See `ScaleRangeDescr` for examples.
1054 """
1056 implemented_id: ClassVar[Literal["clip"]] = "clip"
1057 if TYPE_CHECKING:
1058 id: Literal["clip"] = "clip"
1059 else:
1060 id: Literal["clip"]
1062 kwargs: ClipKwargs
1065class EnsureDtypeKwargs(KwargsNode):
1066 """key word arguments for [EnsureDtypeDescr][]"""
1068 dtype: Literal[
1069 "float32",
1070 "float64",
1071 "uint8",
1072 "int8",
1073 "uint16",
1074 "int16",
1075 "uint32",
1076 "int32",
1077 "uint64",
1078 "int64",
1079 "bool",
1080 ]
1083class EnsureDtypeDescr(NodeWithExplicitlySetFields):
1084 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1086 This can for example be used to ensure the inner neural network model gets a
1087 different input tensor data type than the fully described bioimage.io model does.
1089 Examples:
1090 The described bioimage.io model (incl. preprocessing) accepts any
1091 float32-compatible tensor, normalizes it with percentiles and clipping and then
1092 casts it to uint8, which is what the neural network in this example expects.
1093 - in YAML
1094 ```yaml
1095 inputs:
1096 - data:
1097 type: float32 # described bioimage.io model is compatible with any float32 input tensor
1098 preprocessing:
1099 - id: scale_range
1100 kwargs:
1101 axes: ['y', 'x']
1102 max_percentile: 99.8
1103 min_percentile: 5.0
1104 - id: clip
1105 kwargs:
1106 min: 0.0
1107 max: 1.0
1108 - id: ensure_dtype # the neural network of the model requires uint8
1109 kwargs:
1110 dtype: uint8
1111 ```
1112 - in Python:
1113 >>> preprocessing = [
1114 ... ScaleRangeDescr(
1115 ... kwargs=ScaleRangeKwargs(
1116 ... axes= (AxisId('y'), AxisId('x')),
1117 ... max_percentile= 99.8,
1118 ... min_percentile= 5.0,
1119 ... )
1120 ... ),
1121 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1122 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1123 ... ]
1124 """
1126 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1127 if TYPE_CHECKING:
1128 id: Literal["ensure_dtype"] = "ensure_dtype"
1129 else:
1130 id: Literal["ensure_dtype"]
1132 kwargs: EnsureDtypeKwargs
1135class ScaleLinearKwargs(KwargsNode):
1136 """Key word arguments for [ScaleLinearDescr][]"""
1138 gain: float = 1.0
1139 """multiplicative factor"""
1141 offset: float = 0.0
1142 """additive term"""
1144 @model_validator(mode="after")
1145 def _validate(self) -> Self:
1146 if self.gain == 1.0 and self.offset == 0.0:
1147 raise ValueError(
1148 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1149 + " != 0.0."
1150 )
1152 return self
1155class ScaleLinearAlongAxisKwargs(KwargsNode):
1156 """Key word arguments for [ScaleLinearDescr][]"""
1158 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1159 """The axis of gain and offset values."""
1161 gain: Union[float, NotEmpty[List[float]]] = 1.0
1162 """multiplicative factor"""
1164 offset: Union[float, NotEmpty[List[float]]] = 0.0
1165 """additive term"""
1167 @model_validator(mode="after")
1168 def _validate(self) -> Self:
1169 if isinstance(self.gain, list):
1170 if isinstance(self.offset, list):
1171 if len(self.gain) != len(self.offset):
1172 raise ValueError(
1173 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1174 )
1175 else:
1176 self.offset = [float(self.offset)] * len(self.gain)
1177 elif isinstance(self.offset, list):
1178 self.gain = [float(self.gain)] * len(self.offset)
1179 else:
1180 raise ValueError(
1181 "Do not specify an `axis` for scalar gain and offset values."
1182 )
1184 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1185 raise ValueError(
1186 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1187 + " != 0.0."
1188 )
1190 return self
1193class ScaleLinearDescr(NodeWithExplicitlySetFields):
1194 """Fixed linear scaling.
1196 Examples:
1197 1. Scale with scalar gain and offset
1198 - in YAML
1199 ```yaml
1200 preprocessing:
1201 - id: scale_linear
1202 kwargs:
1203 gain: 2.0
1204 offset: 3.0
1205 ```
1206 - in Python:
1208 >>> preprocessing = [
1209 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1210 ... ]
1212 2. Independent scaling along an axis
1213 - in YAML
1214 ```yaml
1215 preprocessing:
1216 - id: scale_linear
1217 kwargs:
1218 axis: 'channel'
1219 gain: [1.0, 2.0, 3.0]
1220 ```
1221 - in Python:
1223 >>> preprocessing = [
1224 ... ScaleLinearDescr(
1225 ... kwargs=ScaleLinearAlongAxisKwargs(
1226 ... axis=AxisId("channel"),
1227 ... gain=[1.0, 2.0, 3.0],
1228 ... )
1229 ... )
1230 ... ]
1232 """
1234 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1235 if TYPE_CHECKING:
1236 id: Literal["scale_linear"] = "scale_linear"
1237 else:
1238 id: Literal["scale_linear"]
1239 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1242class SigmoidDescr(NodeWithExplicitlySetFields):
1243 """The logistic sigmoid function, a.k.a. expit function.
1245 Examples:
1246 - in YAML
1247 ```yaml
1248 postprocessing:
1249 - id: sigmoid
1250 ```
1251 - in Python:
1253 >>> postprocessing = [SigmoidDescr()]
1254 """
1256 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1257 if TYPE_CHECKING:
1258 id: Literal["sigmoid"] = "sigmoid"
1259 else:
1260 id: Literal["sigmoid"]
1262 @property
1263 def kwargs(self) -> KwargsNode:
1264 """empty kwargs"""
1265 return KwargsNode()
1268class SoftmaxKwargs(KwargsNode):
1269 """key word arguments for [SoftmaxDescr][]"""
1271 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1272 """The axis to apply the softmax function along.
1273 Note:
1274 Defaults to 'channel' axis
1275 (which may not exist, in which case
1276 a different axis id has to be specified).
1277 """
1280class SoftmaxDescr(NodeWithExplicitlySetFields):
1281 """The softmax function.
1283 Examples:
1284 - in YAML
1285 ```yaml
1286 postprocessing:
1287 - id: softmax
1288 kwargs:
1289 axis: channel
1290 ```
1291 - in Python:
1293 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1294 """
1296 implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1297 if TYPE_CHECKING:
1298 id: Literal["softmax"] = "softmax"
1299 else:
1300 id: Literal["softmax"]
1302 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
1305class _StardistPostprocessingKwargsBase(KwargsNode):
1306 """key word arguments for [StardistPostprocessingDescr][]"""
1308 prob_threshold: float
1309 """The probability threshold for object candidate selection."""
1311 nms_threshold: float
1312 """The IoU threshold for non-maximum suppression."""
1315class StardistPostprocessingKwargs2D(_StardistPostprocessingKwargsBase):
1316 grid: Tuple[int, int]
1317 """Grid size of network predictions."""
1319 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int]]]
1320 """Border region in which object probability is set to zero."""
1323class StardistPostprocessingKwargs3D(_StardistPostprocessingKwargsBase):
1324 grid: Tuple[int, int, int]
1325 """Grid size of network predictions."""
1327 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]]
1328 """Border region in which object probability is set to zero."""
1330 n_rays: int
1331 """Number of rays for 3D star-convex polyhedra."""
1333 anisotropy: Tuple[float, float, float]
1334 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis."""
1336 overlap_label: Optional[int] = None
1337 """Optional label to apply to any area of overlapping predicted objects."""
1340class StardistPostprocessingDescr(NodeWithExplicitlySetFields):
1341 """Stardist postprocessing including non-maximum suppression and converting polygon representations to instance labels
1343 as described in:
1344 - Uwe Schmidt, Martin Weigert, Coleman Broaddus, and Gene Myers.
1345 [*Cell Detection with Star-convex Polygons*](https://arxiv.org/abs/1806.03535).
1346 International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018.
1347 - Martin Weigert, Uwe Schmidt, Robert Haase, Ko Sugawara, and Gene Myers.
1348 [*Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy*](http://openaccess.thecvf.com/content_WACV_2020/papers/Weigert_Star-convex_Polyhedra_for_3D_Object_Detection_and_Segmentation_in_Microscopy_WACV_2020_paper.pdf).
1349 The IEEE Winter Conference on Applications of Computer Vision (WACV), Snowmass Village, Colorado, March 2020.
1351 Note: Only available if the `stardist` package is installed.
1352 """
1354 implemented_id: ClassVar[Literal["stardist_postprocessing"]] = (
1355 "stardist_postprocessing"
1356 )
1357 if TYPE_CHECKING:
1358 id: Literal["stardist_postprocessing"] = "stardist_postprocessing"
1359 else:
1360 id: Literal["stardist_postprocessing"]
1362 kwargs: Union[StardistPostprocessingKwargs2D, StardistPostprocessingKwargs3D]
1365class FixedZeroMeanUnitVarianceKwargs(KwargsNode):
1366 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]"""
1368 mean: float
1369 """The mean value to normalize with."""
1371 std: Annotated[float, Ge(1e-6)]
1372 """The standard deviation value to normalize with."""
1375class FixedZeroMeanUnitVarianceAlongAxisKwargs(KwargsNode):
1376 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]"""
1378 mean: NotEmpty[List[float]]
1379 """The mean value(s) to normalize with."""
1381 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1382 """The standard deviation value(s) to normalize with.
1383 Size must match `mean` values."""
1385 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1386 """The axis of the mean/std values to normalize each entry along that dimension
1387 separately."""
1389 @model_validator(mode="after")
1390 def _mean_and_std_match(self) -> Self:
1391 if len(self.mean) != len(self.std):
1392 raise ValueError(
1393 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1394 + " must match."
1395 )
1397 return self
1400class FixedZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields):
1401 """Subtract a given mean and divide by the standard deviation.
1403 Normalize with fixed, precomputed values for
1404 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1405 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1406 axes.
1408 Examples:
1409 1. scalar value for whole tensor
1410 - in YAML
1411 ```yaml
1412 preprocessing:
1413 - id: fixed_zero_mean_unit_variance
1414 kwargs:
1415 mean: 103.5
1416 std: 13.7
1417 ```
1418 - in Python
1419 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1420 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1421 ... )]
1423 2. independently along an axis
1424 - in YAML
1425 ```yaml
1426 preprocessing:
1427 - id: fixed_zero_mean_unit_variance
1428 kwargs:
1429 axis: channel
1430 mean: [101.5, 102.5, 103.5]
1431 std: [11.7, 12.7, 13.7]
1432 ```
1433 - in Python
1434 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1435 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1436 ... axis=AxisId("channel"),
1437 ... mean=[101.5, 102.5, 103.5],
1438 ... std=[11.7, 12.7, 13.7],
1439 ... )
1440 ... )]
1441 """
1443 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1444 "fixed_zero_mean_unit_variance"
1445 )
1446 if TYPE_CHECKING:
1447 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1448 else:
1449 id: Literal["fixed_zero_mean_unit_variance"]
1451 kwargs: Union[
1452 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1453 ]
1456class ZeroMeanUnitVarianceKwargs(KwargsNode):
1457 """key word arguments for [ZeroMeanUnitVarianceDescr][]"""
1459 axes: Annotated[
1460 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1461 ] = None
1462 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1463 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1464 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1465 To normalize each sample independently leave out the 'batch' axis.
1466 Default: Scale all axes jointly."""
1468 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1469 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1472class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields):
1473 """Subtract mean and divide by variance.
1475 Examples:
1476 Subtract tensor mean and variance
1477 - in YAML
1478 ```yaml
1479 preprocessing:
1480 - id: zero_mean_unit_variance
1481 ```
1482 - in Python
1483 >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1484 """
1486 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1487 "zero_mean_unit_variance"
1488 )
1489 if TYPE_CHECKING:
1490 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1491 else:
1492 id: Literal["zero_mean_unit_variance"]
1494 kwargs: ZeroMeanUnitVarianceKwargs = Field(
1495 default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1496 )
1499class ScaleRangeKwargs(KwargsNode):
1500 """key word arguments for [ScaleRangeDescr][]
1502 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1503 this processing step normalizes data to the [0, 1] intervall.
1504 For other percentiles the normalized values will partially be outside the [0, 1]
1505 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1506 normalized values to a range.
1507 """
1509 axes: Annotated[
1510 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1511 ] = None
1512 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1513 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1514 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1515 To normalize samples independently, leave out the "batch" axis.
1516 Default: Scale all axes jointly."""
1518 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1519 """The lower percentile used to determine the value to align with zero."""
1521 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1522 """The upper percentile used to determine the value to align with one.
1523 Has to be bigger than `min_percentile`.
1524 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1525 accepting percentiles specified in the range 0.0 to 1.0."""
1527 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1528 """Epsilon for numeric stability.
1529 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1530 with `v_lower,v_upper` values at the respective percentiles."""
1532 reference_tensor: Optional[TensorId] = None
1533 """ID of the unprocessed input tensor to compute the percentiles from.
1534 Default: The tensor itself.
1535 """
1537 @field_validator("max_percentile", mode="after")
1538 @classmethod
1539 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1540 if (min_p := info.data["min_percentile"]) >= value:
1541 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1543 return value
1546class ScaleRangeDescr(NodeWithExplicitlySetFields):
1547 """Scale with percentiles.
1549 Examples:
1550 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1551 - in YAML
1552 ```yaml
1553 preprocessing:
1554 - id: scale_range
1555 kwargs:
1556 axes: ['y', 'x']
1557 max_percentile: 99.8
1558 min_percentile: 5.0
1559 ```
1560 - in Python
1562 >>> preprocessing = [
1563 ... ScaleRangeDescr(
1564 ... kwargs=ScaleRangeKwargs(
1565 ... axes= (AxisId('y'), AxisId('x')),
1566 ... max_percentile= 99.8,
1567 ... min_percentile= 5.0,
1568 ... )
1569 ... )
1570 ... ]
1572 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1573 - in YAML
1574 ```yaml
1575 preprocessing:
1576 - id: scale_range
1577 kwargs:
1578 axes: ['y', 'x']
1579 max_percentile: 99.8
1580 min_percentile: 5.0
1581 - id: clip
1582 kwargs:
1583 min: 0.0
1584 max: 1.0
1585 ```
1586 - in Python
1588 >>> preprocessing = [
1589 ... ScaleRangeDescr(
1590 ... kwargs=ScaleRangeKwargs(
1591 ... axes= (AxisId('y'), AxisId('x')),
1592 ... max_percentile= 99.8,
1593 ... min_percentile= 5.0,
1594 ... )
1595 ... ),
1596 ... ClipDescr(
1597 ... kwargs=ClipKwargs(
1598 ... min=0.0,
1599 ... max=1.0,
1600 ... )
1601 ... ),
1602 ... ]
1604 """
1606 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1607 if TYPE_CHECKING:
1608 id: Literal["scale_range"] = "scale_range"
1609 else:
1610 id: Literal["scale_range"]
1611 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
1614class ScaleMeanVarianceKwargs(KwargsNode):
1615 """key word arguments for [ScaleMeanVarianceKwargs][]"""
1617 reference_tensor: TensorId
1618 """ID of unprocessed input tensor to match."""
1620 axes: Annotated[
1621 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1622 ] = None
1623 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1624 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1625 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1626 To normalize samples independently, leave out the 'batch' axis.
1627 Default: Scale all axes jointly."""
1629 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1630 """Epsilon for numeric stability:
1631 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1634class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields):
1635 """Scale a tensor's data distribution to match another tensor's mean/std.
1636 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1637 """
1639 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1640 if TYPE_CHECKING:
1641 id: Literal["scale_mean_variance"] = "scale_mean_variance"
1642 else:
1643 id: Literal["scale_mean_variance"]
1644 kwargs: ScaleMeanVarianceKwargs
1647PreprocessingDescr = Annotated[
1648 Union[
1649 BinarizeDescr,
1650 ClipDescr,
1651 EnsureDtypeDescr,
1652 FixedZeroMeanUnitVarianceDescr,
1653 ScaleLinearDescr,
1654 ScaleRangeDescr,
1655 SigmoidDescr,
1656 SoftmaxDescr,
1657 ZeroMeanUnitVarianceDescr,
1658 ],
1659 Discriminator("id"),
1660]
1661PostprocessingDescr = Annotated[
1662 Union[
1663 BinarizeDescr,
1664 ClipDescr,
1665 EnsureDtypeDescr,
1666 FixedZeroMeanUnitVarianceDescr,
1667 ScaleLinearDescr,
1668 ScaleMeanVarianceDescr,
1669 ScaleRangeDescr,
1670 SigmoidDescr,
1671 SoftmaxDescr,
1672 StardistPostprocessingDescr,
1673 ZeroMeanUnitVarianceDescr,
1674 ],
1675 Discriminator("id"),
1676]
1678IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1681class TensorDescrBase(Node, Generic[IO_AxisT]):
1682 id: TensorId
1683 """Tensor id. No duplicates are allowed."""
1685 description: Annotated[str, MaxLen(128)] = ""
1686 """free text description"""
1688 axes: NotEmpty[Sequence[IO_AxisT]]
1689 """tensor axes"""
1691 @property
1692 def shape(self):
1693 return tuple(a.size for a in self.axes)
1695 @field_validator("axes", mode="after", check_fields=False)
1696 @classmethod
1697 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1698 batch_axes = [a for a in axes if a.type == "batch"]
1699 if len(batch_axes) > 1:
1700 raise ValueError(
1701 f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1702 )
1704 seen_ids: Set[AxisId] = set()
1705 duplicate_axes_ids: Set[AxisId] = set()
1706 for a in axes:
1707 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1709 if duplicate_axes_ids:
1710 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1712 return axes
1714 test_tensor: FAIR[Optional[FileDescr_]] = None
1715 """An example tensor to use for testing.
1716 Using the model with the test input tensors is expected to yield the test output tensors.
1717 Each test tensor has be a an ndarray in the
1718 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1719 The file extension must be '.npy'."""
1721 sample_tensor: FAIR[Optional[FileDescr_]] = None
1722 """A sample tensor to illustrate a possible input/output for the model,
1723 The sample image primarily serves to inform a human user about an example use case
1724 and is typically stored as .hdf5, .png or .tiff.
1725 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1726 (numpy's `.npy` format is not supported).
1727 The image dimensionality has to match the number of axes specified in this tensor description.
1728 """
1730 @model_validator(mode="after")
1731 def _validate_sample_tensor(self) -> Self:
1732 if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1733 return self
1735 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1736 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType]
1737 reader.read(),
1738 extension=PurePosixPath(reader.original_file_name).suffix,
1739 )
1740 n_dims = len(tensor.squeeze().shape)
1741 n_dims_min = n_dims_max = len(self.axes)
1743 for a in self.axes:
1744 if isinstance(a, BatchAxis):
1745 n_dims_min -= 1
1746 elif isinstance(a.size, int):
1747 if a.size == 1:
1748 n_dims_min -= 1
1749 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1750 if a.size.min == 1:
1751 n_dims_min -= 1
1752 elif isinstance(a.size, SizeReference):
1753 if a.size.offset < 2:
1754 # size reference may result in singleton axis
1755 n_dims_min -= 1
1756 else:
1757 assert_never(a.size)
1759 n_dims_min = max(0, n_dims_min)
1760 if n_dims < n_dims_min or n_dims > n_dims_max:
1761 raise ValueError(
1762 f"Expected sample tensor to have {n_dims_min} to"
1763 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1764 )
1766 return self
1768 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1769 IntervalOrRatioDataDescr()
1770 )
1771 """Description of the tensor's data values, optionally per channel.
1772 If specified per channel, the data `type` needs to match across channels."""
1774 @property
1775 def dtype(
1776 self,
1777 ) -> Literal[
1778 "float32",
1779 "float64",
1780 "uint8",
1781 "int8",
1782 "uint16",
1783 "int16",
1784 "uint32",
1785 "int32",
1786 "uint64",
1787 "int64",
1788 "bool",
1789 ]:
1790 """dtype as specified under `data.type` or `data[i].type`"""
1791 if isinstance(self.data, collections.abc.Sequence):
1792 return self.data[0].type
1793 else:
1794 return self.data.type
1796 @field_validator("data", mode="after")
1797 @classmethod
1798 def _check_data_type_across_channels(
1799 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1800 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1801 if not isinstance(value, list):
1802 return value
1804 dtypes = {t.type for t in value}
1805 if len(dtypes) > 1:
1806 raise ValueError(
1807 "Tensor data descriptions per channel need to agree in their data"
1808 + f" `type`, but found {dtypes}."
1809 )
1811 return value
1813 @model_validator(mode="after")
1814 def _check_data_matches_channelaxis(self) -> Self:
1815 if not isinstance(self.data, (list, tuple)):
1816 return self
1818 for a in self.axes:
1819 if isinstance(a, ChannelAxis):
1820 size = a.size
1821 assert isinstance(size, int)
1822 break
1823 else:
1824 return self
1826 if len(self.data) != size:
1827 raise ValueError(
1828 f"Got tensor data descriptions for {len(self.data)} channels, but"
1829 + f" '{a.id}' axis has size {size}."
1830 )
1832 return self
1834 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1835 if len(array.shape) != len(self.axes):
1836 raise ValueError(
1837 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1838 + f" incompatible with {len(self.axes)} axes."
1839 )
1840 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1843class InputTensorDescr(TensorDescrBase[InputAxis]):
1844 id: TensorId = TensorId("input")
1845 """Input tensor id.
1846 No duplicates are allowed across all inputs and outputs."""
1848 optional: bool = False
1849 """indicates that this tensor may be `None`"""
1851 preprocessing: List[PreprocessingDescr] = Field(
1852 default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1853 )
1855 """Description of how this input should be preprocessed.
1857 notes:
1858 - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1859 to ensure an input tensor's data type matches the input tensor's data description.
1860 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1861 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1862 changing the data type.
1863 """
1865 @model_validator(mode="after")
1866 def _validate_preprocessing_kwargs(self) -> Self:
1867 axes_ids = [a.id for a in self.axes]
1868 for p in self.preprocessing:
1869 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1870 if kwargs_axes is None:
1871 continue
1873 if not isinstance(kwargs_axes, collections.abc.Sequence):
1874 raise ValueError(
1875 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1876 )
1878 if any(a not in axes_ids for a in kwargs_axes):
1879 raise ValueError(
1880 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1881 )
1883 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1884 dtype = self.data.type
1885 else:
1886 dtype = self.data[0].type
1888 # ensure `preprocessing` begins with `EnsureDtypeDescr`
1889 if not self.preprocessing or not isinstance(
1890 self.preprocessing[0], EnsureDtypeDescr
1891 ):
1892 self.preprocessing.insert(
1893 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1894 )
1896 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1897 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1898 self.preprocessing.append(
1899 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1900 )
1902 return self
1905def convert_axes(
1906 axes: str,
1907 *,
1908 shape: Union[
1909 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1910 ],
1911 tensor_type: Literal["input", "output"],
1912 halo: Optional[Sequence[int]],
1913 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1914):
1915 ret: List[AnyAxis] = []
1916 for i, a in enumerate(axes):
1917 axis_type = _AXIS_TYPE_MAP.get(a, a)
1918 if axis_type == "batch":
1919 ret.append(BatchAxis())
1920 continue
1922 scale = 1.0
1923 if isinstance(shape, _ParameterizedInputShape_v0_4):
1924 if shape.step[i] == 0:
1925 size = shape.min[i]
1926 else:
1927 size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1928 elif isinstance(shape, _ImplicitOutputShape_v0_4):
1929 ref_t = str(shape.reference_tensor)
1930 if ref_t.count(".") == 1:
1931 t_id, orig_a_id = ref_t.split(".")
1932 else:
1933 t_id = ref_t
1934 orig_a_id = a
1936 a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1937 if not (orig_scale := shape.scale[i]):
1938 # old way to insert a new axis dimension
1939 size = int(2 * shape.offset[i])
1940 else:
1941 scale = 1 / orig_scale
1942 if axis_type in ("channel", "index"):
1943 # these axes no longer have a scale
1944 offset_from_scale = orig_scale * size_refs.get(
1945 _TensorName_v0_4(t_id), {}
1946 ).get(orig_a_id, 0)
1947 else:
1948 offset_from_scale = 0
1949 size = SizeReference(
1950 tensor_id=TensorId(t_id),
1951 axis_id=AxisId(a_id),
1952 offset=int(offset_from_scale + 2 * shape.offset[i]),
1953 )
1954 else:
1955 size = shape[i]
1957 if axis_type == "time":
1958 if tensor_type == "input":
1959 ret.append(TimeInputAxis(size=size, scale=scale))
1960 else:
1961 assert not isinstance(size, ParameterizedSize)
1962 if halo is None:
1963 ret.append(TimeOutputAxis(size=size, scale=scale))
1964 else:
1965 assert not isinstance(size, int)
1966 ret.append(
1967 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1968 )
1970 elif axis_type == "index":
1971 if tensor_type == "input":
1972 ret.append(IndexInputAxis(size=size))
1973 else:
1974 if isinstance(size, ParameterizedSize):
1975 size = DataDependentSize(min=size.min)
1977 ret.append(IndexOutputAxis(size=size))
1978 elif axis_type == "channel":
1979 assert not isinstance(size, ParameterizedSize)
1980 if isinstance(size, SizeReference):
1981 warnings.warn(
1982 "Conversion of channel size from an implicit output shape may be"
1983 + " wrong"
1984 )
1985 ret.append(
1986 ChannelAxis(
1987 channel_names=[
1988 Identifier(f"channel{i}") for i in range(size.offset)
1989 ]
1990 )
1991 )
1992 else:
1993 ret.append(
1994 ChannelAxis(
1995 channel_names=[Identifier(f"channel{i}") for i in range(size)]
1996 )
1997 )
1998 elif axis_type == "space":
1999 if tensor_type == "input":
2000 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
2001 else:
2002 assert not isinstance(size, ParameterizedSize)
2003 if halo is None or halo[i] == 0:
2004 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
2005 elif isinstance(size, int):
2006 raise NotImplementedError(
2007 f"output axis with halo and fixed size (here {size}) not allowed"
2008 )
2009 else:
2010 ret.append(
2011 SpaceOutputAxisWithHalo(
2012 id=AxisId(a), size=size, scale=scale, halo=halo[i]
2013 )
2014 )
2016 return ret
2019def _axes_letters_to_ids(
2020 axes: Optional[str],
2021) -> Optional[List[AxisId]]:
2022 if axes is None:
2023 return None
2025 return [AxisId(a) for a in axes]
2028def _get_complement_v04_axis(
2029 tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
2030) -> Optional[AxisId]:
2031 if axes is None:
2032 return None
2034 non_complement_axes = set(axes) | {"b"}
2035 complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
2036 if len(complement_axes) > 1:
2037 raise ValueError(
2038 f"Expected none or a single complement axis, but axes '{axes}' "
2039 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
2040 )
2042 return None if not complement_axes else AxisId(complement_axes[0])
2045def _convert_proc(
2046 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
2047 tensor_axes: Sequence[str],
2048) -> Union[PreprocessingDescr, PostprocessingDescr]:
2049 if isinstance(p, _BinarizeDescr_v0_4):
2050 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
2051 elif isinstance(p, _ClipDescr_v0_4):
2052 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
2053 elif isinstance(p, _SigmoidDescr_v0_4):
2054 return SigmoidDescr()
2055 elif isinstance(p, _ScaleLinearDescr_v0_4):
2056 axes = _axes_letters_to_ids(p.kwargs.axes)
2057 if p.kwargs.axes is None:
2058 axis = None
2059 else:
2060 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
2062 if axis is None:
2063 assert not isinstance(p.kwargs.gain, list)
2064 assert not isinstance(p.kwargs.offset, list)
2065 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
2066 else:
2067 kwargs = ScaleLinearAlongAxisKwargs(
2068 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
2069 )
2070 return ScaleLinearDescr(kwargs=kwargs)
2071 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
2072 return ScaleMeanVarianceDescr(
2073 kwargs=ScaleMeanVarianceKwargs(
2074 axes=_axes_letters_to_ids(p.kwargs.axes),
2075 reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
2076 eps=p.kwargs.eps,
2077 )
2078 )
2079 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
2080 if p.kwargs.mode == "fixed":
2081 mean = p.kwargs.mean
2082 std = p.kwargs.std
2083 assert mean is not None
2084 assert std is not None
2086 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
2088 if axis is None:
2089 if isinstance(mean, list):
2090 raise ValueError("Expected single float value for mean, not <list>")
2091 if isinstance(std, list):
2092 raise ValueError("Expected single float value for std, not <list>")
2093 return FixedZeroMeanUnitVarianceDescr(
2094 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct(
2095 mean=mean,
2096 std=std,
2097 )
2098 )
2099 else:
2100 if not isinstance(mean, list):
2101 mean = [float(mean)]
2102 if not isinstance(std, list):
2103 std = [float(std)]
2105 return FixedZeroMeanUnitVarianceDescr(
2106 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
2107 axis=axis, mean=mean, std=std
2108 )
2109 )
2111 else:
2112 axes = _axes_letters_to_ids(p.kwargs.axes) or []
2113 if p.kwargs.mode == "per_dataset":
2114 axes = [AxisId("batch")] + axes
2115 if not axes:
2116 axes = None
2117 return ZeroMeanUnitVarianceDescr(
2118 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
2119 )
2121 elif isinstance(p, _ScaleRangeDescr_v0_4):
2122 return ScaleRangeDescr(
2123 kwargs=ScaleRangeKwargs(
2124 axes=_axes_letters_to_ids(p.kwargs.axes),
2125 min_percentile=p.kwargs.min_percentile,
2126 max_percentile=p.kwargs.max_percentile,
2127 eps=p.kwargs.eps,
2128 )
2129 )
2130 else:
2131 assert_never(p)
2134class _InputTensorConv(
2135 Converter[
2136 _InputTensorDescr_v0_4,
2137 InputTensorDescr,
2138 FileSource_,
2139 Optional[FileSource_],
2140 Mapping[_TensorName_v0_4, Mapping[str, int]],
2141 ]
2142):
2143 def _convert(
2144 self,
2145 src: _InputTensorDescr_v0_4,
2146 tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
2147 test_tensor: FileSource_,
2148 sample_tensor: Optional[FileSource_],
2149 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2150 ) -> "InputTensorDescr | dict[str, Any]":
2151 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2152 src.axes,
2153 shape=src.shape,
2154 tensor_type="input",
2155 halo=None,
2156 size_refs=size_refs,
2157 )
2158 prep: List[PreprocessingDescr] = []
2159 for p in src.preprocessing:
2160 cp = _convert_proc(p, src.axes)
2161 assert not isinstance(
2162 cp, (ScaleMeanVarianceDescr, StardistPostprocessingDescr)
2163 )
2164 prep.append(cp)
2166 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
2168 return tgt(
2169 axes=axes,
2170 id=TensorId(str(src.name)),
2171 test_tensor=FileDescr(source=test_tensor),
2172 sample_tensor=(
2173 None if sample_tensor is None else FileDescr(source=sample_tensor)
2174 ),
2175 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType]
2176 preprocessing=prep,
2177 )
2180_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
2183class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2184 id: TensorId = TensorId("output")
2185 """Output tensor id.
2186 No duplicates are allowed across all inputs and outputs."""
2188 postprocessing: List[PostprocessingDescr] = Field(
2189 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2190 )
2191 """Description of how this output should be postprocessed.
2193 note: `postprocessing` always ends with an 'ensure_dtype' operation.
2194 If not given this is added to cast to this tensor's `data.type`.
2195 """
2197 @model_validator(mode="after")
2198 def _validate_postprocessing_kwargs(self) -> Self:
2199 axes_ids = [a.id for a in self.axes]
2200 for p in self.postprocessing:
2201 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2202 if kwargs_axes is None:
2203 continue
2205 if not isinstance(kwargs_axes, collections.abc.Sequence):
2206 raise ValueError(
2207 f"expected `axes` sequence, but got {type(kwargs_axes)}"
2208 )
2210 if any(a not in axes_ids for a in kwargs_axes):
2211 raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2213 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2214 dtype = self.data.type
2215 else:
2216 dtype = self.data[0].type
2218 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2219 if not self.postprocessing or not isinstance(
2220 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2221 ):
2222 self.postprocessing.append(
2223 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2224 )
2225 return self
2228class _OutputTensorConv(
2229 Converter[
2230 _OutputTensorDescr_v0_4,
2231 OutputTensorDescr,
2232 FileSource_,
2233 Optional[FileSource_],
2234 Mapping[_TensorName_v0_4, Mapping[str, int]],
2235 ]
2236):
2237 def _convert(
2238 self,
2239 src: _OutputTensorDescr_v0_4,
2240 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2241 test_tensor: FileSource_,
2242 sample_tensor: Optional[FileSource_],
2243 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2244 ) -> "OutputTensorDescr | dict[str, Any]":
2245 # TODO: split convert_axes into convert_output_axes and convert_input_axes
2246 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2247 src.axes,
2248 shape=src.shape,
2249 tensor_type="output",
2250 halo=src.halo,
2251 size_refs=size_refs,
2252 )
2253 data_descr: Dict[str, Any] = dict(type=src.data_type)
2254 if data_descr["type"] == "bool":
2255 data_descr["values"] = [False, True]
2257 return tgt(
2258 axes=axes,
2259 id=TensorId(str(src.name)),
2260 test_tensor=FileDescr(source=test_tensor),
2261 sample_tensor=(
2262 None if sample_tensor is None else FileDescr(source=sample_tensor)
2263 ),
2264 data=data_descr, # pyright: ignore[reportArgumentType]
2265 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2266 )
2269_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2272TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2275def validate_tensors(
2276 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2277 tensor_origin: Literal[
2278 "source", "test_tensor"
2279 ] = "source", # for more precise error messages
2280):
2281 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2283 def e_msg_location(d: TensorDescr):
2284 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2286 for descr, array in tensors.values():
2287 if array is None:
2288 axis_sizes = {a.id: None for a in descr.axes}
2289 else:
2290 try:
2291 axis_sizes = descr.get_axis_sizes_for_array(array)
2292 except ValueError as e:
2293 raise ValueError(f"{e_msg_location(descr)} {e}")
2295 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2297 for descr, array in tensors.values():
2298 if array is None:
2299 continue
2301 if descr.dtype in ("float32", "float64"):
2302 invalid_test_tensor_dtype = array.dtype.name not in (
2303 "float32",
2304 "float64",
2305 "uint8",
2306 "int8",
2307 "uint16",
2308 "int16",
2309 "uint32",
2310 "int32",
2311 "uint64",
2312 "int64",
2313 )
2314 else:
2315 invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2317 if invalid_test_tensor_dtype:
2318 raise ValueError(
2319 f"{tensor_origin} data type '{array.dtype.name}' does not"
2320 + f" match described {e_msg_location(descr)}.dtype '{descr.dtype}'"
2321 )
2323 if array.min() > -1e-4 and array.max() < 1e-4:
2324 raise ValueError(
2325 "Output values are too small for reliable testing."
2326 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2327 )
2329 for a in descr.axes:
2330 actual_size = all_tensor_axes[descr.id][a.id][1]
2331 if actual_size is None:
2332 continue
2334 if a.size is None:
2335 continue
2337 if isinstance(a.size, int):
2338 if actual_size != a.size:
2339 raise ValueError(
2340 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis "
2341 + f"has incompatible size {actual_size}, expected {a.size}"
2342 )
2343 elif isinstance(a.size, ParameterizedSize):
2344 _ = a.size.validate_size(
2345 actual_size,
2346 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis ",
2347 )
2348 elif isinstance(a.size, DataDependentSize):
2349 _ = a.size.validate_size(
2350 actual_size,
2351 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis ",
2352 )
2353 elif isinstance(a.size, SizeReference):
2354 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2355 if ref_tensor_axes is None:
2356 raise ValueError(
2357 f"{e_msg_location(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2358 + f" reference '{a.size.tensor_id}', available: {list(all_tensor_axes)}"
2359 )
2361 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2362 if ref_axis is None or ref_size is None:
2363 raise ValueError(
2364 f"{e_msg_location(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2365 + f" reference '{a.size.tensor_id}.{a.size.axis_id}, available: {list(ref_tensor_axes)}"
2366 )
2368 if a.unit != ref_axis.unit:
2369 raise ValueError(
2370 f"{e_msg_location(descr)}.axes[{a.id}].size: `SizeReference` requires"
2371 + " axis and reference axis to have the same `unit`, but"
2372 + f" {a.unit}!={ref_axis.unit}"
2373 )
2375 if actual_size != (
2376 expected_size := (
2377 ref_size * ref_axis.scale / a.scale + a.size.offset
2378 )
2379 ):
2380 raise ValueError(
2381 f"{e_msg_location(descr)}.{tensor_origin}: axis '{a.id}' of size"
2382 + f" {actual_size} invalid for referenced size {ref_size};"
2383 + f" expected {expected_size}"
2384 )
2385 else:
2386 assert_never(a.size)
2389FileDescr_dependencies = Annotated[
2390 FileDescr_,
2391 WithSuffix((".yaml", ".yml"), case_sensitive=True),
2392 Field(examples=[dict(source="environment.yaml")]),
2393]
2396class _ArchitectureCallableDescr(Node):
2397 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2398 """Identifier of the callable that returns a torch.nn.Module instance."""
2400 kwargs: Dict[str, YamlValue] = Field(
2401 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2402 )
2403 """key word arguments for the `callable`"""
2406class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2407 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2408 """Architecture source file"""
2410 @model_serializer(mode="wrap", when_used="unless-none")
2411 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2412 return package_file_descr_serializer(self, nxt, info)
2415class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2416 import_from: str
2417 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2420class _ArchFileConv(
2421 Converter[
2422 _CallableFromFile_v0_4,
2423 ArchitectureFromFileDescr,
2424 Optional[Sha256],
2425 Dict[str, Any],
2426 ]
2427):
2428 def _convert(
2429 self,
2430 src: _CallableFromFile_v0_4,
2431 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2432 sha256: Optional[Sha256],
2433 kwargs: Dict[str, Any],
2434 ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2435 if src.startswith("http") and src.count(":") == 2:
2436 http, source, callable_ = src.split(":")
2437 source = ":".join((http, source))
2438 elif not src.startswith("http") and src.count(":") == 1:
2439 source, callable_ = src.split(":")
2440 else:
2441 source = str(src)
2442 callable_ = str(src)
2443 return tgt(
2444 callable=Identifier(callable_),
2445 source=cast(FileSource_, source),
2446 sha256=sha256,
2447 kwargs=kwargs,
2448 )
2451_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2454class _ArchLibConv(
2455 Converter[
2456 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2457 ]
2458):
2459 def _convert(
2460 self,
2461 src: _CallableFromDepencency_v0_4,
2462 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2463 kwargs: Dict[str, Any],
2464 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2465 *mods, callable_ = src.split(".")
2466 import_from = ".".join(mods)
2467 return tgt(
2468 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2469 )
2472_arch_lib_conv = _ArchLibConv(
2473 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2474)
2477class WeightsEntryDescrBase(FileDescr):
2478 type: ClassVar[WeightsFormat]
2479 weights_format_name: ClassVar[str] # human readable
2481 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2482 """Source of the weights file."""
2484 authors: Optional[List[Author]] = None
2485 """Authors
2486 Either the person(s) that have trained this model resulting in the original weights file.
2487 (If this is the initial weights entry, i.e. it does not have a `parent`)
2488 Or the person(s) who have converted the weights to this weights format.
2489 (If this is a child weight, i.e. it has a `parent` field)
2490 """
2492 parent: Annotated[
2493 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2494 ] = None
2495 """The source weights these weights were converted from.
2496 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2497 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2498 All weight entries except one (the initial set of weights resulting from training the model),
2499 need to have this field."""
2501 comment: str = ""
2502 """A comment about this weights entry, for example how these weights were created."""
2504 @model_validator(mode="after")
2505 def _validate(self) -> Self:
2506 if self.type == self.parent:
2507 raise ValueError("Weights entry can't be it's own parent.")
2509 return self
2511 @model_serializer(mode="wrap", when_used="unless-none")
2512 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2513 return package_file_descr_serializer(self, nxt, info)
2516class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2517 type: ClassVar[WeightsFormat] = "keras_hdf5"
2518 weights_format_name: ClassVar[str] = "Keras HDF5"
2519 tensorflow_version: Version
2520 """TensorFlow version used to create these weights."""
2523class KerasV3WeightsDescr(WeightsEntryDescrBase):
2524 type: ClassVar[WeightsFormat] = "keras_v3"
2525 weights_format_name: ClassVar[str] = "Keras v3"
2526 keras_version: Annotated[Version, Ge(Version(3))]
2527 """Keras version used to create these weights."""
2528 backend: Tuple[Literal["tensorflow", "jax", "torch"], Version]
2529 """Keras backend used to create these weights."""
2530 source: Annotated[
2531 FileSource,
2532 AfterValidator(wo_special_file_name),
2533 WithSuffix(".keras", case_sensitive=True),
2534 ]
2535 """Source of the .keras weights file."""
2538FileDescr_external_data = Annotated[
2539 FileDescr_,
2540 WithSuffix(".data", case_sensitive=True),
2541 Field(examples=[dict(source="weights.onnx.data")]),
2542]
2545class OnnxWeightsDescr(WeightsEntryDescrBase):
2546 type: ClassVar[WeightsFormat] = "onnx"
2547 weights_format_name: ClassVar[str] = "ONNX"
2548 opset_version: Annotated[int, Ge(7)]
2549 """ONNX opset version"""
2551 external_data: Optional[FileDescr_external_data] = None
2552 """Source of the external ONNX data file holding the weights.
2553 (If present **source** holds the ONNX architecture without weights)."""
2555 @model_validator(mode="after")
2556 def _validate_external_data_unique_file_name(self) -> Self:
2557 if self.external_data is not None and (
2558 extract_file_name(self.source)
2559 == extract_file_name(self.external_data.source)
2560 ):
2561 raise ValueError(
2562 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'"
2563 + " must be different from ONNX `source` file name."
2564 )
2566 return self
2569class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2570 type: ClassVar[WeightsFormat] = "pytorch_state_dict"
2571 weights_format_name: ClassVar[str] = "Pytorch State Dict"
2572 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2573 pytorch_version: Version
2574 """Version of the PyTorch library used.
2575 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2576 """
2577 dependencies: Optional[FileDescr_dependencies] = None
2578 """Custom depencies beyond pytorch described in a Conda environment file.
2579 Allows to specify custom dependencies, see conda docs:
2580 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2581 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2583 The conda environment file should include pytorch and any version pinning has to be compatible with
2584 **pytorch_version**.
2585 """
2588class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2589 type: ClassVar[WeightsFormat] = "tensorflow_js"
2590 weights_format_name: ClassVar[str] = "Tensorflow.js"
2591 tensorflow_version: Version
2592 """Version of the TensorFlow library used."""
2594 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2595 """The multi-file weights.
2596 All required files/folders should be a zip archive."""
2599class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2600 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle"
2601 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2602 tensorflow_version: Version
2603 """Version of the TensorFlow library used."""
2605 dependencies: Optional[FileDescr_dependencies] = None
2606 """Custom dependencies beyond tensorflow.
2607 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2609 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2610 """The multi-file weights.
2611 All required files/folders should be a zip archive."""
2614class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2615 type: ClassVar[WeightsFormat] = "torchscript"
2616 weights_format_name: ClassVar[str] = "TorchScript"
2617 pytorch_version: Version
2618 """Version of the PyTorch library used."""
2621SpecificWeightsDescr = Union[
2622 KerasHdf5WeightsDescr,
2623 KerasV3WeightsDescr,
2624 OnnxWeightsDescr,
2625 PytorchStateDictWeightsDescr,
2626 TensorflowJsWeightsDescr,
2627 TensorflowSavedModelBundleWeightsDescr,
2628 TorchscriptWeightsDescr,
2629]
2632class WeightsDescr(Node):
2633 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2634 keras_v3: Optional[KerasV3WeightsDescr] = None
2635 onnx: Optional[OnnxWeightsDescr] = None
2636 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2637 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2638 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2639 None
2640 )
2641 torchscript: Optional[TorchscriptWeightsDescr] = None
2643 @model_validator(mode="after")
2644 def check_entries(self) -> Self:
2645 entries = {wtype for wtype, entry in self if entry is not None}
2647 if not entries:
2648 raise ValueError("Missing weights entry")
2650 entries_wo_parent = {
2651 wtype
2652 for wtype, entry in self
2653 if entry is not None and hasattr(entry, "parent") and entry.parent is None
2654 }
2655 if len(entries_wo_parent) != 1:
2656 issue_warning(
2657 "Exactly one weights entry may not specify the `parent` field (got"
2658 + " {value}). That entry is considered the original set of model weights."
2659 + " Other weight formats are created through conversion of the orignal or"
2660 + " already converted weights. They have to reference the weights format"
2661 + " they were converted from as their `parent`.",
2662 value=len(entries_wo_parent),
2663 field="weights",
2664 )
2666 for wtype, entry in self:
2667 if entry is None:
2668 continue
2670 assert hasattr(entry, "type")
2671 assert hasattr(entry, "parent")
2672 assert wtype == entry.type
2673 if (
2674 entry.parent is not None and entry.parent not in entries
2675 ): # self reference checked for `parent` field
2676 raise ValueError(
2677 f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2678 + f" formats: {entries}"
2679 )
2681 return self
2683 def __getitem__(
2684 self,
2685 key: WeightsFormat,
2686 ):
2687 if key == "keras_hdf5":
2688 ret = self.keras_hdf5
2689 elif key == "keras_v3":
2690 ret = self.keras_v3
2691 elif key == "onnx":
2692 ret = self.onnx
2693 elif key == "pytorch_state_dict":
2694 ret = self.pytorch_state_dict
2695 elif key == "tensorflow_js":
2696 ret = self.tensorflow_js
2697 elif key == "tensorflow_saved_model_bundle":
2698 ret = self.tensorflow_saved_model_bundle
2699 elif key == "torchscript":
2700 ret = self.torchscript
2701 else:
2702 raise KeyError(key)
2704 if ret is None:
2705 raise KeyError(key)
2707 return ret
2709 @overload
2710 def __setitem__(
2711 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr]
2712 ) -> None: ...
2713 @overload
2714 def __setitem__(
2715 self, key: Literal["keras_v3"], value: Optional[KerasV3WeightsDescr]
2716 ) -> None: ...
2717 @overload
2718 def __setitem__(
2719 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr]
2720 ) -> None: ...
2721 @overload
2722 def __setitem__(
2723 self,
2724 key: Literal["pytorch_state_dict"],
2725 value: Optional[PytorchStateDictWeightsDescr],
2726 ) -> None: ...
2727 @overload
2728 def __setitem__(
2729 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr]
2730 ) -> None: ...
2731 @overload
2732 def __setitem__(
2733 self,
2734 key: Literal["tensorflow_saved_model_bundle"],
2735 value: Optional[TensorflowSavedModelBundleWeightsDescr],
2736 ) -> None: ...
2737 @overload
2738 def __setitem__(
2739 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr]
2740 ) -> None: ...
2742 def __setitem__(
2743 self,
2744 key: WeightsFormat,
2745 value: Optional[SpecificWeightsDescr],
2746 ):
2747 if key == "keras_hdf5":
2748 if value is not None and not isinstance(value, KerasHdf5WeightsDescr):
2749 raise TypeError(
2750 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}"
2751 )
2752 self.keras_hdf5 = value
2753 elif key == "keras_v3":
2754 if value is not None and not isinstance(value, KerasV3WeightsDescr):
2755 raise TypeError(
2756 f"Expected KerasV3WeightsDescr or None for key 'keras_v3', got {type(value)}"
2757 )
2758 self.keras_v3 = value
2759 elif key == "onnx":
2760 if value is not None and not isinstance(value, OnnxWeightsDescr):
2761 raise TypeError(
2762 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}"
2763 )
2764 self.onnx = value
2765 elif key == "pytorch_state_dict":
2766 if value is not None and not isinstance(
2767 value, PytorchStateDictWeightsDescr
2768 ):
2769 raise TypeError(
2770 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}"
2771 )
2772 self.pytorch_state_dict = value
2773 elif key == "tensorflow_js":
2774 if value is not None and not isinstance(value, TensorflowJsWeightsDescr):
2775 raise TypeError(
2776 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}"
2777 )
2778 self.tensorflow_js = value
2779 elif key == "tensorflow_saved_model_bundle":
2780 if value is not None and not isinstance(
2781 value, TensorflowSavedModelBundleWeightsDescr
2782 ):
2783 raise TypeError(
2784 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}"
2785 )
2786 self.tensorflow_saved_model_bundle = value
2787 elif key == "torchscript":
2788 if value is not None and not isinstance(value, TorchscriptWeightsDescr):
2789 raise TypeError(
2790 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}"
2791 )
2792 self.torchscript = value
2793 else:
2794 raise KeyError(key)
2796 @property
2797 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]:
2798 return {
2799 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2800 **({} if self.keras_v3 is None else {"keras_v3": self.keras_v3}),
2801 **({} if self.onnx is None else {"onnx": self.onnx}),
2802 **(
2803 {}
2804 if self.pytorch_state_dict is None
2805 else {"pytorch_state_dict": self.pytorch_state_dict}
2806 ),
2807 **(
2808 {}
2809 if self.tensorflow_js is None
2810 else {"tensorflow_js": self.tensorflow_js}
2811 ),
2812 **(
2813 {}
2814 if self.tensorflow_saved_model_bundle is None
2815 else {
2816 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2817 }
2818 ),
2819 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2820 }
2822 @property
2823 def missing_formats(self) -> Set[WeightsFormat]:
2824 return {
2825 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2826 }
2829class ModelId(ResourceId):
2830 pass
2833class LinkedModel(LinkedResourceBase):
2834 """Reference to a bioimage.io model."""
2836 id: ModelId
2837 """A valid model `id` from the bioimage.io collection."""
2840class _DataDepSize(NamedTuple):
2841 min: StrictInt
2842 max: Optional[StrictInt]
2845class _AxisSizes(NamedTuple):
2846 """the lenghts of all axes of model inputs and outputs"""
2848 inputs: Dict[Tuple[TensorId, AxisId], int]
2849 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2852class _TensorSizes(NamedTuple):
2853 """_AxisSizes as nested dicts"""
2855 inputs: Dict[TensorId, Dict[AxisId, int]]
2856 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2859class ReproducibilityTolerance(Node, extra="allow"):
2860 """Describes what small numerical differences -- if any -- may be tolerated
2861 in the generated output when executing in different environments.
2863 A tensor element *output* is considered mismatched to the **test_tensor** if
2864 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2865 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2867 Motivation:
2868 For testing we can request the respective deep learning frameworks to be as
2869 reproducible as possible by setting seeds and chosing deterministic algorithms,
2870 but differences in operating systems, available hardware and installed drivers
2871 may still lead to numerical differences.
2872 """
2874 relative_tolerance: RelativeTolerance = 1e-3
2875 """Maximum relative tolerance of reproduced test tensor."""
2877 absolute_tolerance: AbsoluteTolerance = 1e-3
2878 """Maximum absolute tolerance of reproduced test tensor."""
2880 mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2881 """Maximum number of mismatched elements/pixels per million to tolerate."""
2883 output_ids: Sequence[TensorId] = ()
2884 """Limits the output tensor IDs these reproducibility details apply to."""
2886 weights_formats: Sequence[WeightsFormat] = ()
2887 """Limits the weights formats these details apply to."""
2890class BiasRisksLimitations(Node, extra="allow"):
2891 """Known biases, risks, technical limitations, and recommendations for model use."""
2893 known_biases: str = dedent("""\
2894 In general bioimage models may suffer from biases caused by:
2896 - Imaging protocol dependencies
2897 - Use of a specific cell type
2898 - Species-specific training data limitations
2900 """)
2901 """Biases in training data or model behavior."""
2903 risks: str = dedent("""\
2904 Common risks in bioimage analysis include:
2906 - Erroneously assuming generalization to unseen experimental conditions
2907 - Trusting (overconfident) model outputs without validation
2908 - Misinterpretation of results
2910 """)
2911 """Potential risks in the context of bioimage analysis."""
2913 limitations: Optional[str] = None
2914 """Technical limitations and failure modes."""
2916 recommendations: str = "Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model."
2917 """Mitigation strategies regarding `known_biases`, `risks`, and `limitations`, as well as applicable best practices.
2919 Consider:
2920 - How to use a validation dataset?
2921 - How to manually validate?
2922 - Feasibility of domain adaptation for different experimental setups?
2924 """
2926 def format_md(self) -> str:
2927 if self.limitations is None:
2928 limitations_header = ""
2929 else:
2930 limitations_header = "## Limitations\n\n"
2932 return f"""# Bias, Risks, and Limitations
2934{self.known_biases}
2936{self.risks}
2938{limitations_header}{self.limitations or ""}
2940## Recommendations
2942{self.recommendations}
2944"""
2947class TrainingDetails(Node, extra="allow"):
2948 training_preprocessing: Optional[str] = None
2949 """Detailed image preprocessing steps during model training:
2951 Mention:
2952 - *Normalization methods*
2953 - *Augmentation strategies*
2954 - *Resizing/resampling procedures*
2955 - *Artifact handling*
2957 """
2959 training_epochs: Optional[float] = None
2960 """Number of training epochs."""
2962 training_batch_size: Optional[float] = None
2963 """Batch size used in training."""
2965 initial_learning_rate: Optional[float] = None
2966 """Initial learning rate used in training."""
2968 learning_rate_schedule: Optional[str] = None
2969 """Learning rate schedule used in training."""
2971 loss_function: Optional[str] = None
2972 """Loss function used in training, e.g. nn.MSELoss."""
2974 loss_function_kwargs: Dict[str, YamlValue] = Field(
2975 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2976 )
2977 """key word arguments for the `loss_function`"""
2979 optimizer: Optional[str] = None
2980 """optimizer, e.g. torch.optim.Adam"""
2982 optimizer_kwargs: Dict[str, YamlValue] = Field(
2983 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2984 )
2985 """key word arguments for the `optimizer`"""
2987 regularization: Optional[str] = None
2988 """Regularization techniques used during training, e.g. drop-out or weight decay."""
2990 training_duration: Optional[float] = None
2991 """Total training duration in hours."""
2994class Evaluation(Node, extra="allow"):
2995 model_id: Optional[ModelId] = None
2996 """Model being evaluated."""
2998 dataset_id: DatasetId
2999 """Dataset used for evaluation."""
3001 dataset_source: HttpUrl
3002 """Source of the dataset."""
3004 dataset_role: Literal["train", "validation", "test", "independent", "unknown"]
3005 """Role of the dataset used for evaluation.
3007 - `train`: dataset was (part of) the training data
3008 - `validation`: dataset was (part of) the validation data used during training, e.g. used for model selection or hyperparameter tuning
3009 - `test`: dataset was (part of) the designated test data; not used during training or validation, but acquired from the same source/distribution as training data
3010 - `independent`: dataset is entirely independent test data; not used during training or validation, and acquired from a different source/distribution than training data
3011 - `unknown`: role of the dataset is unknown; choose this if you are not certain if (a subset) of the data was seen by the model during training.
3012 """
3014 sample_count: int
3015 """Number of evaluated samples."""
3017 evaluation_factors: List[Annotated[str, MaxLen(16)]]
3018 """(Abbreviations of) each evaluation factor.
3020 Evaluation factors are criteria along which model performance is evaluated, e.g. different image conditions
3021 like 'low SNR', 'high cell density', or different biological conditions like 'cell type A', 'cell type B'.
3022 An 'overall' factor may be included to summarize performance across all conditions.
3023 """
3025 evaluation_factors_long: List[str]
3026 """Descriptions (long form) of each evaluation factor."""
3028 metrics: List[Annotated[str, MaxLen(16)]]
3029 """(Abbreviations of) metrics used for evaluation."""
3031 metrics_long: List[str]
3032 """Description of each metric used."""
3034 @model_validator(mode="after")
3035 def _validate_list_lengths(self) -> Self:
3036 if len(self.evaluation_factors) != len(self.evaluation_factors_long):
3037 raise ValueError(
3038 "`evaluation_factors` and `evaluation_factors_long` must have the same length"
3039 )
3041 if len(self.metrics) != len(self.metrics_long):
3042 raise ValueError("`metrics` and `metrics_long` must have the same length")
3044 if len(self.results) != len(self.metrics):
3045 raise ValueError("`results` must have the same number of rows as `metrics`")
3047 for row in self.results:
3048 if len(row) != len(self.evaluation_factors):
3049 raise ValueError(
3050 "`results` must have the same number of columns (in every row) as `evaluation_factors`"
3051 )
3053 return self
3055 results: List[List[Union[str, float, int]]]
3056 """Results for each metric (rows; outer list) and each evaluation factor (columns; inner list)."""
3058 results_summary: Optional[str] = None
3059 """Interpretation of results for general audience.
3061 Consider:
3062 - Overall model performance
3063 - Comparison to existing methods
3064 - Limitations and areas for improvement
3066"""
3068 def format_md(self):
3069 results_header = ["Metric"] + self.evaluation_factors
3070 results_table_cells = [results_header, ["---"] * len(results_header)] + [
3071 [metric] + [str(r) for r in row]
3072 for metric, row in zip(self.metrics, self.results)
3073 ]
3075 results_table = "".join(
3076 "| " + " | ".join(row) + " |\n" for row in results_table_cells
3077 )
3078 factors = "".join(
3079 f"\n - {ef}: {efl}"
3080 for ef, efl in zip(self.evaluation_factors, self.evaluation_factors_long)
3081 )
3082 metrics = "".join(
3083 f"\n - {em}: {eml}" for em, eml in zip(self.metrics, self.metrics_long)
3084 )
3086 return f"""## Testing Data, Factors & Metrics
3088Evaluation of {self.model_id or "this"} model on the {self.dataset_id} dataset (dataset role: {self.dataset_role}).
3090### Testing Data
3092- **Source:** [{self.dataset_id}]({self.dataset_source})
3093- **Size:** {self.sample_count} evaluated samples
3095### Factors
3096{factors}
3098### Metrics
3099{metrics}
3101## Results
3103### Quantitative Results
3105{results_table}
3107### Summary
3109{self.results_summary or "missing"}
3111"""
3114class EnvironmentalImpact(Node, extra="allow"):
3115 """Environmental considerations for model training and deployment.
3117 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
3118 """
3120 hardware_type: Optional[str] = None
3121 """GPU/CPU specifications"""
3123 hours_used: Optional[float] = None
3124 """Total compute hours"""
3126 cloud_provider: Optional[str] = None
3127 """If applicable"""
3129 compute_region: Optional[str] = None
3130 """Geographic location"""
3132 co2_emitted: Optional[float] = None
3133 """kg CO2 equivalent
3135 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
3136 """
3138 def format_md(self):
3139 """Filled Markdown template section following [Hugging Face Model Card Template](https://huggingface.co/docs/hub/en/model-card-annotated)."""
3140 if self == self.__class__():
3141 return ""
3143 ret = "# Environmental Impact\n\n"
3144 if self.hardware_type is not None:
3145 ret += f"- **Hardware Type:** {self.hardware_type}\n"
3146 if self.hours_used is not None:
3147 ret += f"- **Hours used:** {self.hours_used}\n"
3148 if self.cloud_provider is not None:
3149 ret += f"- **Cloud Provider:** {self.cloud_provider}\n"
3150 if self.compute_region is not None:
3151 ret += f"- **Compute Region:** {self.compute_region}\n"
3152 if self.co2_emitted is not None:
3153 ret += f"- **Carbon Emitted:** {self.co2_emitted} kg CO2e\n"
3155 return ret + "\n"
3158class BioimageioConfig(Node, extra="allow"):
3159 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
3160 """Tolerances to allow when reproducing the model's test outputs
3161 from the model's test inputs.
3162 Only the first entry matching tensor id and weights format is considered.
3163 """
3165 funded_by: Optional[str] = None
3166 """Funding agency, grant number if applicable"""
3168 architecture_type: Optional[Annotated[str, MaxLen(32)]] = (
3169 None # TODO: add to differentiated tags
3170 )
3171 """Model architecture type, e.g., 3D U-Net, ResNet, transformer"""
3173 architecture_description: Optional[str] = None
3174 """Text description of model architecture."""
3176 modality: Optional[str] = None # TODO: add to differentiated tags
3177 """Input modality, e.g., fluorescence microscopy, electron microscopy"""
3179 target_structure: List[str] = Field( # TODO: add to differentiated tags
3180 default_factory=cast(Callable[[], List[str]], list)
3181 )
3182 """Biological structure(s) the model is designed to analyze, e.g., nuclei, mitochondria, cells"""
3184 task: Optional[str] = None # TODO: add to differentiated tags
3185 """Bioimage-specific task type, e.g., segmentation, classification, detection, denoising"""
3187 new_version: Optional[ModelId] = None
3188 """A new version of this model exists with a different model id."""
3190 out_of_scope_use: Optional[str] = None
3191 """Describe how the model may be misused in bioimage analysis contexts and what users should **not** do with the model."""
3193 bias_risks_limitations: BiasRisksLimitations = Field(
3194 default_factory=BiasRisksLimitations.model_construct
3195 )
3196 """Description of known bias, risks, and technical limitations for in-scope model use."""
3198 model_parameter_count: Optional[int] = None
3199 """Total number of model parameters."""
3201 training: TrainingDetails = Field(default_factory=TrainingDetails.model_construct)
3202 """Details on how the model was trained."""
3204 inference_time: Optional[str] = None
3205 """Average inference time per image/tile. Specify hardware and image size. Multiple examples can be given."""
3207 memory_requirements_inference: Optional[str] = None
3208 """GPU memory needed for inference. Multiple examples with different image size can be given."""
3210 memory_requirements_training: Optional[str] = None
3211 """GPU memory needed for training. Multiple examples with different image/batch sizes can be given."""
3213 evaluations: List[Evaluation] = Field(
3214 default_factory=cast(Callable[[], List[Evaluation]], list)
3215 )
3216 """Quantitative model evaluations.
3218 Note:
3219 At the moment we recommend to include only a single test dataset
3220 (with evaluation factors that may mark subsets of the dataset)
3221 to avoid confusion and make the presentation of results cleaner.
3222 """
3224 environmental_impact: EnvironmentalImpact = Field(
3225 default_factory=EnvironmentalImpact.model_construct
3226 )
3227 """Environmental considerations for model training and deployment"""
3230class Config(Node, extra="allow"):
3231 bioimageio: BioimageioConfig = Field(
3232 default_factory=BioimageioConfig.model_construct
3233 )
3234 stardist: YamlValue = None
3237class ModelDescr(GenericModelDescrBase):
3238 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
3239 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
3240 """
3242 implemented_format_version: ClassVar[Literal["0.5.9"]] = "0.5.9"
3243 if TYPE_CHECKING:
3244 format_version: Literal["0.5.9"] = "0.5.9"
3245 else:
3246 format_version: Literal["0.5.9"]
3247 """Version of the bioimage.io model description specification used.
3248 When creating a new model always use the latest micro/patch version described here.
3249 The `format_version` is important for any consumer software to understand how to parse the fields.
3250 """
3252 implemented_type: ClassVar[Literal["model"]] = "model"
3253 if TYPE_CHECKING:
3254 type: Literal["model"] = "model"
3255 else:
3256 type: Literal["model"]
3257 """Specialized resource type 'model'"""
3259 id: Optional[ModelId] = None
3260 """bioimage.io-wide unique resource identifier
3261 assigned by bioimage.io; version **un**specific."""
3263 authors: FAIR[List[Author]] = Field(
3264 default_factory=cast(Callable[[], List[Author]], list)
3265 )
3266 """The authors are the creators of the model RDF and the primary points of contact."""
3268 documentation: FAIR[Optional[FileSource_documentation]] = None
3269 """URL or relative path to a markdown file with additional documentation.
3270 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
3271 The documentation should include a '#[#] Validation' (sub)section
3272 with details on how to quantitatively validate the model on unseen data."""
3274 @field_validator("documentation", mode="after")
3275 @classmethod
3276 def _validate_documentation(
3277 cls, value: Optional[FileSource_documentation]
3278 ) -> Optional[FileSource_documentation]:
3279 if not get_validation_context().perform_io_checks or value is None:
3280 return value
3282 doc_reader = get_reader(value)
3283 doc_content = doc_reader.read().decode(encoding="utf-8")
3284 if not re.search("#.*[vV]alidation", doc_content):
3285 issue_warning(
3286 "No '# Validation' (sub)section found in {value}.",
3287 value=value,
3288 field="documentation",
3289 )
3291 return value
3293 inputs: NotEmpty[Sequence[InputTensorDescr]]
3294 """Describes the input tensors expected by this model."""
3296 @field_validator("inputs", mode="after")
3297 @classmethod
3298 def _validate_input_axes(
3299 cls, inputs: Sequence[InputTensorDescr]
3300 ) -> Sequence[InputTensorDescr]:
3301 input_size_refs = cls._get_axes_with_independent_size(inputs)
3303 for i, ipt in enumerate(inputs):
3304 valid_independent_refs: Dict[
3305 Tuple[TensorId, AxisId],
3306 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3307 ] = {
3308 **{
3309 (ipt.id, a.id): (ipt, a, a.size)
3310 for a in ipt.axes
3311 if not isinstance(a, BatchAxis)
3312 and isinstance(a.size, (int, ParameterizedSize))
3313 },
3314 **input_size_refs,
3315 }
3316 for a, ax in enumerate(ipt.axes):
3317 cls._validate_axis(
3318 "inputs",
3319 i=i,
3320 tensor_id=ipt.id,
3321 a=a,
3322 axis=ax,
3323 valid_independent_refs=valid_independent_refs,
3324 )
3325 return inputs
3327 @staticmethod
3328 def _validate_axis(
3329 field_name: str,
3330 i: int,
3331 tensor_id: TensorId,
3332 a: int,
3333 axis: AnyAxis,
3334 valid_independent_refs: Dict[
3335 Tuple[TensorId, AxisId],
3336 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3337 ],
3338 ):
3339 if isinstance(axis, BatchAxis) or isinstance(
3340 axis.size, (int, ParameterizedSize, DataDependentSize)
3341 ):
3342 return
3343 elif not isinstance(axis.size, SizeReference):
3344 assert_never(axis.size)
3346 # validate axis.size SizeReference
3347 ref = (axis.size.tensor_id, axis.size.axis_id)
3348 if ref not in valid_independent_refs:
3349 raise ValueError(
3350 "Invalid tensor axis reference at"
3351 + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
3352 )
3353 if ref == (tensor_id, axis.id):
3354 raise ValueError(
3355 "Self-referencing not allowed for"
3356 + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
3357 )
3358 if axis.type == "channel":
3359 if valid_independent_refs[ref][1].type != "channel":
3360 raise ValueError(
3361 "A channel axis' size may only reference another fixed size"
3362 + " channel axis."
3363 )
3364 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
3365 ref_size = valid_independent_refs[ref][2]
3366 assert isinstance(ref_size, int), (
3367 "channel axis ref (another channel axis) has to specify fixed"
3368 + " size"
3369 )
3370 generated_channel_names = [
3371 Identifier(axis.channel_names.format(i=i))
3372 for i in range(1, ref_size + 1)
3373 ]
3374 axis.channel_names = generated_channel_names
3376 if (ax_unit := getattr(axis, "unit", None)) != (
3377 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
3378 ):
3379 raise ValueError(
3380 "The units of an axis and its reference axis need to match, but"
3381 + f" '{ax_unit}' != '{ref_unit}'."
3382 )
3383 ref_axis = valid_independent_refs[ref][1]
3384 if isinstance(ref_axis, BatchAxis):
3385 raise ValueError(
3386 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
3387 + " (a batch axis is not allowed as reference)."
3388 )
3390 if isinstance(axis, WithHalo):
3391 min_size = axis.size.get_size(axis, ref_axis, n=0)
3392 if (min_size - 2 * axis.halo) < 1:
3393 raise ValueError(
3394 f"axis {axis.id} with minimum size {min_size} is too small for halo"
3395 + f" {axis.halo}."
3396 )
3398 ref_halo = axis.halo * axis.scale / ref_axis.scale
3399 if ref_halo != int(ref_halo):
3400 raise ValueError(
3401 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} ="
3402 + f" {tensor_id}.{axis.id}.halo {axis.halo}"
3403 + f" * {tensor_id}.{axis.id}.scale {axis.scale}"
3404 + f" / {'.'.join(ref)}.scale {ref_axis.scale})."
3405 )
3407 def validate_input_tensors(
3408 self,
3409 sources: Union[
3410 Sequence[NDArray[Any]], Mapping[TensorId, Optional[NDArray[Any]]]
3411 ],
3412 ) -> Mapping[TensorId, Optional[NDArray[Any]]]:
3413 """Check if the given input tensors match the model's input tensor descriptions.
3414 This includes checks of tensor shapes and dtypes, but not of the actual values.
3415 """
3416 if not isinstance(sources, collections.abc.Mapping):
3417 sources = {descr.id: tensor for descr, tensor in zip(self.inputs, sources)}
3419 tensors = {descr.id: (descr, sources.get(descr.id)) for descr in self.inputs}
3420 validate_tensors(tensors)
3422 return sources
3424 @model_validator(mode="after")
3425 def _validate_test_tensors(self) -> Self:
3426 if not get_validation_context().perform_io_checks:
3427 return self
3429 test_inputs = {
3430 descr.id: (
3431 descr,
3432 None if descr.test_tensor is None else load_array(descr.test_tensor),
3433 )
3434 for descr in self.inputs
3435 }
3436 test_outputs = {
3437 descr.id: (
3438 descr,
3439 None if descr.test_tensor is None else load_array(descr.test_tensor),
3440 )
3441 for descr in self.outputs
3442 }
3444 validate_tensors({**test_inputs, **test_outputs}, tensor_origin="test_tensor")
3446 for rep_tol in self.config.bioimageio.reproducibility_tolerance:
3447 if not rep_tol.absolute_tolerance:
3448 continue
3450 if rep_tol.output_ids:
3451 out_arrays = {
3452 k: v[1] for k, v in test_outputs.items() if k in rep_tol.output_ids
3453 }
3454 else:
3455 out_arrays = {k: v[1] for k, v in test_outputs.items()}
3457 for out_id, array in out_arrays.items():
3458 if array is None:
3459 continue
3461 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
3462 raise ValueError(
3463 "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
3464 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
3465 + f" (1% of the maximum value of the test tensor '{out_id}')"
3466 )
3468 return self
3470 @model_validator(mode="after")
3471 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
3472 ipt_refs = {t.id for t in self.inputs}
3473 missing_refs = [
3474 k["reference_tensor"]
3475 for k in [p.kwargs for ipt in self.inputs for p in ipt.preprocessing]
3476 + [p.kwargs for out in self.outputs for p in out.postprocessing]
3477 if "reference_tensor" in k
3478 and k["reference_tensor"] is not None
3479 and k["reference_tensor"] not in ipt_refs
3480 ]
3482 if missing_refs:
3483 raise ValueError(
3484 f"`reference_tensor`s {missing_refs} not found. Valid input tensor"
3485 + f" references are: {ipt_refs}."
3486 )
3488 return self
3490 name: Annotated[
3491 str,
3492 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
3493 MinLen(5),
3494 MaxLen(128),
3495 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
3496 ]
3497 """A human-readable name of this model.
3498 It should be no longer than 64 characters
3499 and may only contain letter, number, underscore, minus, parentheses and spaces.
3500 We recommend to chose a name that refers to the model's task and image modality.
3501 """
3503 outputs: NotEmpty[Sequence[OutputTensorDescr]]
3504 """Describes the output tensors."""
3506 @field_validator("outputs", mode="after")
3507 @classmethod
3508 def _validate_tensor_ids(
3509 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
3510 ) -> Sequence[OutputTensorDescr]:
3511 tensor_ids = [
3512 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
3513 ]
3514 duplicate_tensor_ids: List[str] = []
3515 seen: Set[str] = set()
3516 for t in tensor_ids:
3517 if t in seen:
3518 duplicate_tensor_ids.append(t)
3520 seen.add(t)
3522 if duplicate_tensor_ids:
3523 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
3525 return outputs
3527 @staticmethod
3528 def _get_axes_with_parameterized_size(
3529 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3530 ):
3531 return {
3532 f"{t.id}.{a.id}": (t, a, a.size)
3533 for t in io
3534 for a in t.axes
3535 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
3536 }
3538 @staticmethod
3539 def _get_axes_with_independent_size(
3540 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3541 ):
3542 return {
3543 (t.id, a.id): (t, a, a.size)
3544 for t in io
3545 for a in t.axes
3546 if not isinstance(a, BatchAxis)
3547 and isinstance(a.size, (int, ParameterizedSize))
3548 }
3550 @field_validator("outputs", mode="after")
3551 @classmethod
3552 def _validate_output_axes(
3553 cls, outputs: List[OutputTensorDescr], info: ValidationInfo
3554 ) -> List[OutputTensorDescr]:
3555 input_size_refs = cls._get_axes_with_independent_size(
3556 info.data.get("inputs", [])
3557 )
3558 output_size_refs = cls._get_axes_with_independent_size(outputs)
3560 for i, out in enumerate(outputs):
3561 valid_independent_refs: Dict[
3562 Tuple[TensorId, AxisId],
3563 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3564 ] = {
3565 **{
3566 (out.id, a.id): (out, a, a.size)
3567 for a in out.axes
3568 if not isinstance(a, BatchAxis)
3569 and isinstance(a.size, (int, ParameterizedSize))
3570 },
3571 **input_size_refs,
3572 **output_size_refs,
3573 }
3574 for a, ax in enumerate(out.axes):
3575 cls._validate_axis(
3576 "outputs",
3577 i,
3578 out.id,
3579 a,
3580 ax,
3581 valid_independent_refs=valid_independent_refs,
3582 )
3584 return outputs
3586 packaged_by: List[Author] = Field(
3587 default_factory=cast(Callable[[], List[Author]], list)
3588 )
3589 """The persons that have packaged and uploaded this model.
3590 Only required if those persons differ from the `authors`."""
3592 parent: Optional[LinkedModel] = None
3593 """The model from which this model is derived, e.g. by fine-tuning the weights."""
3595 @model_validator(mode="after")
3596 def _validate_parent_is_not_self(self) -> Self:
3597 if self.parent is not None and self.parent.id == self.id:
3598 raise ValueError("A model description may not reference itself as parent.")
3600 return self
3602 run_mode: Annotated[
3603 Optional[RunMode],
3604 warn(None, "Run mode '{value}' has limited support across consumer softwares."),
3605 ] = None
3606 """Custom run mode for this model: for more complex prediction procedures like test time
3607 data augmentation that currently cannot be expressed in the specification.
3608 No standard run modes are defined yet."""
3610 timestamp: Datetime = Field(default_factory=Datetime.now)
3611 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
3612 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
3613 (In Python a datetime object is valid, too)."""
3615 training_data: Annotated[
3616 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
3617 Field(union_mode="left_to_right"),
3618 ] = None
3619 """The dataset used to train this model"""
3621 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
3622 """The weights for this model.
3623 Weights can be given for different formats, but should otherwise be equivalent.
3624 The available weight formats determine which consumers can use this model."""
3626 config: Config = Field(default_factory=Config.model_construct)
3628 @model_validator(mode="after")
3629 def _add_default_cover(self) -> Self:
3630 if not get_validation_context().perform_io_checks or self.covers:
3631 return self
3633 try:
3634 generated_covers = generate_covers(
3635 [
3636 (t, load_array(t.test_tensor))
3637 for t in self.inputs
3638 if t.test_tensor is not None
3639 ],
3640 [
3641 (t, load_array(t.test_tensor))
3642 for t in self.outputs
3643 if t.test_tensor is not None
3644 ],
3645 )
3646 except Exception as e:
3647 issue_warning(
3648 "Failed to generate cover image(s): {e}",
3649 value=self.covers,
3650 msg_context=dict(e=e),
3651 field="covers",
3652 )
3653 else:
3654 self.covers.extend(generated_covers)
3656 return self
3658 def get_input_test_arrays(self) -> List[NDArray[Any]]:
3659 return self._get_test_arrays(self.inputs)
3661 def get_output_test_arrays(self) -> List[NDArray[Any]]:
3662 return self._get_test_arrays(self.outputs)
3664 @staticmethod
3665 def _get_test_arrays(
3666 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3667 ):
3668 ts: List[FileDescr] = []
3669 for d in io_descr:
3670 if d.test_tensor is None:
3671 raise ValueError(
3672 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3673 )
3674 ts.append(d.test_tensor)
3676 data = [load_array(t) for t in ts]
3677 assert all(isinstance(d, np.ndarray) for d in data)
3678 return data
3680 @staticmethod
3681 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3682 batch_size = 1
3683 tensor_with_batchsize: Optional[TensorId] = None
3684 for tid in tensor_sizes:
3685 for aid, s in tensor_sizes[tid].items():
3686 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3687 continue
3689 if batch_size != 1:
3690 assert tensor_with_batchsize is not None
3691 raise ValueError(
3692 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3693 )
3695 batch_size = s
3696 tensor_with_batchsize = tid
3698 return batch_size
3700 def get_output_tensor_sizes(
3701 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3702 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3703 """Returns the tensor output sizes for given **input_sizes**.
3704 Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3705 Otherwise it might be larger than the actual (valid) output"""
3706 batch_size = self.get_batch_size(input_sizes)
3707 ns = self.get_ns(input_sizes)
3709 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3710 return tensor_sizes.outputs
3712 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3713 """get parameter `n` for each parameterized axis
3714 such that the valid input size is >= the given input size"""
3715 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3716 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3717 for tid in input_sizes:
3718 for aid, s in input_sizes[tid].items():
3719 size_descr = axes[tid][aid].size
3720 if isinstance(size_descr, ParameterizedSize):
3721 ret[(tid, aid)] = size_descr.get_n(s)
3722 elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3723 pass
3724 else:
3725 assert_never(size_descr)
3727 return ret
3729 def get_tensor_sizes(
3730 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3731 ) -> _TensorSizes:
3732 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3733 return _TensorSizes(
3734 {
3735 t: {
3736 aa: axis_sizes.inputs[(tt, aa)]
3737 for tt, aa in axis_sizes.inputs
3738 if tt == t
3739 }
3740 for t in {tt for tt, _ in axis_sizes.inputs}
3741 },
3742 {
3743 t: {
3744 aa: axis_sizes.outputs[(tt, aa)]
3745 for tt, aa in axis_sizes.outputs
3746 if tt == t
3747 }
3748 for t in {tt for tt, _ in axis_sizes.outputs}
3749 },
3750 )
3752 def get_axis_sizes(
3753 self,
3754 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3755 batch_size: Optional[int] = None,
3756 *,
3757 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3758 ) -> _AxisSizes:
3759 """Determine input and output block shape for scale factors **ns**
3760 of parameterized input sizes.
3762 Args:
3763 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3764 that is parameterized as `size = min + n * step`.
3765 batch_size: The desired size of the batch dimension.
3766 If given **batch_size** overwrites any batch size present in
3767 **max_input_shape**. Default 1.
3768 max_input_shape: Limits the derived block shapes.
3769 Each axis for which the input size, parameterized by `n`, is larger
3770 than **max_input_shape** is set to the minimal value `n_min` for which
3771 this is still true.
3772 Use this for small input samples or large values of **ns**.
3773 Or simply whenever you know the full input shape.
3775 Returns:
3776 Resolved axis sizes for model inputs and outputs.
3777 """
3778 max_input_shape = max_input_shape or {}
3779 if batch_size is None:
3780 for (_t_id, a_id), s in max_input_shape.items():
3781 if a_id == BATCH_AXIS_ID:
3782 batch_size = s
3783 break
3784 else:
3785 batch_size = 1
3787 all_axes = {
3788 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3789 }
3791 inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3792 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3794 def get_axis_size(a: Union[InputAxis, OutputAxis]):
3795 if isinstance(a, BatchAxis):
3796 if (t_descr.id, a.id) in ns:
3797 logger.warning(
3798 "Ignoring unexpected size increment factor (n) for batch axis"
3799 + " of tensor '{}'.",
3800 t_descr.id,
3801 )
3802 return batch_size
3803 elif isinstance(a.size, int):
3804 if (t_descr.id, a.id) in ns:
3805 logger.warning(
3806 "Ignoring unexpected size increment factor (n) for fixed size"
3807 + " axis '{}' of tensor '{}'.",
3808 a.id,
3809 t_descr.id,
3810 )
3811 return a.size
3812 elif isinstance(a.size, ParameterizedSize):
3813 if (t_descr.id, a.id) not in ns:
3814 raise ValueError(
3815 "Size increment factor (n) missing for parametrized axis"
3816 + f" '{a.id}' of tensor '{t_descr.id}'."
3817 )
3818 n = ns[(t_descr.id, a.id)]
3819 s_max = max_input_shape.get((t_descr.id, a.id))
3820 if s_max is not None:
3821 n = min(n, a.size.get_n(s_max))
3823 return a.size.get_size(n)
3825 elif isinstance(a.size, SizeReference):
3826 if (t_descr.id, a.id) in ns:
3827 logger.warning(
3828 "Ignoring unexpected size increment factor (n) for axis '{}'"
3829 + " of tensor '{}' with size reference.",
3830 a.id,
3831 t_descr.id,
3832 )
3833 assert not isinstance(a, BatchAxis)
3834 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3835 assert not isinstance(ref_axis, BatchAxis)
3836 ref_key = (a.size.tensor_id, a.size.axis_id)
3837 ref_size = inputs.get(ref_key, outputs.get(ref_key))
3838 assert ref_size is not None, ref_key
3839 assert not isinstance(ref_size, _DataDepSize), ref_key
3840 return a.size.get_size(
3841 axis=a,
3842 ref_axis=ref_axis,
3843 ref_size=ref_size,
3844 )
3845 elif isinstance(a.size, DataDependentSize):
3846 if (t_descr.id, a.id) in ns:
3847 logger.warning(
3848 "Ignoring unexpected increment factor (n) for data dependent"
3849 + " size axis '{}' of tensor '{}'.",
3850 a.id,
3851 t_descr.id,
3852 )
3853 return _DataDepSize(a.size.min, a.size.max)
3854 else:
3855 assert_never(a.size)
3857 # first resolve all , but the `SizeReference` input sizes
3858 for t_descr in self.inputs:
3859 for a in t_descr.axes:
3860 if not isinstance(a.size, SizeReference):
3861 s = get_axis_size(a)
3862 assert not isinstance(s, _DataDepSize)
3863 inputs[t_descr.id, a.id] = s
3865 # resolve all other input axis sizes
3866 for t_descr in self.inputs:
3867 for a in t_descr.axes:
3868 if isinstance(a.size, SizeReference):
3869 s = get_axis_size(a)
3870 assert not isinstance(s, _DataDepSize)
3871 inputs[t_descr.id, a.id] = s
3873 # resolve all output axis sizes
3874 for t_descr in self.outputs:
3875 for a in t_descr.axes:
3876 assert not isinstance(a.size, ParameterizedSize)
3877 s = get_axis_size(a)
3878 outputs[t_descr.id, a.id] = s
3880 return _AxisSizes(inputs=inputs, outputs=outputs)
3882 @model_validator(mode="before")
3883 @classmethod
3884 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3885 cls.convert_from_old_format_wo_validation(data)
3886 return data
3888 @classmethod
3889 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3890 """Convert metadata following an older format version to this classes' format
3891 without validating the result.
3892 """
3893 if (
3894 data.get("type") == "model"
3895 and isinstance(fv := data.get("format_version"), str)
3896 and fv.count(".") == 2
3897 ):
3898 fv_parts = fv.split(".")
3899 if any(not p.isdigit() for p in fv_parts):
3900 return
3902 fv_tuple = tuple(map(int, fv_parts))
3904 assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3905 if fv_tuple[:2] in ((0, 3), (0, 4)):
3906 m04 = _ModelDescr_v0_4.load(data)
3907 if isinstance(m04, InvalidDescr):
3908 try:
3909 updated = _model_conv.convert_as_dict(
3910 m04 # pyright: ignore[reportArgumentType]
3911 )
3912 except Exception as e:
3913 logger.error(
3914 "Failed to convert from invalid model 0.4 description."
3915 + f"\nerror: {e}"
3916 + "\nProceeding with model 0.5 validation without conversion."
3917 )
3918 updated = None
3919 else:
3920 updated = _model_conv.convert_as_dict(m04)
3922 if updated is not None:
3923 data.clear()
3924 data.update(updated)
3926 elif fv_tuple[:2] == (0, 5):
3927 # bump patch version
3928 data["format_version"] = cls.implemented_format_version
3931class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3932 def _convert(
3933 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3934 ) -> "ModelDescr | dict[str, Any]":
3935 name = "".join(
3936 c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3937 for c in src.name
3938 )
3940 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3941 conv = (
3942 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3943 )
3944 return None if auths is None else [conv(a) for a in auths]
3946 if TYPE_CHECKING:
3947 arch_file_conv = _arch_file_conv.convert
3948 arch_lib_conv = _arch_lib_conv.convert
3949 else:
3950 arch_file_conv = _arch_file_conv.convert_as_dict
3951 arch_lib_conv = _arch_lib_conv.convert_as_dict
3953 input_size_refs = {
3954 ipt.name: {
3955 a: s
3956 for a, s in zip(
3957 ipt.axes,
3958 (
3959 ipt.shape.min
3960 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3961 else ipt.shape
3962 ),
3963 )
3964 }
3965 for ipt in src.inputs
3966 if ipt.shape
3967 }
3968 output_size_refs = {
3969 **{
3970 out.name: {a: s for a, s in zip(out.axes, out.shape)}
3971 for out in src.outputs
3972 if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3973 },
3974 **input_size_refs,
3975 }
3977 return tgt(
3978 attachments=(
3979 []
3980 if src.attachments is None
3981 else [FileDescr(source=f) for f in src.attachments.files]
3982 ),
3983 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType]
3984 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType]
3985 config=src.config, # pyright: ignore[reportArgumentType]
3986 covers=src.covers,
3987 description=src.description,
3988 documentation=src.documentation,
3989 format_version="0.5.9",
3990 git_repo=src.git_repo, # pyright: ignore[reportArgumentType]
3991 icon=src.icon,
3992 id=None if src.id is None else ModelId(src.id),
3993 id_emoji=src.id_emoji,
3994 license=src.license, # type: ignore
3995 links=src.links,
3996 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType]
3997 name=name,
3998 tags=src.tags,
3999 type=src.type,
4000 uploader=src.uploader,
4001 version=src.version,
4002 inputs=[ # pyright: ignore[reportArgumentType]
4003 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
4004 for ipt, tt, st in zip(
4005 src.inputs,
4006 src.test_inputs,
4007 src.sample_inputs or [None] * len(src.test_inputs),
4008 )
4009 ],
4010 outputs=[ # pyright: ignore[reportArgumentType]
4011 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
4012 for out, tt, st in zip(
4013 src.outputs,
4014 src.test_outputs,
4015 src.sample_outputs or [None] * len(src.test_outputs),
4016 )
4017 ],
4018 parent=(
4019 None
4020 if src.parent is None
4021 else LinkedModel(
4022 id=ModelId(
4023 str(src.parent.id)
4024 + (
4025 ""
4026 if src.parent.version_number is None
4027 else f"/{src.parent.version_number}"
4028 )
4029 )
4030 )
4031 ),
4032 training_data=(
4033 None
4034 if src.training_data is None
4035 else (
4036 LinkedDataset(
4037 id=DatasetId(
4038 str(src.training_data.id)
4039 + (
4040 ""
4041 if src.training_data.version_number is None
4042 else f"/{src.training_data.version_number}"
4043 )
4044 )
4045 )
4046 if isinstance(src.training_data, LinkedDataset02)
4047 else src.training_data
4048 )
4049 ),
4050 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType]
4051 run_mode=src.run_mode,
4052 timestamp=src.timestamp,
4053 weights=(WeightsDescr if TYPE_CHECKING else dict)(
4054 keras_hdf5=(w := src.weights.keras_hdf5)
4055 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
4056 authors=conv_authors(w.authors),
4057 source=w.source,
4058 tensorflow_version=w.tensorflow_version or Version("1.15"),
4059 parent=w.parent,
4060 ),
4061 onnx=(w := src.weights.onnx)
4062 and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
4063 source=w.source,
4064 authors=conv_authors(w.authors),
4065 parent=w.parent,
4066 opset_version=w.opset_version or 15,
4067 ),
4068 pytorch_state_dict=(w := src.weights.pytorch_state_dict)
4069 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
4070 source=w.source,
4071 authors=conv_authors(w.authors),
4072 parent=w.parent,
4073 architecture=(
4074 arch_file_conv(
4075 w.architecture,
4076 w.architecture_sha256,
4077 w.kwargs,
4078 )
4079 if isinstance(w.architecture, _CallableFromFile_v0_4)
4080 else arch_lib_conv(w.architecture, w.kwargs)
4081 ),
4082 pytorch_version=w.pytorch_version or Version("1.10"),
4083 dependencies=(
4084 None
4085 if w.dependencies is None
4086 else (FileDescr if TYPE_CHECKING else dict)(
4087 source=cast(
4088 FileSource,
4089 str(deps := w.dependencies)[
4090 (
4091 len("conda:")
4092 if str(deps).startswith("conda:")
4093 else 0
4094 ) :
4095 ],
4096 )
4097 )
4098 ),
4099 ),
4100 tensorflow_js=(w := src.weights.tensorflow_js)
4101 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
4102 source=w.source,
4103 authors=conv_authors(w.authors),
4104 parent=w.parent,
4105 tensorflow_version=w.tensorflow_version or Version("1.15"),
4106 ),
4107 tensorflow_saved_model_bundle=(
4108 w := src.weights.tensorflow_saved_model_bundle
4109 )
4110 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
4111 authors=conv_authors(w.authors),
4112 parent=w.parent,
4113 source=w.source,
4114 tensorflow_version=w.tensorflow_version or Version("1.15"),
4115 dependencies=(
4116 None
4117 if w.dependencies is None
4118 else (FileDescr if TYPE_CHECKING else dict)(
4119 source=cast(
4120 FileSource,
4121 (
4122 str(w.dependencies)[len("conda:") :]
4123 if str(w.dependencies).startswith("conda:")
4124 else str(w.dependencies)
4125 ),
4126 )
4127 )
4128 ),
4129 ),
4130 torchscript=(w := src.weights.torchscript)
4131 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
4132 source=w.source,
4133 authors=conv_authors(w.authors),
4134 parent=w.parent,
4135 pytorch_version=w.pytorch_version or Version("1.10"),
4136 ),
4137 ),
4138 )
4141_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
4144# create better cover images for 3d data and non-image outputs
4145def generate_covers(
4146 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
4147 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
4148) -> List[Path]:
4149 def squeeze(
4150 data: NDArray[Any], axes: Sequence[AnyAxis]
4151 ) -> Tuple[NDArray[Any], List[AnyAxis]]:
4152 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
4153 if data.ndim != len(axes):
4154 raise ValueError(
4155 f"tensor shape {data.shape} does not match described axes"
4156 + f" {[a.id for a in axes]}"
4157 )
4159 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
4160 return data.squeeze(), axes
4162 def normalize(
4163 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
4164 ) -> NDArray[np.float32]:
4165 data = data.astype("float32")
4166 data -= data.min(axis=axis, keepdims=True)
4167 data /= data.max(axis=axis, keepdims=True) + eps
4168 return data
4170 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
4171 original_shape = data.shape
4172 original_axes = list(axes)
4173 data, axes = squeeze(data, axes)
4175 # take slice fom any batch or index axis if needed
4176 # and convert the first channel axis and take a slice from any additional channel axes
4177 slices: Tuple[slice, ...] = ()
4178 ndim = data.ndim
4179 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
4180 has_c_axis = False
4181 for i, a in enumerate(axes):
4182 s = data.shape[i]
4183 assert s > 1
4184 if (
4185 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
4186 and ndim > ndim_need
4187 ):
4188 data = data[slices + (slice(s // 2 - 1, s // 2),)]
4189 ndim -= 1
4190 elif isinstance(a, ChannelAxis):
4191 if has_c_axis:
4192 # second channel axis
4193 data = data[slices + (slice(0, 1),)]
4194 ndim -= 1
4195 else:
4196 has_c_axis = True
4197 if s == 2:
4198 # visualize two channels with cyan and magenta
4199 data = np.concatenate(
4200 [
4201 data[slices + (slice(1, 2),)],
4202 data[slices + (slice(0, 1),)],
4203 (
4204 data[slices + (slice(0, 1),)]
4205 + data[slices + (slice(1, 2),)]
4206 )
4207 / 2, # TODO: take maximum instead?
4208 ],
4209 axis=i,
4210 )
4211 elif data.shape[i] == 3:
4212 pass # visualize 3 channels as RGB
4213 else:
4214 # visualize first 3 channels as RGB
4215 data = data[slices + (slice(3),)]
4217 assert data.shape[i] == 3
4219 slices += (slice(None),)
4221 data, axes = squeeze(data, axes)
4222 assert len(axes) == ndim
4223 # take slice from z axis if needed
4224 slices = ()
4225 if ndim > ndim_need:
4226 for i, a in enumerate(axes):
4227 s = data.shape[i]
4228 if a.id == AxisId("z"):
4229 data = data[slices + (slice(s // 2 - 1, s // 2),)]
4230 data, axes = squeeze(data, axes)
4231 ndim -= 1
4232 break
4234 slices += (slice(None),)
4236 # take slice from any space or time axis
4237 slices = ()
4239 for i, a in enumerate(axes):
4240 if ndim <= ndim_need:
4241 break
4243 s = data.shape[i]
4244 assert s > 1
4245 if isinstance(
4246 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
4247 ):
4248 data = data[slices + (slice(s // 2 - 1, s // 2),)]
4249 ndim -= 1
4251 slices += (slice(None),)
4253 del slices
4254 data, axes = squeeze(data, axes)
4255 assert len(axes) == ndim
4257 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2):
4258 raise ValueError(
4259 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}."
4260 )
4262 if not has_c_axis:
4263 assert ndim == 2
4264 data = np.repeat(data[:, :, None], 3, axis=2)
4265 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
4266 ndim += 1
4268 assert ndim == 3
4270 # transpose axis order such that longest axis comes first...
4271 axis_order: List[int] = list(np.argsort(list(data.shape)))
4272 axis_order.reverse()
4273 # ... and channel axis is last
4274 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
4275 axis_order.append(axis_order.pop(c))
4276 axes = [axes[ao] for ao in axis_order]
4277 data = data.transpose(axis_order)
4279 # h, w = data.shape[:2]
4280 # if h / w in (1.0 or 2.0):
4281 # pass
4282 # elif h / w < 2:
4283 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
4285 norm_along = (
4286 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
4287 )
4288 # normalize the data and map to 8 bit
4289 data = normalize(data, norm_along)
4290 data = (data * 255).astype("uint8")
4292 return data
4294 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
4295 assert im0.dtype == im1.dtype == np.uint8
4296 assert im0.shape == im1.shape
4297 assert im0.ndim == 3
4298 N, M, C = im0.shape
4299 assert C == 3
4300 out = np.ones((N, M, C), dtype="uint8")
4301 for c in range(C):
4302 outc = np.tril(im0[..., c])
4303 mask = outc == 0
4304 outc[mask] = np.triu(im1[..., c])[mask]
4305 out[..., c] = outc
4307 return out
4309 if not inputs:
4310 raise ValueError("Missing test input tensor for cover generation.")
4312 if not outputs:
4313 raise ValueError("Missing test output tensor for cover generation.")
4315 ipt_descr, ipt = inputs[0]
4316 out_descr, out = outputs[0]
4318 ipt_img = to_2d_image(ipt, ipt_descr.axes)
4319 out_img = to_2d_image(out, out_descr.axes)
4321 cover_folder = Path(mkdtemp())
4322 if ipt_img.shape == out_img.shape:
4323 covers = [cover_folder / "cover.png"]
4324 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
4325 else:
4326 covers = [cover_folder / "input.png", cover_folder / "output.png"]
4327 imwrite(covers[0], ipt_img)
4328 imwrite(covers[1], out_img)
4330 return covers