Coverage for src / bioimageio / spec / model / v0_5.py: 74%
1391 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 13:04 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 13:04 +0000
1from __future__ import annotations
3import collections.abc
4import re
5import string
6import warnings
7from abc import ABC
8from copy import deepcopy
9from itertools import chain
10from math import ceil
11from pathlib import Path, PurePosixPath
12from tempfile import mkdtemp
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 ClassVar,
18 Dict,
19 Generic,
20 List,
21 Literal,
22 Mapping,
23 NamedTuple,
24 Optional,
25 Sequence,
26 Set,
27 Tuple,
28 Type,
29 TypeVar,
30 Union,
31 cast,
32 overload,
33)
35import numpy as np
36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
38from loguru import logger
39from numpy.typing import NDArray
40from pydantic import (
41 AfterValidator,
42 Discriminator,
43 Field,
44 RootModel,
45 SerializationInfo,
46 SerializerFunctionWrapHandler,
47 StrictInt,
48 Tag,
49 ValidationInfo,
50 WrapSerializer,
51 field_validator,
52 model_serializer,
53 model_validator,
54)
55from typing_extensions import Annotated, Self, assert_never, get_args
57from .._internal.common_nodes import (
58 InvalidDescr,
59 Node,
60 NodeWithExplicitlySetFields,
61)
62from .._internal.constants import DTYPE_LIMITS
63from .._internal.field_warning import issue_warning, warn
64from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
65from .._internal.io import FileDescr as FileDescr
66from .._internal.io import (
67 FileSource,
68 WithSuffix,
69 YamlValue,
70 extract_file_name,
71 get_reader,
72 wo_special_file_name,
73)
74from .._internal.io_basics import Sha256 as Sha256
75from .._internal.io_packaging import (
76 FileDescr_,
77 FileSource_,
78 package_file_descr_serializer,
79)
80from .._internal.io_utils import load_array
81from .._internal.node_converter import Converter
82from .._internal.type_guards import is_dict, is_sequence
83from .._internal.types import (
84 FAIR,
85 AbsoluteTolerance,
86 LowerCaseIdentifier,
87 LowerCaseIdentifierAnno,
88 MismatchedElementsPerMillion,
89 RelativeTolerance,
90)
91from .._internal.types import Datetime as Datetime
92from .._internal.types import Identifier as Identifier
93from .._internal.types import NotEmpty as NotEmpty
94from .._internal.types import SiUnit as SiUnit
95from .._internal.url import HttpUrl as HttpUrl
96from .._internal.validation_context import get_validation_context
97from .._internal.validator_annotations import RestrictCharacters
98from .._internal.version_type import Version as Version
99from .._internal.warning_levels import INFO
100from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
101from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
102from ..dataset.v0_3 import DatasetDescr as DatasetDescr
103from ..dataset.v0_3 import DatasetId as DatasetId
104from ..dataset.v0_3 import LinkedDataset as LinkedDataset
105from ..dataset.v0_3 import Uploader as Uploader
106from ..generic.v0_3 import (
107 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
108)
109from ..generic.v0_3 import Author as Author
110from ..generic.v0_3 import BadgeDescr as BadgeDescr
111from ..generic.v0_3 import CiteEntry as CiteEntry
112from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
113from ..generic.v0_3 import Doi as Doi
114from ..generic.v0_3 import (
115 FileSource_documentation,
116 GenericModelDescrBase,
117 LinkedResourceBase,
118 _author_conv, # pyright: ignore[reportPrivateUsage]
119 _maintainer_conv, # pyright: ignore[reportPrivateUsage]
120)
121from ..generic.v0_3 import LicenseId as LicenseId
122from ..generic.v0_3 import LinkedResource as LinkedResource
123from ..generic.v0_3 import Maintainer as Maintainer
124from ..generic.v0_3 import OrcidId as OrcidId
125from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
126from ..generic.v0_3 import ResourceId as ResourceId
127from .v0_4 import Author as _Author_v0_4
128from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
129from .v0_4 import CallableFromDepencency as CallableFromDepencency
130from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
131from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
132from .v0_4 import ClipDescr as _ClipDescr_v0_4
133from .v0_4 import ClipKwargs as ClipKwargs
134from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
135from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
136from .v0_4 import KnownRunMode as KnownRunMode
137from .v0_4 import ModelDescr as _ModelDescr_v0_4
138from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
139from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
140from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
141from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
142from .v0_4 import ProcessingKwargs as ProcessingKwargs
143from .v0_4 import RunMode as RunMode
144from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
145from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
146from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
147from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
148from .v0_4 import TensorName as _TensorName_v0_4
149from .v0_4 import WeightsFormat as WeightsFormat
150from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
151from .v0_4 import package_weights
153SpaceUnit = Literal[
154 "attometer",
155 "angstrom",
156 "centimeter",
157 "decimeter",
158 "exameter",
159 "femtometer",
160 "foot",
161 "gigameter",
162 "hectometer",
163 "inch",
164 "kilometer",
165 "megameter",
166 "meter",
167 "micrometer",
168 "mile",
169 "millimeter",
170 "nanometer",
171 "parsec",
172 "petameter",
173 "picometer",
174 "terameter",
175 "yard",
176 "yoctometer",
177 "yottameter",
178 "zeptometer",
179 "zettameter",
180]
181"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
183TimeUnit = Literal[
184 "attosecond",
185 "centisecond",
186 "day",
187 "decisecond",
188 "exasecond",
189 "femtosecond",
190 "gigasecond",
191 "hectosecond",
192 "hour",
193 "kilosecond",
194 "megasecond",
195 "microsecond",
196 "millisecond",
197 "minute",
198 "nanosecond",
199 "petasecond",
200 "picosecond",
201 "second",
202 "terasecond",
203 "yoctosecond",
204 "yottasecond",
205 "zeptosecond",
206 "zettasecond",
207]
208"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
210AxisType = Literal["batch", "channel", "index", "time", "space"]
212_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
213 "b": "batch",
214 "t": "time",
215 "i": "index",
216 "c": "channel",
217 "x": "space",
218 "y": "space",
219 "z": "space",
220}
222_AXIS_ID_MAP = {
223 "b": "batch",
224 "t": "time",
225 "i": "index",
226 "c": "channel",
227}
230class TensorId(LowerCaseIdentifier):
231 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
232 Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
233 ]
236def _normalize_axis_id(a: str):
237 a = str(a)
238 normalized = _AXIS_ID_MAP.get(a, a)
239 if a != normalized:
240 logger.opt(depth=3).warning(
241 "Normalized axis id from '{}' to '{}'.", a, normalized
242 )
243 return normalized
246class AxisId(LowerCaseIdentifier):
247 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
248 Annotated[
249 LowerCaseIdentifierAnno,
250 MaxLen(16),
251 AfterValidator(_normalize_axis_id),
252 ]
253 ]
256def _is_batch(a: str) -> bool:
257 return str(a) == "batch"
260def _is_not_batch(a: str) -> bool:
261 return not _is_batch(a)
264NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
266PreprocessingId = Literal[
267 "binarize",
268 "clip",
269 "ensure_dtype",
270 "fixed_zero_mean_unit_variance",
271 "scale_linear",
272 "scale_range",
273 "sigmoid",
274 "softmax",
275]
276PostprocessingId = Literal[
277 "binarize",
278 "clip",
279 "ensure_dtype",
280 "fixed_zero_mean_unit_variance",
281 "scale_linear",
282 "scale_mean_variance",
283 "scale_range",
284 "sigmoid",
285 "softmax",
286 "zero_mean_unit_variance",
287]
290SAME_AS_TYPE = "<same as type>"
293ParameterizedSize_N = int
294"""
295Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
296"""
299class ParameterizedSize(Node):
300 """Describes a range of valid tensor axis sizes as `size = min + n*step`.
302 - **min** and **step** are given by the model description.
303 - All blocksize paramters n = 0,1,2,... yield a valid `size`.
304 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
305 This allows to adjust the axis size more generically.
306 """
308 N: ClassVar[Type[int]] = ParameterizedSize_N
309 """Positive integer to parameterize this axis"""
311 min: Annotated[int, Gt(0)]
312 step: Annotated[int, Gt(0)]
314 def validate_size(self, size: int) -> int:
315 if size < self.min:
316 raise ValueError(f"size {size} < {self.min}")
317 if (size - self.min) % self.step != 0:
318 raise ValueError(
319 f"axis of size {size} is not parameterized by `min + n*step` ="
320 + f" `{self.min} + n*{self.step}`"
321 )
323 return size
325 def get_size(self, n: ParameterizedSize_N) -> int:
326 return self.min + self.step * n
328 def get_n(self, s: int) -> ParameterizedSize_N:
329 """return smallest n parameterizing a size greater or equal than `s`"""
330 return ceil((s - self.min) / self.step)
333class DataDependentSize(Node):
334 min: Annotated[int, Gt(0)] = 1
335 max: Annotated[Optional[int], Gt(1)] = None
337 @model_validator(mode="after")
338 def _validate_max_gt_min(self):
339 if self.max is not None and self.min >= self.max:
340 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
342 return self
344 def validate_size(self, size: int) -> int:
345 if size < self.min:
346 raise ValueError(f"size {size} < {self.min}")
348 if self.max is not None and size > self.max:
349 raise ValueError(f"size {size} > {self.max}")
351 return size
354class SizeReference(Node):
355 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
357 `axis.size = reference.size * reference.scale / axis.scale + offset`
359 Note:
360 1. The axis and the referenced axis need to have the same unit (or no unit).
361 2. Batch axes may not be referenced.
362 3. Fractions are rounded down.
363 4. If the reference axis is `concatenable` the referencing axis is assumed to be
364 `concatenable` as well with the same block order.
366 Example:
367 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
368 Let's assume that we want to express the image height h in relation to its width w
369 instead of only accepting input images of exactly 100*49 pixels
370 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
372 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
373 >>> h = SpaceInputAxis(
374 ... id=AxisId("h"),
375 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
376 ... unit="millimeter",
377 ... scale=4,
378 ... )
379 >>> print(h.size.get_size(h, w))
380 49
382 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
383 """
385 tensor_id: TensorId
386 """tensor id of the reference axis"""
388 axis_id: AxisId
389 """axis id of the reference axis"""
391 offset: StrictInt = 0
393 def get_size(
394 self,
395 axis: Union[
396 ChannelAxis,
397 IndexInputAxis,
398 IndexOutputAxis,
399 TimeInputAxis,
400 SpaceInputAxis,
401 TimeOutputAxis,
402 TimeOutputAxisWithHalo,
403 SpaceOutputAxis,
404 SpaceOutputAxisWithHalo,
405 ],
406 ref_axis: Union[
407 ChannelAxis,
408 IndexInputAxis,
409 IndexOutputAxis,
410 TimeInputAxis,
411 SpaceInputAxis,
412 TimeOutputAxis,
413 TimeOutputAxisWithHalo,
414 SpaceOutputAxis,
415 SpaceOutputAxisWithHalo,
416 ],
417 n: ParameterizedSize_N = 0,
418 ref_size: Optional[int] = None,
419 ):
420 """Compute the concrete size for a given axis and its reference axis.
422 Args:
423 axis: The axis this `SizeReference` is the size of.
424 ref_axis: The reference axis to compute the size from.
425 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
426 and no fixed **ref_size** is given,
427 **n** is used to compute the size of the parameterized **ref_axis**.
428 ref_size: Overwrite the reference size instead of deriving it from
429 **ref_axis**
430 (**ref_axis.scale** is still used; any given **n** is ignored).
431 """
432 assert axis.size == self, (
433 "Given `axis.size` is not defined by this `SizeReference`"
434 )
436 assert ref_axis.id == self.axis_id, (
437 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
438 )
440 assert axis.unit == ref_axis.unit, (
441 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
442 f" but {axis.unit}!={ref_axis.unit}"
443 )
444 if ref_size is None:
445 if isinstance(ref_axis.size, (int, float)):
446 ref_size = ref_axis.size
447 elif isinstance(ref_axis.size, ParameterizedSize):
448 ref_size = ref_axis.size.get_size(n)
449 elif isinstance(ref_axis.size, DataDependentSize):
450 raise ValueError(
451 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
452 )
453 elif isinstance(ref_axis.size, SizeReference):
454 raise ValueError(
455 "Reference axis referenced in `SizeReference` may not be sized by a"
456 + " `SizeReference` itself."
457 )
458 else:
459 assert_never(ref_axis.size)
461 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
463 @staticmethod
464 def _get_unit(
465 axis: Union[
466 ChannelAxis,
467 IndexInputAxis,
468 IndexOutputAxis,
469 TimeInputAxis,
470 SpaceInputAxis,
471 TimeOutputAxis,
472 TimeOutputAxisWithHalo,
473 SpaceOutputAxis,
474 SpaceOutputAxisWithHalo,
475 ],
476 ):
477 return axis.unit
480class AxisBase(NodeWithExplicitlySetFields):
481 id: AxisId
482 """An axis id unique across all axes of one tensor."""
484 description: Annotated[str, MaxLen(128)] = ""
485 """A short description of this axis beyond its type and id."""
488class WithHalo(Node):
489 halo: Annotated[int, Ge(1)]
490 """The halo should be cropped from the output tensor to avoid boundary effects.
491 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
492 To document a halo that is already cropped by the model use `size.offset` instead."""
494 size: Annotated[
495 SizeReference,
496 Field(
497 examples=[
498 10,
499 SizeReference(
500 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
501 ).model_dump(mode="json"),
502 ]
503 ),
504 ]
505 """reference to another axis with an optional offset (see `SizeReference`)"""
508BATCH_AXIS_ID = AxisId("batch")
511class BatchAxis(AxisBase):
512 implemented_type: ClassVar[Literal["batch"]] = "batch"
513 if TYPE_CHECKING:
514 type: Literal["batch"] = "batch"
515 else:
516 type: Literal["batch"]
518 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
519 size: Optional[Literal[1]] = None
520 """The batch size may be fixed to 1,
521 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
523 @property
524 def scale(self):
525 return 1.0
527 @property
528 def concatenable(self):
529 return True
531 @property
532 def unit(self):
533 return None
536class ChannelAxis(AxisBase):
537 implemented_type: ClassVar[Literal["channel"]] = "channel"
538 if TYPE_CHECKING:
539 type: Literal["channel"] = "channel"
540 else:
541 type: Literal["channel"]
543 id: NonBatchAxisId = AxisId("channel")
545 channel_names: NotEmpty[List[Identifier]]
547 @property
548 def size(self) -> int:
549 return len(self.channel_names)
551 @property
552 def concatenable(self):
553 return False
555 @property
556 def scale(self) -> float:
557 return 1.0
559 @property
560 def unit(self):
561 return None
564class IndexAxisBase(AxisBase):
565 implemented_type: ClassVar[Literal["index"]] = "index"
566 if TYPE_CHECKING:
567 type: Literal["index"] = "index"
568 else:
569 type: Literal["index"]
571 id: NonBatchAxisId = AxisId("index")
573 @property
574 def scale(self) -> float:
575 return 1.0
577 @property
578 def unit(self):
579 return None
582class _WithInputAxisSize(Node):
583 size: Annotated[
584 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
585 Field(
586 examples=[
587 10,
588 ParameterizedSize(min=32, step=16).model_dump(mode="json"),
589 SizeReference(
590 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
591 ).model_dump(mode="json"),
592 ]
593 ),
594 ]
595 """The size/length of this axis can be specified as
596 - fixed integer
597 - parameterized series of valid sizes (`ParameterizedSize`)
598 - reference to another axis with an optional offset (`SizeReference`)
599 """
602class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
603 concatenable: bool = False
604 """If a model has a `concatenable` input axis, it can be processed blockwise,
605 splitting a longer sample axis into blocks matching its input tensor description.
606 Output axes are concatenable if they have a `SizeReference` to a concatenable
607 input axis.
608 """
611class IndexOutputAxis(IndexAxisBase):
612 size: Annotated[
613 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
614 Field(
615 examples=[
616 10,
617 SizeReference(
618 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
619 ).model_dump(mode="json"),
620 ]
621 ),
622 ]
623 """The size/length of this axis can be specified as
624 - fixed integer
625 - reference to another axis with an optional offset (`SizeReference`)
626 - data dependent size using `DataDependentSize` (size is only known after model inference)
627 """
630class TimeAxisBase(AxisBase):
631 implemented_type: ClassVar[Literal["time"]] = "time"
632 if TYPE_CHECKING:
633 type: Literal["time"] = "time"
634 else:
635 type: Literal["time"]
637 id: NonBatchAxisId = AxisId("time")
638 unit: Optional[TimeUnit] = None
639 scale: Annotated[float, Gt(0)] = 1.0
642class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
643 concatenable: bool = False
644 """If a model has a `concatenable` input axis, it can be processed blockwise,
645 splitting a longer sample axis into blocks matching its input tensor description.
646 Output axes are concatenable if they have a `SizeReference` to a concatenable
647 input axis.
648 """
651class SpaceAxisBase(AxisBase):
652 implemented_type: ClassVar[Literal["space"]] = "space"
653 if TYPE_CHECKING:
654 type: Literal["space"] = "space"
655 else:
656 type: Literal["space"]
658 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
659 unit: Optional[SpaceUnit] = None
660 scale: Annotated[float, Gt(0)] = 1.0
663class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
664 concatenable: bool = False
665 """If a model has a `concatenable` input axis, it can be processed blockwise,
666 splitting a longer sample axis into blocks matching its input tensor description.
667 Output axes are concatenable if they have a `SizeReference` to a concatenable
668 input axis.
669 """
672INPUT_AXIS_TYPES = (
673 BatchAxis,
674 ChannelAxis,
675 IndexInputAxis,
676 TimeInputAxis,
677 SpaceInputAxis,
678)
679"""intended for isinstance comparisons in py<3.10"""
681_InputAxisUnion = Union[
682 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
683]
684InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
687class _WithOutputAxisSize(Node):
688 size: Annotated[
689 Union[Annotated[int, Gt(0)], SizeReference],
690 Field(
691 examples=[
692 10,
693 SizeReference(
694 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
695 ).model_dump(mode="json"),
696 ]
697 ),
698 ]
699 """The size/length of this axis can be specified as
700 - fixed integer
701 - reference to another axis with an optional offset (see `SizeReference`)
702 """
705class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
706 pass
709class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
710 pass
713def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
714 if isinstance(v, dict):
715 return "with_halo" if "halo" in v else "wo_halo"
716 else:
717 return "with_halo" if hasattr(v, "halo") else "wo_halo"
720_TimeOutputAxisUnion = Annotated[
721 Union[
722 Annotated[TimeOutputAxis, Tag("wo_halo")],
723 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
724 ],
725 Discriminator(_get_halo_axis_discriminator_value),
726]
729class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
730 pass
733class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
734 pass
737_SpaceOutputAxisUnion = Annotated[
738 Union[
739 Annotated[SpaceOutputAxis, Tag("wo_halo")],
740 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
741 ],
742 Discriminator(_get_halo_axis_discriminator_value),
743]
746_OutputAxisUnion = Union[
747 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
748]
749OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
751OUTPUT_AXIS_TYPES = (
752 BatchAxis,
753 ChannelAxis,
754 IndexOutputAxis,
755 TimeOutputAxis,
756 TimeOutputAxisWithHalo,
757 SpaceOutputAxis,
758 SpaceOutputAxisWithHalo,
759)
760"""intended for isinstance comparisons in py<3.10"""
763AnyAxis = Union[InputAxis, OutputAxis]
765ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
766"""intended for isinstance comparisons in py<3.10"""
768TVs = Union[
769 NotEmpty[List[int]],
770 NotEmpty[List[float]],
771 NotEmpty[List[bool]],
772 NotEmpty[List[str]],
773]
776NominalOrOrdinalDType = Literal[
777 "float32",
778 "float64",
779 "uint8",
780 "int8",
781 "uint16",
782 "int16",
783 "uint32",
784 "int32",
785 "uint64",
786 "int64",
787 "bool",
788]
791class NominalOrOrdinalDataDescr(Node):
792 values: TVs
793 """A fixed set of nominal or an ascending sequence of ordinal values.
794 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
795 String `values` are interpreted as labels for tensor values 0, ..., N.
796 Note: as YAML 1.2 does not natively support a "set" datatype,
797 nominal values should be given as a sequence (aka list/array) as well.
798 """
800 type: Annotated[
801 NominalOrOrdinalDType,
802 Field(
803 examples=[
804 "float32",
805 "uint8",
806 "uint16",
807 "int64",
808 "bool",
809 ],
810 ),
811 ] = "uint8"
813 @model_validator(mode="after")
814 def _validate_values_match_type(
815 self,
816 ) -> Self:
817 incompatible: List[Any] = []
818 for v in self.values:
819 if self.type == "bool":
820 if not isinstance(v, bool):
821 incompatible.append(v)
822 elif self.type in DTYPE_LIMITS:
823 if (
824 isinstance(v, (int, float))
825 and (
826 v < DTYPE_LIMITS[self.type].min
827 or v > DTYPE_LIMITS[self.type].max
828 )
829 or (isinstance(v, str) and "uint" not in self.type)
830 or (isinstance(v, float) and "int" in self.type)
831 ):
832 incompatible.append(v)
833 else:
834 incompatible.append(v)
836 if len(incompatible) == 5:
837 incompatible.append("...")
838 break
840 if incompatible:
841 raise ValueError(
842 f"data type '{self.type}' incompatible with values {incompatible}"
843 )
845 return self
847 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
849 @property
850 def range(self):
851 if isinstance(self.values[0], str):
852 return 0, len(self.values) - 1
853 else:
854 return min(self.values), max(self.values)
857IntervalOrRatioDType = Literal[
858 "float32",
859 "float64",
860 "uint8",
861 "int8",
862 "uint16",
863 "int16",
864 "uint32",
865 "int32",
866 "uint64",
867 "int64",
868]
871class IntervalOrRatioDataDescr(Node):
872 type: Annotated[ # TODO: rename to dtype
873 IntervalOrRatioDType,
874 Field(
875 examples=["float32", "float64", "uint8", "uint16"],
876 ),
877 ] = "float32"
878 range: Tuple[Optional[float], Optional[float]] = (
879 None,
880 None,
881 )
882 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
883 `None` corresponds to min/max of what can be expressed by **type**."""
884 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
885 scale: float = 1.0
886 """Scale for data on an interval (or ratio) scale."""
887 offset: Optional[float] = None
888 """Offset for data on a ratio scale."""
890 @model_validator(mode="before")
891 def _replace_inf(cls, data: Any):
892 if is_dict(data):
893 if "range" in data and is_sequence(data["range"]):
894 forbidden = (
895 "inf",
896 "-inf",
897 ".inf",
898 "-.inf",
899 float("inf"),
900 float("-inf"),
901 )
902 if any(v in forbidden for v in data["range"]):
903 issue_warning("replaced 'inf' value", value=data["range"])
905 data["range"] = tuple(
906 (None if v in forbidden else v) for v in data["range"]
907 )
909 return data
912TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
915class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
916 """processing base class"""
919class BinarizeKwargs(ProcessingKwargs):
920 """key word arguments for `BinarizeDescr`"""
922 threshold: float
923 """The fixed threshold"""
926class BinarizeAlongAxisKwargs(ProcessingKwargs):
927 """key word arguments for `BinarizeDescr`"""
929 threshold: NotEmpty[List[float]]
930 """The fixed threshold values along `axis`"""
932 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
933 """The `threshold` axis"""
936class BinarizeDescr(ProcessingDescrBase):
937 """Binarize the tensor with a fixed threshold.
939 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
940 will be set to one, values below the threshold to zero.
942 Examples:
943 - in YAML
944 ```yaml
945 postprocessing:
946 - id: binarize
947 kwargs:
948 axis: 'channel'
949 threshold: [0.25, 0.5, 0.75]
950 ```
951 - in Python:
952 >>> postprocessing = [BinarizeDescr(
953 ... kwargs=BinarizeAlongAxisKwargs(
954 ... axis=AxisId('channel'),
955 ... threshold=[0.25, 0.5, 0.75],
956 ... )
957 ... )]
958 """
960 implemented_id: ClassVar[Literal["binarize"]] = "binarize"
961 if TYPE_CHECKING:
962 id: Literal["binarize"] = "binarize"
963 else:
964 id: Literal["binarize"]
965 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
968class ClipDescr(ProcessingDescrBase):
969 """Set tensor values below min to min and above max to max.
971 See `ScaleRangeDescr` for examples.
972 """
974 implemented_id: ClassVar[Literal["clip"]] = "clip"
975 if TYPE_CHECKING:
976 id: Literal["clip"] = "clip"
977 else:
978 id: Literal["clip"]
980 kwargs: ClipKwargs
983class EnsureDtypeKwargs(ProcessingKwargs):
984 """key word arguments for `EnsureDtypeDescr`"""
986 dtype: Literal[
987 "float32",
988 "float64",
989 "uint8",
990 "int8",
991 "uint16",
992 "int16",
993 "uint32",
994 "int32",
995 "uint64",
996 "int64",
997 "bool",
998 ]
1001class EnsureDtypeDescr(ProcessingDescrBase):
1002 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1004 This can for example be used to ensure the inner neural network model gets a
1005 different input tensor data type than the fully described bioimage.io model does.
1007 Examples:
1008 The described bioimage.io model (incl. preprocessing) accepts any
1009 float32-compatible tensor, normalizes it with percentiles and clipping and then
1010 casts it to uint8, which is what the neural network in this example expects.
1011 - in YAML
1012 ```yaml
1013 inputs:
1014 - data:
1015 type: float32 # described bioimage.io model is compatible with any float32 input tensor
1016 preprocessing:
1017 - id: scale_range
1018 kwargs:
1019 axes: ['y', 'x']
1020 max_percentile: 99.8
1021 min_percentile: 5.0
1022 - id: clip
1023 kwargs:
1024 min: 0.0
1025 max: 1.0
1026 - id: ensure_dtype # the neural network of the model requires uint8
1027 kwargs:
1028 dtype: uint8
1029 ```
1030 - in Python:
1031 >>> preprocessing = [
1032 ... ScaleRangeDescr(
1033 ... kwargs=ScaleRangeKwargs(
1034 ... axes= (AxisId('y'), AxisId('x')),
1035 ... max_percentile= 99.8,
1036 ... min_percentile= 5.0,
1037 ... )
1038 ... ),
1039 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1040 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1041 ... ]
1042 """
1044 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1045 if TYPE_CHECKING:
1046 id: Literal["ensure_dtype"] = "ensure_dtype"
1047 else:
1048 id: Literal["ensure_dtype"]
1050 kwargs: EnsureDtypeKwargs
1053class ScaleLinearKwargs(ProcessingKwargs):
1054 """Key word arguments for `ScaleLinearDescr`"""
1056 gain: float = 1.0
1057 """multiplicative factor"""
1059 offset: float = 0.0
1060 """additive term"""
1062 @model_validator(mode="after")
1063 def _validate(self) -> Self:
1064 if self.gain == 1.0 and self.offset == 0.0:
1065 raise ValueError(
1066 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1067 + " != 0.0."
1068 )
1070 return self
1073class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1074 """Key word arguments for `ScaleLinearDescr`"""
1076 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1077 """The axis of gain and offset values."""
1079 gain: Union[float, NotEmpty[List[float]]] = 1.0
1080 """multiplicative factor"""
1082 offset: Union[float, NotEmpty[List[float]]] = 0.0
1083 """additive term"""
1085 @model_validator(mode="after")
1086 def _validate(self) -> Self:
1087 if isinstance(self.gain, list):
1088 if isinstance(self.offset, list):
1089 if len(self.gain) != len(self.offset):
1090 raise ValueError(
1091 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1092 )
1093 else:
1094 self.offset = [float(self.offset)] * len(self.gain)
1095 elif isinstance(self.offset, list):
1096 self.gain = [float(self.gain)] * len(self.offset)
1097 else:
1098 raise ValueError(
1099 "Do not specify an `axis` for scalar gain and offset values."
1100 )
1102 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1103 raise ValueError(
1104 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1105 + " != 0.0."
1106 )
1108 return self
1111class ScaleLinearDescr(ProcessingDescrBase):
1112 """Fixed linear scaling.
1114 Examples:
1115 1. Scale with scalar gain and offset
1116 - in YAML
1117 ```yaml
1118 preprocessing:
1119 - id: scale_linear
1120 kwargs:
1121 gain: 2.0
1122 offset: 3.0
1123 ```
1124 - in Python:
1125 >>> preprocessing = [
1126 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1127 ... ]
1129 2. Independent scaling along an axis
1130 - in YAML
1131 ```yaml
1132 preprocessing:
1133 - id: scale_linear
1134 kwargs:
1135 axis: 'channel'
1136 gain: [1.0, 2.0, 3.0]
1137 ```
1138 - in Python:
1139 >>> preprocessing = [
1140 ... ScaleLinearDescr(
1141 ... kwargs=ScaleLinearAlongAxisKwargs(
1142 ... axis=AxisId("channel"),
1143 ... gain=[1.0, 2.0, 3.0],
1144 ... )
1145 ... )
1146 ... ]
1148 """
1150 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1151 if TYPE_CHECKING:
1152 id: Literal["scale_linear"] = "scale_linear"
1153 else:
1154 id: Literal["scale_linear"]
1155 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1158class SigmoidDescr(ProcessingDescrBase):
1159 """The logistic sigmoid function, a.k.a. expit function.
1161 Examples:
1162 - in YAML
1163 ```yaml
1164 postprocessing:
1165 - id: sigmoid
1166 ```
1167 - in Python:
1168 >>> postprocessing = [SigmoidDescr()]
1169 """
1171 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1172 if TYPE_CHECKING:
1173 id: Literal["sigmoid"] = "sigmoid"
1174 else:
1175 id: Literal["sigmoid"]
1177 @property
1178 def kwargs(self) -> ProcessingKwargs:
1179 """empty kwargs"""
1180 return ProcessingKwargs()
1183class SoftmaxKwargs(ProcessingKwargs):
1184 """key word arguments for `SoftmaxDescr`"""
1186 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1187 """The axis to apply the softmax function along.
1188 Note:
1189 Defaults to 'channel' axis
1190 (which may not exist, in which case
1191 a different axis id has to be specified).
1192 """
1195class SoftmaxDescr(ProcessingDescrBase):
1196 """The softmax function.
1198 Examples:
1199 - in YAML
1200 ```yaml
1201 postprocessing:
1202 - id: softmax
1203 kwargs:
1204 axis: channel
1205 ```
1206 - in Python:
1207 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1208 """
1210 implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1211 if TYPE_CHECKING:
1212 id: Literal["softmax"] = "softmax"
1213 else:
1214 id: Literal["softmax"]
1216 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
1219class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1220 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1222 mean: float
1223 """The mean value to normalize with."""
1225 std: Annotated[float, Ge(1e-6)]
1226 """The standard deviation value to normalize with."""
1229class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1230 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1232 mean: NotEmpty[List[float]]
1233 """The mean value(s) to normalize with."""
1235 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1236 """The standard deviation value(s) to normalize with.
1237 Size must match `mean` values."""
1239 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1240 """The axis of the mean/std values to normalize each entry along that dimension
1241 separately."""
1243 @model_validator(mode="after")
1244 def _mean_and_std_match(self) -> Self:
1245 if len(self.mean) != len(self.std):
1246 raise ValueError(
1247 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1248 + " must match."
1249 )
1251 return self
1254class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1255 """Subtract a given mean and divide by the standard deviation.
1257 Normalize with fixed, precomputed values for
1258 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1259 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1260 axes.
1262 Examples:
1263 1. scalar value for whole tensor
1264 - in YAML
1265 ```yaml
1266 preprocessing:
1267 - id: fixed_zero_mean_unit_variance
1268 kwargs:
1269 mean: 103.5
1270 std: 13.7
1271 ```
1272 - in Python
1273 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1274 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1275 ... )]
1277 2. independently along an axis
1278 - in YAML
1279 ```yaml
1280 preprocessing:
1281 - id: fixed_zero_mean_unit_variance
1282 kwargs:
1283 axis: channel
1284 mean: [101.5, 102.5, 103.5]
1285 std: [11.7, 12.7, 13.7]
1286 ```
1287 - in Python
1288 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1289 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1290 ... axis=AxisId("channel"),
1291 ... mean=[101.5, 102.5, 103.5],
1292 ... std=[11.7, 12.7, 13.7],
1293 ... )
1294 ... )]
1295 """
1297 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1298 "fixed_zero_mean_unit_variance"
1299 )
1300 if TYPE_CHECKING:
1301 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1302 else:
1303 id: Literal["fixed_zero_mean_unit_variance"]
1305 kwargs: Union[
1306 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1307 ]
1310class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1311 """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1313 axes: Annotated[
1314 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1315 ] = None
1316 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1317 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1318 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1319 To normalize each sample independently leave out the 'batch' axis.
1320 Default: Scale all axes jointly."""
1322 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1323 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1326class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1327 """Subtract mean and divide by variance.
1329 Examples:
1330 Subtract tensor mean and variance
1331 - in YAML
1332 ```yaml
1333 preprocessing:
1334 - id: zero_mean_unit_variance
1335 ```
1336 - in Python
1337 >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1338 """
1340 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1341 "zero_mean_unit_variance"
1342 )
1343 if TYPE_CHECKING:
1344 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1345 else:
1346 id: Literal["zero_mean_unit_variance"]
1348 kwargs: ZeroMeanUnitVarianceKwargs = Field(
1349 default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1350 )
1353class ScaleRangeKwargs(ProcessingKwargs):
1354 """key word arguments for `ScaleRangeDescr`
1356 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1357 this processing step normalizes data to the [0, 1] intervall.
1358 For other percentiles the normalized values will partially be outside the [0, 1]
1359 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1360 normalized values to a range.
1361 """
1363 axes: Annotated[
1364 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1365 ] = None
1366 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1367 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1368 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1369 To normalize samples independently, leave out the "batch" axis.
1370 Default: Scale all axes jointly."""
1372 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1373 """The lower percentile used to determine the value to align with zero."""
1375 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1376 """The upper percentile used to determine the value to align with one.
1377 Has to be bigger than `min_percentile`.
1378 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1379 accepting percentiles specified in the range 0.0 to 1.0."""
1381 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1382 """Epsilon for numeric stability.
1383 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1384 with `v_lower,v_upper` values at the respective percentiles."""
1386 reference_tensor: Optional[TensorId] = None
1387 """Tensor ID to compute the percentiles from. Default: The tensor itself.
1388 For any tensor in `inputs` only input tensor references are allowed."""
1390 @field_validator("max_percentile", mode="after")
1391 @classmethod
1392 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1393 if (min_p := info.data["min_percentile"]) >= value:
1394 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1396 return value
1399class ScaleRangeDescr(ProcessingDescrBase):
1400 """Scale with percentiles.
1402 Examples:
1403 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1404 - in YAML
1405 ```yaml
1406 preprocessing:
1407 - id: scale_range
1408 kwargs:
1409 axes: ['y', 'x']
1410 max_percentile: 99.8
1411 min_percentile: 5.0
1412 ```
1413 - in Python
1414 >>> preprocessing = [
1415 ... ScaleRangeDescr(
1416 ... kwargs=ScaleRangeKwargs(
1417 ... axes= (AxisId('y'), AxisId('x')),
1418 ... max_percentile= 99.8,
1419 ... min_percentile= 5.0,
1420 ... )
1421 ... )
1422 ... ]
1424 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1425 - in YAML
1426 ```yaml
1427 preprocessing:
1428 - id: scale_range
1429 kwargs:
1430 axes: ['y', 'x']
1431 max_percentile: 99.8
1432 min_percentile: 5.0
1433 - id: scale_range
1434 - id: clip
1435 kwargs:
1436 min: 0.0
1437 max: 1.0
1438 ```
1439 - in Python
1440 >>> preprocessing = [
1441 ... ScaleRangeDescr(
1442 ... kwargs=ScaleRangeKwargs(
1443 ... axes= (AxisId('y'), AxisId('x')),
1444 ... max_percentile= 99.8,
1445 ... min_percentile= 5.0,
1446 ... )
1447 ... ),
1448 ... ClipDescr(
1449 ... kwargs=ClipKwargs(
1450 ... min=0.0,
1451 ... max=1.0,
1452 ... )
1453 ... ),
1454 ... ]
1456 """
1458 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1459 if TYPE_CHECKING:
1460 id: Literal["scale_range"] = "scale_range"
1461 else:
1462 id: Literal["scale_range"]
1463 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
1466class ScaleMeanVarianceKwargs(ProcessingKwargs):
1467 """key word arguments for `ScaleMeanVarianceKwargs`"""
1469 reference_tensor: TensorId
1470 """Name of tensor to match."""
1472 axes: Annotated[
1473 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1474 ] = None
1475 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1476 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1477 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1478 To normalize samples independently, leave out the 'batch' axis.
1479 Default: Scale all axes jointly."""
1481 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1482 """Epsilon for numeric stability:
1483 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1486class ScaleMeanVarianceDescr(ProcessingDescrBase):
1487 """Scale a tensor's data distribution to match another tensor's mean/std.
1488 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1489 """
1491 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1492 if TYPE_CHECKING:
1493 id: Literal["scale_mean_variance"] = "scale_mean_variance"
1494 else:
1495 id: Literal["scale_mean_variance"]
1496 kwargs: ScaleMeanVarianceKwargs
1499PreprocessingDescr = Annotated[
1500 Union[
1501 BinarizeDescr,
1502 ClipDescr,
1503 EnsureDtypeDescr,
1504 FixedZeroMeanUnitVarianceDescr,
1505 ScaleLinearDescr,
1506 ScaleRangeDescr,
1507 SigmoidDescr,
1508 SoftmaxDescr,
1509 ZeroMeanUnitVarianceDescr,
1510 ],
1511 Discriminator("id"),
1512]
1513PostprocessingDescr = Annotated[
1514 Union[
1515 BinarizeDescr,
1516 ClipDescr,
1517 EnsureDtypeDescr,
1518 FixedZeroMeanUnitVarianceDescr,
1519 ScaleLinearDescr,
1520 ScaleMeanVarianceDescr,
1521 ScaleRangeDescr,
1522 SigmoidDescr,
1523 SoftmaxDescr,
1524 ZeroMeanUnitVarianceDescr,
1525 ],
1526 Discriminator("id"),
1527]
1529IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1532class TensorDescrBase(Node, Generic[IO_AxisT]):
1533 id: TensorId
1534 """Tensor id. No duplicates are allowed."""
1536 description: Annotated[str, MaxLen(128)] = ""
1537 """free text description"""
1539 axes: NotEmpty[Sequence[IO_AxisT]]
1540 """tensor axes"""
1542 @property
1543 def shape(self):
1544 return tuple(a.size for a in self.axes)
1546 @field_validator("axes", mode="after", check_fields=False)
1547 @classmethod
1548 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1549 batch_axes = [a for a in axes if a.type == "batch"]
1550 if len(batch_axes) > 1:
1551 raise ValueError(
1552 f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1553 )
1555 seen_ids: Set[AxisId] = set()
1556 duplicate_axes_ids: Set[AxisId] = set()
1557 for a in axes:
1558 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1560 if duplicate_axes_ids:
1561 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1563 return axes
1565 test_tensor: FAIR[Optional[FileDescr_]] = None
1566 """An example tensor to use for testing.
1567 Using the model with the test input tensors is expected to yield the test output tensors.
1568 Each test tensor has be a an ndarray in the
1569 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1570 The file extension must be '.npy'."""
1572 sample_tensor: FAIR[Optional[FileDescr_]] = None
1573 """A sample tensor to illustrate a possible input/output for the model,
1574 The sample image primarily serves to inform a human user about an example use case
1575 and is typically stored as .hdf5, .png or .tiff.
1576 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1577 (numpy's `.npy` format is not supported).
1578 The image dimensionality has to match the number of axes specified in this tensor description.
1579 """
1581 @model_validator(mode="after")
1582 def _validate_sample_tensor(self) -> Self:
1583 if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1584 return self
1586 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1587 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType]
1588 reader.read(),
1589 extension=PurePosixPath(reader.original_file_name).suffix,
1590 )
1591 n_dims = len(tensor.squeeze().shape)
1592 n_dims_min = n_dims_max = len(self.axes)
1594 for a in self.axes:
1595 if isinstance(a, BatchAxis):
1596 n_dims_min -= 1
1597 elif isinstance(a.size, int):
1598 if a.size == 1:
1599 n_dims_min -= 1
1600 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1601 if a.size.min == 1:
1602 n_dims_min -= 1
1603 elif isinstance(a.size, SizeReference):
1604 if a.size.offset < 2:
1605 # size reference may result in singleton axis
1606 n_dims_min -= 1
1607 else:
1608 assert_never(a.size)
1610 n_dims_min = max(0, n_dims_min)
1611 if n_dims < n_dims_min or n_dims > n_dims_max:
1612 raise ValueError(
1613 f"Expected sample tensor to have {n_dims_min} to"
1614 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1615 )
1617 return self
1619 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1620 IntervalOrRatioDataDescr()
1621 )
1622 """Description of the tensor's data values, optionally per channel.
1623 If specified per channel, the data `type` needs to match across channels."""
1625 @property
1626 def dtype(
1627 self,
1628 ) -> Literal[
1629 "float32",
1630 "float64",
1631 "uint8",
1632 "int8",
1633 "uint16",
1634 "int16",
1635 "uint32",
1636 "int32",
1637 "uint64",
1638 "int64",
1639 "bool",
1640 ]:
1641 """dtype as specified under `data.type` or `data[i].type`"""
1642 if isinstance(self.data, collections.abc.Sequence):
1643 return self.data[0].type
1644 else:
1645 return self.data.type
1647 @field_validator("data", mode="after")
1648 @classmethod
1649 def _check_data_type_across_channels(
1650 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1651 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1652 if not isinstance(value, list):
1653 return value
1655 dtypes = {t.type for t in value}
1656 if len(dtypes) > 1:
1657 raise ValueError(
1658 "Tensor data descriptions per channel need to agree in their data"
1659 + f" `type`, but found {dtypes}."
1660 )
1662 return value
1664 @model_validator(mode="after")
1665 def _check_data_matches_channelaxis(self) -> Self:
1666 if not isinstance(self.data, (list, tuple)):
1667 return self
1669 for a in self.axes:
1670 if isinstance(a, ChannelAxis):
1671 size = a.size
1672 assert isinstance(size, int)
1673 break
1674 else:
1675 return self
1677 if len(self.data) != size:
1678 raise ValueError(
1679 f"Got tensor data descriptions for {len(self.data)} channels, but"
1680 + f" '{a.id}' axis has size {size}."
1681 )
1683 return self
1685 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1686 if len(array.shape) != len(self.axes):
1687 raise ValueError(
1688 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1689 + f" incompatible with {len(self.axes)} axes."
1690 )
1691 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1694class InputTensorDescr(TensorDescrBase[InputAxis]):
1695 id: TensorId = TensorId("input")
1696 """Input tensor id.
1697 No duplicates are allowed across all inputs and outputs."""
1699 optional: bool = False
1700 """indicates that this tensor may be `None`"""
1702 preprocessing: List[PreprocessingDescr] = Field(
1703 default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1704 )
1706 """Description of how this input should be preprocessed.
1708 notes:
1709 - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1710 to ensure an input tensor's data type matches the input tensor's data description.
1711 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1712 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1713 changing the data type.
1714 """
1716 @model_validator(mode="after")
1717 def _validate_preprocessing_kwargs(self) -> Self:
1718 axes_ids = [a.id for a in self.axes]
1719 for p in self.preprocessing:
1720 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1721 if kwargs_axes is None:
1722 continue
1724 if not isinstance(kwargs_axes, collections.abc.Sequence):
1725 raise ValueError(
1726 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1727 )
1729 if any(a not in axes_ids for a in kwargs_axes):
1730 raise ValueError(
1731 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1732 )
1734 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1735 dtype = self.data.type
1736 else:
1737 dtype = self.data[0].type
1739 # ensure `preprocessing` begins with `EnsureDtypeDescr`
1740 if not self.preprocessing or not isinstance(
1741 self.preprocessing[0], EnsureDtypeDescr
1742 ):
1743 self.preprocessing.insert(
1744 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1745 )
1747 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1748 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1749 self.preprocessing.append(
1750 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1751 )
1753 return self
1756def convert_axes(
1757 axes: str,
1758 *,
1759 shape: Union[
1760 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1761 ],
1762 tensor_type: Literal["input", "output"],
1763 halo: Optional[Sequence[int]],
1764 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1765):
1766 ret: List[AnyAxis] = []
1767 for i, a in enumerate(axes):
1768 axis_type = _AXIS_TYPE_MAP.get(a, a)
1769 if axis_type == "batch":
1770 ret.append(BatchAxis())
1771 continue
1773 scale = 1.0
1774 if isinstance(shape, _ParameterizedInputShape_v0_4):
1775 if shape.step[i] == 0:
1776 size = shape.min[i]
1777 else:
1778 size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1779 elif isinstance(shape, _ImplicitOutputShape_v0_4):
1780 ref_t = str(shape.reference_tensor)
1781 if ref_t.count(".") == 1:
1782 t_id, orig_a_id = ref_t.split(".")
1783 else:
1784 t_id = ref_t
1785 orig_a_id = a
1787 a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1788 if not (orig_scale := shape.scale[i]):
1789 # old way to insert a new axis dimension
1790 size = int(2 * shape.offset[i])
1791 else:
1792 scale = 1 / orig_scale
1793 if axis_type in ("channel", "index"):
1794 # these axes no longer have a scale
1795 offset_from_scale = orig_scale * size_refs.get(
1796 _TensorName_v0_4(t_id), {}
1797 ).get(orig_a_id, 0)
1798 else:
1799 offset_from_scale = 0
1800 size = SizeReference(
1801 tensor_id=TensorId(t_id),
1802 axis_id=AxisId(a_id),
1803 offset=int(offset_from_scale + 2 * shape.offset[i]),
1804 )
1805 else:
1806 size = shape[i]
1808 if axis_type == "time":
1809 if tensor_type == "input":
1810 ret.append(TimeInputAxis(size=size, scale=scale))
1811 else:
1812 assert not isinstance(size, ParameterizedSize)
1813 if halo is None:
1814 ret.append(TimeOutputAxis(size=size, scale=scale))
1815 else:
1816 assert not isinstance(size, int)
1817 ret.append(
1818 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1819 )
1821 elif axis_type == "index":
1822 if tensor_type == "input":
1823 ret.append(IndexInputAxis(size=size))
1824 else:
1825 if isinstance(size, ParameterizedSize):
1826 size = DataDependentSize(min=size.min)
1828 ret.append(IndexOutputAxis(size=size))
1829 elif axis_type == "channel":
1830 assert not isinstance(size, ParameterizedSize)
1831 if isinstance(size, SizeReference):
1832 warnings.warn(
1833 "Conversion of channel size from an implicit output shape may be"
1834 + " wrong"
1835 )
1836 ret.append(
1837 ChannelAxis(
1838 channel_names=[
1839 Identifier(f"channel{i}") for i in range(size.offset)
1840 ]
1841 )
1842 )
1843 else:
1844 ret.append(
1845 ChannelAxis(
1846 channel_names=[Identifier(f"channel{i}") for i in range(size)]
1847 )
1848 )
1849 elif axis_type == "space":
1850 if tensor_type == "input":
1851 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1852 else:
1853 assert not isinstance(size, ParameterizedSize)
1854 if halo is None or halo[i] == 0:
1855 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1856 elif isinstance(size, int):
1857 raise NotImplementedError(
1858 f"output axis with halo and fixed size (here {size}) not allowed"
1859 )
1860 else:
1861 ret.append(
1862 SpaceOutputAxisWithHalo(
1863 id=AxisId(a), size=size, scale=scale, halo=halo[i]
1864 )
1865 )
1867 return ret
1870def _axes_letters_to_ids(
1871 axes: Optional[str],
1872) -> Optional[List[AxisId]]:
1873 if axes is None:
1874 return None
1876 return [AxisId(a) for a in axes]
1879def _get_complement_v04_axis(
1880 tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1881) -> Optional[AxisId]:
1882 if axes is None:
1883 return None
1885 non_complement_axes = set(axes) | {"b"}
1886 complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1887 if len(complement_axes) > 1:
1888 raise ValueError(
1889 f"Expected none or a single complement axis, but axes '{axes}' "
1890 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1891 )
1893 return None if not complement_axes else AxisId(complement_axes[0])
1896def _convert_proc(
1897 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1898 tensor_axes: Sequence[str],
1899) -> Union[PreprocessingDescr, PostprocessingDescr]:
1900 if isinstance(p, _BinarizeDescr_v0_4):
1901 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1902 elif isinstance(p, _ClipDescr_v0_4):
1903 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1904 elif isinstance(p, _SigmoidDescr_v0_4):
1905 return SigmoidDescr()
1906 elif isinstance(p, _ScaleLinearDescr_v0_4):
1907 axes = _axes_letters_to_ids(p.kwargs.axes)
1908 if p.kwargs.axes is None:
1909 axis = None
1910 else:
1911 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1913 if axis is None:
1914 assert not isinstance(p.kwargs.gain, list)
1915 assert not isinstance(p.kwargs.offset, list)
1916 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1917 else:
1918 kwargs = ScaleLinearAlongAxisKwargs(
1919 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1920 )
1921 return ScaleLinearDescr(kwargs=kwargs)
1922 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1923 return ScaleMeanVarianceDescr(
1924 kwargs=ScaleMeanVarianceKwargs(
1925 axes=_axes_letters_to_ids(p.kwargs.axes),
1926 reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1927 eps=p.kwargs.eps,
1928 )
1929 )
1930 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1931 if p.kwargs.mode == "fixed":
1932 mean = p.kwargs.mean
1933 std = p.kwargs.std
1934 assert mean is not None
1935 assert std is not None
1937 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1939 if axis is None:
1940 if isinstance(mean, list):
1941 raise ValueError("Expected single float value for mean, not <list>")
1942 if isinstance(std, list):
1943 raise ValueError("Expected single float value for std, not <list>")
1944 return FixedZeroMeanUnitVarianceDescr(
1945 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct(
1946 mean=mean,
1947 std=std,
1948 )
1949 )
1950 else:
1951 if not isinstance(mean, list):
1952 mean = [float(mean)]
1953 if not isinstance(std, list):
1954 std = [float(std)]
1956 return FixedZeroMeanUnitVarianceDescr(
1957 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1958 axis=axis, mean=mean, std=std
1959 )
1960 )
1962 else:
1963 axes = _axes_letters_to_ids(p.kwargs.axes) or []
1964 if p.kwargs.mode == "per_dataset":
1965 axes = [AxisId("batch")] + axes
1966 if not axes:
1967 axes = None
1968 return ZeroMeanUnitVarianceDescr(
1969 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1970 )
1972 elif isinstance(p, _ScaleRangeDescr_v0_4):
1973 return ScaleRangeDescr(
1974 kwargs=ScaleRangeKwargs(
1975 axes=_axes_letters_to_ids(p.kwargs.axes),
1976 min_percentile=p.kwargs.min_percentile,
1977 max_percentile=p.kwargs.max_percentile,
1978 eps=p.kwargs.eps,
1979 )
1980 )
1981 else:
1982 assert_never(p)
1985class _InputTensorConv(
1986 Converter[
1987 _InputTensorDescr_v0_4,
1988 InputTensorDescr,
1989 FileSource_,
1990 Optional[FileSource_],
1991 Mapping[_TensorName_v0_4, Mapping[str, int]],
1992 ]
1993):
1994 def _convert(
1995 self,
1996 src: _InputTensorDescr_v0_4,
1997 tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1998 test_tensor: FileSource_,
1999 sample_tensor: Optional[FileSource_],
2000 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2001 ) -> "InputTensorDescr | dict[str, Any]":
2002 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2003 src.axes,
2004 shape=src.shape,
2005 tensor_type="input",
2006 halo=None,
2007 size_refs=size_refs,
2008 )
2009 prep: List[PreprocessingDescr] = []
2010 for p in src.preprocessing:
2011 cp = _convert_proc(p, src.axes)
2012 assert not isinstance(cp, ScaleMeanVarianceDescr)
2013 prep.append(cp)
2015 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
2017 return tgt(
2018 axes=axes,
2019 id=TensorId(str(src.name)),
2020 test_tensor=FileDescr(source=test_tensor),
2021 sample_tensor=(
2022 None if sample_tensor is None else FileDescr(source=sample_tensor)
2023 ),
2024 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType]
2025 preprocessing=prep,
2026 )
2029_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
2032class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2033 id: TensorId = TensorId("output")
2034 """Output tensor id.
2035 No duplicates are allowed across all inputs and outputs."""
2037 postprocessing: List[PostprocessingDescr] = Field(
2038 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2039 )
2040 """Description of how this output should be postprocessed.
2042 note: `postprocessing` always ends with an 'ensure_dtype' operation.
2043 If not given this is added to cast to this tensor's `data.type`.
2044 """
2046 @model_validator(mode="after")
2047 def _validate_postprocessing_kwargs(self) -> Self:
2048 axes_ids = [a.id for a in self.axes]
2049 for p in self.postprocessing:
2050 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2051 if kwargs_axes is None:
2052 continue
2054 if not isinstance(kwargs_axes, collections.abc.Sequence):
2055 raise ValueError(
2056 f"expected `axes` sequence, but got {type(kwargs_axes)}"
2057 )
2059 if any(a not in axes_ids for a in kwargs_axes):
2060 raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2062 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2063 dtype = self.data.type
2064 else:
2065 dtype = self.data[0].type
2067 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2068 if not self.postprocessing or not isinstance(
2069 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2070 ):
2071 self.postprocessing.append(
2072 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2073 )
2074 return self
2077class _OutputTensorConv(
2078 Converter[
2079 _OutputTensorDescr_v0_4,
2080 OutputTensorDescr,
2081 FileSource_,
2082 Optional[FileSource_],
2083 Mapping[_TensorName_v0_4, Mapping[str, int]],
2084 ]
2085):
2086 def _convert(
2087 self,
2088 src: _OutputTensorDescr_v0_4,
2089 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2090 test_tensor: FileSource_,
2091 sample_tensor: Optional[FileSource_],
2092 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2093 ) -> "OutputTensorDescr | dict[str, Any]":
2094 # TODO: split convert_axes into convert_output_axes and convert_input_axes
2095 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2096 src.axes,
2097 shape=src.shape,
2098 tensor_type="output",
2099 halo=src.halo,
2100 size_refs=size_refs,
2101 )
2102 data_descr: Dict[str, Any] = dict(type=src.data_type)
2103 if data_descr["type"] == "bool":
2104 data_descr["values"] = [False, True]
2106 return tgt(
2107 axes=axes,
2108 id=TensorId(str(src.name)),
2109 test_tensor=FileDescr(source=test_tensor),
2110 sample_tensor=(
2111 None if sample_tensor is None else FileDescr(source=sample_tensor)
2112 ),
2113 data=data_descr, # pyright: ignore[reportArgumentType]
2114 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2115 )
2118_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2121TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2124def validate_tensors(
2125 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2126 tensor_origin: Literal[
2127 "test_tensor"
2128 ], # for more precise error messages, e.g. 'test_tensor'
2129):
2130 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2132 def e_msg(d: TensorDescr):
2133 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2135 for descr, array in tensors.values():
2136 if array is None:
2137 axis_sizes = {a.id: None for a in descr.axes}
2138 else:
2139 try:
2140 axis_sizes = descr.get_axis_sizes_for_array(array)
2141 except ValueError as e:
2142 raise ValueError(f"{e_msg(descr)} {e}")
2144 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2146 for descr, array in tensors.values():
2147 if array is None:
2148 continue
2150 if descr.dtype in ("float32", "float64"):
2151 invalid_test_tensor_dtype = array.dtype.name not in (
2152 "float32",
2153 "float64",
2154 "uint8",
2155 "int8",
2156 "uint16",
2157 "int16",
2158 "uint32",
2159 "int32",
2160 "uint64",
2161 "int64",
2162 )
2163 else:
2164 invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2166 if invalid_test_tensor_dtype:
2167 raise ValueError(
2168 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2169 + f" match described dtype '{descr.dtype}'"
2170 )
2172 if array.min() > -1e-4 and array.max() < 1e-4:
2173 raise ValueError(
2174 "Output values are too small for reliable testing."
2175 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2176 )
2178 for a in descr.axes:
2179 actual_size = all_tensor_axes[descr.id][a.id][1]
2180 if actual_size is None:
2181 continue
2183 if a.size is None:
2184 continue
2186 if isinstance(a.size, int):
2187 if actual_size != a.size:
2188 raise ValueError(
2189 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2190 + f"has incompatible size {actual_size}, expected {a.size}"
2191 )
2192 elif isinstance(a.size, ParameterizedSize):
2193 _ = a.size.validate_size(actual_size)
2194 elif isinstance(a.size, DataDependentSize):
2195 _ = a.size.validate_size(actual_size)
2196 elif isinstance(a.size, SizeReference):
2197 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2198 if ref_tensor_axes is None:
2199 raise ValueError(
2200 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2201 + f" reference '{a.size.tensor_id}'"
2202 )
2204 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2205 if ref_axis is None or ref_size is None:
2206 raise ValueError(
2207 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2208 + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2209 )
2211 if a.unit != ref_axis.unit:
2212 raise ValueError(
2213 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2214 + " axis and reference axis to have the same `unit`, but"
2215 + f" {a.unit}!={ref_axis.unit}"
2216 )
2218 if actual_size != (
2219 expected_size := (
2220 ref_size * ref_axis.scale / a.scale + a.size.offset
2221 )
2222 ):
2223 raise ValueError(
2224 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2225 + f" {actual_size} invalid for referenced size {ref_size};"
2226 + f" expected {expected_size}"
2227 )
2228 else:
2229 assert_never(a.size)
2232FileDescr_dependencies = Annotated[
2233 FileDescr_,
2234 WithSuffix((".yaml", ".yml"), case_sensitive=True),
2235 Field(examples=[dict(source="environment.yaml")]),
2236]
2239class _ArchitectureCallableDescr(Node):
2240 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2241 """Identifier of the callable that returns a torch.nn.Module instance."""
2243 kwargs: Dict[str, YamlValue] = Field(
2244 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2245 )
2246 """key word arguments for the `callable`"""
2249class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2250 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2251 """Architecture source file"""
2253 @model_serializer(mode="wrap", when_used="unless-none")
2254 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2255 return package_file_descr_serializer(self, nxt, info)
2258class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2259 import_from: str
2260 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2263class _ArchFileConv(
2264 Converter[
2265 _CallableFromFile_v0_4,
2266 ArchitectureFromFileDescr,
2267 Optional[Sha256],
2268 Dict[str, Any],
2269 ]
2270):
2271 def _convert(
2272 self,
2273 src: _CallableFromFile_v0_4,
2274 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2275 sha256: Optional[Sha256],
2276 kwargs: Dict[str, Any],
2277 ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2278 if src.startswith("http") and src.count(":") == 2:
2279 http, source, callable_ = src.split(":")
2280 source = ":".join((http, source))
2281 elif not src.startswith("http") and src.count(":") == 1:
2282 source, callable_ = src.split(":")
2283 else:
2284 source = str(src)
2285 callable_ = str(src)
2286 return tgt(
2287 callable=Identifier(callable_),
2288 source=cast(FileSource_, source),
2289 sha256=sha256,
2290 kwargs=kwargs,
2291 )
2294_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2297class _ArchLibConv(
2298 Converter[
2299 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2300 ]
2301):
2302 def _convert(
2303 self,
2304 src: _CallableFromDepencency_v0_4,
2305 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2306 kwargs: Dict[str, Any],
2307 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2308 *mods, callable_ = src.split(".")
2309 import_from = ".".join(mods)
2310 return tgt(
2311 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2312 )
2315_arch_lib_conv = _ArchLibConv(
2316 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2317)
2320class WeightsEntryDescrBase(FileDescr):
2321 type: ClassVar[WeightsFormat]
2322 weights_format_name: ClassVar[str] # human readable
2324 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2325 """Source of the weights file."""
2327 authors: Optional[List[Author]] = None
2328 """Authors
2329 Either the person(s) that have trained this model resulting in the original weights file.
2330 (If this is the initial weights entry, i.e. it does not have a `parent`)
2331 Or the person(s) who have converted the weights to this weights format.
2332 (If this is a child weight, i.e. it has a `parent` field)
2333 """
2335 parent: Annotated[
2336 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2337 ] = None
2338 """The source weights these weights were converted from.
2339 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2340 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2341 All weight entries except one (the initial set of weights resulting from training the model),
2342 need to have this field."""
2344 comment: str = ""
2345 """A comment about this weights entry, for example how these weights were created."""
2347 @model_validator(mode="after")
2348 def _validate(self) -> Self:
2349 if self.type == self.parent:
2350 raise ValueError("Weights entry can't be it's own parent.")
2352 return self
2354 @model_serializer(mode="wrap", when_used="unless-none")
2355 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2356 return package_file_descr_serializer(self, nxt, info)
2359class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2360 type: ClassVar[WeightsFormat] = "keras_hdf5"
2361 weights_format_name: ClassVar[str] = "Keras HDF5"
2362 tensorflow_version: Version
2363 """TensorFlow version used to create these weights."""
2366FileDescr_external_data = Annotated[
2367 FileDescr_,
2368 WithSuffix(".data", case_sensitive=True),
2369 Field(examples=[dict(source="weights.onnx.data")]),
2370]
2373class OnnxWeightsDescr(WeightsEntryDescrBase):
2374 type: ClassVar[WeightsFormat] = "onnx"
2375 weights_format_name: ClassVar[str] = "ONNX"
2376 opset_version: Annotated[int, Ge(7)]
2377 """ONNX opset version"""
2379 external_data: Optional[FileDescr_external_data] = None
2380 """Source of the external ONNX data file holding the weights.
2381 (If present **source** holds the ONNX architecture without weights)."""
2383 @model_validator(mode="after")
2384 def _validate_external_data_unique_file_name(self) -> Self:
2385 if self.external_data is not None and (
2386 extract_file_name(self.source)
2387 == extract_file_name(self.external_data.source)
2388 ):
2389 raise ValueError(
2390 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'"
2391 + " must be different from ONNX `source` file name."
2392 )
2394 return self
2397class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2398 type: ClassVar[WeightsFormat] = "pytorch_state_dict"
2399 weights_format_name: ClassVar[str] = "Pytorch State Dict"
2400 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2401 pytorch_version: Version
2402 """Version of the PyTorch library used.
2403 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2404 """
2405 dependencies: Optional[FileDescr_dependencies] = None
2406 """Custom depencies beyond pytorch described in a Conda environment file.
2407 Allows to specify custom dependencies, see conda docs:
2408 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2409 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2411 The conda environment file should include pytorch and any version pinning has to be compatible with
2412 **pytorch_version**.
2413 """
2416class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2417 type: ClassVar[WeightsFormat] = "tensorflow_js"
2418 weights_format_name: ClassVar[str] = "Tensorflow.js"
2419 tensorflow_version: Version
2420 """Version of the TensorFlow library used."""
2422 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2423 """The multi-file weights.
2424 All required files/folders should be a zip archive."""
2427class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2428 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle"
2429 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2430 tensorflow_version: Version
2431 """Version of the TensorFlow library used."""
2433 dependencies: Optional[FileDescr_dependencies] = None
2434 """Custom dependencies beyond tensorflow.
2435 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2437 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2438 """The multi-file weights.
2439 All required files/folders should be a zip archive."""
2442class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2443 type: ClassVar[WeightsFormat] = "torchscript"
2444 weights_format_name: ClassVar[str] = "TorchScript"
2445 pytorch_version: Version
2446 """Version of the PyTorch library used."""
2449SpecificWeightsDescr = Union[
2450 KerasHdf5WeightsDescr,
2451 OnnxWeightsDescr,
2452 PytorchStateDictWeightsDescr,
2453 TensorflowJsWeightsDescr,
2454 TensorflowSavedModelBundleWeightsDescr,
2455 TorchscriptWeightsDescr,
2456]
2459class WeightsDescr(Node):
2460 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2461 onnx: Optional[OnnxWeightsDescr] = None
2462 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2463 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2464 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2465 None
2466 )
2467 torchscript: Optional[TorchscriptWeightsDescr] = None
2469 @model_validator(mode="after")
2470 def check_entries(self) -> Self:
2471 entries = {wtype for wtype, entry in self if entry is not None}
2473 if not entries:
2474 raise ValueError("Missing weights entry")
2476 entries_wo_parent = {
2477 wtype
2478 for wtype, entry in self
2479 if entry is not None and hasattr(entry, "parent") and entry.parent is None
2480 }
2481 if len(entries_wo_parent) != 1:
2482 issue_warning(
2483 "Exactly one weights entry may not specify the `parent` field (got"
2484 + " {value}). That entry is considered the original set of model weights."
2485 + " Other weight formats are created through conversion of the orignal or"
2486 + " already converted weights. They have to reference the weights format"
2487 + " they were converted from as their `parent`.",
2488 value=len(entries_wo_parent),
2489 field="weights",
2490 )
2492 for wtype, entry in self:
2493 if entry is None:
2494 continue
2496 assert hasattr(entry, "type")
2497 assert hasattr(entry, "parent")
2498 assert wtype == entry.type
2499 if (
2500 entry.parent is not None and entry.parent not in entries
2501 ): # self reference checked for `parent` field
2502 raise ValueError(
2503 f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2504 + f" formats: {entries}"
2505 )
2507 return self
2509 def __getitem__(
2510 self,
2511 key: Literal[
2512 "keras_hdf5",
2513 "onnx",
2514 "pytorch_state_dict",
2515 "tensorflow_js",
2516 "tensorflow_saved_model_bundle",
2517 "torchscript",
2518 ],
2519 ):
2520 if key == "keras_hdf5":
2521 ret = self.keras_hdf5
2522 elif key == "onnx":
2523 ret = self.onnx
2524 elif key == "pytorch_state_dict":
2525 ret = self.pytorch_state_dict
2526 elif key == "tensorflow_js":
2527 ret = self.tensorflow_js
2528 elif key == "tensorflow_saved_model_bundle":
2529 ret = self.tensorflow_saved_model_bundle
2530 elif key == "torchscript":
2531 ret = self.torchscript
2532 else:
2533 raise KeyError(key)
2535 if ret is None:
2536 raise KeyError(key)
2538 return ret
2540 @overload
2541 def __setitem__(
2542 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr]
2543 ) -> None: ...
2544 @overload
2545 def __setitem__(
2546 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr]
2547 ) -> None: ...
2548 @overload
2549 def __setitem__(
2550 self,
2551 key: Literal["pytorch_state_dict"],
2552 value: Optional[PytorchStateDictWeightsDescr],
2553 ) -> None: ...
2554 @overload
2555 def __setitem__(
2556 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr]
2557 ) -> None: ...
2558 @overload
2559 def __setitem__(
2560 self,
2561 key: Literal["tensorflow_saved_model_bundle"],
2562 value: Optional[TensorflowSavedModelBundleWeightsDescr],
2563 ) -> None: ...
2564 @overload
2565 def __setitem__(
2566 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr]
2567 ) -> None: ...
2569 def __setitem__(
2570 self,
2571 key: Literal[
2572 "keras_hdf5",
2573 "onnx",
2574 "pytorch_state_dict",
2575 "tensorflow_js",
2576 "tensorflow_saved_model_bundle",
2577 "torchscript",
2578 ],
2579 value: Optional[SpecificWeightsDescr],
2580 ):
2581 if key == "keras_hdf5":
2582 if value is not None and not isinstance(value, KerasHdf5WeightsDescr):
2583 raise TypeError(
2584 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}"
2585 )
2586 self.keras_hdf5 = value
2587 elif key == "onnx":
2588 if value is not None and not isinstance(value, OnnxWeightsDescr):
2589 raise TypeError(
2590 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}"
2591 )
2592 self.onnx = value
2593 elif key == "pytorch_state_dict":
2594 if value is not None and not isinstance(
2595 value, PytorchStateDictWeightsDescr
2596 ):
2597 raise TypeError(
2598 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}"
2599 )
2600 self.pytorch_state_dict = value
2601 elif key == "tensorflow_js":
2602 if value is not None and not isinstance(value, TensorflowJsWeightsDescr):
2603 raise TypeError(
2604 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}"
2605 )
2606 self.tensorflow_js = value
2607 elif key == "tensorflow_saved_model_bundle":
2608 if value is not None and not isinstance(
2609 value, TensorflowSavedModelBundleWeightsDescr
2610 ):
2611 raise TypeError(
2612 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}"
2613 )
2614 self.tensorflow_saved_model_bundle = value
2615 elif key == "torchscript":
2616 if value is not None and not isinstance(value, TorchscriptWeightsDescr):
2617 raise TypeError(
2618 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}"
2619 )
2620 self.torchscript = value
2621 else:
2622 raise KeyError(key)
2624 @property
2625 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]:
2626 return {
2627 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2628 **({} if self.onnx is None else {"onnx": self.onnx}),
2629 **(
2630 {}
2631 if self.pytorch_state_dict is None
2632 else {"pytorch_state_dict": self.pytorch_state_dict}
2633 ),
2634 **(
2635 {}
2636 if self.tensorflow_js is None
2637 else {"tensorflow_js": self.tensorflow_js}
2638 ),
2639 **(
2640 {}
2641 if self.tensorflow_saved_model_bundle is None
2642 else {
2643 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2644 }
2645 ),
2646 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2647 }
2649 @property
2650 def missing_formats(self) -> Set[WeightsFormat]:
2651 return {
2652 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2653 }
2656class ModelId(ResourceId):
2657 pass
2660class LinkedModel(LinkedResourceBase):
2661 """Reference to a bioimage.io model."""
2663 id: ModelId
2664 """A valid model `id` from the bioimage.io collection."""
2667class _DataDepSize(NamedTuple):
2668 min: StrictInt
2669 max: Optional[StrictInt]
2672class _AxisSizes(NamedTuple):
2673 """the lenghts of all axes of model inputs and outputs"""
2675 inputs: Dict[Tuple[TensorId, AxisId], int]
2676 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2679class _TensorSizes(NamedTuple):
2680 """_AxisSizes as nested dicts"""
2682 inputs: Dict[TensorId, Dict[AxisId, int]]
2683 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2686class ReproducibilityTolerance(Node, extra="allow"):
2687 """Describes what small numerical differences -- if any -- may be tolerated
2688 in the generated output when executing in different environments.
2690 A tensor element *output* is considered mismatched to the **test_tensor** if
2691 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2692 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2694 Motivation:
2695 For testing we can request the respective deep learning frameworks to be as
2696 reproducible as possible by setting seeds and chosing deterministic algorithms,
2697 but differences in operating systems, available hardware and installed drivers
2698 may still lead to numerical differences.
2699 """
2701 relative_tolerance: RelativeTolerance = 1e-3
2702 """Maximum relative tolerance of reproduced test tensor."""
2704 absolute_tolerance: AbsoluteTolerance = 1e-4
2705 """Maximum absolute tolerance of reproduced test tensor."""
2707 mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2708 """Maximum number of mismatched elements/pixels per million to tolerate."""
2710 output_ids: Sequence[TensorId] = ()
2711 """Limits the output tensor IDs these reproducibility details apply to."""
2713 weights_formats: Sequence[WeightsFormat] = ()
2714 """Limits the weights formats these details apply to."""
2717class BioimageioConfig(Node, extra="allow"):
2718 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2719 """Tolerances to allow when reproducing the model's test outputs
2720 from the model's test inputs.
2721 Only the first entry matching tensor id and weights format is considered.
2722 """
2725class Config(Node, extra="allow"):
2726 bioimageio: BioimageioConfig = Field(
2727 default_factory=BioimageioConfig.model_construct
2728 )
2731class ModelDescr(GenericModelDescrBase):
2732 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2733 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2734 """
2736 implemented_format_version: ClassVar[Literal["0.5.6"]] = "0.5.6"
2737 if TYPE_CHECKING:
2738 format_version: Literal["0.5.6"] = "0.5.6"
2739 else:
2740 format_version: Literal["0.5.6"]
2741 """Version of the bioimage.io model description specification used.
2742 When creating a new model always use the latest micro/patch version described here.
2743 The `format_version` is important for any consumer software to understand how to parse the fields.
2744 """
2746 implemented_type: ClassVar[Literal["model"]] = "model"
2747 if TYPE_CHECKING:
2748 type: Literal["model"] = "model"
2749 else:
2750 type: Literal["model"]
2751 """Specialized resource type 'model'"""
2753 id: Optional[ModelId] = None
2754 """bioimage.io-wide unique resource identifier
2755 assigned by bioimage.io; version **un**specific."""
2757 authors: FAIR[List[Author]] = Field(
2758 default_factory=cast(Callable[[], List[Author]], list)
2759 )
2760 """The authors are the creators of the model RDF and the primary points of contact."""
2762 documentation: FAIR[Optional[FileSource_documentation]] = None
2763 """URL or relative path to a markdown file with additional documentation.
2764 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2765 The documentation should include a '#[#] Validation' (sub)section
2766 with details on how to quantitatively validate the model on unseen data."""
2768 @field_validator("documentation", mode="after")
2769 @classmethod
2770 def _validate_documentation(
2771 cls, value: Optional[FileSource_documentation]
2772 ) -> Optional[FileSource_documentation]:
2773 if not get_validation_context().perform_io_checks or value is None:
2774 return value
2776 doc_reader = get_reader(value)
2777 doc_content = doc_reader.read().decode(encoding="utf-8")
2778 if not re.search("#.*[vV]alidation", doc_content):
2779 issue_warning(
2780 "No '# Validation' (sub)section found in {value}.",
2781 value=value,
2782 field="documentation",
2783 )
2785 return value
2787 inputs: NotEmpty[Sequence[InputTensorDescr]]
2788 """Describes the input tensors expected by this model."""
2790 @field_validator("inputs", mode="after")
2791 @classmethod
2792 def _validate_input_axes(
2793 cls, inputs: Sequence[InputTensorDescr]
2794 ) -> Sequence[InputTensorDescr]:
2795 input_size_refs = cls._get_axes_with_independent_size(inputs)
2797 for i, ipt in enumerate(inputs):
2798 valid_independent_refs: Dict[
2799 Tuple[TensorId, AxisId],
2800 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2801 ] = {
2802 **{
2803 (ipt.id, a.id): (ipt, a, a.size)
2804 for a in ipt.axes
2805 if not isinstance(a, BatchAxis)
2806 and isinstance(a.size, (int, ParameterizedSize))
2807 },
2808 **input_size_refs,
2809 }
2810 for a, ax in enumerate(ipt.axes):
2811 cls._validate_axis(
2812 "inputs",
2813 i=i,
2814 tensor_id=ipt.id,
2815 a=a,
2816 axis=ax,
2817 valid_independent_refs=valid_independent_refs,
2818 )
2819 return inputs
2821 @staticmethod
2822 def _validate_axis(
2823 field_name: str,
2824 i: int,
2825 tensor_id: TensorId,
2826 a: int,
2827 axis: AnyAxis,
2828 valid_independent_refs: Dict[
2829 Tuple[TensorId, AxisId],
2830 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2831 ],
2832 ):
2833 if isinstance(axis, BatchAxis) or isinstance(
2834 axis.size, (int, ParameterizedSize, DataDependentSize)
2835 ):
2836 return
2837 elif not isinstance(axis.size, SizeReference):
2838 assert_never(axis.size)
2840 # validate axis.size SizeReference
2841 ref = (axis.size.tensor_id, axis.size.axis_id)
2842 if ref not in valid_independent_refs:
2843 raise ValueError(
2844 "Invalid tensor axis reference at"
2845 + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2846 )
2847 if ref == (tensor_id, axis.id):
2848 raise ValueError(
2849 "Self-referencing not allowed for"
2850 + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2851 )
2852 if axis.type == "channel":
2853 if valid_independent_refs[ref][1].type != "channel":
2854 raise ValueError(
2855 "A channel axis' size may only reference another fixed size"
2856 + " channel axis."
2857 )
2858 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2859 ref_size = valid_independent_refs[ref][2]
2860 assert isinstance(ref_size, int), (
2861 "channel axis ref (another channel axis) has to specify fixed"
2862 + " size"
2863 )
2864 generated_channel_names = [
2865 Identifier(axis.channel_names.format(i=i))
2866 for i in range(1, ref_size + 1)
2867 ]
2868 axis.channel_names = generated_channel_names
2870 if (ax_unit := getattr(axis, "unit", None)) != (
2871 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2872 ):
2873 raise ValueError(
2874 "The units of an axis and its reference axis need to match, but"
2875 + f" '{ax_unit}' != '{ref_unit}'."
2876 )
2877 ref_axis = valid_independent_refs[ref][1]
2878 if isinstance(ref_axis, BatchAxis):
2879 raise ValueError(
2880 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2881 + " (a batch axis is not allowed as reference)."
2882 )
2884 if isinstance(axis, WithHalo):
2885 min_size = axis.size.get_size(axis, ref_axis, n=0)
2886 if (min_size - 2 * axis.halo) < 1:
2887 raise ValueError(
2888 f"axis {axis.id} with minimum size {min_size} is too small for halo"
2889 + f" {axis.halo}."
2890 )
2892 ref_halo = axis.halo * axis.scale / ref_axis.scale
2893 if ref_halo != int(ref_halo):
2894 raise ValueError(
2895 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} ="
2896 + f" {tensor_id}.{axis.id}.halo {axis.halo}"
2897 + f" * {tensor_id}.{axis.id}.scale {axis.scale}"
2898 + f" / {'.'.join(ref)}.scale {ref_axis.scale})."
2899 )
2901 @model_validator(mode="after")
2902 def _validate_test_tensors(self) -> Self:
2903 if not get_validation_context().perform_io_checks:
2904 return self
2906 test_output_arrays = [
2907 None if descr.test_tensor is None else load_array(descr.test_tensor)
2908 for descr in self.outputs
2909 ]
2910 test_input_arrays = [
2911 None if descr.test_tensor is None else load_array(descr.test_tensor)
2912 for descr in self.inputs
2913 ]
2915 tensors = {
2916 descr.id: (descr, array)
2917 for descr, array in zip(
2918 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2919 )
2920 }
2921 validate_tensors(tensors, tensor_origin="test_tensor")
2923 output_arrays = {
2924 descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2925 }
2926 for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2927 if not rep_tol.absolute_tolerance:
2928 continue
2930 if rep_tol.output_ids:
2931 out_arrays = {
2932 oid: a
2933 for oid, a in output_arrays.items()
2934 if oid in rep_tol.output_ids
2935 }
2936 else:
2937 out_arrays = output_arrays
2939 for out_id, array in out_arrays.items():
2940 if array is None:
2941 continue
2943 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2944 raise ValueError(
2945 "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2946 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2947 + f" (1% of the maximum value of the test tensor '{out_id}')"
2948 )
2950 return self
2952 @model_validator(mode="after")
2953 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2954 ipt_refs = {t.id for t in self.inputs}
2955 out_refs = {t.id for t in self.outputs}
2956 for ipt in self.inputs:
2957 for p in ipt.preprocessing:
2958 ref = p.kwargs.get("reference_tensor")
2959 if ref is None:
2960 continue
2961 if ref not in ipt_refs:
2962 raise ValueError(
2963 f"`reference_tensor` '{ref}' not found. Valid input tensor"
2964 + f" references are: {ipt_refs}."
2965 )
2967 for out in self.outputs:
2968 for p in out.postprocessing:
2969 ref = p.kwargs.get("reference_tensor")
2970 if ref is None:
2971 continue
2973 if ref not in ipt_refs and ref not in out_refs:
2974 raise ValueError(
2975 f"`reference_tensor` '{ref}' not found. Valid tensor references"
2976 + f" are: {ipt_refs | out_refs}."
2977 )
2979 return self
2981 # TODO: use validate funcs in validate_test_tensors
2982 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2984 name: Annotated[
2985 str,
2986 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2987 MinLen(5),
2988 MaxLen(128),
2989 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2990 ]
2991 """A human-readable name of this model.
2992 It should be no longer than 64 characters
2993 and may only contain letter, number, underscore, minus, parentheses and spaces.
2994 We recommend to chose a name that refers to the model's task and image modality.
2995 """
2997 outputs: NotEmpty[Sequence[OutputTensorDescr]]
2998 """Describes the output tensors."""
3000 @field_validator("outputs", mode="after")
3001 @classmethod
3002 def _validate_tensor_ids(
3003 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
3004 ) -> Sequence[OutputTensorDescr]:
3005 tensor_ids = [
3006 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
3007 ]
3008 duplicate_tensor_ids: List[str] = []
3009 seen: Set[str] = set()
3010 for t in tensor_ids:
3011 if t in seen:
3012 duplicate_tensor_ids.append(t)
3014 seen.add(t)
3016 if duplicate_tensor_ids:
3017 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
3019 return outputs
3021 @staticmethod
3022 def _get_axes_with_parameterized_size(
3023 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3024 ):
3025 return {
3026 f"{t.id}.{a.id}": (t, a, a.size)
3027 for t in io
3028 for a in t.axes
3029 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
3030 }
3032 @staticmethod
3033 def _get_axes_with_independent_size(
3034 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3035 ):
3036 return {
3037 (t.id, a.id): (t, a, a.size)
3038 for t in io
3039 for a in t.axes
3040 if not isinstance(a, BatchAxis)
3041 and isinstance(a.size, (int, ParameterizedSize))
3042 }
3044 @field_validator("outputs", mode="after")
3045 @classmethod
3046 def _validate_output_axes(
3047 cls, outputs: List[OutputTensorDescr], info: ValidationInfo
3048 ) -> List[OutputTensorDescr]:
3049 input_size_refs = cls._get_axes_with_independent_size(
3050 info.data.get("inputs", [])
3051 )
3052 output_size_refs = cls._get_axes_with_independent_size(outputs)
3054 for i, out in enumerate(outputs):
3055 valid_independent_refs: Dict[
3056 Tuple[TensorId, AxisId],
3057 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3058 ] = {
3059 **{
3060 (out.id, a.id): (out, a, a.size)
3061 for a in out.axes
3062 if not isinstance(a, BatchAxis)
3063 and isinstance(a.size, (int, ParameterizedSize))
3064 },
3065 **input_size_refs,
3066 **output_size_refs,
3067 }
3068 for a, ax in enumerate(out.axes):
3069 cls._validate_axis(
3070 "outputs",
3071 i,
3072 out.id,
3073 a,
3074 ax,
3075 valid_independent_refs=valid_independent_refs,
3076 )
3078 return outputs
3080 packaged_by: List[Author] = Field(
3081 default_factory=cast(Callable[[], List[Author]], list)
3082 )
3083 """The persons that have packaged and uploaded this model.
3084 Only required if those persons differ from the `authors`."""
3086 parent: Optional[LinkedModel] = None
3087 """The model from which this model is derived, e.g. by fine-tuning the weights."""
3089 @model_validator(mode="after")
3090 def _validate_parent_is_not_self(self) -> Self:
3091 if self.parent is not None and self.parent.id == self.id:
3092 raise ValueError("A model description may not reference itself as parent.")
3094 return self
3096 run_mode: Annotated[
3097 Optional[RunMode],
3098 warn(None, "Run mode '{value}' has limited support across consumer softwares."),
3099 ] = None
3100 """Custom run mode for this model: for more complex prediction procedures like test time
3101 data augmentation that currently cannot be expressed in the specification.
3102 No standard run modes are defined yet."""
3104 timestamp: Datetime = Field(default_factory=Datetime.now)
3105 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
3106 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
3107 (In Python a datetime object is valid, too)."""
3109 training_data: Annotated[
3110 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
3111 Field(union_mode="left_to_right"),
3112 ] = None
3113 """The dataset used to train this model"""
3115 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
3116 """The weights for this model.
3117 Weights can be given for different formats, but should otherwise be equivalent.
3118 The available weight formats determine which consumers can use this model."""
3120 config: Config = Field(default_factory=Config.model_construct)
3122 @model_validator(mode="after")
3123 def _add_default_cover(self) -> Self:
3124 if not get_validation_context().perform_io_checks or self.covers:
3125 return self
3127 try:
3128 generated_covers = generate_covers(
3129 [
3130 (t, load_array(t.test_tensor))
3131 for t in self.inputs
3132 if t.test_tensor is not None
3133 ],
3134 [
3135 (t, load_array(t.test_tensor))
3136 for t in self.outputs
3137 if t.test_tensor is not None
3138 ],
3139 )
3140 except Exception as e:
3141 issue_warning(
3142 "Failed to generate cover image(s): {e}",
3143 value=self.covers,
3144 msg_context=dict(e=e),
3145 field="covers",
3146 )
3147 else:
3148 self.covers.extend(generated_covers)
3150 return self
3152 def get_input_test_arrays(self) -> List[NDArray[Any]]:
3153 return self._get_test_arrays(self.inputs)
3155 def get_output_test_arrays(self) -> List[NDArray[Any]]:
3156 return self._get_test_arrays(self.outputs)
3158 @staticmethod
3159 def _get_test_arrays(
3160 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3161 ):
3162 ts: List[FileDescr] = []
3163 for d in io_descr:
3164 if d.test_tensor is None:
3165 raise ValueError(
3166 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3167 )
3168 ts.append(d.test_tensor)
3170 data = [load_array(t) for t in ts]
3171 assert all(isinstance(d, np.ndarray) for d in data)
3172 return data
3174 @staticmethod
3175 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3176 batch_size = 1
3177 tensor_with_batchsize: Optional[TensorId] = None
3178 for tid in tensor_sizes:
3179 for aid, s in tensor_sizes[tid].items():
3180 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3181 continue
3183 if batch_size != 1:
3184 assert tensor_with_batchsize is not None
3185 raise ValueError(
3186 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3187 )
3189 batch_size = s
3190 tensor_with_batchsize = tid
3192 return batch_size
3194 def get_output_tensor_sizes(
3195 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3196 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3197 """Returns the tensor output sizes for given **input_sizes**.
3198 Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3199 Otherwise it might be larger than the actual (valid) output"""
3200 batch_size = self.get_batch_size(input_sizes)
3201 ns = self.get_ns(input_sizes)
3203 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3204 return tensor_sizes.outputs
3206 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3207 """get parameter `n` for each parameterized axis
3208 such that the valid input size is >= the given input size"""
3209 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3210 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3211 for tid in input_sizes:
3212 for aid, s in input_sizes[tid].items():
3213 size_descr = axes[tid][aid].size
3214 if isinstance(size_descr, ParameterizedSize):
3215 ret[(tid, aid)] = size_descr.get_n(s)
3216 elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3217 pass
3218 else:
3219 assert_never(size_descr)
3221 return ret
3223 def get_tensor_sizes(
3224 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3225 ) -> _TensorSizes:
3226 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3227 return _TensorSizes(
3228 {
3229 t: {
3230 aa: axis_sizes.inputs[(tt, aa)]
3231 for tt, aa in axis_sizes.inputs
3232 if tt == t
3233 }
3234 for t in {tt for tt, _ in axis_sizes.inputs}
3235 },
3236 {
3237 t: {
3238 aa: axis_sizes.outputs[(tt, aa)]
3239 for tt, aa in axis_sizes.outputs
3240 if tt == t
3241 }
3242 for t in {tt for tt, _ in axis_sizes.outputs}
3243 },
3244 )
3246 def get_axis_sizes(
3247 self,
3248 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3249 batch_size: Optional[int] = None,
3250 *,
3251 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3252 ) -> _AxisSizes:
3253 """Determine input and output block shape for scale factors **ns**
3254 of parameterized input sizes.
3256 Args:
3257 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3258 that is parameterized as `size = min + n * step`.
3259 batch_size: The desired size of the batch dimension.
3260 If given **batch_size** overwrites any batch size present in
3261 **max_input_shape**. Default 1.
3262 max_input_shape: Limits the derived block shapes.
3263 Each axis for which the input size, parameterized by `n`, is larger
3264 than **max_input_shape** is set to the minimal value `n_min` for which
3265 this is still true.
3266 Use this for small input samples or large values of **ns**.
3267 Or simply whenever you know the full input shape.
3269 Returns:
3270 Resolved axis sizes for model inputs and outputs.
3271 """
3272 max_input_shape = max_input_shape or {}
3273 if batch_size is None:
3274 for (_t_id, a_id), s in max_input_shape.items():
3275 if a_id == BATCH_AXIS_ID:
3276 batch_size = s
3277 break
3278 else:
3279 batch_size = 1
3281 all_axes = {
3282 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3283 }
3285 inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3286 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3288 def get_axis_size(a: Union[InputAxis, OutputAxis]):
3289 if isinstance(a, BatchAxis):
3290 if (t_descr.id, a.id) in ns:
3291 logger.warning(
3292 "Ignoring unexpected size increment factor (n) for batch axis"
3293 + " of tensor '{}'.",
3294 t_descr.id,
3295 )
3296 return batch_size
3297 elif isinstance(a.size, int):
3298 if (t_descr.id, a.id) in ns:
3299 logger.warning(
3300 "Ignoring unexpected size increment factor (n) for fixed size"
3301 + " axis '{}' of tensor '{}'.",
3302 a.id,
3303 t_descr.id,
3304 )
3305 return a.size
3306 elif isinstance(a.size, ParameterizedSize):
3307 if (t_descr.id, a.id) not in ns:
3308 raise ValueError(
3309 "Size increment factor (n) missing for parametrized axis"
3310 + f" '{a.id}' of tensor '{t_descr.id}'."
3311 )
3312 n = ns[(t_descr.id, a.id)]
3313 s_max = max_input_shape.get((t_descr.id, a.id))
3314 if s_max is not None:
3315 n = min(n, a.size.get_n(s_max))
3317 return a.size.get_size(n)
3319 elif isinstance(a.size, SizeReference):
3320 if (t_descr.id, a.id) in ns:
3321 logger.warning(
3322 "Ignoring unexpected size increment factor (n) for axis '{}'"
3323 + " of tensor '{}' with size reference.",
3324 a.id,
3325 t_descr.id,
3326 )
3327 assert not isinstance(a, BatchAxis)
3328 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3329 assert not isinstance(ref_axis, BatchAxis)
3330 ref_key = (a.size.tensor_id, a.size.axis_id)
3331 ref_size = inputs.get(ref_key, outputs.get(ref_key))
3332 assert ref_size is not None, ref_key
3333 assert not isinstance(ref_size, _DataDepSize), ref_key
3334 return a.size.get_size(
3335 axis=a,
3336 ref_axis=ref_axis,
3337 ref_size=ref_size,
3338 )
3339 elif isinstance(a.size, DataDependentSize):
3340 if (t_descr.id, a.id) in ns:
3341 logger.warning(
3342 "Ignoring unexpected increment factor (n) for data dependent"
3343 + " size axis '{}' of tensor '{}'.",
3344 a.id,
3345 t_descr.id,
3346 )
3347 return _DataDepSize(a.size.min, a.size.max)
3348 else:
3349 assert_never(a.size)
3351 # first resolve all , but the `SizeReference` input sizes
3352 for t_descr in self.inputs:
3353 for a in t_descr.axes:
3354 if not isinstance(a.size, SizeReference):
3355 s = get_axis_size(a)
3356 assert not isinstance(s, _DataDepSize)
3357 inputs[t_descr.id, a.id] = s
3359 # resolve all other input axis sizes
3360 for t_descr in self.inputs:
3361 for a in t_descr.axes:
3362 if isinstance(a.size, SizeReference):
3363 s = get_axis_size(a)
3364 assert not isinstance(s, _DataDepSize)
3365 inputs[t_descr.id, a.id] = s
3367 # resolve all output axis sizes
3368 for t_descr in self.outputs:
3369 for a in t_descr.axes:
3370 assert not isinstance(a.size, ParameterizedSize)
3371 s = get_axis_size(a)
3372 outputs[t_descr.id, a.id] = s
3374 return _AxisSizes(inputs=inputs, outputs=outputs)
3376 @model_validator(mode="before")
3377 @classmethod
3378 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3379 cls.convert_from_old_format_wo_validation(data)
3380 return data
3382 @classmethod
3383 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3384 """Convert metadata following an older format version to this classes' format
3385 without validating the result.
3386 """
3387 if (
3388 data.get("type") == "model"
3389 and isinstance(fv := data.get("format_version"), str)
3390 and fv.count(".") == 2
3391 ):
3392 fv_parts = fv.split(".")
3393 if any(not p.isdigit() for p in fv_parts):
3394 return
3396 fv_tuple = tuple(map(int, fv_parts))
3398 assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3399 if fv_tuple[:2] in ((0, 3), (0, 4)):
3400 m04 = _ModelDescr_v0_4.load(data)
3401 if isinstance(m04, InvalidDescr):
3402 try:
3403 updated = _model_conv.convert_as_dict(
3404 m04 # pyright: ignore[reportArgumentType]
3405 )
3406 except Exception as e:
3407 logger.error(
3408 "Failed to convert from invalid model 0.4 description."
3409 + f"\nerror: {e}"
3410 + "\nProceeding with model 0.5 validation without conversion."
3411 )
3412 updated = None
3413 else:
3414 updated = _model_conv.convert_as_dict(m04)
3416 if updated is not None:
3417 data.clear()
3418 data.update(updated)
3420 elif fv_tuple[:2] == (0, 5):
3421 # bump patch version
3422 data["format_version"] = cls.implemented_format_version
3425class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3426 def _convert(
3427 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3428 ) -> "ModelDescr | dict[str, Any]":
3429 name = "".join(
3430 c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3431 for c in src.name
3432 )
3434 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3435 conv = (
3436 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3437 )
3438 return None if auths is None else [conv(a) for a in auths]
3440 if TYPE_CHECKING:
3441 arch_file_conv = _arch_file_conv.convert
3442 arch_lib_conv = _arch_lib_conv.convert
3443 else:
3444 arch_file_conv = _arch_file_conv.convert_as_dict
3445 arch_lib_conv = _arch_lib_conv.convert_as_dict
3447 input_size_refs = {
3448 ipt.name: {
3449 a: s
3450 for a, s in zip(
3451 ipt.axes,
3452 (
3453 ipt.shape.min
3454 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3455 else ipt.shape
3456 ),
3457 )
3458 }
3459 for ipt in src.inputs
3460 if ipt.shape
3461 }
3462 output_size_refs = {
3463 **{
3464 out.name: {a: s for a, s in zip(out.axes, out.shape)}
3465 for out in src.outputs
3466 if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3467 },
3468 **input_size_refs,
3469 }
3471 return tgt(
3472 attachments=(
3473 []
3474 if src.attachments is None
3475 else [FileDescr(source=f) for f in src.attachments.files]
3476 ),
3477 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType]
3478 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType]
3479 config=src.config, # pyright: ignore[reportArgumentType]
3480 covers=src.covers,
3481 description=src.description,
3482 documentation=src.documentation,
3483 format_version="0.5.6",
3484 git_repo=src.git_repo, # pyright: ignore[reportArgumentType]
3485 icon=src.icon,
3486 id=None if src.id is None else ModelId(src.id),
3487 id_emoji=src.id_emoji,
3488 license=src.license, # type: ignore
3489 links=src.links,
3490 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType]
3491 name=name,
3492 tags=src.tags,
3493 type=src.type,
3494 uploader=src.uploader,
3495 version=src.version,
3496 inputs=[ # pyright: ignore[reportArgumentType]
3497 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3498 for ipt, tt, st in zip(
3499 src.inputs,
3500 src.test_inputs,
3501 src.sample_inputs or [None] * len(src.test_inputs),
3502 )
3503 ],
3504 outputs=[ # pyright: ignore[reportArgumentType]
3505 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3506 for out, tt, st in zip(
3507 src.outputs,
3508 src.test_outputs,
3509 src.sample_outputs or [None] * len(src.test_outputs),
3510 )
3511 ],
3512 parent=(
3513 None
3514 if src.parent is None
3515 else LinkedModel(
3516 id=ModelId(
3517 str(src.parent.id)
3518 + (
3519 ""
3520 if src.parent.version_number is None
3521 else f"/{src.parent.version_number}"
3522 )
3523 )
3524 )
3525 ),
3526 training_data=(
3527 None
3528 if src.training_data is None
3529 else (
3530 LinkedDataset(
3531 id=DatasetId(
3532 str(src.training_data.id)
3533 + (
3534 ""
3535 if src.training_data.version_number is None
3536 else f"/{src.training_data.version_number}"
3537 )
3538 )
3539 )
3540 if isinstance(src.training_data, LinkedDataset02)
3541 else src.training_data
3542 )
3543 ),
3544 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType]
3545 run_mode=src.run_mode,
3546 timestamp=src.timestamp,
3547 weights=(WeightsDescr if TYPE_CHECKING else dict)(
3548 keras_hdf5=(w := src.weights.keras_hdf5)
3549 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3550 authors=conv_authors(w.authors),
3551 source=w.source,
3552 tensorflow_version=w.tensorflow_version or Version("1.15"),
3553 parent=w.parent,
3554 ),
3555 onnx=(w := src.weights.onnx)
3556 and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3557 source=w.source,
3558 authors=conv_authors(w.authors),
3559 parent=w.parent,
3560 opset_version=w.opset_version or 15,
3561 ),
3562 pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3563 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3564 source=w.source,
3565 authors=conv_authors(w.authors),
3566 parent=w.parent,
3567 architecture=(
3568 arch_file_conv(
3569 w.architecture,
3570 w.architecture_sha256,
3571 w.kwargs,
3572 )
3573 if isinstance(w.architecture, _CallableFromFile_v0_4)
3574 else arch_lib_conv(w.architecture, w.kwargs)
3575 ),
3576 pytorch_version=w.pytorch_version or Version("1.10"),
3577 dependencies=(
3578 None
3579 if w.dependencies is None
3580 else (FileDescr if TYPE_CHECKING else dict)(
3581 source=cast(
3582 FileSource,
3583 str(deps := w.dependencies)[
3584 (
3585 len("conda:")
3586 if str(deps).startswith("conda:")
3587 else 0
3588 ) :
3589 ],
3590 )
3591 )
3592 ),
3593 ),
3594 tensorflow_js=(w := src.weights.tensorflow_js)
3595 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3596 source=w.source,
3597 authors=conv_authors(w.authors),
3598 parent=w.parent,
3599 tensorflow_version=w.tensorflow_version or Version("1.15"),
3600 ),
3601 tensorflow_saved_model_bundle=(
3602 w := src.weights.tensorflow_saved_model_bundle
3603 )
3604 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3605 authors=conv_authors(w.authors),
3606 parent=w.parent,
3607 source=w.source,
3608 tensorflow_version=w.tensorflow_version or Version("1.15"),
3609 dependencies=(
3610 None
3611 if w.dependencies is None
3612 else (FileDescr if TYPE_CHECKING else dict)(
3613 source=cast(
3614 FileSource,
3615 (
3616 str(w.dependencies)[len("conda:") :]
3617 if str(w.dependencies).startswith("conda:")
3618 else str(w.dependencies)
3619 ),
3620 )
3621 )
3622 ),
3623 ),
3624 torchscript=(w := src.weights.torchscript)
3625 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3626 source=w.source,
3627 authors=conv_authors(w.authors),
3628 parent=w.parent,
3629 pytorch_version=w.pytorch_version or Version("1.10"),
3630 ),
3631 ),
3632 )
3635_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3638# create better cover images for 3d data and non-image outputs
3639def generate_covers(
3640 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3641 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3642) -> List[Path]:
3643 def squeeze(
3644 data: NDArray[Any], axes: Sequence[AnyAxis]
3645 ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3646 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3647 if data.ndim != len(axes):
3648 raise ValueError(
3649 f"tensor shape {data.shape} does not match described axes"
3650 + f" {[a.id for a in axes]}"
3651 )
3653 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3654 return data.squeeze(), axes
3656 def normalize(
3657 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3658 ) -> NDArray[np.float32]:
3659 data = data.astype("float32")
3660 data -= data.min(axis=axis, keepdims=True)
3661 data /= data.max(axis=axis, keepdims=True) + eps
3662 return data
3664 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3665 original_shape = data.shape
3666 original_axes = list(axes)
3667 data, axes = squeeze(data, axes)
3669 # take slice fom any batch or index axis if needed
3670 # and convert the first channel axis and take a slice from any additional channel axes
3671 slices: Tuple[slice, ...] = ()
3672 ndim = data.ndim
3673 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3674 has_c_axis = False
3675 for i, a in enumerate(axes):
3676 s = data.shape[i]
3677 assert s > 1
3678 if (
3679 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3680 and ndim > ndim_need
3681 ):
3682 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3683 ndim -= 1
3684 elif isinstance(a, ChannelAxis):
3685 if has_c_axis:
3686 # second channel axis
3687 data = data[slices + (slice(0, 1),)]
3688 ndim -= 1
3689 else:
3690 has_c_axis = True
3691 if s == 2:
3692 # visualize two channels with cyan and magenta
3693 data = np.concatenate(
3694 [
3695 data[slices + (slice(1, 2),)],
3696 data[slices + (slice(0, 1),)],
3697 (
3698 data[slices + (slice(0, 1),)]
3699 + data[slices + (slice(1, 2),)]
3700 )
3701 / 2, # TODO: take maximum instead?
3702 ],
3703 axis=i,
3704 )
3705 elif data.shape[i] == 3:
3706 pass # visualize 3 channels as RGB
3707 else:
3708 # visualize first 3 channels as RGB
3709 data = data[slices + (slice(3),)]
3711 assert data.shape[i] == 3
3713 slices += (slice(None),)
3715 data, axes = squeeze(data, axes)
3716 assert len(axes) == ndim
3717 # take slice from z axis if needed
3718 slices = ()
3719 if ndim > ndim_need:
3720 for i, a in enumerate(axes):
3721 s = data.shape[i]
3722 if a.id == AxisId("z"):
3723 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3724 data, axes = squeeze(data, axes)
3725 ndim -= 1
3726 break
3728 slices += (slice(None),)
3730 # take slice from any space or time axis
3731 slices = ()
3733 for i, a in enumerate(axes):
3734 if ndim <= ndim_need:
3735 break
3737 s = data.shape[i]
3738 assert s > 1
3739 if isinstance(
3740 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3741 ):
3742 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3743 ndim -= 1
3745 slices += (slice(None),)
3747 del slices
3748 data, axes = squeeze(data, axes)
3749 assert len(axes) == ndim
3751 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2):
3752 raise ValueError(
3753 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}."
3754 )
3756 if not has_c_axis:
3757 assert ndim == 2
3758 data = np.repeat(data[:, :, None], 3, axis=2)
3759 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3760 ndim += 1
3762 assert ndim == 3
3764 # transpose axis order such that longest axis comes first...
3765 axis_order: List[int] = list(np.argsort(list(data.shape)))
3766 axis_order.reverse()
3767 # ... and channel axis is last
3768 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3769 axis_order.append(axis_order.pop(c))
3770 axes = [axes[ao] for ao in axis_order]
3771 data = data.transpose(axis_order)
3773 # h, w = data.shape[:2]
3774 # if h / w in (1.0 or 2.0):
3775 # pass
3776 # elif h / w < 2:
3777 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3779 norm_along = (
3780 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3781 )
3782 # normalize the data and map to 8 bit
3783 data = normalize(data, norm_along)
3784 data = (data * 255).astype("uint8")
3786 return data
3788 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3789 assert im0.dtype == im1.dtype == np.uint8
3790 assert im0.shape == im1.shape
3791 assert im0.ndim == 3
3792 N, M, C = im0.shape
3793 assert C == 3
3794 out = np.ones((N, M, C), dtype="uint8")
3795 for c in range(C):
3796 outc = np.tril(im0[..., c])
3797 mask = outc == 0
3798 outc[mask] = np.triu(im1[..., c])[mask]
3799 out[..., c] = outc
3801 return out
3803 if not inputs:
3804 raise ValueError("Missing test input tensor for cover generation.")
3806 if not outputs:
3807 raise ValueError("Missing test output tensor for cover generation.")
3809 ipt_descr, ipt = inputs[0]
3810 out_descr, out = outputs[0]
3812 ipt_img = to_2d_image(ipt, ipt_descr.axes)
3813 out_img = to_2d_image(out, out_descr.axes)
3815 cover_folder = Path(mkdtemp())
3816 if ipt_img.shape == out_img.shape:
3817 covers = [cover_folder / "cover.png"]
3818 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3819 else:
3820 covers = [cover_folder / "input.png", cover_folder / "output.png"]
3821 imwrite(covers[0], ipt_img)
3822 imwrite(covers[1], out_img)
3824 return covers