Coverage for src / bioimageio / spec / model / v0_5.py: 76%
1581 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 14:45 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 14:45 +0000
1from __future__ import annotations
3import collections.abc
4import re
5import string
6import warnings
7from copy import deepcopy
8from itertools import chain
9from math import ceil
10from pathlib import Path, PurePosixPath
11from tempfile import mkdtemp
12from textwrap import dedent
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 ClassVar,
18 Dict,
19 Generic,
20 List,
21 Literal,
22 Mapping,
23 NamedTuple,
24 Optional,
25 Sequence,
26 Set,
27 Tuple,
28 Type,
29 TypeVar,
30 Union,
31 cast,
32 overload,
33)
35import numpy as np
36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
38from loguru import logger
39from numpy.typing import NDArray
40from pydantic import (
41 AfterValidator,
42 Discriminator,
43 Field,
44 RootModel,
45 SerializationInfo,
46 SerializerFunctionWrapHandler,
47 StrictInt,
48 Tag,
49 ValidationInfo,
50 WrapSerializer,
51 field_validator,
52 model_serializer,
53 model_validator,
54)
55from typing_extensions import Annotated, Self, assert_never, get_args
57from .._internal.common_nodes import (
58 InvalidDescr,
59 KwargsNode,
60 Node,
61 NodeWithExplicitlySetFields,
62)
63from .._internal.constants import DTYPE_LIMITS
64from .._internal.field_warning import issue_warning, warn
65from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
66from .._internal.io import FileDescr as FileDescr
67from .._internal.io import (
68 FileSource,
69 WithSuffix,
70 YamlValue,
71 extract_file_name,
72 get_reader,
73 wo_special_file_name,
74)
75from .._internal.io_basics import Sha256 as Sha256
76from .._internal.io_packaging import (
77 FileDescr_,
78 FileSource_,
79 package_file_descr_serializer,
80)
81from .._internal.io_utils import load_array
82from .._internal.node_converter import Converter
83from .._internal.type_guards import is_dict, is_sequence
84from .._internal.types import (
85 FAIR,
86 AbsoluteTolerance,
87 LowerCaseIdentifier,
88 LowerCaseIdentifierAnno,
89 MismatchedElementsPerMillion,
90 RelativeTolerance,
91)
92from .._internal.types import Datetime as Datetime
93from .._internal.types import Identifier as Identifier
94from .._internal.types import NotEmpty as NotEmpty
95from .._internal.types import SiUnit as SiUnit
96from .._internal.url import HttpUrl as HttpUrl
97from .._internal.validation_context import get_validation_context
98from .._internal.validator_annotations import RestrictCharacters
99from .._internal.version_type import Version as Version
100from .._internal.warning_levels import INFO
101from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
102from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
103from ..dataset.v0_3 import DatasetDescr as DatasetDescr
104from ..dataset.v0_3 import DatasetId as DatasetId
105from ..dataset.v0_3 import LinkedDataset as LinkedDataset
106from ..dataset.v0_3 import Uploader as Uploader
107from ..generic.v0_3 import (
108 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
109)
110from ..generic.v0_3 import Author as Author
111from ..generic.v0_3 import BadgeDescr as BadgeDescr
112from ..generic.v0_3 import CiteEntry as CiteEntry
113from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
114from ..generic.v0_3 import Doi as Doi
115from ..generic.v0_3 import (
116 FileSource_documentation,
117 GenericModelDescrBase,
118 LinkedResourceBase,
119 _author_conv, # pyright: ignore[reportPrivateUsage]
120 _maintainer_conv, # pyright: ignore[reportPrivateUsage]
121)
122from ..generic.v0_3 import LicenseId as LicenseId
123from ..generic.v0_3 import LinkedResource as LinkedResource
124from ..generic.v0_3 import Maintainer as Maintainer
125from ..generic.v0_3 import OrcidId as OrcidId
126from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
127from ..generic.v0_3 import ResourceId as ResourceId
128from .v0_4 import Author as _Author_v0_4
129from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
130from .v0_4 import CallableFromDepencency as CallableFromDepencency
131from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
132from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
133from .v0_4 import ClipDescr as _ClipDescr_v0_4
134from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
135from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
136from .v0_4 import KnownRunMode as KnownRunMode
137from .v0_4 import ModelDescr as _ModelDescr_v0_4
138from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
139from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
140from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
141from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
142from .v0_4 import RunMode as RunMode
143from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
144from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
145from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
146from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
147from .v0_4 import TensorName as _TensorName_v0_4
148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
149from .v0_4 import package_weights
151SpaceUnit = Literal[
152 "attometer",
153 "angstrom",
154 "centimeter",
155 "decimeter",
156 "exameter",
157 "femtometer",
158 "foot",
159 "gigameter",
160 "hectometer",
161 "inch",
162 "kilometer",
163 "megameter",
164 "meter",
165 "micrometer",
166 "mile",
167 "millimeter",
168 "nanometer",
169 "parsec",
170 "petameter",
171 "picometer",
172 "terameter",
173 "yard",
174 "yoctometer",
175 "yottameter",
176 "zeptometer",
177 "zettameter",
178]
179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
181TimeUnit = Literal[
182 "attosecond",
183 "centisecond",
184 "day",
185 "decisecond",
186 "exasecond",
187 "femtosecond",
188 "gigasecond",
189 "hectosecond",
190 "hour",
191 "kilosecond",
192 "megasecond",
193 "microsecond",
194 "millisecond",
195 "minute",
196 "nanosecond",
197 "petasecond",
198 "picosecond",
199 "second",
200 "terasecond",
201 "yoctosecond",
202 "yottasecond",
203 "zeptosecond",
204 "zettasecond",
205]
206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
208AxisType = Literal["batch", "channel", "index", "time", "space"]
210_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
211 "b": "batch",
212 "t": "time",
213 "i": "index",
214 "c": "channel",
215 "x": "space",
216 "y": "space",
217 "z": "space",
218}
220_AXIS_ID_MAP = {
221 "b": "batch",
222 "t": "time",
223 "i": "index",
224 "c": "channel",
225}
227WeightsFormat = Literal[
228 "keras_hdf5",
229 "keras_v3",
230 "onnx",
231 "pytorch_state_dict",
232 "tensorflow_js",
233 "tensorflow_saved_model_bundle",
234 "torchscript",
235]
238class TensorId(LowerCaseIdentifier):
239 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
240 Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
241 ]
244def _normalize_axis_id(a: str):
245 a = str(a)
246 normalized = _AXIS_ID_MAP.get(a, a)
247 if a != normalized:
248 logger.opt(depth=3).warning(
249 "Normalized axis id from '{}' to '{}'.", a, normalized
250 )
251 return normalized
254class AxisId(LowerCaseIdentifier):
255 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
256 Annotated[
257 LowerCaseIdentifierAnno,
258 MaxLen(16),
259 AfterValidator(_normalize_axis_id),
260 ]
261 ]
264def _is_batch(a: str) -> bool:
265 return str(a) == "batch"
268def _is_not_batch(a: str) -> bool:
269 return not _is_batch(a)
272NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
274PreprocessingId = Literal[
275 "binarize",
276 "clip",
277 "ensure_dtype",
278 "fixed_zero_mean_unit_variance",
279 "scale_linear",
280 "scale_range",
281 "sigmoid",
282 "softmax",
283]
284PostprocessingId = Literal[
285 "binarize",
286 "clip",
287 "ensure_dtype",
288 "fixed_zero_mean_unit_variance",
289 "scale_linear",
290 "scale_mean_variance",
291 "scale_range",
292 "sigmoid",
293 "softmax",
294 "zero_mean_unit_variance",
295]
298SAME_AS_TYPE = "<same as type>"
301ParameterizedSize_N = int
302"""
303Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
304"""
307class ParameterizedSize(Node):
308 """Describes a range of valid tensor axis sizes as `size = min + n*step`.
310 - **min** and **step** are given by the model description.
311 - All blocksize paramters n = 0,1,2,... yield a valid `size`.
312 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
313 This allows to adjust the axis size more generically.
314 """
316 N: ClassVar[Type[int]] = ParameterizedSize_N
317 """Positive integer to parameterize this axis"""
319 min: Annotated[int, Gt(0)]
320 step: Annotated[int, Gt(0)]
322 def validate_size(self, size: int, msg_prefix: str = "") -> int:
323 if size < self.min:
324 raise ValueError(
325 f"{msg_prefix}size {size} < {self.min} (minimum axis size)"
326 )
327 if (size - self.min) % self.step != 0:
328 raise ValueError(
329 f"{msg_prefix}size {size} is not parameterized by `min + n*step` ="
330 + f" `{self.min} + n*{self.step}`"
331 )
333 return size
335 def get_size(self, n: ParameterizedSize_N) -> int:
336 return self.min + self.step * n
338 def get_n(self, s: int) -> ParameterizedSize_N:
339 """return smallest n parameterizing a size greater or equal than `s`"""
340 return ceil((s - self.min) / self.step)
343class DataDependentSize(Node):
344 min: Annotated[int, Gt(0)] = 1
345 max: Annotated[Optional[int], Gt(1)] = None
347 @model_validator(mode="after")
348 def _validate_max_gt_min(self):
349 if self.max is not None and self.min >= self.max:
350 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
352 return self
354 def validate_size(self, size: int, msg_prefix: str = "") -> int:
355 if size < self.min:
356 raise ValueError(f"{msg_prefix}size {size} < {self.min}")
358 if self.max is not None and size > self.max:
359 raise ValueError(f"{msg_prefix}size {size} > {self.max}")
361 return size
364class SizeReference(Node):
365 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
367 `axis.size = reference.size * reference.scale / axis.scale + offset`
369 Note:
370 1. The axis and the referenced axis need to have the same unit (or no unit).
371 2. Batch axes may not be referenced.
372 3. Fractions are rounded down.
373 4. If the reference axis is `concatenable` the referencing axis is assumed to be
374 `concatenable` as well with the same block order.
376 Example:
377 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
378 Let's assume that we want to express the image height h in relation to its width w
379 instead of only accepting input images of exactly 100*49 pixels
380 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
382 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
383 >>> h = SpaceInputAxis(
384 ... id=AxisId("h"),
385 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
386 ... unit="millimeter",
387 ... scale=4,
388 ... )
389 >>> print(h.size.get_size(h, w))
390 49
392 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
393 """
395 tensor_id: TensorId
396 """tensor id of the reference axis"""
398 axis_id: AxisId
399 """axis id of the reference axis"""
401 offset: StrictInt = 0
403 def get_size(
404 self,
405 axis: Union[
406 ChannelAxis,
407 IndexInputAxis,
408 IndexOutputAxis,
409 TimeInputAxis,
410 SpaceInputAxis,
411 TimeOutputAxis,
412 TimeOutputAxisWithHalo,
413 SpaceOutputAxis,
414 SpaceOutputAxisWithHalo,
415 ],
416 ref_axis: Union[
417 ChannelAxis,
418 IndexInputAxis,
419 IndexOutputAxis,
420 TimeInputAxis,
421 SpaceInputAxis,
422 TimeOutputAxis,
423 TimeOutputAxisWithHalo,
424 SpaceOutputAxis,
425 SpaceOutputAxisWithHalo,
426 ],
427 n: ParameterizedSize_N = 0,
428 ref_size: Optional[int] = None,
429 ):
430 """Compute the concrete size for a given axis and its reference axis.
432 Args:
433 axis: The axis this [SizeReference][] is the size of.
434 ref_axis: The reference axis to compute the size from.
435 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
436 and no fixed **ref_size** is given,
437 **n** is used to compute the size of the parameterized **ref_axis**.
438 ref_size: Overwrite the reference size instead of deriving it from
439 **ref_axis**
440 (**ref_axis.scale** is still used; any given **n** is ignored).
441 """
442 assert axis.size == self, (
443 "Given `axis.size` is not defined by this `SizeReference`"
444 )
446 assert ref_axis.id == self.axis_id, (
447 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
448 )
450 assert axis.unit == ref_axis.unit, (
451 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
452 f" but {axis.unit}!={ref_axis.unit}"
453 )
454 if ref_size is None:
455 if isinstance(ref_axis.size, (int, float)):
456 ref_size = ref_axis.size
457 elif isinstance(ref_axis.size, ParameterizedSize):
458 ref_size = ref_axis.size.get_size(n)
459 elif isinstance(ref_axis.size, DataDependentSize):
460 raise ValueError(
461 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
462 )
463 elif isinstance(ref_axis.size, SizeReference):
464 raise ValueError(
465 "Reference axis referenced in `SizeReference` may not be sized by a"
466 + " `SizeReference` itself."
467 )
468 else:
469 assert_never(ref_axis.size)
471 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
473 @staticmethod
474 def _get_unit(
475 axis: Union[
476 ChannelAxis,
477 IndexInputAxis,
478 IndexOutputAxis,
479 TimeInputAxis,
480 SpaceInputAxis,
481 TimeOutputAxis,
482 TimeOutputAxisWithHalo,
483 SpaceOutputAxis,
484 SpaceOutputAxisWithHalo,
485 ],
486 ):
487 return axis.unit
490class AxisBase(NodeWithExplicitlySetFields):
491 id: AxisId
492 """An axis id unique across all axes of one tensor."""
494 description: Annotated[str, MaxLen(128)] = ""
495 """A short description of this axis beyond its type and id."""
498class WithHalo(Node):
499 halo: Annotated[int, Ge(1)]
500 """The halo should be cropped from the output tensor to avoid boundary effects.
501 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
502 To document a halo that is already cropped by the model use `size.offset` instead."""
504 size: Annotated[
505 SizeReference,
506 Field(
507 examples=[
508 10,
509 SizeReference(
510 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
511 ).model_dump(mode="json"),
512 ]
513 ),
514 ]
515 """reference to another axis with an optional offset (see [SizeReference][])"""
518BATCH_AXIS_ID = AxisId("batch")
521class BatchAxis(AxisBase):
522 implemented_type: ClassVar[Literal["batch"]] = "batch"
523 if TYPE_CHECKING:
524 type: Literal["batch"] = "batch"
525 else:
526 type: Literal["batch"]
528 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
529 size: Optional[Literal[1]] = None
530 """The batch size may be fixed to 1,
531 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
533 @property
534 def scale(self):
535 return 1.0
537 @property
538 def concatenable(self):
539 return True
541 @property
542 def unit(self):
543 return None
546class ChannelAxis(AxisBase):
547 implemented_type: ClassVar[Literal["channel"]] = "channel"
548 if TYPE_CHECKING:
549 type: Literal["channel"] = "channel"
550 else:
551 type: Literal["channel"]
553 id: NonBatchAxisId = AxisId("channel")
555 channel_names: NotEmpty[List[Identifier]]
557 @property
558 def size(self) -> int:
559 return len(self.channel_names)
561 @property
562 def concatenable(self):
563 return False
565 @property
566 def scale(self) -> float:
567 return 1.0
569 @property
570 def unit(self):
571 return None
574class IndexAxisBase(AxisBase):
575 implemented_type: ClassVar[Literal["index"]] = "index"
576 if TYPE_CHECKING:
577 type: Literal["index"] = "index"
578 else:
579 type: Literal["index"]
581 id: NonBatchAxisId = AxisId("index")
583 @property
584 def scale(self) -> float:
585 return 1.0
587 @property
588 def unit(self):
589 return None
592class _WithInputAxisSize(Node):
593 size: Annotated[
594 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
595 Field(
596 examples=[
597 10,
598 ParameterizedSize(min=32, step=16).model_dump(mode="json"),
599 SizeReference(
600 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
601 ).model_dump(mode="json"),
602 ]
603 ),
604 ]
605 """The size/length of this axis can be specified as
606 - fixed integer
607 - parameterized series of valid sizes ([ParameterizedSize][])
608 - reference to another axis with an optional offset ([SizeReference][])
609 """
612class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
613 concatenable: bool = False
614 """If a model has a `concatenable` input axis, it can be processed blockwise,
615 splitting a longer sample axis into blocks matching its input tensor description.
616 Output axes are concatenable if they have a [SizeReference][] to a concatenable
617 input axis.
618 """
621class IndexOutputAxis(IndexAxisBase):
622 size: Annotated[
623 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
624 Field(
625 examples=[
626 10,
627 SizeReference(
628 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
629 ).model_dump(mode="json"),
630 ]
631 ),
632 ]
633 """The size/length of this axis can be specified as
634 - fixed integer
635 - reference to another axis with an optional offset ([SizeReference][])
636 - data dependent size using [DataDependentSize][] (size is only known after model inference)
637 """
640class TimeAxisBase(AxisBase):
641 implemented_type: ClassVar[Literal["time"]] = "time"
642 if TYPE_CHECKING:
643 type: Literal["time"] = "time"
644 else:
645 type: Literal["time"]
647 id: NonBatchAxisId = AxisId("time")
648 unit: Optional[TimeUnit] = None
649 scale: Annotated[float, Gt(0)] = 1.0
652class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
653 concatenable: bool = False
654 """If a model has a `concatenable` input axis, it can be processed blockwise,
655 splitting a longer sample axis into blocks matching its input tensor description.
656 Output axes are concatenable if they have a [SizeReference][] to a concatenable
657 input axis.
658 """
661class SpaceAxisBase(AxisBase):
662 implemented_type: ClassVar[Literal["space"]] = "space"
663 if TYPE_CHECKING:
664 type: Literal["space"] = "space"
665 else:
666 type: Literal["space"]
668 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
669 unit: Optional[SpaceUnit] = None
670 scale: Annotated[float, Gt(0)] = 1.0
673class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
674 concatenable: bool = False
675 """If a model has a `concatenable` input axis, it can be processed blockwise,
676 splitting a longer sample axis into blocks matching its input tensor description.
677 Output axes are concatenable if they have a [SizeReference][] to a concatenable
678 input axis.
679 """
682INPUT_AXIS_TYPES = (
683 BatchAxis,
684 ChannelAxis,
685 IndexInputAxis,
686 TimeInputAxis,
687 SpaceInputAxis,
688)
689"""intended for isinstance comparisons in py<3.10"""
691_InputAxisUnion = Union[
692 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
693]
694InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
697class _WithOutputAxisSize(Node):
698 size: Annotated[
699 Union[Annotated[int, Gt(0)], SizeReference],
700 Field(
701 examples=[
702 10,
703 SizeReference(
704 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
705 ).model_dump(mode="json"),
706 ]
707 ),
708 ]
709 """The size/length of this axis can be specified as
710 - fixed integer
711 - reference to another axis with an optional offset (see [SizeReference][])
712 """
715class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
716 pass
719class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
720 pass
723def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
724 if isinstance(v, dict):
725 return "with_halo" if "halo" in v else "wo_halo"
726 else:
727 return "with_halo" if hasattr(v, "halo") else "wo_halo"
730_TimeOutputAxisUnion = Annotated[
731 Union[
732 Annotated[TimeOutputAxis, Tag("wo_halo")],
733 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
734 ],
735 Discriminator(_get_halo_axis_discriminator_value),
736]
739class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
740 pass
743class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
744 pass
747_SpaceOutputAxisUnion = Annotated[
748 Union[
749 Annotated[SpaceOutputAxis, Tag("wo_halo")],
750 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
751 ],
752 Discriminator(_get_halo_axis_discriminator_value),
753]
756_OutputAxisUnion = Union[
757 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
758]
759OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
761OUTPUT_AXIS_TYPES = (
762 BatchAxis,
763 ChannelAxis,
764 IndexOutputAxis,
765 TimeOutputAxis,
766 TimeOutputAxisWithHalo,
767 SpaceOutputAxis,
768 SpaceOutputAxisWithHalo,
769)
770"""intended for isinstance comparisons in py<3.10"""
773AnyAxis = Union[InputAxis, OutputAxis]
775ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
776"""intended for isinstance comparisons in py<3.10"""
778TVs = Union[
779 NotEmpty[List[int]],
780 NotEmpty[List[float]],
781 NotEmpty[List[bool]],
782 NotEmpty[List[str]],
783]
786NominalOrOrdinalDType = Literal[
787 "float32",
788 "float64",
789 "uint8",
790 "int8",
791 "uint16",
792 "int16",
793 "uint32",
794 "int32",
795 "uint64",
796 "int64",
797 "bool",
798]
801class NominalOrOrdinalDataDescr(Node):
802 values: TVs
803 """A fixed set of nominal or an ascending sequence of ordinal values.
804 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
805 String `values` are interpreted as labels for tensor values 0, ..., N.
806 Note: as YAML 1.2 does not natively support a "set" datatype,
807 nominal values should be given as a sequence (aka list/array) as well.
808 """
810 type: Annotated[
811 NominalOrOrdinalDType,
812 Field(
813 examples=[
814 "float32",
815 "uint8",
816 "uint16",
817 "int64",
818 "bool",
819 ],
820 ),
821 ] = "uint8"
823 @model_validator(mode="after")
824 def _validate_values_match_type(
825 self,
826 ) -> Self:
827 incompatible: List[Any] = []
828 for v in self.values:
829 if self.type == "bool":
830 if not isinstance(v, bool):
831 incompatible.append(v)
832 elif self.type in DTYPE_LIMITS:
833 if (
834 isinstance(v, (int, float))
835 and (
836 v < DTYPE_LIMITS[self.type].min
837 or v > DTYPE_LIMITS[self.type].max
838 )
839 or (isinstance(v, str) and "uint" not in self.type)
840 or (isinstance(v, float) and "int" in self.type)
841 ):
842 incompatible.append(v)
843 else:
844 incompatible.append(v)
846 if len(incompatible) == 5:
847 incompatible.append("...")
848 break
850 if incompatible:
851 raise ValueError(
852 f"data type '{self.type}' incompatible with values {incompatible}"
853 )
855 return self
857 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
859 @property
860 def range(self):
861 if isinstance(self.values[0], str):
862 return 0, len(self.values) - 1
863 else:
864 return min(self.values), max(self.values)
867IntervalOrRatioDType = Literal[
868 "float32",
869 "float64",
870 "uint8",
871 "int8",
872 "uint16",
873 "int16",
874 "uint32",
875 "int32",
876 "uint64",
877 "int64",
878]
881class IntervalOrRatioDataDescr(Node):
882 type: Annotated[ # TODO: rename to dtype
883 IntervalOrRatioDType,
884 Field(
885 examples=["float32", "float64", "uint8", "uint16"],
886 ),
887 ] = "float32"
888 range: Tuple[Optional[float], Optional[float]] = (
889 None,
890 None,
891 )
892 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
893 `None` corresponds to min/max of what can be expressed by **type**."""
894 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
895 scale: float = 1.0
896 """Scale for data on an interval (or ratio) scale."""
897 offset: Optional[float] = None
898 """Offset for data on a ratio scale."""
900 @model_validator(mode="before")
901 def _replace_inf(cls, data: Any):
902 if is_dict(data):
903 if "range" in data and is_sequence(data["range"]):
904 forbidden = (
905 "inf",
906 "-inf",
907 ".inf",
908 "-.inf",
909 float("inf"),
910 float("-inf"),
911 )
912 if any(v in forbidden for v in data["range"]):
913 issue_warning("replaced 'inf' value", value=data["range"])
915 data["range"] = tuple(
916 (None if v in forbidden else v) for v in data["range"]
917 )
919 return data
922TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
925class BinarizeKwargs(KwargsNode):
926 """key word arguments for [BinarizeDescr][]"""
928 threshold: float
929 """The fixed threshold"""
932class BinarizeAlongAxisKwargs(KwargsNode):
933 """key word arguments for [BinarizeDescr][]"""
935 threshold: NotEmpty[List[float]]
936 """The fixed threshold values along `axis`"""
938 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
939 """The `threshold` axis"""
942class BinarizeDescr(NodeWithExplicitlySetFields):
943 """Binarize the tensor with a fixed threshold.
945 Values above [BinarizeKwargs.threshold][]/[BinarizeAlongAxisKwargs.threshold][]
946 will be set to one, values below the threshold to zero.
948 Examples:
949 - in YAML
950 ```yaml
951 postprocessing:
952 - id: binarize
953 kwargs:
954 axis: 'channel'
955 threshold: [0.25, 0.5, 0.75]
956 ```
957 - in Python:
958 >>> postprocessing = [BinarizeDescr(
959 ... kwargs=BinarizeAlongAxisKwargs(
960 ... axis=AxisId('channel'),
961 ... threshold=[0.25, 0.5, 0.75],
962 ... )
963 ... )]
964 """
966 implemented_id: ClassVar[Literal["binarize"]] = "binarize"
967 if TYPE_CHECKING:
968 id: Literal["binarize"] = "binarize"
969 else:
970 id: Literal["binarize"]
971 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
974class ClipKwargs(KwargsNode):
975 """key word arguments for [ClipDescr][]"""
977 min: Optional[float] = None
978 """Minimum value for clipping.
980 Exclusive with [min_percentile][]
981 """
982 min_percentile: Optional[Annotated[float, Interval(ge=0, lt=100)]] = None
983 """Minimum percentile for clipping.
985 Exclusive with [min][].
987 In range [0, 100).
988 """
990 max: Optional[float] = None
991 """Maximum value for clipping.
993 Exclusive with `max_percentile`.
994 """
995 max_percentile: Optional[Annotated[float, Interval(gt=1, le=100)]] = None
996 """Maximum percentile for clipping.
998 Exclusive with `max`.
1000 In range (1, 100].
1001 """
1003 axes: Annotated[
1004 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1005 ] = None
1006 """The subset of axes to determine percentiles jointly,
1008 i.e. axes to reduce to compute min/max from `min_percentile`/`max_percentile`.
1009 For example to clip 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1010 resulting in a tensor of equal shape with clipped values per channel, specify `axes=('batch', 'x', 'y')`.
1011 To clip samples independently, leave out the 'batch' axis.
1013 Only valid if `min_percentile` and/or `max_percentile` are set.
1015 Default: Compute percentiles over all axes jointly."""
1017 @model_validator(mode="after")
1018 def _validate(self) -> Self:
1019 if (self.min is not None) and (self.min_percentile is not None):
1020 raise ValueError(
1021 "Only one of `min` and `min_percentile` may be set, not both."
1022 )
1023 if (self.max is not None) and (self.max_percentile is not None):
1024 raise ValueError(
1025 "Only one of `max` and `max_percentile` may be set, not both."
1026 )
1027 if (
1028 self.min is None
1029 and self.min_percentile is None
1030 and self.max is None
1031 and self.max_percentile is None
1032 ):
1033 raise ValueError(
1034 "At least one of `min`, `min_percentile`, `max`, or `max_percentile` must be set."
1035 )
1037 if (
1038 self.axes is not None
1039 and self.min_percentile is None
1040 and self.max_percentile is None
1041 ):
1042 raise ValueError(
1043 "If `axes` is set, at least one of `min_percentile` or `max_percentile` must be set."
1044 )
1046 return self
1049class ClipDescr(NodeWithExplicitlySetFields):
1050 """Set tensor values below min to min and above max to max.
1052 See `ScaleRangeDescr` for examples.
1053 """
1055 implemented_id: ClassVar[Literal["clip"]] = "clip"
1056 if TYPE_CHECKING:
1057 id: Literal["clip"] = "clip"
1058 else:
1059 id: Literal["clip"]
1061 kwargs: ClipKwargs
1064class EnsureDtypeKwargs(KwargsNode):
1065 """key word arguments for [EnsureDtypeDescr][]"""
1067 dtype: Literal[
1068 "float32",
1069 "float64",
1070 "uint8",
1071 "int8",
1072 "uint16",
1073 "int16",
1074 "uint32",
1075 "int32",
1076 "uint64",
1077 "int64",
1078 "bool",
1079 ]
1082class EnsureDtypeDescr(NodeWithExplicitlySetFields):
1083 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1085 This can for example be used to ensure the inner neural network model gets a
1086 different input tensor data type than the fully described bioimage.io model does.
1088 Examples:
1089 The described bioimage.io model (incl. preprocessing) accepts any
1090 float32-compatible tensor, normalizes it with percentiles and clipping and then
1091 casts it to uint8, which is what the neural network in this example expects.
1092 - in YAML
1093 ```yaml
1094 inputs:
1095 - data:
1096 type: float32 # described bioimage.io model is compatible with any float32 input tensor
1097 preprocessing:
1098 - id: scale_range
1099 kwargs:
1100 axes: ['y', 'x']
1101 max_percentile: 99.8
1102 min_percentile: 5.0
1103 - id: clip
1104 kwargs:
1105 min: 0.0
1106 max: 1.0
1107 - id: ensure_dtype # the neural network of the model requires uint8
1108 kwargs:
1109 dtype: uint8
1110 ```
1111 - in Python:
1112 >>> preprocessing = [
1113 ... ScaleRangeDescr(
1114 ... kwargs=ScaleRangeKwargs(
1115 ... axes= (AxisId('y'), AxisId('x')),
1116 ... max_percentile= 99.8,
1117 ... min_percentile= 5.0,
1118 ... )
1119 ... ),
1120 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1121 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1122 ... ]
1123 """
1125 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1126 if TYPE_CHECKING:
1127 id: Literal["ensure_dtype"] = "ensure_dtype"
1128 else:
1129 id: Literal["ensure_dtype"]
1131 kwargs: EnsureDtypeKwargs
1134class ScaleLinearKwargs(KwargsNode):
1135 """Key word arguments for [ScaleLinearDescr][]"""
1137 gain: float = 1.0
1138 """multiplicative factor"""
1140 offset: float = 0.0
1141 """additive term"""
1143 @model_validator(mode="after")
1144 def _validate(self) -> Self:
1145 if self.gain == 1.0 and self.offset == 0.0:
1146 raise ValueError(
1147 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1148 + " != 0.0."
1149 )
1151 return self
1154class ScaleLinearAlongAxisKwargs(KwargsNode):
1155 """Key word arguments for [ScaleLinearDescr][]"""
1157 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1158 """The axis of gain and offset values."""
1160 gain: Union[float, NotEmpty[List[float]]] = 1.0
1161 """multiplicative factor"""
1163 offset: Union[float, NotEmpty[List[float]]] = 0.0
1164 """additive term"""
1166 @model_validator(mode="after")
1167 def _validate(self) -> Self:
1168 if isinstance(self.gain, list):
1169 if isinstance(self.offset, list):
1170 if len(self.gain) != len(self.offset):
1171 raise ValueError(
1172 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1173 )
1174 else:
1175 self.offset = [float(self.offset)] * len(self.gain)
1176 elif isinstance(self.offset, list):
1177 self.gain = [float(self.gain)] * len(self.offset)
1178 else:
1179 raise ValueError(
1180 "Do not specify an `axis` for scalar gain and offset values."
1181 )
1183 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1184 raise ValueError(
1185 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1186 + " != 0.0."
1187 )
1189 return self
1192class ScaleLinearDescr(NodeWithExplicitlySetFields):
1193 """Fixed linear scaling.
1195 Examples:
1196 1. Scale with scalar gain and offset
1197 - in YAML
1198 ```yaml
1199 preprocessing:
1200 - id: scale_linear
1201 kwargs:
1202 gain: 2.0
1203 offset: 3.0
1204 ```
1205 - in Python:
1206 >>> preprocessing = [
1207 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1208 ... ]
1210 2. Independent scaling along an axis
1211 - in YAML
1212 ```yaml
1213 preprocessing:
1214 - id: scale_linear
1215 kwargs:
1216 axis: 'channel'
1217 gain: [1.0, 2.0, 3.0]
1218 ```
1219 - in Python:
1220 >>> preprocessing = [
1221 ... ScaleLinearDescr(
1222 ... kwargs=ScaleLinearAlongAxisKwargs(
1223 ... axis=AxisId("channel"),
1224 ... gain=[1.0, 2.0, 3.0],
1225 ... )
1226 ... )
1227 ... ]
1229 """
1231 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1232 if TYPE_CHECKING:
1233 id: Literal["scale_linear"] = "scale_linear"
1234 else:
1235 id: Literal["scale_linear"]
1236 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1239class SigmoidDescr(NodeWithExplicitlySetFields):
1240 """The logistic sigmoid function, a.k.a. expit function.
1242 Examples:
1243 - in YAML
1244 ```yaml
1245 postprocessing:
1246 - id: sigmoid
1247 ```
1248 - in Python:
1249 >>> postprocessing = [SigmoidDescr()]
1250 """
1252 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1253 if TYPE_CHECKING:
1254 id: Literal["sigmoid"] = "sigmoid"
1255 else:
1256 id: Literal["sigmoid"]
1258 @property
1259 def kwargs(self) -> KwargsNode:
1260 """empty kwargs"""
1261 return KwargsNode()
1264class SoftmaxKwargs(KwargsNode):
1265 """key word arguments for [SoftmaxDescr][]"""
1267 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1268 """The axis to apply the softmax function along.
1269 Note:
1270 Defaults to 'channel' axis
1271 (which may not exist, in which case
1272 a different axis id has to be specified).
1273 """
1276class SoftmaxDescr(NodeWithExplicitlySetFields):
1277 """The softmax function.
1279 Examples:
1280 - in YAML
1281 ```yaml
1282 postprocessing:
1283 - id: softmax
1284 kwargs:
1285 axis: channel
1286 ```
1287 - in Python:
1288 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1289 """
1291 implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1292 if TYPE_CHECKING:
1293 id: Literal["softmax"] = "softmax"
1294 else:
1295 id: Literal["softmax"]
1297 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
1300class _StardistPostprocessingKwargsBase(KwargsNode):
1301 """key word arguments for [StardistPostprocessingDescr][]"""
1303 prob_threshold: float
1304 """The probability threshold for object candidate selection."""
1306 nms_threshold: float
1307 """The IoU threshold for non-maximum suppression."""
1310class StardistPostprocessingKwargs2D(_StardistPostprocessingKwargsBase):
1311 grid: Tuple[int, int]
1312 """Grid size of network predictions."""
1314 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int]]]
1315 """Border region in which object probability is set to zero."""
1318class StardistPostprocessingKwargs3D(_StardistPostprocessingKwargsBase):
1319 grid: Tuple[int, int, int]
1320 """Grid size of network predictions."""
1322 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]]
1323 """Border region in which object probability is set to zero."""
1325 n_rays: int
1326 """Number of rays for 3D star-convex polyhedra."""
1328 anisotropy: Tuple[float, float, float]
1329 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis."""
1331 overlap_label: Optional[int] = None
1332 """Optional label to apply to any area of overlapping predicted objects."""
1335class StardistPostprocessingDescr(NodeWithExplicitlySetFields):
1336 """Stardist postprocessing including non-maximum suppression and converting polygon representations to instance labels
1338 as described in:
1339 - Uwe Schmidt, Martin Weigert, Coleman Broaddus, and Gene Myers.
1340 [*Cell Detection with Star-convex Polygons*](https://arxiv.org/abs/1806.03535).
1341 International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018.
1342 - Martin Weigert, Uwe Schmidt, Robert Haase, Ko Sugawara, and Gene Myers.
1343 [*Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy*](http://openaccess.thecvf.com/content_WACV_2020/papers/Weigert_Star-convex_Polyhedra_for_3D_Object_Detection_and_Segmentation_in_Microscopy_WACV_2020_paper.pdf).
1344 The IEEE Winter Conference on Applications of Computer Vision (WACV), Snowmass Village, Colorado, March 2020.
1346 Note: Only available if the `stardist` package is installed.
1347 """
1349 implemented_id: ClassVar[Literal["stardist_postprocessing"]] = (
1350 "stardist_postprocessing"
1351 )
1352 if TYPE_CHECKING:
1353 id: Literal["stardist_postprocessing"] = "stardist_postprocessing"
1354 else:
1355 id: Literal["stardist_postprocessing"]
1357 kwargs: Union[StardistPostprocessingKwargs2D, StardistPostprocessingKwargs3D]
1360class FixedZeroMeanUnitVarianceKwargs(KwargsNode):
1361 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]"""
1363 mean: float
1364 """The mean value to normalize with."""
1366 std: Annotated[float, Ge(1e-6)]
1367 """The standard deviation value to normalize with."""
1370class FixedZeroMeanUnitVarianceAlongAxisKwargs(KwargsNode):
1371 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]"""
1373 mean: NotEmpty[List[float]]
1374 """The mean value(s) to normalize with."""
1376 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1377 """The standard deviation value(s) to normalize with.
1378 Size must match `mean` values."""
1380 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1381 """The axis of the mean/std values to normalize each entry along that dimension
1382 separately."""
1384 @model_validator(mode="after")
1385 def _mean_and_std_match(self) -> Self:
1386 if len(self.mean) != len(self.std):
1387 raise ValueError(
1388 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1389 + " must match."
1390 )
1392 return self
1395class FixedZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields):
1396 """Subtract a given mean and divide by the standard deviation.
1398 Normalize with fixed, precomputed values for
1399 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1400 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1401 axes.
1403 Examples:
1404 1. scalar value for whole tensor
1405 - in YAML
1406 ```yaml
1407 preprocessing:
1408 - id: fixed_zero_mean_unit_variance
1409 kwargs:
1410 mean: 103.5
1411 std: 13.7
1412 ```
1413 - in Python
1414 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1415 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1416 ... )]
1418 2. independently along an axis
1419 - in YAML
1420 ```yaml
1421 preprocessing:
1422 - id: fixed_zero_mean_unit_variance
1423 kwargs:
1424 axis: channel
1425 mean: [101.5, 102.5, 103.5]
1426 std: [11.7, 12.7, 13.7]
1427 ```
1428 - in Python
1429 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1430 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1431 ... axis=AxisId("channel"),
1432 ... mean=[101.5, 102.5, 103.5],
1433 ... std=[11.7, 12.7, 13.7],
1434 ... )
1435 ... )]
1436 """
1438 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1439 "fixed_zero_mean_unit_variance"
1440 )
1441 if TYPE_CHECKING:
1442 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1443 else:
1444 id: Literal["fixed_zero_mean_unit_variance"]
1446 kwargs: Union[
1447 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1448 ]
1451class ZeroMeanUnitVarianceKwargs(KwargsNode):
1452 """key word arguments for [ZeroMeanUnitVarianceDescr][]"""
1454 axes: Annotated[
1455 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1456 ] = None
1457 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1458 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1459 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1460 To normalize each sample independently leave out the 'batch' axis.
1461 Default: Scale all axes jointly."""
1463 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1464 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1467class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields):
1468 """Subtract mean and divide by variance.
1470 Examples:
1471 Subtract tensor mean and variance
1472 - in YAML
1473 ```yaml
1474 preprocessing:
1475 - id: zero_mean_unit_variance
1476 ```
1477 - in Python
1478 >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1479 """
1481 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1482 "zero_mean_unit_variance"
1483 )
1484 if TYPE_CHECKING:
1485 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1486 else:
1487 id: Literal["zero_mean_unit_variance"]
1489 kwargs: ZeroMeanUnitVarianceKwargs = Field(
1490 default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1491 )
1494class ScaleRangeKwargs(KwargsNode):
1495 """key word arguments for [ScaleRangeDescr][]
1497 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1498 this processing step normalizes data to the [0, 1] intervall.
1499 For other percentiles the normalized values will partially be outside the [0, 1]
1500 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1501 normalized values to a range.
1502 """
1504 axes: Annotated[
1505 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1506 ] = None
1507 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1508 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1509 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1510 To normalize samples independently, leave out the "batch" axis.
1511 Default: Scale all axes jointly."""
1513 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1514 """The lower percentile used to determine the value to align with zero."""
1516 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1517 """The upper percentile used to determine the value to align with one.
1518 Has to be bigger than `min_percentile`.
1519 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1520 accepting percentiles specified in the range 0.0 to 1.0."""
1522 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1523 """Epsilon for numeric stability.
1524 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1525 with `v_lower,v_upper` values at the respective percentiles."""
1527 reference_tensor: Optional[TensorId] = None
1528 """ID of the unprocessed input tensor to compute the percentiles from.
1529 Default: The tensor itself.
1530 """
1532 @field_validator("max_percentile", mode="after")
1533 @classmethod
1534 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1535 if (min_p := info.data["min_percentile"]) >= value:
1536 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1538 return value
1541class ScaleRangeDescr(NodeWithExplicitlySetFields):
1542 """Scale with percentiles.
1544 Examples:
1545 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1546 - in YAML
1547 ```yaml
1548 preprocessing:
1549 - id: scale_range
1550 kwargs:
1551 axes: ['y', 'x']
1552 max_percentile: 99.8
1553 min_percentile: 5.0
1554 ```
1555 - in Python
1556 >>> preprocessing = [
1557 ... ScaleRangeDescr(
1558 ... kwargs=ScaleRangeKwargs(
1559 ... axes= (AxisId('y'), AxisId('x')),
1560 ... max_percentile= 99.8,
1561 ... min_percentile= 5.0,
1562 ... )
1563 ... )
1564 ... ]
1566 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1567 - in YAML
1568 ```yaml
1569 preprocessing:
1570 - id: scale_range
1571 kwargs:
1572 axes: ['y', 'x']
1573 max_percentile: 99.8
1574 min_percentile: 5.0
1575 - id: scale_range
1576 - id: clip
1577 kwargs:
1578 min: 0.0
1579 max: 1.0
1580 ```
1581 - in Python
1582 >>> preprocessing = [
1583 ... ScaleRangeDescr(
1584 ... kwargs=ScaleRangeKwargs(
1585 ... axes= (AxisId('y'), AxisId('x')),
1586 ... max_percentile= 99.8,
1587 ... min_percentile= 5.0,
1588 ... )
1589 ... ),
1590 ... ClipDescr(
1591 ... kwargs=ClipKwargs(
1592 ... min=0.0,
1593 ... max=1.0,
1594 ... )
1595 ... ),
1596 ... ]
1598 """
1600 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1601 if TYPE_CHECKING:
1602 id: Literal["scale_range"] = "scale_range"
1603 else:
1604 id: Literal["scale_range"]
1605 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
1608class ScaleMeanVarianceKwargs(KwargsNode):
1609 """key word arguments for [ScaleMeanVarianceKwargs][]"""
1611 reference_tensor: TensorId
1612 """ID of unprocessed input tensor to match."""
1614 axes: Annotated[
1615 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1616 ] = None
1617 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1618 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1619 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1620 To normalize samples independently, leave out the 'batch' axis.
1621 Default: Scale all axes jointly."""
1623 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1624 """Epsilon for numeric stability:
1625 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1628class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields):
1629 """Scale a tensor's data distribution to match another tensor's mean/std.
1630 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1631 """
1633 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1634 if TYPE_CHECKING:
1635 id: Literal["scale_mean_variance"] = "scale_mean_variance"
1636 else:
1637 id: Literal["scale_mean_variance"]
1638 kwargs: ScaleMeanVarianceKwargs
1641PreprocessingDescr = Annotated[
1642 Union[
1643 BinarizeDescr,
1644 ClipDescr,
1645 EnsureDtypeDescr,
1646 FixedZeroMeanUnitVarianceDescr,
1647 ScaleLinearDescr,
1648 ScaleRangeDescr,
1649 SigmoidDescr,
1650 SoftmaxDescr,
1651 ZeroMeanUnitVarianceDescr,
1652 ],
1653 Discriminator("id"),
1654]
1655PostprocessingDescr = Annotated[
1656 Union[
1657 BinarizeDescr,
1658 ClipDescr,
1659 EnsureDtypeDescr,
1660 FixedZeroMeanUnitVarianceDescr,
1661 ScaleLinearDescr,
1662 ScaleMeanVarianceDescr,
1663 ScaleRangeDescr,
1664 SigmoidDescr,
1665 SoftmaxDescr,
1666 StardistPostprocessingDescr,
1667 ZeroMeanUnitVarianceDescr,
1668 ],
1669 Discriminator("id"),
1670]
1672IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1675class TensorDescrBase(Node, Generic[IO_AxisT]):
1676 id: TensorId
1677 """Tensor id. No duplicates are allowed."""
1679 description: Annotated[str, MaxLen(128)] = ""
1680 """free text description"""
1682 axes: NotEmpty[Sequence[IO_AxisT]]
1683 """tensor axes"""
1685 @property
1686 def shape(self):
1687 return tuple(a.size for a in self.axes)
1689 @field_validator("axes", mode="after", check_fields=False)
1690 @classmethod
1691 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1692 batch_axes = [a for a in axes if a.type == "batch"]
1693 if len(batch_axes) > 1:
1694 raise ValueError(
1695 f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1696 )
1698 seen_ids: Set[AxisId] = set()
1699 duplicate_axes_ids: Set[AxisId] = set()
1700 for a in axes:
1701 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1703 if duplicate_axes_ids:
1704 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1706 return axes
1708 test_tensor: FAIR[Optional[FileDescr_]] = None
1709 """An example tensor to use for testing.
1710 Using the model with the test input tensors is expected to yield the test output tensors.
1711 Each test tensor has be a an ndarray in the
1712 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1713 The file extension must be '.npy'."""
1715 sample_tensor: FAIR[Optional[FileDescr_]] = None
1716 """A sample tensor to illustrate a possible input/output for the model,
1717 The sample image primarily serves to inform a human user about an example use case
1718 and is typically stored as .hdf5, .png or .tiff.
1719 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1720 (numpy's `.npy` format is not supported).
1721 The image dimensionality has to match the number of axes specified in this tensor description.
1722 """
1724 @model_validator(mode="after")
1725 def _validate_sample_tensor(self) -> Self:
1726 if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1727 return self
1729 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1730 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType]
1731 reader.read(),
1732 extension=PurePosixPath(reader.original_file_name).suffix,
1733 )
1734 n_dims = len(tensor.squeeze().shape)
1735 n_dims_min = n_dims_max = len(self.axes)
1737 for a in self.axes:
1738 if isinstance(a, BatchAxis):
1739 n_dims_min -= 1
1740 elif isinstance(a.size, int):
1741 if a.size == 1:
1742 n_dims_min -= 1
1743 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1744 if a.size.min == 1:
1745 n_dims_min -= 1
1746 elif isinstance(a.size, SizeReference):
1747 if a.size.offset < 2:
1748 # size reference may result in singleton axis
1749 n_dims_min -= 1
1750 else:
1751 assert_never(a.size)
1753 n_dims_min = max(0, n_dims_min)
1754 if n_dims < n_dims_min or n_dims > n_dims_max:
1755 raise ValueError(
1756 f"Expected sample tensor to have {n_dims_min} to"
1757 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1758 )
1760 return self
1762 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1763 IntervalOrRatioDataDescr()
1764 )
1765 """Description of the tensor's data values, optionally per channel.
1766 If specified per channel, the data `type` needs to match across channels."""
1768 @property
1769 def dtype(
1770 self,
1771 ) -> Literal[
1772 "float32",
1773 "float64",
1774 "uint8",
1775 "int8",
1776 "uint16",
1777 "int16",
1778 "uint32",
1779 "int32",
1780 "uint64",
1781 "int64",
1782 "bool",
1783 ]:
1784 """dtype as specified under `data.type` or `data[i].type`"""
1785 if isinstance(self.data, collections.abc.Sequence):
1786 return self.data[0].type
1787 else:
1788 return self.data.type
1790 @field_validator("data", mode="after")
1791 @classmethod
1792 def _check_data_type_across_channels(
1793 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1794 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1795 if not isinstance(value, list):
1796 return value
1798 dtypes = {t.type for t in value}
1799 if len(dtypes) > 1:
1800 raise ValueError(
1801 "Tensor data descriptions per channel need to agree in their data"
1802 + f" `type`, but found {dtypes}."
1803 )
1805 return value
1807 @model_validator(mode="after")
1808 def _check_data_matches_channelaxis(self) -> Self:
1809 if not isinstance(self.data, (list, tuple)):
1810 return self
1812 for a in self.axes:
1813 if isinstance(a, ChannelAxis):
1814 size = a.size
1815 assert isinstance(size, int)
1816 break
1817 else:
1818 return self
1820 if len(self.data) != size:
1821 raise ValueError(
1822 f"Got tensor data descriptions for {len(self.data)} channels, but"
1823 + f" '{a.id}' axis has size {size}."
1824 )
1826 return self
1828 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1829 if len(array.shape) != len(self.axes):
1830 raise ValueError(
1831 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1832 + f" incompatible with {len(self.axes)} axes."
1833 )
1834 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1837class InputTensorDescr(TensorDescrBase[InputAxis]):
1838 id: TensorId = TensorId("input")
1839 """Input tensor id.
1840 No duplicates are allowed across all inputs and outputs."""
1842 optional: bool = False
1843 """indicates that this tensor may be `None`"""
1845 preprocessing: List[PreprocessingDescr] = Field(
1846 default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1847 )
1849 """Description of how this input should be preprocessed.
1851 notes:
1852 - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1853 to ensure an input tensor's data type matches the input tensor's data description.
1854 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1855 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1856 changing the data type.
1857 """
1859 @model_validator(mode="after")
1860 def _validate_preprocessing_kwargs(self) -> Self:
1861 axes_ids = [a.id for a in self.axes]
1862 for p in self.preprocessing:
1863 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1864 if kwargs_axes is None:
1865 continue
1867 if not isinstance(kwargs_axes, collections.abc.Sequence):
1868 raise ValueError(
1869 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1870 )
1872 if any(a not in axes_ids for a in kwargs_axes):
1873 raise ValueError(
1874 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1875 )
1877 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1878 dtype = self.data.type
1879 else:
1880 dtype = self.data[0].type
1882 # ensure `preprocessing` begins with `EnsureDtypeDescr`
1883 if not self.preprocessing or not isinstance(
1884 self.preprocessing[0], EnsureDtypeDescr
1885 ):
1886 self.preprocessing.insert(
1887 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1888 )
1890 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1891 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1892 self.preprocessing.append(
1893 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1894 )
1896 return self
1899def convert_axes(
1900 axes: str,
1901 *,
1902 shape: Union[
1903 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1904 ],
1905 tensor_type: Literal["input", "output"],
1906 halo: Optional[Sequence[int]],
1907 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1908):
1909 ret: List[AnyAxis] = []
1910 for i, a in enumerate(axes):
1911 axis_type = _AXIS_TYPE_MAP.get(a, a)
1912 if axis_type == "batch":
1913 ret.append(BatchAxis())
1914 continue
1916 scale = 1.0
1917 if isinstance(shape, _ParameterizedInputShape_v0_4):
1918 if shape.step[i] == 0:
1919 size = shape.min[i]
1920 else:
1921 size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1922 elif isinstance(shape, _ImplicitOutputShape_v0_4):
1923 ref_t = str(shape.reference_tensor)
1924 if ref_t.count(".") == 1:
1925 t_id, orig_a_id = ref_t.split(".")
1926 else:
1927 t_id = ref_t
1928 orig_a_id = a
1930 a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1931 if not (orig_scale := shape.scale[i]):
1932 # old way to insert a new axis dimension
1933 size = int(2 * shape.offset[i])
1934 else:
1935 scale = 1 / orig_scale
1936 if axis_type in ("channel", "index"):
1937 # these axes no longer have a scale
1938 offset_from_scale = orig_scale * size_refs.get(
1939 _TensorName_v0_4(t_id), {}
1940 ).get(orig_a_id, 0)
1941 else:
1942 offset_from_scale = 0
1943 size = SizeReference(
1944 tensor_id=TensorId(t_id),
1945 axis_id=AxisId(a_id),
1946 offset=int(offset_from_scale + 2 * shape.offset[i]),
1947 )
1948 else:
1949 size = shape[i]
1951 if axis_type == "time":
1952 if tensor_type == "input":
1953 ret.append(TimeInputAxis(size=size, scale=scale))
1954 else:
1955 assert not isinstance(size, ParameterizedSize)
1956 if halo is None:
1957 ret.append(TimeOutputAxis(size=size, scale=scale))
1958 else:
1959 assert not isinstance(size, int)
1960 ret.append(
1961 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1962 )
1964 elif axis_type == "index":
1965 if tensor_type == "input":
1966 ret.append(IndexInputAxis(size=size))
1967 else:
1968 if isinstance(size, ParameterizedSize):
1969 size = DataDependentSize(min=size.min)
1971 ret.append(IndexOutputAxis(size=size))
1972 elif axis_type == "channel":
1973 assert not isinstance(size, ParameterizedSize)
1974 if isinstance(size, SizeReference):
1975 warnings.warn(
1976 "Conversion of channel size from an implicit output shape may be"
1977 + " wrong"
1978 )
1979 ret.append(
1980 ChannelAxis(
1981 channel_names=[
1982 Identifier(f"channel{i}") for i in range(size.offset)
1983 ]
1984 )
1985 )
1986 else:
1987 ret.append(
1988 ChannelAxis(
1989 channel_names=[Identifier(f"channel{i}") for i in range(size)]
1990 )
1991 )
1992 elif axis_type == "space":
1993 if tensor_type == "input":
1994 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1995 else:
1996 assert not isinstance(size, ParameterizedSize)
1997 if halo is None or halo[i] == 0:
1998 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1999 elif isinstance(size, int):
2000 raise NotImplementedError(
2001 f"output axis with halo and fixed size (here {size}) not allowed"
2002 )
2003 else:
2004 ret.append(
2005 SpaceOutputAxisWithHalo(
2006 id=AxisId(a), size=size, scale=scale, halo=halo[i]
2007 )
2008 )
2010 return ret
2013def _axes_letters_to_ids(
2014 axes: Optional[str],
2015) -> Optional[List[AxisId]]:
2016 if axes is None:
2017 return None
2019 return [AxisId(a) for a in axes]
2022def _get_complement_v04_axis(
2023 tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
2024) -> Optional[AxisId]:
2025 if axes is None:
2026 return None
2028 non_complement_axes = set(axes) | {"b"}
2029 complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
2030 if len(complement_axes) > 1:
2031 raise ValueError(
2032 f"Expected none or a single complement axis, but axes '{axes}' "
2033 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
2034 )
2036 return None if not complement_axes else AxisId(complement_axes[0])
2039def _convert_proc(
2040 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
2041 tensor_axes: Sequence[str],
2042) -> Union[PreprocessingDescr, PostprocessingDescr]:
2043 if isinstance(p, _BinarizeDescr_v0_4):
2044 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
2045 elif isinstance(p, _ClipDescr_v0_4):
2046 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
2047 elif isinstance(p, _SigmoidDescr_v0_4):
2048 return SigmoidDescr()
2049 elif isinstance(p, _ScaleLinearDescr_v0_4):
2050 axes = _axes_letters_to_ids(p.kwargs.axes)
2051 if p.kwargs.axes is None:
2052 axis = None
2053 else:
2054 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
2056 if axis is None:
2057 assert not isinstance(p.kwargs.gain, list)
2058 assert not isinstance(p.kwargs.offset, list)
2059 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
2060 else:
2061 kwargs = ScaleLinearAlongAxisKwargs(
2062 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
2063 )
2064 return ScaleLinearDescr(kwargs=kwargs)
2065 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
2066 return ScaleMeanVarianceDescr(
2067 kwargs=ScaleMeanVarianceKwargs(
2068 axes=_axes_letters_to_ids(p.kwargs.axes),
2069 reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
2070 eps=p.kwargs.eps,
2071 )
2072 )
2073 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
2074 if p.kwargs.mode == "fixed":
2075 mean = p.kwargs.mean
2076 std = p.kwargs.std
2077 assert mean is not None
2078 assert std is not None
2080 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
2082 if axis is None:
2083 if isinstance(mean, list):
2084 raise ValueError("Expected single float value for mean, not <list>")
2085 if isinstance(std, list):
2086 raise ValueError("Expected single float value for std, not <list>")
2087 return FixedZeroMeanUnitVarianceDescr(
2088 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct(
2089 mean=mean,
2090 std=std,
2091 )
2092 )
2093 else:
2094 if not isinstance(mean, list):
2095 mean = [float(mean)]
2096 if not isinstance(std, list):
2097 std = [float(std)]
2099 return FixedZeroMeanUnitVarianceDescr(
2100 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
2101 axis=axis, mean=mean, std=std
2102 )
2103 )
2105 else:
2106 axes = _axes_letters_to_ids(p.kwargs.axes) or []
2107 if p.kwargs.mode == "per_dataset":
2108 axes = [AxisId("batch")] + axes
2109 if not axes:
2110 axes = None
2111 return ZeroMeanUnitVarianceDescr(
2112 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
2113 )
2115 elif isinstance(p, _ScaleRangeDescr_v0_4):
2116 return ScaleRangeDescr(
2117 kwargs=ScaleRangeKwargs(
2118 axes=_axes_letters_to_ids(p.kwargs.axes),
2119 min_percentile=p.kwargs.min_percentile,
2120 max_percentile=p.kwargs.max_percentile,
2121 eps=p.kwargs.eps,
2122 )
2123 )
2124 else:
2125 assert_never(p)
2128class _InputTensorConv(
2129 Converter[
2130 _InputTensorDescr_v0_4,
2131 InputTensorDescr,
2132 FileSource_,
2133 Optional[FileSource_],
2134 Mapping[_TensorName_v0_4, Mapping[str, int]],
2135 ]
2136):
2137 def _convert(
2138 self,
2139 src: _InputTensorDescr_v0_4,
2140 tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
2141 test_tensor: FileSource_,
2142 sample_tensor: Optional[FileSource_],
2143 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2144 ) -> "InputTensorDescr | dict[str, Any]":
2145 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2146 src.axes,
2147 shape=src.shape,
2148 tensor_type="input",
2149 halo=None,
2150 size_refs=size_refs,
2151 )
2152 prep: List[PreprocessingDescr] = []
2153 for p in src.preprocessing:
2154 cp = _convert_proc(p, src.axes)
2155 assert not isinstance(
2156 cp, (ScaleMeanVarianceDescr, StardistPostprocessingDescr)
2157 )
2158 prep.append(cp)
2160 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
2162 return tgt(
2163 axes=axes,
2164 id=TensorId(str(src.name)),
2165 test_tensor=FileDescr(source=test_tensor),
2166 sample_tensor=(
2167 None if sample_tensor is None else FileDescr(source=sample_tensor)
2168 ),
2169 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType]
2170 preprocessing=prep,
2171 )
2174_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
2177class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2178 id: TensorId = TensorId("output")
2179 """Output tensor id.
2180 No duplicates are allowed across all inputs and outputs."""
2182 postprocessing: List[PostprocessingDescr] = Field(
2183 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2184 )
2185 """Description of how this output should be postprocessed.
2187 note: `postprocessing` always ends with an 'ensure_dtype' operation.
2188 If not given this is added to cast to this tensor's `data.type`.
2189 """
2191 @model_validator(mode="after")
2192 def _validate_postprocessing_kwargs(self) -> Self:
2193 axes_ids = [a.id for a in self.axes]
2194 for p in self.postprocessing:
2195 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2196 if kwargs_axes is None:
2197 continue
2199 if not isinstance(kwargs_axes, collections.abc.Sequence):
2200 raise ValueError(
2201 f"expected `axes` sequence, but got {type(kwargs_axes)}"
2202 )
2204 if any(a not in axes_ids for a in kwargs_axes):
2205 raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2207 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2208 dtype = self.data.type
2209 else:
2210 dtype = self.data[0].type
2212 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2213 if not self.postprocessing or not isinstance(
2214 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2215 ):
2216 self.postprocessing.append(
2217 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2218 )
2219 return self
2222class _OutputTensorConv(
2223 Converter[
2224 _OutputTensorDescr_v0_4,
2225 OutputTensorDescr,
2226 FileSource_,
2227 Optional[FileSource_],
2228 Mapping[_TensorName_v0_4, Mapping[str, int]],
2229 ]
2230):
2231 def _convert(
2232 self,
2233 src: _OutputTensorDescr_v0_4,
2234 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2235 test_tensor: FileSource_,
2236 sample_tensor: Optional[FileSource_],
2237 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2238 ) -> "OutputTensorDescr | dict[str, Any]":
2239 # TODO: split convert_axes into convert_output_axes and convert_input_axes
2240 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2241 src.axes,
2242 shape=src.shape,
2243 tensor_type="output",
2244 halo=src.halo,
2245 size_refs=size_refs,
2246 )
2247 data_descr: Dict[str, Any] = dict(type=src.data_type)
2248 if data_descr["type"] == "bool":
2249 data_descr["values"] = [False, True]
2251 return tgt(
2252 axes=axes,
2253 id=TensorId(str(src.name)),
2254 test_tensor=FileDescr(source=test_tensor),
2255 sample_tensor=(
2256 None if sample_tensor is None else FileDescr(source=sample_tensor)
2257 ),
2258 data=data_descr, # pyright: ignore[reportArgumentType]
2259 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2260 )
2263_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2266TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2269def validate_tensors(
2270 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2271 tensor_origin: Literal[
2272 "source", "test_tensor"
2273 ] = "source", # for more precise error messages
2274):
2275 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2277 def e_msg_location(d: TensorDescr):
2278 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2280 for descr, array in tensors.values():
2281 if array is None:
2282 axis_sizes = {a.id: None for a in descr.axes}
2283 else:
2284 try:
2285 axis_sizes = descr.get_axis_sizes_for_array(array)
2286 except ValueError as e:
2287 raise ValueError(f"{e_msg_location(descr)} {e}")
2289 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2291 for descr, array in tensors.values():
2292 if array is None:
2293 continue
2295 if descr.dtype in ("float32", "float64"):
2296 invalid_test_tensor_dtype = array.dtype.name not in (
2297 "float32",
2298 "float64",
2299 "uint8",
2300 "int8",
2301 "uint16",
2302 "int16",
2303 "uint32",
2304 "int32",
2305 "uint64",
2306 "int64",
2307 )
2308 else:
2309 invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2311 if invalid_test_tensor_dtype:
2312 raise ValueError(
2313 f"{tensor_origin} data type '{array.dtype.name}' does not"
2314 + f" match described {e_msg_location(descr)}.dtype '{descr.dtype}'"
2315 )
2317 if array.min() > -1e-4 and array.max() < 1e-4:
2318 raise ValueError(
2319 "Output values are too small for reliable testing."
2320 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2321 )
2323 for a in descr.axes:
2324 actual_size = all_tensor_axes[descr.id][a.id][1]
2325 if actual_size is None:
2326 continue
2328 if a.size is None:
2329 continue
2331 if isinstance(a.size, int):
2332 if actual_size != a.size:
2333 raise ValueError(
2334 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis "
2335 + f"has incompatible size {actual_size}, expected {a.size}"
2336 )
2337 elif isinstance(a.size, ParameterizedSize):
2338 _ = a.size.validate_size(
2339 actual_size,
2340 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis ",
2341 )
2342 elif isinstance(a.size, DataDependentSize):
2343 _ = a.size.validate_size(
2344 actual_size,
2345 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis ",
2346 )
2347 elif isinstance(a.size, SizeReference):
2348 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2349 if ref_tensor_axes is None:
2350 raise ValueError(
2351 f"{e_msg_location(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2352 + f" reference '{a.size.tensor_id}', available: {list(all_tensor_axes)}"
2353 )
2355 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2356 if ref_axis is None or ref_size is None:
2357 raise ValueError(
2358 f"{e_msg_location(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2359 + f" reference '{a.size.tensor_id}.{a.size.axis_id}, available: {list(ref_tensor_axes)}"
2360 )
2362 if a.unit != ref_axis.unit:
2363 raise ValueError(
2364 f"{e_msg_location(descr)}.axes[{a.id}].size: `SizeReference` requires"
2365 + " axis and reference axis to have the same `unit`, but"
2366 + f" {a.unit}!={ref_axis.unit}"
2367 )
2369 if actual_size != (
2370 expected_size := (
2371 ref_size * ref_axis.scale / a.scale + a.size.offset
2372 )
2373 ):
2374 raise ValueError(
2375 f"{e_msg_location(descr)}.{tensor_origin}: axis '{a.id}' of size"
2376 + f" {actual_size} invalid for referenced size {ref_size};"
2377 + f" expected {expected_size}"
2378 )
2379 else:
2380 assert_never(a.size)
2383FileDescr_dependencies = Annotated[
2384 FileDescr_,
2385 WithSuffix((".yaml", ".yml"), case_sensitive=True),
2386 Field(examples=[dict(source="environment.yaml")]),
2387]
2390class _ArchitectureCallableDescr(Node):
2391 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2392 """Identifier of the callable that returns a torch.nn.Module instance."""
2394 kwargs: Dict[str, YamlValue] = Field(
2395 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2396 )
2397 """key word arguments for the `callable`"""
2400class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2401 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2402 """Architecture source file"""
2404 @model_serializer(mode="wrap", when_used="unless-none")
2405 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2406 return package_file_descr_serializer(self, nxt, info)
2409class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2410 import_from: str
2411 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2414class _ArchFileConv(
2415 Converter[
2416 _CallableFromFile_v0_4,
2417 ArchitectureFromFileDescr,
2418 Optional[Sha256],
2419 Dict[str, Any],
2420 ]
2421):
2422 def _convert(
2423 self,
2424 src: _CallableFromFile_v0_4,
2425 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2426 sha256: Optional[Sha256],
2427 kwargs: Dict[str, Any],
2428 ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2429 if src.startswith("http") and src.count(":") == 2:
2430 http, source, callable_ = src.split(":")
2431 source = ":".join((http, source))
2432 elif not src.startswith("http") and src.count(":") == 1:
2433 source, callable_ = src.split(":")
2434 else:
2435 source = str(src)
2436 callable_ = str(src)
2437 return tgt(
2438 callable=Identifier(callable_),
2439 source=cast(FileSource_, source),
2440 sha256=sha256,
2441 kwargs=kwargs,
2442 )
2445_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2448class _ArchLibConv(
2449 Converter[
2450 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2451 ]
2452):
2453 def _convert(
2454 self,
2455 src: _CallableFromDepencency_v0_4,
2456 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2457 kwargs: Dict[str, Any],
2458 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2459 *mods, callable_ = src.split(".")
2460 import_from = ".".join(mods)
2461 return tgt(
2462 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2463 )
2466_arch_lib_conv = _ArchLibConv(
2467 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2468)
2471class WeightsEntryDescrBase(FileDescr):
2472 type: ClassVar[WeightsFormat]
2473 weights_format_name: ClassVar[str] # human readable
2475 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2476 """Source of the weights file."""
2478 authors: Optional[List[Author]] = None
2479 """Authors
2480 Either the person(s) that have trained this model resulting in the original weights file.
2481 (If this is the initial weights entry, i.e. it does not have a `parent`)
2482 Or the person(s) who have converted the weights to this weights format.
2483 (If this is a child weight, i.e. it has a `parent` field)
2484 """
2486 parent: Annotated[
2487 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2488 ] = None
2489 """The source weights these weights were converted from.
2490 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2491 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2492 All weight entries except one (the initial set of weights resulting from training the model),
2493 need to have this field."""
2495 comment: str = ""
2496 """A comment about this weights entry, for example how these weights were created."""
2498 @model_validator(mode="after")
2499 def _validate(self) -> Self:
2500 if self.type == self.parent:
2501 raise ValueError("Weights entry can't be it's own parent.")
2503 return self
2505 @model_serializer(mode="wrap", when_used="unless-none")
2506 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2507 return package_file_descr_serializer(self, nxt, info)
2510class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2511 type: ClassVar[WeightsFormat] = "keras_hdf5"
2512 weights_format_name: ClassVar[str] = "Keras HDF5"
2513 tensorflow_version: Version
2514 """TensorFlow version used to create these weights."""
2517class KerasV3WeightsDescr(WeightsEntryDescrBase):
2518 type: ClassVar[WeightsFormat] = "keras_v3"
2519 weights_format_name: ClassVar[str] = "Keras v3"
2520 keras_version: Annotated[Version, Ge(Version(3))]
2521 """Keras version used to create these weights."""
2522 backend: Tuple[Literal["tensorflow", "jax", "torch"], Version]
2523 """Keras backend used to create these weights."""
2524 source: Annotated[
2525 FileSource,
2526 AfterValidator(wo_special_file_name),
2527 WithSuffix(".keras", case_sensitive=True),
2528 ]
2529 """Source of the .keras weights file."""
2532FileDescr_external_data = Annotated[
2533 FileDescr_,
2534 WithSuffix(".data", case_sensitive=True),
2535 Field(examples=[dict(source="weights.onnx.data")]),
2536]
2539class OnnxWeightsDescr(WeightsEntryDescrBase):
2540 type: ClassVar[WeightsFormat] = "onnx"
2541 weights_format_name: ClassVar[str] = "ONNX"
2542 opset_version: Annotated[int, Ge(7)]
2543 """ONNX opset version"""
2545 external_data: Optional[FileDescr_external_data] = None
2546 """Source of the external ONNX data file holding the weights.
2547 (If present **source** holds the ONNX architecture without weights)."""
2549 @model_validator(mode="after")
2550 def _validate_external_data_unique_file_name(self) -> Self:
2551 if self.external_data is not None and (
2552 extract_file_name(self.source)
2553 == extract_file_name(self.external_data.source)
2554 ):
2555 raise ValueError(
2556 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'"
2557 + " must be different from ONNX `source` file name."
2558 )
2560 return self
2563class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2564 type: ClassVar[WeightsFormat] = "pytorch_state_dict"
2565 weights_format_name: ClassVar[str] = "Pytorch State Dict"
2566 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2567 pytorch_version: Version
2568 """Version of the PyTorch library used.
2569 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2570 """
2571 dependencies: Optional[FileDescr_dependencies] = None
2572 """Custom depencies beyond pytorch described in a Conda environment file.
2573 Allows to specify custom dependencies, see conda docs:
2574 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2575 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2577 The conda environment file should include pytorch and any version pinning has to be compatible with
2578 **pytorch_version**.
2579 """
2582class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2583 type: ClassVar[WeightsFormat] = "tensorflow_js"
2584 weights_format_name: ClassVar[str] = "Tensorflow.js"
2585 tensorflow_version: Version
2586 """Version of the TensorFlow library used."""
2588 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2589 """The multi-file weights.
2590 All required files/folders should be a zip archive."""
2593class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2594 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle"
2595 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2596 tensorflow_version: Version
2597 """Version of the TensorFlow library used."""
2599 dependencies: Optional[FileDescr_dependencies] = None
2600 """Custom dependencies beyond tensorflow.
2601 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2603 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2604 """The multi-file weights.
2605 All required files/folders should be a zip archive."""
2608class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2609 type: ClassVar[WeightsFormat] = "torchscript"
2610 weights_format_name: ClassVar[str] = "TorchScript"
2611 pytorch_version: Version
2612 """Version of the PyTorch library used."""
2615SpecificWeightsDescr = Union[
2616 KerasHdf5WeightsDescr,
2617 KerasV3WeightsDescr,
2618 OnnxWeightsDescr,
2619 PytorchStateDictWeightsDescr,
2620 TensorflowJsWeightsDescr,
2621 TensorflowSavedModelBundleWeightsDescr,
2622 TorchscriptWeightsDescr,
2623]
2626class WeightsDescr(Node):
2627 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2628 keras_v3: Optional[KerasV3WeightsDescr] = None
2629 onnx: Optional[OnnxWeightsDescr] = None
2630 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2631 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2632 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2633 None
2634 )
2635 torchscript: Optional[TorchscriptWeightsDescr] = None
2637 @model_validator(mode="after")
2638 def check_entries(self) -> Self:
2639 entries = {wtype for wtype, entry in self if entry is not None}
2641 if not entries:
2642 raise ValueError("Missing weights entry")
2644 entries_wo_parent = {
2645 wtype
2646 for wtype, entry in self
2647 if entry is not None and hasattr(entry, "parent") and entry.parent is None
2648 }
2649 if len(entries_wo_parent) != 1:
2650 issue_warning(
2651 "Exactly one weights entry may not specify the `parent` field (got"
2652 + " {value}). That entry is considered the original set of model weights."
2653 + " Other weight formats are created through conversion of the orignal or"
2654 + " already converted weights. They have to reference the weights format"
2655 + " they were converted from as their `parent`.",
2656 value=len(entries_wo_parent),
2657 field="weights",
2658 )
2660 for wtype, entry in self:
2661 if entry is None:
2662 continue
2664 assert hasattr(entry, "type")
2665 assert hasattr(entry, "parent")
2666 assert wtype == entry.type
2667 if (
2668 entry.parent is not None and entry.parent not in entries
2669 ): # self reference checked for `parent` field
2670 raise ValueError(
2671 f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2672 + f" formats: {entries}"
2673 )
2675 return self
2677 def __getitem__(
2678 self,
2679 key: WeightsFormat,
2680 ):
2681 if key == "keras_hdf5":
2682 ret = self.keras_hdf5
2683 elif key == "keras_v3":
2684 ret = self.keras_v3
2685 elif key == "onnx":
2686 ret = self.onnx
2687 elif key == "pytorch_state_dict":
2688 ret = self.pytorch_state_dict
2689 elif key == "tensorflow_js":
2690 ret = self.tensorflow_js
2691 elif key == "tensorflow_saved_model_bundle":
2692 ret = self.tensorflow_saved_model_bundle
2693 elif key == "torchscript":
2694 ret = self.torchscript
2695 else:
2696 raise KeyError(key)
2698 if ret is None:
2699 raise KeyError(key)
2701 return ret
2703 @overload
2704 def __setitem__(
2705 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr]
2706 ) -> None: ...
2707 @overload
2708 def __setitem__(
2709 self, key: Literal["keras_v3"], value: Optional[KerasV3WeightsDescr]
2710 ) -> None: ...
2711 @overload
2712 def __setitem__(
2713 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr]
2714 ) -> None: ...
2715 @overload
2716 def __setitem__(
2717 self,
2718 key: Literal["pytorch_state_dict"],
2719 value: Optional[PytorchStateDictWeightsDescr],
2720 ) -> None: ...
2721 @overload
2722 def __setitem__(
2723 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr]
2724 ) -> None: ...
2725 @overload
2726 def __setitem__(
2727 self,
2728 key: Literal["tensorflow_saved_model_bundle"],
2729 value: Optional[TensorflowSavedModelBundleWeightsDescr],
2730 ) -> None: ...
2731 @overload
2732 def __setitem__(
2733 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr]
2734 ) -> None: ...
2736 def __setitem__(
2737 self,
2738 key: WeightsFormat,
2739 value: Optional[SpecificWeightsDescr],
2740 ):
2741 if key == "keras_hdf5":
2742 if value is not None and not isinstance(value, KerasHdf5WeightsDescr):
2743 raise TypeError(
2744 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}"
2745 )
2746 self.keras_hdf5 = value
2747 elif key == "keras_v3":
2748 if value is not None and not isinstance(value, KerasV3WeightsDescr):
2749 raise TypeError(
2750 f"Expected KerasV3WeightsDescr or None for key 'keras_v3', got {type(value)}"
2751 )
2752 self.keras_v3 = value
2753 elif key == "onnx":
2754 if value is not None and not isinstance(value, OnnxWeightsDescr):
2755 raise TypeError(
2756 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}"
2757 )
2758 self.onnx = value
2759 elif key == "pytorch_state_dict":
2760 if value is not None and not isinstance(
2761 value, PytorchStateDictWeightsDescr
2762 ):
2763 raise TypeError(
2764 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}"
2765 )
2766 self.pytorch_state_dict = value
2767 elif key == "tensorflow_js":
2768 if value is not None and not isinstance(value, TensorflowJsWeightsDescr):
2769 raise TypeError(
2770 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}"
2771 )
2772 self.tensorflow_js = value
2773 elif key == "tensorflow_saved_model_bundle":
2774 if value is not None and not isinstance(
2775 value, TensorflowSavedModelBundleWeightsDescr
2776 ):
2777 raise TypeError(
2778 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}"
2779 )
2780 self.tensorflow_saved_model_bundle = value
2781 elif key == "torchscript":
2782 if value is not None and not isinstance(value, TorchscriptWeightsDescr):
2783 raise TypeError(
2784 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}"
2785 )
2786 self.torchscript = value
2787 else:
2788 raise KeyError(key)
2790 @property
2791 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]:
2792 return {
2793 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2794 **({} if self.keras_v3 is None else {"keras_v3": self.keras_v3}),
2795 **({} if self.onnx is None else {"onnx": self.onnx}),
2796 **(
2797 {}
2798 if self.pytorch_state_dict is None
2799 else {"pytorch_state_dict": self.pytorch_state_dict}
2800 ),
2801 **(
2802 {}
2803 if self.tensorflow_js is None
2804 else {"tensorflow_js": self.tensorflow_js}
2805 ),
2806 **(
2807 {}
2808 if self.tensorflow_saved_model_bundle is None
2809 else {
2810 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2811 }
2812 ),
2813 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2814 }
2816 @property
2817 def missing_formats(self) -> Set[WeightsFormat]:
2818 return {
2819 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2820 }
2823class ModelId(ResourceId):
2824 pass
2827class LinkedModel(LinkedResourceBase):
2828 """Reference to a bioimage.io model."""
2830 id: ModelId
2831 """A valid model `id` from the bioimage.io collection."""
2834class _DataDepSize(NamedTuple):
2835 min: StrictInt
2836 max: Optional[StrictInt]
2839class _AxisSizes(NamedTuple):
2840 """the lenghts of all axes of model inputs and outputs"""
2842 inputs: Dict[Tuple[TensorId, AxisId], int]
2843 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2846class _TensorSizes(NamedTuple):
2847 """_AxisSizes as nested dicts"""
2849 inputs: Dict[TensorId, Dict[AxisId, int]]
2850 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2853class ReproducibilityTolerance(Node, extra="allow"):
2854 """Describes what small numerical differences -- if any -- may be tolerated
2855 in the generated output when executing in different environments.
2857 A tensor element *output* is considered mismatched to the **test_tensor** if
2858 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2859 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2861 Motivation:
2862 For testing we can request the respective deep learning frameworks to be as
2863 reproducible as possible by setting seeds and chosing deterministic algorithms,
2864 but differences in operating systems, available hardware and installed drivers
2865 may still lead to numerical differences.
2866 """
2868 relative_tolerance: RelativeTolerance = 1e-3
2869 """Maximum relative tolerance of reproduced test tensor."""
2871 absolute_tolerance: AbsoluteTolerance = 1e-3
2872 """Maximum absolute tolerance of reproduced test tensor."""
2874 mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2875 """Maximum number of mismatched elements/pixels per million to tolerate."""
2877 output_ids: Sequence[TensorId] = ()
2878 """Limits the output tensor IDs these reproducibility details apply to."""
2880 weights_formats: Sequence[WeightsFormat] = ()
2881 """Limits the weights formats these details apply to."""
2884class BiasRisksLimitations(Node, extra="allow"):
2885 """Known biases, risks, technical limitations, and recommendations for model use."""
2887 known_biases: str = dedent("""\
2888 In general bioimage models may suffer from biases caused by:
2890 - Imaging protocol dependencies
2891 - Use of a specific cell type
2892 - Species-specific training data limitations
2894 """)
2895 """Biases in training data or model behavior."""
2897 risks: str = dedent("""\
2898 Common risks in bioimage analysis include:
2900 - Erroneously assuming generalization to unseen experimental conditions
2901 - Trusting (overconfident) model outputs without validation
2902 - Misinterpretation of results
2904 """)
2905 """Potential risks in the context of bioimage analysis."""
2907 limitations: Optional[str] = None
2908 """Technical limitations and failure modes."""
2910 recommendations: str = "Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model."
2911 """Mitigation strategies regarding `known_biases`, `risks`, and `limitations`, as well as applicable best practices.
2913 Consider:
2914 - How to use a validation dataset?
2915 - How to manually validate?
2916 - Feasibility of domain adaptation for different experimental setups?
2918 """
2920 def format_md(self) -> str:
2921 if self.limitations is None:
2922 limitations_header = ""
2923 else:
2924 limitations_header = "## Limitations\n\n"
2926 return f"""# Bias, Risks, and Limitations
2928{self.known_biases}
2930{self.risks}
2932{limitations_header}{self.limitations or ""}
2934## Recommendations
2936{self.recommendations}
2938"""
2941class TrainingDetails(Node, extra="allow"):
2942 training_preprocessing: Optional[str] = None
2943 """Detailed image preprocessing steps during model training:
2945 Mention:
2946 - *Normalization methods*
2947 - *Augmentation strategies*
2948 - *Resizing/resampling procedures*
2949 - *Artifact handling*
2951 """
2953 training_epochs: Optional[float] = None
2954 """Number of training epochs."""
2956 training_batch_size: Optional[float] = None
2957 """Batch size used in training."""
2959 initial_learning_rate: Optional[float] = None
2960 """Initial learning rate used in training."""
2962 learning_rate_schedule: Optional[str] = None
2963 """Learning rate schedule used in training."""
2965 loss_function: Optional[str] = None
2966 """Loss function used in training, e.g. nn.MSELoss."""
2968 loss_function_kwargs: Dict[str, YamlValue] = Field(
2969 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2970 )
2971 """key word arguments for the `loss_function`"""
2973 optimizer: Optional[str] = None
2974 """optimizer, e.g. torch.optim.Adam"""
2976 optimizer_kwargs: Dict[str, YamlValue] = Field(
2977 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2978 )
2979 """key word arguments for the `optimizer`"""
2981 regularization: Optional[str] = None
2982 """Regularization techniques used during training, e.g. drop-out or weight decay."""
2984 training_duration: Optional[float] = None
2985 """Total training duration in hours."""
2988class Evaluation(Node, extra="allow"):
2989 model_id: Optional[ModelId] = None
2990 """Model being evaluated."""
2992 dataset_id: DatasetId
2993 """Dataset used for evaluation."""
2995 dataset_source: HttpUrl
2996 """Source of the dataset."""
2998 dataset_role: Literal["train", "validation", "test", "independent", "unknown"]
2999 """Role of the dataset used for evaluation.
3001 - `train`: dataset was (part of) the training data
3002 - `validation`: dataset was (part of) the validation data used during training, e.g. used for model selection or hyperparameter tuning
3003 - `test`: dataset was (part of) the designated test data; not used during training or validation, but acquired from the same source/distribution as training data
3004 - `independent`: dataset is entirely independent test data; not used during training or validation, and acquired from a different source/distribution than training data
3005 - `unknown`: role of the dataset is unknown; choose this if you are not certain if (a subset) of the data was seen by the model during training.
3006 """
3008 sample_count: int
3009 """Number of evaluated samples."""
3011 evaluation_factors: List[Annotated[str, MaxLen(16)]]
3012 """(Abbreviations of) each evaluation factor.
3014 Evaluation factors are criteria along which model performance is evaluated, e.g. different image conditions
3015 like 'low SNR', 'high cell density', or different biological conditions like 'cell type A', 'cell type B'.
3016 An 'overall' factor may be included to summarize performance across all conditions.
3017 """
3019 evaluation_factors_long: List[str]
3020 """Descriptions (long form) of each evaluation factor."""
3022 metrics: List[Annotated[str, MaxLen(16)]]
3023 """(Abbreviations of) metrics used for evaluation."""
3025 metrics_long: List[str]
3026 """Description of each metric used."""
3028 @model_validator(mode="after")
3029 def _validate_list_lengths(self) -> Self:
3030 if len(self.evaluation_factors) != len(self.evaluation_factors_long):
3031 raise ValueError(
3032 "`evaluation_factors` and `evaluation_factors_long` must have the same length"
3033 )
3035 if len(self.metrics) != len(self.metrics_long):
3036 raise ValueError("`metrics` and `metrics_long` must have the same length")
3038 if len(self.results) != len(self.metrics):
3039 raise ValueError("`results` must have the same number of rows as `metrics`")
3041 for row in self.results:
3042 if len(row) != len(self.evaluation_factors):
3043 raise ValueError(
3044 "`results` must have the same number of columns (in every row) as `evaluation_factors`"
3045 )
3047 return self
3049 results: List[List[Union[str, float, int]]]
3050 """Results for each metric (rows; outer list) and each evaluation factor (columns; inner list)."""
3052 results_summary: Optional[str] = None
3053 """Interpretation of results for general audience.
3055 Consider:
3056 - Overall model performance
3057 - Comparison to existing methods
3058 - Limitations and areas for improvement
3060"""
3062 def format_md(self):
3063 results_header = ["Metric"] + self.evaluation_factors
3064 results_table_cells = [results_header, ["---"] * len(results_header)] + [
3065 [metric] + [str(r) for r in row]
3066 for metric, row in zip(self.metrics, self.results)
3067 ]
3069 results_table = "".join(
3070 "| " + " | ".join(row) + " |\n" for row in results_table_cells
3071 )
3072 factors = "".join(
3073 f"\n - {ef}: {efl}"
3074 for ef, efl in zip(self.evaluation_factors, self.evaluation_factors_long)
3075 )
3076 metrics = "".join(
3077 f"\n - {em}: {eml}" for em, eml in zip(self.metrics, self.metrics_long)
3078 )
3080 return f"""## Testing Data, Factors & Metrics
3082Evaluation of {self.model_id or "this"} model on the {self.dataset_id} dataset (dataset role: {self.dataset_role}).
3084### Testing Data
3086- **Source:** [{self.dataset_id}]({self.dataset_source})
3087- **Size:** {self.sample_count} evaluated samples
3089### Factors
3090{factors}
3092### Metrics
3093{metrics}
3095## Results
3097### Quantitative Results
3099{results_table}
3101### Summary
3103{self.results_summary or "missing"}
3105"""
3108class EnvironmentalImpact(Node, extra="allow"):
3109 """Environmental considerations for model training and deployment.
3111 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
3112 """
3114 hardware_type: Optional[str] = None
3115 """GPU/CPU specifications"""
3117 hours_used: Optional[float] = None
3118 """Total compute hours"""
3120 cloud_provider: Optional[str] = None
3121 """If applicable"""
3123 compute_region: Optional[str] = None
3124 """Geographic location"""
3126 co2_emitted: Optional[float] = None
3127 """kg CO2 equivalent
3129 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
3130 """
3132 def format_md(self):
3133 """Filled Markdown template section following [Hugging Face Model Card Template](https://huggingface.co/docs/hub/en/model-card-annotated)."""
3134 if self == self.__class__():
3135 return ""
3137 ret = "# Environmental Impact\n\n"
3138 if self.hardware_type is not None:
3139 ret += f"- **Hardware Type:** {self.hardware_type}\n"
3140 if self.hours_used is not None:
3141 ret += f"- **Hours used:** {self.hours_used}\n"
3142 if self.cloud_provider is not None:
3143 ret += f"- **Cloud Provider:** {self.cloud_provider}\n"
3144 if self.compute_region is not None:
3145 ret += f"- **Compute Region:** {self.compute_region}\n"
3146 if self.co2_emitted is not None:
3147 ret += f"- **Carbon Emitted:** {self.co2_emitted} kg CO2e\n"
3149 return ret + "\n"
3152class BioimageioConfig(Node, extra="allow"):
3153 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
3154 """Tolerances to allow when reproducing the model's test outputs
3155 from the model's test inputs.
3156 Only the first entry matching tensor id and weights format is considered.
3157 """
3159 funded_by: Optional[str] = None
3160 """Funding agency, grant number if applicable"""
3162 architecture_type: Optional[Annotated[str, MaxLen(32)]] = (
3163 None # TODO: add to differentiated tags
3164 )
3165 """Model architecture type, e.g., 3D U-Net, ResNet, transformer"""
3167 architecture_description: Optional[str] = None
3168 """Text description of model architecture."""
3170 modality: Optional[str] = None # TODO: add to differentiated tags
3171 """Input modality, e.g., fluorescence microscopy, electron microscopy"""
3173 target_structure: List[str] = Field( # TODO: add to differentiated tags
3174 default_factory=cast(Callable[[], List[str]], list)
3175 )
3176 """Biological structure(s) the model is designed to analyze, e.g., nuclei, mitochondria, cells"""
3178 task: Optional[str] = None # TODO: add to differentiated tags
3179 """Bioimage-specific task type, e.g., segmentation, classification, detection, denoising"""
3181 new_version: Optional[ModelId] = None
3182 """A new version of this model exists with a different model id."""
3184 out_of_scope_use: Optional[str] = None
3185 """Describe how the model may be misused in bioimage analysis contexts and what users should **not** do with the model."""
3187 bias_risks_limitations: BiasRisksLimitations = Field(
3188 default_factory=BiasRisksLimitations.model_construct
3189 )
3190 """Description of known bias, risks, and technical limitations for in-scope model use."""
3192 model_parameter_count: Optional[int] = None
3193 """Total number of model parameters."""
3195 training: TrainingDetails = Field(default_factory=TrainingDetails.model_construct)
3196 """Details on how the model was trained."""
3198 inference_time: Optional[str] = None
3199 """Average inference time per image/tile. Specify hardware and image size. Multiple examples can be given."""
3201 memory_requirements_inference: Optional[str] = None
3202 """GPU memory needed for inference. Multiple examples with different image size can be given."""
3204 memory_requirements_training: Optional[str] = None
3205 """GPU memory needed for training. Multiple examples with different image/batch sizes can be given."""
3207 evaluations: List[Evaluation] = Field(
3208 default_factory=cast(Callable[[], List[Evaluation]], list)
3209 )
3210 """Quantitative model evaluations.
3212 Note:
3213 At the moment we recommend to include only a single test dataset
3214 (with evaluation factors that may mark subsets of the dataset)
3215 to avoid confusion and make the presentation of results cleaner.
3216 """
3218 environmental_impact: EnvironmentalImpact = Field(
3219 default_factory=EnvironmentalImpact.model_construct
3220 )
3221 """Environmental considerations for model training and deployment"""
3224class Config(Node, extra="allow"):
3225 bioimageio: BioimageioConfig = Field(
3226 default_factory=BioimageioConfig.model_construct
3227 )
3228 stardist: YamlValue = None
3231class ModelDescr(GenericModelDescrBase):
3232 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
3233 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
3234 """
3236 implemented_format_version: ClassVar[Literal["0.5.9"]] = "0.5.9"
3237 if TYPE_CHECKING:
3238 format_version: Literal["0.5.9"] = "0.5.9"
3239 else:
3240 format_version: Literal["0.5.9"]
3241 """Version of the bioimage.io model description specification used.
3242 When creating a new model always use the latest micro/patch version described here.
3243 The `format_version` is important for any consumer software to understand how to parse the fields.
3244 """
3246 implemented_type: ClassVar[Literal["model"]] = "model"
3247 if TYPE_CHECKING:
3248 type: Literal["model"] = "model"
3249 else:
3250 type: Literal["model"]
3251 """Specialized resource type 'model'"""
3253 id: Optional[ModelId] = None
3254 """bioimage.io-wide unique resource identifier
3255 assigned by bioimage.io; version **un**specific."""
3257 authors: FAIR[List[Author]] = Field(
3258 default_factory=cast(Callable[[], List[Author]], list)
3259 )
3260 """The authors are the creators of the model RDF and the primary points of contact."""
3262 documentation: FAIR[Optional[FileSource_documentation]] = None
3263 """URL or relative path to a markdown file with additional documentation.
3264 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
3265 The documentation should include a '#[#] Validation' (sub)section
3266 with details on how to quantitatively validate the model on unseen data."""
3268 @field_validator("documentation", mode="after")
3269 @classmethod
3270 def _validate_documentation(
3271 cls, value: Optional[FileSource_documentation]
3272 ) -> Optional[FileSource_documentation]:
3273 if not get_validation_context().perform_io_checks or value is None:
3274 return value
3276 doc_reader = get_reader(value)
3277 doc_content = doc_reader.read().decode(encoding="utf-8")
3278 if not re.search("#.*[vV]alidation", doc_content):
3279 issue_warning(
3280 "No '# Validation' (sub)section found in {value}.",
3281 value=value,
3282 field="documentation",
3283 )
3285 return value
3287 inputs: NotEmpty[Sequence[InputTensorDescr]]
3288 """Describes the input tensors expected by this model."""
3290 @field_validator("inputs", mode="after")
3291 @classmethod
3292 def _validate_input_axes(
3293 cls, inputs: Sequence[InputTensorDescr]
3294 ) -> Sequence[InputTensorDescr]:
3295 input_size_refs = cls._get_axes_with_independent_size(inputs)
3297 for i, ipt in enumerate(inputs):
3298 valid_independent_refs: Dict[
3299 Tuple[TensorId, AxisId],
3300 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3301 ] = {
3302 **{
3303 (ipt.id, a.id): (ipt, a, a.size)
3304 for a in ipt.axes
3305 if not isinstance(a, BatchAxis)
3306 and isinstance(a.size, (int, ParameterizedSize))
3307 },
3308 **input_size_refs,
3309 }
3310 for a, ax in enumerate(ipt.axes):
3311 cls._validate_axis(
3312 "inputs",
3313 i=i,
3314 tensor_id=ipt.id,
3315 a=a,
3316 axis=ax,
3317 valid_independent_refs=valid_independent_refs,
3318 )
3319 return inputs
3321 @staticmethod
3322 def _validate_axis(
3323 field_name: str,
3324 i: int,
3325 tensor_id: TensorId,
3326 a: int,
3327 axis: AnyAxis,
3328 valid_independent_refs: Dict[
3329 Tuple[TensorId, AxisId],
3330 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3331 ],
3332 ):
3333 if isinstance(axis, BatchAxis) or isinstance(
3334 axis.size, (int, ParameterizedSize, DataDependentSize)
3335 ):
3336 return
3337 elif not isinstance(axis.size, SizeReference):
3338 assert_never(axis.size)
3340 # validate axis.size SizeReference
3341 ref = (axis.size.tensor_id, axis.size.axis_id)
3342 if ref not in valid_independent_refs:
3343 raise ValueError(
3344 "Invalid tensor axis reference at"
3345 + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
3346 )
3347 if ref == (tensor_id, axis.id):
3348 raise ValueError(
3349 "Self-referencing not allowed for"
3350 + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
3351 )
3352 if axis.type == "channel":
3353 if valid_independent_refs[ref][1].type != "channel":
3354 raise ValueError(
3355 "A channel axis' size may only reference another fixed size"
3356 + " channel axis."
3357 )
3358 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
3359 ref_size = valid_independent_refs[ref][2]
3360 assert isinstance(ref_size, int), (
3361 "channel axis ref (another channel axis) has to specify fixed"
3362 + " size"
3363 )
3364 generated_channel_names = [
3365 Identifier(axis.channel_names.format(i=i))
3366 for i in range(1, ref_size + 1)
3367 ]
3368 axis.channel_names = generated_channel_names
3370 if (ax_unit := getattr(axis, "unit", None)) != (
3371 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
3372 ):
3373 raise ValueError(
3374 "The units of an axis and its reference axis need to match, but"
3375 + f" '{ax_unit}' != '{ref_unit}'."
3376 )
3377 ref_axis = valid_independent_refs[ref][1]
3378 if isinstance(ref_axis, BatchAxis):
3379 raise ValueError(
3380 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
3381 + " (a batch axis is not allowed as reference)."
3382 )
3384 if isinstance(axis, WithHalo):
3385 min_size = axis.size.get_size(axis, ref_axis, n=0)
3386 if (min_size - 2 * axis.halo) < 1:
3387 raise ValueError(
3388 f"axis {axis.id} with minimum size {min_size} is too small for halo"
3389 + f" {axis.halo}."
3390 )
3392 ref_halo = axis.halo * axis.scale / ref_axis.scale
3393 if ref_halo != int(ref_halo):
3394 raise ValueError(
3395 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} ="
3396 + f" {tensor_id}.{axis.id}.halo {axis.halo}"
3397 + f" * {tensor_id}.{axis.id}.scale {axis.scale}"
3398 + f" / {'.'.join(ref)}.scale {ref_axis.scale})."
3399 )
3401 def validate_input_tensors(
3402 self,
3403 sources: Union[
3404 Sequence[NDArray[Any]], Mapping[TensorId, Optional[NDArray[Any]]]
3405 ],
3406 ) -> Mapping[TensorId, Optional[NDArray[Any]]]:
3407 """Check if the given input tensors match the model's input tensor descriptions.
3408 This includes checks of tensor shapes and dtypes, but not of the actual values.
3409 """
3410 if not isinstance(sources, collections.abc.Mapping):
3411 sources = {descr.id: tensor for descr, tensor in zip(self.inputs, sources)}
3413 tensors = {descr.id: (descr, sources.get(descr.id)) for descr in self.inputs}
3414 validate_tensors(tensors)
3416 return sources
3418 @model_validator(mode="after")
3419 def _validate_test_tensors(self) -> Self:
3420 if not get_validation_context().perform_io_checks:
3421 return self
3423 test_inputs = {
3424 descr.id: (
3425 descr,
3426 None if descr.test_tensor is None else load_array(descr.test_tensor),
3427 )
3428 for descr in self.inputs
3429 }
3430 test_outputs = {
3431 descr.id: (
3432 descr,
3433 None if descr.test_tensor is None else load_array(descr.test_tensor),
3434 )
3435 for descr in self.outputs
3436 }
3438 validate_tensors({**test_inputs, **test_outputs}, tensor_origin="test_tensor")
3440 for rep_tol in self.config.bioimageio.reproducibility_tolerance:
3441 if not rep_tol.absolute_tolerance:
3442 continue
3444 if rep_tol.output_ids:
3445 out_arrays = {
3446 k: v[1] for k, v in test_outputs.items() if k in rep_tol.output_ids
3447 }
3448 else:
3449 out_arrays = {k: v[1] for k, v in test_outputs.items()}
3451 for out_id, array in out_arrays.items():
3452 if array is None:
3453 continue
3455 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
3456 raise ValueError(
3457 "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
3458 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
3459 + f" (1% of the maximum value of the test tensor '{out_id}')"
3460 )
3462 return self
3464 @model_validator(mode="after")
3465 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
3466 ipt_refs = {t.id for t in self.inputs}
3467 missing_refs = [
3468 k["reference_tensor"]
3469 for k in [p.kwargs for ipt in self.inputs for p in ipt.preprocessing]
3470 + [p.kwargs for out in self.outputs for p in out.postprocessing]
3471 if "reference_tensor" in k
3472 and k["reference_tensor"] is not None
3473 and k["reference_tensor"] not in ipt_refs
3474 ]
3476 if missing_refs:
3477 raise ValueError(
3478 f"`reference_tensor`s {missing_refs} not found. Valid input tensor"
3479 + f" references are: {ipt_refs}."
3480 )
3482 return self
3484 name: Annotated[
3485 str,
3486 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
3487 MinLen(5),
3488 MaxLen(128),
3489 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
3490 ]
3491 """A human-readable name of this model.
3492 It should be no longer than 64 characters
3493 and may only contain letter, number, underscore, minus, parentheses and spaces.
3494 We recommend to chose a name that refers to the model's task and image modality.
3495 """
3497 outputs: NotEmpty[Sequence[OutputTensorDescr]]
3498 """Describes the output tensors."""
3500 @field_validator("outputs", mode="after")
3501 @classmethod
3502 def _validate_tensor_ids(
3503 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
3504 ) -> Sequence[OutputTensorDescr]:
3505 tensor_ids = [
3506 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
3507 ]
3508 duplicate_tensor_ids: List[str] = []
3509 seen: Set[str] = set()
3510 for t in tensor_ids:
3511 if t in seen:
3512 duplicate_tensor_ids.append(t)
3514 seen.add(t)
3516 if duplicate_tensor_ids:
3517 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
3519 return outputs
3521 @staticmethod
3522 def _get_axes_with_parameterized_size(
3523 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3524 ):
3525 return {
3526 f"{t.id}.{a.id}": (t, a, a.size)
3527 for t in io
3528 for a in t.axes
3529 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
3530 }
3532 @staticmethod
3533 def _get_axes_with_independent_size(
3534 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3535 ):
3536 return {
3537 (t.id, a.id): (t, a, a.size)
3538 for t in io
3539 for a in t.axes
3540 if not isinstance(a, BatchAxis)
3541 and isinstance(a.size, (int, ParameterizedSize))
3542 }
3544 @field_validator("outputs", mode="after")
3545 @classmethod
3546 def _validate_output_axes(
3547 cls, outputs: List[OutputTensorDescr], info: ValidationInfo
3548 ) -> List[OutputTensorDescr]:
3549 input_size_refs = cls._get_axes_with_independent_size(
3550 info.data.get("inputs", [])
3551 )
3552 output_size_refs = cls._get_axes_with_independent_size(outputs)
3554 for i, out in enumerate(outputs):
3555 valid_independent_refs: Dict[
3556 Tuple[TensorId, AxisId],
3557 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3558 ] = {
3559 **{
3560 (out.id, a.id): (out, a, a.size)
3561 for a in out.axes
3562 if not isinstance(a, BatchAxis)
3563 and isinstance(a.size, (int, ParameterizedSize))
3564 },
3565 **input_size_refs,
3566 **output_size_refs,
3567 }
3568 for a, ax in enumerate(out.axes):
3569 cls._validate_axis(
3570 "outputs",
3571 i,
3572 out.id,
3573 a,
3574 ax,
3575 valid_independent_refs=valid_independent_refs,
3576 )
3578 return outputs
3580 packaged_by: List[Author] = Field(
3581 default_factory=cast(Callable[[], List[Author]], list)
3582 )
3583 """The persons that have packaged and uploaded this model.
3584 Only required if those persons differ from the `authors`."""
3586 parent: Optional[LinkedModel] = None
3587 """The model from which this model is derived, e.g. by fine-tuning the weights."""
3589 @model_validator(mode="after")
3590 def _validate_parent_is_not_self(self) -> Self:
3591 if self.parent is not None and self.parent.id == self.id:
3592 raise ValueError("A model description may not reference itself as parent.")
3594 return self
3596 run_mode: Annotated[
3597 Optional[RunMode],
3598 warn(None, "Run mode '{value}' has limited support across consumer softwares."),
3599 ] = None
3600 """Custom run mode for this model: for more complex prediction procedures like test time
3601 data augmentation that currently cannot be expressed in the specification.
3602 No standard run modes are defined yet."""
3604 timestamp: Datetime = Field(default_factory=Datetime.now)
3605 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
3606 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
3607 (In Python a datetime object is valid, too)."""
3609 training_data: Annotated[
3610 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
3611 Field(union_mode="left_to_right"),
3612 ] = None
3613 """The dataset used to train this model"""
3615 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
3616 """The weights for this model.
3617 Weights can be given for different formats, but should otherwise be equivalent.
3618 The available weight formats determine which consumers can use this model."""
3620 config: Config = Field(default_factory=Config.model_construct)
3622 @model_validator(mode="after")
3623 def _add_default_cover(self) -> Self:
3624 if not get_validation_context().perform_io_checks or self.covers:
3625 return self
3627 try:
3628 generated_covers = generate_covers(
3629 [
3630 (t, load_array(t.test_tensor))
3631 for t in self.inputs
3632 if t.test_tensor is not None
3633 ],
3634 [
3635 (t, load_array(t.test_tensor))
3636 for t in self.outputs
3637 if t.test_tensor is not None
3638 ],
3639 )
3640 except Exception as e:
3641 issue_warning(
3642 "Failed to generate cover image(s): {e}",
3643 value=self.covers,
3644 msg_context=dict(e=e),
3645 field="covers",
3646 )
3647 else:
3648 self.covers.extend(generated_covers)
3650 return self
3652 def get_input_test_arrays(self) -> List[NDArray[Any]]:
3653 return self._get_test_arrays(self.inputs)
3655 def get_output_test_arrays(self) -> List[NDArray[Any]]:
3656 return self._get_test_arrays(self.outputs)
3658 @staticmethod
3659 def _get_test_arrays(
3660 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3661 ):
3662 ts: List[FileDescr] = []
3663 for d in io_descr:
3664 if d.test_tensor is None:
3665 raise ValueError(
3666 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3667 )
3668 ts.append(d.test_tensor)
3670 data = [load_array(t) for t in ts]
3671 assert all(isinstance(d, np.ndarray) for d in data)
3672 return data
3674 @staticmethod
3675 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3676 batch_size = 1
3677 tensor_with_batchsize: Optional[TensorId] = None
3678 for tid in tensor_sizes:
3679 for aid, s in tensor_sizes[tid].items():
3680 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3681 continue
3683 if batch_size != 1:
3684 assert tensor_with_batchsize is not None
3685 raise ValueError(
3686 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3687 )
3689 batch_size = s
3690 tensor_with_batchsize = tid
3692 return batch_size
3694 def get_output_tensor_sizes(
3695 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3696 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3697 """Returns the tensor output sizes for given **input_sizes**.
3698 Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3699 Otherwise it might be larger than the actual (valid) output"""
3700 batch_size = self.get_batch_size(input_sizes)
3701 ns = self.get_ns(input_sizes)
3703 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3704 return tensor_sizes.outputs
3706 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3707 """get parameter `n` for each parameterized axis
3708 such that the valid input size is >= the given input size"""
3709 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3710 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3711 for tid in input_sizes:
3712 for aid, s in input_sizes[tid].items():
3713 size_descr = axes[tid][aid].size
3714 if isinstance(size_descr, ParameterizedSize):
3715 ret[(tid, aid)] = size_descr.get_n(s)
3716 elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3717 pass
3718 else:
3719 assert_never(size_descr)
3721 return ret
3723 def get_tensor_sizes(
3724 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3725 ) -> _TensorSizes:
3726 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3727 return _TensorSizes(
3728 {
3729 t: {
3730 aa: axis_sizes.inputs[(tt, aa)]
3731 for tt, aa in axis_sizes.inputs
3732 if tt == t
3733 }
3734 for t in {tt for tt, _ in axis_sizes.inputs}
3735 },
3736 {
3737 t: {
3738 aa: axis_sizes.outputs[(tt, aa)]
3739 for tt, aa in axis_sizes.outputs
3740 if tt == t
3741 }
3742 for t in {tt for tt, _ in axis_sizes.outputs}
3743 },
3744 )
3746 def get_axis_sizes(
3747 self,
3748 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3749 batch_size: Optional[int] = None,
3750 *,
3751 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3752 ) -> _AxisSizes:
3753 """Determine input and output block shape for scale factors **ns**
3754 of parameterized input sizes.
3756 Args:
3757 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3758 that is parameterized as `size = min + n * step`.
3759 batch_size: The desired size of the batch dimension.
3760 If given **batch_size** overwrites any batch size present in
3761 **max_input_shape**. Default 1.
3762 max_input_shape: Limits the derived block shapes.
3763 Each axis for which the input size, parameterized by `n`, is larger
3764 than **max_input_shape** is set to the minimal value `n_min` for which
3765 this is still true.
3766 Use this for small input samples or large values of **ns**.
3767 Or simply whenever you know the full input shape.
3769 Returns:
3770 Resolved axis sizes for model inputs and outputs.
3771 """
3772 max_input_shape = max_input_shape or {}
3773 if batch_size is None:
3774 for (_t_id, a_id), s in max_input_shape.items():
3775 if a_id == BATCH_AXIS_ID:
3776 batch_size = s
3777 break
3778 else:
3779 batch_size = 1
3781 all_axes = {
3782 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3783 }
3785 inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3786 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3788 def get_axis_size(a: Union[InputAxis, OutputAxis]):
3789 if isinstance(a, BatchAxis):
3790 if (t_descr.id, a.id) in ns:
3791 logger.warning(
3792 "Ignoring unexpected size increment factor (n) for batch axis"
3793 + " of tensor '{}'.",
3794 t_descr.id,
3795 )
3796 return batch_size
3797 elif isinstance(a.size, int):
3798 if (t_descr.id, a.id) in ns:
3799 logger.warning(
3800 "Ignoring unexpected size increment factor (n) for fixed size"
3801 + " axis '{}' of tensor '{}'.",
3802 a.id,
3803 t_descr.id,
3804 )
3805 return a.size
3806 elif isinstance(a.size, ParameterizedSize):
3807 if (t_descr.id, a.id) not in ns:
3808 raise ValueError(
3809 "Size increment factor (n) missing for parametrized axis"
3810 + f" '{a.id}' of tensor '{t_descr.id}'."
3811 )
3812 n = ns[(t_descr.id, a.id)]
3813 s_max = max_input_shape.get((t_descr.id, a.id))
3814 if s_max is not None:
3815 n = min(n, a.size.get_n(s_max))
3817 return a.size.get_size(n)
3819 elif isinstance(a.size, SizeReference):
3820 if (t_descr.id, a.id) in ns:
3821 logger.warning(
3822 "Ignoring unexpected size increment factor (n) for axis '{}'"
3823 + " of tensor '{}' with size reference.",
3824 a.id,
3825 t_descr.id,
3826 )
3827 assert not isinstance(a, BatchAxis)
3828 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3829 assert not isinstance(ref_axis, BatchAxis)
3830 ref_key = (a.size.tensor_id, a.size.axis_id)
3831 ref_size = inputs.get(ref_key, outputs.get(ref_key))
3832 assert ref_size is not None, ref_key
3833 assert not isinstance(ref_size, _DataDepSize), ref_key
3834 return a.size.get_size(
3835 axis=a,
3836 ref_axis=ref_axis,
3837 ref_size=ref_size,
3838 )
3839 elif isinstance(a.size, DataDependentSize):
3840 if (t_descr.id, a.id) in ns:
3841 logger.warning(
3842 "Ignoring unexpected increment factor (n) for data dependent"
3843 + " size axis '{}' of tensor '{}'.",
3844 a.id,
3845 t_descr.id,
3846 )
3847 return _DataDepSize(a.size.min, a.size.max)
3848 else:
3849 assert_never(a.size)
3851 # first resolve all , but the `SizeReference` input sizes
3852 for t_descr in self.inputs:
3853 for a in t_descr.axes:
3854 if not isinstance(a.size, SizeReference):
3855 s = get_axis_size(a)
3856 assert not isinstance(s, _DataDepSize)
3857 inputs[t_descr.id, a.id] = s
3859 # resolve all other input axis sizes
3860 for t_descr in self.inputs:
3861 for a in t_descr.axes:
3862 if isinstance(a.size, SizeReference):
3863 s = get_axis_size(a)
3864 assert not isinstance(s, _DataDepSize)
3865 inputs[t_descr.id, a.id] = s
3867 # resolve all output axis sizes
3868 for t_descr in self.outputs:
3869 for a in t_descr.axes:
3870 assert not isinstance(a.size, ParameterizedSize)
3871 s = get_axis_size(a)
3872 outputs[t_descr.id, a.id] = s
3874 return _AxisSizes(inputs=inputs, outputs=outputs)
3876 @model_validator(mode="before")
3877 @classmethod
3878 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3879 cls.convert_from_old_format_wo_validation(data)
3880 return data
3882 @classmethod
3883 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3884 """Convert metadata following an older format version to this classes' format
3885 without validating the result.
3886 """
3887 if (
3888 data.get("type") == "model"
3889 and isinstance(fv := data.get("format_version"), str)
3890 and fv.count(".") == 2
3891 ):
3892 fv_parts = fv.split(".")
3893 if any(not p.isdigit() for p in fv_parts):
3894 return
3896 fv_tuple = tuple(map(int, fv_parts))
3898 assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3899 if fv_tuple[:2] in ((0, 3), (0, 4)):
3900 m04 = _ModelDescr_v0_4.load(data)
3901 if isinstance(m04, InvalidDescr):
3902 try:
3903 updated = _model_conv.convert_as_dict(
3904 m04 # pyright: ignore[reportArgumentType]
3905 )
3906 except Exception as e:
3907 logger.error(
3908 "Failed to convert from invalid model 0.4 description."
3909 + f"\nerror: {e}"
3910 + "\nProceeding with model 0.5 validation without conversion."
3911 )
3912 updated = None
3913 else:
3914 updated = _model_conv.convert_as_dict(m04)
3916 if updated is not None:
3917 data.clear()
3918 data.update(updated)
3920 elif fv_tuple[:2] == (0, 5):
3921 # bump patch version
3922 data["format_version"] = cls.implemented_format_version
3925class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3926 def _convert(
3927 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3928 ) -> "ModelDescr | dict[str, Any]":
3929 name = "".join(
3930 c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3931 for c in src.name
3932 )
3934 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3935 conv = (
3936 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3937 )
3938 return None if auths is None else [conv(a) for a in auths]
3940 if TYPE_CHECKING:
3941 arch_file_conv = _arch_file_conv.convert
3942 arch_lib_conv = _arch_lib_conv.convert
3943 else:
3944 arch_file_conv = _arch_file_conv.convert_as_dict
3945 arch_lib_conv = _arch_lib_conv.convert_as_dict
3947 input_size_refs = {
3948 ipt.name: {
3949 a: s
3950 for a, s in zip(
3951 ipt.axes,
3952 (
3953 ipt.shape.min
3954 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3955 else ipt.shape
3956 ),
3957 )
3958 }
3959 for ipt in src.inputs
3960 if ipt.shape
3961 }
3962 output_size_refs = {
3963 **{
3964 out.name: {a: s for a, s in zip(out.axes, out.shape)}
3965 for out in src.outputs
3966 if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3967 },
3968 **input_size_refs,
3969 }
3971 return tgt(
3972 attachments=(
3973 []
3974 if src.attachments is None
3975 else [FileDescr(source=f) for f in src.attachments.files]
3976 ),
3977 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType]
3978 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType]
3979 config=src.config, # pyright: ignore[reportArgumentType]
3980 covers=src.covers,
3981 description=src.description,
3982 documentation=src.documentation,
3983 format_version="0.5.9",
3984 git_repo=src.git_repo, # pyright: ignore[reportArgumentType]
3985 icon=src.icon,
3986 id=None if src.id is None else ModelId(src.id),
3987 id_emoji=src.id_emoji,
3988 license=src.license, # type: ignore
3989 links=src.links,
3990 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType]
3991 name=name,
3992 tags=src.tags,
3993 type=src.type,
3994 uploader=src.uploader,
3995 version=src.version,
3996 inputs=[ # pyright: ignore[reportArgumentType]
3997 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3998 for ipt, tt, st in zip(
3999 src.inputs,
4000 src.test_inputs,
4001 src.sample_inputs or [None] * len(src.test_inputs),
4002 )
4003 ],
4004 outputs=[ # pyright: ignore[reportArgumentType]
4005 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
4006 for out, tt, st in zip(
4007 src.outputs,
4008 src.test_outputs,
4009 src.sample_outputs or [None] * len(src.test_outputs),
4010 )
4011 ],
4012 parent=(
4013 None
4014 if src.parent is None
4015 else LinkedModel(
4016 id=ModelId(
4017 str(src.parent.id)
4018 + (
4019 ""
4020 if src.parent.version_number is None
4021 else f"/{src.parent.version_number}"
4022 )
4023 )
4024 )
4025 ),
4026 training_data=(
4027 None
4028 if src.training_data is None
4029 else (
4030 LinkedDataset(
4031 id=DatasetId(
4032 str(src.training_data.id)
4033 + (
4034 ""
4035 if src.training_data.version_number is None
4036 else f"/{src.training_data.version_number}"
4037 )
4038 )
4039 )
4040 if isinstance(src.training_data, LinkedDataset02)
4041 else src.training_data
4042 )
4043 ),
4044 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType]
4045 run_mode=src.run_mode,
4046 timestamp=src.timestamp,
4047 weights=(WeightsDescr if TYPE_CHECKING else dict)(
4048 keras_hdf5=(w := src.weights.keras_hdf5)
4049 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
4050 authors=conv_authors(w.authors),
4051 source=w.source,
4052 tensorflow_version=w.tensorflow_version or Version("1.15"),
4053 parent=w.parent,
4054 ),
4055 onnx=(w := src.weights.onnx)
4056 and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
4057 source=w.source,
4058 authors=conv_authors(w.authors),
4059 parent=w.parent,
4060 opset_version=w.opset_version or 15,
4061 ),
4062 pytorch_state_dict=(w := src.weights.pytorch_state_dict)
4063 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
4064 source=w.source,
4065 authors=conv_authors(w.authors),
4066 parent=w.parent,
4067 architecture=(
4068 arch_file_conv(
4069 w.architecture,
4070 w.architecture_sha256,
4071 w.kwargs,
4072 )
4073 if isinstance(w.architecture, _CallableFromFile_v0_4)
4074 else arch_lib_conv(w.architecture, w.kwargs)
4075 ),
4076 pytorch_version=w.pytorch_version or Version("1.10"),
4077 dependencies=(
4078 None
4079 if w.dependencies is None
4080 else (FileDescr if TYPE_CHECKING else dict)(
4081 source=cast(
4082 FileSource,
4083 str(deps := w.dependencies)[
4084 (
4085 len("conda:")
4086 if str(deps).startswith("conda:")
4087 else 0
4088 ) :
4089 ],
4090 )
4091 )
4092 ),
4093 ),
4094 tensorflow_js=(w := src.weights.tensorflow_js)
4095 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
4096 source=w.source,
4097 authors=conv_authors(w.authors),
4098 parent=w.parent,
4099 tensorflow_version=w.tensorflow_version or Version("1.15"),
4100 ),
4101 tensorflow_saved_model_bundle=(
4102 w := src.weights.tensorflow_saved_model_bundle
4103 )
4104 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
4105 authors=conv_authors(w.authors),
4106 parent=w.parent,
4107 source=w.source,
4108 tensorflow_version=w.tensorflow_version or Version("1.15"),
4109 dependencies=(
4110 None
4111 if w.dependencies is None
4112 else (FileDescr if TYPE_CHECKING else dict)(
4113 source=cast(
4114 FileSource,
4115 (
4116 str(w.dependencies)[len("conda:") :]
4117 if str(w.dependencies).startswith("conda:")
4118 else str(w.dependencies)
4119 ),
4120 )
4121 )
4122 ),
4123 ),
4124 torchscript=(w := src.weights.torchscript)
4125 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
4126 source=w.source,
4127 authors=conv_authors(w.authors),
4128 parent=w.parent,
4129 pytorch_version=w.pytorch_version or Version("1.10"),
4130 ),
4131 ),
4132 )
4135_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
4138# create better cover images for 3d data and non-image outputs
4139def generate_covers(
4140 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
4141 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
4142) -> List[Path]:
4143 def squeeze(
4144 data: NDArray[Any], axes: Sequence[AnyAxis]
4145 ) -> Tuple[NDArray[Any], List[AnyAxis]]:
4146 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
4147 if data.ndim != len(axes):
4148 raise ValueError(
4149 f"tensor shape {data.shape} does not match described axes"
4150 + f" {[a.id for a in axes]}"
4151 )
4153 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
4154 return data.squeeze(), axes
4156 def normalize(
4157 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
4158 ) -> NDArray[np.float32]:
4159 data = data.astype("float32")
4160 data -= data.min(axis=axis, keepdims=True)
4161 data /= data.max(axis=axis, keepdims=True) + eps
4162 return data
4164 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
4165 original_shape = data.shape
4166 original_axes = list(axes)
4167 data, axes = squeeze(data, axes)
4169 # take slice fom any batch or index axis if needed
4170 # and convert the first channel axis and take a slice from any additional channel axes
4171 slices: Tuple[slice, ...] = ()
4172 ndim = data.ndim
4173 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
4174 has_c_axis = False
4175 for i, a in enumerate(axes):
4176 s = data.shape[i]
4177 assert s > 1
4178 if (
4179 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
4180 and ndim > ndim_need
4181 ):
4182 data = data[slices + (slice(s // 2 - 1, s // 2),)]
4183 ndim -= 1
4184 elif isinstance(a, ChannelAxis):
4185 if has_c_axis:
4186 # second channel axis
4187 data = data[slices + (slice(0, 1),)]
4188 ndim -= 1
4189 else:
4190 has_c_axis = True
4191 if s == 2:
4192 # visualize two channels with cyan and magenta
4193 data = np.concatenate(
4194 [
4195 data[slices + (slice(1, 2),)],
4196 data[slices + (slice(0, 1),)],
4197 (
4198 data[slices + (slice(0, 1),)]
4199 + data[slices + (slice(1, 2),)]
4200 )
4201 / 2, # TODO: take maximum instead?
4202 ],
4203 axis=i,
4204 )
4205 elif data.shape[i] == 3:
4206 pass # visualize 3 channels as RGB
4207 else:
4208 # visualize first 3 channels as RGB
4209 data = data[slices + (slice(3),)]
4211 assert data.shape[i] == 3
4213 slices += (slice(None),)
4215 data, axes = squeeze(data, axes)
4216 assert len(axes) == ndim
4217 # take slice from z axis if needed
4218 slices = ()
4219 if ndim > ndim_need:
4220 for i, a in enumerate(axes):
4221 s = data.shape[i]
4222 if a.id == AxisId("z"):
4223 data = data[slices + (slice(s // 2 - 1, s // 2),)]
4224 data, axes = squeeze(data, axes)
4225 ndim -= 1
4226 break
4228 slices += (slice(None),)
4230 # take slice from any space or time axis
4231 slices = ()
4233 for i, a in enumerate(axes):
4234 if ndim <= ndim_need:
4235 break
4237 s = data.shape[i]
4238 assert s > 1
4239 if isinstance(
4240 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
4241 ):
4242 data = data[slices + (slice(s // 2 - 1, s // 2),)]
4243 ndim -= 1
4245 slices += (slice(None),)
4247 del slices
4248 data, axes = squeeze(data, axes)
4249 assert len(axes) == ndim
4251 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2):
4252 raise ValueError(
4253 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}."
4254 )
4256 if not has_c_axis:
4257 assert ndim == 2
4258 data = np.repeat(data[:, :, None], 3, axis=2)
4259 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
4260 ndim += 1
4262 assert ndim == 3
4264 # transpose axis order such that longest axis comes first...
4265 axis_order: List[int] = list(np.argsort(list(data.shape)))
4266 axis_order.reverse()
4267 # ... and channel axis is last
4268 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
4269 axis_order.append(axis_order.pop(c))
4270 axes = [axes[ao] for ao in axis_order]
4271 data = data.transpose(axis_order)
4273 # h, w = data.shape[:2]
4274 # if h / w in (1.0 or 2.0):
4275 # pass
4276 # elif h / w < 2:
4277 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
4279 norm_along = (
4280 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
4281 )
4282 # normalize the data and map to 8 bit
4283 data = normalize(data, norm_along)
4284 data = (data * 255).astype("uint8")
4286 return data
4288 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
4289 assert im0.dtype == im1.dtype == np.uint8
4290 assert im0.shape == im1.shape
4291 assert im0.ndim == 3
4292 N, M, C = im0.shape
4293 assert C == 3
4294 out = np.ones((N, M, C), dtype="uint8")
4295 for c in range(C):
4296 outc = np.tril(im0[..., c])
4297 mask = outc == 0
4298 outc[mask] = np.triu(im1[..., c])[mask]
4299 out[..., c] = outc
4301 return out
4303 if not inputs:
4304 raise ValueError("Missing test input tensor for cover generation.")
4306 if not outputs:
4307 raise ValueError("Missing test output tensor for cover generation.")
4309 ipt_descr, ipt = inputs[0]
4310 out_descr, out = outputs[0]
4312 ipt_img = to_2d_image(ipt, ipt_descr.axes)
4313 out_img = to_2d_image(out, out_descr.axes)
4315 cover_folder = Path(mkdtemp())
4316 if ipt_img.shape == out_img.shape:
4317 covers = [cover_folder / "cover.png"]
4318 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
4319 else:
4320 covers = [cover_folder / "input.png", cover_folder / "output.png"]
4321 imwrite(covers[0], ipt_img)
4322 imwrite(covers[1], out_img)
4324 return covers