Coverage for bioimageio/spec/model/v0_5.py: 75%
1315 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 09:20 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 09:20 +0000
1from __future__ import annotations
3import collections.abc
4import re
5import string
6import warnings
7from abc import ABC
8from copy import deepcopy
9from itertools import chain
10from math import ceil
11from pathlib import Path, PurePosixPath
12from tempfile import mkdtemp
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 ClassVar,
18 Dict,
19 Generic,
20 List,
21 Literal,
22 Mapping,
23 NamedTuple,
24 Optional,
25 Sequence,
26 Set,
27 Tuple,
28 Type,
29 TypeVar,
30 Union,
31 cast,
32)
34import numpy as np
35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
37from loguru import logger
38from numpy.typing import NDArray
39from pydantic import (
40 AfterValidator,
41 Discriminator,
42 Field,
43 RootModel,
44 SerializationInfo,
45 SerializerFunctionWrapHandler,
46 Tag,
47 ValidationInfo,
48 WrapSerializer,
49 field_validator,
50 model_serializer,
51 model_validator,
52)
53from typing_extensions import Annotated, Self, assert_never, get_args
55from .._internal.common_nodes import (
56 InvalidDescr,
57 Node,
58 NodeWithExplicitlySetFields,
59)
60from .._internal.constants import DTYPE_LIMITS
61from .._internal.field_warning import issue_warning, warn
62from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
63from .._internal.io import FileDescr as FileDescr
64from .._internal.io import (
65 FileSource,
66 WithSuffix,
67 YamlValue,
68 get_reader,
69 wo_special_file_name,
70)
71from .._internal.io_basics import Sha256 as Sha256
72from .._internal.io_packaging import (
73 FileDescr_,
74 FileSource_,
75 package_file_descr_serializer,
76)
77from .._internal.io_utils import load_array
78from .._internal.node_converter import Converter
79from .._internal.types import (
80 AbsoluteTolerance,
81 LowerCaseIdentifier,
82 LowerCaseIdentifierAnno,
83 MismatchedElementsPerMillion,
84 RelativeTolerance,
85)
86from .._internal.types import Datetime as Datetime
87from .._internal.types import Identifier as Identifier
88from .._internal.types import NotEmpty as NotEmpty
89from .._internal.types import SiUnit as SiUnit
90from .._internal.url import HttpUrl as HttpUrl
91from .._internal.validation_context import get_validation_context
92from .._internal.validator_annotations import RestrictCharacters
93from .._internal.version_type import Version as Version
94from .._internal.warning_levels import INFO
95from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
96from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
97from ..dataset.v0_3 import DatasetDescr as DatasetDescr
98from ..dataset.v0_3 import DatasetId as DatasetId
99from ..dataset.v0_3 import LinkedDataset as LinkedDataset
100from ..dataset.v0_3 import Uploader as Uploader
101from ..generic.v0_3 import (
102 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
103)
104from ..generic.v0_3 import Author as Author
105from ..generic.v0_3 import BadgeDescr as BadgeDescr
106from ..generic.v0_3 import CiteEntry as CiteEntry
107from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
108from ..generic.v0_3 import Doi as Doi
109from ..generic.v0_3 import (
110 FileSource_documentation,
111 GenericModelDescrBase,
112 LinkedResourceBase,
113 _author_conv, # pyright: ignore[reportPrivateUsage]
114 _maintainer_conv, # pyright: ignore[reportPrivateUsage]
115)
116from ..generic.v0_3 import LicenseId as LicenseId
117from ..generic.v0_3 import LinkedResource as LinkedResource
118from ..generic.v0_3 import Maintainer as Maintainer
119from ..generic.v0_3 import OrcidId as OrcidId
120from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
121from ..generic.v0_3 import ResourceId as ResourceId
122from .v0_4 import Author as _Author_v0_4
123from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
124from .v0_4 import CallableFromDepencency as CallableFromDepencency
125from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
126from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
127from .v0_4 import ClipDescr as _ClipDescr_v0_4
128from .v0_4 import ClipKwargs as ClipKwargs
129from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
130from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
131from .v0_4 import KnownRunMode as KnownRunMode
132from .v0_4 import ModelDescr as _ModelDescr_v0_4
133from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
134from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
135from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
136from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
137from .v0_4 import ProcessingKwargs as ProcessingKwargs
138from .v0_4 import RunMode as RunMode
139from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
140from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
141from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
142from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
143from .v0_4 import TensorName as _TensorName_v0_4
144from .v0_4 import WeightsFormat as WeightsFormat
145from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
146from .v0_4 import package_weights
148SpaceUnit = Literal[
149 "attometer",
150 "angstrom",
151 "centimeter",
152 "decimeter",
153 "exameter",
154 "femtometer",
155 "foot",
156 "gigameter",
157 "hectometer",
158 "inch",
159 "kilometer",
160 "megameter",
161 "meter",
162 "micrometer",
163 "mile",
164 "millimeter",
165 "nanometer",
166 "parsec",
167 "petameter",
168 "picometer",
169 "terameter",
170 "yard",
171 "yoctometer",
172 "yottameter",
173 "zeptometer",
174 "zettameter",
175]
176"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
178TimeUnit = Literal[
179 "attosecond",
180 "centisecond",
181 "day",
182 "decisecond",
183 "exasecond",
184 "femtosecond",
185 "gigasecond",
186 "hectosecond",
187 "hour",
188 "kilosecond",
189 "megasecond",
190 "microsecond",
191 "millisecond",
192 "minute",
193 "nanosecond",
194 "petasecond",
195 "picosecond",
196 "second",
197 "terasecond",
198 "yoctosecond",
199 "yottasecond",
200 "zeptosecond",
201 "zettasecond",
202]
203"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
205AxisType = Literal["batch", "channel", "index", "time", "space"]
207_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
208 "b": "batch",
209 "t": "time",
210 "i": "index",
211 "c": "channel",
212 "x": "space",
213 "y": "space",
214 "z": "space",
215}
217_AXIS_ID_MAP = {
218 "b": "batch",
219 "t": "time",
220 "i": "index",
221 "c": "channel",
222}
225class TensorId(LowerCaseIdentifier):
226 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
227 Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
228 ]
231def _normalize_axis_id(a: str):
232 a = str(a)
233 normalized = _AXIS_ID_MAP.get(a, a)
234 if a != normalized:
235 logger.opt(depth=3).warning(
236 "Normalized axis id from '{}' to '{}'.", a, normalized
237 )
238 return normalized
241class AxisId(LowerCaseIdentifier):
242 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
243 Annotated[
244 LowerCaseIdentifierAnno,
245 MaxLen(16),
246 AfterValidator(_normalize_axis_id),
247 ]
248 ]
251def _is_batch(a: str) -> bool:
252 return str(a) == "batch"
255def _is_not_batch(a: str) -> bool:
256 return not _is_batch(a)
259NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
261PostprocessingId = Literal[
262 "binarize",
263 "clip",
264 "ensure_dtype",
265 "fixed_zero_mean_unit_variance",
266 "scale_linear",
267 "scale_mean_variance",
268 "scale_range",
269 "sigmoid",
270 "zero_mean_unit_variance",
271]
272PreprocessingId = Literal[
273 "binarize",
274 "clip",
275 "ensure_dtype",
276 "scale_linear",
277 "sigmoid",
278 "zero_mean_unit_variance",
279 "scale_range",
280]
283SAME_AS_TYPE = "<same as type>"
286ParameterizedSize_N = int
287"""
288Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
289"""
292class ParameterizedSize(Node):
293 """Describes a range of valid tensor axis sizes as `size = min + n*step`.
295 - **min** and **step** are given by the model description.
296 - All blocksize paramters n = 0,1,2,... yield a valid `size`.
297 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
298 This allows to adjust the axis size more generically.
299 """
301 N: ClassVar[Type[int]] = ParameterizedSize_N
302 """Positive integer to parameterize this axis"""
304 min: Annotated[int, Gt(0)]
305 step: Annotated[int, Gt(0)]
307 def validate_size(self, size: int) -> int:
308 if size < self.min:
309 raise ValueError(f"size {size} < {self.min}")
310 if (size - self.min) % self.step != 0:
311 raise ValueError(
312 f"axis of size {size} is not parameterized by `min + n*step` ="
313 + f" `{self.min} + n*{self.step}`"
314 )
316 return size
318 def get_size(self, n: ParameterizedSize_N) -> int:
319 return self.min + self.step * n
321 def get_n(self, s: int) -> ParameterizedSize_N:
322 """return smallest n parameterizing a size greater or equal than `s`"""
323 return ceil((s - self.min) / self.step)
326class DataDependentSize(Node):
327 min: Annotated[int, Gt(0)] = 1
328 max: Annotated[Optional[int], Gt(1)] = None
330 @model_validator(mode="after")
331 def _validate_max_gt_min(self):
332 if self.max is not None and self.min >= self.max:
333 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
335 return self
337 def validate_size(self, size: int) -> int:
338 if size < self.min:
339 raise ValueError(f"size {size} < {self.min}")
341 if self.max is not None and size > self.max:
342 raise ValueError(f"size {size} > {self.max}")
344 return size
347class SizeReference(Node):
348 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
350 `axis.size = reference.size * reference.scale / axis.scale + offset`
352 Note:
353 1. The axis and the referenced axis need to have the same unit (or no unit).
354 2. Batch axes may not be referenced.
355 3. Fractions are rounded down.
356 4. If the reference axis is `concatenable` the referencing axis is assumed to be
357 `concatenable` as well with the same block order.
359 Example:
360 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
361 Let's assume that we want to express the image height h in relation to its width w
362 instead of only accepting input images of exactly 100*49 pixels
363 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
365 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
366 >>> h = SpaceInputAxis(
367 ... id=AxisId("h"),
368 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
369 ... unit="millimeter",
370 ... scale=4,
371 ... )
372 >>> print(h.size.get_size(h, w))
373 49
375 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
376 """
378 tensor_id: TensorId
379 """tensor id of the reference axis"""
381 axis_id: AxisId
382 """axis id of the reference axis"""
384 offset: int = 0
386 def get_size(
387 self,
388 axis: Union[
389 ChannelAxis,
390 IndexInputAxis,
391 IndexOutputAxis,
392 TimeInputAxis,
393 SpaceInputAxis,
394 TimeOutputAxis,
395 TimeOutputAxisWithHalo,
396 SpaceOutputAxis,
397 SpaceOutputAxisWithHalo,
398 ],
399 ref_axis: Union[
400 ChannelAxis,
401 IndexInputAxis,
402 IndexOutputAxis,
403 TimeInputAxis,
404 SpaceInputAxis,
405 TimeOutputAxis,
406 TimeOutputAxisWithHalo,
407 SpaceOutputAxis,
408 SpaceOutputAxisWithHalo,
409 ],
410 n: ParameterizedSize_N = 0,
411 ref_size: Optional[int] = None,
412 ):
413 """Compute the concrete size for a given axis and its reference axis.
415 Args:
416 axis: The axis this `SizeReference` is the size of.
417 ref_axis: The reference axis to compute the size from.
418 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
419 and no fixed **ref_size** is given,
420 **n** is used to compute the size of the parameterized **ref_axis**.
421 ref_size: Overwrite the reference size instead of deriving it from
422 **ref_axis**
423 (**ref_axis.scale** is still used; any given **n** is ignored).
424 """
425 assert (
426 axis.size == self
427 ), "Given `axis.size` is not defined by this `SizeReference`"
429 assert (
430 ref_axis.id == self.axis_id
431 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
433 assert axis.unit == ref_axis.unit, (
434 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
435 f" but {axis.unit}!={ref_axis.unit}"
436 )
437 if ref_size is None:
438 if isinstance(ref_axis.size, (int, float)):
439 ref_size = ref_axis.size
440 elif isinstance(ref_axis.size, ParameterizedSize):
441 ref_size = ref_axis.size.get_size(n)
442 elif isinstance(ref_axis.size, DataDependentSize):
443 raise ValueError(
444 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
445 )
446 elif isinstance(ref_axis.size, SizeReference):
447 raise ValueError(
448 "Reference axis referenced in `SizeReference` may not be sized by a"
449 + " `SizeReference` itself."
450 )
451 else:
452 assert_never(ref_axis.size)
454 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
456 @staticmethod
457 def _get_unit(
458 axis: Union[
459 ChannelAxis,
460 IndexInputAxis,
461 IndexOutputAxis,
462 TimeInputAxis,
463 SpaceInputAxis,
464 TimeOutputAxis,
465 TimeOutputAxisWithHalo,
466 SpaceOutputAxis,
467 SpaceOutputAxisWithHalo,
468 ],
469 ):
470 return axis.unit
473class AxisBase(NodeWithExplicitlySetFields):
474 id: AxisId
475 """An axis id unique across all axes of one tensor."""
477 description: Annotated[str, MaxLen(128)] = ""
480class WithHalo(Node):
481 halo: Annotated[int, Ge(1)]
482 """The halo should be cropped from the output tensor to avoid boundary effects.
483 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
484 To document a halo that is already cropped by the model use `size.offset` instead."""
486 size: Annotated[
487 SizeReference,
488 Field(
489 examples=[
490 10,
491 SizeReference(
492 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
493 ).model_dump(mode="json"),
494 ]
495 ),
496 ]
497 """reference to another axis with an optional offset (see `SizeReference`)"""
500BATCH_AXIS_ID = AxisId("batch")
503class BatchAxis(AxisBase):
504 implemented_type: ClassVar[Literal["batch"]] = "batch"
505 if TYPE_CHECKING:
506 type: Literal["batch"] = "batch"
507 else:
508 type: Literal["batch"]
510 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
511 size: Optional[Literal[1]] = None
512 """The batch size may be fixed to 1,
513 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
515 @property
516 def scale(self):
517 return 1.0
519 @property
520 def concatenable(self):
521 return True
523 @property
524 def unit(self):
525 return None
528class ChannelAxis(AxisBase):
529 implemented_type: ClassVar[Literal["channel"]] = "channel"
530 if TYPE_CHECKING:
531 type: Literal["channel"] = "channel"
532 else:
533 type: Literal["channel"]
535 id: NonBatchAxisId = AxisId("channel")
536 channel_names: NotEmpty[List[Identifier]]
538 @property
539 def size(self) -> int:
540 return len(self.channel_names)
542 @property
543 def concatenable(self):
544 return False
546 @property
547 def scale(self) -> float:
548 return 1.0
550 @property
551 def unit(self):
552 return None
555class IndexAxisBase(AxisBase):
556 implemented_type: ClassVar[Literal["index"]] = "index"
557 if TYPE_CHECKING:
558 type: Literal["index"] = "index"
559 else:
560 type: Literal["index"]
562 id: NonBatchAxisId = AxisId("index")
564 @property
565 def scale(self) -> float:
566 return 1.0
568 @property
569 def unit(self):
570 return None
573class _WithInputAxisSize(Node):
574 size: Annotated[
575 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
576 Field(
577 examples=[
578 10,
579 ParameterizedSize(min=32, step=16).model_dump(mode="json"),
580 SizeReference(
581 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
582 ).model_dump(mode="json"),
583 ]
584 ),
585 ]
586 """The size/length of this axis can be specified as
587 - fixed integer
588 - parameterized series of valid sizes (`ParameterizedSize`)
589 - reference to another axis with an optional offset (`SizeReference`)
590 """
593class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
594 concatenable: bool = False
595 """If a model has a `concatenable` input axis, it can be processed blockwise,
596 splitting a longer sample axis into blocks matching its input tensor description.
597 Output axes are concatenable if they have a `SizeReference` to a concatenable
598 input axis.
599 """
602class IndexOutputAxis(IndexAxisBase):
603 size: Annotated[
604 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
605 Field(
606 examples=[
607 10,
608 SizeReference(
609 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
610 ).model_dump(mode="json"),
611 ]
612 ),
613 ]
614 """The size/length of this axis can be specified as
615 - fixed integer
616 - reference to another axis with an optional offset (`SizeReference`)
617 - data dependent size using `DataDependentSize` (size is only known after model inference)
618 """
621class TimeAxisBase(AxisBase):
622 implemented_type: ClassVar[Literal["time"]] = "time"
623 if TYPE_CHECKING:
624 type: Literal["time"] = "time"
625 else:
626 type: Literal["time"]
628 id: NonBatchAxisId = AxisId("time")
629 unit: Optional[TimeUnit] = None
630 scale: Annotated[float, Gt(0)] = 1.0
633class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
634 concatenable: bool = False
635 """If a model has a `concatenable` input axis, it can be processed blockwise,
636 splitting a longer sample axis into blocks matching its input tensor description.
637 Output axes are concatenable if they have a `SizeReference` to a concatenable
638 input axis.
639 """
642class SpaceAxisBase(AxisBase):
643 implemented_type: ClassVar[Literal["space"]] = "space"
644 if TYPE_CHECKING:
645 type: Literal["space"] = "space"
646 else:
647 type: Literal["space"]
649 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
650 unit: Optional[SpaceUnit] = None
651 scale: Annotated[float, Gt(0)] = 1.0
654class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
655 concatenable: bool = False
656 """If a model has a `concatenable` input axis, it can be processed blockwise,
657 splitting a longer sample axis into blocks matching its input tensor description.
658 Output axes are concatenable if they have a `SizeReference` to a concatenable
659 input axis.
660 """
663INPUT_AXIS_TYPES = (
664 BatchAxis,
665 ChannelAxis,
666 IndexInputAxis,
667 TimeInputAxis,
668 SpaceInputAxis,
669)
670"""intended for isinstance comparisons in py<3.10"""
672_InputAxisUnion = Union[
673 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
674]
675InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
678class _WithOutputAxisSize(Node):
679 size: Annotated[
680 Union[Annotated[int, Gt(0)], SizeReference],
681 Field(
682 examples=[
683 10,
684 SizeReference(
685 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
686 ).model_dump(mode="json"),
687 ]
688 ),
689 ]
690 """The size/length of this axis can be specified as
691 - fixed integer
692 - reference to another axis with an optional offset (see `SizeReference`)
693 """
696class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
697 pass
700class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
701 pass
704def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
705 if isinstance(v, dict):
706 return "with_halo" if "halo" in v else "wo_halo"
707 else:
708 return "with_halo" if hasattr(v, "halo") else "wo_halo"
711_TimeOutputAxisUnion = Annotated[
712 Union[
713 Annotated[TimeOutputAxis, Tag("wo_halo")],
714 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
715 ],
716 Discriminator(_get_halo_axis_discriminator_value),
717]
720class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
721 pass
724class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
725 pass
728_SpaceOutputAxisUnion = Annotated[
729 Union[
730 Annotated[SpaceOutputAxis, Tag("wo_halo")],
731 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
732 ],
733 Discriminator(_get_halo_axis_discriminator_value),
734]
737_OutputAxisUnion = Union[
738 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
739]
740OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
742OUTPUT_AXIS_TYPES = (
743 BatchAxis,
744 ChannelAxis,
745 IndexOutputAxis,
746 TimeOutputAxis,
747 TimeOutputAxisWithHalo,
748 SpaceOutputAxis,
749 SpaceOutputAxisWithHalo,
750)
751"""intended for isinstance comparisons in py<3.10"""
754AnyAxis = Union[InputAxis, OutputAxis]
756ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
757"""intended for isinstance comparisons in py<3.10"""
759TVs = Union[
760 NotEmpty[List[int]],
761 NotEmpty[List[float]],
762 NotEmpty[List[bool]],
763 NotEmpty[List[str]],
764]
767NominalOrOrdinalDType = Literal[
768 "float32",
769 "float64",
770 "uint8",
771 "int8",
772 "uint16",
773 "int16",
774 "uint32",
775 "int32",
776 "uint64",
777 "int64",
778 "bool",
779]
782class NominalOrOrdinalDataDescr(Node):
783 values: TVs
784 """A fixed set of nominal or an ascending sequence of ordinal values.
785 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
786 String `values` are interpreted as labels for tensor values 0, ..., N.
787 Note: as YAML 1.2 does not natively support a "set" datatype,
788 nominal values should be given as a sequence (aka list/array) as well.
789 """
791 type: Annotated[
792 NominalOrOrdinalDType,
793 Field(
794 examples=[
795 "float32",
796 "uint8",
797 "uint16",
798 "int64",
799 "bool",
800 ],
801 ),
802 ] = "uint8"
804 @model_validator(mode="after")
805 def _validate_values_match_type(
806 self,
807 ) -> Self:
808 incompatible: List[Any] = []
809 for v in self.values:
810 if self.type == "bool":
811 if not isinstance(v, bool):
812 incompatible.append(v)
813 elif self.type in DTYPE_LIMITS:
814 if (
815 isinstance(v, (int, float))
816 and (
817 v < DTYPE_LIMITS[self.type].min
818 or v > DTYPE_LIMITS[self.type].max
819 )
820 or (isinstance(v, str) and "uint" not in self.type)
821 or (isinstance(v, float) and "int" in self.type)
822 ):
823 incompatible.append(v)
824 else:
825 incompatible.append(v)
827 if len(incompatible) == 5:
828 incompatible.append("...")
829 break
831 if incompatible:
832 raise ValueError(
833 f"data type '{self.type}' incompatible with values {incompatible}"
834 )
836 return self
838 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
840 @property
841 def range(self):
842 if isinstance(self.values[0], str):
843 return 0, len(self.values) - 1
844 else:
845 return min(self.values), max(self.values)
848IntervalOrRatioDType = Literal[
849 "float32",
850 "float64",
851 "uint8",
852 "int8",
853 "uint16",
854 "int16",
855 "uint32",
856 "int32",
857 "uint64",
858 "int64",
859]
862class IntervalOrRatioDataDescr(Node):
863 type: Annotated[ # todo: rename to dtype
864 IntervalOrRatioDType,
865 Field(
866 examples=["float32", "float64", "uint8", "uint16"],
867 ),
868 ] = "float32"
869 range: Tuple[Optional[float], Optional[float]] = (
870 None,
871 None,
872 )
873 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
874 `None` corresponds to min/max of what can be expressed by **type**."""
875 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
876 scale: float = 1.0
877 """Scale for data on an interval (or ratio) scale."""
878 offset: Optional[float] = None
879 """Offset for data on a ratio scale."""
882TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
885class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
886 """processing base class"""
889class BinarizeKwargs(ProcessingKwargs):
890 """key word arguments for `BinarizeDescr`"""
892 threshold: float
893 """The fixed threshold"""
896class BinarizeAlongAxisKwargs(ProcessingKwargs):
897 """key word arguments for `BinarizeDescr`"""
899 threshold: NotEmpty[List[float]]
900 """The fixed threshold values along `axis`"""
902 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
903 """The `threshold` axis"""
906class BinarizeDescr(ProcessingDescrBase):
907 """Binarize the tensor with a fixed threshold.
909 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
910 will be set to one, values below the threshold to zero.
912 Examples:
913 - in YAML
914 ```yaml
915 postprocessing:
916 - id: binarize
917 kwargs:
918 axis: 'channel'
919 threshold: [0.25, 0.5, 0.75]
920 ```
921 - in Python:
922 >>> postprocessing = [BinarizeDescr(
923 ... kwargs=BinarizeAlongAxisKwargs(
924 ... axis=AxisId('channel'),
925 ... threshold=[0.25, 0.5, 0.75],
926 ... )
927 ... )]
928 """
930 implemented_id: ClassVar[Literal["binarize"]] = "binarize"
931 if TYPE_CHECKING:
932 id: Literal["binarize"] = "binarize"
933 else:
934 id: Literal["binarize"]
935 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
938class ClipDescr(ProcessingDescrBase):
939 """Set tensor values below min to min and above max to max.
941 See `ScaleRangeDescr` for examples.
942 """
944 implemented_id: ClassVar[Literal["clip"]] = "clip"
945 if TYPE_CHECKING:
946 id: Literal["clip"] = "clip"
947 else:
948 id: Literal["clip"]
950 kwargs: ClipKwargs
953class EnsureDtypeKwargs(ProcessingKwargs):
954 """key word arguments for `EnsureDtypeDescr`"""
956 dtype: Literal[
957 "float32",
958 "float64",
959 "uint8",
960 "int8",
961 "uint16",
962 "int16",
963 "uint32",
964 "int32",
965 "uint64",
966 "int64",
967 "bool",
968 ]
971class EnsureDtypeDescr(ProcessingDescrBase):
972 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
974 This can for example be used to ensure the inner neural network model gets a
975 different input tensor data type than the fully described bioimage.io model does.
977 Examples:
978 The described bioimage.io model (incl. preprocessing) accepts any
979 float32-compatible tensor, normalizes it with percentiles and clipping and then
980 casts it to uint8, which is what the neural network in this example expects.
981 - in YAML
982 ```yaml
983 inputs:
984 - data:
985 type: float32 # described bioimage.io model is compatible with any float32 input tensor
986 preprocessing:
987 - id: scale_range
988 kwargs:
989 axes: ['y', 'x']
990 max_percentile: 99.8
991 min_percentile: 5.0
992 - id: clip
993 kwargs:
994 min: 0.0
995 max: 1.0
996 - id: ensure_dtype # the neural network of the model requires uint8
997 kwargs:
998 dtype: uint8
999 ```
1000 - in Python:
1001 >>> preprocessing = [
1002 ... ScaleRangeDescr(
1003 ... kwargs=ScaleRangeKwargs(
1004 ... axes= (AxisId('y'), AxisId('x')),
1005 ... max_percentile= 99.8,
1006 ... min_percentile= 5.0,
1007 ... )
1008 ... ),
1009 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1010 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1011 ... ]
1012 """
1014 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1015 if TYPE_CHECKING:
1016 id: Literal["ensure_dtype"] = "ensure_dtype"
1017 else:
1018 id: Literal["ensure_dtype"]
1020 kwargs: EnsureDtypeKwargs
1023class ScaleLinearKwargs(ProcessingKwargs):
1024 """Key word arguments for `ScaleLinearDescr`"""
1026 gain: float = 1.0
1027 """multiplicative factor"""
1029 offset: float = 0.0
1030 """additive term"""
1032 @model_validator(mode="after")
1033 def _validate(self) -> Self:
1034 if self.gain == 1.0 and self.offset == 0.0:
1035 raise ValueError(
1036 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1037 + " != 0.0."
1038 )
1040 return self
1043class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1044 """Key word arguments for `ScaleLinearDescr`"""
1046 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1047 """The axis of gain and offset values."""
1049 gain: Union[float, NotEmpty[List[float]]] = 1.0
1050 """multiplicative factor"""
1052 offset: Union[float, NotEmpty[List[float]]] = 0.0
1053 """additive term"""
1055 @model_validator(mode="after")
1056 def _validate(self) -> Self:
1058 if isinstance(self.gain, list):
1059 if isinstance(self.offset, list):
1060 if len(self.gain) != len(self.offset):
1061 raise ValueError(
1062 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1063 )
1064 else:
1065 self.offset = [float(self.offset)] * len(self.gain)
1066 elif isinstance(self.offset, list):
1067 self.gain = [float(self.gain)] * len(self.offset)
1068 else:
1069 raise ValueError(
1070 "Do not specify an `axis` for scalar gain and offset values."
1071 )
1073 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1074 raise ValueError(
1075 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1076 + " != 0.0."
1077 )
1079 return self
1082class ScaleLinearDescr(ProcessingDescrBase):
1083 """Fixed linear scaling.
1085 Examples:
1086 1. Scale with scalar gain and offset
1087 - in YAML
1088 ```yaml
1089 preprocessing:
1090 - id: scale_linear
1091 kwargs:
1092 gain: 2.0
1093 offset: 3.0
1094 ```
1095 - in Python:
1096 >>> preprocessing = [
1097 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1098 ... ]
1100 2. Independent scaling along an axis
1101 - in YAML
1102 ```yaml
1103 preprocessing:
1104 - id: scale_linear
1105 kwargs:
1106 axis: 'channel'
1107 gain: [1.0, 2.0, 3.0]
1108 ```
1109 - in Python:
1110 >>> preprocessing = [
1111 ... ScaleLinearDescr(
1112 ... kwargs=ScaleLinearAlongAxisKwargs(
1113 ... axis=AxisId("channel"),
1114 ... gain=[1.0, 2.0, 3.0],
1115 ... )
1116 ... )
1117 ... ]
1119 """
1121 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1122 if TYPE_CHECKING:
1123 id: Literal["scale_linear"] = "scale_linear"
1124 else:
1125 id: Literal["scale_linear"]
1126 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1129class SigmoidDescr(ProcessingDescrBase):
1130 """The logistic sigmoid funciton, a.k.a. expit function.
1132 Examples:
1133 - in YAML
1134 ```yaml
1135 postprocessing:
1136 - id: sigmoid
1137 ```
1138 - in Python:
1139 >>> postprocessing = [SigmoidDescr()]
1140 """
1142 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1143 if TYPE_CHECKING:
1144 id: Literal["sigmoid"] = "sigmoid"
1145 else:
1146 id: Literal["sigmoid"]
1148 @property
1149 def kwargs(self) -> ProcessingKwargs:
1150 """empty kwargs"""
1151 return ProcessingKwargs()
1154class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1155 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1157 mean: float
1158 """The mean value to normalize with."""
1160 std: Annotated[float, Ge(1e-6)]
1161 """The standard deviation value to normalize with."""
1164class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1165 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1167 mean: NotEmpty[List[float]]
1168 """The mean value(s) to normalize with."""
1170 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1171 """The standard deviation value(s) to normalize with.
1172 Size must match `mean` values."""
1174 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1175 """The axis of the mean/std values to normalize each entry along that dimension
1176 separately."""
1178 @model_validator(mode="after")
1179 def _mean_and_std_match(self) -> Self:
1180 if len(self.mean) != len(self.std):
1181 raise ValueError(
1182 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1183 + " must match."
1184 )
1186 return self
1189class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1190 """Subtract a given mean and divide by the standard deviation.
1192 Normalize with fixed, precomputed values for
1193 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1194 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1195 axes.
1197 Examples:
1198 1. scalar value for whole tensor
1199 - in YAML
1200 ```yaml
1201 preprocessing:
1202 - id: fixed_zero_mean_unit_variance
1203 kwargs:
1204 mean: 103.5
1205 std: 13.7
1206 ```
1207 - in Python
1208 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1209 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1210 ... )]
1212 2. independently along an axis
1213 - in YAML
1214 ```yaml
1215 preprocessing:
1216 - id: fixed_zero_mean_unit_variance
1217 kwargs:
1218 axis: channel
1219 mean: [101.5, 102.5, 103.5]
1220 std: [11.7, 12.7, 13.7]
1221 ```
1222 - in Python
1223 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1224 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1225 ... axis=AxisId("channel"),
1226 ... mean=[101.5, 102.5, 103.5],
1227 ... std=[11.7, 12.7, 13.7],
1228 ... )
1229 ... )]
1230 """
1232 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1233 "fixed_zero_mean_unit_variance"
1234 )
1235 if TYPE_CHECKING:
1236 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1237 else:
1238 id: Literal["fixed_zero_mean_unit_variance"]
1240 kwargs: Union[
1241 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1242 ]
1245class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1246 """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1248 axes: Annotated[
1249 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1250 ] = None
1251 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1252 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1253 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1254 To normalize each sample independently leave out the 'batch' axis.
1255 Default: Scale all axes jointly."""
1257 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1258 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1261class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1262 """Subtract mean and divide by variance.
1264 Examples:
1265 Subtract tensor mean and variance
1266 - in YAML
1267 ```yaml
1268 preprocessing:
1269 - id: zero_mean_unit_variance
1270 ```
1271 - in Python
1272 >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1273 """
1275 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1276 "zero_mean_unit_variance"
1277 )
1278 if TYPE_CHECKING:
1279 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1280 else:
1281 id: Literal["zero_mean_unit_variance"]
1283 kwargs: ZeroMeanUnitVarianceKwargs = Field(
1284 default_factory=ZeroMeanUnitVarianceKwargs
1285 )
1288class ScaleRangeKwargs(ProcessingKwargs):
1289 """key word arguments for `ScaleRangeDescr`
1291 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1292 this processing step normalizes data to the [0, 1] intervall.
1293 For other percentiles the normalized values will partially be outside the [0, 1]
1294 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1295 normalized values to a range.
1296 """
1298 axes: Annotated[
1299 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1300 ] = None
1301 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1302 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1303 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1304 To normalize samples independently, leave out the "batch" axis.
1305 Default: Scale all axes jointly."""
1307 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1308 """The lower percentile used to determine the value to align with zero."""
1310 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1311 """The upper percentile used to determine the value to align with one.
1312 Has to be bigger than `min_percentile`.
1313 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1314 accepting percentiles specified in the range 0.0 to 1.0."""
1316 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1317 """Epsilon for numeric stability.
1318 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1319 with `v_lower,v_upper` values at the respective percentiles."""
1321 reference_tensor: Optional[TensorId] = None
1322 """Tensor ID to compute the percentiles from. Default: The tensor itself.
1323 For any tensor in `inputs` only input tensor references are allowed."""
1325 @field_validator("max_percentile", mode="after")
1326 @classmethod
1327 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1328 if (min_p := info.data["min_percentile"]) >= value:
1329 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1331 return value
1334class ScaleRangeDescr(ProcessingDescrBase):
1335 """Scale with percentiles.
1337 Examples:
1338 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1339 - in YAML
1340 ```yaml
1341 preprocessing:
1342 - id: scale_range
1343 kwargs:
1344 axes: ['y', 'x']
1345 max_percentile: 99.8
1346 min_percentile: 5.0
1347 ```
1348 - in Python
1349 >>> preprocessing = [
1350 ... ScaleRangeDescr(
1351 ... kwargs=ScaleRangeKwargs(
1352 ... axes= (AxisId('y'), AxisId('x')),
1353 ... max_percentile= 99.8,
1354 ... min_percentile= 5.0,
1355 ... )
1356 ... ),
1357 ... ClipDescr(
1358 ... kwargs=ClipKwargs(
1359 ... min=0.0,
1360 ... max=1.0,
1361 ... )
1362 ... ),
1363 ... ]
1365 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1366 - in YAML
1367 ```yaml
1368 preprocessing:
1369 - id: scale_range
1370 kwargs:
1371 axes: ['y', 'x']
1372 max_percentile: 99.8
1373 min_percentile: 5.0
1374 - id: scale_range
1375 - id: clip
1376 kwargs:
1377 min: 0.0
1378 max: 1.0
1379 ```
1380 - in Python
1381 >>> preprocessing = [ScaleRangeDescr(
1382 ... kwargs=ScaleRangeKwargs(
1383 ... axes= (AxisId('y'), AxisId('x')),
1384 ... max_percentile= 99.8,
1385 ... min_percentile= 5.0,
1386 ... )
1387 ... )]
1389 """
1391 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1392 if TYPE_CHECKING:
1393 id: Literal["scale_range"] = "scale_range"
1394 else:
1395 id: Literal["scale_range"]
1396 kwargs: ScaleRangeKwargs
1399class ScaleMeanVarianceKwargs(ProcessingKwargs):
1400 """key word arguments for `ScaleMeanVarianceKwargs`"""
1402 reference_tensor: TensorId
1403 """Name of tensor to match."""
1405 axes: Annotated[
1406 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1407 ] = None
1408 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1409 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1410 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1411 To normalize samples independently, leave out the 'batch' axis.
1412 Default: Scale all axes jointly."""
1414 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1415 """Epsilon for numeric stability:
1416 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1419class ScaleMeanVarianceDescr(ProcessingDescrBase):
1420 """Scale a tensor's data distribution to match another tensor's mean/std.
1421 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1422 """
1424 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1425 if TYPE_CHECKING:
1426 id: Literal["scale_mean_variance"] = "scale_mean_variance"
1427 else:
1428 id: Literal["scale_mean_variance"]
1429 kwargs: ScaleMeanVarianceKwargs
1432PreprocessingDescr = Annotated[
1433 Union[
1434 BinarizeDescr,
1435 ClipDescr,
1436 EnsureDtypeDescr,
1437 ScaleLinearDescr,
1438 SigmoidDescr,
1439 FixedZeroMeanUnitVarianceDescr,
1440 ZeroMeanUnitVarianceDescr,
1441 ScaleRangeDescr,
1442 ],
1443 Discriminator("id"),
1444]
1445PostprocessingDescr = Annotated[
1446 Union[
1447 BinarizeDescr,
1448 ClipDescr,
1449 EnsureDtypeDescr,
1450 ScaleLinearDescr,
1451 SigmoidDescr,
1452 FixedZeroMeanUnitVarianceDescr,
1453 ZeroMeanUnitVarianceDescr,
1454 ScaleRangeDescr,
1455 ScaleMeanVarianceDescr,
1456 ],
1457 Discriminator("id"),
1458]
1460IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1463class TensorDescrBase(Node, Generic[IO_AxisT]):
1464 id: TensorId
1465 """Tensor id. No duplicates are allowed."""
1467 description: Annotated[str, MaxLen(128)] = ""
1468 """free text description"""
1470 axes: NotEmpty[Sequence[IO_AxisT]]
1471 """tensor axes"""
1473 @property
1474 def shape(self):
1475 return tuple(a.size for a in self.axes)
1477 @field_validator("axes", mode="after", check_fields=False)
1478 @classmethod
1479 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1480 batch_axes = [a for a in axes if a.type == "batch"]
1481 if len(batch_axes) > 1:
1482 raise ValueError(
1483 f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1484 )
1486 seen_ids: Set[AxisId] = set()
1487 duplicate_axes_ids: Set[AxisId] = set()
1488 for a in axes:
1489 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1491 if duplicate_axes_ids:
1492 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1494 return axes
1496 test_tensor: FileDescr_
1497 """An example tensor to use for testing.
1498 Using the model with the test input tensors is expected to yield the test output tensors.
1499 Each test tensor has be a an ndarray in the
1500 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1501 The file extension must be '.npy'."""
1503 sample_tensor: Optional[FileDescr_] = None
1504 """A sample tensor to illustrate a possible input/output for the model,
1505 The sample image primarily serves to inform a human user about an example use case
1506 and is typically stored as .hdf5, .png or .tiff.
1507 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1508 (numpy's `.npy` format is not supported).
1509 The image dimensionality has to match the number of axes specified in this tensor description.
1510 """
1512 @model_validator(mode="after")
1513 def _validate_sample_tensor(self) -> Self:
1514 if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1515 return self
1517 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1518 tensor: NDArray[Any] = imread(
1519 reader.read(),
1520 extension=PurePosixPath(reader.original_file_name).suffix,
1521 )
1522 n_dims = len(tensor.squeeze().shape)
1523 n_dims_min = n_dims_max = len(self.axes)
1525 for a in self.axes:
1526 if isinstance(a, BatchAxis):
1527 n_dims_min -= 1
1528 elif isinstance(a.size, int):
1529 if a.size == 1:
1530 n_dims_min -= 1
1531 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1532 if a.size.min == 1:
1533 n_dims_min -= 1
1534 elif isinstance(a.size, SizeReference):
1535 if a.size.offset < 2:
1536 # size reference may result in singleton axis
1537 n_dims_min -= 1
1538 else:
1539 assert_never(a.size)
1541 n_dims_min = max(0, n_dims_min)
1542 if n_dims < n_dims_min or n_dims > n_dims_max:
1543 raise ValueError(
1544 f"Expected sample tensor to have {n_dims_min} to"
1545 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1546 )
1548 return self
1550 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1551 IntervalOrRatioDataDescr()
1552 )
1553 """Description of the tensor's data values, optionally per channel.
1554 If specified per channel, the data `type` needs to match across channels."""
1556 @property
1557 def dtype(
1558 self,
1559 ) -> Literal[
1560 "float32",
1561 "float64",
1562 "uint8",
1563 "int8",
1564 "uint16",
1565 "int16",
1566 "uint32",
1567 "int32",
1568 "uint64",
1569 "int64",
1570 "bool",
1571 ]:
1572 """dtype as specified under `data.type` or `data[i].type`"""
1573 if isinstance(self.data, collections.abc.Sequence):
1574 return self.data[0].type
1575 else:
1576 return self.data.type
1578 @field_validator("data", mode="after")
1579 @classmethod
1580 def _check_data_type_across_channels(
1581 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1582 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1583 if not isinstance(value, list):
1584 return value
1586 dtypes = {t.type for t in value}
1587 if len(dtypes) > 1:
1588 raise ValueError(
1589 "Tensor data descriptions per channel need to agree in their data"
1590 + f" `type`, but found {dtypes}."
1591 )
1593 return value
1595 @model_validator(mode="after")
1596 def _check_data_matches_channelaxis(self) -> Self:
1597 if not isinstance(self.data, (list, tuple)):
1598 return self
1600 for a in self.axes:
1601 if isinstance(a, ChannelAxis):
1602 size = a.size
1603 assert isinstance(size, int)
1604 break
1605 else:
1606 return self
1608 if len(self.data) != size:
1609 raise ValueError(
1610 f"Got tensor data descriptions for {len(self.data)} channels, but"
1611 + f" '{a.id}' axis has size {size}."
1612 )
1614 return self
1616 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1617 if len(array.shape) != len(self.axes):
1618 raise ValueError(
1619 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1620 + f" incompatible with {len(self.axes)} axes."
1621 )
1622 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1625class InputTensorDescr(TensorDescrBase[InputAxis]):
1626 id: TensorId = TensorId("input")
1627 """Input tensor id.
1628 No duplicates are allowed across all inputs and outputs."""
1630 optional: bool = False
1631 """indicates that this tensor may be `None`"""
1633 preprocessing: List[PreprocessingDescr] = Field(
1634 default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
1635 )
1637 """Description of how this input should be preprocessed.
1639 notes:
1640 - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1641 to ensure an input tensor's data type matches the input tensor's data description.
1642 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1643 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1644 changing the data type.
1645 """
1647 @model_validator(mode="after")
1648 def _validate_preprocessing_kwargs(self) -> Self:
1649 axes_ids = [a.id for a in self.axes]
1650 for p in self.preprocessing:
1651 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1652 if kwargs_axes is None:
1653 continue
1655 if not isinstance(kwargs_axes, collections.abc.Sequence):
1656 raise ValueError(
1657 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1658 )
1660 if any(a not in axes_ids for a in kwargs_axes):
1661 raise ValueError(
1662 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1663 )
1665 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1666 dtype = self.data.type
1667 else:
1668 dtype = self.data[0].type
1670 # ensure `preprocessing` begins with `EnsureDtypeDescr`
1671 if not self.preprocessing or not isinstance(
1672 self.preprocessing[0], EnsureDtypeDescr
1673 ):
1674 self.preprocessing.insert(
1675 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1676 )
1678 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1679 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1680 self.preprocessing.append(
1681 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1682 )
1684 return self
1687def convert_axes(
1688 axes: str,
1689 *,
1690 shape: Union[
1691 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1692 ],
1693 tensor_type: Literal["input", "output"],
1694 halo: Optional[Sequence[int]],
1695 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1696):
1697 ret: List[AnyAxis] = []
1698 for i, a in enumerate(axes):
1699 axis_type = _AXIS_TYPE_MAP.get(a, a)
1700 if axis_type == "batch":
1701 ret.append(BatchAxis())
1702 continue
1704 scale = 1.0
1705 if isinstance(shape, _ParameterizedInputShape_v0_4):
1706 if shape.step[i] == 0:
1707 size = shape.min[i]
1708 else:
1709 size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1710 elif isinstance(shape, _ImplicitOutputShape_v0_4):
1711 ref_t = str(shape.reference_tensor)
1712 if ref_t.count(".") == 1:
1713 t_id, orig_a_id = ref_t.split(".")
1714 else:
1715 t_id = ref_t
1716 orig_a_id = a
1718 a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1719 if not (orig_scale := shape.scale[i]):
1720 # old way to insert a new axis dimension
1721 size = int(2 * shape.offset[i])
1722 else:
1723 scale = 1 / orig_scale
1724 if axis_type in ("channel", "index"):
1725 # these axes no longer have a scale
1726 offset_from_scale = orig_scale * size_refs.get(
1727 _TensorName_v0_4(t_id), {}
1728 ).get(orig_a_id, 0)
1729 else:
1730 offset_from_scale = 0
1731 size = SizeReference(
1732 tensor_id=TensorId(t_id),
1733 axis_id=AxisId(a_id),
1734 offset=int(offset_from_scale + 2 * shape.offset[i]),
1735 )
1736 else:
1737 size = shape[i]
1739 if axis_type == "time":
1740 if tensor_type == "input":
1741 ret.append(TimeInputAxis(size=size, scale=scale))
1742 else:
1743 assert not isinstance(size, ParameterizedSize)
1744 if halo is None:
1745 ret.append(TimeOutputAxis(size=size, scale=scale))
1746 else:
1747 assert not isinstance(size, int)
1748 ret.append(
1749 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1750 )
1752 elif axis_type == "index":
1753 if tensor_type == "input":
1754 ret.append(IndexInputAxis(size=size))
1755 else:
1756 if isinstance(size, ParameterizedSize):
1757 size = DataDependentSize(min=size.min)
1759 ret.append(IndexOutputAxis(size=size))
1760 elif axis_type == "channel":
1761 assert not isinstance(size, ParameterizedSize)
1762 if isinstance(size, SizeReference):
1763 warnings.warn(
1764 "Conversion of channel size from an implicit output shape may be"
1765 + " wrong"
1766 )
1767 ret.append(
1768 ChannelAxis(
1769 channel_names=[
1770 Identifier(f"channel{i}") for i in range(size.offset)
1771 ]
1772 )
1773 )
1774 else:
1775 ret.append(
1776 ChannelAxis(
1777 channel_names=[Identifier(f"channel{i}") for i in range(size)]
1778 )
1779 )
1780 elif axis_type == "space":
1781 if tensor_type == "input":
1782 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1783 else:
1784 assert not isinstance(size, ParameterizedSize)
1785 if halo is None or halo[i] == 0:
1786 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1787 elif isinstance(size, int):
1788 raise NotImplementedError(
1789 f"output axis with halo and fixed size (here {size}) not allowed"
1790 )
1791 else:
1792 ret.append(
1793 SpaceOutputAxisWithHalo(
1794 id=AxisId(a), size=size, scale=scale, halo=halo[i]
1795 )
1796 )
1798 return ret
1801def _axes_letters_to_ids(
1802 axes: Optional[str],
1803) -> Optional[List[AxisId]]:
1804 if axes is None:
1805 return None
1807 return [AxisId(a) for a in axes]
1810def _get_complement_v04_axis(
1811 tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1812) -> Optional[AxisId]:
1813 if axes is None:
1814 return None
1816 non_complement_axes = set(axes) | {"b"}
1817 complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1818 if len(complement_axes) > 1:
1819 raise ValueError(
1820 f"Expected none or a single complement axis, but axes '{axes}' "
1821 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1822 )
1824 return None if not complement_axes else AxisId(complement_axes[0])
1827def _convert_proc(
1828 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1829 tensor_axes: Sequence[str],
1830) -> Union[PreprocessingDescr, PostprocessingDescr]:
1831 if isinstance(p, _BinarizeDescr_v0_4):
1832 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1833 elif isinstance(p, _ClipDescr_v0_4):
1834 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1835 elif isinstance(p, _SigmoidDescr_v0_4):
1836 return SigmoidDescr()
1837 elif isinstance(p, _ScaleLinearDescr_v0_4):
1838 axes = _axes_letters_to_ids(p.kwargs.axes)
1839 if p.kwargs.axes is None:
1840 axis = None
1841 else:
1842 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1844 if axis is None:
1845 assert not isinstance(p.kwargs.gain, list)
1846 assert not isinstance(p.kwargs.offset, list)
1847 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1848 else:
1849 kwargs = ScaleLinearAlongAxisKwargs(
1850 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1851 )
1852 return ScaleLinearDescr(kwargs=kwargs)
1853 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1854 return ScaleMeanVarianceDescr(
1855 kwargs=ScaleMeanVarianceKwargs(
1856 axes=_axes_letters_to_ids(p.kwargs.axes),
1857 reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1858 eps=p.kwargs.eps,
1859 )
1860 )
1861 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1862 if p.kwargs.mode == "fixed":
1863 mean = p.kwargs.mean
1864 std = p.kwargs.std
1865 assert mean is not None
1866 assert std is not None
1868 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1870 if axis is None:
1871 return FixedZeroMeanUnitVarianceDescr(
1872 kwargs=FixedZeroMeanUnitVarianceKwargs(
1873 mean=mean, std=std # pyright: ignore[reportArgumentType]
1874 )
1875 )
1876 else:
1877 if not isinstance(mean, list):
1878 mean = [float(mean)]
1879 if not isinstance(std, list):
1880 std = [float(std)]
1882 return FixedZeroMeanUnitVarianceDescr(
1883 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1884 axis=axis, mean=mean, std=std
1885 )
1886 )
1888 else:
1889 axes = _axes_letters_to_ids(p.kwargs.axes) or []
1890 if p.kwargs.mode == "per_dataset":
1891 axes = [AxisId("batch")] + axes
1892 if not axes:
1893 axes = None
1894 return ZeroMeanUnitVarianceDescr(
1895 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1896 )
1898 elif isinstance(p, _ScaleRangeDescr_v0_4):
1899 return ScaleRangeDescr(
1900 kwargs=ScaleRangeKwargs(
1901 axes=_axes_letters_to_ids(p.kwargs.axes),
1902 min_percentile=p.kwargs.min_percentile,
1903 max_percentile=p.kwargs.max_percentile,
1904 eps=p.kwargs.eps,
1905 )
1906 )
1907 else:
1908 assert_never(p)
1911class _InputTensorConv(
1912 Converter[
1913 _InputTensorDescr_v0_4,
1914 InputTensorDescr,
1915 FileSource_,
1916 Optional[FileSource_],
1917 Mapping[_TensorName_v0_4, Mapping[str, int]],
1918 ]
1919):
1920 def _convert(
1921 self,
1922 src: _InputTensorDescr_v0_4,
1923 tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1924 test_tensor: FileSource_,
1925 sample_tensor: Optional[FileSource_],
1926 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1927 ) -> "InputTensorDescr | dict[str, Any]":
1928 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
1929 src.axes,
1930 shape=src.shape,
1931 tensor_type="input",
1932 halo=None,
1933 size_refs=size_refs,
1934 )
1935 prep: List[PreprocessingDescr] = []
1936 for p in src.preprocessing:
1937 cp = _convert_proc(p, src.axes)
1938 assert not isinstance(cp, ScaleMeanVarianceDescr)
1939 prep.append(cp)
1941 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
1943 return tgt(
1944 axes=axes,
1945 id=TensorId(str(src.name)),
1946 test_tensor=FileDescr(source=test_tensor),
1947 sample_tensor=(
1948 None if sample_tensor is None else FileDescr(source=sample_tensor)
1949 ),
1950 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType]
1951 preprocessing=prep,
1952 )
1955_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
1958class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1959 id: TensorId = TensorId("output")
1960 """Output tensor id.
1961 No duplicates are allowed across all inputs and outputs."""
1963 postprocessing: List[PostprocessingDescr] = Field(
1964 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
1965 )
1966 """Description of how this output should be postprocessed.
1968 note: `postprocessing` always ends with an 'ensure_dtype' operation.
1969 If not given this is added to cast to this tensor's `data.type`.
1970 """
1972 @model_validator(mode="after")
1973 def _validate_postprocessing_kwargs(self) -> Self:
1974 axes_ids = [a.id for a in self.axes]
1975 for p in self.postprocessing:
1976 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1977 if kwargs_axes is None:
1978 continue
1980 if not isinstance(kwargs_axes, collections.abc.Sequence):
1981 raise ValueError(
1982 f"expected `axes` sequence, but got {type(kwargs_axes)}"
1983 )
1985 if any(a not in axes_ids for a in kwargs_axes):
1986 raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1988 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1989 dtype = self.data.type
1990 else:
1991 dtype = self.data[0].type
1993 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1994 if not self.postprocessing or not isinstance(
1995 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1996 ):
1997 self.postprocessing.append(
1998 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1999 )
2000 return self
2003class _OutputTensorConv(
2004 Converter[
2005 _OutputTensorDescr_v0_4,
2006 OutputTensorDescr,
2007 FileSource_,
2008 Optional[FileSource_],
2009 Mapping[_TensorName_v0_4, Mapping[str, int]],
2010 ]
2011):
2012 def _convert(
2013 self,
2014 src: _OutputTensorDescr_v0_4,
2015 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2016 test_tensor: FileSource_,
2017 sample_tensor: Optional[FileSource_],
2018 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2019 ) -> "OutputTensorDescr | dict[str, Any]":
2020 # TODO: split convert_axes into convert_output_axes and convert_input_axes
2021 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2022 src.axes,
2023 shape=src.shape,
2024 tensor_type="output",
2025 halo=src.halo,
2026 size_refs=size_refs,
2027 )
2028 data_descr: Dict[str, Any] = dict(type=src.data_type)
2029 if data_descr["type"] == "bool":
2030 data_descr["values"] = [False, True]
2032 return tgt(
2033 axes=axes,
2034 id=TensorId(str(src.name)),
2035 test_tensor=FileDescr(source=test_tensor),
2036 sample_tensor=(
2037 None if sample_tensor is None else FileDescr(source=sample_tensor)
2038 ),
2039 data=data_descr, # pyright: ignore[reportArgumentType]
2040 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2041 )
2044_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2047TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2050def validate_tensors(
2051 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
2052 tensor_origin: Literal[
2053 "test_tensor"
2054 ], # for more precise error messages, e.g. 'test_tensor'
2055):
2056 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
2058 def e_msg(d: TensorDescr):
2059 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2061 for descr, array in tensors.values():
2062 try:
2063 axis_sizes = descr.get_axis_sizes_for_array(array)
2064 except ValueError as e:
2065 raise ValueError(f"{e_msg(descr)} {e}")
2066 else:
2067 all_tensor_axes[descr.id] = {
2068 a.id: (a, axis_sizes[a.id]) for a in descr.axes
2069 }
2071 for descr, array in tensors.values():
2072 if descr.dtype in ("float32", "float64"):
2073 invalid_test_tensor_dtype = array.dtype.name not in (
2074 "float32",
2075 "float64",
2076 "uint8",
2077 "int8",
2078 "uint16",
2079 "int16",
2080 "uint32",
2081 "int32",
2082 "uint64",
2083 "int64",
2084 )
2085 else:
2086 invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2088 if invalid_test_tensor_dtype:
2089 raise ValueError(
2090 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2091 + f" match described dtype '{descr.dtype}'"
2092 )
2094 if array.min() > -1e-4 and array.max() < 1e-4:
2095 raise ValueError(
2096 "Output values are too small for reliable testing."
2097 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2098 )
2100 for a in descr.axes:
2101 actual_size = all_tensor_axes[descr.id][a.id][1]
2102 if a.size is None:
2103 continue
2105 if isinstance(a.size, int):
2106 if actual_size != a.size:
2107 raise ValueError(
2108 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2109 + f"has incompatible size {actual_size}, expected {a.size}"
2110 )
2111 elif isinstance(a.size, ParameterizedSize):
2112 _ = a.size.validate_size(actual_size)
2113 elif isinstance(a.size, DataDependentSize):
2114 _ = a.size.validate_size(actual_size)
2115 elif isinstance(a.size, SizeReference):
2116 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2117 if ref_tensor_axes is None:
2118 raise ValueError(
2119 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2120 + f" reference '{a.size.tensor_id}'"
2121 )
2123 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2124 if ref_axis is None or ref_size is None:
2125 raise ValueError(
2126 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2127 + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2128 )
2130 if a.unit != ref_axis.unit:
2131 raise ValueError(
2132 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2133 + " axis and reference axis to have the same `unit`, but"
2134 + f" {a.unit}!={ref_axis.unit}"
2135 )
2137 if actual_size != (
2138 expected_size := (
2139 ref_size * ref_axis.scale / a.scale + a.size.offset
2140 )
2141 ):
2142 raise ValueError(
2143 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2144 + f" {actual_size} invalid for referenced size {ref_size};"
2145 + f" expected {expected_size}"
2146 )
2147 else:
2148 assert_never(a.size)
2151FileDescr_dependencies = Annotated[
2152 FileDescr_,
2153 WithSuffix((".yaml", ".yml"), case_sensitive=True),
2154 Field(examples=[dict(source="environment.yaml")]),
2155]
2158class _ArchitectureCallableDescr(Node):
2159 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2160 """Identifier of the callable that returns a torch.nn.Module instance."""
2162 kwargs: Dict[str, YamlValue] = Field(
2163 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2164 )
2165 """key word arguments for the `callable`"""
2168class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2169 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2170 """Architecture source file"""
2172 @model_serializer(mode="wrap", when_used="unless-none")
2173 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2174 return package_file_descr_serializer(self, nxt, info)
2177class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2178 import_from: str
2179 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2182class _ArchFileConv(
2183 Converter[
2184 _CallableFromFile_v0_4,
2185 ArchitectureFromFileDescr,
2186 Optional[Sha256],
2187 Dict[str, Any],
2188 ]
2189):
2190 def _convert(
2191 self,
2192 src: _CallableFromFile_v0_4,
2193 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2194 sha256: Optional[Sha256],
2195 kwargs: Dict[str, Any],
2196 ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2197 if src.startswith("http") and src.count(":") == 2:
2198 http, source, callable_ = src.split(":")
2199 source = ":".join((http, source))
2200 elif not src.startswith("http") and src.count(":") == 1:
2201 source, callable_ = src.split(":")
2202 else:
2203 source = str(src)
2204 callable_ = str(src)
2205 return tgt(
2206 callable=Identifier(callable_),
2207 source=cast(FileSource_, source),
2208 sha256=sha256,
2209 kwargs=kwargs,
2210 )
2213_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2216class _ArchLibConv(
2217 Converter[
2218 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2219 ]
2220):
2221 def _convert(
2222 self,
2223 src: _CallableFromDepencency_v0_4,
2224 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2225 kwargs: Dict[str, Any],
2226 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2227 *mods, callable_ = src.split(".")
2228 import_from = ".".join(mods)
2229 return tgt(
2230 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2231 )
2234_arch_lib_conv = _ArchLibConv(
2235 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2236)
2239class WeightsEntryDescrBase(FileDescr):
2240 type: ClassVar[WeightsFormat]
2241 weights_format_name: ClassVar[str] # human readable
2243 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2244 """Source of the weights file."""
2246 authors: Optional[List[Author]] = None
2247 """Authors
2248 Either the person(s) that have trained this model resulting in the original weights file.
2249 (If this is the initial weights entry, i.e. it does not have a `parent`)
2250 Or the person(s) who have converted the weights to this weights format.
2251 (If this is a child weight, i.e. it has a `parent` field)
2252 """
2254 parent: Annotated[
2255 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2256 ] = None
2257 """The source weights these weights were converted from.
2258 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2259 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2260 All weight entries except one (the initial set of weights resulting from training the model),
2261 need to have this field."""
2263 comment: str = ""
2264 """A comment about this weights entry, for example how these weights were created."""
2266 @model_validator(mode="after")
2267 def _validate(self) -> Self:
2268 if self.type == self.parent:
2269 raise ValueError("Weights entry can't be it's own parent.")
2271 return self
2273 @model_serializer(mode="wrap", when_used="unless-none")
2274 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2275 return package_file_descr_serializer(self, nxt, info)
2278class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2279 type = "keras_hdf5"
2280 weights_format_name: ClassVar[str] = "Keras HDF5"
2281 tensorflow_version: Version
2282 """TensorFlow version used to create these weights."""
2285class OnnxWeightsDescr(WeightsEntryDescrBase):
2286 type = "onnx"
2287 weights_format_name: ClassVar[str] = "ONNX"
2288 opset_version: Annotated[int, Ge(7)]
2289 """ONNX opset version"""
2292class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2293 type = "pytorch_state_dict"
2294 weights_format_name: ClassVar[str] = "Pytorch State Dict"
2295 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2296 pytorch_version: Version
2297 """Version of the PyTorch library used.
2298 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2299 """
2300 dependencies: Optional[FileDescr_dependencies] = None
2301 """Custom depencies beyond pytorch described in a Conda environment file.
2302 Allows to specify custom dependencies, see conda docs:
2303 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2304 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2306 The conda environment file should include pytorch and any version pinning has to be compatible with
2307 **pytorch_version**.
2308 """
2311class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2312 type = "tensorflow_js"
2313 weights_format_name: ClassVar[str] = "Tensorflow.js"
2314 tensorflow_version: Version
2315 """Version of the TensorFlow library used."""
2317 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2318 """The multi-file weights.
2319 All required files/folders should be a zip archive."""
2322class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2323 type = "tensorflow_saved_model_bundle"
2324 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2325 tensorflow_version: Version
2326 """Version of the TensorFlow library used."""
2328 dependencies: Optional[FileDescr_dependencies] = None
2329 """Custom dependencies beyond tensorflow.
2330 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2332 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2333 """The multi-file weights.
2334 All required files/folders should be a zip archive."""
2337class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2338 type = "torchscript"
2339 weights_format_name: ClassVar[str] = "TorchScript"
2340 pytorch_version: Version
2341 """Version of the PyTorch library used."""
2344class WeightsDescr(Node):
2345 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2346 onnx: Optional[OnnxWeightsDescr] = None
2347 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2348 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2349 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2350 None
2351 )
2352 torchscript: Optional[TorchscriptWeightsDescr] = None
2354 @model_validator(mode="after")
2355 def check_entries(self) -> Self:
2356 entries = {wtype for wtype, entry in self if entry is not None}
2358 if not entries:
2359 raise ValueError("Missing weights entry")
2361 entries_wo_parent = {
2362 wtype
2363 for wtype, entry in self
2364 if entry is not None and hasattr(entry, "parent") and entry.parent is None
2365 }
2366 if len(entries_wo_parent) != 1:
2367 issue_warning(
2368 "Exactly one weights entry may not specify the `parent` field (got"
2369 + " {value}). That entry is considered the original set of model weights."
2370 + " Other weight formats are created through conversion of the orignal or"
2371 + " already converted weights. They have to reference the weights format"
2372 + " they were converted from as their `parent`.",
2373 value=len(entries_wo_parent),
2374 field="weights",
2375 )
2377 for wtype, entry in self:
2378 if entry is None:
2379 continue
2381 assert hasattr(entry, "type")
2382 assert hasattr(entry, "parent")
2383 assert wtype == entry.type
2384 if (
2385 entry.parent is not None and entry.parent not in entries
2386 ): # self reference checked for `parent` field
2387 raise ValueError(
2388 f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2389 + f" formats: {entries}"
2390 )
2392 return self
2394 def __getitem__(
2395 self,
2396 key: Literal[
2397 "keras_hdf5",
2398 "onnx",
2399 "pytorch_state_dict",
2400 "tensorflow_js",
2401 "tensorflow_saved_model_bundle",
2402 "torchscript",
2403 ],
2404 ):
2405 if key == "keras_hdf5":
2406 ret = self.keras_hdf5
2407 elif key == "onnx":
2408 ret = self.onnx
2409 elif key == "pytorch_state_dict":
2410 ret = self.pytorch_state_dict
2411 elif key == "tensorflow_js":
2412 ret = self.tensorflow_js
2413 elif key == "tensorflow_saved_model_bundle":
2414 ret = self.tensorflow_saved_model_bundle
2415 elif key == "torchscript":
2416 ret = self.torchscript
2417 else:
2418 raise KeyError(key)
2420 if ret is None:
2421 raise KeyError(key)
2423 return ret
2425 @property
2426 def available_formats(self):
2427 return {
2428 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2429 **({} if self.onnx is None else {"onnx": self.onnx}),
2430 **(
2431 {}
2432 if self.pytorch_state_dict is None
2433 else {"pytorch_state_dict": self.pytorch_state_dict}
2434 ),
2435 **(
2436 {}
2437 if self.tensorflow_js is None
2438 else {"tensorflow_js": self.tensorflow_js}
2439 ),
2440 **(
2441 {}
2442 if self.tensorflow_saved_model_bundle is None
2443 else {
2444 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2445 }
2446 ),
2447 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2448 }
2450 @property
2451 def missing_formats(self):
2452 return {
2453 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2454 }
2457class ModelId(ResourceId):
2458 pass
2461class LinkedModel(LinkedResourceBase):
2462 """Reference to a bioimage.io model."""
2464 id: ModelId
2465 """A valid model `id` from the bioimage.io collection."""
2468class _DataDepSize(NamedTuple):
2469 min: int
2470 max: Optional[int]
2473class _AxisSizes(NamedTuple):
2474 """the lenghts of all axes of model inputs and outputs"""
2476 inputs: Dict[Tuple[TensorId, AxisId], int]
2477 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2480class _TensorSizes(NamedTuple):
2481 """_AxisSizes as nested dicts"""
2483 inputs: Dict[TensorId, Dict[AxisId, int]]
2484 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2487class ReproducibilityTolerance(Node, extra="allow"):
2488 """Describes what small numerical differences -- if any -- may be tolerated
2489 in the generated output when executing in different environments.
2491 A tensor element *output* is considered mismatched to the **test_tensor** if
2492 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2493 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2495 Motivation:
2496 For testing we can request the respective deep learning frameworks to be as
2497 reproducible as possible by setting seeds and chosing deterministic algorithms,
2498 but differences in operating systems, available hardware and installed drivers
2499 may still lead to numerical differences.
2500 """
2502 relative_tolerance: RelativeTolerance = 1e-3
2503 """Maximum relative tolerance of reproduced test tensor."""
2505 absolute_tolerance: AbsoluteTolerance = 1e-4
2506 """Maximum absolute tolerance of reproduced test tensor."""
2508 mismatched_elements_per_million: MismatchedElementsPerMillion = 100
2509 """Maximum number of mismatched elements/pixels per million to tolerate."""
2511 output_ids: Sequence[TensorId] = ()
2512 """Limits the output tensor IDs these reproducibility details apply to."""
2514 weights_formats: Sequence[WeightsFormat] = ()
2515 """Limits the weights formats these details apply to."""
2518class BioimageioConfig(Node, extra="allow"):
2519 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2520 """Tolerances to allow when reproducing the model's test outputs
2521 from the model's test inputs.
2522 Only the first entry matching tensor id and weights format is considered.
2523 """
2526class Config(Node, extra="allow"):
2527 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
2530class ModelDescr(GenericModelDescrBase):
2531 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2532 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2533 """
2535 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2536 if TYPE_CHECKING:
2537 format_version: Literal["0.5.4"] = "0.5.4"
2538 else:
2539 format_version: Literal["0.5.4"]
2540 """Version of the bioimage.io model description specification used.
2541 When creating a new model always use the latest micro/patch version described here.
2542 The `format_version` is important for any consumer software to understand how to parse the fields.
2543 """
2545 implemented_type: ClassVar[Literal["model"]] = "model"
2546 if TYPE_CHECKING:
2547 type: Literal["model"] = "model"
2548 else:
2549 type: Literal["model"]
2550 """Specialized resource type 'model'"""
2552 id: Optional[ModelId] = None
2553 """bioimage.io-wide unique resource identifier
2554 assigned by bioimage.io; version **un**specific."""
2556 authors: NotEmpty[List[Author]]
2557 """The authors are the creators of the model RDF and the primary points of contact."""
2559 documentation: FileSource_documentation
2560 """URL or relative path to a markdown file with additional documentation.
2561 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2562 The documentation should include a '#[#] Validation' (sub)section
2563 with details on how to quantitatively validate the model on unseen data."""
2565 @field_validator("documentation", mode="after")
2566 @classmethod
2567 def _validate_documentation(
2568 cls, value: FileSource_documentation
2569 ) -> FileSource_documentation:
2570 if not get_validation_context().perform_io_checks:
2571 return value
2573 doc_reader = get_reader(value)
2574 doc_content = doc_reader.read().decode(encoding="utf-8")
2575 if not re.search("#.*[vV]alidation", doc_content):
2576 issue_warning(
2577 "No '# Validation' (sub)section found in {value}.",
2578 value=value,
2579 field="documentation",
2580 )
2582 return value
2584 inputs: NotEmpty[Sequence[InputTensorDescr]]
2585 """Describes the input tensors expected by this model."""
2587 @field_validator("inputs", mode="after")
2588 @classmethod
2589 def _validate_input_axes(
2590 cls, inputs: Sequence[InputTensorDescr]
2591 ) -> Sequence[InputTensorDescr]:
2592 input_size_refs = cls._get_axes_with_independent_size(inputs)
2594 for i, ipt in enumerate(inputs):
2595 valid_independent_refs: Dict[
2596 Tuple[TensorId, AxisId],
2597 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2598 ] = {
2599 **{
2600 (ipt.id, a.id): (ipt, a, a.size)
2601 for a in ipt.axes
2602 if not isinstance(a, BatchAxis)
2603 and isinstance(a.size, (int, ParameterizedSize))
2604 },
2605 **input_size_refs,
2606 }
2607 for a, ax in enumerate(ipt.axes):
2608 cls._validate_axis(
2609 "inputs",
2610 i=i,
2611 tensor_id=ipt.id,
2612 a=a,
2613 axis=ax,
2614 valid_independent_refs=valid_independent_refs,
2615 )
2616 return inputs
2618 @staticmethod
2619 def _validate_axis(
2620 field_name: str,
2621 i: int,
2622 tensor_id: TensorId,
2623 a: int,
2624 axis: AnyAxis,
2625 valid_independent_refs: Dict[
2626 Tuple[TensorId, AxisId],
2627 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2628 ],
2629 ):
2630 if isinstance(axis, BatchAxis) or isinstance(
2631 axis.size, (int, ParameterizedSize, DataDependentSize)
2632 ):
2633 return
2634 elif not isinstance(axis.size, SizeReference):
2635 assert_never(axis.size)
2637 # validate axis.size SizeReference
2638 ref = (axis.size.tensor_id, axis.size.axis_id)
2639 if ref not in valid_independent_refs:
2640 raise ValueError(
2641 "Invalid tensor axis reference at"
2642 + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2643 )
2644 if ref == (tensor_id, axis.id):
2645 raise ValueError(
2646 "Self-referencing not allowed for"
2647 + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2648 )
2649 if axis.type == "channel":
2650 if valid_independent_refs[ref][1].type != "channel":
2651 raise ValueError(
2652 "A channel axis' size may only reference another fixed size"
2653 + " channel axis."
2654 )
2655 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2656 ref_size = valid_independent_refs[ref][2]
2657 assert isinstance(ref_size, int), (
2658 "channel axis ref (another channel axis) has to specify fixed"
2659 + " size"
2660 )
2661 generated_channel_names = [
2662 Identifier(axis.channel_names.format(i=i))
2663 for i in range(1, ref_size + 1)
2664 ]
2665 axis.channel_names = generated_channel_names
2667 if (ax_unit := getattr(axis, "unit", None)) != (
2668 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2669 ):
2670 raise ValueError(
2671 "The units of an axis and its reference axis need to match, but"
2672 + f" '{ax_unit}' != '{ref_unit}'."
2673 )
2674 ref_axis = valid_independent_refs[ref][1]
2675 if isinstance(ref_axis, BatchAxis):
2676 raise ValueError(
2677 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2678 + " (a batch axis is not allowed as reference)."
2679 )
2681 if isinstance(axis, WithHalo):
2682 min_size = axis.size.get_size(axis, ref_axis, n=0)
2683 if (min_size - 2 * axis.halo) < 1:
2684 raise ValueError(
2685 f"axis {axis.id} with minimum size {min_size} is too small for halo"
2686 + f" {axis.halo}."
2687 )
2689 input_halo = axis.halo * axis.scale / ref_axis.scale
2690 if input_halo != int(input_halo) or input_halo % 2 == 1:
2691 raise ValueError(
2692 f"input_halo {input_halo} (output_halo {axis.halo} *"
2693 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2694 + f" {tensor_id}.{axis.id}."
2695 )
2697 @model_validator(mode="after")
2698 def _validate_test_tensors(self) -> Self:
2699 if not get_validation_context().perform_io_checks:
2700 return self
2702 test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs]
2703 test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs]
2705 tensors = {
2706 descr.id: (descr, array)
2707 for descr, array in zip(
2708 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2709 )
2710 }
2711 validate_tensors(tensors, tensor_origin="test_tensor")
2713 output_arrays = {
2714 descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2715 }
2716 for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2717 if not rep_tol.absolute_tolerance:
2718 continue
2720 if rep_tol.output_ids:
2721 out_arrays = {
2722 oid: a
2723 for oid, a in output_arrays.items()
2724 if oid in rep_tol.output_ids
2725 }
2726 else:
2727 out_arrays = output_arrays
2729 for out_id, array in out_arrays.items():
2730 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2731 raise ValueError(
2732 "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2733 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2734 + f" (1% of the maximum value of the test tensor '{out_id}')"
2735 )
2737 return self
2739 @model_validator(mode="after")
2740 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2741 ipt_refs = {t.id for t in self.inputs}
2742 out_refs = {t.id for t in self.outputs}
2743 for ipt in self.inputs:
2744 for p in ipt.preprocessing:
2745 ref = p.kwargs.get("reference_tensor")
2746 if ref is None:
2747 continue
2748 if ref not in ipt_refs:
2749 raise ValueError(
2750 f"`reference_tensor` '{ref}' not found. Valid input tensor"
2751 + f" references are: {ipt_refs}."
2752 )
2754 for out in self.outputs:
2755 for p in out.postprocessing:
2756 ref = p.kwargs.get("reference_tensor")
2757 if ref is None:
2758 continue
2760 if ref not in ipt_refs and ref not in out_refs:
2761 raise ValueError(
2762 f"`reference_tensor` '{ref}' not found. Valid tensor references"
2763 + f" are: {ipt_refs | out_refs}."
2764 )
2766 return self
2768 # TODO: use validate funcs in validate_test_tensors
2769 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2771 name: Annotated[
2772 Annotated[
2773 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2774 ],
2775 MinLen(5),
2776 MaxLen(128),
2777 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2778 ]
2779 """A human-readable name of this model.
2780 It should be no longer than 64 characters
2781 and may only contain letter, number, underscore, minus, parentheses and spaces.
2782 We recommend to chose a name that refers to the model's task and image modality.
2783 """
2785 outputs: NotEmpty[Sequence[OutputTensorDescr]]
2786 """Describes the output tensors."""
2788 @field_validator("outputs", mode="after")
2789 @classmethod
2790 def _validate_tensor_ids(
2791 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2792 ) -> Sequence[OutputTensorDescr]:
2793 tensor_ids = [
2794 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2795 ]
2796 duplicate_tensor_ids: List[str] = []
2797 seen: Set[str] = set()
2798 for t in tensor_ids:
2799 if t in seen:
2800 duplicate_tensor_ids.append(t)
2802 seen.add(t)
2804 if duplicate_tensor_ids:
2805 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2807 return outputs
2809 @staticmethod
2810 def _get_axes_with_parameterized_size(
2811 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2812 ):
2813 return {
2814 f"{t.id}.{a.id}": (t, a, a.size)
2815 for t in io
2816 for a in t.axes
2817 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2818 }
2820 @staticmethod
2821 def _get_axes_with_independent_size(
2822 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2823 ):
2824 return {
2825 (t.id, a.id): (t, a, a.size)
2826 for t in io
2827 for a in t.axes
2828 if not isinstance(a, BatchAxis)
2829 and isinstance(a.size, (int, ParameterizedSize))
2830 }
2832 @field_validator("outputs", mode="after")
2833 @classmethod
2834 def _validate_output_axes(
2835 cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2836 ) -> List[OutputTensorDescr]:
2837 input_size_refs = cls._get_axes_with_independent_size(
2838 info.data.get("inputs", [])
2839 )
2840 output_size_refs = cls._get_axes_with_independent_size(outputs)
2842 for i, out in enumerate(outputs):
2843 valid_independent_refs: Dict[
2844 Tuple[TensorId, AxisId],
2845 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2846 ] = {
2847 **{
2848 (out.id, a.id): (out, a, a.size)
2849 for a in out.axes
2850 if not isinstance(a, BatchAxis)
2851 and isinstance(a.size, (int, ParameterizedSize))
2852 },
2853 **input_size_refs,
2854 **output_size_refs,
2855 }
2856 for a, ax in enumerate(out.axes):
2857 cls._validate_axis(
2858 "outputs",
2859 i,
2860 out.id,
2861 a,
2862 ax,
2863 valid_independent_refs=valid_independent_refs,
2864 )
2866 return outputs
2868 packaged_by: List[Author] = Field(
2869 default_factory=cast(Callable[[], List[Author]], list)
2870 )
2871 """The persons that have packaged and uploaded this model.
2872 Only required if those persons differ from the `authors`."""
2874 parent: Optional[LinkedModel] = None
2875 """The model from which this model is derived, e.g. by fine-tuning the weights."""
2877 @model_validator(mode="after")
2878 def _validate_parent_is_not_self(self) -> Self:
2879 if self.parent is not None and self.parent.id == self.id:
2880 raise ValueError("A model description may not reference itself as parent.")
2882 return self
2884 run_mode: Annotated[
2885 Optional[RunMode],
2886 warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2887 ] = None
2888 """Custom run mode for this model: for more complex prediction procedures like test time
2889 data augmentation that currently cannot be expressed in the specification.
2890 No standard run modes are defined yet."""
2892 timestamp: Datetime = Field(default_factory=Datetime.now)
2893 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2894 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2895 (In Python a datetime object is valid, too)."""
2897 training_data: Annotated[
2898 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2899 Field(union_mode="left_to_right"),
2900 ] = None
2901 """The dataset used to train this model"""
2903 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2904 """The weights for this model.
2905 Weights can be given for different formats, but should otherwise be equivalent.
2906 The available weight formats determine which consumers can use this model."""
2908 config: Config = Field(default_factory=Config)
2910 @model_validator(mode="after")
2911 def _add_default_cover(self) -> Self:
2912 if not get_validation_context().perform_io_checks or self.covers:
2913 return self
2915 try:
2916 generated_covers = generate_covers(
2917 [(t, load_array(t.test_tensor)) for t in self.inputs],
2918 [(t, load_array(t.test_tensor)) for t in self.outputs],
2919 )
2920 except Exception as e:
2921 issue_warning(
2922 "Failed to generate cover image(s): {e}",
2923 value=self.covers,
2924 msg_context=dict(e=e),
2925 field="covers",
2926 )
2927 else:
2928 self.covers.extend(generated_covers)
2930 return self
2932 def get_input_test_arrays(self) -> List[NDArray[Any]]:
2933 data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2934 assert all(isinstance(d, np.ndarray) for d in data)
2935 return data
2937 def get_output_test_arrays(self) -> List[NDArray[Any]]:
2938 data = [load_array(out.test_tensor) for out in self.outputs]
2939 assert all(isinstance(d, np.ndarray) for d in data)
2940 return data
2942 @staticmethod
2943 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2944 batch_size = 1
2945 tensor_with_batchsize: Optional[TensorId] = None
2946 for tid in tensor_sizes:
2947 for aid, s in tensor_sizes[tid].items():
2948 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2949 continue
2951 if batch_size != 1:
2952 assert tensor_with_batchsize is not None
2953 raise ValueError(
2954 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2955 )
2957 batch_size = s
2958 tensor_with_batchsize = tid
2960 return batch_size
2962 def get_output_tensor_sizes(
2963 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2964 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2965 """Returns the tensor output sizes for given **input_sizes**.
2966 Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2967 Otherwise it might be larger than the actual (valid) output"""
2968 batch_size = self.get_batch_size(input_sizes)
2969 ns = self.get_ns(input_sizes)
2971 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2972 return tensor_sizes.outputs
2974 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2975 """get parameter `n` for each parameterized axis
2976 such that the valid input size is >= the given input size"""
2977 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2978 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2979 for tid in input_sizes:
2980 for aid, s in input_sizes[tid].items():
2981 size_descr = axes[tid][aid].size
2982 if isinstance(size_descr, ParameterizedSize):
2983 ret[(tid, aid)] = size_descr.get_n(s)
2984 elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2985 pass
2986 else:
2987 assert_never(size_descr)
2989 return ret
2991 def get_tensor_sizes(
2992 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2993 ) -> _TensorSizes:
2994 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2995 return _TensorSizes(
2996 {
2997 t: {
2998 aa: axis_sizes.inputs[(tt, aa)]
2999 for tt, aa in axis_sizes.inputs
3000 if tt == t
3001 }
3002 for t in {tt for tt, _ in axis_sizes.inputs}
3003 },
3004 {
3005 t: {
3006 aa: axis_sizes.outputs[(tt, aa)]
3007 for tt, aa in axis_sizes.outputs
3008 if tt == t
3009 }
3010 for t in {tt for tt, _ in axis_sizes.outputs}
3011 },
3012 )
3014 def get_axis_sizes(
3015 self,
3016 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3017 batch_size: Optional[int] = None,
3018 *,
3019 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3020 ) -> _AxisSizes:
3021 """Determine input and output block shape for scale factors **ns**
3022 of parameterized input sizes.
3024 Args:
3025 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3026 that is parameterized as `size = min + n * step`.
3027 batch_size: The desired size of the batch dimension.
3028 If given **batch_size** overwrites any batch size present in
3029 **max_input_shape**. Default 1.
3030 max_input_shape: Limits the derived block shapes.
3031 Each axis for which the input size, parameterized by `n`, is larger
3032 than **max_input_shape** is set to the minimal value `n_min` for which
3033 this is still true.
3034 Use this for small input samples or large values of **ns**.
3035 Or simply whenever you know the full input shape.
3037 Returns:
3038 Resolved axis sizes for model inputs and outputs.
3039 """
3040 max_input_shape = max_input_shape or {}
3041 if batch_size is None:
3042 for (_t_id, a_id), s in max_input_shape.items():
3043 if a_id == BATCH_AXIS_ID:
3044 batch_size = s
3045 break
3046 else:
3047 batch_size = 1
3049 all_axes = {
3050 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3051 }
3053 inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3054 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3056 def get_axis_size(a: Union[InputAxis, OutputAxis]):
3057 if isinstance(a, BatchAxis):
3058 if (t_descr.id, a.id) in ns:
3059 logger.warning(
3060 "Ignoring unexpected size increment factor (n) for batch axis"
3061 + " of tensor '{}'.",
3062 t_descr.id,
3063 )
3064 return batch_size
3065 elif isinstance(a.size, int):
3066 if (t_descr.id, a.id) in ns:
3067 logger.warning(
3068 "Ignoring unexpected size increment factor (n) for fixed size"
3069 + " axis '{}' of tensor '{}'.",
3070 a.id,
3071 t_descr.id,
3072 )
3073 return a.size
3074 elif isinstance(a.size, ParameterizedSize):
3075 if (t_descr.id, a.id) not in ns:
3076 raise ValueError(
3077 "Size increment factor (n) missing for parametrized axis"
3078 + f" '{a.id}' of tensor '{t_descr.id}'."
3079 )
3080 n = ns[(t_descr.id, a.id)]
3081 s_max = max_input_shape.get((t_descr.id, a.id))
3082 if s_max is not None:
3083 n = min(n, a.size.get_n(s_max))
3085 return a.size.get_size(n)
3087 elif isinstance(a.size, SizeReference):
3088 if (t_descr.id, a.id) in ns:
3089 logger.warning(
3090 "Ignoring unexpected size increment factor (n) for axis '{}'"
3091 + " of tensor '{}' with size reference.",
3092 a.id,
3093 t_descr.id,
3094 )
3095 assert not isinstance(a, BatchAxis)
3096 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3097 assert not isinstance(ref_axis, BatchAxis)
3098 ref_key = (a.size.tensor_id, a.size.axis_id)
3099 ref_size = inputs.get(ref_key, outputs.get(ref_key))
3100 assert ref_size is not None, ref_key
3101 assert not isinstance(ref_size, _DataDepSize), ref_key
3102 return a.size.get_size(
3103 axis=a,
3104 ref_axis=ref_axis,
3105 ref_size=ref_size,
3106 )
3107 elif isinstance(a.size, DataDependentSize):
3108 if (t_descr.id, a.id) in ns:
3109 logger.warning(
3110 "Ignoring unexpected increment factor (n) for data dependent"
3111 + " size axis '{}' of tensor '{}'.",
3112 a.id,
3113 t_descr.id,
3114 )
3115 return _DataDepSize(a.size.min, a.size.max)
3116 else:
3117 assert_never(a.size)
3119 # first resolve all , but the `SizeReference` input sizes
3120 for t_descr in self.inputs:
3121 for a in t_descr.axes:
3122 if not isinstance(a.size, SizeReference):
3123 s = get_axis_size(a)
3124 assert not isinstance(s, _DataDepSize)
3125 inputs[t_descr.id, a.id] = s
3127 # resolve all other input axis sizes
3128 for t_descr in self.inputs:
3129 for a in t_descr.axes:
3130 if isinstance(a.size, SizeReference):
3131 s = get_axis_size(a)
3132 assert not isinstance(s, _DataDepSize)
3133 inputs[t_descr.id, a.id] = s
3135 # resolve all output axis sizes
3136 for t_descr in self.outputs:
3137 for a in t_descr.axes:
3138 assert not isinstance(a.size, ParameterizedSize)
3139 s = get_axis_size(a)
3140 outputs[t_descr.id, a.id] = s
3142 return _AxisSizes(inputs=inputs, outputs=outputs)
3144 @model_validator(mode="before")
3145 @classmethod
3146 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3147 cls.convert_from_old_format_wo_validation(data)
3148 return data
3150 @classmethod
3151 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3152 """Convert metadata following an older format version to this classes' format
3153 without validating the result.
3154 """
3155 if (
3156 data.get("type") == "model"
3157 and isinstance(fv := data.get("format_version"), str)
3158 and fv.count(".") == 2
3159 ):
3160 fv_parts = fv.split(".")
3161 if any(not p.isdigit() for p in fv_parts):
3162 return
3164 fv_tuple = tuple(map(int, fv_parts))
3166 assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3167 if fv_tuple[:2] in ((0, 3), (0, 4)):
3168 m04 = _ModelDescr_v0_4.load(data)
3169 if isinstance(m04, InvalidDescr):
3170 try:
3171 updated = _model_conv.convert_as_dict(
3172 m04 # pyright: ignore[reportArgumentType]
3173 )
3174 except Exception as e:
3175 logger.error(
3176 "Failed to convert from invalid model 0.4 description."
3177 + f"\nerror: {e}"
3178 + "\nProceeding with model 0.5 validation without conversion."
3179 )
3180 updated = None
3181 else:
3182 updated = _model_conv.convert_as_dict(m04)
3184 if updated is not None:
3185 data.clear()
3186 data.update(updated)
3188 elif fv_tuple[:2] == (0, 5):
3189 # bump patch version
3190 data["format_version"] = cls.implemented_format_version
3193class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3194 def _convert(
3195 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3196 ) -> "ModelDescr | dict[str, Any]":
3197 name = "".join(
3198 c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3199 for c in src.name
3200 )
3202 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3203 conv = (
3204 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3205 )
3206 return None if auths is None else [conv(a) for a in auths]
3208 if TYPE_CHECKING:
3209 arch_file_conv = _arch_file_conv.convert
3210 arch_lib_conv = _arch_lib_conv.convert
3211 else:
3212 arch_file_conv = _arch_file_conv.convert_as_dict
3213 arch_lib_conv = _arch_lib_conv.convert_as_dict
3215 input_size_refs = {
3216 ipt.name: {
3217 a: s
3218 for a, s in zip(
3219 ipt.axes,
3220 (
3221 ipt.shape.min
3222 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3223 else ipt.shape
3224 ),
3225 )
3226 }
3227 for ipt in src.inputs
3228 if ipt.shape
3229 }
3230 output_size_refs = {
3231 **{
3232 out.name: {a: s for a, s in zip(out.axes, out.shape)}
3233 for out in src.outputs
3234 if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3235 },
3236 **input_size_refs,
3237 }
3239 return tgt(
3240 attachments=(
3241 []
3242 if src.attachments is None
3243 else [FileDescr(source=f) for f in src.attachments.files]
3244 ),
3245 authors=[
3246 _author_conv.convert_as_dict(a) for a in src.authors
3247 ], # pyright: ignore[reportArgumentType]
3248 cite=[
3249 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite
3250 ], # pyright: ignore[reportArgumentType]
3251 config=src.config, # pyright: ignore[reportArgumentType]
3252 covers=src.covers,
3253 description=src.description,
3254 documentation=src.documentation,
3255 format_version="0.5.4",
3256 git_repo=src.git_repo, # pyright: ignore[reportArgumentType]
3257 icon=src.icon,
3258 id=None if src.id is None else ModelId(src.id),
3259 id_emoji=src.id_emoji,
3260 license=src.license, # type: ignore
3261 links=src.links,
3262 maintainers=[
3263 _maintainer_conv.convert_as_dict(m) for m in src.maintainers
3264 ], # pyright: ignore[reportArgumentType]
3265 name=name,
3266 tags=src.tags,
3267 type=src.type,
3268 uploader=src.uploader,
3269 version=src.version,
3270 inputs=[ # pyright: ignore[reportArgumentType]
3271 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3272 for ipt, tt, st, in zip(
3273 src.inputs,
3274 src.test_inputs,
3275 src.sample_inputs or [None] * len(src.test_inputs),
3276 )
3277 ],
3278 outputs=[ # pyright: ignore[reportArgumentType]
3279 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3280 for out, tt, st, in zip(
3281 src.outputs,
3282 src.test_outputs,
3283 src.sample_outputs or [None] * len(src.test_outputs),
3284 )
3285 ],
3286 parent=(
3287 None
3288 if src.parent is None
3289 else LinkedModel(
3290 id=ModelId(
3291 str(src.parent.id)
3292 + (
3293 ""
3294 if src.parent.version_number is None
3295 else f"/{src.parent.version_number}"
3296 )
3297 )
3298 )
3299 ),
3300 training_data=(
3301 None
3302 if src.training_data is None
3303 else (
3304 LinkedDataset(
3305 id=DatasetId(
3306 str(src.training_data.id)
3307 + (
3308 ""
3309 if src.training_data.version_number is None
3310 else f"/{src.training_data.version_number}"
3311 )
3312 )
3313 )
3314 if isinstance(src.training_data, LinkedDataset02)
3315 else src.training_data
3316 )
3317 ),
3318 packaged_by=[
3319 _author_conv.convert_as_dict(a) for a in src.packaged_by
3320 ], # pyright: ignore[reportArgumentType]
3321 run_mode=src.run_mode,
3322 timestamp=src.timestamp,
3323 weights=(WeightsDescr if TYPE_CHECKING else dict)(
3324 keras_hdf5=(w := src.weights.keras_hdf5)
3325 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3326 authors=conv_authors(w.authors),
3327 source=w.source,
3328 tensorflow_version=w.tensorflow_version or Version("1.15"),
3329 parent=w.parent,
3330 ),
3331 onnx=(w := src.weights.onnx)
3332 and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3333 source=w.source,
3334 authors=conv_authors(w.authors),
3335 parent=w.parent,
3336 opset_version=w.opset_version or 15,
3337 ),
3338 pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3339 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3340 source=w.source,
3341 authors=conv_authors(w.authors),
3342 parent=w.parent,
3343 architecture=(
3344 arch_file_conv(
3345 w.architecture,
3346 w.architecture_sha256,
3347 w.kwargs,
3348 )
3349 if isinstance(w.architecture, _CallableFromFile_v0_4)
3350 else arch_lib_conv(w.architecture, w.kwargs)
3351 ),
3352 pytorch_version=w.pytorch_version or Version("1.10"),
3353 dependencies=(
3354 None
3355 if w.dependencies is None
3356 else (FileDescr if TYPE_CHECKING else dict)(
3357 source=cast(
3358 FileSource,
3359 str(deps := w.dependencies)[
3360 (
3361 len("conda:")
3362 if str(deps).startswith("conda:")
3363 else 0
3364 ) :
3365 ],
3366 )
3367 )
3368 ),
3369 ),
3370 tensorflow_js=(w := src.weights.tensorflow_js)
3371 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3372 source=w.source,
3373 authors=conv_authors(w.authors),
3374 parent=w.parent,
3375 tensorflow_version=w.tensorflow_version or Version("1.15"),
3376 ),
3377 tensorflow_saved_model_bundle=(
3378 w := src.weights.tensorflow_saved_model_bundle
3379 )
3380 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3381 authors=conv_authors(w.authors),
3382 parent=w.parent,
3383 source=w.source,
3384 tensorflow_version=w.tensorflow_version or Version("1.15"),
3385 dependencies=(
3386 None
3387 if w.dependencies is None
3388 else (FileDescr if TYPE_CHECKING else dict)(
3389 source=cast(
3390 FileSource,
3391 (
3392 str(w.dependencies)[len("conda:") :]
3393 if str(w.dependencies).startswith("conda:")
3394 else str(w.dependencies)
3395 ),
3396 )
3397 )
3398 ),
3399 ),
3400 torchscript=(w := src.weights.torchscript)
3401 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3402 source=w.source,
3403 authors=conv_authors(w.authors),
3404 parent=w.parent,
3405 pytorch_version=w.pytorch_version or Version("1.10"),
3406 ),
3407 ),
3408 )
3411_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3414# create better cover images for 3d data and non-image outputs
3415def generate_covers(
3416 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3417 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3418) -> List[Path]:
3419 def squeeze(
3420 data: NDArray[Any], axes: Sequence[AnyAxis]
3421 ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3422 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3423 if data.ndim != len(axes):
3424 raise ValueError(
3425 f"tensor shape {data.shape} does not match described axes"
3426 + f" {[a.id for a in axes]}"
3427 )
3429 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3430 return data.squeeze(), axes
3432 def normalize(
3433 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3434 ) -> NDArray[np.float32]:
3435 data = data.astype("float32")
3436 data -= data.min(axis=axis, keepdims=True)
3437 data /= data.max(axis=axis, keepdims=True) + eps
3438 return data
3440 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3441 original_shape = data.shape
3442 data, axes = squeeze(data, axes)
3444 # take slice fom any batch or index axis if needed
3445 # and convert the first channel axis and take a slice from any additional channel axes
3446 slices: Tuple[slice, ...] = ()
3447 ndim = data.ndim
3448 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3449 has_c_axis = False
3450 for i, a in enumerate(axes):
3451 s = data.shape[i]
3452 assert s > 1
3453 if (
3454 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3455 and ndim > ndim_need
3456 ):
3457 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3458 ndim -= 1
3459 elif isinstance(a, ChannelAxis):
3460 if has_c_axis:
3461 # second channel axis
3462 data = data[slices + (slice(0, 1),)]
3463 ndim -= 1
3464 else:
3465 has_c_axis = True
3466 if s == 2:
3467 # visualize two channels with cyan and magenta
3468 data = np.concatenate(
3469 [
3470 data[slices + (slice(1, 2),)],
3471 data[slices + (slice(0, 1),)],
3472 (
3473 data[slices + (slice(0, 1),)]
3474 + data[slices + (slice(1, 2),)]
3475 )
3476 / 2, # TODO: take maximum instead?
3477 ],
3478 axis=i,
3479 )
3480 elif data.shape[i] == 3:
3481 pass # visualize 3 channels as RGB
3482 else:
3483 # visualize first 3 channels as RGB
3484 data = data[slices + (slice(3),)]
3486 assert data.shape[i] == 3
3488 slices += (slice(None),)
3490 data, axes = squeeze(data, axes)
3491 assert len(axes) == ndim
3492 # take slice from z axis if needed
3493 slices = ()
3494 if ndim > ndim_need:
3495 for i, a in enumerate(axes):
3496 s = data.shape[i]
3497 if a.id == AxisId("z"):
3498 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3499 data, axes = squeeze(data, axes)
3500 ndim -= 1
3501 break
3503 slices += (slice(None),)
3505 # take slice from any space or time axis
3506 slices = ()
3508 for i, a in enumerate(axes):
3509 if ndim <= ndim_need:
3510 break
3512 s = data.shape[i]
3513 assert s > 1
3514 if isinstance(
3515 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3516 ):
3517 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3518 ndim -= 1
3520 slices += (slice(None),)
3522 del slices
3523 data, axes = squeeze(data, axes)
3524 assert len(axes) == ndim
3526 if (has_c_axis and ndim != 3) or ndim != 2:
3527 raise ValueError(
3528 f"Failed to construct cover image from shape {original_shape}"
3529 )
3531 if not has_c_axis:
3532 assert ndim == 2
3533 data = np.repeat(data[:, :, None], 3, axis=2)
3534 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3535 ndim += 1
3537 assert ndim == 3
3539 # transpose axis order such that longest axis comes first...
3540 axis_order: List[int] = list(np.argsort(list(data.shape)))
3541 axis_order.reverse()
3542 # ... and channel axis is last
3543 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3544 axis_order.append(axis_order.pop(c))
3545 axes = [axes[ao] for ao in axis_order]
3546 data = data.transpose(axis_order)
3548 # h, w = data.shape[:2]
3549 # if h / w in (1.0 or 2.0):
3550 # pass
3551 # elif h / w < 2:
3552 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3554 norm_along = (
3555 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3556 )
3557 # normalize the data and map to 8 bit
3558 data = normalize(data, norm_along)
3559 data = (data * 255).astype("uint8")
3561 return data
3563 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3564 assert im0.dtype == im1.dtype == np.uint8
3565 assert im0.shape == im1.shape
3566 assert im0.ndim == 3
3567 N, M, C = im0.shape
3568 assert C == 3
3569 out = np.ones((N, M, C), dtype="uint8")
3570 for c in range(C):
3571 outc = np.tril(im0[..., c])
3572 mask = outc == 0
3573 outc[mask] = np.triu(im1[..., c])[mask]
3574 out[..., c] = outc
3576 return out
3578 ipt_descr, ipt = inputs[0]
3579 out_descr, out = outputs[0]
3581 ipt_img = to_2d_image(ipt, ipt_descr.axes)
3582 out_img = to_2d_image(out, out_descr.axes)
3584 cover_folder = Path(mkdtemp())
3585 if ipt_img.shape == out_img.shape:
3586 covers = [cover_folder / "cover.png"]
3587 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3588 else:
3589 covers = [cover_folder / "input.png", cover_folder / "output.png"]
3590 imwrite(covers[0], ipt_img)
3591 imwrite(covers[1], out_img)
3593 return covers