Coverage for src / bioimageio / spec / model / v0_5.py: 74%
1409 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 13:16 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 13:16 +0000
1from __future__ import annotations
3import collections.abc
4import re
5import string
6import warnings
7from copy import deepcopy
8from itertools import chain
9from math import ceil
10from pathlib import Path, PurePosixPath
11from tempfile import mkdtemp
12from typing import (
13 TYPE_CHECKING,
14 Any,
15 Callable,
16 ClassVar,
17 Dict,
18 Generic,
19 List,
20 Literal,
21 Mapping,
22 NamedTuple,
23 Optional,
24 Sequence,
25 Set,
26 Tuple,
27 Type,
28 TypeVar,
29 Union,
30 cast,
31 overload,
32)
34import numpy as np
35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
37from loguru import logger
38from numpy.typing import NDArray
39from pydantic import (
40 AfterValidator,
41 Discriminator,
42 Field,
43 RootModel,
44 SerializationInfo,
45 SerializerFunctionWrapHandler,
46 StrictInt,
47 Tag,
48 ValidationInfo,
49 WrapSerializer,
50 field_validator,
51 model_serializer,
52 model_validator,
53)
54from typing_extensions import Annotated, Self, assert_never, get_args
56from .._internal.common_nodes import (
57 InvalidDescr,
58 KwargsNode,
59 Node,
60 NodeWithExplicitlySetFields,
61)
62from .._internal.constants import DTYPE_LIMITS
63from .._internal.field_warning import issue_warning, warn
64from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
65from .._internal.io import FileDescr as FileDescr
66from .._internal.io import (
67 FileSource,
68 WithSuffix,
69 YamlValue,
70 extract_file_name,
71 get_reader,
72 wo_special_file_name,
73)
74from .._internal.io_basics import Sha256 as Sha256
75from .._internal.io_packaging import (
76 FileDescr_,
77 FileSource_,
78 package_file_descr_serializer,
79)
80from .._internal.io_utils import load_array
81from .._internal.node_converter import Converter
82from .._internal.type_guards import is_dict, is_sequence
83from .._internal.types import (
84 FAIR,
85 AbsoluteTolerance,
86 LowerCaseIdentifier,
87 LowerCaseIdentifierAnno,
88 MismatchedElementsPerMillion,
89 RelativeTolerance,
90)
91from .._internal.types import Datetime as Datetime
92from .._internal.types import Identifier as Identifier
93from .._internal.types import NotEmpty as NotEmpty
94from .._internal.types import SiUnit as SiUnit
95from .._internal.url import HttpUrl as HttpUrl
96from .._internal.validation_context import get_validation_context
97from .._internal.validator_annotations import RestrictCharacters
98from .._internal.version_type import Version as Version
99from .._internal.warning_levels import INFO
100from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
101from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
102from ..dataset.v0_3 import DatasetDescr as DatasetDescr
103from ..dataset.v0_3 import DatasetId as DatasetId
104from ..dataset.v0_3 import LinkedDataset as LinkedDataset
105from ..dataset.v0_3 import Uploader as Uploader
106from ..generic.v0_3 import (
107 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
108)
109from ..generic.v0_3 import Author as Author
110from ..generic.v0_3 import BadgeDescr as BadgeDescr
111from ..generic.v0_3 import CiteEntry as CiteEntry
112from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
113from ..generic.v0_3 import Doi as Doi
114from ..generic.v0_3 import (
115 FileSource_documentation,
116 GenericModelDescrBase,
117 LinkedResourceBase,
118 _author_conv, # pyright: ignore[reportPrivateUsage]
119 _maintainer_conv, # pyright: ignore[reportPrivateUsage]
120)
121from ..generic.v0_3 import LicenseId as LicenseId
122from ..generic.v0_3 import LinkedResource as LinkedResource
123from ..generic.v0_3 import Maintainer as Maintainer
124from ..generic.v0_3 import OrcidId as OrcidId
125from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
126from ..generic.v0_3 import ResourceId as ResourceId
127from .v0_4 import Author as _Author_v0_4
128from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
129from .v0_4 import CallableFromDepencency as CallableFromDepencency
130from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
131from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
132from .v0_4 import ClipDescr as _ClipDescr_v0_4
133from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
134from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
135from .v0_4 import KnownRunMode as KnownRunMode
136from .v0_4 import ModelDescr as _ModelDescr_v0_4
137from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
138from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
139from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
140from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
141from .v0_4 import RunMode as RunMode
142from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
143from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
144from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
145from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
146from .v0_4 import TensorName as _TensorName_v0_4
147from .v0_4 import WeightsFormat as WeightsFormat
148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
149from .v0_4 import package_weights
151SpaceUnit = Literal[
152 "attometer",
153 "angstrom",
154 "centimeter",
155 "decimeter",
156 "exameter",
157 "femtometer",
158 "foot",
159 "gigameter",
160 "hectometer",
161 "inch",
162 "kilometer",
163 "megameter",
164 "meter",
165 "micrometer",
166 "mile",
167 "millimeter",
168 "nanometer",
169 "parsec",
170 "petameter",
171 "picometer",
172 "terameter",
173 "yard",
174 "yoctometer",
175 "yottameter",
176 "zeptometer",
177 "zettameter",
178]
179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
181TimeUnit = Literal[
182 "attosecond",
183 "centisecond",
184 "day",
185 "decisecond",
186 "exasecond",
187 "femtosecond",
188 "gigasecond",
189 "hectosecond",
190 "hour",
191 "kilosecond",
192 "megasecond",
193 "microsecond",
194 "millisecond",
195 "minute",
196 "nanosecond",
197 "petasecond",
198 "picosecond",
199 "second",
200 "terasecond",
201 "yoctosecond",
202 "yottasecond",
203 "zeptosecond",
204 "zettasecond",
205]
206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
208AxisType = Literal["batch", "channel", "index", "time", "space"]
210_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
211 "b": "batch",
212 "t": "time",
213 "i": "index",
214 "c": "channel",
215 "x": "space",
216 "y": "space",
217 "z": "space",
218}
220_AXIS_ID_MAP = {
221 "b": "batch",
222 "t": "time",
223 "i": "index",
224 "c": "channel",
225}
228class TensorId(LowerCaseIdentifier):
229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
230 Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
231 ]
234def _normalize_axis_id(a: str):
235 a = str(a)
236 normalized = _AXIS_ID_MAP.get(a, a)
237 if a != normalized:
238 logger.opt(depth=3).warning(
239 "Normalized axis id from '{}' to '{}'.", a, normalized
240 )
241 return normalized
244class AxisId(LowerCaseIdentifier):
245 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
246 Annotated[
247 LowerCaseIdentifierAnno,
248 MaxLen(16),
249 AfterValidator(_normalize_axis_id),
250 ]
251 ]
254def _is_batch(a: str) -> bool:
255 return str(a) == "batch"
258def _is_not_batch(a: str) -> bool:
259 return not _is_batch(a)
262NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
264PreprocessingId = Literal[
265 "binarize",
266 "clip",
267 "ensure_dtype",
268 "fixed_zero_mean_unit_variance",
269 "scale_linear",
270 "scale_range",
271 "sigmoid",
272 "softmax",
273]
274PostprocessingId = Literal[
275 "binarize",
276 "clip",
277 "ensure_dtype",
278 "fixed_zero_mean_unit_variance",
279 "scale_linear",
280 "scale_mean_variance",
281 "scale_range",
282 "sigmoid",
283 "softmax",
284 "zero_mean_unit_variance",
285]
288SAME_AS_TYPE = "<same as type>"
291ParameterizedSize_N = int
292"""
293Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
294"""
297class ParameterizedSize(Node):
298 """Describes a range of valid tensor axis sizes as `size = min + n*step`.
300 - **min** and **step** are given by the model description.
301 - All blocksize paramters n = 0,1,2,... yield a valid `size`.
302 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
303 This allows to adjust the axis size more generically.
304 """
306 N: ClassVar[Type[int]] = ParameterizedSize_N
307 """Positive integer to parameterize this axis"""
309 min: Annotated[int, Gt(0)]
310 step: Annotated[int, Gt(0)]
312 def validate_size(self, size: int) -> int:
313 if size < self.min:
314 raise ValueError(f"size {size} < {self.min}")
315 if (size - self.min) % self.step != 0:
316 raise ValueError(
317 f"axis of size {size} is not parameterized by `min + n*step` ="
318 + f" `{self.min} + n*{self.step}`"
319 )
321 return size
323 def get_size(self, n: ParameterizedSize_N) -> int:
324 return self.min + self.step * n
326 def get_n(self, s: int) -> ParameterizedSize_N:
327 """return smallest n parameterizing a size greater or equal than `s`"""
328 return ceil((s - self.min) / self.step)
331class DataDependentSize(Node):
332 min: Annotated[int, Gt(0)] = 1
333 max: Annotated[Optional[int], Gt(1)] = None
335 @model_validator(mode="after")
336 def _validate_max_gt_min(self):
337 if self.max is not None and self.min >= self.max:
338 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
340 return self
342 def validate_size(self, size: int) -> int:
343 if size < self.min:
344 raise ValueError(f"size {size} < {self.min}")
346 if self.max is not None and size > self.max:
347 raise ValueError(f"size {size} > {self.max}")
349 return size
352class SizeReference(Node):
353 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
355 `axis.size = reference.size * reference.scale / axis.scale + offset`
357 Note:
358 1. The axis and the referenced axis need to have the same unit (or no unit).
359 2. Batch axes may not be referenced.
360 3. Fractions are rounded down.
361 4. If the reference axis is `concatenable` the referencing axis is assumed to be
362 `concatenable` as well with the same block order.
364 Example:
365 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
366 Let's assume that we want to express the image height h in relation to its width w
367 instead of only accepting input images of exactly 100*49 pixels
368 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
370 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
371 >>> h = SpaceInputAxis(
372 ... id=AxisId("h"),
373 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
374 ... unit="millimeter",
375 ... scale=4,
376 ... )
377 >>> print(h.size.get_size(h, w))
378 49
380 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
381 """
383 tensor_id: TensorId
384 """tensor id of the reference axis"""
386 axis_id: AxisId
387 """axis id of the reference axis"""
389 offset: StrictInt = 0
391 def get_size(
392 self,
393 axis: Union[
394 ChannelAxis,
395 IndexInputAxis,
396 IndexOutputAxis,
397 TimeInputAxis,
398 SpaceInputAxis,
399 TimeOutputAxis,
400 TimeOutputAxisWithHalo,
401 SpaceOutputAxis,
402 SpaceOutputAxisWithHalo,
403 ],
404 ref_axis: Union[
405 ChannelAxis,
406 IndexInputAxis,
407 IndexOutputAxis,
408 TimeInputAxis,
409 SpaceInputAxis,
410 TimeOutputAxis,
411 TimeOutputAxisWithHalo,
412 SpaceOutputAxis,
413 SpaceOutputAxisWithHalo,
414 ],
415 n: ParameterizedSize_N = 0,
416 ref_size: Optional[int] = None,
417 ):
418 """Compute the concrete size for a given axis and its reference axis.
420 Args:
421 axis: The axis this [SizeReference][] is the size of.
422 ref_axis: The reference axis to compute the size from.
423 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
424 and no fixed **ref_size** is given,
425 **n** is used to compute the size of the parameterized **ref_axis**.
426 ref_size: Overwrite the reference size instead of deriving it from
427 **ref_axis**
428 (**ref_axis.scale** is still used; any given **n** is ignored).
429 """
430 assert axis.size == self, (
431 "Given `axis.size` is not defined by this `SizeReference`"
432 )
434 assert ref_axis.id == self.axis_id, (
435 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
436 )
438 assert axis.unit == ref_axis.unit, (
439 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
440 f" but {axis.unit}!={ref_axis.unit}"
441 )
442 if ref_size is None:
443 if isinstance(ref_axis.size, (int, float)):
444 ref_size = ref_axis.size
445 elif isinstance(ref_axis.size, ParameterizedSize):
446 ref_size = ref_axis.size.get_size(n)
447 elif isinstance(ref_axis.size, DataDependentSize):
448 raise ValueError(
449 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
450 )
451 elif isinstance(ref_axis.size, SizeReference):
452 raise ValueError(
453 "Reference axis referenced in `SizeReference` may not be sized by a"
454 + " `SizeReference` itself."
455 )
456 else:
457 assert_never(ref_axis.size)
459 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
461 @staticmethod
462 def _get_unit(
463 axis: Union[
464 ChannelAxis,
465 IndexInputAxis,
466 IndexOutputAxis,
467 TimeInputAxis,
468 SpaceInputAxis,
469 TimeOutputAxis,
470 TimeOutputAxisWithHalo,
471 SpaceOutputAxis,
472 SpaceOutputAxisWithHalo,
473 ],
474 ):
475 return axis.unit
478class AxisBase(NodeWithExplicitlySetFields):
479 id: AxisId
480 """An axis id unique across all axes of one tensor."""
482 description: Annotated[str, MaxLen(128)] = ""
483 """A short description of this axis beyond its type and id."""
486class WithHalo(Node):
487 halo: Annotated[int, Ge(1)]
488 """The halo should be cropped from the output tensor to avoid boundary effects.
489 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
490 To document a halo that is already cropped by the model use `size.offset` instead."""
492 size: Annotated[
493 SizeReference,
494 Field(
495 examples=[
496 10,
497 SizeReference(
498 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
499 ).model_dump(mode="json"),
500 ]
501 ),
502 ]
503 """reference to another axis with an optional offset (see [SizeReference][])"""
506BATCH_AXIS_ID = AxisId("batch")
509class BatchAxis(AxisBase):
510 implemented_type: ClassVar[Literal["batch"]] = "batch"
511 if TYPE_CHECKING:
512 type: Literal["batch"] = "batch"
513 else:
514 type: Literal["batch"]
516 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
517 size: Optional[Literal[1]] = None
518 """The batch size may be fixed to 1,
519 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
521 @property
522 def scale(self):
523 return 1.0
525 @property
526 def concatenable(self):
527 return True
529 @property
530 def unit(self):
531 return None
534class ChannelAxis(AxisBase):
535 implemented_type: ClassVar[Literal["channel"]] = "channel"
536 if TYPE_CHECKING:
537 type: Literal["channel"] = "channel"
538 else:
539 type: Literal["channel"]
541 id: NonBatchAxisId = AxisId("channel")
543 channel_names: NotEmpty[List[Identifier]]
545 @property
546 def size(self) -> int:
547 return len(self.channel_names)
549 @property
550 def concatenable(self):
551 return False
553 @property
554 def scale(self) -> float:
555 return 1.0
557 @property
558 def unit(self):
559 return None
562class IndexAxisBase(AxisBase):
563 implemented_type: ClassVar[Literal["index"]] = "index"
564 if TYPE_CHECKING:
565 type: Literal["index"] = "index"
566 else:
567 type: Literal["index"]
569 id: NonBatchAxisId = AxisId("index")
571 @property
572 def scale(self) -> float:
573 return 1.0
575 @property
576 def unit(self):
577 return None
580class _WithInputAxisSize(Node):
581 size: Annotated[
582 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
583 Field(
584 examples=[
585 10,
586 ParameterizedSize(min=32, step=16).model_dump(mode="json"),
587 SizeReference(
588 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
589 ).model_dump(mode="json"),
590 ]
591 ),
592 ]
593 """The size/length of this axis can be specified as
594 - fixed integer
595 - parameterized series of valid sizes ([ParameterizedSize][])
596 - reference to another axis with an optional offset ([SizeReference][])
597 """
600class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
601 concatenable: bool = False
602 """If a model has a `concatenable` input axis, it can be processed blockwise,
603 splitting a longer sample axis into blocks matching its input tensor description.
604 Output axes are concatenable if they have a [SizeReference][] to a concatenable
605 input axis.
606 """
609class IndexOutputAxis(IndexAxisBase):
610 size: Annotated[
611 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
612 Field(
613 examples=[
614 10,
615 SizeReference(
616 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
617 ).model_dump(mode="json"),
618 ]
619 ),
620 ]
621 """The size/length of this axis can be specified as
622 - fixed integer
623 - reference to another axis with an optional offset ([SizeReference][])
624 - data dependent size using [DataDependentSize][] (size is only known after model inference)
625 """
628class TimeAxisBase(AxisBase):
629 implemented_type: ClassVar[Literal["time"]] = "time"
630 if TYPE_CHECKING:
631 type: Literal["time"] = "time"
632 else:
633 type: Literal["time"]
635 id: NonBatchAxisId = AxisId("time")
636 unit: Optional[TimeUnit] = None
637 scale: Annotated[float, Gt(0)] = 1.0
640class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
641 concatenable: bool = False
642 """If a model has a `concatenable` input axis, it can be processed blockwise,
643 splitting a longer sample axis into blocks matching its input tensor description.
644 Output axes are concatenable if they have a [SizeReference][] to a concatenable
645 input axis.
646 """
649class SpaceAxisBase(AxisBase):
650 implemented_type: ClassVar[Literal["space"]] = "space"
651 if TYPE_CHECKING:
652 type: Literal["space"] = "space"
653 else:
654 type: Literal["space"]
656 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
657 unit: Optional[SpaceUnit] = None
658 scale: Annotated[float, Gt(0)] = 1.0
661class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
662 concatenable: bool = False
663 """If a model has a `concatenable` input axis, it can be processed blockwise,
664 splitting a longer sample axis into blocks matching its input tensor description.
665 Output axes are concatenable if they have a [SizeReference][] to a concatenable
666 input axis.
667 """
670INPUT_AXIS_TYPES = (
671 BatchAxis,
672 ChannelAxis,
673 IndexInputAxis,
674 TimeInputAxis,
675 SpaceInputAxis,
676)
677"""intended for isinstance comparisons in py<3.10"""
679_InputAxisUnion = Union[
680 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
681]
682InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
685class _WithOutputAxisSize(Node):
686 size: Annotated[
687 Union[Annotated[int, Gt(0)], SizeReference],
688 Field(
689 examples=[
690 10,
691 SizeReference(
692 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
693 ).model_dump(mode="json"),
694 ]
695 ),
696 ]
697 """The size/length of this axis can be specified as
698 - fixed integer
699 - reference to another axis with an optional offset (see [SizeReference][])
700 """
703class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
704 pass
707class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
708 pass
711def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
712 if isinstance(v, dict):
713 return "with_halo" if "halo" in v else "wo_halo"
714 else:
715 return "with_halo" if hasattr(v, "halo") else "wo_halo"
718_TimeOutputAxisUnion = Annotated[
719 Union[
720 Annotated[TimeOutputAxis, Tag("wo_halo")],
721 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
722 ],
723 Discriminator(_get_halo_axis_discriminator_value),
724]
727class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
728 pass
731class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
732 pass
735_SpaceOutputAxisUnion = Annotated[
736 Union[
737 Annotated[SpaceOutputAxis, Tag("wo_halo")],
738 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
739 ],
740 Discriminator(_get_halo_axis_discriminator_value),
741]
744_OutputAxisUnion = Union[
745 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
746]
747OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
749OUTPUT_AXIS_TYPES = (
750 BatchAxis,
751 ChannelAxis,
752 IndexOutputAxis,
753 TimeOutputAxis,
754 TimeOutputAxisWithHalo,
755 SpaceOutputAxis,
756 SpaceOutputAxisWithHalo,
757)
758"""intended for isinstance comparisons in py<3.10"""
761AnyAxis = Union[InputAxis, OutputAxis]
763ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
764"""intended for isinstance comparisons in py<3.10"""
766TVs = Union[
767 NotEmpty[List[int]],
768 NotEmpty[List[float]],
769 NotEmpty[List[bool]],
770 NotEmpty[List[str]],
771]
774NominalOrOrdinalDType = Literal[
775 "float32",
776 "float64",
777 "uint8",
778 "int8",
779 "uint16",
780 "int16",
781 "uint32",
782 "int32",
783 "uint64",
784 "int64",
785 "bool",
786]
789class NominalOrOrdinalDataDescr(Node):
790 values: TVs
791 """A fixed set of nominal or an ascending sequence of ordinal values.
792 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
793 String `values` are interpreted as labels for tensor values 0, ..., N.
794 Note: as YAML 1.2 does not natively support a "set" datatype,
795 nominal values should be given as a sequence (aka list/array) as well.
796 """
798 type: Annotated[
799 NominalOrOrdinalDType,
800 Field(
801 examples=[
802 "float32",
803 "uint8",
804 "uint16",
805 "int64",
806 "bool",
807 ],
808 ),
809 ] = "uint8"
811 @model_validator(mode="after")
812 def _validate_values_match_type(
813 self,
814 ) -> Self:
815 incompatible: List[Any] = []
816 for v in self.values:
817 if self.type == "bool":
818 if not isinstance(v, bool):
819 incompatible.append(v)
820 elif self.type in DTYPE_LIMITS:
821 if (
822 isinstance(v, (int, float))
823 and (
824 v < DTYPE_LIMITS[self.type].min
825 or v > DTYPE_LIMITS[self.type].max
826 )
827 or (isinstance(v, str) and "uint" not in self.type)
828 or (isinstance(v, float) and "int" in self.type)
829 ):
830 incompatible.append(v)
831 else:
832 incompatible.append(v)
834 if len(incompatible) == 5:
835 incompatible.append("...")
836 break
838 if incompatible:
839 raise ValueError(
840 f"data type '{self.type}' incompatible with values {incompatible}"
841 )
843 return self
845 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
847 @property
848 def range(self):
849 if isinstance(self.values[0], str):
850 return 0, len(self.values) - 1
851 else:
852 return min(self.values), max(self.values)
855IntervalOrRatioDType = Literal[
856 "float32",
857 "float64",
858 "uint8",
859 "int8",
860 "uint16",
861 "int16",
862 "uint32",
863 "int32",
864 "uint64",
865 "int64",
866]
869class IntervalOrRatioDataDescr(Node):
870 type: Annotated[ # TODO: rename to dtype
871 IntervalOrRatioDType,
872 Field(
873 examples=["float32", "float64", "uint8", "uint16"],
874 ),
875 ] = "float32"
876 range: Tuple[Optional[float], Optional[float]] = (
877 None,
878 None,
879 )
880 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
881 `None` corresponds to min/max of what can be expressed by **type**."""
882 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
883 scale: float = 1.0
884 """Scale for data on an interval (or ratio) scale."""
885 offset: Optional[float] = None
886 """Offset for data on a ratio scale."""
888 @model_validator(mode="before")
889 def _replace_inf(cls, data: Any):
890 if is_dict(data):
891 if "range" in data and is_sequence(data["range"]):
892 forbidden = (
893 "inf",
894 "-inf",
895 ".inf",
896 "-.inf",
897 float("inf"),
898 float("-inf"),
899 )
900 if any(v in forbidden for v in data["range"]):
901 issue_warning("replaced 'inf' value", value=data["range"])
903 data["range"] = tuple(
904 (None if v in forbidden else v) for v in data["range"]
905 )
907 return data
910TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
913class BinarizeKwargs(KwargsNode):
914 """key word arguments for [BinarizeDescr][]"""
916 threshold: float
917 """The fixed threshold"""
920class BinarizeAlongAxisKwargs(KwargsNode):
921 """key word arguments for [BinarizeDescr][]"""
923 threshold: NotEmpty[List[float]]
924 """The fixed threshold values along `axis`"""
926 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
927 """The `threshold` axis"""
930class BinarizeDescr(NodeWithExplicitlySetFields):
931 """Binarize the tensor with a fixed threshold.
933 Values above [BinarizeKwargs.threshold][]/[BinarizeAlongAxisKwargs.threshold][]
934 will be set to one, values below the threshold to zero.
936 Examples:
937 - in YAML
938 ```yaml
939 postprocessing:
940 - id: binarize
941 kwargs:
942 axis: 'channel'
943 threshold: [0.25, 0.5, 0.75]
944 ```
945 - in Python:
946 >>> postprocessing = [BinarizeDescr(
947 ... kwargs=BinarizeAlongAxisKwargs(
948 ... axis=AxisId('channel'),
949 ... threshold=[0.25, 0.5, 0.75],
950 ... )
951 ... )]
952 """
954 implemented_id: ClassVar[Literal["binarize"]] = "binarize"
955 if TYPE_CHECKING:
956 id: Literal["binarize"] = "binarize"
957 else:
958 id: Literal["binarize"]
959 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
962class ClipKwargs(KwargsNode):
963 """key word arguments for [ClipDescr][]"""
965 min: Optional[float] = None
966 """Minimum value for clipping.
968 Exclusive with [min_percentile][]
969 """
970 min_percentile: Optional[Annotated[float, Interval(ge=0, lt=100)]] = None
971 """Minimum percentile for clipping.
973 Exclusive with [min][].
975 In range [0, 100).
976 """
978 max: Optional[float] = None
979 """Maximum value for clipping.
981 Exclusive with `max_percentile`.
982 """
983 max_percentile: Optional[Annotated[float, Interval(gt=1, le=100)]] = None
984 """Maximum percentile for clipping.
986 Exclusive with `max`.
988 In range (1, 100].
989 """
991 axes: Annotated[
992 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
993 ] = None
994 """The subset of axes to determine percentiles jointly,
996 i.e. axes to reduce to compute min/max from `min_percentile`/`max_percentile`.
997 For example to clip 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
998 resulting in a tensor of equal shape with clipped values per channel, specify `axes=('batch', 'x', 'y')`.
999 To clip samples independently, leave out the 'batch' axis.
1001 Only valid if `min_percentile` and/or `max_percentile` are set.
1003 Default: Compute percentiles over all axes jointly."""
1005 @model_validator(mode="after")
1006 def _validate(self) -> Self:
1007 if (self.min is not None) and (self.min_percentile is not None):
1008 raise ValueError(
1009 "Only one of `min` and `min_percentile` may be set, not both."
1010 )
1011 if (self.max is not None) and (self.max_percentile is not None):
1012 raise ValueError(
1013 "Only one of `max` and `max_percentile` may be set, not both."
1014 )
1015 if (
1016 self.min is None
1017 and self.min_percentile is None
1018 and self.max is None
1019 and self.max_percentile is None
1020 ):
1021 raise ValueError(
1022 "At least one of `min`, `min_percentile`, `max`, or `max_percentile` must be set."
1023 )
1025 if (
1026 self.axes is not None
1027 and self.min_percentile is None
1028 and self.max_percentile is None
1029 ):
1030 raise ValueError(
1031 "If `axes` is set, at least one of `min_percentile` or `max_percentile` must be set."
1032 )
1034 return self
1037class ClipDescr(NodeWithExplicitlySetFields):
1038 """Set tensor values below min to min and above max to max.
1040 See `ScaleRangeDescr` for examples.
1041 """
1043 implemented_id: ClassVar[Literal["clip"]] = "clip"
1044 if TYPE_CHECKING:
1045 id: Literal["clip"] = "clip"
1046 else:
1047 id: Literal["clip"]
1049 kwargs: ClipKwargs
1052class EnsureDtypeKwargs(KwargsNode):
1053 """key word arguments for [EnsureDtypeDescr][]"""
1055 dtype: Literal[
1056 "float32",
1057 "float64",
1058 "uint8",
1059 "int8",
1060 "uint16",
1061 "int16",
1062 "uint32",
1063 "int32",
1064 "uint64",
1065 "int64",
1066 "bool",
1067 ]
1070class EnsureDtypeDescr(NodeWithExplicitlySetFields):
1071 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1073 This can for example be used to ensure the inner neural network model gets a
1074 different input tensor data type than the fully described bioimage.io model does.
1076 Examples:
1077 The described bioimage.io model (incl. preprocessing) accepts any
1078 float32-compatible tensor, normalizes it with percentiles and clipping and then
1079 casts it to uint8, which is what the neural network in this example expects.
1080 - in YAML
1081 ```yaml
1082 inputs:
1083 - data:
1084 type: float32 # described bioimage.io model is compatible with any float32 input tensor
1085 preprocessing:
1086 - id: scale_range
1087 kwargs:
1088 axes: ['y', 'x']
1089 max_percentile: 99.8
1090 min_percentile: 5.0
1091 - id: clip
1092 kwargs:
1093 min: 0.0
1094 max: 1.0
1095 - id: ensure_dtype # the neural network of the model requires uint8
1096 kwargs:
1097 dtype: uint8
1098 ```
1099 - in Python:
1100 >>> preprocessing = [
1101 ... ScaleRangeDescr(
1102 ... kwargs=ScaleRangeKwargs(
1103 ... axes= (AxisId('y'), AxisId('x')),
1104 ... max_percentile= 99.8,
1105 ... min_percentile= 5.0,
1106 ... )
1107 ... ),
1108 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1109 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1110 ... ]
1111 """
1113 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1114 if TYPE_CHECKING:
1115 id: Literal["ensure_dtype"] = "ensure_dtype"
1116 else:
1117 id: Literal["ensure_dtype"]
1119 kwargs: EnsureDtypeKwargs
1122class ScaleLinearKwargs(KwargsNode):
1123 """Key word arguments for [ScaleLinearDescr][]"""
1125 gain: float = 1.0
1126 """multiplicative factor"""
1128 offset: float = 0.0
1129 """additive term"""
1131 @model_validator(mode="after")
1132 def _validate(self) -> Self:
1133 if self.gain == 1.0 and self.offset == 0.0:
1134 raise ValueError(
1135 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1136 + " != 0.0."
1137 )
1139 return self
1142class ScaleLinearAlongAxisKwargs(KwargsNode):
1143 """Key word arguments for [ScaleLinearDescr][]"""
1145 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1146 """The axis of gain and offset values."""
1148 gain: Union[float, NotEmpty[List[float]]] = 1.0
1149 """multiplicative factor"""
1151 offset: Union[float, NotEmpty[List[float]]] = 0.0
1152 """additive term"""
1154 @model_validator(mode="after")
1155 def _validate(self) -> Self:
1156 if isinstance(self.gain, list):
1157 if isinstance(self.offset, list):
1158 if len(self.gain) != len(self.offset):
1159 raise ValueError(
1160 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1161 )
1162 else:
1163 self.offset = [float(self.offset)] * len(self.gain)
1164 elif isinstance(self.offset, list):
1165 self.gain = [float(self.gain)] * len(self.offset)
1166 else:
1167 raise ValueError(
1168 "Do not specify an `axis` for scalar gain and offset values."
1169 )
1171 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1172 raise ValueError(
1173 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1174 + " != 0.0."
1175 )
1177 return self
1180class ScaleLinearDescr(NodeWithExplicitlySetFields):
1181 """Fixed linear scaling.
1183 Examples:
1184 1. Scale with scalar gain and offset
1185 - in YAML
1186 ```yaml
1187 preprocessing:
1188 - id: scale_linear
1189 kwargs:
1190 gain: 2.0
1191 offset: 3.0
1192 ```
1193 - in Python:
1194 >>> preprocessing = [
1195 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1196 ... ]
1198 2. Independent scaling along an axis
1199 - in YAML
1200 ```yaml
1201 preprocessing:
1202 - id: scale_linear
1203 kwargs:
1204 axis: 'channel'
1205 gain: [1.0, 2.0, 3.0]
1206 ```
1207 - in Python:
1208 >>> preprocessing = [
1209 ... ScaleLinearDescr(
1210 ... kwargs=ScaleLinearAlongAxisKwargs(
1211 ... axis=AxisId("channel"),
1212 ... gain=[1.0, 2.0, 3.0],
1213 ... )
1214 ... )
1215 ... ]
1217 """
1219 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1220 if TYPE_CHECKING:
1221 id: Literal["scale_linear"] = "scale_linear"
1222 else:
1223 id: Literal["scale_linear"]
1224 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1227class SigmoidDescr(NodeWithExplicitlySetFields):
1228 """The logistic sigmoid function, a.k.a. expit function.
1230 Examples:
1231 - in YAML
1232 ```yaml
1233 postprocessing:
1234 - id: sigmoid
1235 ```
1236 - in Python:
1237 >>> postprocessing = [SigmoidDescr()]
1238 """
1240 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1241 if TYPE_CHECKING:
1242 id: Literal["sigmoid"] = "sigmoid"
1243 else:
1244 id: Literal["sigmoid"]
1246 @property
1247 def kwargs(self) -> KwargsNode:
1248 """empty kwargs"""
1249 return KwargsNode()
1252class SoftmaxKwargs(KwargsNode):
1253 """key word arguments for [SoftmaxDescr][]"""
1255 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1256 """The axis to apply the softmax function along.
1257 Note:
1258 Defaults to 'channel' axis
1259 (which may not exist, in which case
1260 a different axis id has to be specified).
1261 """
1264class SoftmaxDescr(NodeWithExplicitlySetFields):
1265 """The softmax function.
1267 Examples:
1268 - in YAML
1269 ```yaml
1270 postprocessing:
1271 - id: softmax
1272 kwargs:
1273 axis: channel
1274 ```
1275 - in Python:
1276 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1277 """
1279 implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1280 if TYPE_CHECKING:
1281 id: Literal["softmax"] = "softmax"
1282 else:
1283 id: Literal["softmax"]
1285 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
1288class FixedZeroMeanUnitVarianceKwargs(KwargsNode):
1289 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]"""
1291 mean: float
1292 """The mean value to normalize with."""
1294 std: Annotated[float, Ge(1e-6)]
1295 """The standard deviation value to normalize with."""
1298class FixedZeroMeanUnitVarianceAlongAxisKwargs(KwargsNode):
1299 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]"""
1301 mean: NotEmpty[List[float]]
1302 """The mean value(s) to normalize with."""
1304 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1305 """The standard deviation value(s) to normalize with.
1306 Size must match `mean` values."""
1308 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1309 """The axis of the mean/std values to normalize each entry along that dimension
1310 separately."""
1312 @model_validator(mode="after")
1313 def _mean_and_std_match(self) -> Self:
1314 if len(self.mean) != len(self.std):
1315 raise ValueError(
1316 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1317 + " must match."
1318 )
1320 return self
1323class FixedZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields):
1324 """Subtract a given mean and divide by the standard deviation.
1326 Normalize with fixed, precomputed values for
1327 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1328 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1329 axes.
1331 Examples:
1332 1. scalar value for whole tensor
1333 - in YAML
1334 ```yaml
1335 preprocessing:
1336 - id: fixed_zero_mean_unit_variance
1337 kwargs:
1338 mean: 103.5
1339 std: 13.7
1340 ```
1341 - in Python
1342 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1343 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1344 ... )]
1346 2. independently along an axis
1347 - in YAML
1348 ```yaml
1349 preprocessing:
1350 - id: fixed_zero_mean_unit_variance
1351 kwargs:
1352 axis: channel
1353 mean: [101.5, 102.5, 103.5]
1354 std: [11.7, 12.7, 13.7]
1355 ```
1356 - in Python
1357 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1358 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1359 ... axis=AxisId("channel"),
1360 ... mean=[101.5, 102.5, 103.5],
1361 ... std=[11.7, 12.7, 13.7],
1362 ... )
1363 ... )]
1364 """
1366 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1367 "fixed_zero_mean_unit_variance"
1368 )
1369 if TYPE_CHECKING:
1370 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1371 else:
1372 id: Literal["fixed_zero_mean_unit_variance"]
1374 kwargs: Union[
1375 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1376 ]
1379class ZeroMeanUnitVarianceKwargs(KwargsNode):
1380 """key word arguments for [ZeroMeanUnitVarianceDescr][]"""
1382 axes: Annotated[
1383 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1384 ] = None
1385 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1386 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1387 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1388 To normalize each sample independently leave out the 'batch' axis.
1389 Default: Scale all axes jointly."""
1391 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1392 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1395class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields):
1396 """Subtract mean and divide by variance.
1398 Examples:
1399 Subtract tensor mean and variance
1400 - in YAML
1401 ```yaml
1402 preprocessing:
1403 - id: zero_mean_unit_variance
1404 ```
1405 - in Python
1406 >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1407 """
1409 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1410 "zero_mean_unit_variance"
1411 )
1412 if TYPE_CHECKING:
1413 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1414 else:
1415 id: Literal["zero_mean_unit_variance"]
1417 kwargs: ZeroMeanUnitVarianceKwargs = Field(
1418 default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1419 )
1422class ScaleRangeKwargs(KwargsNode):
1423 """key word arguments for [ScaleRangeDescr][]
1425 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1426 this processing step normalizes data to the [0, 1] intervall.
1427 For other percentiles the normalized values will partially be outside the [0, 1]
1428 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1429 normalized values to a range.
1430 """
1432 axes: Annotated[
1433 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1434 ] = None
1435 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1436 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1437 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1438 To normalize samples independently, leave out the "batch" axis.
1439 Default: Scale all axes jointly."""
1441 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1442 """The lower percentile used to determine the value to align with zero."""
1444 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1445 """The upper percentile used to determine the value to align with one.
1446 Has to be bigger than `min_percentile`.
1447 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1448 accepting percentiles specified in the range 0.0 to 1.0."""
1450 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1451 """Epsilon for numeric stability.
1452 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1453 with `v_lower,v_upper` values at the respective percentiles."""
1455 reference_tensor: Optional[TensorId] = None
1456 """Tensor ID to compute the percentiles from. Default: The tensor itself.
1457 For any tensor in `inputs` only input tensor references are allowed."""
1459 @field_validator("max_percentile", mode="after")
1460 @classmethod
1461 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1462 if (min_p := info.data["min_percentile"]) >= value:
1463 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1465 return value
1468class ScaleRangeDescr(NodeWithExplicitlySetFields):
1469 """Scale with percentiles.
1471 Examples:
1472 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1473 - in YAML
1474 ```yaml
1475 preprocessing:
1476 - id: scale_range
1477 kwargs:
1478 axes: ['y', 'x']
1479 max_percentile: 99.8
1480 min_percentile: 5.0
1481 ```
1482 - in Python
1483 >>> preprocessing = [
1484 ... ScaleRangeDescr(
1485 ... kwargs=ScaleRangeKwargs(
1486 ... axes= (AxisId('y'), AxisId('x')),
1487 ... max_percentile= 99.8,
1488 ... min_percentile= 5.0,
1489 ... )
1490 ... )
1491 ... ]
1493 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1494 - in YAML
1495 ```yaml
1496 preprocessing:
1497 - id: scale_range
1498 kwargs:
1499 axes: ['y', 'x']
1500 max_percentile: 99.8
1501 min_percentile: 5.0
1502 - id: scale_range
1503 - id: clip
1504 kwargs:
1505 min: 0.0
1506 max: 1.0
1507 ```
1508 - in Python
1509 >>> preprocessing = [
1510 ... ScaleRangeDescr(
1511 ... kwargs=ScaleRangeKwargs(
1512 ... axes= (AxisId('y'), AxisId('x')),
1513 ... max_percentile= 99.8,
1514 ... min_percentile= 5.0,
1515 ... )
1516 ... ),
1517 ... ClipDescr(
1518 ... kwargs=ClipKwargs(
1519 ... min=0.0,
1520 ... max=1.0,
1521 ... )
1522 ... ),
1523 ... ]
1525 """
1527 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1528 if TYPE_CHECKING:
1529 id: Literal["scale_range"] = "scale_range"
1530 else:
1531 id: Literal["scale_range"]
1532 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
1535class ScaleMeanVarianceKwargs(KwargsNode):
1536 """key word arguments for [ScaleMeanVarianceKwargs][]"""
1538 reference_tensor: TensorId
1539 """Name of tensor to match."""
1541 axes: Annotated[
1542 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1543 ] = None
1544 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1545 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1546 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1547 To normalize samples independently, leave out the 'batch' axis.
1548 Default: Scale all axes jointly."""
1550 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1551 """Epsilon for numeric stability:
1552 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1555class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields):
1556 """Scale a tensor's data distribution to match another tensor's mean/std.
1557 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1558 """
1560 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1561 if TYPE_CHECKING:
1562 id: Literal["scale_mean_variance"] = "scale_mean_variance"
1563 else:
1564 id: Literal["scale_mean_variance"]
1565 kwargs: ScaleMeanVarianceKwargs
1568PreprocessingDescr = Annotated[
1569 Union[
1570 BinarizeDescr,
1571 ClipDescr,
1572 EnsureDtypeDescr,
1573 FixedZeroMeanUnitVarianceDescr,
1574 ScaleLinearDescr,
1575 ScaleRangeDescr,
1576 SigmoidDescr,
1577 SoftmaxDescr,
1578 ZeroMeanUnitVarianceDescr,
1579 ],
1580 Discriminator("id"),
1581]
1582PostprocessingDescr = Annotated[
1583 Union[
1584 BinarizeDescr,
1585 ClipDescr,
1586 EnsureDtypeDescr,
1587 FixedZeroMeanUnitVarianceDescr,
1588 ScaleLinearDescr,
1589 ScaleMeanVarianceDescr,
1590 ScaleRangeDescr,
1591 SigmoidDescr,
1592 SoftmaxDescr,
1593 ZeroMeanUnitVarianceDescr,
1594 ],
1595 Discriminator("id"),
1596]
1598IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1601class TensorDescrBase(Node, Generic[IO_AxisT]):
1602 id: TensorId
1603 """Tensor id. No duplicates are allowed."""
1605 description: Annotated[str, MaxLen(128)] = ""
1606 """free text description"""
1608 axes: NotEmpty[Sequence[IO_AxisT]]
1609 """tensor axes"""
1611 @property
1612 def shape(self):
1613 return tuple(a.size for a in self.axes)
1615 @field_validator("axes", mode="after", check_fields=False)
1616 @classmethod
1617 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1618 batch_axes = [a for a in axes if a.type == "batch"]
1619 if len(batch_axes) > 1:
1620 raise ValueError(
1621 f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1622 )
1624 seen_ids: Set[AxisId] = set()
1625 duplicate_axes_ids: Set[AxisId] = set()
1626 for a in axes:
1627 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1629 if duplicate_axes_ids:
1630 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1632 return axes
1634 test_tensor: FAIR[Optional[FileDescr_]] = None
1635 """An example tensor to use for testing.
1636 Using the model with the test input tensors is expected to yield the test output tensors.
1637 Each test tensor has be a an ndarray in the
1638 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1639 The file extension must be '.npy'."""
1641 sample_tensor: FAIR[Optional[FileDescr_]] = None
1642 """A sample tensor to illustrate a possible input/output for the model,
1643 The sample image primarily serves to inform a human user about an example use case
1644 and is typically stored as .hdf5, .png or .tiff.
1645 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1646 (numpy's `.npy` format is not supported).
1647 The image dimensionality has to match the number of axes specified in this tensor description.
1648 """
1650 @model_validator(mode="after")
1651 def _validate_sample_tensor(self) -> Self:
1652 if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1653 return self
1655 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1656 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType]
1657 reader.read(),
1658 extension=PurePosixPath(reader.original_file_name).suffix,
1659 )
1660 n_dims = len(tensor.squeeze().shape)
1661 n_dims_min = n_dims_max = len(self.axes)
1663 for a in self.axes:
1664 if isinstance(a, BatchAxis):
1665 n_dims_min -= 1
1666 elif isinstance(a.size, int):
1667 if a.size == 1:
1668 n_dims_min -= 1
1669 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1670 if a.size.min == 1:
1671 n_dims_min -= 1
1672 elif isinstance(a.size, SizeReference):
1673 if a.size.offset < 2:
1674 # size reference may result in singleton axis
1675 n_dims_min -= 1
1676 else:
1677 assert_never(a.size)
1679 n_dims_min = max(0, n_dims_min)
1680 if n_dims < n_dims_min or n_dims > n_dims_max:
1681 raise ValueError(
1682 f"Expected sample tensor to have {n_dims_min} to"
1683 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1684 )
1686 return self
1688 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1689 IntervalOrRatioDataDescr()
1690 )
1691 """Description of the tensor's data values, optionally per channel.
1692 If specified per channel, the data `type` needs to match across channels."""
1694 @property
1695 def dtype(
1696 self,
1697 ) -> Literal[
1698 "float32",
1699 "float64",
1700 "uint8",
1701 "int8",
1702 "uint16",
1703 "int16",
1704 "uint32",
1705 "int32",
1706 "uint64",
1707 "int64",
1708 "bool",
1709 ]:
1710 """dtype as specified under `data.type` or `data[i].type`"""
1711 if isinstance(self.data, collections.abc.Sequence):
1712 return self.data[0].type
1713 else:
1714 return self.data.type
1716 @field_validator("data", mode="after")
1717 @classmethod
1718 def _check_data_type_across_channels(
1719 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1720 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1721 if not isinstance(value, list):
1722 return value
1724 dtypes = {t.type for t in value}
1725 if len(dtypes) > 1:
1726 raise ValueError(
1727 "Tensor data descriptions per channel need to agree in their data"
1728 + f" `type`, but found {dtypes}."
1729 )
1731 return value
1733 @model_validator(mode="after")
1734 def _check_data_matches_channelaxis(self) -> Self:
1735 if not isinstance(self.data, (list, tuple)):
1736 return self
1738 for a in self.axes:
1739 if isinstance(a, ChannelAxis):
1740 size = a.size
1741 assert isinstance(size, int)
1742 break
1743 else:
1744 return self
1746 if len(self.data) != size:
1747 raise ValueError(
1748 f"Got tensor data descriptions for {len(self.data)} channels, but"
1749 + f" '{a.id}' axis has size {size}."
1750 )
1752 return self
1754 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1755 if len(array.shape) != len(self.axes):
1756 raise ValueError(
1757 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1758 + f" incompatible with {len(self.axes)} axes."
1759 )
1760 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1763class InputTensorDescr(TensorDescrBase[InputAxis]):
1764 id: TensorId = TensorId("input")
1765 """Input tensor id.
1766 No duplicates are allowed across all inputs and outputs."""
1768 optional: bool = False
1769 """indicates that this tensor may be `None`"""
1771 preprocessing: List[PreprocessingDescr] = Field(
1772 default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1773 )
1775 """Description of how this input should be preprocessed.
1777 notes:
1778 - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1779 to ensure an input tensor's data type matches the input tensor's data description.
1780 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1781 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1782 changing the data type.
1783 """
1785 @model_validator(mode="after")
1786 def _validate_preprocessing_kwargs(self) -> Self:
1787 axes_ids = [a.id for a in self.axes]
1788 for p in self.preprocessing:
1789 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1790 if kwargs_axes is None:
1791 continue
1793 if not isinstance(kwargs_axes, collections.abc.Sequence):
1794 raise ValueError(
1795 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1796 )
1798 if any(a not in axes_ids for a in kwargs_axes):
1799 raise ValueError(
1800 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1801 )
1803 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1804 dtype = self.data.type
1805 else:
1806 dtype = self.data[0].type
1808 # ensure `preprocessing` begins with `EnsureDtypeDescr`
1809 if not self.preprocessing or not isinstance(
1810 self.preprocessing[0], EnsureDtypeDescr
1811 ):
1812 self.preprocessing.insert(
1813 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1814 )
1816 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1817 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1818 self.preprocessing.append(
1819 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1820 )
1822 return self
1825def convert_axes(
1826 axes: str,
1827 *,
1828 shape: Union[
1829 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1830 ],
1831 tensor_type: Literal["input", "output"],
1832 halo: Optional[Sequence[int]],
1833 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1834):
1835 ret: List[AnyAxis] = []
1836 for i, a in enumerate(axes):
1837 axis_type = _AXIS_TYPE_MAP.get(a, a)
1838 if axis_type == "batch":
1839 ret.append(BatchAxis())
1840 continue
1842 scale = 1.0
1843 if isinstance(shape, _ParameterizedInputShape_v0_4):
1844 if shape.step[i] == 0:
1845 size = shape.min[i]
1846 else:
1847 size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1848 elif isinstance(shape, _ImplicitOutputShape_v0_4):
1849 ref_t = str(shape.reference_tensor)
1850 if ref_t.count(".") == 1:
1851 t_id, orig_a_id = ref_t.split(".")
1852 else:
1853 t_id = ref_t
1854 orig_a_id = a
1856 a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1857 if not (orig_scale := shape.scale[i]):
1858 # old way to insert a new axis dimension
1859 size = int(2 * shape.offset[i])
1860 else:
1861 scale = 1 / orig_scale
1862 if axis_type in ("channel", "index"):
1863 # these axes no longer have a scale
1864 offset_from_scale = orig_scale * size_refs.get(
1865 _TensorName_v0_4(t_id), {}
1866 ).get(orig_a_id, 0)
1867 else:
1868 offset_from_scale = 0
1869 size = SizeReference(
1870 tensor_id=TensorId(t_id),
1871 axis_id=AxisId(a_id),
1872 offset=int(offset_from_scale + 2 * shape.offset[i]),
1873 )
1874 else:
1875 size = shape[i]
1877 if axis_type == "time":
1878 if tensor_type == "input":
1879 ret.append(TimeInputAxis(size=size, scale=scale))
1880 else:
1881 assert not isinstance(size, ParameterizedSize)
1882 if halo is None:
1883 ret.append(TimeOutputAxis(size=size, scale=scale))
1884 else:
1885 assert not isinstance(size, int)
1886 ret.append(
1887 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1888 )
1890 elif axis_type == "index":
1891 if tensor_type == "input":
1892 ret.append(IndexInputAxis(size=size))
1893 else:
1894 if isinstance(size, ParameterizedSize):
1895 size = DataDependentSize(min=size.min)
1897 ret.append(IndexOutputAxis(size=size))
1898 elif axis_type == "channel":
1899 assert not isinstance(size, ParameterizedSize)
1900 if isinstance(size, SizeReference):
1901 warnings.warn(
1902 "Conversion of channel size from an implicit output shape may be"
1903 + " wrong"
1904 )
1905 ret.append(
1906 ChannelAxis(
1907 channel_names=[
1908 Identifier(f"channel{i}") for i in range(size.offset)
1909 ]
1910 )
1911 )
1912 else:
1913 ret.append(
1914 ChannelAxis(
1915 channel_names=[Identifier(f"channel{i}") for i in range(size)]
1916 )
1917 )
1918 elif axis_type == "space":
1919 if tensor_type == "input":
1920 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1921 else:
1922 assert not isinstance(size, ParameterizedSize)
1923 if halo is None or halo[i] == 0:
1924 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1925 elif isinstance(size, int):
1926 raise NotImplementedError(
1927 f"output axis with halo and fixed size (here {size}) not allowed"
1928 )
1929 else:
1930 ret.append(
1931 SpaceOutputAxisWithHalo(
1932 id=AxisId(a), size=size, scale=scale, halo=halo[i]
1933 )
1934 )
1936 return ret
1939def _axes_letters_to_ids(
1940 axes: Optional[str],
1941) -> Optional[List[AxisId]]:
1942 if axes is None:
1943 return None
1945 return [AxisId(a) for a in axes]
1948def _get_complement_v04_axis(
1949 tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1950) -> Optional[AxisId]:
1951 if axes is None:
1952 return None
1954 non_complement_axes = set(axes) | {"b"}
1955 complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1956 if len(complement_axes) > 1:
1957 raise ValueError(
1958 f"Expected none or a single complement axis, but axes '{axes}' "
1959 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1960 )
1962 return None if not complement_axes else AxisId(complement_axes[0])
1965def _convert_proc(
1966 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1967 tensor_axes: Sequence[str],
1968) -> Union[PreprocessingDescr, PostprocessingDescr]:
1969 if isinstance(p, _BinarizeDescr_v0_4):
1970 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1971 elif isinstance(p, _ClipDescr_v0_4):
1972 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1973 elif isinstance(p, _SigmoidDescr_v0_4):
1974 return SigmoidDescr()
1975 elif isinstance(p, _ScaleLinearDescr_v0_4):
1976 axes = _axes_letters_to_ids(p.kwargs.axes)
1977 if p.kwargs.axes is None:
1978 axis = None
1979 else:
1980 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1982 if axis is None:
1983 assert not isinstance(p.kwargs.gain, list)
1984 assert not isinstance(p.kwargs.offset, list)
1985 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1986 else:
1987 kwargs = ScaleLinearAlongAxisKwargs(
1988 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1989 )
1990 return ScaleLinearDescr(kwargs=kwargs)
1991 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1992 return ScaleMeanVarianceDescr(
1993 kwargs=ScaleMeanVarianceKwargs(
1994 axes=_axes_letters_to_ids(p.kwargs.axes),
1995 reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1996 eps=p.kwargs.eps,
1997 )
1998 )
1999 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
2000 if p.kwargs.mode == "fixed":
2001 mean = p.kwargs.mean
2002 std = p.kwargs.std
2003 assert mean is not None
2004 assert std is not None
2006 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
2008 if axis is None:
2009 if isinstance(mean, list):
2010 raise ValueError("Expected single float value for mean, not <list>")
2011 if isinstance(std, list):
2012 raise ValueError("Expected single float value for std, not <list>")
2013 return FixedZeroMeanUnitVarianceDescr(
2014 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct(
2015 mean=mean,
2016 std=std,
2017 )
2018 )
2019 else:
2020 if not isinstance(mean, list):
2021 mean = [float(mean)]
2022 if not isinstance(std, list):
2023 std = [float(std)]
2025 return FixedZeroMeanUnitVarianceDescr(
2026 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
2027 axis=axis, mean=mean, std=std
2028 )
2029 )
2031 else:
2032 axes = _axes_letters_to_ids(p.kwargs.axes) or []
2033 if p.kwargs.mode == "per_dataset":
2034 axes = [AxisId("batch")] + axes
2035 if not axes:
2036 axes = None
2037 return ZeroMeanUnitVarianceDescr(
2038 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
2039 )
2041 elif isinstance(p, _ScaleRangeDescr_v0_4):
2042 return ScaleRangeDescr(
2043 kwargs=ScaleRangeKwargs(
2044 axes=_axes_letters_to_ids(p.kwargs.axes),
2045 min_percentile=p.kwargs.min_percentile,
2046 max_percentile=p.kwargs.max_percentile,
2047 eps=p.kwargs.eps,
2048 )
2049 )
2050 else:
2051 assert_never(p)
2054class _InputTensorConv(
2055 Converter[
2056 _InputTensorDescr_v0_4,
2057 InputTensorDescr,
2058 FileSource_,
2059 Optional[FileSource_],
2060 Mapping[_TensorName_v0_4, Mapping[str, int]],
2061 ]
2062):
2063 def _convert(
2064 self,
2065 src: _InputTensorDescr_v0_4,
2066 tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
2067 test_tensor: FileSource_,
2068 sample_tensor: Optional[FileSource_],
2069 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2070 ) -> "InputTensorDescr | dict[str, Any]":
2071 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2072 src.axes,
2073 shape=src.shape,
2074 tensor_type="input",
2075 halo=None,
2076 size_refs=size_refs,
2077 )
2078 prep: List[PreprocessingDescr] = []
2079 for p in src.preprocessing:
2080 cp = _convert_proc(p, src.axes)
2081 assert not isinstance(cp, ScaleMeanVarianceDescr)
2082 prep.append(cp)
2084 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
2086 return tgt(
2087 axes=axes,
2088 id=TensorId(str(src.name)),
2089 test_tensor=FileDescr(source=test_tensor),
2090 sample_tensor=(
2091 None if sample_tensor is None else FileDescr(source=sample_tensor)
2092 ),
2093 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType]
2094 preprocessing=prep,
2095 )
2098_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
2101class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2102 id: TensorId = TensorId("output")
2103 """Output tensor id.
2104 No duplicates are allowed across all inputs and outputs."""
2106 postprocessing: List[PostprocessingDescr] = Field(
2107 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2108 )
2109 """Description of how this output should be postprocessed.
2111 note: `postprocessing` always ends with an 'ensure_dtype' operation.
2112 If not given this is added to cast to this tensor's `data.type`.
2113 """
2115 @model_validator(mode="after")
2116 def _validate_postprocessing_kwargs(self) -> Self:
2117 axes_ids = [a.id for a in self.axes]
2118 for p in self.postprocessing:
2119 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2120 if kwargs_axes is None:
2121 continue
2123 if not isinstance(kwargs_axes, collections.abc.Sequence):
2124 raise ValueError(
2125 f"expected `axes` sequence, but got {type(kwargs_axes)}"
2126 )
2128 if any(a not in axes_ids for a in kwargs_axes):
2129 raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2131 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2132 dtype = self.data.type
2133 else:
2134 dtype = self.data[0].type
2136 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2137 if not self.postprocessing or not isinstance(
2138 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2139 ):
2140 self.postprocessing.append(
2141 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2142 )
2143 return self
2146class _OutputTensorConv(
2147 Converter[
2148 _OutputTensorDescr_v0_4,
2149 OutputTensorDescr,
2150 FileSource_,
2151 Optional[FileSource_],
2152 Mapping[_TensorName_v0_4, Mapping[str, int]],
2153 ]
2154):
2155 def _convert(
2156 self,
2157 src: _OutputTensorDescr_v0_4,
2158 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2159 test_tensor: FileSource_,
2160 sample_tensor: Optional[FileSource_],
2161 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2162 ) -> "OutputTensorDescr | dict[str, Any]":
2163 # TODO: split convert_axes into convert_output_axes and convert_input_axes
2164 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2165 src.axes,
2166 shape=src.shape,
2167 tensor_type="output",
2168 halo=src.halo,
2169 size_refs=size_refs,
2170 )
2171 data_descr: Dict[str, Any] = dict(type=src.data_type)
2172 if data_descr["type"] == "bool":
2173 data_descr["values"] = [False, True]
2175 return tgt(
2176 axes=axes,
2177 id=TensorId(str(src.name)),
2178 test_tensor=FileDescr(source=test_tensor),
2179 sample_tensor=(
2180 None if sample_tensor is None else FileDescr(source=sample_tensor)
2181 ),
2182 data=data_descr, # pyright: ignore[reportArgumentType]
2183 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2184 )
2187_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2190TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2193def validate_tensors(
2194 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2195 tensor_origin: Literal[
2196 "test_tensor"
2197 ], # for more precise error messages, e.g. 'test_tensor'
2198):
2199 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2201 def e_msg(d: TensorDescr):
2202 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2204 for descr, array in tensors.values():
2205 if array is None:
2206 axis_sizes = {a.id: None for a in descr.axes}
2207 else:
2208 try:
2209 axis_sizes = descr.get_axis_sizes_for_array(array)
2210 except ValueError as e:
2211 raise ValueError(f"{e_msg(descr)} {e}")
2213 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2215 for descr, array in tensors.values():
2216 if array is None:
2217 continue
2219 if descr.dtype in ("float32", "float64"):
2220 invalid_test_tensor_dtype = array.dtype.name not in (
2221 "float32",
2222 "float64",
2223 "uint8",
2224 "int8",
2225 "uint16",
2226 "int16",
2227 "uint32",
2228 "int32",
2229 "uint64",
2230 "int64",
2231 )
2232 else:
2233 invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2235 if invalid_test_tensor_dtype:
2236 raise ValueError(
2237 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2238 + f" match described dtype '{descr.dtype}'"
2239 )
2241 if array.min() > -1e-4 and array.max() < 1e-4:
2242 raise ValueError(
2243 "Output values are too small for reliable testing."
2244 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2245 )
2247 for a in descr.axes:
2248 actual_size = all_tensor_axes[descr.id][a.id][1]
2249 if actual_size is None:
2250 continue
2252 if a.size is None:
2253 continue
2255 if isinstance(a.size, int):
2256 if actual_size != a.size:
2257 raise ValueError(
2258 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2259 + f"has incompatible size {actual_size}, expected {a.size}"
2260 )
2261 elif isinstance(a.size, ParameterizedSize):
2262 _ = a.size.validate_size(actual_size)
2263 elif isinstance(a.size, DataDependentSize):
2264 _ = a.size.validate_size(actual_size)
2265 elif isinstance(a.size, SizeReference):
2266 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2267 if ref_tensor_axes is None:
2268 raise ValueError(
2269 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2270 + f" reference '{a.size.tensor_id}'"
2271 )
2273 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2274 if ref_axis is None or ref_size is None:
2275 raise ValueError(
2276 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2277 + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2278 )
2280 if a.unit != ref_axis.unit:
2281 raise ValueError(
2282 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2283 + " axis and reference axis to have the same `unit`, but"
2284 + f" {a.unit}!={ref_axis.unit}"
2285 )
2287 if actual_size != (
2288 expected_size := (
2289 ref_size * ref_axis.scale / a.scale + a.size.offset
2290 )
2291 ):
2292 raise ValueError(
2293 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2294 + f" {actual_size} invalid for referenced size {ref_size};"
2295 + f" expected {expected_size}"
2296 )
2297 else:
2298 assert_never(a.size)
2301FileDescr_dependencies = Annotated[
2302 FileDescr_,
2303 WithSuffix((".yaml", ".yml"), case_sensitive=True),
2304 Field(examples=[dict(source="environment.yaml")]),
2305]
2308class _ArchitectureCallableDescr(Node):
2309 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2310 """Identifier of the callable that returns a torch.nn.Module instance."""
2312 kwargs: Dict[str, YamlValue] = Field(
2313 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2314 )
2315 """key word arguments for the `callable`"""
2318class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2319 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2320 """Architecture source file"""
2322 @model_serializer(mode="wrap", when_used="unless-none")
2323 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2324 return package_file_descr_serializer(self, nxt, info)
2327class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2328 import_from: str
2329 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2332class _ArchFileConv(
2333 Converter[
2334 _CallableFromFile_v0_4,
2335 ArchitectureFromFileDescr,
2336 Optional[Sha256],
2337 Dict[str, Any],
2338 ]
2339):
2340 def _convert(
2341 self,
2342 src: _CallableFromFile_v0_4,
2343 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2344 sha256: Optional[Sha256],
2345 kwargs: Dict[str, Any],
2346 ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2347 if src.startswith("http") and src.count(":") == 2:
2348 http, source, callable_ = src.split(":")
2349 source = ":".join((http, source))
2350 elif not src.startswith("http") and src.count(":") == 1:
2351 source, callable_ = src.split(":")
2352 else:
2353 source = str(src)
2354 callable_ = str(src)
2355 return tgt(
2356 callable=Identifier(callable_),
2357 source=cast(FileSource_, source),
2358 sha256=sha256,
2359 kwargs=kwargs,
2360 )
2363_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2366class _ArchLibConv(
2367 Converter[
2368 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2369 ]
2370):
2371 def _convert(
2372 self,
2373 src: _CallableFromDepencency_v0_4,
2374 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2375 kwargs: Dict[str, Any],
2376 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2377 *mods, callable_ = src.split(".")
2378 import_from = ".".join(mods)
2379 return tgt(
2380 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2381 )
2384_arch_lib_conv = _ArchLibConv(
2385 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2386)
2389class WeightsEntryDescrBase(FileDescr):
2390 type: ClassVar[WeightsFormat]
2391 weights_format_name: ClassVar[str] # human readable
2393 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2394 """Source of the weights file."""
2396 authors: Optional[List[Author]] = None
2397 """Authors
2398 Either the person(s) that have trained this model resulting in the original weights file.
2399 (If this is the initial weights entry, i.e. it does not have a `parent`)
2400 Or the person(s) who have converted the weights to this weights format.
2401 (If this is a child weight, i.e. it has a `parent` field)
2402 """
2404 parent: Annotated[
2405 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2406 ] = None
2407 """The source weights these weights were converted from.
2408 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2409 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2410 All weight entries except one (the initial set of weights resulting from training the model),
2411 need to have this field."""
2413 comment: str = ""
2414 """A comment about this weights entry, for example how these weights were created."""
2416 @model_validator(mode="after")
2417 def _validate(self) -> Self:
2418 if self.type == self.parent:
2419 raise ValueError("Weights entry can't be it's own parent.")
2421 return self
2423 @model_serializer(mode="wrap", when_used="unless-none")
2424 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2425 return package_file_descr_serializer(self, nxt, info)
2428class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2429 type: ClassVar[WeightsFormat] = "keras_hdf5"
2430 weights_format_name: ClassVar[str] = "Keras HDF5"
2431 tensorflow_version: Version
2432 """TensorFlow version used to create these weights."""
2435FileDescr_external_data = Annotated[
2436 FileDescr_,
2437 WithSuffix(".data", case_sensitive=True),
2438 Field(examples=[dict(source="weights.onnx.data")]),
2439]
2442class OnnxWeightsDescr(WeightsEntryDescrBase):
2443 type: ClassVar[WeightsFormat] = "onnx"
2444 weights_format_name: ClassVar[str] = "ONNX"
2445 opset_version: Annotated[int, Ge(7)]
2446 """ONNX opset version"""
2448 external_data: Optional[FileDescr_external_data] = None
2449 """Source of the external ONNX data file holding the weights.
2450 (If present **source** holds the ONNX architecture without weights)."""
2452 @model_validator(mode="after")
2453 def _validate_external_data_unique_file_name(self) -> Self:
2454 if self.external_data is not None and (
2455 extract_file_name(self.source)
2456 == extract_file_name(self.external_data.source)
2457 ):
2458 raise ValueError(
2459 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'"
2460 + " must be different from ONNX `source` file name."
2461 )
2463 return self
2466class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2467 type: ClassVar[WeightsFormat] = "pytorch_state_dict"
2468 weights_format_name: ClassVar[str] = "Pytorch State Dict"
2469 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2470 pytorch_version: Version
2471 """Version of the PyTorch library used.
2472 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2473 """
2474 dependencies: Optional[FileDescr_dependencies] = None
2475 """Custom depencies beyond pytorch described in a Conda environment file.
2476 Allows to specify custom dependencies, see conda docs:
2477 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2478 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2480 The conda environment file should include pytorch and any version pinning has to be compatible with
2481 **pytorch_version**.
2482 """
2485class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2486 type: ClassVar[WeightsFormat] = "tensorflow_js"
2487 weights_format_name: ClassVar[str] = "Tensorflow.js"
2488 tensorflow_version: Version
2489 """Version of the TensorFlow library used."""
2491 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2492 """The multi-file weights.
2493 All required files/folders should be a zip archive."""
2496class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2497 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle"
2498 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2499 tensorflow_version: Version
2500 """Version of the TensorFlow library used."""
2502 dependencies: Optional[FileDescr_dependencies] = None
2503 """Custom dependencies beyond tensorflow.
2504 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2506 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2507 """The multi-file weights.
2508 All required files/folders should be a zip archive."""
2511class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2512 type: ClassVar[WeightsFormat] = "torchscript"
2513 weights_format_name: ClassVar[str] = "TorchScript"
2514 pytorch_version: Version
2515 """Version of the PyTorch library used."""
2518SpecificWeightsDescr = Union[
2519 KerasHdf5WeightsDescr,
2520 OnnxWeightsDescr,
2521 PytorchStateDictWeightsDescr,
2522 TensorflowJsWeightsDescr,
2523 TensorflowSavedModelBundleWeightsDescr,
2524 TorchscriptWeightsDescr,
2525]
2528class WeightsDescr(Node):
2529 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2530 onnx: Optional[OnnxWeightsDescr] = None
2531 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2532 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2533 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2534 None
2535 )
2536 torchscript: Optional[TorchscriptWeightsDescr] = None
2538 @model_validator(mode="after")
2539 def check_entries(self) -> Self:
2540 entries = {wtype for wtype, entry in self if entry is not None}
2542 if not entries:
2543 raise ValueError("Missing weights entry")
2545 entries_wo_parent = {
2546 wtype
2547 for wtype, entry in self
2548 if entry is not None and hasattr(entry, "parent") and entry.parent is None
2549 }
2550 if len(entries_wo_parent) != 1:
2551 issue_warning(
2552 "Exactly one weights entry may not specify the `parent` field (got"
2553 + " {value}). That entry is considered the original set of model weights."
2554 + " Other weight formats are created through conversion of the orignal or"
2555 + " already converted weights. They have to reference the weights format"
2556 + " they were converted from as their `parent`.",
2557 value=len(entries_wo_parent),
2558 field="weights",
2559 )
2561 for wtype, entry in self:
2562 if entry is None:
2563 continue
2565 assert hasattr(entry, "type")
2566 assert hasattr(entry, "parent")
2567 assert wtype == entry.type
2568 if (
2569 entry.parent is not None and entry.parent not in entries
2570 ): # self reference checked for `parent` field
2571 raise ValueError(
2572 f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2573 + f" formats: {entries}"
2574 )
2576 return self
2578 def __getitem__(
2579 self,
2580 key: Literal[
2581 "keras_hdf5",
2582 "onnx",
2583 "pytorch_state_dict",
2584 "tensorflow_js",
2585 "tensorflow_saved_model_bundle",
2586 "torchscript",
2587 ],
2588 ):
2589 if key == "keras_hdf5":
2590 ret = self.keras_hdf5
2591 elif key == "onnx":
2592 ret = self.onnx
2593 elif key == "pytorch_state_dict":
2594 ret = self.pytorch_state_dict
2595 elif key == "tensorflow_js":
2596 ret = self.tensorflow_js
2597 elif key == "tensorflow_saved_model_bundle":
2598 ret = self.tensorflow_saved_model_bundle
2599 elif key == "torchscript":
2600 ret = self.torchscript
2601 else:
2602 raise KeyError(key)
2604 if ret is None:
2605 raise KeyError(key)
2607 return ret
2609 @overload
2610 def __setitem__(
2611 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr]
2612 ) -> None: ...
2613 @overload
2614 def __setitem__(
2615 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr]
2616 ) -> None: ...
2617 @overload
2618 def __setitem__(
2619 self,
2620 key: Literal["pytorch_state_dict"],
2621 value: Optional[PytorchStateDictWeightsDescr],
2622 ) -> None: ...
2623 @overload
2624 def __setitem__(
2625 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr]
2626 ) -> None: ...
2627 @overload
2628 def __setitem__(
2629 self,
2630 key: Literal["tensorflow_saved_model_bundle"],
2631 value: Optional[TensorflowSavedModelBundleWeightsDescr],
2632 ) -> None: ...
2633 @overload
2634 def __setitem__(
2635 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr]
2636 ) -> None: ...
2638 def __setitem__(
2639 self,
2640 key: Literal[
2641 "keras_hdf5",
2642 "onnx",
2643 "pytorch_state_dict",
2644 "tensorflow_js",
2645 "tensorflow_saved_model_bundle",
2646 "torchscript",
2647 ],
2648 value: Optional[SpecificWeightsDescr],
2649 ):
2650 if key == "keras_hdf5":
2651 if value is not None and not isinstance(value, KerasHdf5WeightsDescr):
2652 raise TypeError(
2653 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}"
2654 )
2655 self.keras_hdf5 = value
2656 elif key == "onnx":
2657 if value is not None and not isinstance(value, OnnxWeightsDescr):
2658 raise TypeError(
2659 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}"
2660 )
2661 self.onnx = value
2662 elif key == "pytorch_state_dict":
2663 if value is not None and not isinstance(
2664 value, PytorchStateDictWeightsDescr
2665 ):
2666 raise TypeError(
2667 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}"
2668 )
2669 self.pytorch_state_dict = value
2670 elif key == "tensorflow_js":
2671 if value is not None and not isinstance(value, TensorflowJsWeightsDescr):
2672 raise TypeError(
2673 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}"
2674 )
2675 self.tensorflow_js = value
2676 elif key == "tensorflow_saved_model_bundle":
2677 if value is not None and not isinstance(
2678 value, TensorflowSavedModelBundleWeightsDescr
2679 ):
2680 raise TypeError(
2681 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}"
2682 )
2683 self.tensorflow_saved_model_bundle = value
2684 elif key == "torchscript":
2685 if value is not None and not isinstance(value, TorchscriptWeightsDescr):
2686 raise TypeError(
2687 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}"
2688 )
2689 self.torchscript = value
2690 else:
2691 raise KeyError(key)
2693 @property
2694 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]:
2695 return {
2696 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2697 **({} if self.onnx is None else {"onnx": self.onnx}),
2698 **(
2699 {}
2700 if self.pytorch_state_dict is None
2701 else {"pytorch_state_dict": self.pytorch_state_dict}
2702 ),
2703 **(
2704 {}
2705 if self.tensorflow_js is None
2706 else {"tensorflow_js": self.tensorflow_js}
2707 ),
2708 **(
2709 {}
2710 if self.tensorflow_saved_model_bundle is None
2711 else {
2712 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2713 }
2714 ),
2715 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2716 }
2718 @property
2719 def missing_formats(self) -> Set[WeightsFormat]:
2720 return {
2721 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2722 }
2725class ModelId(ResourceId):
2726 pass
2729class LinkedModel(LinkedResourceBase):
2730 """Reference to a bioimage.io model."""
2732 id: ModelId
2733 """A valid model `id` from the bioimage.io collection."""
2736class _DataDepSize(NamedTuple):
2737 min: StrictInt
2738 max: Optional[StrictInt]
2741class _AxisSizes(NamedTuple):
2742 """the lenghts of all axes of model inputs and outputs"""
2744 inputs: Dict[Tuple[TensorId, AxisId], int]
2745 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2748class _TensorSizes(NamedTuple):
2749 """_AxisSizes as nested dicts"""
2751 inputs: Dict[TensorId, Dict[AxisId, int]]
2752 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2755class ReproducibilityTolerance(Node, extra="allow"):
2756 """Describes what small numerical differences -- if any -- may be tolerated
2757 in the generated output when executing in different environments.
2759 A tensor element *output* is considered mismatched to the **test_tensor** if
2760 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2761 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2763 Motivation:
2764 For testing we can request the respective deep learning frameworks to be as
2765 reproducible as possible by setting seeds and chosing deterministic algorithms,
2766 but differences in operating systems, available hardware and installed drivers
2767 may still lead to numerical differences.
2768 """
2770 relative_tolerance: RelativeTolerance = 1e-3
2771 """Maximum relative tolerance of reproduced test tensor."""
2773 absolute_tolerance: AbsoluteTolerance = 1e-3
2774 """Maximum absolute tolerance of reproduced test tensor."""
2776 mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2777 """Maximum number of mismatched elements/pixels per million to tolerate."""
2779 output_ids: Sequence[TensorId] = ()
2780 """Limits the output tensor IDs these reproducibility details apply to."""
2782 weights_formats: Sequence[WeightsFormat] = ()
2783 """Limits the weights formats these details apply to."""
2786class BioimageioConfig(Node, extra="allow"):
2787 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2788 """Tolerances to allow when reproducing the model's test outputs
2789 from the model's test inputs.
2790 Only the first entry matching tensor id and weights format is considered.
2791 """
2794class Config(Node, extra="allow"):
2795 bioimageio: BioimageioConfig = Field(
2796 default_factory=BioimageioConfig.model_construct
2797 )
2800class ModelDescr(GenericModelDescrBase):
2801 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2802 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2803 """
2805 implemented_format_version: ClassVar[Literal["0.5.7"]] = "0.5.7"
2806 if TYPE_CHECKING:
2807 format_version: Literal["0.5.7"] = "0.5.7"
2808 else:
2809 format_version: Literal["0.5.7"]
2810 """Version of the bioimage.io model description specification used.
2811 When creating a new model always use the latest micro/patch version described here.
2812 The `format_version` is important for any consumer software to understand how to parse the fields.
2813 """
2815 implemented_type: ClassVar[Literal["model"]] = "model"
2816 if TYPE_CHECKING:
2817 type: Literal["model"] = "model"
2818 else:
2819 type: Literal["model"]
2820 """Specialized resource type 'model'"""
2822 id: Optional[ModelId] = None
2823 """bioimage.io-wide unique resource identifier
2824 assigned by bioimage.io; version **un**specific."""
2826 authors: FAIR[List[Author]] = Field(
2827 default_factory=cast(Callable[[], List[Author]], list)
2828 )
2829 """The authors are the creators of the model RDF and the primary points of contact."""
2831 documentation: FAIR[Optional[FileSource_documentation]] = None
2832 """URL or relative path to a markdown file with additional documentation.
2833 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2834 The documentation should include a '#[#] Validation' (sub)section
2835 with details on how to quantitatively validate the model on unseen data."""
2837 @field_validator("documentation", mode="after")
2838 @classmethod
2839 def _validate_documentation(
2840 cls, value: Optional[FileSource_documentation]
2841 ) -> Optional[FileSource_documentation]:
2842 if not get_validation_context().perform_io_checks or value is None:
2843 return value
2845 doc_reader = get_reader(value)
2846 doc_content = doc_reader.read().decode(encoding="utf-8")
2847 if not re.search("#.*[vV]alidation", doc_content):
2848 issue_warning(
2849 "No '# Validation' (sub)section found in {value}.",
2850 value=value,
2851 field="documentation",
2852 )
2854 return value
2856 inputs: NotEmpty[Sequence[InputTensorDescr]]
2857 """Describes the input tensors expected by this model."""
2859 @field_validator("inputs", mode="after")
2860 @classmethod
2861 def _validate_input_axes(
2862 cls, inputs: Sequence[InputTensorDescr]
2863 ) -> Sequence[InputTensorDescr]:
2864 input_size_refs = cls._get_axes_with_independent_size(inputs)
2866 for i, ipt in enumerate(inputs):
2867 valid_independent_refs: Dict[
2868 Tuple[TensorId, AxisId],
2869 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2870 ] = {
2871 **{
2872 (ipt.id, a.id): (ipt, a, a.size)
2873 for a in ipt.axes
2874 if not isinstance(a, BatchAxis)
2875 and isinstance(a.size, (int, ParameterizedSize))
2876 },
2877 **input_size_refs,
2878 }
2879 for a, ax in enumerate(ipt.axes):
2880 cls._validate_axis(
2881 "inputs",
2882 i=i,
2883 tensor_id=ipt.id,
2884 a=a,
2885 axis=ax,
2886 valid_independent_refs=valid_independent_refs,
2887 )
2888 return inputs
2890 @staticmethod
2891 def _validate_axis(
2892 field_name: str,
2893 i: int,
2894 tensor_id: TensorId,
2895 a: int,
2896 axis: AnyAxis,
2897 valid_independent_refs: Dict[
2898 Tuple[TensorId, AxisId],
2899 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2900 ],
2901 ):
2902 if isinstance(axis, BatchAxis) or isinstance(
2903 axis.size, (int, ParameterizedSize, DataDependentSize)
2904 ):
2905 return
2906 elif not isinstance(axis.size, SizeReference):
2907 assert_never(axis.size)
2909 # validate axis.size SizeReference
2910 ref = (axis.size.tensor_id, axis.size.axis_id)
2911 if ref not in valid_independent_refs:
2912 raise ValueError(
2913 "Invalid tensor axis reference at"
2914 + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2915 )
2916 if ref == (tensor_id, axis.id):
2917 raise ValueError(
2918 "Self-referencing not allowed for"
2919 + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2920 )
2921 if axis.type == "channel":
2922 if valid_independent_refs[ref][1].type != "channel":
2923 raise ValueError(
2924 "A channel axis' size may only reference another fixed size"
2925 + " channel axis."
2926 )
2927 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2928 ref_size = valid_independent_refs[ref][2]
2929 assert isinstance(ref_size, int), (
2930 "channel axis ref (another channel axis) has to specify fixed"
2931 + " size"
2932 )
2933 generated_channel_names = [
2934 Identifier(axis.channel_names.format(i=i))
2935 for i in range(1, ref_size + 1)
2936 ]
2937 axis.channel_names = generated_channel_names
2939 if (ax_unit := getattr(axis, "unit", None)) != (
2940 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2941 ):
2942 raise ValueError(
2943 "The units of an axis and its reference axis need to match, but"
2944 + f" '{ax_unit}' != '{ref_unit}'."
2945 )
2946 ref_axis = valid_independent_refs[ref][1]
2947 if isinstance(ref_axis, BatchAxis):
2948 raise ValueError(
2949 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2950 + " (a batch axis is not allowed as reference)."
2951 )
2953 if isinstance(axis, WithHalo):
2954 min_size = axis.size.get_size(axis, ref_axis, n=0)
2955 if (min_size - 2 * axis.halo) < 1:
2956 raise ValueError(
2957 f"axis {axis.id} with minimum size {min_size} is too small for halo"
2958 + f" {axis.halo}."
2959 )
2961 ref_halo = axis.halo * axis.scale / ref_axis.scale
2962 if ref_halo != int(ref_halo):
2963 raise ValueError(
2964 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} ="
2965 + f" {tensor_id}.{axis.id}.halo {axis.halo}"
2966 + f" * {tensor_id}.{axis.id}.scale {axis.scale}"
2967 + f" / {'.'.join(ref)}.scale {ref_axis.scale})."
2968 )
2970 @model_validator(mode="after")
2971 def _validate_test_tensors(self) -> Self:
2972 if not get_validation_context().perform_io_checks:
2973 return self
2975 test_output_arrays = [
2976 None if descr.test_tensor is None else load_array(descr.test_tensor)
2977 for descr in self.outputs
2978 ]
2979 test_input_arrays = [
2980 None if descr.test_tensor is None else load_array(descr.test_tensor)
2981 for descr in self.inputs
2982 ]
2984 tensors = {
2985 descr.id: (descr, array)
2986 for descr, array in zip(
2987 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2988 )
2989 }
2990 validate_tensors(tensors, tensor_origin="test_tensor")
2992 output_arrays = {
2993 descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2994 }
2995 for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2996 if not rep_tol.absolute_tolerance:
2997 continue
2999 if rep_tol.output_ids:
3000 out_arrays = {
3001 oid: a
3002 for oid, a in output_arrays.items()
3003 if oid in rep_tol.output_ids
3004 }
3005 else:
3006 out_arrays = output_arrays
3008 for out_id, array in out_arrays.items():
3009 if array is None:
3010 continue
3012 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
3013 raise ValueError(
3014 "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
3015 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
3016 + f" (1% of the maximum value of the test tensor '{out_id}')"
3017 )
3019 return self
3021 @model_validator(mode="after")
3022 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
3023 ipt_refs = {t.id for t in self.inputs}
3024 out_refs = {t.id for t in self.outputs}
3025 for ipt in self.inputs:
3026 for p in ipt.preprocessing:
3027 ref = p.kwargs.get("reference_tensor")
3028 if ref is None:
3029 continue
3030 if ref not in ipt_refs:
3031 raise ValueError(
3032 f"`reference_tensor` '{ref}' not found. Valid input tensor"
3033 + f" references are: {ipt_refs}."
3034 )
3036 for out in self.outputs:
3037 for p in out.postprocessing:
3038 ref = p.kwargs.get("reference_tensor")
3039 if ref is None:
3040 continue
3042 if ref not in ipt_refs and ref not in out_refs:
3043 raise ValueError(
3044 f"`reference_tensor` '{ref}' not found. Valid tensor references"
3045 + f" are: {ipt_refs | out_refs}."
3046 )
3048 return self
3050 # TODO: use validate funcs in validate_test_tensors
3051 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
3053 name: Annotated[
3054 str,
3055 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
3056 MinLen(5),
3057 MaxLen(128),
3058 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
3059 ]
3060 """A human-readable name of this model.
3061 It should be no longer than 64 characters
3062 and may only contain letter, number, underscore, minus, parentheses and spaces.
3063 We recommend to chose a name that refers to the model's task and image modality.
3064 """
3066 outputs: NotEmpty[Sequence[OutputTensorDescr]]
3067 """Describes the output tensors."""
3069 @field_validator("outputs", mode="after")
3070 @classmethod
3071 def _validate_tensor_ids(
3072 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
3073 ) -> Sequence[OutputTensorDescr]:
3074 tensor_ids = [
3075 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
3076 ]
3077 duplicate_tensor_ids: List[str] = []
3078 seen: Set[str] = set()
3079 for t in tensor_ids:
3080 if t in seen:
3081 duplicate_tensor_ids.append(t)
3083 seen.add(t)
3085 if duplicate_tensor_ids:
3086 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
3088 return outputs
3090 @staticmethod
3091 def _get_axes_with_parameterized_size(
3092 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3093 ):
3094 return {
3095 f"{t.id}.{a.id}": (t, a, a.size)
3096 for t in io
3097 for a in t.axes
3098 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
3099 }
3101 @staticmethod
3102 def _get_axes_with_independent_size(
3103 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3104 ):
3105 return {
3106 (t.id, a.id): (t, a, a.size)
3107 for t in io
3108 for a in t.axes
3109 if not isinstance(a, BatchAxis)
3110 and isinstance(a.size, (int, ParameterizedSize))
3111 }
3113 @field_validator("outputs", mode="after")
3114 @classmethod
3115 def _validate_output_axes(
3116 cls, outputs: List[OutputTensorDescr], info: ValidationInfo
3117 ) -> List[OutputTensorDescr]:
3118 input_size_refs = cls._get_axes_with_independent_size(
3119 info.data.get("inputs", [])
3120 )
3121 output_size_refs = cls._get_axes_with_independent_size(outputs)
3123 for i, out in enumerate(outputs):
3124 valid_independent_refs: Dict[
3125 Tuple[TensorId, AxisId],
3126 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3127 ] = {
3128 **{
3129 (out.id, a.id): (out, a, a.size)
3130 for a in out.axes
3131 if not isinstance(a, BatchAxis)
3132 and isinstance(a.size, (int, ParameterizedSize))
3133 },
3134 **input_size_refs,
3135 **output_size_refs,
3136 }
3137 for a, ax in enumerate(out.axes):
3138 cls._validate_axis(
3139 "outputs",
3140 i,
3141 out.id,
3142 a,
3143 ax,
3144 valid_independent_refs=valid_independent_refs,
3145 )
3147 return outputs
3149 packaged_by: List[Author] = Field(
3150 default_factory=cast(Callable[[], List[Author]], list)
3151 )
3152 """The persons that have packaged and uploaded this model.
3153 Only required if those persons differ from the `authors`."""
3155 parent: Optional[LinkedModel] = None
3156 """The model from which this model is derived, e.g. by fine-tuning the weights."""
3158 @model_validator(mode="after")
3159 def _validate_parent_is_not_self(self) -> Self:
3160 if self.parent is not None and self.parent.id == self.id:
3161 raise ValueError("A model description may not reference itself as parent.")
3163 return self
3165 run_mode: Annotated[
3166 Optional[RunMode],
3167 warn(None, "Run mode '{value}' has limited support across consumer softwares."),
3168 ] = None
3169 """Custom run mode for this model: for more complex prediction procedures like test time
3170 data augmentation that currently cannot be expressed in the specification.
3171 No standard run modes are defined yet."""
3173 timestamp: Datetime = Field(default_factory=Datetime.now)
3174 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
3175 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
3176 (In Python a datetime object is valid, too)."""
3178 training_data: Annotated[
3179 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
3180 Field(union_mode="left_to_right"),
3181 ] = None
3182 """The dataset used to train this model"""
3184 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
3185 """The weights for this model.
3186 Weights can be given for different formats, but should otherwise be equivalent.
3187 The available weight formats determine which consumers can use this model."""
3189 config: Config = Field(default_factory=Config.model_construct)
3191 @model_validator(mode="after")
3192 def _add_default_cover(self) -> Self:
3193 if not get_validation_context().perform_io_checks or self.covers:
3194 return self
3196 try:
3197 generated_covers = generate_covers(
3198 [
3199 (t, load_array(t.test_tensor))
3200 for t in self.inputs
3201 if t.test_tensor is not None
3202 ],
3203 [
3204 (t, load_array(t.test_tensor))
3205 for t in self.outputs
3206 if t.test_tensor is not None
3207 ],
3208 )
3209 except Exception as e:
3210 issue_warning(
3211 "Failed to generate cover image(s): {e}",
3212 value=self.covers,
3213 msg_context=dict(e=e),
3214 field="covers",
3215 )
3216 else:
3217 self.covers.extend(generated_covers)
3219 return self
3221 def get_input_test_arrays(self) -> List[NDArray[Any]]:
3222 return self._get_test_arrays(self.inputs)
3224 def get_output_test_arrays(self) -> List[NDArray[Any]]:
3225 return self._get_test_arrays(self.outputs)
3227 @staticmethod
3228 def _get_test_arrays(
3229 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3230 ):
3231 ts: List[FileDescr] = []
3232 for d in io_descr:
3233 if d.test_tensor is None:
3234 raise ValueError(
3235 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3236 )
3237 ts.append(d.test_tensor)
3239 data = [load_array(t) for t in ts]
3240 assert all(isinstance(d, np.ndarray) for d in data)
3241 return data
3243 @staticmethod
3244 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3245 batch_size = 1
3246 tensor_with_batchsize: Optional[TensorId] = None
3247 for tid in tensor_sizes:
3248 for aid, s in tensor_sizes[tid].items():
3249 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3250 continue
3252 if batch_size != 1:
3253 assert tensor_with_batchsize is not None
3254 raise ValueError(
3255 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3256 )
3258 batch_size = s
3259 tensor_with_batchsize = tid
3261 return batch_size
3263 def get_output_tensor_sizes(
3264 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3265 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3266 """Returns the tensor output sizes for given **input_sizes**.
3267 Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3268 Otherwise it might be larger than the actual (valid) output"""
3269 batch_size = self.get_batch_size(input_sizes)
3270 ns = self.get_ns(input_sizes)
3272 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3273 return tensor_sizes.outputs
3275 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3276 """get parameter `n` for each parameterized axis
3277 such that the valid input size is >= the given input size"""
3278 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3279 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3280 for tid in input_sizes:
3281 for aid, s in input_sizes[tid].items():
3282 size_descr = axes[tid][aid].size
3283 if isinstance(size_descr, ParameterizedSize):
3284 ret[(tid, aid)] = size_descr.get_n(s)
3285 elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3286 pass
3287 else:
3288 assert_never(size_descr)
3290 return ret
3292 def get_tensor_sizes(
3293 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3294 ) -> _TensorSizes:
3295 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3296 return _TensorSizes(
3297 {
3298 t: {
3299 aa: axis_sizes.inputs[(tt, aa)]
3300 for tt, aa in axis_sizes.inputs
3301 if tt == t
3302 }
3303 for t in {tt for tt, _ in axis_sizes.inputs}
3304 },
3305 {
3306 t: {
3307 aa: axis_sizes.outputs[(tt, aa)]
3308 for tt, aa in axis_sizes.outputs
3309 if tt == t
3310 }
3311 for t in {tt for tt, _ in axis_sizes.outputs}
3312 },
3313 )
3315 def get_axis_sizes(
3316 self,
3317 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3318 batch_size: Optional[int] = None,
3319 *,
3320 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3321 ) -> _AxisSizes:
3322 """Determine input and output block shape for scale factors **ns**
3323 of parameterized input sizes.
3325 Args:
3326 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3327 that is parameterized as `size = min + n * step`.
3328 batch_size: The desired size of the batch dimension.
3329 If given **batch_size** overwrites any batch size present in
3330 **max_input_shape**. Default 1.
3331 max_input_shape: Limits the derived block shapes.
3332 Each axis for which the input size, parameterized by `n`, is larger
3333 than **max_input_shape** is set to the minimal value `n_min` for which
3334 this is still true.
3335 Use this for small input samples or large values of **ns**.
3336 Or simply whenever you know the full input shape.
3338 Returns:
3339 Resolved axis sizes for model inputs and outputs.
3340 """
3341 max_input_shape = max_input_shape or {}
3342 if batch_size is None:
3343 for (_t_id, a_id), s in max_input_shape.items():
3344 if a_id == BATCH_AXIS_ID:
3345 batch_size = s
3346 break
3347 else:
3348 batch_size = 1
3350 all_axes = {
3351 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3352 }
3354 inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3355 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3357 def get_axis_size(a: Union[InputAxis, OutputAxis]):
3358 if isinstance(a, BatchAxis):
3359 if (t_descr.id, a.id) in ns:
3360 logger.warning(
3361 "Ignoring unexpected size increment factor (n) for batch axis"
3362 + " of tensor '{}'.",
3363 t_descr.id,
3364 )
3365 return batch_size
3366 elif isinstance(a.size, int):
3367 if (t_descr.id, a.id) in ns:
3368 logger.warning(
3369 "Ignoring unexpected size increment factor (n) for fixed size"
3370 + " axis '{}' of tensor '{}'.",
3371 a.id,
3372 t_descr.id,
3373 )
3374 return a.size
3375 elif isinstance(a.size, ParameterizedSize):
3376 if (t_descr.id, a.id) not in ns:
3377 raise ValueError(
3378 "Size increment factor (n) missing for parametrized axis"
3379 + f" '{a.id}' of tensor '{t_descr.id}'."
3380 )
3381 n = ns[(t_descr.id, a.id)]
3382 s_max = max_input_shape.get((t_descr.id, a.id))
3383 if s_max is not None:
3384 n = min(n, a.size.get_n(s_max))
3386 return a.size.get_size(n)
3388 elif isinstance(a.size, SizeReference):
3389 if (t_descr.id, a.id) in ns:
3390 logger.warning(
3391 "Ignoring unexpected size increment factor (n) for axis '{}'"
3392 + " of tensor '{}' with size reference.",
3393 a.id,
3394 t_descr.id,
3395 )
3396 assert not isinstance(a, BatchAxis)
3397 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3398 assert not isinstance(ref_axis, BatchAxis)
3399 ref_key = (a.size.tensor_id, a.size.axis_id)
3400 ref_size = inputs.get(ref_key, outputs.get(ref_key))
3401 assert ref_size is not None, ref_key
3402 assert not isinstance(ref_size, _DataDepSize), ref_key
3403 return a.size.get_size(
3404 axis=a,
3405 ref_axis=ref_axis,
3406 ref_size=ref_size,
3407 )
3408 elif isinstance(a.size, DataDependentSize):
3409 if (t_descr.id, a.id) in ns:
3410 logger.warning(
3411 "Ignoring unexpected increment factor (n) for data dependent"
3412 + " size axis '{}' of tensor '{}'.",
3413 a.id,
3414 t_descr.id,
3415 )
3416 return _DataDepSize(a.size.min, a.size.max)
3417 else:
3418 assert_never(a.size)
3420 # first resolve all , but the `SizeReference` input sizes
3421 for t_descr in self.inputs:
3422 for a in t_descr.axes:
3423 if not isinstance(a.size, SizeReference):
3424 s = get_axis_size(a)
3425 assert not isinstance(s, _DataDepSize)
3426 inputs[t_descr.id, a.id] = s
3428 # resolve all other input axis sizes
3429 for t_descr in self.inputs:
3430 for a in t_descr.axes:
3431 if isinstance(a.size, SizeReference):
3432 s = get_axis_size(a)
3433 assert not isinstance(s, _DataDepSize)
3434 inputs[t_descr.id, a.id] = s
3436 # resolve all output axis sizes
3437 for t_descr in self.outputs:
3438 for a in t_descr.axes:
3439 assert not isinstance(a.size, ParameterizedSize)
3440 s = get_axis_size(a)
3441 outputs[t_descr.id, a.id] = s
3443 return _AxisSizes(inputs=inputs, outputs=outputs)
3445 @model_validator(mode="before")
3446 @classmethod
3447 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3448 cls.convert_from_old_format_wo_validation(data)
3449 return data
3451 @classmethod
3452 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3453 """Convert metadata following an older format version to this classes' format
3454 without validating the result.
3455 """
3456 if (
3457 data.get("type") == "model"
3458 and isinstance(fv := data.get("format_version"), str)
3459 and fv.count(".") == 2
3460 ):
3461 fv_parts = fv.split(".")
3462 if any(not p.isdigit() for p in fv_parts):
3463 return
3465 fv_tuple = tuple(map(int, fv_parts))
3467 assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3468 if fv_tuple[:2] in ((0, 3), (0, 4)):
3469 m04 = _ModelDescr_v0_4.load(data)
3470 if isinstance(m04, InvalidDescr):
3471 try:
3472 updated = _model_conv.convert_as_dict(
3473 m04 # pyright: ignore[reportArgumentType]
3474 )
3475 except Exception as e:
3476 logger.error(
3477 "Failed to convert from invalid model 0.4 description."
3478 + f"\nerror: {e}"
3479 + "\nProceeding with model 0.5 validation without conversion."
3480 )
3481 updated = None
3482 else:
3483 updated = _model_conv.convert_as_dict(m04)
3485 if updated is not None:
3486 data.clear()
3487 data.update(updated)
3489 elif fv_tuple[:2] == (0, 5):
3490 # bump patch version
3491 data["format_version"] = cls.implemented_format_version
3494class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3495 def _convert(
3496 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3497 ) -> "ModelDescr | dict[str, Any]":
3498 name = "".join(
3499 c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3500 for c in src.name
3501 )
3503 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3504 conv = (
3505 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3506 )
3507 return None if auths is None else [conv(a) for a in auths]
3509 if TYPE_CHECKING:
3510 arch_file_conv = _arch_file_conv.convert
3511 arch_lib_conv = _arch_lib_conv.convert
3512 else:
3513 arch_file_conv = _arch_file_conv.convert_as_dict
3514 arch_lib_conv = _arch_lib_conv.convert_as_dict
3516 input_size_refs = {
3517 ipt.name: {
3518 a: s
3519 for a, s in zip(
3520 ipt.axes,
3521 (
3522 ipt.shape.min
3523 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3524 else ipt.shape
3525 ),
3526 )
3527 }
3528 for ipt in src.inputs
3529 if ipt.shape
3530 }
3531 output_size_refs = {
3532 **{
3533 out.name: {a: s for a, s in zip(out.axes, out.shape)}
3534 for out in src.outputs
3535 if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3536 },
3537 **input_size_refs,
3538 }
3540 return tgt(
3541 attachments=(
3542 []
3543 if src.attachments is None
3544 else [FileDescr(source=f) for f in src.attachments.files]
3545 ),
3546 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType]
3547 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType]
3548 config=src.config, # pyright: ignore[reportArgumentType]
3549 covers=src.covers,
3550 description=src.description,
3551 documentation=src.documentation,
3552 format_version="0.5.7",
3553 git_repo=src.git_repo, # pyright: ignore[reportArgumentType]
3554 icon=src.icon,
3555 id=None if src.id is None else ModelId(src.id),
3556 id_emoji=src.id_emoji,
3557 license=src.license, # type: ignore
3558 links=src.links,
3559 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType]
3560 name=name,
3561 tags=src.tags,
3562 type=src.type,
3563 uploader=src.uploader,
3564 version=src.version,
3565 inputs=[ # pyright: ignore[reportArgumentType]
3566 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3567 for ipt, tt, st in zip(
3568 src.inputs,
3569 src.test_inputs,
3570 src.sample_inputs or [None] * len(src.test_inputs),
3571 )
3572 ],
3573 outputs=[ # pyright: ignore[reportArgumentType]
3574 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3575 for out, tt, st in zip(
3576 src.outputs,
3577 src.test_outputs,
3578 src.sample_outputs or [None] * len(src.test_outputs),
3579 )
3580 ],
3581 parent=(
3582 None
3583 if src.parent is None
3584 else LinkedModel(
3585 id=ModelId(
3586 str(src.parent.id)
3587 + (
3588 ""
3589 if src.parent.version_number is None
3590 else f"/{src.parent.version_number}"
3591 )
3592 )
3593 )
3594 ),
3595 training_data=(
3596 None
3597 if src.training_data is None
3598 else (
3599 LinkedDataset(
3600 id=DatasetId(
3601 str(src.training_data.id)
3602 + (
3603 ""
3604 if src.training_data.version_number is None
3605 else f"/{src.training_data.version_number}"
3606 )
3607 )
3608 )
3609 if isinstance(src.training_data, LinkedDataset02)
3610 else src.training_data
3611 )
3612 ),
3613 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType]
3614 run_mode=src.run_mode,
3615 timestamp=src.timestamp,
3616 weights=(WeightsDescr if TYPE_CHECKING else dict)(
3617 keras_hdf5=(w := src.weights.keras_hdf5)
3618 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3619 authors=conv_authors(w.authors),
3620 source=w.source,
3621 tensorflow_version=w.tensorflow_version or Version("1.15"),
3622 parent=w.parent,
3623 ),
3624 onnx=(w := src.weights.onnx)
3625 and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3626 source=w.source,
3627 authors=conv_authors(w.authors),
3628 parent=w.parent,
3629 opset_version=w.opset_version or 15,
3630 ),
3631 pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3632 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3633 source=w.source,
3634 authors=conv_authors(w.authors),
3635 parent=w.parent,
3636 architecture=(
3637 arch_file_conv(
3638 w.architecture,
3639 w.architecture_sha256,
3640 w.kwargs,
3641 )
3642 if isinstance(w.architecture, _CallableFromFile_v0_4)
3643 else arch_lib_conv(w.architecture, w.kwargs)
3644 ),
3645 pytorch_version=w.pytorch_version or Version("1.10"),
3646 dependencies=(
3647 None
3648 if w.dependencies is None
3649 else (FileDescr if TYPE_CHECKING else dict)(
3650 source=cast(
3651 FileSource,
3652 str(deps := w.dependencies)[
3653 (
3654 len("conda:")
3655 if str(deps).startswith("conda:")
3656 else 0
3657 ) :
3658 ],
3659 )
3660 )
3661 ),
3662 ),
3663 tensorflow_js=(w := src.weights.tensorflow_js)
3664 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3665 source=w.source,
3666 authors=conv_authors(w.authors),
3667 parent=w.parent,
3668 tensorflow_version=w.tensorflow_version or Version("1.15"),
3669 ),
3670 tensorflow_saved_model_bundle=(
3671 w := src.weights.tensorflow_saved_model_bundle
3672 )
3673 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3674 authors=conv_authors(w.authors),
3675 parent=w.parent,
3676 source=w.source,
3677 tensorflow_version=w.tensorflow_version or Version("1.15"),
3678 dependencies=(
3679 None
3680 if w.dependencies is None
3681 else (FileDescr if TYPE_CHECKING else dict)(
3682 source=cast(
3683 FileSource,
3684 (
3685 str(w.dependencies)[len("conda:") :]
3686 if str(w.dependencies).startswith("conda:")
3687 else str(w.dependencies)
3688 ),
3689 )
3690 )
3691 ),
3692 ),
3693 torchscript=(w := src.weights.torchscript)
3694 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3695 source=w.source,
3696 authors=conv_authors(w.authors),
3697 parent=w.parent,
3698 pytorch_version=w.pytorch_version or Version("1.10"),
3699 ),
3700 ),
3701 )
3704_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3707# create better cover images for 3d data and non-image outputs
3708def generate_covers(
3709 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3710 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3711) -> List[Path]:
3712 def squeeze(
3713 data: NDArray[Any], axes: Sequence[AnyAxis]
3714 ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3715 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3716 if data.ndim != len(axes):
3717 raise ValueError(
3718 f"tensor shape {data.shape} does not match described axes"
3719 + f" {[a.id for a in axes]}"
3720 )
3722 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3723 return data.squeeze(), axes
3725 def normalize(
3726 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3727 ) -> NDArray[np.float32]:
3728 data = data.astype("float32")
3729 data -= data.min(axis=axis, keepdims=True)
3730 data /= data.max(axis=axis, keepdims=True) + eps
3731 return data
3733 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3734 original_shape = data.shape
3735 original_axes = list(axes)
3736 data, axes = squeeze(data, axes)
3738 # take slice fom any batch or index axis if needed
3739 # and convert the first channel axis and take a slice from any additional channel axes
3740 slices: Tuple[slice, ...] = ()
3741 ndim = data.ndim
3742 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3743 has_c_axis = False
3744 for i, a in enumerate(axes):
3745 s = data.shape[i]
3746 assert s > 1
3747 if (
3748 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3749 and ndim > ndim_need
3750 ):
3751 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3752 ndim -= 1
3753 elif isinstance(a, ChannelAxis):
3754 if has_c_axis:
3755 # second channel axis
3756 data = data[slices + (slice(0, 1),)]
3757 ndim -= 1
3758 else:
3759 has_c_axis = True
3760 if s == 2:
3761 # visualize two channels with cyan and magenta
3762 data = np.concatenate(
3763 [
3764 data[slices + (slice(1, 2),)],
3765 data[slices + (slice(0, 1),)],
3766 (
3767 data[slices + (slice(0, 1),)]
3768 + data[slices + (slice(1, 2),)]
3769 )
3770 / 2, # TODO: take maximum instead?
3771 ],
3772 axis=i,
3773 )
3774 elif data.shape[i] == 3:
3775 pass # visualize 3 channels as RGB
3776 else:
3777 # visualize first 3 channels as RGB
3778 data = data[slices + (slice(3),)]
3780 assert data.shape[i] == 3
3782 slices += (slice(None),)
3784 data, axes = squeeze(data, axes)
3785 assert len(axes) == ndim
3786 # take slice from z axis if needed
3787 slices = ()
3788 if ndim > ndim_need:
3789 for i, a in enumerate(axes):
3790 s = data.shape[i]
3791 if a.id == AxisId("z"):
3792 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3793 data, axes = squeeze(data, axes)
3794 ndim -= 1
3795 break
3797 slices += (slice(None),)
3799 # take slice from any space or time axis
3800 slices = ()
3802 for i, a in enumerate(axes):
3803 if ndim <= ndim_need:
3804 break
3806 s = data.shape[i]
3807 assert s > 1
3808 if isinstance(
3809 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3810 ):
3811 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3812 ndim -= 1
3814 slices += (slice(None),)
3816 del slices
3817 data, axes = squeeze(data, axes)
3818 assert len(axes) == ndim
3820 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2):
3821 raise ValueError(
3822 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}."
3823 )
3825 if not has_c_axis:
3826 assert ndim == 2
3827 data = np.repeat(data[:, :, None], 3, axis=2)
3828 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3829 ndim += 1
3831 assert ndim == 3
3833 # transpose axis order such that longest axis comes first...
3834 axis_order: List[int] = list(np.argsort(list(data.shape)))
3835 axis_order.reverse()
3836 # ... and channel axis is last
3837 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3838 axis_order.append(axis_order.pop(c))
3839 axes = [axes[ao] for ao in axis_order]
3840 data = data.transpose(axis_order)
3842 # h, w = data.shape[:2]
3843 # if h / w in (1.0 or 2.0):
3844 # pass
3845 # elif h / w < 2:
3846 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3848 norm_along = (
3849 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3850 )
3851 # normalize the data and map to 8 bit
3852 data = normalize(data, norm_along)
3853 data = (data * 255).astype("uint8")
3855 return data
3857 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3858 assert im0.dtype == im1.dtype == np.uint8
3859 assert im0.shape == im1.shape
3860 assert im0.ndim == 3
3861 N, M, C = im0.shape
3862 assert C == 3
3863 out = np.ones((N, M, C), dtype="uint8")
3864 for c in range(C):
3865 outc = np.tril(im0[..., c])
3866 mask = outc == 0
3867 outc[mask] = np.triu(im1[..., c])[mask]
3868 out[..., c] = outc
3870 return out
3872 if not inputs:
3873 raise ValueError("Missing test input tensor for cover generation.")
3875 if not outputs:
3876 raise ValueError("Missing test output tensor for cover generation.")
3878 ipt_descr, ipt = inputs[0]
3879 out_descr, out = outputs[0]
3881 ipt_img = to_2d_image(ipt, ipt_descr.axes)
3882 out_img = to_2d_image(out, out_descr.axes)
3884 cover_folder = Path(mkdtemp())
3885 if ipt_img.shape == out_img.shape:
3886 covers = [cover_folder / "cover.png"]
3887 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3888 else:
3889 covers = [cover_folder / "input.png", cover_folder / "output.png"]
3890 imwrite(covers[0], ipt_img)
3891 imwrite(covers[1], out_img)
3893 return covers