Coverage for bioimageio/spec/model/v0_5.py: 75%
1325 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-18 12:47 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-18 12:47 +0000
1from __future__ import annotations
3import collections.abc
4import re
5import string
6import warnings
7from abc import ABC
8from copy import deepcopy
9from itertools import chain
10from math import ceil
11from pathlib import Path, PurePosixPath
12from tempfile import mkdtemp
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 ClassVar,
18 Dict,
19 Generic,
20 List,
21 Literal,
22 Mapping,
23 NamedTuple,
24 Optional,
25 Sequence,
26 Set,
27 Tuple,
28 Type,
29 TypeVar,
30 Union,
31 cast,
32)
34import numpy as np
35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
37from loguru import logger
38from numpy.typing import NDArray
39from pydantic import (
40 AfterValidator,
41 Discriminator,
42 Field,
43 RootModel,
44 SerializationInfo,
45 SerializerFunctionWrapHandler,
46 StrictInt,
47 Tag,
48 ValidationInfo,
49 WrapSerializer,
50 field_validator,
51 model_serializer,
52 model_validator,
53)
54from typing_extensions import Annotated, Self, assert_never, get_args
56from .._internal.common_nodes import (
57 InvalidDescr,
58 Node,
59 NodeWithExplicitlySetFields,
60)
61from .._internal.constants import DTYPE_LIMITS
62from .._internal.field_warning import issue_warning, warn
63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
64from .._internal.io import FileDescr as FileDescr
65from .._internal.io import (
66 FileSource,
67 WithSuffix,
68 YamlValue,
69 get_reader,
70 wo_special_file_name,
71)
72from .._internal.io_basics import Sha256 as Sha256
73from .._internal.io_packaging import (
74 FileDescr_,
75 FileSource_,
76 package_file_descr_serializer,
77)
78from .._internal.io_utils import load_array
79from .._internal.node_converter import Converter
80from .._internal.type_guards import is_dict, is_sequence
81from .._internal.types import (
82 AbsoluteTolerance,
83 LowerCaseIdentifier,
84 LowerCaseIdentifierAnno,
85 MismatchedElementsPerMillion,
86 RelativeTolerance,
87)
88from .._internal.types import Datetime as Datetime
89from .._internal.types import Identifier as Identifier
90from .._internal.types import NotEmpty as NotEmpty
91from .._internal.types import SiUnit as SiUnit
92from .._internal.url import HttpUrl as HttpUrl
93from .._internal.validation_context import get_validation_context
94from .._internal.validator_annotations import RestrictCharacters
95from .._internal.version_type import Version as Version
96from .._internal.warning_levels import INFO
97from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
98from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
99from ..dataset.v0_3 import DatasetDescr as DatasetDescr
100from ..dataset.v0_3 import DatasetId as DatasetId
101from ..dataset.v0_3 import LinkedDataset as LinkedDataset
102from ..dataset.v0_3 import Uploader as Uploader
103from ..generic.v0_3 import (
104 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
105)
106from ..generic.v0_3 import Author as Author
107from ..generic.v0_3 import BadgeDescr as BadgeDescr
108from ..generic.v0_3 import CiteEntry as CiteEntry
109from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
110from ..generic.v0_3 import Doi as Doi
111from ..generic.v0_3 import (
112 FileSource_documentation,
113 GenericModelDescrBase,
114 LinkedResourceBase,
115 _author_conv, # pyright: ignore[reportPrivateUsage]
116 _maintainer_conv, # pyright: ignore[reportPrivateUsage]
117)
118from ..generic.v0_3 import LicenseId as LicenseId
119from ..generic.v0_3 import LinkedResource as LinkedResource
120from ..generic.v0_3 import Maintainer as Maintainer
121from ..generic.v0_3 import OrcidId as OrcidId
122from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
123from ..generic.v0_3 import ResourceId as ResourceId
124from .v0_4 import Author as _Author_v0_4
125from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
126from .v0_4 import CallableFromDepencency as CallableFromDepencency
127from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
128from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
129from .v0_4 import ClipDescr as _ClipDescr_v0_4
130from .v0_4 import ClipKwargs as ClipKwargs
131from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
132from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
133from .v0_4 import KnownRunMode as KnownRunMode
134from .v0_4 import ModelDescr as _ModelDescr_v0_4
135from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
136from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
137from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
138from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
139from .v0_4 import ProcessingKwargs as ProcessingKwargs
140from .v0_4 import RunMode as RunMode
141from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
142from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
143from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
144from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
145from .v0_4 import TensorName as _TensorName_v0_4
146from .v0_4 import WeightsFormat as WeightsFormat
147from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
148from .v0_4 import package_weights
150SpaceUnit = Literal[
151 "attometer",
152 "angstrom",
153 "centimeter",
154 "decimeter",
155 "exameter",
156 "femtometer",
157 "foot",
158 "gigameter",
159 "hectometer",
160 "inch",
161 "kilometer",
162 "megameter",
163 "meter",
164 "micrometer",
165 "mile",
166 "millimeter",
167 "nanometer",
168 "parsec",
169 "petameter",
170 "picometer",
171 "terameter",
172 "yard",
173 "yoctometer",
174 "yottameter",
175 "zeptometer",
176 "zettameter",
177]
178"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
180TimeUnit = Literal[
181 "attosecond",
182 "centisecond",
183 "day",
184 "decisecond",
185 "exasecond",
186 "femtosecond",
187 "gigasecond",
188 "hectosecond",
189 "hour",
190 "kilosecond",
191 "megasecond",
192 "microsecond",
193 "millisecond",
194 "minute",
195 "nanosecond",
196 "petasecond",
197 "picosecond",
198 "second",
199 "terasecond",
200 "yoctosecond",
201 "yottasecond",
202 "zeptosecond",
203 "zettasecond",
204]
205"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
207AxisType = Literal["batch", "channel", "index", "time", "space"]
209_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
210 "b": "batch",
211 "t": "time",
212 "i": "index",
213 "c": "channel",
214 "x": "space",
215 "y": "space",
216 "z": "space",
217}
219_AXIS_ID_MAP = {
220 "b": "batch",
221 "t": "time",
222 "i": "index",
223 "c": "channel",
224}
227class TensorId(LowerCaseIdentifier):
228 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
229 Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
230 ]
233def _normalize_axis_id(a: str):
234 a = str(a)
235 normalized = _AXIS_ID_MAP.get(a, a)
236 if a != normalized:
237 logger.opt(depth=3).warning(
238 "Normalized axis id from '{}' to '{}'.", a, normalized
239 )
240 return normalized
243class AxisId(LowerCaseIdentifier):
244 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
245 Annotated[
246 LowerCaseIdentifierAnno,
247 MaxLen(16),
248 AfterValidator(_normalize_axis_id),
249 ]
250 ]
253def _is_batch(a: str) -> bool:
254 return str(a) == "batch"
257def _is_not_batch(a: str) -> bool:
258 return not _is_batch(a)
261NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
263PostprocessingId = Literal[
264 "binarize",
265 "clip",
266 "ensure_dtype",
267 "fixed_zero_mean_unit_variance",
268 "scale_linear",
269 "scale_mean_variance",
270 "scale_range",
271 "sigmoid",
272 "zero_mean_unit_variance",
273]
274PreprocessingId = Literal[
275 "binarize",
276 "clip",
277 "ensure_dtype",
278 "scale_linear",
279 "sigmoid",
280 "zero_mean_unit_variance",
281 "scale_range",
282]
285SAME_AS_TYPE = "<same as type>"
288ParameterizedSize_N = int
289"""
290Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
291"""
294class ParameterizedSize(Node):
295 """Describes a range of valid tensor axis sizes as `size = min + n*step`.
297 - **min** and **step** are given by the model description.
298 - All blocksize paramters n = 0,1,2,... yield a valid `size`.
299 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
300 This allows to adjust the axis size more generically.
301 """
303 N: ClassVar[Type[int]] = ParameterizedSize_N
304 """Positive integer to parameterize this axis"""
306 min: Annotated[int, Gt(0)]
307 step: Annotated[int, Gt(0)]
309 def validate_size(self, size: int) -> int:
310 if size < self.min:
311 raise ValueError(f"size {size} < {self.min}")
312 if (size - self.min) % self.step != 0:
313 raise ValueError(
314 f"axis of size {size} is not parameterized by `min + n*step` ="
315 + f" `{self.min} + n*{self.step}`"
316 )
318 return size
320 def get_size(self, n: ParameterizedSize_N) -> int:
321 return self.min + self.step * n
323 def get_n(self, s: int) -> ParameterizedSize_N:
324 """return smallest n parameterizing a size greater or equal than `s`"""
325 return ceil((s - self.min) / self.step)
328class DataDependentSize(Node):
329 min: Annotated[int, Gt(0)] = 1
330 max: Annotated[Optional[int], Gt(1)] = None
332 @model_validator(mode="after")
333 def _validate_max_gt_min(self):
334 if self.max is not None and self.min >= self.max:
335 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
337 return self
339 def validate_size(self, size: int) -> int:
340 if size < self.min:
341 raise ValueError(f"size {size} < {self.min}")
343 if self.max is not None and size > self.max:
344 raise ValueError(f"size {size} > {self.max}")
346 return size
349class SizeReference(Node):
350 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
352 `axis.size = reference.size * reference.scale / axis.scale + offset`
354 Note:
355 1. The axis and the referenced axis need to have the same unit (or no unit).
356 2. Batch axes may not be referenced.
357 3. Fractions are rounded down.
358 4. If the reference axis is `concatenable` the referencing axis is assumed to be
359 `concatenable` as well with the same block order.
361 Example:
362 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
363 Let's assume that we want to express the image height h in relation to its width w
364 instead of only accepting input images of exactly 100*49 pixels
365 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
367 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
368 >>> h = SpaceInputAxis(
369 ... id=AxisId("h"),
370 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
371 ... unit="millimeter",
372 ... scale=4,
373 ... )
374 >>> print(h.size.get_size(h, w))
375 49
377 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
378 """
380 tensor_id: TensorId
381 """tensor id of the reference axis"""
383 axis_id: AxisId
384 """axis id of the reference axis"""
386 offset: StrictInt = 0
388 def get_size(
389 self,
390 axis: Union[
391 ChannelAxis,
392 IndexInputAxis,
393 IndexOutputAxis,
394 TimeInputAxis,
395 SpaceInputAxis,
396 TimeOutputAxis,
397 TimeOutputAxisWithHalo,
398 SpaceOutputAxis,
399 SpaceOutputAxisWithHalo,
400 ],
401 ref_axis: Union[
402 ChannelAxis,
403 IndexInputAxis,
404 IndexOutputAxis,
405 TimeInputAxis,
406 SpaceInputAxis,
407 TimeOutputAxis,
408 TimeOutputAxisWithHalo,
409 SpaceOutputAxis,
410 SpaceOutputAxisWithHalo,
411 ],
412 n: ParameterizedSize_N = 0,
413 ref_size: Optional[int] = None,
414 ):
415 """Compute the concrete size for a given axis and its reference axis.
417 Args:
418 axis: The axis this `SizeReference` is the size of.
419 ref_axis: The reference axis to compute the size from.
420 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
421 and no fixed **ref_size** is given,
422 **n** is used to compute the size of the parameterized **ref_axis**.
423 ref_size: Overwrite the reference size instead of deriving it from
424 **ref_axis**
425 (**ref_axis.scale** is still used; any given **n** is ignored).
426 """
427 assert (
428 axis.size == self
429 ), "Given `axis.size` is not defined by this `SizeReference`"
431 assert (
432 ref_axis.id == self.axis_id
433 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
435 assert axis.unit == ref_axis.unit, (
436 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
437 f" but {axis.unit}!={ref_axis.unit}"
438 )
439 if ref_size is None:
440 if isinstance(ref_axis.size, (int, float)):
441 ref_size = ref_axis.size
442 elif isinstance(ref_axis.size, ParameterizedSize):
443 ref_size = ref_axis.size.get_size(n)
444 elif isinstance(ref_axis.size, DataDependentSize):
445 raise ValueError(
446 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
447 )
448 elif isinstance(ref_axis.size, SizeReference):
449 raise ValueError(
450 "Reference axis referenced in `SizeReference` may not be sized by a"
451 + " `SizeReference` itself."
452 )
453 else:
454 assert_never(ref_axis.size)
456 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
458 @staticmethod
459 def _get_unit(
460 axis: Union[
461 ChannelAxis,
462 IndexInputAxis,
463 IndexOutputAxis,
464 TimeInputAxis,
465 SpaceInputAxis,
466 TimeOutputAxis,
467 TimeOutputAxisWithHalo,
468 SpaceOutputAxis,
469 SpaceOutputAxisWithHalo,
470 ],
471 ):
472 return axis.unit
475class AxisBase(NodeWithExplicitlySetFields):
476 id: AxisId
477 """An axis id unique across all axes of one tensor."""
479 description: Annotated[str, MaxLen(128)] = ""
482class WithHalo(Node):
483 halo: Annotated[int, Ge(1)]
484 """The halo should be cropped from the output tensor to avoid boundary effects.
485 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
486 To document a halo that is already cropped by the model use `size.offset` instead."""
488 size: Annotated[
489 SizeReference,
490 Field(
491 examples=[
492 10,
493 SizeReference(
494 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
495 ).model_dump(mode="json"),
496 ]
497 ),
498 ]
499 """reference to another axis with an optional offset (see `SizeReference`)"""
502BATCH_AXIS_ID = AxisId("batch")
505class BatchAxis(AxisBase):
506 implemented_type: ClassVar[Literal["batch"]] = "batch"
507 if TYPE_CHECKING:
508 type: Literal["batch"] = "batch"
509 else:
510 type: Literal["batch"]
512 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
513 size: Optional[Literal[1]] = None
514 """The batch size may be fixed to 1,
515 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
517 @property
518 def scale(self):
519 return 1.0
521 @property
522 def concatenable(self):
523 return True
525 @property
526 def unit(self):
527 return None
530class ChannelAxis(AxisBase):
531 implemented_type: ClassVar[Literal["channel"]] = "channel"
532 if TYPE_CHECKING:
533 type: Literal["channel"] = "channel"
534 else:
535 type: Literal["channel"]
537 id: NonBatchAxisId = AxisId("channel")
538 channel_names: NotEmpty[List[Identifier]]
540 @property
541 def size(self) -> int:
542 return len(self.channel_names)
544 @property
545 def concatenable(self):
546 return False
548 @property
549 def scale(self) -> float:
550 return 1.0
552 @property
553 def unit(self):
554 return None
557class IndexAxisBase(AxisBase):
558 implemented_type: ClassVar[Literal["index"]] = "index"
559 if TYPE_CHECKING:
560 type: Literal["index"] = "index"
561 else:
562 type: Literal["index"]
564 id: NonBatchAxisId = AxisId("index")
566 @property
567 def scale(self) -> float:
568 return 1.0
570 @property
571 def unit(self):
572 return None
575class _WithInputAxisSize(Node):
576 size: Annotated[
577 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
578 Field(
579 examples=[
580 10,
581 ParameterizedSize(min=32, step=16).model_dump(mode="json"),
582 SizeReference(
583 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
584 ).model_dump(mode="json"),
585 ]
586 ),
587 ]
588 """The size/length of this axis can be specified as
589 - fixed integer
590 - parameterized series of valid sizes (`ParameterizedSize`)
591 - reference to another axis with an optional offset (`SizeReference`)
592 """
595class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
596 concatenable: bool = False
597 """If a model has a `concatenable` input axis, it can be processed blockwise,
598 splitting a longer sample axis into blocks matching its input tensor description.
599 Output axes are concatenable if they have a `SizeReference` to a concatenable
600 input axis.
601 """
604class IndexOutputAxis(IndexAxisBase):
605 size: Annotated[
606 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
607 Field(
608 examples=[
609 10,
610 SizeReference(
611 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
612 ).model_dump(mode="json"),
613 ]
614 ),
615 ]
616 """The size/length of this axis can be specified as
617 - fixed integer
618 - reference to another axis with an optional offset (`SizeReference`)
619 - data dependent size using `DataDependentSize` (size is only known after model inference)
620 """
623class TimeAxisBase(AxisBase):
624 implemented_type: ClassVar[Literal["time"]] = "time"
625 if TYPE_CHECKING:
626 type: Literal["time"] = "time"
627 else:
628 type: Literal["time"]
630 id: NonBatchAxisId = AxisId("time")
631 unit: Optional[TimeUnit] = None
632 scale: Annotated[float, Gt(0)] = 1.0
635class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
636 concatenable: bool = False
637 """If a model has a `concatenable` input axis, it can be processed blockwise,
638 splitting a longer sample axis into blocks matching its input tensor description.
639 Output axes are concatenable if they have a `SizeReference` to a concatenable
640 input axis.
641 """
644class SpaceAxisBase(AxisBase):
645 implemented_type: ClassVar[Literal["space"]] = "space"
646 if TYPE_CHECKING:
647 type: Literal["space"] = "space"
648 else:
649 type: Literal["space"]
651 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
652 unit: Optional[SpaceUnit] = None
653 scale: Annotated[float, Gt(0)] = 1.0
656class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
657 concatenable: bool = False
658 """If a model has a `concatenable` input axis, it can be processed blockwise,
659 splitting a longer sample axis into blocks matching its input tensor description.
660 Output axes are concatenable if they have a `SizeReference` to a concatenable
661 input axis.
662 """
665INPUT_AXIS_TYPES = (
666 BatchAxis,
667 ChannelAxis,
668 IndexInputAxis,
669 TimeInputAxis,
670 SpaceInputAxis,
671)
672"""intended for isinstance comparisons in py<3.10"""
674_InputAxisUnion = Union[
675 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
676]
677InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
680class _WithOutputAxisSize(Node):
681 size: Annotated[
682 Union[Annotated[int, Gt(0)], SizeReference],
683 Field(
684 examples=[
685 10,
686 SizeReference(
687 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
688 ).model_dump(mode="json"),
689 ]
690 ),
691 ]
692 """The size/length of this axis can be specified as
693 - fixed integer
694 - reference to another axis with an optional offset (see `SizeReference`)
695 """
698class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
699 pass
702class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
703 pass
706def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
707 if isinstance(v, dict):
708 return "with_halo" if "halo" in v else "wo_halo"
709 else:
710 return "with_halo" if hasattr(v, "halo") else "wo_halo"
713_TimeOutputAxisUnion = Annotated[
714 Union[
715 Annotated[TimeOutputAxis, Tag("wo_halo")],
716 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
717 ],
718 Discriminator(_get_halo_axis_discriminator_value),
719]
722class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
723 pass
726class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
727 pass
730_SpaceOutputAxisUnion = Annotated[
731 Union[
732 Annotated[SpaceOutputAxis, Tag("wo_halo")],
733 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
734 ],
735 Discriminator(_get_halo_axis_discriminator_value),
736]
739_OutputAxisUnion = Union[
740 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
741]
742OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
744OUTPUT_AXIS_TYPES = (
745 BatchAxis,
746 ChannelAxis,
747 IndexOutputAxis,
748 TimeOutputAxis,
749 TimeOutputAxisWithHalo,
750 SpaceOutputAxis,
751 SpaceOutputAxisWithHalo,
752)
753"""intended for isinstance comparisons in py<3.10"""
756AnyAxis = Union[InputAxis, OutputAxis]
758ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
759"""intended for isinstance comparisons in py<3.10"""
761TVs = Union[
762 NotEmpty[List[int]],
763 NotEmpty[List[float]],
764 NotEmpty[List[bool]],
765 NotEmpty[List[str]],
766]
769NominalOrOrdinalDType = Literal[
770 "float32",
771 "float64",
772 "uint8",
773 "int8",
774 "uint16",
775 "int16",
776 "uint32",
777 "int32",
778 "uint64",
779 "int64",
780 "bool",
781]
784class NominalOrOrdinalDataDescr(Node):
785 values: TVs
786 """A fixed set of nominal or an ascending sequence of ordinal values.
787 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
788 String `values` are interpreted as labels for tensor values 0, ..., N.
789 Note: as YAML 1.2 does not natively support a "set" datatype,
790 nominal values should be given as a sequence (aka list/array) as well.
791 """
793 type: Annotated[
794 NominalOrOrdinalDType,
795 Field(
796 examples=[
797 "float32",
798 "uint8",
799 "uint16",
800 "int64",
801 "bool",
802 ],
803 ),
804 ] = "uint8"
806 @model_validator(mode="after")
807 def _validate_values_match_type(
808 self,
809 ) -> Self:
810 incompatible: List[Any] = []
811 for v in self.values:
812 if self.type == "bool":
813 if not isinstance(v, bool):
814 incompatible.append(v)
815 elif self.type in DTYPE_LIMITS:
816 if (
817 isinstance(v, (int, float))
818 and (
819 v < DTYPE_LIMITS[self.type].min
820 or v > DTYPE_LIMITS[self.type].max
821 )
822 or (isinstance(v, str) and "uint" not in self.type)
823 or (isinstance(v, float) and "int" in self.type)
824 ):
825 incompatible.append(v)
826 else:
827 incompatible.append(v)
829 if len(incompatible) == 5:
830 incompatible.append("...")
831 break
833 if incompatible:
834 raise ValueError(
835 f"data type '{self.type}' incompatible with values {incompatible}"
836 )
838 return self
840 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
842 @property
843 def range(self):
844 if isinstance(self.values[0], str):
845 return 0, len(self.values) - 1
846 else:
847 return min(self.values), max(self.values)
850IntervalOrRatioDType = Literal[
851 "float32",
852 "float64",
853 "uint8",
854 "int8",
855 "uint16",
856 "int16",
857 "uint32",
858 "int32",
859 "uint64",
860 "int64",
861]
864class IntervalOrRatioDataDescr(Node):
865 type: Annotated[ # todo: rename to dtype
866 IntervalOrRatioDType,
867 Field(
868 examples=["float32", "float64", "uint8", "uint16"],
869 ),
870 ] = "float32"
871 range: Tuple[Optional[float], Optional[float]] = (
872 None,
873 None,
874 )
875 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
876 `None` corresponds to min/max of what can be expressed by **type**."""
877 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
878 scale: float = 1.0
879 """Scale for data on an interval (or ratio) scale."""
880 offset: Optional[float] = None
881 """Offset for data on a ratio scale."""
883 @model_validator(mode="before")
884 def _replace_inf(cls, data: Any):
885 if is_dict(data):
886 if "range" in data and is_sequence(data["range"]):
887 forbidden = (
888 "inf",
889 "-inf",
890 ".inf",
891 "-.inf",
892 float("inf"),
893 float("-inf"),
894 )
895 if any(v in forbidden for v in data["range"]):
896 issue_warning("replaced 'inf' value", value=data["range"])
898 data["range"] = tuple(
899 (None if v in forbidden else v) for v in data["range"]
900 )
902 return data
905TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
908class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
909 """processing base class"""
912class BinarizeKwargs(ProcessingKwargs):
913 """key word arguments for `BinarizeDescr`"""
915 threshold: float
916 """The fixed threshold"""
919class BinarizeAlongAxisKwargs(ProcessingKwargs):
920 """key word arguments for `BinarizeDescr`"""
922 threshold: NotEmpty[List[float]]
923 """The fixed threshold values along `axis`"""
925 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
926 """The `threshold` axis"""
929class BinarizeDescr(ProcessingDescrBase):
930 """Binarize the tensor with a fixed threshold.
932 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
933 will be set to one, values below the threshold to zero.
935 Examples:
936 - in YAML
937 ```yaml
938 postprocessing:
939 - id: binarize
940 kwargs:
941 axis: 'channel'
942 threshold: [0.25, 0.5, 0.75]
943 ```
944 - in Python:
945 >>> postprocessing = [BinarizeDescr(
946 ... kwargs=BinarizeAlongAxisKwargs(
947 ... axis=AxisId('channel'),
948 ... threshold=[0.25, 0.5, 0.75],
949 ... )
950 ... )]
951 """
953 implemented_id: ClassVar[Literal["binarize"]] = "binarize"
954 if TYPE_CHECKING:
955 id: Literal["binarize"] = "binarize"
956 else:
957 id: Literal["binarize"]
958 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
961class ClipDescr(ProcessingDescrBase):
962 """Set tensor values below min to min and above max to max.
964 See `ScaleRangeDescr` for examples.
965 """
967 implemented_id: ClassVar[Literal["clip"]] = "clip"
968 if TYPE_CHECKING:
969 id: Literal["clip"] = "clip"
970 else:
971 id: Literal["clip"]
973 kwargs: ClipKwargs
976class EnsureDtypeKwargs(ProcessingKwargs):
977 """key word arguments for `EnsureDtypeDescr`"""
979 dtype: Literal[
980 "float32",
981 "float64",
982 "uint8",
983 "int8",
984 "uint16",
985 "int16",
986 "uint32",
987 "int32",
988 "uint64",
989 "int64",
990 "bool",
991 ]
994class EnsureDtypeDescr(ProcessingDescrBase):
995 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
997 This can for example be used to ensure the inner neural network model gets a
998 different input tensor data type than the fully described bioimage.io model does.
1000 Examples:
1001 The described bioimage.io model (incl. preprocessing) accepts any
1002 float32-compatible tensor, normalizes it with percentiles and clipping and then
1003 casts it to uint8, which is what the neural network in this example expects.
1004 - in YAML
1005 ```yaml
1006 inputs:
1007 - data:
1008 type: float32 # described bioimage.io model is compatible with any float32 input tensor
1009 preprocessing:
1010 - id: scale_range
1011 kwargs:
1012 axes: ['y', 'x']
1013 max_percentile: 99.8
1014 min_percentile: 5.0
1015 - id: clip
1016 kwargs:
1017 min: 0.0
1018 max: 1.0
1019 - id: ensure_dtype # the neural network of the model requires uint8
1020 kwargs:
1021 dtype: uint8
1022 ```
1023 - in Python:
1024 >>> preprocessing = [
1025 ... ScaleRangeDescr(
1026 ... kwargs=ScaleRangeKwargs(
1027 ... axes= (AxisId('y'), AxisId('x')),
1028 ... max_percentile= 99.8,
1029 ... min_percentile= 5.0,
1030 ... )
1031 ... ),
1032 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1033 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1034 ... ]
1035 """
1037 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1038 if TYPE_CHECKING:
1039 id: Literal["ensure_dtype"] = "ensure_dtype"
1040 else:
1041 id: Literal["ensure_dtype"]
1043 kwargs: EnsureDtypeKwargs
1046class ScaleLinearKwargs(ProcessingKwargs):
1047 """Key word arguments for `ScaleLinearDescr`"""
1049 gain: float = 1.0
1050 """multiplicative factor"""
1052 offset: float = 0.0
1053 """additive term"""
1055 @model_validator(mode="after")
1056 def _validate(self) -> Self:
1057 if self.gain == 1.0 and self.offset == 0.0:
1058 raise ValueError(
1059 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1060 + " != 0.0."
1061 )
1063 return self
1066class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1067 """Key word arguments for `ScaleLinearDescr`"""
1069 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1070 """The axis of gain and offset values."""
1072 gain: Union[float, NotEmpty[List[float]]] = 1.0
1073 """multiplicative factor"""
1075 offset: Union[float, NotEmpty[List[float]]] = 0.0
1076 """additive term"""
1078 @model_validator(mode="after")
1079 def _validate(self) -> Self:
1081 if isinstance(self.gain, list):
1082 if isinstance(self.offset, list):
1083 if len(self.gain) != len(self.offset):
1084 raise ValueError(
1085 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1086 )
1087 else:
1088 self.offset = [float(self.offset)] * len(self.gain)
1089 elif isinstance(self.offset, list):
1090 self.gain = [float(self.gain)] * len(self.offset)
1091 else:
1092 raise ValueError(
1093 "Do not specify an `axis` for scalar gain and offset values."
1094 )
1096 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1097 raise ValueError(
1098 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1099 + " != 0.0."
1100 )
1102 return self
1105class ScaleLinearDescr(ProcessingDescrBase):
1106 """Fixed linear scaling.
1108 Examples:
1109 1. Scale with scalar gain and offset
1110 - in YAML
1111 ```yaml
1112 preprocessing:
1113 - id: scale_linear
1114 kwargs:
1115 gain: 2.0
1116 offset: 3.0
1117 ```
1118 - in Python:
1119 >>> preprocessing = [
1120 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1121 ... ]
1123 2. Independent scaling along an axis
1124 - in YAML
1125 ```yaml
1126 preprocessing:
1127 - id: scale_linear
1128 kwargs:
1129 axis: 'channel'
1130 gain: [1.0, 2.0, 3.0]
1131 ```
1132 - in Python:
1133 >>> preprocessing = [
1134 ... ScaleLinearDescr(
1135 ... kwargs=ScaleLinearAlongAxisKwargs(
1136 ... axis=AxisId("channel"),
1137 ... gain=[1.0, 2.0, 3.0],
1138 ... )
1139 ... )
1140 ... ]
1142 """
1144 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1145 if TYPE_CHECKING:
1146 id: Literal["scale_linear"] = "scale_linear"
1147 else:
1148 id: Literal["scale_linear"]
1149 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1152class SigmoidDescr(ProcessingDescrBase):
1153 """The logistic sigmoid funciton, a.k.a. expit function.
1155 Examples:
1156 - in YAML
1157 ```yaml
1158 postprocessing:
1159 - id: sigmoid
1160 ```
1161 - in Python:
1162 >>> postprocessing = [SigmoidDescr()]
1163 """
1165 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1166 if TYPE_CHECKING:
1167 id: Literal["sigmoid"] = "sigmoid"
1168 else:
1169 id: Literal["sigmoid"]
1171 @property
1172 def kwargs(self) -> ProcessingKwargs:
1173 """empty kwargs"""
1174 return ProcessingKwargs()
1177class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1178 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1180 mean: float
1181 """The mean value to normalize with."""
1183 std: Annotated[float, Ge(1e-6)]
1184 """The standard deviation value to normalize with."""
1187class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1188 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1190 mean: NotEmpty[List[float]]
1191 """The mean value(s) to normalize with."""
1193 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1194 """The standard deviation value(s) to normalize with.
1195 Size must match `mean` values."""
1197 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1198 """The axis of the mean/std values to normalize each entry along that dimension
1199 separately."""
1201 @model_validator(mode="after")
1202 def _mean_and_std_match(self) -> Self:
1203 if len(self.mean) != len(self.std):
1204 raise ValueError(
1205 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1206 + " must match."
1207 )
1209 return self
1212class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1213 """Subtract a given mean and divide by the standard deviation.
1215 Normalize with fixed, precomputed values for
1216 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1217 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1218 axes.
1220 Examples:
1221 1. scalar value for whole tensor
1222 - in YAML
1223 ```yaml
1224 preprocessing:
1225 - id: fixed_zero_mean_unit_variance
1226 kwargs:
1227 mean: 103.5
1228 std: 13.7
1229 ```
1230 - in Python
1231 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1232 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1233 ... )]
1235 2. independently along an axis
1236 - in YAML
1237 ```yaml
1238 preprocessing:
1239 - id: fixed_zero_mean_unit_variance
1240 kwargs:
1241 axis: channel
1242 mean: [101.5, 102.5, 103.5]
1243 std: [11.7, 12.7, 13.7]
1244 ```
1245 - in Python
1246 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1247 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1248 ... axis=AxisId("channel"),
1249 ... mean=[101.5, 102.5, 103.5],
1250 ... std=[11.7, 12.7, 13.7],
1251 ... )
1252 ... )]
1253 """
1255 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1256 "fixed_zero_mean_unit_variance"
1257 )
1258 if TYPE_CHECKING:
1259 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1260 else:
1261 id: Literal["fixed_zero_mean_unit_variance"]
1263 kwargs: Union[
1264 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1265 ]
1268class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1269 """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1271 axes: Annotated[
1272 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1273 ] = None
1274 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1275 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1276 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1277 To normalize each sample independently leave out the 'batch' axis.
1278 Default: Scale all axes jointly."""
1280 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1281 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1284class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1285 """Subtract mean and divide by variance.
1287 Examples:
1288 Subtract tensor mean and variance
1289 - in YAML
1290 ```yaml
1291 preprocessing:
1292 - id: zero_mean_unit_variance
1293 ```
1294 - in Python
1295 >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1296 """
1298 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1299 "zero_mean_unit_variance"
1300 )
1301 if TYPE_CHECKING:
1302 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1303 else:
1304 id: Literal["zero_mean_unit_variance"]
1306 kwargs: ZeroMeanUnitVarianceKwargs = Field(
1307 default_factory=ZeroMeanUnitVarianceKwargs
1308 )
1311class ScaleRangeKwargs(ProcessingKwargs):
1312 """key word arguments for `ScaleRangeDescr`
1314 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1315 this processing step normalizes data to the [0, 1] intervall.
1316 For other percentiles the normalized values will partially be outside the [0, 1]
1317 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1318 normalized values to a range.
1319 """
1321 axes: Annotated[
1322 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1323 ] = None
1324 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1325 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1326 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1327 To normalize samples independently, leave out the "batch" axis.
1328 Default: Scale all axes jointly."""
1330 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1331 """The lower percentile used to determine the value to align with zero."""
1333 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1334 """The upper percentile used to determine the value to align with one.
1335 Has to be bigger than `min_percentile`.
1336 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1337 accepting percentiles specified in the range 0.0 to 1.0."""
1339 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1340 """Epsilon for numeric stability.
1341 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1342 with `v_lower,v_upper` values at the respective percentiles."""
1344 reference_tensor: Optional[TensorId] = None
1345 """Tensor ID to compute the percentiles from. Default: The tensor itself.
1346 For any tensor in `inputs` only input tensor references are allowed."""
1348 @field_validator("max_percentile", mode="after")
1349 @classmethod
1350 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1351 if (min_p := info.data["min_percentile"]) >= value:
1352 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1354 return value
1357class ScaleRangeDescr(ProcessingDescrBase):
1358 """Scale with percentiles.
1360 Examples:
1361 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1362 - in YAML
1363 ```yaml
1364 preprocessing:
1365 - id: scale_range
1366 kwargs:
1367 axes: ['y', 'x']
1368 max_percentile: 99.8
1369 min_percentile: 5.0
1370 ```
1371 - in Python
1372 >>> preprocessing = [
1373 ... ScaleRangeDescr(
1374 ... kwargs=ScaleRangeKwargs(
1375 ... axes= (AxisId('y'), AxisId('x')),
1376 ... max_percentile= 99.8,
1377 ... min_percentile= 5.0,
1378 ... )
1379 ... ),
1380 ... ClipDescr(
1381 ... kwargs=ClipKwargs(
1382 ... min=0.0,
1383 ... max=1.0,
1384 ... )
1385 ... ),
1386 ... ]
1388 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1389 - in YAML
1390 ```yaml
1391 preprocessing:
1392 - id: scale_range
1393 kwargs:
1394 axes: ['y', 'x']
1395 max_percentile: 99.8
1396 min_percentile: 5.0
1397 - id: scale_range
1398 - id: clip
1399 kwargs:
1400 min: 0.0
1401 max: 1.0
1402 ```
1403 - in Python
1404 >>> preprocessing = [ScaleRangeDescr(
1405 ... kwargs=ScaleRangeKwargs(
1406 ... axes= (AxisId('y'), AxisId('x')),
1407 ... max_percentile= 99.8,
1408 ... min_percentile= 5.0,
1409 ... )
1410 ... )]
1412 """
1414 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1415 if TYPE_CHECKING:
1416 id: Literal["scale_range"] = "scale_range"
1417 else:
1418 id: Literal["scale_range"]
1419 kwargs: ScaleRangeKwargs
1422class ScaleMeanVarianceKwargs(ProcessingKwargs):
1423 """key word arguments for `ScaleMeanVarianceKwargs`"""
1425 reference_tensor: TensorId
1426 """Name of tensor to match."""
1428 axes: Annotated[
1429 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1430 ] = None
1431 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1432 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1433 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1434 To normalize samples independently, leave out the 'batch' axis.
1435 Default: Scale all axes jointly."""
1437 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1438 """Epsilon for numeric stability:
1439 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1442class ScaleMeanVarianceDescr(ProcessingDescrBase):
1443 """Scale a tensor's data distribution to match another tensor's mean/std.
1444 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1445 """
1447 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1448 if TYPE_CHECKING:
1449 id: Literal["scale_mean_variance"] = "scale_mean_variance"
1450 else:
1451 id: Literal["scale_mean_variance"]
1452 kwargs: ScaleMeanVarianceKwargs
1455PreprocessingDescr = Annotated[
1456 Union[
1457 BinarizeDescr,
1458 ClipDescr,
1459 EnsureDtypeDescr,
1460 ScaleLinearDescr,
1461 SigmoidDescr,
1462 FixedZeroMeanUnitVarianceDescr,
1463 ZeroMeanUnitVarianceDescr,
1464 ScaleRangeDescr,
1465 ],
1466 Discriminator("id"),
1467]
1468PostprocessingDescr = Annotated[
1469 Union[
1470 BinarizeDescr,
1471 ClipDescr,
1472 EnsureDtypeDescr,
1473 ScaleLinearDescr,
1474 SigmoidDescr,
1475 FixedZeroMeanUnitVarianceDescr,
1476 ZeroMeanUnitVarianceDescr,
1477 ScaleRangeDescr,
1478 ScaleMeanVarianceDescr,
1479 ],
1480 Discriminator("id"),
1481]
1483IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1486class TensorDescrBase(Node, Generic[IO_AxisT]):
1487 id: TensorId
1488 """Tensor id. No duplicates are allowed."""
1490 description: Annotated[str, MaxLen(128)] = ""
1491 """free text description"""
1493 axes: NotEmpty[Sequence[IO_AxisT]]
1494 """tensor axes"""
1496 @property
1497 def shape(self):
1498 return tuple(a.size for a in self.axes)
1500 @field_validator("axes", mode="after", check_fields=False)
1501 @classmethod
1502 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1503 batch_axes = [a for a in axes if a.type == "batch"]
1504 if len(batch_axes) > 1:
1505 raise ValueError(
1506 f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1507 )
1509 seen_ids: Set[AxisId] = set()
1510 duplicate_axes_ids: Set[AxisId] = set()
1511 for a in axes:
1512 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1514 if duplicate_axes_ids:
1515 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1517 return axes
1519 test_tensor: FileDescr_
1520 """An example tensor to use for testing.
1521 Using the model with the test input tensors is expected to yield the test output tensors.
1522 Each test tensor has be a an ndarray in the
1523 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1524 The file extension must be '.npy'."""
1526 sample_tensor: Optional[FileDescr_] = None
1527 """A sample tensor to illustrate a possible input/output for the model,
1528 The sample image primarily serves to inform a human user about an example use case
1529 and is typically stored as .hdf5, .png or .tiff.
1530 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1531 (numpy's `.npy` format is not supported).
1532 The image dimensionality has to match the number of axes specified in this tensor description.
1533 """
1535 @model_validator(mode="after")
1536 def _validate_sample_tensor(self) -> Self:
1537 if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1538 return self
1540 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1541 tensor: NDArray[Any] = imread(
1542 reader.read(),
1543 extension=PurePosixPath(reader.original_file_name).suffix,
1544 )
1545 n_dims = len(tensor.squeeze().shape)
1546 n_dims_min = n_dims_max = len(self.axes)
1548 for a in self.axes:
1549 if isinstance(a, BatchAxis):
1550 n_dims_min -= 1
1551 elif isinstance(a.size, int):
1552 if a.size == 1:
1553 n_dims_min -= 1
1554 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1555 if a.size.min == 1:
1556 n_dims_min -= 1
1557 elif isinstance(a.size, SizeReference):
1558 if a.size.offset < 2:
1559 # size reference may result in singleton axis
1560 n_dims_min -= 1
1561 else:
1562 assert_never(a.size)
1564 n_dims_min = max(0, n_dims_min)
1565 if n_dims < n_dims_min or n_dims > n_dims_max:
1566 raise ValueError(
1567 f"Expected sample tensor to have {n_dims_min} to"
1568 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1569 )
1571 return self
1573 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1574 IntervalOrRatioDataDescr()
1575 )
1576 """Description of the tensor's data values, optionally per channel.
1577 If specified per channel, the data `type` needs to match across channels."""
1579 @property
1580 def dtype(
1581 self,
1582 ) -> Literal[
1583 "float32",
1584 "float64",
1585 "uint8",
1586 "int8",
1587 "uint16",
1588 "int16",
1589 "uint32",
1590 "int32",
1591 "uint64",
1592 "int64",
1593 "bool",
1594 ]:
1595 """dtype as specified under `data.type` or `data[i].type`"""
1596 if isinstance(self.data, collections.abc.Sequence):
1597 return self.data[0].type
1598 else:
1599 return self.data.type
1601 @field_validator("data", mode="after")
1602 @classmethod
1603 def _check_data_type_across_channels(
1604 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1605 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1606 if not isinstance(value, list):
1607 return value
1609 dtypes = {t.type for t in value}
1610 if len(dtypes) > 1:
1611 raise ValueError(
1612 "Tensor data descriptions per channel need to agree in their data"
1613 + f" `type`, but found {dtypes}."
1614 )
1616 return value
1618 @model_validator(mode="after")
1619 def _check_data_matches_channelaxis(self) -> Self:
1620 if not isinstance(self.data, (list, tuple)):
1621 return self
1623 for a in self.axes:
1624 if isinstance(a, ChannelAxis):
1625 size = a.size
1626 assert isinstance(size, int)
1627 break
1628 else:
1629 return self
1631 if len(self.data) != size:
1632 raise ValueError(
1633 f"Got tensor data descriptions for {len(self.data)} channels, but"
1634 + f" '{a.id}' axis has size {size}."
1635 )
1637 return self
1639 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1640 if len(array.shape) != len(self.axes):
1641 raise ValueError(
1642 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1643 + f" incompatible with {len(self.axes)} axes."
1644 )
1645 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1648class InputTensorDescr(TensorDescrBase[InputAxis]):
1649 id: TensorId = TensorId("input")
1650 """Input tensor id.
1651 No duplicates are allowed across all inputs and outputs."""
1653 optional: bool = False
1654 """indicates that this tensor may be `None`"""
1656 preprocessing: List[PreprocessingDescr] = Field(
1657 default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1658 )
1660 """Description of how this input should be preprocessed.
1662 notes:
1663 - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1664 to ensure an input tensor's data type matches the input tensor's data description.
1665 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1666 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1667 changing the data type.
1668 """
1670 @model_validator(mode="after")
1671 def _validate_preprocessing_kwargs(self) -> Self:
1672 axes_ids = [a.id for a in self.axes]
1673 for p in self.preprocessing:
1674 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1675 if kwargs_axes is None:
1676 continue
1678 if not isinstance(kwargs_axes, collections.abc.Sequence):
1679 raise ValueError(
1680 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1681 )
1683 if any(a not in axes_ids for a in kwargs_axes):
1684 raise ValueError(
1685 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1686 )
1688 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1689 dtype = self.data.type
1690 else:
1691 dtype = self.data[0].type
1693 # ensure `preprocessing` begins with `EnsureDtypeDescr`
1694 if not self.preprocessing or not isinstance(
1695 self.preprocessing[0], EnsureDtypeDescr
1696 ):
1697 self.preprocessing.insert(
1698 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1699 )
1701 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1702 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1703 self.preprocessing.append(
1704 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1705 )
1707 return self
1710def convert_axes(
1711 axes: str,
1712 *,
1713 shape: Union[
1714 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1715 ],
1716 tensor_type: Literal["input", "output"],
1717 halo: Optional[Sequence[int]],
1718 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1719):
1720 ret: List[AnyAxis] = []
1721 for i, a in enumerate(axes):
1722 axis_type = _AXIS_TYPE_MAP.get(a, a)
1723 if axis_type == "batch":
1724 ret.append(BatchAxis())
1725 continue
1727 scale = 1.0
1728 if isinstance(shape, _ParameterizedInputShape_v0_4):
1729 if shape.step[i] == 0:
1730 size = shape.min[i]
1731 else:
1732 size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1733 elif isinstance(shape, _ImplicitOutputShape_v0_4):
1734 ref_t = str(shape.reference_tensor)
1735 if ref_t.count(".") == 1:
1736 t_id, orig_a_id = ref_t.split(".")
1737 else:
1738 t_id = ref_t
1739 orig_a_id = a
1741 a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1742 if not (orig_scale := shape.scale[i]):
1743 # old way to insert a new axis dimension
1744 size = int(2 * shape.offset[i])
1745 else:
1746 scale = 1 / orig_scale
1747 if axis_type in ("channel", "index"):
1748 # these axes no longer have a scale
1749 offset_from_scale = orig_scale * size_refs.get(
1750 _TensorName_v0_4(t_id), {}
1751 ).get(orig_a_id, 0)
1752 else:
1753 offset_from_scale = 0
1754 size = SizeReference(
1755 tensor_id=TensorId(t_id),
1756 axis_id=AxisId(a_id),
1757 offset=int(offset_from_scale + 2 * shape.offset[i]),
1758 )
1759 else:
1760 size = shape[i]
1762 if axis_type == "time":
1763 if tensor_type == "input":
1764 ret.append(TimeInputAxis(size=size, scale=scale))
1765 else:
1766 assert not isinstance(size, ParameterizedSize)
1767 if halo is None:
1768 ret.append(TimeOutputAxis(size=size, scale=scale))
1769 else:
1770 assert not isinstance(size, int)
1771 ret.append(
1772 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1773 )
1775 elif axis_type == "index":
1776 if tensor_type == "input":
1777 ret.append(IndexInputAxis(size=size))
1778 else:
1779 if isinstance(size, ParameterizedSize):
1780 size = DataDependentSize(min=size.min)
1782 ret.append(IndexOutputAxis(size=size))
1783 elif axis_type == "channel":
1784 assert not isinstance(size, ParameterizedSize)
1785 if isinstance(size, SizeReference):
1786 warnings.warn(
1787 "Conversion of channel size from an implicit output shape may be"
1788 + " wrong"
1789 )
1790 ret.append(
1791 ChannelAxis(
1792 channel_names=[
1793 Identifier(f"channel{i}") for i in range(size.offset)
1794 ]
1795 )
1796 )
1797 else:
1798 ret.append(
1799 ChannelAxis(
1800 channel_names=[Identifier(f"channel{i}") for i in range(size)]
1801 )
1802 )
1803 elif axis_type == "space":
1804 if tensor_type == "input":
1805 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1806 else:
1807 assert not isinstance(size, ParameterizedSize)
1808 if halo is None or halo[i] == 0:
1809 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1810 elif isinstance(size, int):
1811 raise NotImplementedError(
1812 f"output axis with halo and fixed size (here {size}) not allowed"
1813 )
1814 else:
1815 ret.append(
1816 SpaceOutputAxisWithHalo(
1817 id=AxisId(a), size=size, scale=scale, halo=halo[i]
1818 )
1819 )
1821 return ret
1824def _axes_letters_to_ids(
1825 axes: Optional[str],
1826) -> Optional[List[AxisId]]:
1827 if axes is None:
1828 return None
1830 return [AxisId(a) for a in axes]
1833def _get_complement_v04_axis(
1834 tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1835) -> Optional[AxisId]:
1836 if axes is None:
1837 return None
1839 non_complement_axes = set(axes) | {"b"}
1840 complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1841 if len(complement_axes) > 1:
1842 raise ValueError(
1843 f"Expected none or a single complement axis, but axes '{axes}' "
1844 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1845 )
1847 return None if not complement_axes else AxisId(complement_axes[0])
1850def _convert_proc(
1851 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1852 tensor_axes: Sequence[str],
1853) -> Union[PreprocessingDescr, PostprocessingDescr]:
1854 if isinstance(p, _BinarizeDescr_v0_4):
1855 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1856 elif isinstance(p, _ClipDescr_v0_4):
1857 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1858 elif isinstance(p, _SigmoidDescr_v0_4):
1859 return SigmoidDescr()
1860 elif isinstance(p, _ScaleLinearDescr_v0_4):
1861 axes = _axes_letters_to_ids(p.kwargs.axes)
1862 if p.kwargs.axes is None:
1863 axis = None
1864 else:
1865 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1867 if axis is None:
1868 assert not isinstance(p.kwargs.gain, list)
1869 assert not isinstance(p.kwargs.offset, list)
1870 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1871 else:
1872 kwargs = ScaleLinearAlongAxisKwargs(
1873 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1874 )
1875 return ScaleLinearDescr(kwargs=kwargs)
1876 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1877 return ScaleMeanVarianceDescr(
1878 kwargs=ScaleMeanVarianceKwargs(
1879 axes=_axes_letters_to_ids(p.kwargs.axes),
1880 reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1881 eps=p.kwargs.eps,
1882 )
1883 )
1884 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1885 if p.kwargs.mode == "fixed":
1886 mean = p.kwargs.mean
1887 std = p.kwargs.std
1888 assert mean is not None
1889 assert std is not None
1891 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1893 if axis is None:
1894 return FixedZeroMeanUnitVarianceDescr(
1895 kwargs=FixedZeroMeanUnitVarianceKwargs(
1896 mean=mean, std=std # pyright: ignore[reportArgumentType]
1897 )
1898 )
1899 else:
1900 if not isinstance(mean, list):
1901 mean = [float(mean)]
1902 if not isinstance(std, list):
1903 std = [float(std)]
1905 return FixedZeroMeanUnitVarianceDescr(
1906 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1907 axis=axis, mean=mean, std=std
1908 )
1909 )
1911 else:
1912 axes = _axes_letters_to_ids(p.kwargs.axes) or []
1913 if p.kwargs.mode == "per_dataset":
1914 axes = [AxisId("batch")] + axes
1915 if not axes:
1916 axes = None
1917 return ZeroMeanUnitVarianceDescr(
1918 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1919 )
1921 elif isinstance(p, _ScaleRangeDescr_v0_4):
1922 return ScaleRangeDescr(
1923 kwargs=ScaleRangeKwargs(
1924 axes=_axes_letters_to_ids(p.kwargs.axes),
1925 min_percentile=p.kwargs.min_percentile,
1926 max_percentile=p.kwargs.max_percentile,
1927 eps=p.kwargs.eps,
1928 )
1929 )
1930 else:
1931 assert_never(p)
1934class _InputTensorConv(
1935 Converter[
1936 _InputTensorDescr_v0_4,
1937 InputTensorDescr,
1938 FileSource_,
1939 Optional[FileSource_],
1940 Mapping[_TensorName_v0_4, Mapping[str, int]],
1941 ]
1942):
1943 def _convert(
1944 self,
1945 src: _InputTensorDescr_v0_4,
1946 tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1947 test_tensor: FileSource_,
1948 sample_tensor: Optional[FileSource_],
1949 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1950 ) -> "InputTensorDescr | dict[str, Any]":
1951 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
1952 src.axes,
1953 shape=src.shape,
1954 tensor_type="input",
1955 halo=None,
1956 size_refs=size_refs,
1957 )
1958 prep: List[PreprocessingDescr] = []
1959 for p in src.preprocessing:
1960 cp = _convert_proc(p, src.axes)
1961 assert not isinstance(cp, ScaleMeanVarianceDescr)
1962 prep.append(cp)
1964 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
1966 return tgt(
1967 axes=axes,
1968 id=TensorId(str(src.name)),
1969 test_tensor=FileDescr(source=test_tensor),
1970 sample_tensor=(
1971 None if sample_tensor is None else FileDescr(source=sample_tensor)
1972 ),
1973 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType]
1974 preprocessing=prep,
1975 )
1978_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
1981class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1982 id: TensorId = TensorId("output")
1983 """Output tensor id.
1984 No duplicates are allowed across all inputs and outputs."""
1986 postprocessing: List[PostprocessingDescr] = Field(
1987 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
1988 )
1989 """Description of how this output should be postprocessed.
1991 note: `postprocessing` always ends with an 'ensure_dtype' operation.
1992 If not given this is added to cast to this tensor's `data.type`.
1993 """
1995 @model_validator(mode="after")
1996 def _validate_postprocessing_kwargs(self) -> Self:
1997 axes_ids = [a.id for a in self.axes]
1998 for p in self.postprocessing:
1999 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2000 if kwargs_axes is None:
2001 continue
2003 if not isinstance(kwargs_axes, collections.abc.Sequence):
2004 raise ValueError(
2005 f"expected `axes` sequence, but got {type(kwargs_axes)}"
2006 )
2008 if any(a not in axes_ids for a in kwargs_axes):
2009 raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2011 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2012 dtype = self.data.type
2013 else:
2014 dtype = self.data[0].type
2016 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2017 if not self.postprocessing or not isinstance(
2018 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2019 ):
2020 self.postprocessing.append(
2021 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2022 )
2023 return self
2026class _OutputTensorConv(
2027 Converter[
2028 _OutputTensorDescr_v0_4,
2029 OutputTensorDescr,
2030 FileSource_,
2031 Optional[FileSource_],
2032 Mapping[_TensorName_v0_4, Mapping[str, int]],
2033 ]
2034):
2035 def _convert(
2036 self,
2037 src: _OutputTensorDescr_v0_4,
2038 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2039 test_tensor: FileSource_,
2040 sample_tensor: Optional[FileSource_],
2041 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2042 ) -> "OutputTensorDescr | dict[str, Any]":
2043 # TODO: split convert_axes into convert_output_axes and convert_input_axes
2044 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2045 src.axes,
2046 shape=src.shape,
2047 tensor_type="output",
2048 halo=src.halo,
2049 size_refs=size_refs,
2050 )
2051 data_descr: Dict[str, Any] = dict(type=src.data_type)
2052 if data_descr["type"] == "bool":
2053 data_descr["values"] = [False, True]
2055 return tgt(
2056 axes=axes,
2057 id=TensorId(str(src.name)),
2058 test_tensor=FileDescr(source=test_tensor),
2059 sample_tensor=(
2060 None if sample_tensor is None else FileDescr(source=sample_tensor)
2061 ),
2062 data=data_descr, # pyright: ignore[reportArgumentType]
2063 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2064 )
2067_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2070TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2073def validate_tensors(
2074 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
2075 tensor_origin: Literal[
2076 "test_tensor"
2077 ], # for more precise error messages, e.g. 'test_tensor'
2078):
2079 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
2081 def e_msg(d: TensorDescr):
2082 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2084 for descr, array in tensors.values():
2085 try:
2086 axis_sizes = descr.get_axis_sizes_for_array(array)
2087 except ValueError as e:
2088 raise ValueError(f"{e_msg(descr)} {e}")
2089 else:
2090 all_tensor_axes[descr.id] = {
2091 a.id: (a, axis_sizes[a.id]) for a in descr.axes
2092 }
2094 for descr, array in tensors.values():
2095 if descr.dtype in ("float32", "float64"):
2096 invalid_test_tensor_dtype = array.dtype.name not in (
2097 "float32",
2098 "float64",
2099 "uint8",
2100 "int8",
2101 "uint16",
2102 "int16",
2103 "uint32",
2104 "int32",
2105 "uint64",
2106 "int64",
2107 )
2108 else:
2109 invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2111 if invalid_test_tensor_dtype:
2112 raise ValueError(
2113 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2114 + f" match described dtype '{descr.dtype}'"
2115 )
2117 if array.min() > -1e-4 and array.max() < 1e-4:
2118 raise ValueError(
2119 "Output values are too small for reliable testing."
2120 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2121 )
2123 for a in descr.axes:
2124 actual_size = all_tensor_axes[descr.id][a.id][1]
2125 if a.size is None:
2126 continue
2128 if isinstance(a.size, int):
2129 if actual_size != a.size:
2130 raise ValueError(
2131 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2132 + f"has incompatible size {actual_size}, expected {a.size}"
2133 )
2134 elif isinstance(a.size, ParameterizedSize):
2135 _ = a.size.validate_size(actual_size)
2136 elif isinstance(a.size, DataDependentSize):
2137 _ = a.size.validate_size(actual_size)
2138 elif isinstance(a.size, SizeReference):
2139 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2140 if ref_tensor_axes is None:
2141 raise ValueError(
2142 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2143 + f" reference '{a.size.tensor_id}'"
2144 )
2146 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2147 if ref_axis is None or ref_size is None:
2148 raise ValueError(
2149 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2150 + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2151 )
2153 if a.unit != ref_axis.unit:
2154 raise ValueError(
2155 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2156 + " axis and reference axis to have the same `unit`, but"
2157 + f" {a.unit}!={ref_axis.unit}"
2158 )
2160 if actual_size != (
2161 expected_size := (
2162 ref_size * ref_axis.scale / a.scale + a.size.offset
2163 )
2164 ):
2165 raise ValueError(
2166 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2167 + f" {actual_size} invalid for referenced size {ref_size};"
2168 + f" expected {expected_size}"
2169 )
2170 else:
2171 assert_never(a.size)
2174FileDescr_dependencies = Annotated[
2175 FileDescr_,
2176 WithSuffix((".yaml", ".yml"), case_sensitive=True),
2177 Field(examples=[dict(source="environment.yaml")]),
2178]
2181class _ArchitectureCallableDescr(Node):
2182 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2183 """Identifier of the callable that returns a torch.nn.Module instance."""
2185 kwargs: Dict[str, YamlValue] = Field(
2186 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2187 )
2188 """key word arguments for the `callable`"""
2191class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2192 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2193 """Architecture source file"""
2195 @model_serializer(mode="wrap", when_used="unless-none")
2196 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2197 return package_file_descr_serializer(self, nxt, info)
2200class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2201 import_from: str
2202 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2205class _ArchFileConv(
2206 Converter[
2207 _CallableFromFile_v0_4,
2208 ArchitectureFromFileDescr,
2209 Optional[Sha256],
2210 Dict[str, Any],
2211 ]
2212):
2213 def _convert(
2214 self,
2215 src: _CallableFromFile_v0_4,
2216 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2217 sha256: Optional[Sha256],
2218 kwargs: Dict[str, Any],
2219 ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2220 if src.startswith("http") and src.count(":") == 2:
2221 http, source, callable_ = src.split(":")
2222 source = ":".join((http, source))
2223 elif not src.startswith("http") and src.count(":") == 1:
2224 source, callable_ = src.split(":")
2225 else:
2226 source = str(src)
2227 callable_ = str(src)
2228 return tgt(
2229 callable=Identifier(callable_),
2230 source=cast(FileSource_, source),
2231 sha256=sha256,
2232 kwargs=kwargs,
2233 )
2236_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2239class _ArchLibConv(
2240 Converter[
2241 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2242 ]
2243):
2244 def _convert(
2245 self,
2246 src: _CallableFromDepencency_v0_4,
2247 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2248 kwargs: Dict[str, Any],
2249 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2250 *mods, callable_ = src.split(".")
2251 import_from = ".".join(mods)
2252 return tgt(
2253 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2254 )
2257_arch_lib_conv = _ArchLibConv(
2258 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2259)
2262class WeightsEntryDescrBase(FileDescr):
2263 type: ClassVar[WeightsFormat]
2264 weights_format_name: ClassVar[str] # human readable
2266 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2267 """Source of the weights file."""
2269 authors: Optional[List[Author]] = None
2270 """Authors
2271 Either the person(s) that have trained this model resulting in the original weights file.
2272 (If this is the initial weights entry, i.e. it does not have a `parent`)
2273 Or the person(s) who have converted the weights to this weights format.
2274 (If this is a child weight, i.e. it has a `parent` field)
2275 """
2277 parent: Annotated[
2278 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2279 ] = None
2280 """The source weights these weights were converted from.
2281 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2282 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2283 All weight entries except one (the initial set of weights resulting from training the model),
2284 need to have this field."""
2286 comment: str = ""
2287 """A comment about this weights entry, for example how these weights were created."""
2289 @model_validator(mode="after")
2290 def _validate(self) -> Self:
2291 if self.type == self.parent:
2292 raise ValueError("Weights entry can't be it's own parent.")
2294 return self
2296 @model_serializer(mode="wrap", when_used="unless-none")
2297 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2298 return package_file_descr_serializer(self, nxt, info)
2301class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2302 type = "keras_hdf5"
2303 weights_format_name: ClassVar[str] = "Keras HDF5"
2304 tensorflow_version: Version
2305 """TensorFlow version used to create these weights."""
2308class OnnxWeightsDescr(WeightsEntryDescrBase):
2309 type = "onnx"
2310 weights_format_name: ClassVar[str] = "ONNX"
2311 opset_version: Annotated[int, Ge(7)]
2312 """ONNX opset version"""
2315class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2316 type = "pytorch_state_dict"
2317 weights_format_name: ClassVar[str] = "Pytorch State Dict"
2318 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2319 pytorch_version: Version
2320 """Version of the PyTorch library used.
2321 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2322 """
2323 dependencies: Optional[FileDescr_dependencies] = None
2324 """Custom depencies beyond pytorch described in a Conda environment file.
2325 Allows to specify custom dependencies, see conda docs:
2326 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2327 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2329 The conda environment file should include pytorch and any version pinning has to be compatible with
2330 **pytorch_version**.
2331 """
2334class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2335 type = "tensorflow_js"
2336 weights_format_name: ClassVar[str] = "Tensorflow.js"
2337 tensorflow_version: Version
2338 """Version of the TensorFlow library used."""
2340 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2341 """The multi-file weights.
2342 All required files/folders should be a zip archive."""
2345class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2346 type = "tensorflow_saved_model_bundle"
2347 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2348 tensorflow_version: Version
2349 """Version of the TensorFlow library used."""
2351 dependencies: Optional[FileDescr_dependencies] = None
2352 """Custom dependencies beyond tensorflow.
2353 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2355 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2356 """The multi-file weights.
2357 All required files/folders should be a zip archive."""
2360class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2361 type = "torchscript"
2362 weights_format_name: ClassVar[str] = "TorchScript"
2363 pytorch_version: Version
2364 """Version of the PyTorch library used."""
2367class WeightsDescr(Node):
2368 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2369 onnx: Optional[OnnxWeightsDescr] = None
2370 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2371 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2372 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2373 None
2374 )
2375 torchscript: Optional[TorchscriptWeightsDescr] = None
2377 @model_validator(mode="after")
2378 def check_entries(self) -> Self:
2379 entries = {wtype for wtype, entry in self if entry is not None}
2381 if not entries:
2382 raise ValueError("Missing weights entry")
2384 entries_wo_parent = {
2385 wtype
2386 for wtype, entry in self
2387 if entry is not None and hasattr(entry, "parent") and entry.parent is None
2388 }
2389 if len(entries_wo_parent) != 1:
2390 issue_warning(
2391 "Exactly one weights entry may not specify the `parent` field (got"
2392 + " {value}). That entry is considered the original set of model weights."
2393 + " Other weight formats are created through conversion of the orignal or"
2394 + " already converted weights. They have to reference the weights format"
2395 + " they were converted from as their `parent`.",
2396 value=len(entries_wo_parent),
2397 field="weights",
2398 )
2400 for wtype, entry in self:
2401 if entry is None:
2402 continue
2404 assert hasattr(entry, "type")
2405 assert hasattr(entry, "parent")
2406 assert wtype == entry.type
2407 if (
2408 entry.parent is not None and entry.parent not in entries
2409 ): # self reference checked for `parent` field
2410 raise ValueError(
2411 f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2412 + f" formats: {entries}"
2413 )
2415 return self
2417 def __getitem__(
2418 self,
2419 key: Literal[
2420 "keras_hdf5",
2421 "onnx",
2422 "pytorch_state_dict",
2423 "tensorflow_js",
2424 "tensorflow_saved_model_bundle",
2425 "torchscript",
2426 ],
2427 ):
2428 if key == "keras_hdf5":
2429 ret = self.keras_hdf5
2430 elif key == "onnx":
2431 ret = self.onnx
2432 elif key == "pytorch_state_dict":
2433 ret = self.pytorch_state_dict
2434 elif key == "tensorflow_js":
2435 ret = self.tensorflow_js
2436 elif key == "tensorflow_saved_model_bundle":
2437 ret = self.tensorflow_saved_model_bundle
2438 elif key == "torchscript":
2439 ret = self.torchscript
2440 else:
2441 raise KeyError(key)
2443 if ret is None:
2444 raise KeyError(key)
2446 return ret
2448 @property
2449 def available_formats(self):
2450 return {
2451 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2452 **({} if self.onnx is None else {"onnx": self.onnx}),
2453 **(
2454 {}
2455 if self.pytorch_state_dict is None
2456 else {"pytorch_state_dict": self.pytorch_state_dict}
2457 ),
2458 **(
2459 {}
2460 if self.tensorflow_js is None
2461 else {"tensorflow_js": self.tensorflow_js}
2462 ),
2463 **(
2464 {}
2465 if self.tensorflow_saved_model_bundle is None
2466 else {
2467 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2468 }
2469 ),
2470 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2471 }
2473 @property
2474 def missing_formats(self):
2475 return {
2476 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2477 }
2480class ModelId(ResourceId):
2481 pass
2484class LinkedModel(LinkedResourceBase):
2485 """Reference to a bioimage.io model."""
2487 id: ModelId
2488 """A valid model `id` from the bioimage.io collection."""
2491class _DataDepSize(NamedTuple):
2492 min: StrictInt
2493 max: Optional[StrictInt]
2496class _AxisSizes(NamedTuple):
2497 """the lenghts of all axes of model inputs and outputs"""
2499 inputs: Dict[Tuple[TensorId, AxisId], int]
2500 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2503class _TensorSizes(NamedTuple):
2504 """_AxisSizes as nested dicts"""
2506 inputs: Dict[TensorId, Dict[AxisId, int]]
2507 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2510class ReproducibilityTolerance(Node, extra="allow"):
2511 """Describes what small numerical differences -- if any -- may be tolerated
2512 in the generated output when executing in different environments.
2514 A tensor element *output* is considered mismatched to the **test_tensor** if
2515 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2516 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2518 Motivation:
2519 For testing we can request the respective deep learning frameworks to be as
2520 reproducible as possible by setting seeds and chosing deterministic algorithms,
2521 but differences in operating systems, available hardware and installed drivers
2522 may still lead to numerical differences.
2523 """
2525 relative_tolerance: RelativeTolerance = 1e-3
2526 """Maximum relative tolerance of reproduced test tensor."""
2528 absolute_tolerance: AbsoluteTolerance = 1e-4
2529 """Maximum absolute tolerance of reproduced test tensor."""
2531 mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2532 """Maximum number of mismatched elements/pixels per million to tolerate."""
2534 output_ids: Sequence[TensorId] = ()
2535 """Limits the output tensor IDs these reproducibility details apply to."""
2537 weights_formats: Sequence[WeightsFormat] = ()
2538 """Limits the weights formats these details apply to."""
2541class BioimageioConfig(Node, extra="allow"):
2542 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2543 """Tolerances to allow when reproducing the model's test outputs
2544 from the model's test inputs.
2545 Only the first entry matching tensor id and weights format is considered.
2546 """
2549class Config(Node, extra="allow"):
2550 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
2553class ModelDescr(GenericModelDescrBase):
2554 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2555 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2556 """
2558 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2559 if TYPE_CHECKING:
2560 format_version: Literal["0.5.4"] = "0.5.4"
2561 else:
2562 format_version: Literal["0.5.4"]
2563 """Version of the bioimage.io model description specification used.
2564 When creating a new model always use the latest micro/patch version described here.
2565 The `format_version` is important for any consumer software to understand how to parse the fields.
2566 """
2568 implemented_type: ClassVar[Literal["model"]] = "model"
2569 if TYPE_CHECKING:
2570 type: Literal["model"] = "model"
2571 else:
2572 type: Literal["model"]
2573 """Specialized resource type 'model'"""
2575 id: Optional[ModelId] = None
2576 """bioimage.io-wide unique resource identifier
2577 assigned by bioimage.io; version **un**specific."""
2579 authors: NotEmpty[List[Author]]
2580 """The authors are the creators of the model RDF and the primary points of contact."""
2582 documentation: FileSource_documentation
2583 """URL or relative path to a markdown file with additional documentation.
2584 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2585 The documentation should include a '#[#] Validation' (sub)section
2586 with details on how to quantitatively validate the model on unseen data."""
2588 @field_validator("documentation", mode="after")
2589 @classmethod
2590 def _validate_documentation(
2591 cls, value: FileSource_documentation
2592 ) -> FileSource_documentation:
2593 if not get_validation_context().perform_io_checks:
2594 return value
2596 doc_reader = get_reader(value)
2597 doc_content = doc_reader.read().decode(encoding="utf-8")
2598 if not re.search("#.*[vV]alidation", doc_content):
2599 issue_warning(
2600 "No '# Validation' (sub)section found in {value}.",
2601 value=value,
2602 field="documentation",
2603 )
2605 return value
2607 inputs: NotEmpty[Sequence[InputTensorDescr]]
2608 """Describes the input tensors expected by this model."""
2610 @field_validator("inputs", mode="after")
2611 @classmethod
2612 def _validate_input_axes(
2613 cls, inputs: Sequence[InputTensorDescr]
2614 ) -> Sequence[InputTensorDescr]:
2615 input_size_refs = cls._get_axes_with_independent_size(inputs)
2617 for i, ipt in enumerate(inputs):
2618 valid_independent_refs: Dict[
2619 Tuple[TensorId, AxisId],
2620 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2621 ] = {
2622 **{
2623 (ipt.id, a.id): (ipt, a, a.size)
2624 for a in ipt.axes
2625 if not isinstance(a, BatchAxis)
2626 and isinstance(a.size, (int, ParameterizedSize))
2627 },
2628 **input_size_refs,
2629 }
2630 for a, ax in enumerate(ipt.axes):
2631 cls._validate_axis(
2632 "inputs",
2633 i=i,
2634 tensor_id=ipt.id,
2635 a=a,
2636 axis=ax,
2637 valid_independent_refs=valid_independent_refs,
2638 )
2639 return inputs
2641 @staticmethod
2642 def _validate_axis(
2643 field_name: str,
2644 i: int,
2645 tensor_id: TensorId,
2646 a: int,
2647 axis: AnyAxis,
2648 valid_independent_refs: Dict[
2649 Tuple[TensorId, AxisId],
2650 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2651 ],
2652 ):
2653 if isinstance(axis, BatchAxis) or isinstance(
2654 axis.size, (int, ParameterizedSize, DataDependentSize)
2655 ):
2656 return
2657 elif not isinstance(axis.size, SizeReference):
2658 assert_never(axis.size)
2660 # validate axis.size SizeReference
2661 ref = (axis.size.tensor_id, axis.size.axis_id)
2662 if ref not in valid_independent_refs:
2663 raise ValueError(
2664 "Invalid tensor axis reference at"
2665 + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2666 )
2667 if ref == (tensor_id, axis.id):
2668 raise ValueError(
2669 "Self-referencing not allowed for"
2670 + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2671 )
2672 if axis.type == "channel":
2673 if valid_independent_refs[ref][1].type != "channel":
2674 raise ValueError(
2675 "A channel axis' size may only reference another fixed size"
2676 + " channel axis."
2677 )
2678 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2679 ref_size = valid_independent_refs[ref][2]
2680 assert isinstance(ref_size, int), (
2681 "channel axis ref (another channel axis) has to specify fixed"
2682 + " size"
2683 )
2684 generated_channel_names = [
2685 Identifier(axis.channel_names.format(i=i))
2686 for i in range(1, ref_size + 1)
2687 ]
2688 axis.channel_names = generated_channel_names
2690 if (ax_unit := getattr(axis, "unit", None)) != (
2691 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2692 ):
2693 raise ValueError(
2694 "The units of an axis and its reference axis need to match, but"
2695 + f" '{ax_unit}' != '{ref_unit}'."
2696 )
2697 ref_axis = valid_independent_refs[ref][1]
2698 if isinstance(ref_axis, BatchAxis):
2699 raise ValueError(
2700 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2701 + " (a batch axis is not allowed as reference)."
2702 )
2704 if isinstance(axis, WithHalo):
2705 min_size = axis.size.get_size(axis, ref_axis, n=0)
2706 if (min_size - 2 * axis.halo) < 1:
2707 raise ValueError(
2708 f"axis {axis.id} with minimum size {min_size} is too small for halo"
2709 + f" {axis.halo}."
2710 )
2712 input_halo = axis.halo * axis.scale / ref_axis.scale
2713 if input_halo != int(input_halo) or input_halo % 2 == 1:
2714 raise ValueError(
2715 f"input_halo {input_halo} (output_halo {axis.halo} *"
2716 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2717 + f" {tensor_id}.{axis.id}."
2718 )
2720 @model_validator(mode="after")
2721 def _validate_test_tensors(self) -> Self:
2722 if not get_validation_context().perform_io_checks:
2723 return self
2725 test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs]
2726 test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs]
2728 tensors = {
2729 descr.id: (descr, array)
2730 for descr, array in zip(
2731 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2732 )
2733 }
2734 validate_tensors(tensors, tensor_origin="test_tensor")
2736 output_arrays = {
2737 descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2738 }
2739 for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2740 if not rep_tol.absolute_tolerance:
2741 continue
2743 if rep_tol.output_ids:
2744 out_arrays = {
2745 oid: a
2746 for oid, a in output_arrays.items()
2747 if oid in rep_tol.output_ids
2748 }
2749 else:
2750 out_arrays = output_arrays
2752 for out_id, array in out_arrays.items():
2753 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2754 raise ValueError(
2755 "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2756 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2757 + f" (1% of the maximum value of the test tensor '{out_id}')"
2758 )
2760 return self
2762 @model_validator(mode="after")
2763 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2764 ipt_refs = {t.id for t in self.inputs}
2765 out_refs = {t.id for t in self.outputs}
2766 for ipt in self.inputs:
2767 for p in ipt.preprocessing:
2768 ref = p.kwargs.get("reference_tensor")
2769 if ref is None:
2770 continue
2771 if ref not in ipt_refs:
2772 raise ValueError(
2773 f"`reference_tensor` '{ref}' not found. Valid input tensor"
2774 + f" references are: {ipt_refs}."
2775 )
2777 for out in self.outputs:
2778 for p in out.postprocessing:
2779 ref = p.kwargs.get("reference_tensor")
2780 if ref is None:
2781 continue
2783 if ref not in ipt_refs and ref not in out_refs:
2784 raise ValueError(
2785 f"`reference_tensor` '{ref}' not found. Valid tensor references"
2786 + f" are: {ipt_refs | out_refs}."
2787 )
2789 return self
2791 # TODO: use validate funcs in validate_test_tensors
2792 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2794 name: Annotated[
2795 Annotated[
2796 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2797 ],
2798 MinLen(5),
2799 MaxLen(128),
2800 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2801 ]
2802 """A human-readable name of this model.
2803 It should be no longer than 64 characters
2804 and may only contain letter, number, underscore, minus, parentheses and spaces.
2805 We recommend to chose a name that refers to the model's task and image modality.
2806 """
2808 outputs: NotEmpty[Sequence[OutputTensorDescr]]
2809 """Describes the output tensors."""
2811 @field_validator("outputs", mode="after")
2812 @classmethod
2813 def _validate_tensor_ids(
2814 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2815 ) -> Sequence[OutputTensorDescr]:
2816 tensor_ids = [
2817 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2818 ]
2819 duplicate_tensor_ids: List[str] = []
2820 seen: Set[str] = set()
2821 for t in tensor_ids:
2822 if t in seen:
2823 duplicate_tensor_ids.append(t)
2825 seen.add(t)
2827 if duplicate_tensor_ids:
2828 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2830 return outputs
2832 @staticmethod
2833 def _get_axes_with_parameterized_size(
2834 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2835 ):
2836 return {
2837 f"{t.id}.{a.id}": (t, a, a.size)
2838 for t in io
2839 for a in t.axes
2840 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2841 }
2843 @staticmethod
2844 def _get_axes_with_independent_size(
2845 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2846 ):
2847 return {
2848 (t.id, a.id): (t, a, a.size)
2849 for t in io
2850 for a in t.axes
2851 if not isinstance(a, BatchAxis)
2852 and isinstance(a.size, (int, ParameterizedSize))
2853 }
2855 @field_validator("outputs", mode="after")
2856 @classmethod
2857 def _validate_output_axes(
2858 cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2859 ) -> List[OutputTensorDescr]:
2860 input_size_refs = cls._get_axes_with_independent_size(
2861 info.data.get("inputs", [])
2862 )
2863 output_size_refs = cls._get_axes_with_independent_size(outputs)
2865 for i, out in enumerate(outputs):
2866 valid_independent_refs: Dict[
2867 Tuple[TensorId, AxisId],
2868 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2869 ] = {
2870 **{
2871 (out.id, a.id): (out, a, a.size)
2872 for a in out.axes
2873 if not isinstance(a, BatchAxis)
2874 and isinstance(a.size, (int, ParameterizedSize))
2875 },
2876 **input_size_refs,
2877 **output_size_refs,
2878 }
2879 for a, ax in enumerate(out.axes):
2880 cls._validate_axis(
2881 "outputs",
2882 i,
2883 out.id,
2884 a,
2885 ax,
2886 valid_independent_refs=valid_independent_refs,
2887 )
2889 return outputs
2891 packaged_by: List[Author] = Field(
2892 default_factory=cast(Callable[[], List[Author]], list)
2893 )
2894 """The persons that have packaged and uploaded this model.
2895 Only required if those persons differ from the `authors`."""
2897 parent: Optional[LinkedModel] = None
2898 """The model from which this model is derived, e.g. by fine-tuning the weights."""
2900 @model_validator(mode="after")
2901 def _validate_parent_is_not_self(self) -> Self:
2902 if self.parent is not None and self.parent.id == self.id:
2903 raise ValueError("A model description may not reference itself as parent.")
2905 return self
2907 run_mode: Annotated[
2908 Optional[RunMode],
2909 warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2910 ] = None
2911 """Custom run mode for this model: for more complex prediction procedures like test time
2912 data augmentation that currently cannot be expressed in the specification.
2913 No standard run modes are defined yet."""
2915 timestamp: Datetime = Field(default_factory=Datetime.now)
2916 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2917 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2918 (In Python a datetime object is valid, too)."""
2920 training_data: Annotated[
2921 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2922 Field(union_mode="left_to_right"),
2923 ] = None
2924 """The dataset used to train this model"""
2926 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2927 """The weights for this model.
2928 Weights can be given for different formats, but should otherwise be equivalent.
2929 The available weight formats determine which consumers can use this model."""
2931 config: Config = Field(default_factory=Config)
2933 @model_validator(mode="after")
2934 def _add_default_cover(self) -> Self:
2935 if not get_validation_context().perform_io_checks or self.covers:
2936 return self
2938 try:
2939 generated_covers = generate_covers(
2940 [(t, load_array(t.test_tensor)) for t in self.inputs],
2941 [(t, load_array(t.test_tensor)) for t in self.outputs],
2942 )
2943 except Exception as e:
2944 issue_warning(
2945 "Failed to generate cover image(s): {e}",
2946 value=self.covers,
2947 msg_context=dict(e=e),
2948 field="covers",
2949 )
2950 else:
2951 self.covers.extend(generated_covers)
2953 return self
2955 def get_input_test_arrays(self) -> List[NDArray[Any]]:
2956 data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2957 assert all(isinstance(d, np.ndarray) for d in data)
2958 return data
2960 def get_output_test_arrays(self) -> List[NDArray[Any]]:
2961 data = [load_array(out.test_tensor) for out in self.outputs]
2962 assert all(isinstance(d, np.ndarray) for d in data)
2963 return data
2965 @staticmethod
2966 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2967 batch_size = 1
2968 tensor_with_batchsize: Optional[TensorId] = None
2969 for tid in tensor_sizes:
2970 for aid, s in tensor_sizes[tid].items():
2971 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2972 continue
2974 if batch_size != 1:
2975 assert tensor_with_batchsize is not None
2976 raise ValueError(
2977 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2978 )
2980 batch_size = s
2981 tensor_with_batchsize = tid
2983 return batch_size
2985 def get_output_tensor_sizes(
2986 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2987 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2988 """Returns the tensor output sizes for given **input_sizes**.
2989 Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2990 Otherwise it might be larger than the actual (valid) output"""
2991 batch_size = self.get_batch_size(input_sizes)
2992 ns = self.get_ns(input_sizes)
2994 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2995 return tensor_sizes.outputs
2997 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2998 """get parameter `n` for each parameterized axis
2999 such that the valid input size is >= the given input size"""
3000 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3001 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3002 for tid in input_sizes:
3003 for aid, s in input_sizes[tid].items():
3004 size_descr = axes[tid][aid].size
3005 if isinstance(size_descr, ParameterizedSize):
3006 ret[(tid, aid)] = size_descr.get_n(s)
3007 elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3008 pass
3009 else:
3010 assert_never(size_descr)
3012 return ret
3014 def get_tensor_sizes(
3015 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3016 ) -> _TensorSizes:
3017 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3018 return _TensorSizes(
3019 {
3020 t: {
3021 aa: axis_sizes.inputs[(tt, aa)]
3022 for tt, aa in axis_sizes.inputs
3023 if tt == t
3024 }
3025 for t in {tt for tt, _ in axis_sizes.inputs}
3026 },
3027 {
3028 t: {
3029 aa: axis_sizes.outputs[(tt, aa)]
3030 for tt, aa in axis_sizes.outputs
3031 if tt == t
3032 }
3033 for t in {tt for tt, _ in axis_sizes.outputs}
3034 },
3035 )
3037 def get_axis_sizes(
3038 self,
3039 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3040 batch_size: Optional[int] = None,
3041 *,
3042 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3043 ) -> _AxisSizes:
3044 """Determine input and output block shape for scale factors **ns**
3045 of parameterized input sizes.
3047 Args:
3048 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3049 that is parameterized as `size = min + n * step`.
3050 batch_size: The desired size of the batch dimension.
3051 If given **batch_size** overwrites any batch size present in
3052 **max_input_shape**. Default 1.
3053 max_input_shape: Limits the derived block shapes.
3054 Each axis for which the input size, parameterized by `n`, is larger
3055 than **max_input_shape** is set to the minimal value `n_min` for which
3056 this is still true.
3057 Use this for small input samples or large values of **ns**.
3058 Or simply whenever you know the full input shape.
3060 Returns:
3061 Resolved axis sizes for model inputs and outputs.
3062 """
3063 max_input_shape = max_input_shape or {}
3064 if batch_size is None:
3065 for (_t_id, a_id), s in max_input_shape.items():
3066 if a_id == BATCH_AXIS_ID:
3067 batch_size = s
3068 break
3069 else:
3070 batch_size = 1
3072 all_axes = {
3073 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3074 }
3076 inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3077 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3079 def get_axis_size(a: Union[InputAxis, OutputAxis]):
3080 if isinstance(a, BatchAxis):
3081 if (t_descr.id, a.id) in ns:
3082 logger.warning(
3083 "Ignoring unexpected size increment factor (n) for batch axis"
3084 + " of tensor '{}'.",
3085 t_descr.id,
3086 )
3087 return batch_size
3088 elif isinstance(a.size, int):
3089 if (t_descr.id, a.id) in ns:
3090 logger.warning(
3091 "Ignoring unexpected size increment factor (n) for fixed size"
3092 + " axis '{}' of tensor '{}'.",
3093 a.id,
3094 t_descr.id,
3095 )
3096 return a.size
3097 elif isinstance(a.size, ParameterizedSize):
3098 if (t_descr.id, a.id) not in ns:
3099 raise ValueError(
3100 "Size increment factor (n) missing for parametrized axis"
3101 + f" '{a.id}' of tensor '{t_descr.id}'."
3102 )
3103 n = ns[(t_descr.id, a.id)]
3104 s_max = max_input_shape.get((t_descr.id, a.id))
3105 if s_max is not None:
3106 n = min(n, a.size.get_n(s_max))
3108 return a.size.get_size(n)
3110 elif isinstance(a.size, SizeReference):
3111 if (t_descr.id, a.id) in ns:
3112 logger.warning(
3113 "Ignoring unexpected size increment factor (n) for axis '{}'"
3114 + " of tensor '{}' with size reference.",
3115 a.id,
3116 t_descr.id,
3117 )
3118 assert not isinstance(a, BatchAxis)
3119 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3120 assert not isinstance(ref_axis, BatchAxis)
3121 ref_key = (a.size.tensor_id, a.size.axis_id)
3122 ref_size = inputs.get(ref_key, outputs.get(ref_key))
3123 assert ref_size is not None, ref_key
3124 assert not isinstance(ref_size, _DataDepSize), ref_key
3125 return a.size.get_size(
3126 axis=a,
3127 ref_axis=ref_axis,
3128 ref_size=ref_size,
3129 )
3130 elif isinstance(a.size, DataDependentSize):
3131 if (t_descr.id, a.id) in ns:
3132 logger.warning(
3133 "Ignoring unexpected increment factor (n) for data dependent"
3134 + " size axis '{}' of tensor '{}'.",
3135 a.id,
3136 t_descr.id,
3137 )
3138 return _DataDepSize(a.size.min, a.size.max)
3139 else:
3140 assert_never(a.size)
3142 # first resolve all , but the `SizeReference` input sizes
3143 for t_descr in self.inputs:
3144 for a in t_descr.axes:
3145 if not isinstance(a.size, SizeReference):
3146 s = get_axis_size(a)
3147 assert not isinstance(s, _DataDepSize)
3148 inputs[t_descr.id, a.id] = s
3150 # resolve all other input axis sizes
3151 for t_descr in self.inputs:
3152 for a in t_descr.axes:
3153 if isinstance(a.size, SizeReference):
3154 s = get_axis_size(a)
3155 assert not isinstance(s, _DataDepSize)
3156 inputs[t_descr.id, a.id] = s
3158 # resolve all output axis sizes
3159 for t_descr in self.outputs:
3160 for a in t_descr.axes:
3161 assert not isinstance(a.size, ParameterizedSize)
3162 s = get_axis_size(a)
3163 outputs[t_descr.id, a.id] = s
3165 return _AxisSizes(inputs=inputs, outputs=outputs)
3167 @model_validator(mode="before")
3168 @classmethod
3169 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3170 cls.convert_from_old_format_wo_validation(data)
3171 return data
3173 @classmethod
3174 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3175 """Convert metadata following an older format version to this classes' format
3176 without validating the result.
3177 """
3178 if (
3179 data.get("type") == "model"
3180 and isinstance(fv := data.get("format_version"), str)
3181 and fv.count(".") == 2
3182 ):
3183 fv_parts = fv.split(".")
3184 if any(not p.isdigit() for p in fv_parts):
3185 return
3187 fv_tuple = tuple(map(int, fv_parts))
3189 assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3190 if fv_tuple[:2] in ((0, 3), (0, 4)):
3191 m04 = _ModelDescr_v0_4.load(data)
3192 if isinstance(m04, InvalidDescr):
3193 try:
3194 updated = _model_conv.convert_as_dict(
3195 m04 # pyright: ignore[reportArgumentType]
3196 )
3197 except Exception as e:
3198 logger.error(
3199 "Failed to convert from invalid model 0.4 description."
3200 + f"\nerror: {e}"
3201 + "\nProceeding with model 0.5 validation without conversion."
3202 )
3203 updated = None
3204 else:
3205 updated = _model_conv.convert_as_dict(m04)
3207 if updated is not None:
3208 data.clear()
3209 data.update(updated)
3211 elif fv_tuple[:2] == (0, 5):
3212 # bump patch version
3213 data["format_version"] = cls.implemented_format_version
3216class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3217 def _convert(
3218 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3219 ) -> "ModelDescr | dict[str, Any]":
3220 name = "".join(
3221 c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3222 for c in src.name
3223 )
3225 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3226 conv = (
3227 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3228 )
3229 return None if auths is None else [conv(a) for a in auths]
3231 if TYPE_CHECKING:
3232 arch_file_conv = _arch_file_conv.convert
3233 arch_lib_conv = _arch_lib_conv.convert
3234 else:
3235 arch_file_conv = _arch_file_conv.convert_as_dict
3236 arch_lib_conv = _arch_lib_conv.convert_as_dict
3238 input_size_refs = {
3239 ipt.name: {
3240 a: s
3241 for a, s in zip(
3242 ipt.axes,
3243 (
3244 ipt.shape.min
3245 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3246 else ipt.shape
3247 ),
3248 )
3249 }
3250 for ipt in src.inputs
3251 if ipt.shape
3252 }
3253 output_size_refs = {
3254 **{
3255 out.name: {a: s for a, s in zip(out.axes, out.shape)}
3256 for out in src.outputs
3257 if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3258 },
3259 **input_size_refs,
3260 }
3262 return tgt(
3263 attachments=(
3264 []
3265 if src.attachments is None
3266 else [FileDescr(source=f) for f in src.attachments.files]
3267 ),
3268 authors=[
3269 _author_conv.convert_as_dict(a) for a in src.authors
3270 ], # pyright: ignore[reportArgumentType]
3271 cite=[
3272 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite
3273 ], # pyright: ignore[reportArgumentType]
3274 config=src.config, # pyright: ignore[reportArgumentType]
3275 covers=src.covers,
3276 description=src.description,
3277 documentation=src.documentation,
3278 format_version="0.5.4",
3279 git_repo=src.git_repo, # pyright: ignore[reportArgumentType]
3280 icon=src.icon,
3281 id=None if src.id is None else ModelId(src.id),
3282 id_emoji=src.id_emoji,
3283 license=src.license, # type: ignore
3284 links=src.links,
3285 maintainers=[
3286 _maintainer_conv.convert_as_dict(m) for m in src.maintainers
3287 ], # pyright: ignore[reportArgumentType]
3288 name=name,
3289 tags=src.tags,
3290 type=src.type,
3291 uploader=src.uploader,
3292 version=src.version,
3293 inputs=[ # pyright: ignore[reportArgumentType]
3294 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3295 for ipt, tt, st, in zip(
3296 src.inputs,
3297 src.test_inputs,
3298 src.sample_inputs or [None] * len(src.test_inputs),
3299 )
3300 ],
3301 outputs=[ # pyright: ignore[reportArgumentType]
3302 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3303 for out, tt, st, in zip(
3304 src.outputs,
3305 src.test_outputs,
3306 src.sample_outputs or [None] * len(src.test_outputs),
3307 )
3308 ],
3309 parent=(
3310 None
3311 if src.parent is None
3312 else LinkedModel(
3313 id=ModelId(
3314 str(src.parent.id)
3315 + (
3316 ""
3317 if src.parent.version_number is None
3318 else f"/{src.parent.version_number}"
3319 )
3320 )
3321 )
3322 ),
3323 training_data=(
3324 None
3325 if src.training_data is None
3326 else (
3327 LinkedDataset(
3328 id=DatasetId(
3329 str(src.training_data.id)
3330 + (
3331 ""
3332 if src.training_data.version_number is None
3333 else f"/{src.training_data.version_number}"
3334 )
3335 )
3336 )
3337 if isinstance(src.training_data, LinkedDataset02)
3338 else src.training_data
3339 )
3340 ),
3341 packaged_by=[
3342 _author_conv.convert_as_dict(a) for a in src.packaged_by
3343 ], # pyright: ignore[reportArgumentType]
3344 run_mode=src.run_mode,
3345 timestamp=src.timestamp,
3346 weights=(WeightsDescr if TYPE_CHECKING else dict)(
3347 keras_hdf5=(w := src.weights.keras_hdf5)
3348 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3349 authors=conv_authors(w.authors),
3350 source=w.source,
3351 tensorflow_version=w.tensorflow_version or Version("1.15"),
3352 parent=w.parent,
3353 ),
3354 onnx=(w := src.weights.onnx)
3355 and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3356 source=w.source,
3357 authors=conv_authors(w.authors),
3358 parent=w.parent,
3359 opset_version=w.opset_version or 15,
3360 ),
3361 pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3362 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3363 source=w.source,
3364 authors=conv_authors(w.authors),
3365 parent=w.parent,
3366 architecture=(
3367 arch_file_conv(
3368 w.architecture,
3369 w.architecture_sha256,
3370 w.kwargs,
3371 )
3372 if isinstance(w.architecture, _CallableFromFile_v0_4)
3373 else arch_lib_conv(w.architecture, w.kwargs)
3374 ),
3375 pytorch_version=w.pytorch_version or Version("1.10"),
3376 dependencies=(
3377 None
3378 if w.dependencies is None
3379 else (FileDescr if TYPE_CHECKING else dict)(
3380 source=cast(
3381 FileSource,
3382 str(deps := w.dependencies)[
3383 (
3384 len("conda:")
3385 if str(deps).startswith("conda:")
3386 else 0
3387 ) :
3388 ],
3389 )
3390 )
3391 ),
3392 ),
3393 tensorflow_js=(w := src.weights.tensorflow_js)
3394 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3395 source=w.source,
3396 authors=conv_authors(w.authors),
3397 parent=w.parent,
3398 tensorflow_version=w.tensorflow_version or Version("1.15"),
3399 ),
3400 tensorflow_saved_model_bundle=(
3401 w := src.weights.tensorflow_saved_model_bundle
3402 )
3403 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3404 authors=conv_authors(w.authors),
3405 parent=w.parent,
3406 source=w.source,
3407 tensorflow_version=w.tensorflow_version or Version("1.15"),
3408 dependencies=(
3409 None
3410 if w.dependencies is None
3411 else (FileDescr if TYPE_CHECKING else dict)(
3412 source=cast(
3413 FileSource,
3414 (
3415 str(w.dependencies)[len("conda:") :]
3416 if str(w.dependencies).startswith("conda:")
3417 else str(w.dependencies)
3418 ),
3419 )
3420 )
3421 ),
3422 ),
3423 torchscript=(w := src.weights.torchscript)
3424 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3425 source=w.source,
3426 authors=conv_authors(w.authors),
3427 parent=w.parent,
3428 pytorch_version=w.pytorch_version or Version("1.10"),
3429 ),
3430 ),
3431 )
3434_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3437# create better cover images for 3d data and non-image outputs
3438def generate_covers(
3439 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3440 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3441) -> List[Path]:
3442 def squeeze(
3443 data: NDArray[Any], axes: Sequence[AnyAxis]
3444 ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3445 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3446 if data.ndim != len(axes):
3447 raise ValueError(
3448 f"tensor shape {data.shape} does not match described axes"
3449 + f" {[a.id for a in axes]}"
3450 )
3452 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3453 return data.squeeze(), axes
3455 def normalize(
3456 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3457 ) -> NDArray[np.float32]:
3458 data = data.astype("float32")
3459 data -= data.min(axis=axis, keepdims=True)
3460 data /= data.max(axis=axis, keepdims=True) + eps
3461 return data
3463 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3464 original_shape = data.shape
3465 data, axes = squeeze(data, axes)
3467 # take slice fom any batch or index axis if needed
3468 # and convert the first channel axis and take a slice from any additional channel axes
3469 slices: Tuple[slice, ...] = ()
3470 ndim = data.ndim
3471 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3472 has_c_axis = False
3473 for i, a in enumerate(axes):
3474 s = data.shape[i]
3475 assert s > 1
3476 if (
3477 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3478 and ndim > ndim_need
3479 ):
3480 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3481 ndim -= 1
3482 elif isinstance(a, ChannelAxis):
3483 if has_c_axis:
3484 # second channel axis
3485 data = data[slices + (slice(0, 1),)]
3486 ndim -= 1
3487 else:
3488 has_c_axis = True
3489 if s == 2:
3490 # visualize two channels with cyan and magenta
3491 data = np.concatenate(
3492 [
3493 data[slices + (slice(1, 2),)],
3494 data[slices + (slice(0, 1),)],
3495 (
3496 data[slices + (slice(0, 1),)]
3497 + data[slices + (slice(1, 2),)]
3498 )
3499 / 2, # TODO: take maximum instead?
3500 ],
3501 axis=i,
3502 )
3503 elif data.shape[i] == 3:
3504 pass # visualize 3 channels as RGB
3505 else:
3506 # visualize first 3 channels as RGB
3507 data = data[slices + (slice(3),)]
3509 assert data.shape[i] == 3
3511 slices += (slice(None),)
3513 data, axes = squeeze(data, axes)
3514 assert len(axes) == ndim
3515 # take slice from z axis if needed
3516 slices = ()
3517 if ndim > ndim_need:
3518 for i, a in enumerate(axes):
3519 s = data.shape[i]
3520 if a.id == AxisId("z"):
3521 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3522 data, axes = squeeze(data, axes)
3523 ndim -= 1
3524 break
3526 slices += (slice(None),)
3528 # take slice from any space or time axis
3529 slices = ()
3531 for i, a in enumerate(axes):
3532 if ndim <= ndim_need:
3533 break
3535 s = data.shape[i]
3536 assert s > 1
3537 if isinstance(
3538 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3539 ):
3540 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3541 ndim -= 1
3543 slices += (slice(None),)
3545 del slices
3546 data, axes = squeeze(data, axes)
3547 assert len(axes) == ndim
3549 if (has_c_axis and ndim != 3) or ndim != 2:
3550 raise ValueError(
3551 f"Failed to construct cover image from shape {original_shape}"
3552 )
3554 if not has_c_axis:
3555 assert ndim == 2
3556 data = np.repeat(data[:, :, None], 3, axis=2)
3557 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3558 ndim += 1
3560 assert ndim == 3
3562 # transpose axis order such that longest axis comes first...
3563 axis_order: List[int] = list(np.argsort(list(data.shape)))
3564 axis_order.reverse()
3565 # ... and channel axis is last
3566 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3567 axis_order.append(axis_order.pop(c))
3568 axes = [axes[ao] for ao in axis_order]
3569 data = data.transpose(axis_order)
3571 # h, w = data.shape[:2]
3572 # if h / w in (1.0 or 2.0):
3573 # pass
3574 # elif h / w < 2:
3575 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3577 norm_along = (
3578 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3579 )
3580 # normalize the data and map to 8 bit
3581 data = normalize(data, norm_along)
3582 data = (data * 255).astype("uint8")
3584 return data
3586 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3587 assert im0.dtype == im1.dtype == np.uint8
3588 assert im0.shape == im1.shape
3589 assert im0.ndim == 3
3590 N, M, C = im0.shape
3591 assert C == 3
3592 out = np.ones((N, M, C), dtype="uint8")
3593 for c in range(C):
3594 outc = np.tril(im0[..., c])
3595 mask = outc == 0
3596 outc[mask] = np.triu(im1[..., c])[mask]
3597 out[..., c] = outc
3599 return out
3601 ipt_descr, ipt = inputs[0]
3602 out_descr, out = outputs[0]
3604 ipt_img = to_2d_image(ipt, ipt_descr.axes)
3605 out_img = to_2d_image(out, out_descr.axes)
3607 cover_folder = Path(mkdtemp())
3608 if ipt_img.shape == out_img.shape:
3609 covers = [cover_folder / "cover.png"]
3610 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3611 else:
3612 covers = [cover_folder / "input.png", cover_folder / "output.png"]
3613 imwrite(covers[0], ipt_img)
3614 imwrite(covers[1], out_img)
3616 return covers