Coverage for bioimageio/spec/model/v0_5.py: 75%
1355 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-11 07:34 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-11 07:34 +0000
1from __future__ import annotations
3import collections.abc
4import re
5import string
6import warnings
7from abc import ABC
8from copy import deepcopy
9from itertools import chain
10from math import ceil
11from pathlib import Path, PurePosixPath
12from tempfile import mkdtemp
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 ClassVar,
18 Dict,
19 Generic,
20 List,
21 Literal,
22 Mapping,
23 NamedTuple,
24 Optional,
25 Sequence,
26 Set,
27 Tuple,
28 Type,
29 TypeVar,
30 Union,
31 cast,
32)
34import numpy as np
35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
37from loguru import logger
38from numpy.typing import NDArray
39from pydantic import (
40 AfterValidator,
41 Discriminator,
42 Field,
43 RootModel,
44 SerializationInfo,
45 SerializerFunctionWrapHandler,
46 StrictInt,
47 Tag,
48 ValidationInfo,
49 WrapSerializer,
50 field_validator,
51 model_serializer,
52 model_validator,
53)
54from typing_extensions import Annotated, Self, assert_never, get_args
56from .._internal.common_nodes import (
57 InvalidDescr,
58 Node,
59 NodeWithExplicitlySetFields,
60)
61from .._internal.constants import DTYPE_LIMITS
62from .._internal.field_warning import issue_warning, warn
63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
64from .._internal.io import FileDescr as FileDescr
65from .._internal.io import (
66 FileSource,
67 WithSuffix,
68 YamlValue,
69 get_reader,
70 wo_special_file_name,
71)
72from .._internal.io_basics import Sha256 as Sha256
73from .._internal.io_packaging import (
74 FileDescr_,
75 FileSource_,
76 package_file_descr_serializer,
77)
78from .._internal.io_utils import load_array
79from .._internal.node_converter import Converter
80from .._internal.type_guards import is_dict, is_sequence
81from .._internal.types import (
82 FAIR,
83 AbsoluteTolerance,
84 LowerCaseIdentifier,
85 LowerCaseIdentifierAnno,
86 MismatchedElementsPerMillion,
87 RelativeTolerance,
88)
89from .._internal.types import Datetime as Datetime
90from .._internal.types import Identifier as Identifier
91from .._internal.types import NotEmpty as NotEmpty
92from .._internal.types import SiUnit as SiUnit
93from .._internal.url import HttpUrl as HttpUrl
94from .._internal.validation_context import get_validation_context
95from .._internal.validator_annotations import RestrictCharacters
96from .._internal.version_type import Version as Version
97from .._internal.warning_levels import INFO
98from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
99from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
100from ..dataset.v0_3 import DatasetDescr as DatasetDescr
101from ..dataset.v0_3 import DatasetId as DatasetId
102from ..dataset.v0_3 import LinkedDataset as LinkedDataset
103from ..dataset.v0_3 import Uploader as Uploader
104from ..generic.v0_3 import (
105 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
106)
107from ..generic.v0_3 import Author as Author
108from ..generic.v0_3 import BadgeDescr as BadgeDescr
109from ..generic.v0_3 import CiteEntry as CiteEntry
110from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
111from ..generic.v0_3 import Doi as Doi
112from ..generic.v0_3 import (
113 FileSource_documentation,
114 GenericModelDescrBase,
115 LinkedResourceBase,
116 _author_conv, # pyright: ignore[reportPrivateUsage]
117 _maintainer_conv, # pyright: ignore[reportPrivateUsage]
118)
119from ..generic.v0_3 import LicenseId as LicenseId
120from ..generic.v0_3 import LinkedResource as LinkedResource
121from ..generic.v0_3 import Maintainer as Maintainer
122from ..generic.v0_3 import OrcidId as OrcidId
123from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
124from ..generic.v0_3 import ResourceId as ResourceId
125from .v0_4 import Author as _Author_v0_4
126from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
127from .v0_4 import CallableFromDepencency as CallableFromDepencency
128from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
129from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
130from .v0_4 import ClipDescr as _ClipDescr_v0_4
131from .v0_4 import ClipKwargs as ClipKwargs
132from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
133from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
134from .v0_4 import KnownRunMode as KnownRunMode
135from .v0_4 import ModelDescr as _ModelDescr_v0_4
136from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
137from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
138from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
139from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
140from .v0_4 import ProcessingKwargs as ProcessingKwargs
141from .v0_4 import RunMode as RunMode
142from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
143from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
144from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
145from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
146from .v0_4 import TensorName as _TensorName_v0_4
147from .v0_4 import WeightsFormat as WeightsFormat
148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
149from .v0_4 import package_weights
151SpaceUnit = Literal[
152 "attometer",
153 "angstrom",
154 "centimeter",
155 "decimeter",
156 "exameter",
157 "femtometer",
158 "foot",
159 "gigameter",
160 "hectometer",
161 "inch",
162 "kilometer",
163 "megameter",
164 "meter",
165 "micrometer",
166 "mile",
167 "millimeter",
168 "nanometer",
169 "parsec",
170 "petameter",
171 "picometer",
172 "terameter",
173 "yard",
174 "yoctometer",
175 "yottameter",
176 "zeptometer",
177 "zettameter",
178]
179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
181TimeUnit = Literal[
182 "attosecond",
183 "centisecond",
184 "day",
185 "decisecond",
186 "exasecond",
187 "femtosecond",
188 "gigasecond",
189 "hectosecond",
190 "hour",
191 "kilosecond",
192 "megasecond",
193 "microsecond",
194 "millisecond",
195 "minute",
196 "nanosecond",
197 "petasecond",
198 "picosecond",
199 "second",
200 "terasecond",
201 "yoctosecond",
202 "yottasecond",
203 "zeptosecond",
204 "zettasecond",
205]
206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
208AxisType = Literal["batch", "channel", "index", "time", "space"]
210_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
211 "b": "batch",
212 "t": "time",
213 "i": "index",
214 "c": "channel",
215 "x": "space",
216 "y": "space",
217 "z": "space",
218}
220_AXIS_ID_MAP = {
221 "b": "batch",
222 "t": "time",
223 "i": "index",
224 "c": "channel",
225}
228class TensorId(LowerCaseIdentifier):
229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
230 Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
231 ]
234def _normalize_axis_id(a: str):
235 a = str(a)
236 normalized = _AXIS_ID_MAP.get(a, a)
237 if a != normalized:
238 logger.opt(depth=3).warning(
239 "Normalized axis id from '{}' to '{}'.", a, normalized
240 )
241 return normalized
244class AxisId(LowerCaseIdentifier):
245 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
246 Annotated[
247 LowerCaseIdentifierAnno,
248 MaxLen(16),
249 AfterValidator(_normalize_axis_id),
250 ]
251 ]
254def _is_batch(a: str) -> bool:
255 return str(a) == "batch"
258def _is_not_batch(a: str) -> bool:
259 return not _is_batch(a)
262NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
264PreprocessingId = Literal[
265 "binarize",
266 "clip",
267 "ensure_dtype",
268 "fixed_zero_mean_unit_variance",
269 "scale_linear",
270 "scale_range",
271 "sigmoid",
272 "softmax",
273]
274PostprocessingId = Literal[
275 "binarize",
276 "clip",
277 "ensure_dtype",
278 "fixed_zero_mean_unit_variance",
279 "scale_linear",
280 "scale_mean_variance",
281 "scale_range",
282 "sigmoid",
283 "softmax",
284 "zero_mean_unit_variance",
285]
288SAME_AS_TYPE = "<same as type>"
291ParameterizedSize_N = int
292"""
293Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
294"""
297class ParameterizedSize(Node):
298 """Describes a range of valid tensor axis sizes as `size = min + n*step`.
300 - **min** and **step** are given by the model description.
301 - All blocksize paramters n = 0,1,2,... yield a valid `size`.
302 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
303 This allows to adjust the axis size more generically.
304 """
306 N: ClassVar[Type[int]] = ParameterizedSize_N
307 """Positive integer to parameterize this axis"""
309 min: Annotated[int, Gt(0)]
310 step: Annotated[int, Gt(0)]
312 def validate_size(self, size: int) -> int:
313 if size < self.min:
314 raise ValueError(f"size {size} < {self.min}")
315 if (size - self.min) % self.step != 0:
316 raise ValueError(
317 f"axis of size {size} is not parameterized by `min + n*step` ="
318 + f" `{self.min} + n*{self.step}`"
319 )
321 return size
323 def get_size(self, n: ParameterizedSize_N) -> int:
324 return self.min + self.step * n
326 def get_n(self, s: int) -> ParameterizedSize_N:
327 """return smallest n parameterizing a size greater or equal than `s`"""
328 return ceil((s - self.min) / self.step)
331class DataDependentSize(Node):
332 min: Annotated[int, Gt(0)] = 1
333 max: Annotated[Optional[int], Gt(1)] = None
335 @model_validator(mode="after")
336 def _validate_max_gt_min(self):
337 if self.max is not None and self.min >= self.max:
338 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
340 return self
342 def validate_size(self, size: int) -> int:
343 if size < self.min:
344 raise ValueError(f"size {size} < {self.min}")
346 if self.max is not None and size > self.max:
347 raise ValueError(f"size {size} > {self.max}")
349 return size
352class SizeReference(Node):
353 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
355 `axis.size = reference.size * reference.scale / axis.scale + offset`
357 Note:
358 1. The axis and the referenced axis need to have the same unit (or no unit).
359 2. Batch axes may not be referenced.
360 3. Fractions are rounded down.
361 4. If the reference axis is `concatenable` the referencing axis is assumed to be
362 `concatenable` as well with the same block order.
364 Example:
365 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
366 Let's assume that we want to express the image height h in relation to its width w
367 instead of only accepting input images of exactly 100*49 pixels
368 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
370 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
371 >>> h = SpaceInputAxis(
372 ... id=AxisId("h"),
373 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
374 ... unit="millimeter",
375 ... scale=4,
376 ... )
377 >>> print(h.size.get_size(h, w))
378 49
380 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
381 """
383 tensor_id: TensorId
384 """tensor id of the reference axis"""
386 axis_id: AxisId
387 """axis id of the reference axis"""
389 offset: StrictInt = 0
391 def get_size(
392 self,
393 axis: Union[
394 ChannelAxis,
395 IndexInputAxis,
396 IndexOutputAxis,
397 TimeInputAxis,
398 SpaceInputAxis,
399 TimeOutputAxis,
400 TimeOutputAxisWithHalo,
401 SpaceOutputAxis,
402 SpaceOutputAxisWithHalo,
403 ],
404 ref_axis: Union[
405 ChannelAxis,
406 IndexInputAxis,
407 IndexOutputAxis,
408 TimeInputAxis,
409 SpaceInputAxis,
410 TimeOutputAxis,
411 TimeOutputAxisWithHalo,
412 SpaceOutputAxis,
413 SpaceOutputAxisWithHalo,
414 ],
415 n: ParameterizedSize_N = 0,
416 ref_size: Optional[int] = None,
417 ):
418 """Compute the concrete size for a given axis and its reference axis.
420 Args:
421 axis: The axis this `SizeReference` is the size of.
422 ref_axis: The reference axis to compute the size from.
423 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
424 and no fixed **ref_size** is given,
425 **n** is used to compute the size of the parameterized **ref_axis**.
426 ref_size: Overwrite the reference size instead of deriving it from
427 **ref_axis**
428 (**ref_axis.scale** is still used; any given **n** is ignored).
429 """
430 assert axis.size == self, (
431 "Given `axis.size` is not defined by this `SizeReference`"
432 )
434 assert ref_axis.id == self.axis_id, (
435 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
436 )
438 assert axis.unit == ref_axis.unit, (
439 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
440 f" but {axis.unit}!={ref_axis.unit}"
441 )
442 if ref_size is None:
443 if isinstance(ref_axis.size, (int, float)):
444 ref_size = ref_axis.size
445 elif isinstance(ref_axis.size, ParameterizedSize):
446 ref_size = ref_axis.size.get_size(n)
447 elif isinstance(ref_axis.size, DataDependentSize):
448 raise ValueError(
449 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
450 )
451 elif isinstance(ref_axis.size, SizeReference):
452 raise ValueError(
453 "Reference axis referenced in `SizeReference` may not be sized by a"
454 + " `SizeReference` itself."
455 )
456 else:
457 assert_never(ref_axis.size)
459 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
461 @staticmethod
462 def _get_unit(
463 axis: Union[
464 ChannelAxis,
465 IndexInputAxis,
466 IndexOutputAxis,
467 TimeInputAxis,
468 SpaceInputAxis,
469 TimeOutputAxis,
470 TimeOutputAxisWithHalo,
471 SpaceOutputAxis,
472 SpaceOutputAxisWithHalo,
473 ],
474 ):
475 return axis.unit
478class AxisBase(NodeWithExplicitlySetFields):
479 id: AxisId
480 """An axis id unique across all axes of one tensor."""
482 description: Annotated[str, MaxLen(128)] = ""
483 """A short description of this axis beyond its type and id."""
486class WithHalo(Node):
487 halo: Annotated[int, Ge(1)]
488 """The halo should be cropped from the output tensor to avoid boundary effects.
489 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
490 To document a halo that is already cropped by the model use `size.offset` instead."""
492 size: Annotated[
493 SizeReference,
494 Field(
495 examples=[
496 10,
497 SizeReference(
498 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
499 ).model_dump(mode="json"),
500 ]
501 ),
502 ]
503 """reference to another axis with an optional offset (see `SizeReference`)"""
506BATCH_AXIS_ID = AxisId("batch")
509class BatchAxis(AxisBase):
510 implemented_type: ClassVar[Literal["batch"]] = "batch"
511 if TYPE_CHECKING:
512 type: Literal["batch"] = "batch"
513 else:
514 type: Literal["batch"]
516 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
517 size: Optional[Literal[1]] = None
518 """The batch size may be fixed to 1,
519 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
521 @property
522 def scale(self):
523 return 1.0
525 @property
526 def concatenable(self):
527 return True
529 @property
530 def unit(self):
531 return None
534class ChannelAxis(AxisBase):
535 implemented_type: ClassVar[Literal["channel"]] = "channel"
536 if TYPE_CHECKING:
537 type: Literal["channel"] = "channel"
538 else:
539 type: Literal["channel"]
541 id: NonBatchAxisId = AxisId("channel")
543 channel_names: NotEmpty[List[Identifier]]
545 @property
546 def size(self) -> int:
547 return len(self.channel_names)
549 @property
550 def concatenable(self):
551 return False
553 @property
554 def scale(self) -> float:
555 return 1.0
557 @property
558 def unit(self):
559 return None
562class IndexAxisBase(AxisBase):
563 implemented_type: ClassVar[Literal["index"]] = "index"
564 if TYPE_CHECKING:
565 type: Literal["index"] = "index"
566 else:
567 type: Literal["index"]
569 id: NonBatchAxisId = AxisId("index")
571 @property
572 def scale(self) -> float:
573 return 1.0
575 @property
576 def unit(self):
577 return None
580class _WithInputAxisSize(Node):
581 size: Annotated[
582 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
583 Field(
584 examples=[
585 10,
586 ParameterizedSize(min=32, step=16).model_dump(mode="json"),
587 SizeReference(
588 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
589 ).model_dump(mode="json"),
590 ]
591 ),
592 ]
593 """The size/length of this axis can be specified as
594 - fixed integer
595 - parameterized series of valid sizes (`ParameterizedSize`)
596 - reference to another axis with an optional offset (`SizeReference`)
597 """
600class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
601 concatenable: bool = False
602 """If a model has a `concatenable` input axis, it can be processed blockwise,
603 splitting a longer sample axis into blocks matching its input tensor description.
604 Output axes are concatenable if they have a `SizeReference` to a concatenable
605 input axis.
606 """
609class IndexOutputAxis(IndexAxisBase):
610 size: Annotated[
611 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
612 Field(
613 examples=[
614 10,
615 SizeReference(
616 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
617 ).model_dump(mode="json"),
618 ]
619 ),
620 ]
621 """The size/length of this axis can be specified as
622 - fixed integer
623 - reference to another axis with an optional offset (`SizeReference`)
624 - data dependent size using `DataDependentSize` (size is only known after model inference)
625 """
628class TimeAxisBase(AxisBase):
629 implemented_type: ClassVar[Literal["time"]] = "time"
630 if TYPE_CHECKING:
631 type: Literal["time"] = "time"
632 else:
633 type: Literal["time"]
635 id: NonBatchAxisId = AxisId("time")
636 unit: Optional[TimeUnit] = None
637 scale: Annotated[float, Gt(0)] = 1.0
640class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
641 concatenable: bool = False
642 """If a model has a `concatenable` input axis, it can be processed blockwise,
643 splitting a longer sample axis into blocks matching its input tensor description.
644 Output axes are concatenable if they have a `SizeReference` to a concatenable
645 input axis.
646 """
649class SpaceAxisBase(AxisBase):
650 implemented_type: ClassVar[Literal["space"]] = "space"
651 if TYPE_CHECKING:
652 type: Literal["space"] = "space"
653 else:
654 type: Literal["space"]
656 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
657 unit: Optional[SpaceUnit] = None
658 scale: Annotated[float, Gt(0)] = 1.0
661class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
662 concatenable: bool = False
663 """If a model has a `concatenable` input axis, it can be processed blockwise,
664 splitting a longer sample axis into blocks matching its input tensor description.
665 Output axes are concatenable if they have a `SizeReference` to a concatenable
666 input axis.
667 """
670INPUT_AXIS_TYPES = (
671 BatchAxis,
672 ChannelAxis,
673 IndexInputAxis,
674 TimeInputAxis,
675 SpaceInputAxis,
676)
677"""intended for isinstance comparisons in py<3.10"""
679_InputAxisUnion = Union[
680 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
681]
682InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
685class _WithOutputAxisSize(Node):
686 size: Annotated[
687 Union[Annotated[int, Gt(0)], SizeReference],
688 Field(
689 examples=[
690 10,
691 SizeReference(
692 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
693 ).model_dump(mode="json"),
694 ]
695 ),
696 ]
697 """The size/length of this axis can be specified as
698 - fixed integer
699 - reference to another axis with an optional offset (see `SizeReference`)
700 """
703class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
704 pass
707class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
708 pass
711def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
712 if isinstance(v, dict):
713 return "with_halo" if "halo" in v else "wo_halo"
714 else:
715 return "with_halo" if hasattr(v, "halo") else "wo_halo"
718_TimeOutputAxisUnion = Annotated[
719 Union[
720 Annotated[TimeOutputAxis, Tag("wo_halo")],
721 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
722 ],
723 Discriminator(_get_halo_axis_discriminator_value),
724]
727class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
728 pass
731class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
732 pass
735_SpaceOutputAxisUnion = Annotated[
736 Union[
737 Annotated[SpaceOutputAxis, Tag("wo_halo")],
738 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
739 ],
740 Discriminator(_get_halo_axis_discriminator_value),
741]
744_OutputAxisUnion = Union[
745 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
746]
747OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
749OUTPUT_AXIS_TYPES = (
750 BatchAxis,
751 ChannelAxis,
752 IndexOutputAxis,
753 TimeOutputAxis,
754 TimeOutputAxisWithHalo,
755 SpaceOutputAxis,
756 SpaceOutputAxisWithHalo,
757)
758"""intended for isinstance comparisons in py<3.10"""
761AnyAxis = Union[InputAxis, OutputAxis]
763ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
764"""intended for isinstance comparisons in py<3.10"""
766TVs = Union[
767 NotEmpty[List[int]],
768 NotEmpty[List[float]],
769 NotEmpty[List[bool]],
770 NotEmpty[List[str]],
771]
774NominalOrOrdinalDType = Literal[
775 "float32",
776 "float64",
777 "uint8",
778 "int8",
779 "uint16",
780 "int16",
781 "uint32",
782 "int32",
783 "uint64",
784 "int64",
785 "bool",
786]
789class NominalOrOrdinalDataDescr(Node):
790 values: TVs
791 """A fixed set of nominal or an ascending sequence of ordinal values.
792 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
793 String `values` are interpreted as labels for tensor values 0, ..., N.
794 Note: as YAML 1.2 does not natively support a "set" datatype,
795 nominal values should be given as a sequence (aka list/array) as well.
796 """
798 type: Annotated[
799 NominalOrOrdinalDType,
800 Field(
801 examples=[
802 "float32",
803 "uint8",
804 "uint16",
805 "int64",
806 "bool",
807 ],
808 ),
809 ] = "uint8"
811 @model_validator(mode="after")
812 def _validate_values_match_type(
813 self,
814 ) -> Self:
815 incompatible: List[Any] = []
816 for v in self.values:
817 if self.type == "bool":
818 if not isinstance(v, bool):
819 incompatible.append(v)
820 elif self.type in DTYPE_LIMITS:
821 if (
822 isinstance(v, (int, float))
823 and (
824 v < DTYPE_LIMITS[self.type].min
825 or v > DTYPE_LIMITS[self.type].max
826 )
827 or (isinstance(v, str) and "uint" not in self.type)
828 or (isinstance(v, float) and "int" in self.type)
829 ):
830 incompatible.append(v)
831 else:
832 incompatible.append(v)
834 if len(incompatible) == 5:
835 incompatible.append("...")
836 break
838 if incompatible:
839 raise ValueError(
840 f"data type '{self.type}' incompatible with values {incompatible}"
841 )
843 return self
845 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
847 @property
848 def range(self):
849 if isinstance(self.values[0], str):
850 return 0, len(self.values) - 1
851 else:
852 return min(self.values), max(self.values)
855IntervalOrRatioDType = Literal[
856 "float32",
857 "float64",
858 "uint8",
859 "int8",
860 "uint16",
861 "int16",
862 "uint32",
863 "int32",
864 "uint64",
865 "int64",
866]
869class IntervalOrRatioDataDescr(Node):
870 type: Annotated[ # TODO: rename to dtype
871 IntervalOrRatioDType,
872 Field(
873 examples=["float32", "float64", "uint8", "uint16"],
874 ),
875 ] = "float32"
876 range: Tuple[Optional[float], Optional[float]] = (
877 None,
878 None,
879 )
880 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
881 `None` corresponds to min/max of what can be expressed by **type**."""
882 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
883 scale: float = 1.0
884 """Scale for data on an interval (or ratio) scale."""
885 offset: Optional[float] = None
886 """Offset for data on a ratio scale."""
888 @model_validator(mode="before")
889 def _replace_inf(cls, data: Any):
890 if is_dict(data):
891 if "range" in data and is_sequence(data["range"]):
892 forbidden = (
893 "inf",
894 "-inf",
895 ".inf",
896 "-.inf",
897 float("inf"),
898 float("-inf"),
899 )
900 if any(v in forbidden for v in data["range"]):
901 issue_warning("replaced 'inf' value", value=data["range"])
903 data["range"] = tuple(
904 (None if v in forbidden else v) for v in data["range"]
905 )
907 return data
910TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
913class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
914 """processing base class"""
917class BinarizeKwargs(ProcessingKwargs):
918 """key word arguments for `BinarizeDescr`"""
920 threshold: float
921 """The fixed threshold"""
924class BinarizeAlongAxisKwargs(ProcessingKwargs):
925 """key word arguments for `BinarizeDescr`"""
927 threshold: NotEmpty[List[float]]
928 """The fixed threshold values along `axis`"""
930 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
931 """The `threshold` axis"""
934class BinarizeDescr(ProcessingDescrBase):
935 """Binarize the tensor with a fixed threshold.
937 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
938 will be set to one, values below the threshold to zero.
940 Examples:
941 - in YAML
942 ```yaml
943 postprocessing:
944 - id: binarize
945 kwargs:
946 axis: 'channel'
947 threshold: [0.25, 0.5, 0.75]
948 ```
949 - in Python:
950 >>> postprocessing = [BinarizeDescr(
951 ... kwargs=BinarizeAlongAxisKwargs(
952 ... axis=AxisId('channel'),
953 ... threshold=[0.25, 0.5, 0.75],
954 ... )
955 ... )]
956 """
958 implemented_id: ClassVar[Literal["binarize"]] = "binarize"
959 if TYPE_CHECKING:
960 id: Literal["binarize"] = "binarize"
961 else:
962 id: Literal["binarize"]
963 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
966class ClipDescr(ProcessingDescrBase):
967 """Set tensor values below min to min and above max to max.
969 See `ScaleRangeDescr` for examples.
970 """
972 implemented_id: ClassVar[Literal["clip"]] = "clip"
973 if TYPE_CHECKING:
974 id: Literal["clip"] = "clip"
975 else:
976 id: Literal["clip"]
978 kwargs: ClipKwargs
981class EnsureDtypeKwargs(ProcessingKwargs):
982 """key word arguments for `EnsureDtypeDescr`"""
984 dtype: Literal[
985 "float32",
986 "float64",
987 "uint8",
988 "int8",
989 "uint16",
990 "int16",
991 "uint32",
992 "int32",
993 "uint64",
994 "int64",
995 "bool",
996 ]
999class EnsureDtypeDescr(ProcessingDescrBase):
1000 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1002 This can for example be used to ensure the inner neural network model gets a
1003 different input tensor data type than the fully described bioimage.io model does.
1005 Examples:
1006 The described bioimage.io model (incl. preprocessing) accepts any
1007 float32-compatible tensor, normalizes it with percentiles and clipping and then
1008 casts it to uint8, which is what the neural network in this example expects.
1009 - in YAML
1010 ```yaml
1011 inputs:
1012 - data:
1013 type: float32 # described bioimage.io model is compatible with any float32 input tensor
1014 preprocessing:
1015 - id: scale_range
1016 kwargs:
1017 axes: ['y', 'x']
1018 max_percentile: 99.8
1019 min_percentile: 5.0
1020 - id: clip
1021 kwargs:
1022 min: 0.0
1023 max: 1.0
1024 - id: ensure_dtype # the neural network of the model requires uint8
1025 kwargs:
1026 dtype: uint8
1027 ```
1028 - in Python:
1029 >>> preprocessing = [
1030 ... ScaleRangeDescr(
1031 ... kwargs=ScaleRangeKwargs(
1032 ... axes= (AxisId('y'), AxisId('x')),
1033 ... max_percentile= 99.8,
1034 ... min_percentile= 5.0,
1035 ... )
1036 ... ),
1037 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1038 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1039 ... ]
1040 """
1042 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1043 if TYPE_CHECKING:
1044 id: Literal["ensure_dtype"] = "ensure_dtype"
1045 else:
1046 id: Literal["ensure_dtype"]
1048 kwargs: EnsureDtypeKwargs
1051class ScaleLinearKwargs(ProcessingKwargs):
1052 """Key word arguments for `ScaleLinearDescr`"""
1054 gain: float = 1.0
1055 """multiplicative factor"""
1057 offset: float = 0.0
1058 """additive term"""
1060 @model_validator(mode="after")
1061 def _validate(self) -> Self:
1062 if self.gain == 1.0 and self.offset == 0.0:
1063 raise ValueError(
1064 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1065 + " != 0.0."
1066 )
1068 return self
1071class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1072 """Key word arguments for `ScaleLinearDescr`"""
1074 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1075 """The axis of gain and offset values."""
1077 gain: Union[float, NotEmpty[List[float]]] = 1.0
1078 """multiplicative factor"""
1080 offset: Union[float, NotEmpty[List[float]]] = 0.0
1081 """additive term"""
1083 @model_validator(mode="after")
1084 def _validate(self) -> Self:
1085 if isinstance(self.gain, list):
1086 if isinstance(self.offset, list):
1087 if len(self.gain) != len(self.offset):
1088 raise ValueError(
1089 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1090 )
1091 else:
1092 self.offset = [float(self.offset)] * len(self.gain)
1093 elif isinstance(self.offset, list):
1094 self.gain = [float(self.gain)] * len(self.offset)
1095 else:
1096 raise ValueError(
1097 "Do not specify an `axis` for scalar gain and offset values."
1098 )
1100 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1101 raise ValueError(
1102 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1103 + " != 0.0."
1104 )
1106 return self
1109class ScaleLinearDescr(ProcessingDescrBase):
1110 """Fixed linear scaling.
1112 Examples:
1113 1. Scale with scalar gain and offset
1114 - in YAML
1115 ```yaml
1116 preprocessing:
1117 - id: scale_linear
1118 kwargs:
1119 gain: 2.0
1120 offset: 3.0
1121 ```
1122 - in Python:
1123 >>> preprocessing = [
1124 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1125 ... ]
1127 2. Independent scaling along an axis
1128 - in YAML
1129 ```yaml
1130 preprocessing:
1131 - id: scale_linear
1132 kwargs:
1133 axis: 'channel'
1134 gain: [1.0, 2.0, 3.0]
1135 ```
1136 - in Python:
1137 >>> preprocessing = [
1138 ... ScaleLinearDescr(
1139 ... kwargs=ScaleLinearAlongAxisKwargs(
1140 ... axis=AxisId("channel"),
1141 ... gain=[1.0, 2.0, 3.0],
1142 ... )
1143 ... )
1144 ... ]
1146 """
1148 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1149 if TYPE_CHECKING:
1150 id: Literal["scale_linear"] = "scale_linear"
1151 else:
1152 id: Literal["scale_linear"]
1153 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1156class SigmoidDescr(ProcessingDescrBase):
1157 """The logistic sigmoid function, a.k.a. expit function.
1159 Examples:
1160 - in YAML
1161 ```yaml
1162 postprocessing:
1163 - id: sigmoid
1164 ```
1165 - in Python:
1166 >>> postprocessing = [SigmoidDescr()]
1167 """
1169 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1170 if TYPE_CHECKING:
1171 id: Literal["sigmoid"] = "sigmoid"
1172 else:
1173 id: Literal["sigmoid"]
1175 @property
1176 def kwargs(self) -> ProcessingKwargs:
1177 """empty kwargs"""
1178 return ProcessingKwargs()
1181class SoftmaxKwargs(ProcessingKwargs):
1182 """key word arguments for `SoftmaxDescr`"""
1184 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1185 """The axis to apply the softmax function along.
1186 Note:
1187 Defaults to 'channel' axis
1188 (which may not exist, in which case
1189 a different axis id has to be specified).
1190 """
1193class SoftmaxDescr(ProcessingDescrBase):
1194 """The softmax function.
1196 Examples:
1197 - in YAML
1198 ```yaml
1199 postprocessing:
1200 - id: softmax
1201 kwargs:
1202 axis: channel
1203 ```
1204 - in Python:
1205 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1206 """
1208 implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1209 if TYPE_CHECKING:
1210 id: Literal["softmax"] = "softmax"
1211 else:
1212 id: Literal["softmax"]
1214 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
1217class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1218 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1220 mean: float
1221 """The mean value to normalize with."""
1223 std: Annotated[float, Ge(1e-6)]
1224 """The standard deviation value to normalize with."""
1227class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1228 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1230 mean: NotEmpty[List[float]]
1231 """The mean value(s) to normalize with."""
1233 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1234 """The standard deviation value(s) to normalize with.
1235 Size must match `mean` values."""
1237 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1238 """The axis of the mean/std values to normalize each entry along that dimension
1239 separately."""
1241 @model_validator(mode="after")
1242 def _mean_and_std_match(self) -> Self:
1243 if len(self.mean) != len(self.std):
1244 raise ValueError(
1245 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1246 + " must match."
1247 )
1249 return self
1252class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1253 """Subtract a given mean and divide by the standard deviation.
1255 Normalize with fixed, precomputed values for
1256 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1257 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1258 axes.
1260 Examples:
1261 1. scalar value for whole tensor
1262 - in YAML
1263 ```yaml
1264 preprocessing:
1265 - id: fixed_zero_mean_unit_variance
1266 kwargs:
1267 mean: 103.5
1268 std: 13.7
1269 ```
1270 - in Python
1271 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1272 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1273 ... )]
1275 2. independently along an axis
1276 - in YAML
1277 ```yaml
1278 preprocessing:
1279 - id: fixed_zero_mean_unit_variance
1280 kwargs:
1281 axis: channel
1282 mean: [101.5, 102.5, 103.5]
1283 std: [11.7, 12.7, 13.7]
1284 ```
1285 - in Python
1286 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1287 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1288 ... axis=AxisId("channel"),
1289 ... mean=[101.5, 102.5, 103.5],
1290 ... std=[11.7, 12.7, 13.7],
1291 ... )
1292 ... )]
1293 """
1295 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1296 "fixed_zero_mean_unit_variance"
1297 )
1298 if TYPE_CHECKING:
1299 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1300 else:
1301 id: Literal["fixed_zero_mean_unit_variance"]
1303 kwargs: Union[
1304 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1305 ]
1308class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1309 """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1311 axes: Annotated[
1312 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1313 ] = None
1314 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1315 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1316 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1317 To normalize each sample independently leave out the 'batch' axis.
1318 Default: Scale all axes jointly."""
1320 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1321 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1324class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1325 """Subtract mean and divide by variance.
1327 Examples:
1328 Subtract tensor mean and variance
1329 - in YAML
1330 ```yaml
1331 preprocessing:
1332 - id: zero_mean_unit_variance
1333 ```
1334 - in Python
1335 >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1336 """
1338 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1339 "zero_mean_unit_variance"
1340 )
1341 if TYPE_CHECKING:
1342 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1343 else:
1344 id: Literal["zero_mean_unit_variance"]
1346 kwargs: ZeroMeanUnitVarianceKwargs = Field(
1347 default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1348 )
1351class ScaleRangeKwargs(ProcessingKwargs):
1352 """key word arguments for `ScaleRangeDescr`
1354 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1355 this processing step normalizes data to the [0, 1] intervall.
1356 For other percentiles the normalized values will partially be outside the [0, 1]
1357 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1358 normalized values to a range.
1359 """
1361 axes: Annotated[
1362 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1363 ] = None
1364 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1365 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1366 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1367 To normalize samples independently, leave out the "batch" axis.
1368 Default: Scale all axes jointly."""
1370 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1371 """The lower percentile used to determine the value to align with zero."""
1373 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1374 """The upper percentile used to determine the value to align with one.
1375 Has to be bigger than `min_percentile`.
1376 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1377 accepting percentiles specified in the range 0.0 to 1.0."""
1379 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1380 """Epsilon for numeric stability.
1381 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1382 with `v_lower,v_upper` values at the respective percentiles."""
1384 reference_tensor: Optional[TensorId] = None
1385 """Tensor ID to compute the percentiles from. Default: The tensor itself.
1386 For any tensor in `inputs` only input tensor references are allowed."""
1388 @field_validator("max_percentile", mode="after")
1389 @classmethod
1390 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1391 if (min_p := info.data["min_percentile"]) >= value:
1392 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1394 return value
1397class ScaleRangeDescr(ProcessingDescrBase):
1398 """Scale with percentiles.
1400 Examples:
1401 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1402 - in YAML
1403 ```yaml
1404 preprocessing:
1405 - id: scale_range
1406 kwargs:
1407 axes: ['y', 'x']
1408 max_percentile: 99.8
1409 min_percentile: 5.0
1410 ```
1411 - in Python
1412 >>> preprocessing = [
1413 ... ScaleRangeDescr(
1414 ... kwargs=ScaleRangeKwargs(
1415 ... axes= (AxisId('y'), AxisId('x')),
1416 ... max_percentile= 99.8,
1417 ... min_percentile= 5.0,
1418 ... )
1419 ... ),
1420 ... ClipDescr(
1421 ... kwargs=ClipKwargs(
1422 ... min=0.0,
1423 ... max=1.0,
1424 ... )
1425 ... ),
1426 ... ]
1428 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1429 - in YAML
1430 ```yaml
1431 preprocessing:
1432 - id: scale_range
1433 kwargs:
1434 axes: ['y', 'x']
1435 max_percentile: 99.8
1436 min_percentile: 5.0
1437 - id: scale_range
1438 - id: clip
1439 kwargs:
1440 min: 0.0
1441 max: 1.0
1442 ```
1443 - in Python
1444 >>> preprocessing = [ScaleRangeDescr(
1445 ... kwargs=ScaleRangeKwargs(
1446 ... axes= (AxisId('y'), AxisId('x')),
1447 ... max_percentile= 99.8,
1448 ... min_percentile= 5.0,
1449 ... )
1450 ... )]
1452 """
1454 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1455 if TYPE_CHECKING:
1456 id: Literal["scale_range"] = "scale_range"
1457 else:
1458 id: Literal["scale_range"]
1459 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
1462class ScaleMeanVarianceKwargs(ProcessingKwargs):
1463 """key word arguments for `ScaleMeanVarianceKwargs`"""
1465 reference_tensor: TensorId
1466 """Name of tensor to match."""
1468 axes: Annotated[
1469 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1470 ] = None
1471 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1472 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1473 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1474 To normalize samples independently, leave out the 'batch' axis.
1475 Default: Scale all axes jointly."""
1477 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1478 """Epsilon for numeric stability:
1479 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1482class ScaleMeanVarianceDescr(ProcessingDescrBase):
1483 """Scale a tensor's data distribution to match another tensor's mean/std.
1484 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1485 """
1487 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1488 if TYPE_CHECKING:
1489 id: Literal["scale_mean_variance"] = "scale_mean_variance"
1490 else:
1491 id: Literal["scale_mean_variance"]
1492 kwargs: ScaleMeanVarianceKwargs
1495PreprocessingDescr = Annotated[
1496 Union[
1497 BinarizeDescr,
1498 ClipDescr,
1499 EnsureDtypeDescr,
1500 FixedZeroMeanUnitVarianceDescr,
1501 ScaleLinearDescr,
1502 ScaleRangeDescr,
1503 SigmoidDescr,
1504 SoftmaxDescr,
1505 ZeroMeanUnitVarianceDescr,
1506 ],
1507 Discriminator("id"),
1508]
1509PostprocessingDescr = Annotated[
1510 Union[
1511 BinarizeDescr,
1512 ClipDescr,
1513 EnsureDtypeDescr,
1514 FixedZeroMeanUnitVarianceDescr,
1515 ScaleLinearDescr,
1516 ScaleMeanVarianceDescr,
1517 ScaleRangeDescr,
1518 SigmoidDescr,
1519 SoftmaxDescr,
1520 ZeroMeanUnitVarianceDescr,
1521 ],
1522 Discriminator("id"),
1523]
1525IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1528class TensorDescrBase(Node, Generic[IO_AxisT]):
1529 id: TensorId
1530 """Tensor id. No duplicates are allowed."""
1532 description: Annotated[str, MaxLen(128)] = ""
1533 """free text description"""
1535 axes: NotEmpty[Sequence[IO_AxisT]]
1536 """tensor axes"""
1538 @property
1539 def shape(self):
1540 return tuple(a.size for a in self.axes)
1542 @field_validator("axes", mode="after", check_fields=False)
1543 @classmethod
1544 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1545 batch_axes = [a for a in axes if a.type == "batch"]
1546 if len(batch_axes) > 1:
1547 raise ValueError(
1548 f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1549 )
1551 seen_ids: Set[AxisId] = set()
1552 duplicate_axes_ids: Set[AxisId] = set()
1553 for a in axes:
1554 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1556 if duplicate_axes_ids:
1557 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1559 return axes
1561 test_tensor: FAIR[Optional[FileDescr_]] = None
1562 """An example tensor to use for testing.
1563 Using the model with the test input tensors is expected to yield the test output tensors.
1564 Each test tensor has be a an ndarray in the
1565 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1566 The file extension must be '.npy'."""
1568 sample_tensor: FAIR[Optional[FileDescr_]] = None
1569 """A sample tensor to illustrate a possible input/output for the model,
1570 The sample image primarily serves to inform a human user about an example use case
1571 and is typically stored as .hdf5, .png or .tiff.
1572 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1573 (numpy's `.npy` format is not supported).
1574 The image dimensionality has to match the number of axes specified in this tensor description.
1575 """
1577 @model_validator(mode="after")
1578 def _validate_sample_tensor(self) -> Self:
1579 if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1580 return self
1582 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1583 tensor: NDArray[Any] = imread(
1584 reader.read(),
1585 extension=PurePosixPath(reader.original_file_name).suffix,
1586 )
1587 n_dims = len(tensor.squeeze().shape)
1588 n_dims_min = n_dims_max = len(self.axes)
1590 for a in self.axes:
1591 if isinstance(a, BatchAxis):
1592 n_dims_min -= 1
1593 elif isinstance(a.size, int):
1594 if a.size == 1:
1595 n_dims_min -= 1
1596 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1597 if a.size.min == 1:
1598 n_dims_min -= 1
1599 elif isinstance(a.size, SizeReference):
1600 if a.size.offset < 2:
1601 # size reference may result in singleton axis
1602 n_dims_min -= 1
1603 else:
1604 assert_never(a.size)
1606 n_dims_min = max(0, n_dims_min)
1607 if n_dims < n_dims_min or n_dims > n_dims_max:
1608 raise ValueError(
1609 f"Expected sample tensor to have {n_dims_min} to"
1610 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1611 )
1613 return self
1615 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1616 IntervalOrRatioDataDescr()
1617 )
1618 """Description of the tensor's data values, optionally per channel.
1619 If specified per channel, the data `type` needs to match across channels."""
1621 @property
1622 def dtype(
1623 self,
1624 ) -> Literal[
1625 "float32",
1626 "float64",
1627 "uint8",
1628 "int8",
1629 "uint16",
1630 "int16",
1631 "uint32",
1632 "int32",
1633 "uint64",
1634 "int64",
1635 "bool",
1636 ]:
1637 """dtype as specified under `data.type` or `data[i].type`"""
1638 if isinstance(self.data, collections.abc.Sequence):
1639 return self.data[0].type
1640 else:
1641 return self.data.type
1643 @field_validator("data", mode="after")
1644 @classmethod
1645 def _check_data_type_across_channels(
1646 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1647 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1648 if not isinstance(value, list):
1649 return value
1651 dtypes = {t.type for t in value}
1652 if len(dtypes) > 1:
1653 raise ValueError(
1654 "Tensor data descriptions per channel need to agree in their data"
1655 + f" `type`, but found {dtypes}."
1656 )
1658 return value
1660 @model_validator(mode="after")
1661 def _check_data_matches_channelaxis(self) -> Self:
1662 if not isinstance(self.data, (list, tuple)):
1663 return self
1665 for a in self.axes:
1666 if isinstance(a, ChannelAxis):
1667 size = a.size
1668 assert isinstance(size, int)
1669 break
1670 else:
1671 return self
1673 if len(self.data) != size:
1674 raise ValueError(
1675 f"Got tensor data descriptions for {len(self.data)} channels, but"
1676 + f" '{a.id}' axis has size {size}."
1677 )
1679 return self
1681 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1682 if len(array.shape) != len(self.axes):
1683 raise ValueError(
1684 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1685 + f" incompatible with {len(self.axes)} axes."
1686 )
1687 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1690class InputTensorDescr(TensorDescrBase[InputAxis]):
1691 id: TensorId = TensorId("input")
1692 """Input tensor id.
1693 No duplicates are allowed across all inputs and outputs."""
1695 optional: bool = False
1696 """indicates that this tensor may be `None`"""
1698 preprocessing: List[PreprocessingDescr] = Field(
1699 default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1700 )
1702 """Description of how this input should be preprocessed.
1704 notes:
1705 - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1706 to ensure an input tensor's data type matches the input tensor's data description.
1707 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1708 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1709 changing the data type.
1710 """
1712 @model_validator(mode="after")
1713 def _validate_preprocessing_kwargs(self) -> Self:
1714 axes_ids = [a.id for a in self.axes]
1715 for p in self.preprocessing:
1716 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1717 if kwargs_axes is None:
1718 continue
1720 if not isinstance(kwargs_axes, collections.abc.Sequence):
1721 raise ValueError(
1722 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1723 )
1725 if any(a not in axes_ids for a in kwargs_axes):
1726 raise ValueError(
1727 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1728 )
1730 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1731 dtype = self.data.type
1732 else:
1733 dtype = self.data[0].type
1735 # ensure `preprocessing` begins with `EnsureDtypeDescr`
1736 if not self.preprocessing or not isinstance(
1737 self.preprocessing[0], EnsureDtypeDescr
1738 ):
1739 self.preprocessing.insert(
1740 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1741 )
1743 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1744 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1745 self.preprocessing.append(
1746 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1747 )
1749 return self
1752def convert_axes(
1753 axes: str,
1754 *,
1755 shape: Union[
1756 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1757 ],
1758 tensor_type: Literal["input", "output"],
1759 halo: Optional[Sequence[int]],
1760 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1761):
1762 ret: List[AnyAxis] = []
1763 for i, a in enumerate(axes):
1764 axis_type = _AXIS_TYPE_MAP.get(a, a)
1765 if axis_type == "batch":
1766 ret.append(BatchAxis())
1767 continue
1769 scale = 1.0
1770 if isinstance(shape, _ParameterizedInputShape_v0_4):
1771 if shape.step[i] == 0:
1772 size = shape.min[i]
1773 else:
1774 size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1775 elif isinstance(shape, _ImplicitOutputShape_v0_4):
1776 ref_t = str(shape.reference_tensor)
1777 if ref_t.count(".") == 1:
1778 t_id, orig_a_id = ref_t.split(".")
1779 else:
1780 t_id = ref_t
1781 orig_a_id = a
1783 a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1784 if not (orig_scale := shape.scale[i]):
1785 # old way to insert a new axis dimension
1786 size = int(2 * shape.offset[i])
1787 else:
1788 scale = 1 / orig_scale
1789 if axis_type in ("channel", "index"):
1790 # these axes no longer have a scale
1791 offset_from_scale = orig_scale * size_refs.get(
1792 _TensorName_v0_4(t_id), {}
1793 ).get(orig_a_id, 0)
1794 else:
1795 offset_from_scale = 0
1796 size = SizeReference(
1797 tensor_id=TensorId(t_id),
1798 axis_id=AxisId(a_id),
1799 offset=int(offset_from_scale + 2 * shape.offset[i]),
1800 )
1801 else:
1802 size = shape[i]
1804 if axis_type == "time":
1805 if tensor_type == "input":
1806 ret.append(TimeInputAxis(size=size, scale=scale))
1807 else:
1808 assert not isinstance(size, ParameterizedSize)
1809 if halo is None:
1810 ret.append(TimeOutputAxis(size=size, scale=scale))
1811 else:
1812 assert not isinstance(size, int)
1813 ret.append(
1814 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1815 )
1817 elif axis_type == "index":
1818 if tensor_type == "input":
1819 ret.append(IndexInputAxis(size=size))
1820 else:
1821 if isinstance(size, ParameterizedSize):
1822 size = DataDependentSize(min=size.min)
1824 ret.append(IndexOutputAxis(size=size))
1825 elif axis_type == "channel":
1826 assert not isinstance(size, ParameterizedSize)
1827 if isinstance(size, SizeReference):
1828 warnings.warn(
1829 "Conversion of channel size from an implicit output shape may be"
1830 + " wrong"
1831 )
1832 ret.append(
1833 ChannelAxis(
1834 channel_names=[
1835 Identifier(f"channel{i}") for i in range(size.offset)
1836 ]
1837 )
1838 )
1839 else:
1840 ret.append(
1841 ChannelAxis(
1842 channel_names=[Identifier(f"channel{i}") for i in range(size)]
1843 )
1844 )
1845 elif axis_type == "space":
1846 if tensor_type == "input":
1847 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1848 else:
1849 assert not isinstance(size, ParameterizedSize)
1850 if halo is None or halo[i] == 0:
1851 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1852 elif isinstance(size, int):
1853 raise NotImplementedError(
1854 f"output axis with halo and fixed size (here {size}) not allowed"
1855 )
1856 else:
1857 ret.append(
1858 SpaceOutputAxisWithHalo(
1859 id=AxisId(a), size=size, scale=scale, halo=halo[i]
1860 )
1861 )
1863 return ret
1866def _axes_letters_to_ids(
1867 axes: Optional[str],
1868) -> Optional[List[AxisId]]:
1869 if axes is None:
1870 return None
1872 return [AxisId(a) for a in axes]
1875def _get_complement_v04_axis(
1876 tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1877) -> Optional[AxisId]:
1878 if axes is None:
1879 return None
1881 non_complement_axes = set(axes) | {"b"}
1882 complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1883 if len(complement_axes) > 1:
1884 raise ValueError(
1885 f"Expected none or a single complement axis, but axes '{axes}' "
1886 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1887 )
1889 return None if not complement_axes else AxisId(complement_axes[0])
1892def _convert_proc(
1893 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1894 tensor_axes: Sequence[str],
1895) -> Union[PreprocessingDescr, PostprocessingDescr]:
1896 if isinstance(p, _BinarizeDescr_v0_4):
1897 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1898 elif isinstance(p, _ClipDescr_v0_4):
1899 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1900 elif isinstance(p, _SigmoidDescr_v0_4):
1901 return SigmoidDescr()
1902 elif isinstance(p, _ScaleLinearDescr_v0_4):
1903 axes = _axes_letters_to_ids(p.kwargs.axes)
1904 if p.kwargs.axes is None:
1905 axis = None
1906 else:
1907 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1909 if axis is None:
1910 assert not isinstance(p.kwargs.gain, list)
1911 assert not isinstance(p.kwargs.offset, list)
1912 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1913 else:
1914 kwargs = ScaleLinearAlongAxisKwargs(
1915 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1916 )
1917 return ScaleLinearDescr(kwargs=kwargs)
1918 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1919 return ScaleMeanVarianceDescr(
1920 kwargs=ScaleMeanVarianceKwargs(
1921 axes=_axes_letters_to_ids(p.kwargs.axes),
1922 reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1923 eps=p.kwargs.eps,
1924 )
1925 )
1926 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1927 if p.kwargs.mode == "fixed":
1928 mean = p.kwargs.mean
1929 std = p.kwargs.std
1930 assert mean is not None
1931 assert std is not None
1933 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1935 if axis is None:
1936 if isinstance(mean, list):
1937 raise ValueError("Expected single float value for mean, not <list>")
1938 if isinstance(std, list):
1939 raise ValueError("Expected single float value for std, not <list>")
1940 return FixedZeroMeanUnitVarianceDescr(
1941 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct(
1942 mean=mean,
1943 std=std,
1944 )
1945 )
1946 else:
1947 if not isinstance(mean, list):
1948 mean = [float(mean)]
1949 if not isinstance(std, list):
1950 std = [float(std)]
1952 return FixedZeroMeanUnitVarianceDescr(
1953 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1954 axis=axis, mean=mean, std=std
1955 )
1956 )
1958 else:
1959 axes = _axes_letters_to_ids(p.kwargs.axes) or []
1960 if p.kwargs.mode == "per_dataset":
1961 axes = [AxisId("batch")] + axes
1962 if not axes:
1963 axes = None
1964 return ZeroMeanUnitVarianceDescr(
1965 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1966 )
1968 elif isinstance(p, _ScaleRangeDescr_v0_4):
1969 return ScaleRangeDescr(
1970 kwargs=ScaleRangeKwargs(
1971 axes=_axes_letters_to_ids(p.kwargs.axes),
1972 min_percentile=p.kwargs.min_percentile,
1973 max_percentile=p.kwargs.max_percentile,
1974 eps=p.kwargs.eps,
1975 )
1976 )
1977 else:
1978 assert_never(p)
1981class _InputTensorConv(
1982 Converter[
1983 _InputTensorDescr_v0_4,
1984 InputTensorDescr,
1985 FileSource_,
1986 Optional[FileSource_],
1987 Mapping[_TensorName_v0_4, Mapping[str, int]],
1988 ]
1989):
1990 def _convert(
1991 self,
1992 src: _InputTensorDescr_v0_4,
1993 tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1994 test_tensor: FileSource_,
1995 sample_tensor: Optional[FileSource_],
1996 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1997 ) -> "InputTensorDescr | dict[str, Any]":
1998 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
1999 src.axes,
2000 shape=src.shape,
2001 tensor_type="input",
2002 halo=None,
2003 size_refs=size_refs,
2004 )
2005 prep: List[PreprocessingDescr] = []
2006 for p in src.preprocessing:
2007 cp = _convert_proc(p, src.axes)
2008 assert not isinstance(cp, ScaleMeanVarianceDescr)
2009 prep.append(cp)
2011 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
2013 return tgt(
2014 axes=axes,
2015 id=TensorId(str(src.name)),
2016 test_tensor=FileDescr(source=test_tensor),
2017 sample_tensor=(
2018 None if sample_tensor is None else FileDescr(source=sample_tensor)
2019 ),
2020 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType]
2021 preprocessing=prep,
2022 )
2025_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
2028class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2029 id: TensorId = TensorId("output")
2030 """Output tensor id.
2031 No duplicates are allowed across all inputs and outputs."""
2033 postprocessing: List[PostprocessingDescr] = Field(
2034 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2035 )
2036 """Description of how this output should be postprocessed.
2038 note: `postprocessing` always ends with an 'ensure_dtype' operation.
2039 If not given this is added to cast to this tensor's `data.type`.
2040 """
2042 @model_validator(mode="after")
2043 def _validate_postprocessing_kwargs(self) -> Self:
2044 axes_ids = [a.id for a in self.axes]
2045 for p in self.postprocessing:
2046 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2047 if kwargs_axes is None:
2048 continue
2050 if not isinstance(kwargs_axes, collections.abc.Sequence):
2051 raise ValueError(
2052 f"expected `axes` sequence, but got {type(kwargs_axes)}"
2053 )
2055 if any(a not in axes_ids for a in kwargs_axes):
2056 raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2058 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2059 dtype = self.data.type
2060 else:
2061 dtype = self.data[0].type
2063 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2064 if not self.postprocessing or not isinstance(
2065 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2066 ):
2067 self.postprocessing.append(
2068 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2069 )
2070 return self
2073class _OutputTensorConv(
2074 Converter[
2075 _OutputTensorDescr_v0_4,
2076 OutputTensorDescr,
2077 FileSource_,
2078 Optional[FileSource_],
2079 Mapping[_TensorName_v0_4, Mapping[str, int]],
2080 ]
2081):
2082 def _convert(
2083 self,
2084 src: _OutputTensorDescr_v0_4,
2085 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2086 test_tensor: FileSource_,
2087 sample_tensor: Optional[FileSource_],
2088 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2089 ) -> "OutputTensorDescr | dict[str, Any]":
2090 # TODO: split convert_axes into convert_output_axes and convert_input_axes
2091 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2092 src.axes,
2093 shape=src.shape,
2094 tensor_type="output",
2095 halo=src.halo,
2096 size_refs=size_refs,
2097 )
2098 data_descr: Dict[str, Any] = dict(type=src.data_type)
2099 if data_descr["type"] == "bool":
2100 data_descr["values"] = [False, True]
2102 return tgt(
2103 axes=axes,
2104 id=TensorId(str(src.name)),
2105 test_tensor=FileDescr(source=test_tensor),
2106 sample_tensor=(
2107 None if sample_tensor is None else FileDescr(source=sample_tensor)
2108 ),
2109 data=data_descr, # pyright: ignore[reportArgumentType]
2110 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2111 )
2114_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2117TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2120def validate_tensors(
2121 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2122 tensor_origin: Literal[
2123 "test_tensor"
2124 ], # for more precise error messages, e.g. 'test_tensor'
2125):
2126 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2128 def e_msg(d: TensorDescr):
2129 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2131 for descr, array in tensors.values():
2132 if array is None:
2133 axis_sizes = {a.id: None for a in descr.axes}
2134 else:
2135 try:
2136 axis_sizes = descr.get_axis_sizes_for_array(array)
2137 except ValueError as e:
2138 raise ValueError(f"{e_msg(descr)} {e}")
2140 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2142 for descr, array in tensors.values():
2143 if array is None:
2144 continue
2146 if descr.dtype in ("float32", "float64"):
2147 invalid_test_tensor_dtype = array.dtype.name not in (
2148 "float32",
2149 "float64",
2150 "uint8",
2151 "int8",
2152 "uint16",
2153 "int16",
2154 "uint32",
2155 "int32",
2156 "uint64",
2157 "int64",
2158 )
2159 else:
2160 invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2162 if invalid_test_tensor_dtype:
2163 raise ValueError(
2164 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2165 + f" match described dtype '{descr.dtype}'"
2166 )
2168 if array.min() > -1e-4 and array.max() < 1e-4:
2169 raise ValueError(
2170 "Output values are too small for reliable testing."
2171 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2172 )
2174 for a in descr.axes:
2175 actual_size = all_tensor_axes[descr.id][a.id][1]
2176 if actual_size is None:
2177 continue
2179 if a.size is None:
2180 continue
2182 if isinstance(a.size, int):
2183 if actual_size != a.size:
2184 raise ValueError(
2185 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2186 + f"has incompatible size {actual_size}, expected {a.size}"
2187 )
2188 elif isinstance(a.size, ParameterizedSize):
2189 _ = a.size.validate_size(actual_size)
2190 elif isinstance(a.size, DataDependentSize):
2191 _ = a.size.validate_size(actual_size)
2192 elif isinstance(a.size, SizeReference):
2193 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2194 if ref_tensor_axes is None:
2195 raise ValueError(
2196 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2197 + f" reference '{a.size.tensor_id}'"
2198 )
2200 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2201 if ref_axis is None or ref_size is None:
2202 raise ValueError(
2203 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2204 + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2205 )
2207 if a.unit != ref_axis.unit:
2208 raise ValueError(
2209 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2210 + " axis and reference axis to have the same `unit`, but"
2211 + f" {a.unit}!={ref_axis.unit}"
2212 )
2214 if actual_size != (
2215 expected_size := (
2216 ref_size * ref_axis.scale / a.scale + a.size.offset
2217 )
2218 ):
2219 raise ValueError(
2220 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2221 + f" {actual_size} invalid for referenced size {ref_size};"
2222 + f" expected {expected_size}"
2223 )
2224 else:
2225 assert_never(a.size)
2228FileDescr_dependencies = Annotated[
2229 FileDescr_,
2230 WithSuffix((".yaml", ".yml"), case_sensitive=True),
2231 Field(examples=[dict(source="environment.yaml")]),
2232]
2235class _ArchitectureCallableDescr(Node):
2236 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2237 """Identifier of the callable that returns a torch.nn.Module instance."""
2239 kwargs: Dict[str, YamlValue] = Field(
2240 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2241 )
2242 """key word arguments for the `callable`"""
2245class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2246 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2247 """Architecture source file"""
2249 @model_serializer(mode="wrap", when_used="unless-none")
2250 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2251 return package_file_descr_serializer(self, nxt, info)
2254class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2255 import_from: str
2256 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2259class _ArchFileConv(
2260 Converter[
2261 _CallableFromFile_v0_4,
2262 ArchitectureFromFileDescr,
2263 Optional[Sha256],
2264 Dict[str, Any],
2265 ]
2266):
2267 def _convert(
2268 self,
2269 src: _CallableFromFile_v0_4,
2270 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2271 sha256: Optional[Sha256],
2272 kwargs: Dict[str, Any],
2273 ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2274 if src.startswith("http") and src.count(":") == 2:
2275 http, source, callable_ = src.split(":")
2276 source = ":".join((http, source))
2277 elif not src.startswith("http") and src.count(":") == 1:
2278 source, callable_ = src.split(":")
2279 else:
2280 source = str(src)
2281 callable_ = str(src)
2282 return tgt(
2283 callable=Identifier(callable_),
2284 source=cast(FileSource_, source),
2285 sha256=sha256,
2286 kwargs=kwargs,
2287 )
2290_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2293class _ArchLibConv(
2294 Converter[
2295 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2296 ]
2297):
2298 def _convert(
2299 self,
2300 src: _CallableFromDepencency_v0_4,
2301 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2302 kwargs: Dict[str, Any],
2303 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2304 *mods, callable_ = src.split(".")
2305 import_from = ".".join(mods)
2306 return tgt(
2307 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2308 )
2311_arch_lib_conv = _ArchLibConv(
2312 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2313)
2316class WeightsEntryDescrBase(FileDescr):
2317 type: ClassVar[WeightsFormat]
2318 weights_format_name: ClassVar[str] # human readable
2320 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2321 """Source of the weights file."""
2323 authors: Optional[List[Author]] = None
2324 """Authors
2325 Either the person(s) that have trained this model resulting in the original weights file.
2326 (If this is the initial weights entry, i.e. it does not have a `parent`)
2327 Or the person(s) who have converted the weights to this weights format.
2328 (If this is a child weight, i.e. it has a `parent` field)
2329 """
2331 parent: Annotated[
2332 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2333 ] = None
2334 """The source weights these weights were converted from.
2335 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2336 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2337 All weight entries except one (the initial set of weights resulting from training the model),
2338 need to have this field."""
2340 comment: str = ""
2341 """A comment about this weights entry, for example how these weights were created."""
2343 @model_validator(mode="after")
2344 def _validate(self) -> Self:
2345 if self.type == self.parent:
2346 raise ValueError("Weights entry can't be it's own parent.")
2348 return self
2350 @model_serializer(mode="wrap", when_used="unless-none")
2351 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2352 return package_file_descr_serializer(self, nxt, info)
2355class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2356 type = "keras_hdf5"
2357 weights_format_name: ClassVar[str] = "Keras HDF5"
2358 tensorflow_version: Version
2359 """TensorFlow version used to create these weights."""
2362class OnnxWeightsDescr(WeightsEntryDescrBase):
2363 type = "onnx"
2364 weights_format_name: ClassVar[str] = "ONNX"
2365 opset_version: Annotated[int, Ge(7)]
2366 """ONNX opset version"""
2369class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2370 type = "pytorch_state_dict"
2371 weights_format_name: ClassVar[str] = "Pytorch State Dict"
2372 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2373 pytorch_version: Version
2374 """Version of the PyTorch library used.
2375 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2376 """
2377 dependencies: Optional[FileDescr_dependencies] = None
2378 """Custom depencies beyond pytorch described in a Conda environment file.
2379 Allows to specify custom dependencies, see conda docs:
2380 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2381 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2383 The conda environment file should include pytorch and any version pinning has to be compatible with
2384 **pytorch_version**.
2385 """
2388class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2389 type = "tensorflow_js"
2390 weights_format_name: ClassVar[str] = "Tensorflow.js"
2391 tensorflow_version: Version
2392 """Version of the TensorFlow library used."""
2394 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2395 """The multi-file weights.
2396 All required files/folders should be a zip archive."""
2399class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2400 type = "tensorflow_saved_model_bundle"
2401 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2402 tensorflow_version: Version
2403 """Version of the TensorFlow library used."""
2405 dependencies: Optional[FileDescr_dependencies] = None
2406 """Custom dependencies beyond tensorflow.
2407 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2409 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2410 """The multi-file weights.
2411 All required files/folders should be a zip archive."""
2414class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2415 type = "torchscript"
2416 weights_format_name: ClassVar[str] = "TorchScript"
2417 pytorch_version: Version
2418 """Version of the PyTorch library used."""
2421class WeightsDescr(Node):
2422 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2423 onnx: Optional[OnnxWeightsDescr] = None
2424 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2425 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2426 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2427 None
2428 )
2429 torchscript: Optional[TorchscriptWeightsDescr] = None
2431 @model_validator(mode="after")
2432 def check_entries(self) -> Self:
2433 entries = {wtype for wtype, entry in self if entry is not None}
2435 if not entries:
2436 raise ValueError("Missing weights entry")
2438 entries_wo_parent = {
2439 wtype
2440 for wtype, entry in self
2441 if entry is not None and hasattr(entry, "parent") and entry.parent is None
2442 }
2443 if len(entries_wo_parent) != 1:
2444 issue_warning(
2445 "Exactly one weights entry may not specify the `parent` field (got"
2446 + " {value}). That entry is considered the original set of model weights."
2447 + " Other weight formats are created through conversion of the orignal or"
2448 + " already converted weights. They have to reference the weights format"
2449 + " they were converted from as their `parent`.",
2450 value=len(entries_wo_parent),
2451 field="weights",
2452 )
2454 for wtype, entry in self:
2455 if entry is None:
2456 continue
2458 assert hasattr(entry, "type")
2459 assert hasattr(entry, "parent")
2460 assert wtype == entry.type
2461 if (
2462 entry.parent is not None and entry.parent not in entries
2463 ): # self reference checked for `parent` field
2464 raise ValueError(
2465 f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2466 + f" formats: {entries}"
2467 )
2469 return self
2471 def __getitem__(
2472 self,
2473 key: Literal[
2474 "keras_hdf5",
2475 "onnx",
2476 "pytorch_state_dict",
2477 "tensorflow_js",
2478 "tensorflow_saved_model_bundle",
2479 "torchscript",
2480 ],
2481 ):
2482 if key == "keras_hdf5":
2483 ret = self.keras_hdf5
2484 elif key == "onnx":
2485 ret = self.onnx
2486 elif key == "pytorch_state_dict":
2487 ret = self.pytorch_state_dict
2488 elif key == "tensorflow_js":
2489 ret = self.tensorflow_js
2490 elif key == "tensorflow_saved_model_bundle":
2491 ret = self.tensorflow_saved_model_bundle
2492 elif key == "torchscript":
2493 ret = self.torchscript
2494 else:
2495 raise KeyError(key)
2497 if ret is None:
2498 raise KeyError(key)
2500 return ret
2502 @property
2503 def available_formats(self):
2504 return {
2505 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2506 **({} if self.onnx is None else {"onnx": self.onnx}),
2507 **(
2508 {}
2509 if self.pytorch_state_dict is None
2510 else {"pytorch_state_dict": self.pytorch_state_dict}
2511 ),
2512 **(
2513 {}
2514 if self.tensorflow_js is None
2515 else {"tensorflow_js": self.tensorflow_js}
2516 ),
2517 **(
2518 {}
2519 if self.tensorflow_saved_model_bundle is None
2520 else {
2521 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2522 }
2523 ),
2524 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2525 }
2527 @property
2528 def missing_formats(self):
2529 return {
2530 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2531 }
2534class ModelId(ResourceId):
2535 pass
2538class LinkedModel(LinkedResourceBase):
2539 """Reference to a bioimage.io model."""
2541 id: ModelId
2542 """A valid model `id` from the bioimage.io collection."""
2545class _DataDepSize(NamedTuple):
2546 min: StrictInt
2547 max: Optional[StrictInt]
2550class _AxisSizes(NamedTuple):
2551 """the lenghts of all axes of model inputs and outputs"""
2553 inputs: Dict[Tuple[TensorId, AxisId], int]
2554 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2557class _TensorSizes(NamedTuple):
2558 """_AxisSizes as nested dicts"""
2560 inputs: Dict[TensorId, Dict[AxisId, int]]
2561 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2564class ReproducibilityTolerance(Node, extra="allow"):
2565 """Describes what small numerical differences -- if any -- may be tolerated
2566 in the generated output when executing in different environments.
2568 A tensor element *output* is considered mismatched to the **test_tensor** if
2569 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2570 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2572 Motivation:
2573 For testing we can request the respective deep learning frameworks to be as
2574 reproducible as possible by setting seeds and chosing deterministic algorithms,
2575 but differences in operating systems, available hardware and installed drivers
2576 may still lead to numerical differences.
2577 """
2579 relative_tolerance: RelativeTolerance = 1e-3
2580 """Maximum relative tolerance of reproduced test tensor."""
2582 absolute_tolerance: AbsoluteTolerance = 1e-4
2583 """Maximum absolute tolerance of reproduced test tensor."""
2585 mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2586 """Maximum number of mismatched elements/pixels per million to tolerate."""
2588 output_ids: Sequence[TensorId] = ()
2589 """Limits the output tensor IDs these reproducibility details apply to."""
2591 weights_formats: Sequence[WeightsFormat] = ()
2592 """Limits the weights formats these details apply to."""
2595class BioimageioConfig(Node, extra="allow"):
2596 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2597 """Tolerances to allow when reproducing the model's test outputs
2598 from the model's test inputs.
2599 Only the first entry matching tensor id and weights format is considered.
2600 """
2603class Config(Node, extra="allow"):
2604 bioimageio: BioimageioConfig = Field(
2605 default_factory=BioimageioConfig.model_construct
2606 )
2609class ModelDescr(GenericModelDescrBase):
2610 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2611 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2612 """
2614 implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5"
2615 if TYPE_CHECKING:
2616 format_version: Literal["0.5.5"] = "0.5.5"
2617 else:
2618 format_version: Literal["0.5.5"]
2619 """Version of the bioimage.io model description specification used.
2620 When creating a new model always use the latest micro/patch version described here.
2621 The `format_version` is important for any consumer software to understand how to parse the fields.
2622 """
2624 implemented_type: ClassVar[Literal["model"]] = "model"
2625 if TYPE_CHECKING:
2626 type: Literal["model"] = "model"
2627 else:
2628 type: Literal["model"]
2629 """Specialized resource type 'model'"""
2631 id: Optional[ModelId] = None
2632 """bioimage.io-wide unique resource identifier
2633 assigned by bioimage.io; version **un**specific."""
2635 authors: FAIR[List[Author]] = Field(
2636 default_factory=cast(Callable[[], List[Author]], list)
2637 )
2638 """The authors are the creators of the model RDF and the primary points of contact."""
2640 documentation: FAIR[Optional[FileSource_documentation]] = None
2641 """URL or relative path to a markdown file with additional documentation.
2642 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2643 The documentation should include a '#[#] Validation' (sub)section
2644 with details on how to quantitatively validate the model on unseen data."""
2646 @field_validator("documentation", mode="after")
2647 @classmethod
2648 def _validate_documentation(
2649 cls, value: Optional[FileSource_documentation]
2650 ) -> Optional[FileSource_documentation]:
2651 if not get_validation_context().perform_io_checks or value is None:
2652 return value
2654 doc_reader = get_reader(value)
2655 doc_content = doc_reader.read().decode(encoding="utf-8")
2656 if not re.search("#.*[vV]alidation", doc_content):
2657 issue_warning(
2658 "No '# Validation' (sub)section found in {value}.",
2659 value=value,
2660 field="documentation",
2661 )
2663 return value
2665 inputs: NotEmpty[Sequence[InputTensorDescr]]
2666 """Describes the input tensors expected by this model."""
2668 @field_validator("inputs", mode="after")
2669 @classmethod
2670 def _validate_input_axes(
2671 cls, inputs: Sequence[InputTensorDescr]
2672 ) -> Sequence[InputTensorDescr]:
2673 input_size_refs = cls._get_axes_with_independent_size(inputs)
2675 for i, ipt in enumerate(inputs):
2676 valid_independent_refs: Dict[
2677 Tuple[TensorId, AxisId],
2678 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2679 ] = {
2680 **{
2681 (ipt.id, a.id): (ipt, a, a.size)
2682 for a in ipt.axes
2683 if not isinstance(a, BatchAxis)
2684 and isinstance(a.size, (int, ParameterizedSize))
2685 },
2686 **input_size_refs,
2687 }
2688 for a, ax in enumerate(ipt.axes):
2689 cls._validate_axis(
2690 "inputs",
2691 i=i,
2692 tensor_id=ipt.id,
2693 a=a,
2694 axis=ax,
2695 valid_independent_refs=valid_independent_refs,
2696 )
2697 return inputs
2699 @staticmethod
2700 def _validate_axis(
2701 field_name: str,
2702 i: int,
2703 tensor_id: TensorId,
2704 a: int,
2705 axis: AnyAxis,
2706 valid_independent_refs: Dict[
2707 Tuple[TensorId, AxisId],
2708 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2709 ],
2710 ):
2711 if isinstance(axis, BatchAxis) or isinstance(
2712 axis.size, (int, ParameterizedSize, DataDependentSize)
2713 ):
2714 return
2715 elif not isinstance(axis.size, SizeReference):
2716 assert_never(axis.size)
2718 # validate axis.size SizeReference
2719 ref = (axis.size.tensor_id, axis.size.axis_id)
2720 if ref not in valid_independent_refs:
2721 raise ValueError(
2722 "Invalid tensor axis reference at"
2723 + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2724 )
2725 if ref == (tensor_id, axis.id):
2726 raise ValueError(
2727 "Self-referencing not allowed for"
2728 + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2729 )
2730 if axis.type == "channel":
2731 if valid_independent_refs[ref][1].type != "channel":
2732 raise ValueError(
2733 "A channel axis' size may only reference another fixed size"
2734 + " channel axis."
2735 )
2736 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2737 ref_size = valid_independent_refs[ref][2]
2738 assert isinstance(ref_size, int), (
2739 "channel axis ref (another channel axis) has to specify fixed"
2740 + " size"
2741 )
2742 generated_channel_names = [
2743 Identifier(axis.channel_names.format(i=i))
2744 for i in range(1, ref_size + 1)
2745 ]
2746 axis.channel_names = generated_channel_names
2748 if (ax_unit := getattr(axis, "unit", None)) != (
2749 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2750 ):
2751 raise ValueError(
2752 "The units of an axis and its reference axis need to match, but"
2753 + f" '{ax_unit}' != '{ref_unit}'."
2754 )
2755 ref_axis = valid_independent_refs[ref][1]
2756 if isinstance(ref_axis, BatchAxis):
2757 raise ValueError(
2758 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2759 + " (a batch axis is not allowed as reference)."
2760 )
2762 if isinstance(axis, WithHalo):
2763 min_size = axis.size.get_size(axis, ref_axis, n=0)
2764 if (min_size - 2 * axis.halo) < 1:
2765 raise ValueError(
2766 f"axis {axis.id} with minimum size {min_size} is too small for halo"
2767 + f" {axis.halo}."
2768 )
2770 input_halo = axis.halo * axis.scale / ref_axis.scale
2771 if input_halo != int(input_halo) or input_halo % 2 == 1:
2772 raise ValueError(
2773 f"input_halo {input_halo} (output_halo {axis.halo} *"
2774 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2775 + f" {tensor_id}.{axis.id}."
2776 )
2778 @model_validator(mode="after")
2779 def _validate_test_tensors(self) -> Self:
2780 if not get_validation_context().perform_io_checks:
2781 return self
2783 test_output_arrays = [
2784 None if descr.test_tensor is None else load_array(descr.test_tensor)
2785 for descr in self.outputs
2786 ]
2787 test_input_arrays = [
2788 None if descr.test_tensor is None else load_array(descr.test_tensor)
2789 for descr in self.inputs
2790 ]
2792 tensors = {
2793 descr.id: (descr, array)
2794 for descr, array in zip(
2795 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2796 )
2797 }
2798 validate_tensors(tensors, tensor_origin="test_tensor")
2800 output_arrays = {
2801 descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2802 }
2803 for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2804 if not rep_tol.absolute_tolerance:
2805 continue
2807 if rep_tol.output_ids:
2808 out_arrays = {
2809 oid: a
2810 for oid, a in output_arrays.items()
2811 if oid in rep_tol.output_ids
2812 }
2813 else:
2814 out_arrays = output_arrays
2816 for out_id, array in out_arrays.items():
2817 if array is None:
2818 continue
2820 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2821 raise ValueError(
2822 "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2823 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2824 + f" (1% of the maximum value of the test tensor '{out_id}')"
2825 )
2827 return self
2829 @model_validator(mode="after")
2830 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2831 ipt_refs = {t.id for t in self.inputs}
2832 out_refs = {t.id for t in self.outputs}
2833 for ipt in self.inputs:
2834 for p in ipt.preprocessing:
2835 ref = p.kwargs.get("reference_tensor")
2836 if ref is None:
2837 continue
2838 if ref not in ipt_refs:
2839 raise ValueError(
2840 f"`reference_tensor` '{ref}' not found. Valid input tensor"
2841 + f" references are: {ipt_refs}."
2842 )
2844 for out in self.outputs:
2845 for p in out.postprocessing:
2846 ref = p.kwargs.get("reference_tensor")
2847 if ref is None:
2848 continue
2850 if ref not in ipt_refs and ref not in out_refs:
2851 raise ValueError(
2852 f"`reference_tensor` '{ref}' not found. Valid tensor references"
2853 + f" are: {ipt_refs | out_refs}."
2854 )
2856 return self
2858 # TODO: use validate funcs in validate_test_tensors
2859 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2861 name: Annotated[
2862 str,
2863 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2864 MinLen(5),
2865 MaxLen(128),
2866 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2867 ]
2868 """A human-readable name of this model.
2869 It should be no longer than 64 characters
2870 and may only contain letter, number, underscore, minus, parentheses and spaces.
2871 We recommend to chose a name that refers to the model's task and image modality.
2872 """
2874 outputs: NotEmpty[Sequence[OutputTensorDescr]]
2875 """Describes the output tensors."""
2877 @field_validator("outputs", mode="after")
2878 @classmethod
2879 def _validate_tensor_ids(
2880 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2881 ) -> Sequence[OutputTensorDescr]:
2882 tensor_ids = [
2883 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2884 ]
2885 duplicate_tensor_ids: List[str] = []
2886 seen: Set[str] = set()
2887 for t in tensor_ids:
2888 if t in seen:
2889 duplicate_tensor_ids.append(t)
2891 seen.add(t)
2893 if duplicate_tensor_ids:
2894 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2896 return outputs
2898 @staticmethod
2899 def _get_axes_with_parameterized_size(
2900 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2901 ):
2902 return {
2903 f"{t.id}.{a.id}": (t, a, a.size)
2904 for t in io
2905 for a in t.axes
2906 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2907 }
2909 @staticmethod
2910 def _get_axes_with_independent_size(
2911 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2912 ):
2913 return {
2914 (t.id, a.id): (t, a, a.size)
2915 for t in io
2916 for a in t.axes
2917 if not isinstance(a, BatchAxis)
2918 and isinstance(a.size, (int, ParameterizedSize))
2919 }
2921 @field_validator("outputs", mode="after")
2922 @classmethod
2923 def _validate_output_axes(
2924 cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2925 ) -> List[OutputTensorDescr]:
2926 input_size_refs = cls._get_axes_with_independent_size(
2927 info.data.get("inputs", [])
2928 )
2929 output_size_refs = cls._get_axes_with_independent_size(outputs)
2931 for i, out in enumerate(outputs):
2932 valid_independent_refs: Dict[
2933 Tuple[TensorId, AxisId],
2934 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2935 ] = {
2936 **{
2937 (out.id, a.id): (out, a, a.size)
2938 for a in out.axes
2939 if not isinstance(a, BatchAxis)
2940 and isinstance(a.size, (int, ParameterizedSize))
2941 },
2942 **input_size_refs,
2943 **output_size_refs,
2944 }
2945 for a, ax in enumerate(out.axes):
2946 cls._validate_axis(
2947 "outputs",
2948 i,
2949 out.id,
2950 a,
2951 ax,
2952 valid_independent_refs=valid_independent_refs,
2953 )
2955 return outputs
2957 packaged_by: List[Author] = Field(
2958 default_factory=cast(Callable[[], List[Author]], list)
2959 )
2960 """The persons that have packaged and uploaded this model.
2961 Only required if those persons differ from the `authors`."""
2963 parent: Optional[LinkedModel] = None
2964 """The model from which this model is derived, e.g. by fine-tuning the weights."""
2966 @model_validator(mode="after")
2967 def _validate_parent_is_not_self(self) -> Self:
2968 if self.parent is not None and self.parent.id == self.id:
2969 raise ValueError("A model description may not reference itself as parent.")
2971 return self
2973 run_mode: Annotated[
2974 Optional[RunMode],
2975 warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2976 ] = None
2977 """Custom run mode for this model: for more complex prediction procedures like test time
2978 data augmentation that currently cannot be expressed in the specification.
2979 No standard run modes are defined yet."""
2981 timestamp: Datetime = Field(default_factory=Datetime.now)
2982 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2983 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2984 (In Python a datetime object is valid, too)."""
2986 training_data: Annotated[
2987 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2988 Field(union_mode="left_to_right"),
2989 ] = None
2990 """The dataset used to train this model"""
2992 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2993 """The weights for this model.
2994 Weights can be given for different formats, but should otherwise be equivalent.
2995 The available weight formats determine which consumers can use this model."""
2997 config: Config = Field(default_factory=Config.model_construct)
2999 @model_validator(mode="after")
3000 def _add_default_cover(self) -> Self:
3001 if not get_validation_context().perform_io_checks or self.covers:
3002 return self
3004 try:
3005 generated_covers = generate_covers(
3006 [
3007 (t, load_array(t.test_tensor))
3008 for t in self.inputs
3009 if t.test_tensor is not None
3010 ],
3011 [
3012 (t, load_array(t.test_tensor))
3013 for t in self.outputs
3014 if t.test_tensor is not None
3015 ],
3016 )
3017 except Exception as e:
3018 issue_warning(
3019 "Failed to generate cover image(s): {e}",
3020 value=self.covers,
3021 msg_context=dict(e=e),
3022 field="covers",
3023 )
3024 else:
3025 self.covers.extend(generated_covers)
3027 return self
3029 def get_input_test_arrays(self) -> List[NDArray[Any]]:
3030 return self._get_test_arrays(self.inputs)
3032 def get_output_test_arrays(self) -> List[NDArray[Any]]:
3033 return self._get_test_arrays(self.outputs)
3035 @staticmethod
3036 def _get_test_arrays(
3037 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3038 ):
3039 ts: List[FileDescr] = []
3040 for d in io_descr:
3041 if d.test_tensor is None:
3042 raise ValueError(
3043 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3044 )
3045 ts.append(d.test_tensor)
3047 data = [load_array(t) for t in ts]
3048 assert all(isinstance(d, np.ndarray) for d in data)
3049 return data
3051 @staticmethod
3052 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3053 batch_size = 1
3054 tensor_with_batchsize: Optional[TensorId] = None
3055 for tid in tensor_sizes:
3056 for aid, s in tensor_sizes[tid].items():
3057 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3058 continue
3060 if batch_size != 1:
3061 assert tensor_with_batchsize is not None
3062 raise ValueError(
3063 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3064 )
3066 batch_size = s
3067 tensor_with_batchsize = tid
3069 return batch_size
3071 def get_output_tensor_sizes(
3072 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3073 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3074 """Returns the tensor output sizes for given **input_sizes**.
3075 Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3076 Otherwise it might be larger than the actual (valid) output"""
3077 batch_size = self.get_batch_size(input_sizes)
3078 ns = self.get_ns(input_sizes)
3080 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3081 return tensor_sizes.outputs
3083 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3084 """get parameter `n` for each parameterized axis
3085 such that the valid input size is >= the given input size"""
3086 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3087 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3088 for tid in input_sizes:
3089 for aid, s in input_sizes[tid].items():
3090 size_descr = axes[tid][aid].size
3091 if isinstance(size_descr, ParameterizedSize):
3092 ret[(tid, aid)] = size_descr.get_n(s)
3093 elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3094 pass
3095 else:
3096 assert_never(size_descr)
3098 return ret
3100 def get_tensor_sizes(
3101 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3102 ) -> _TensorSizes:
3103 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3104 return _TensorSizes(
3105 {
3106 t: {
3107 aa: axis_sizes.inputs[(tt, aa)]
3108 for tt, aa in axis_sizes.inputs
3109 if tt == t
3110 }
3111 for t in {tt for tt, _ in axis_sizes.inputs}
3112 },
3113 {
3114 t: {
3115 aa: axis_sizes.outputs[(tt, aa)]
3116 for tt, aa in axis_sizes.outputs
3117 if tt == t
3118 }
3119 for t in {tt for tt, _ in axis_sizes.outputs}
3120 },
3121 )
3123 def get_axis_sizes(
3124 self,
3125 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3126 batch_size: Optional[int] = None,
3127 *,
3128 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3129 ) -> _AxisSizes:
3130 """Determine input and output block shape for scale factors **ns**
3131 of parameterized input sizes.
3133 Args:
3134 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3135 that is parameterized as `size = min + n * step`.
3136 batch_size: The desired size of the batch dimension.
3137 If given **batch_size** overwrites any batch size present in
3138 **max_input_shape**. Default 1.
3139 max_input_shape: Limits the derived block shapes.
3140 Each axis for which the input size, parameterized by `n`, is larger
3141 than **max_input_shape** is set to the minimal value `n_min` for which
3142 this is still true.
3143 Use this for small input samples or large values of **ns**.
3144 Or simply whenever you know the full input shape.
3146 Returns:
3147 Resolved axis sizes for model inputs and outputs.
3148 """
3149 max_input_shape = max_input_shape or {}
3150 if batch_size is None:
3151 for (_t_id, a_id), s in max_input_shape.items():
3152 if a_id == BATCH_AXIS_ID:
3153 batch_size = s
3154 break
3155 else:
3156 batch_size = 1
3158 all_axes = {
3159 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3160 }
3162 inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3163 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3165 def get_axis_size(a: Union[InputAxis, OutputAxis]):
3166 if isinstance(a, BatchAxis):
3167 if (t_descr.id, a.id) in ns:
3168 logger.warning(
3169 "Ignoring unexpected size increment factor (n) for batch axis"
3170 + " of tensor '{}'.",
3171 t_descr.id,
3172 )
3173 return batch_size
3174 elif isinstance(a.size, int):
3175 if (t_descr.id, a.id) in ns:
3176 logger.warning(
3177 "Ignoring unexpected size increment factor (n) for fixed size"
3178 + " axis '{}' of tensor '{}'.",
3179 a.id,
3180 t_descr.id,
3181 )
3182 return a.size
3183 elif isinstance(a.size, ParameterizedSize):
3184 if (t_descr.id, a.id) not in ns:
3185 raise ValueError(
3186 "Size increment factor (n) missing for parametrized axis"
3187 + f" '{a.id}' of tensor '{t_descr.id}'."
3188 )
3189 n = ns[(t_descr.id, a.id)]
3190 s_max = max_input_shape.get((t_descr.id, a.id))
3191 if s_max is not None:
3192 n = min(n, a.size.get_n(s_max))
3194 return a.size.get_size(n)
3196 elif isinstance(a.size, SizeReference):
3197 if (t_descr.id, a.id) in ns:
3198 logger.warning(
3199 "Ignoring unexpected size increment factor (n) for axis '{}'"
3200 + " of tensor '{}' with size reference.",
3201 a.id,
3202 t_descr.id,
3203 )
3204 assert not isinstance(a, BatchAxis)
3205 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3206 assert not isinstance(ref_axis, BatchAxis)
3207 ref_key = (a.size.tensor_id, a.size.axis_id)
3208 ref_size = inputs.get(ref_key, outputs.get(ref_key))
3209 assert ref_size is not None, ref_key
3210 assert not isinstance(ref_size, _DataDepSize), ref_key
3211 return a.size.get_size(
3212 axis=a,
3213 ref_axis=ref_axis,
3214 ref_size=ref_size,
3215 )
3216 elif isinstance(a.size, DataDependentSize):
3217 if (t_descr.id, a.id) in ns:
3218 logger.warning(
3219 "Ignoring unexpected increment factor (n) for data dependent"
3220 + " size axis '{}' of tensor '{}'.",
3221 a.id,
3222 t_descr.id,
3223 )
3224 return _DataDepSize(a.size.min, a.size.max)
3225 else:
3226 assert_never(a.size)
3228 # first resolve all , but the `SizeReference` input sizes
3229 for t_descr in self.inputs:
3230 for a in t_descr.axes:
3231 if not isinstance(a.size, SizeReference):
3232 s = get_axis_size(a)
3233 assert not isinstance(s, _DataDepSize)
3234 inputs[t_descr.id, a.id] = s
3236 # resolve all other input axis sizes
3237 for t_descr in self.inputs:
3238 for a in t_descr.axes:
3239 if isinstance(a.size, SizeReference):
3240 s = get_axis_size(a)
3241 assert not isinstance(s, _DataDepSize)
3242 inputs[t_descr.id, a.id] = s
3244 # resolve all output axis sizes
3245 for t_descr in self.outputs:
3246 for a in t_descr.axes:
3247 assert not isinstance(a.size, ParameterizedSize)
3248 s = get_axis_size(a)
3249 outputs[t_descr.id, a.id] = s
3251 return _AxisSizes(inputs=inputs, outputs=outputs)
3253 @model_validator(mode="before")
3254 @classmethod
3255 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3256 cls.convert_from_old_format_wo_validation(data)
3257 return data
3259 @classmethod
3260 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3261 """Convert metadata following an older format version to this classes' format
3262 without validating the result.
3263 """
3264 if (
3265 data.get("type") == "model"
3266 and isinstance(fv := data.get("format_version"), str)
3267 and fv.count(".") == 2
3268 ):
3269 fv_parts = fv.split(".")
3270 if any(not p.isdigit() for p in fv_parts):
3271 return
3273 fv_tuple = tuple(map(int, fv_parts))
3275 assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3276 if fv_tuple[:2] in ((0, 3), (0, 4)):
3277 m04 = _ModelDescr_v0_4.load(data)
3278 if isinstance(m04, InvalidDescr):
3279 try:
3280 updated = _model_conv.convert_as_dict(
3281 m04 # pyright: ignore[reportArgumentType]
3282 )
3283 except Exception as e:
3284 logger.error(
3285 "Failed to convert from invalid model 0.4 description."
3286 + f"\nerror: {e}"
3287 + "\nProceeding with model 0.5 validation without conversion."
3288 )
3289 updated = None
3290 else:
3291 updated = _model_conv.convert_as_dict(m04)
3293 if updated is not None:
3294 data.clear()
3295 data.update(updated)
3297 elif fv_tuple[:2] == (0, 5):
3298 # bump patch version
3299 data["format_version"] = cls.implemented_format_version
3302class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3303 def _convert(
3304 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3305 ) -> "ModelDescr | dict[str, Any]":
3306 name = "".join(
3307 c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3308 for c in src.name
3309 )
3311 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3312 conv = (
3313 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3314 )
3315 return None if auths is None else [conv(a) for a in auths]
3317 if TYPE_CHECKING:
3318 arch_file_conv = _arch_file_conv.convert
3319 arch_lib_conv = _arch_lib_conv.convert
3320 else:
3321 arch_file_conv = _arch_file_conv.convert_as_dict
3322 arch_lib_conv = _arch_lib_conv.convert_as_dict
3324 input_size_refs = {
3325 ipt.name: {
3326 a: s
3327 for a, s in zip(
3328 ipt.axes,
3329 (
3330 ipt.shape.min
3331 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3332 else ipt.shape
3333 ),
3334 )
3335 }
3336 for ipt in src.inputs
3337 if ipt.shape
3338 }
3339 output_size_refs = {
3340 **{
3341 out.name: {a: s for a, s in zip(out.axes, out.shape)}
3342 for out in src.outputs
3343 if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3344 },
3345 **input_size_refs,
3346 }
3348 return tgt(
3349 attachments=(
3350 []
3351 if src.attachments is None
3352 else [FileDescr(source=f) for f in src.attachments.files]
3353 ),
3354 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType]
3355 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType]
3356 config=src.config, # pyright: ignore[reportArgumentType]
3357 covers=src.covers,
3358 description=src.description,
3359 documentation=src.documentation,
3360 format_version="0.5.5",
3361 git_repo=src.git_repo, # pyright: ignore[reportArgumentType]
3362 icon=src.icon,
3363 id=None if src.id is None else ModelId(src.id),
3364 id_emoji=src.id_emoji,
3365 license=src.license, # type: ignore
3366 links=src.links,
3367 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType]
3368 name=name,
3369 tags=src.tags,
3370 type=src.type,
3371 uploader=src.uploader,
3372 version=src.version,
3373 inputs=[ # pyright: ignore[reportArgumentType]
3374 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3375 for ipt, tt, st in zip(
3376 src.inputs,
3377 src.test_inputs,
3378 src.sample_inputs or [None] * len(src.test_inputs),
3379 )
3380 ],
3381 outputs=[ # pyright: ignore[reportArgumentType]
3382 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3383 for out, tt, st in zip(
3384 src.outputs,
3385 src.test_outputs,
3386 src.sample_outputs or [None] * len(src.test_outputs),
3387 )
3388 ],
3389 parent=(
3390 None
3391 if src.parent is None
3392 else LinkedModel(
3393 id=ModelId(
3394 str(src.parent.id)
3395 + (
3396 ""
3397 if src.parent.version_number is None
3398 else f"/{src.parent.version_number}"
3399 )
3400 )
3401 )
3402 ),
3403 training_data=(
3404 None
3405 if src.training_data is None
3406 else (
3407 LinkedDataset(
3408 id=DatasetId(
3409 str(src.training_data.id)
3410 + (
3411 ""
3412 if src.training_data.version_number is None
3413 else f"/{src.training_data.version_number}"
3414 )
3415 )
3416 )
3417 if isinstance(src.training_data, LinkedDataset02)
3418 else src.training_data
3419 )
3420 ),
3421 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType]
3422 run_mode=src.run_mode,
3423 timestamp=src.timestamp,
3424 weights=(WeightsDescr if TYPE_CHECKING else dict)(
3425 keras_hdf5=(w := src.weights.keras_hdf5)
3426 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3427 authors=conv_authors(w.authors),
3428 source=w.source,
3429 tensorflow_version=w.tensorflow_version or Version("1.15"),
3430 parent=w.parent,
3431 ),
3432 onnx=(w := src.weights.onnx)
3433 and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3434 source=w.source,
3435 authors=conv_authors(w.authors),
3436 parent=w.parent,
3437 opset_version=w.opset_version or 15,
3438 ),
3439 pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3440 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3441 source=w.source,
3442 authors=conv_authors(w.authors),
3443 parent=w.parent,
3444 architecture=(
3445 arch_file_conv(
3446 w.architecture,
3447 w.architecture_sha256,
3448 w.kwargs,
3449 )
3450 if isinstance(w.architecture, _CallableFromFile_v0_4)
3451 else arch_lib_conv(w.architecture, w.kwargs)
3452 ),
3453 pytorch_version=w.pytorch_version or Version("1.10"),
3454 dependencies=(
3455 None
3456 if w.dependencies is None
3457 else (FileDescr if TYPE_CHECKING else dict)(
3458 source=cast(
3459 FileSource,
3460 str(deps := w.dependencies)[
3461 (
3462 len("conda:")
3463 if str(deps).startswith("conda:")
3464 else 0
3465 ) :
3466 ],
3467 )
3468 )
3469 ),
3470 ),
3471 tensorflow_js=(w := src.weights.tensorflow_js)
3472 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3473 source=w.source,
3474 authors=conv_authors(w.authors),
3475 parent=w.parent,
3476 tensorflow_version=w.tensorflow_version or Version("1.15"),
3477 ),
3478 tensorflow_saved_model_bundle=(
3479 w := src.weights.tensorflow_saved_model_bundle
3480 )
3481 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3482 authors=conv_authors(w.authors),
3483 parent=w.parent,
3484 source=w.source,
3485 tensorflow_version=w.tensorflow_version or Version("1.15"),
3486 dependencies=(
3487 None
3488 if w.dependencies is None
3489 else (FileDescr if TYPE_CHECKING else dict)(
3490 source=cast(
3491 FileSource,
3492 (
3493 str(w.dependencies)[len("conda:") :]
3494 if str(w.dependencies).startswith("conda:")
3495 else str(w.dependencies)
3496 ),
3497 )
3498 )
3499 ),
3500 ),
3501 torchscript=(w := src.weights.torchscript)
3502 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3503 source=w.source,
3504 authors=conv_authors(w.authors),
3505 parent=w.parent,
3506 pytorch_version=w.pytorch_version or Version("1.10"),
3507 ),
3508 ),
3509 )
3512_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3515# create better cover images for 3d data and non-image outputs
3516def generate_covers(
3517 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3518 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3519) -> List[Path]:
3520 def squeeze(
3521 data: NDArray[Any], axes: Sequence[AnyAxis]
3522 ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3523 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3524 if data.ndim != len(axes):
3525 raise ValueError(
3526 f"tensor shape {data.shape} does not match described axes"
3527 + f" {[a.id for a in axes]}"
3528 )
3530 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3531 return data.squeeze(), axes
3533 def normalize(
3534 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3535 ) -> NDArray[np.float32]:
3536 data = data.astype("float32")
3537 data -= data.min(axis=axis, keepdims=True)
3538 data /= data.max(axis=axis, keepdims=True) + eps
3539 return data
3541 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3542 original_shape = data.shape
3543 data, axes = squeeze(data, axes)
3545 # take slice fom any batch or index axis if needed
3546 # and convert the first channel axis and take a slice from any additional channel axes
3547 slices: Tuple[slice, ...] = ()
3548 ndim = data.ndim
3549 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3550 has_c_axis = False
3551 for i, a in enumerate(axes):
3552 s = data.shape[i]
3553 assert s > 1
3554 if (
3555 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3556 and ndim > ndim_need
3557 ):
3558 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3559 ndim -= 1
3560 elif isinstance(a, ChannelAxis):
3561 if has_c_axis:
3562 # second channel axis
3563 data = data[slices + (slice(0, 1),)]
3564 ndim -= 1
3565 else:
3566 has_c_axis = True
3567 if s == 2:
3568 # visualize two channels with cyan and magenta
3569 data = np.concatenate(
3570 [
3571 data[slices + (slice(1, 2),)],
3572 data[slices + (slice(0, 1),)],
3573 (
3574 data[slices + (slice(0, 1),)]
3575 + data[slices + (slice(1, 2),)]
3576 )
3577 / 2, # TODO: take maximum instead?
3578 ],
3579 axis=i,
3580 )
3581 elif data.shape[i] == 3:
3582 pass # visualize 3 channels as RGB
3583 else:
3584 # visualize first 3 channels as RGB
3585 data = data[slices + (slice(3),)]
3587 assert data.shape[i] == 3
3589 slices += (slice(None),)
3591 data, axes = squeeze(data, axes)
3592 assert len(axes) == ndim
3593 # take slice from z axis if needed
3594 slices = ()
3595 if ndim > ndim_need:
3596 for i, a in enumerate(axes):
3597 s = data.shape[i]
3598 if a.id == AxisId("z"):
3599 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3600 data, axes = squeeze(data, axes)
3601 ndim -= 1
3602 break
3604 slices += (slice(None),)
3606 # take slice from any space or time axis
3607 slices = ()
3609 for i, a in enumerate(axes):
3610 if ndim <= ndim_need:
3611 break
3613 s = data.shape[i]
3614 assert s > 1
3615 if isinstance(
3616 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3617 ):
3618 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3619 ndim -= 1
3621 slices += (slice(None),)
3623 del slices
3624 data, axes = squeeze(data, axes)
3625 assert len(axes) == ndim
3627 if (has_c_axis and ndim != 3) or ndim != 2:
3628 raise ValueError(
3629 f"Failed to construct cover image from shape {original_shape}"
3630 )
3632 if not has_c_axis:
3633 assert ndim == 2
3634 data = np.repeat(data[:, :, None], 3, axis=2)
3635 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3636 ndim += 1
3638 assert ndim == 3
3640 # transpose axis order such that longest axis comes first...
3641 axis_order: List[int] = list(np.argsort(list(data.shape)))
3642 axis_order.reverse()
3643 # ... and channel axis is last
3644 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3645 axis_order.append(axis_order.pop(c))
3646 axes = [axes[ao] for ao in axis_order]
3647 data = data.transpose(axis_order)
3649 # h, w = data.shape[:2]
3650 # if h / w in (1.0 or 2.0):
3651 # pass
3652 # elif h / w < 2:
3653 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3655 norm_along = (
3656 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3657 )
3658 # normalize the data and map to 8 bit
3659 data = normalize(data, norm_along)
3660 data = (data * 255).astype("uint8")
3662 return data
3664 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3665 assert im0.dtype == im1.dtype == np.uint8
3666 assert im0.shape == im1.shape
3667 assert im0.ndim == 3
3668 N, M, C = im0.shape
3669 assert C == 3
3670 out = np.ones((N, M, C), dtype="uint8")
3671 for c in range(C):
3672 outc = np.tril(im0[..., c])
3673 mask = outc == 0
3674 outc[mask] = np.triu(im1[..., c])[mask]
3675 out[..., c] = outc
3677 return out
3679 if not inputs:
3680 raise ValueError("Missing test input tensor for cover generation.")
3682 if not outputs:
3683 raise ValueError("Missing test output tensor for cover generation.")
3685 ipt_descr, ipt = inputs[0]
3686 out_descr, out = outputs[0]
3688 ipt_img = to_2d_image(ipt, ipt_descr.axes)
3689 out_img = to_2d_image(out, out_descr.axes)
3691 cover_folder = Path(mkdtemp())
3692 if ipt_img.shape == out_img.shape:
3693 covers = [cover_folder / "cover.png"]
3694 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3695 else:
3696 covers = [cover_folder / "input.png", cover_folder / "output.png"]
3697 imwrite(covers[0], ipt_img)
3698 imwrite(covers[1], out_img)
3700 return covers