Coverage for bioimageio/spec/model/v0_5.py: 75%
1311 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 14:21 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 14:21 +0000
1from __future__ import annotations
3import collections.abc
4import re
5import string
6import warnings
7from abc import ABC
8from copy import deepcopy
9from itertools import chain
10from math import ceil
11from pathlib import Path, PurePosixPath
12from tempfile import mkdtemp
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 ClassVar,
17 Dict,
18 Generic,
19 List,
20 Literal,
21 Mapping,
22 NamedTuple,
23 Optional,
24 Sequence,
25 Set,
26 Tuple,
27 Type,
28 TypeVar,
29 Union,
30 cast,
31)
33import numpy as np
34from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
35from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
36from loguru import logger
37from numpy.typing import NDArray
38from pydantic import (
39 AfterValidator,
40 Discriminator,
41 Field,
42 RootModel,
43 Tag,
44 ValidationInfo,
45 WrapSerializer,
46 field_validator,
47 model_validator,
48)
49from typing_extensions import Annotated, Self, assert_never, get_args
51from .._internal.common_nodes import (
52 InvalidDescr,
53 Node,
54 NodeWithExplicitlySetFields,
55)
56from .._internal.constants import DTYPE_LIMITS
57from .._internal.field_warning import issue_warning, warn
58from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
59from .._internal.io import FileDescr as FileDescr
60from .._internal.io import WithSuffix, YamlValue, download
61from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath
62from .._internal.io_basics import Sha256 as Sha256
63from .._internal.io_utils import load_array
64from .._internal.node_converter import Converter
65from .._internal.types import (
66 AbsoluteTolerance,
67 ImportantFileSource,
68 LowerCaseIdentifier,
69 LowerCaseIdentifierAnno,
70 MismatchedElementsPerMillion,
71 RelativeTolerance,
72 SiUnit,
73)
74from .._internal.types import Datetime as Datetime
75from .._internal.types import Identifier as Identifier
76from .._internal.types import NotEmpty as NotEmpty
77from .._internal.url import HttpUrl as HttpUrl
78from .._internal.validation_context import get_validation_context
79from .._internal.validator_annotations import RestrictCharacters
80from .._internal.version_type import Version as Version
81from .._internal.warning_levels import INFO
82from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
83from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
84from ..dataset.v0_3 import DatasetDescr as DatasetDescr
85from ..dataset.v0_3 import DatasetId as DatasetId
86from ..dataset.v0_3 import LinkedDataset as LinkedDataset
87from ..dataset.v0_3 import Uploader as Uploader
88from ..generic.v0_3 import (
89 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
90)
91from ..generic.v0_3 import Author as Author
92from ..generic.v0_3 import BadgeDescr as BadgeDescr
93from ..generic.v0_3 import CiteEntry as CiteEntry
94from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
95from ..generic.v0_3 import (
96 DocumentationSource,
97 GenericModelDescrBase,
98 LinkedResourceBase,
99 _author_conv, # pyright: ignore[reportPrivateUsage]
100 _maintainer_conv, # pyright: ignore[reportPrivateUsage]
101)
102from ..generic.v0_3 import Doi as Doi
103from ..generic.v0_3 import LicenseId as LicenseId
104from ..generic.v0_3 import LinkedResource as LinkedResource
105from ..generic.v0_3 import Maintainer as Maintainer
106from ..generic.v0_3 import OrcidId as OrcidId
107from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
108from ..generic.v0_3 import ResourceId as ResourceId
109from .v0_4 import Author as _Author_v0_4
110from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
111from .v0_4 import CallableFromDepencency as CallableFromDepencency
112from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
113from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
114from .v0_4 import ClipDescr as _ClipDescr_v0_4
115from .v0_4 import ClipKwargs as ClipKwargs
116from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
117from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
118from .v0_4 import KnownRunMode as KnownRunMode
119from .v0_4 import ModelDescr as _ModelDescr_v0_4
120from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
121from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
122from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
123from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
124from .v0_4 import ProcessingKwargs as ProcessingKwargs
125from .v0_4 import RunMode as RunMode
126from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
127from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
128from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
129from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
130from .v0_4 import TensorName as _TensorName_v0_4
131from .v0_4 import WeightsFormat as WeightsFormat
132from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
133from .v0_4 import package_weights
135SpaceUnit = Literal[
136 "attometer",
137 "angstrom",
138 "centimeter",
139 "decimeter",
140 "exameter",
141 "femtometer",
142 "foot",
143 "gigameter",
144 "hectometer",
145 "inch",
146 "kilometer",
147 "megameter",
148 "meter",
149 "micrometer",
150 "mile",
151 "millimeter",
152 "nanometer",
153 "parsec",
154 "petameter",
155 "picometer",
156 "terameter",
157 "yard",
158 "yoctometer",
159 "yottameter",
160 "zeptometer",
161 "zettameter",
162]
163"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
165TimeUnit = Literal[
166 "attosecond",
167 "centisecond",
168 "day",
169 "decisecond",
170 "exasecond",
171 "femtosecond",
172 "gigasecond",
173 "hectosecond",
174 "hour",
175 "kilosecond",
176 "megasecond",
177 "microsecond",
178 "millisecond",
179 "minute",
180 "nanosecond",
181 "petasecond",
182 "picosecond",
183 "second",
184 "terasecond",
185 "yoctosecond",
186 "yottasecond",
187 "zeptosecond",
188 "zettasecond",
189]
190"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
192AxisType = Literal["batch", "channel", "index", "time", "space"]
194_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
195 "b": "batch",
196 "t": "time",
197 "i": "index",
198 "c": "channel",
199 "x": "space",
200 "y": "space",
201 "z": "space",
202}
204_AXIS_ID_MAP = {
205 "b": "batch",
206 "t": "time",
207 "i": "index",
208 "c": "channel",
209}
212class TensorId(LowerCaseIdentifier):
213 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
214 Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
215 ]
218def _normalize_axis_id(a: str):
219 a = str(a)
220 normalized = _AXIS_ID_MAP.get(a, a)
221 if a != normalized:
222 logger.opt(depth=3).warning(
223 "Normalized axis id from '{}' to '{}'.", a, normalized
224 )
225 return normalized
228class AxisId(LowerCaseIdentifier):
229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
230 Annotated[
231 LowerCaseIdentifierAnno,
232 MaxLen(16),
233 AfterValidator(_normalize_axis_id),
234 ]
235 ]
238def _is_batch(a: str) -> bool:
239 return str(a) == "batch"
242def _is_not_batch(a: str) -> bool:
243 return not _is_batch(a)
246NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
248PostprocessingId = Literal[
249 "binarize",
250 "clip",
251 "ensure_dtype",
252 "fixed_zero_mean_unit_variance",
253 "scale_linear",
254 "scale_mean_variance",
255 "scale_range",
256 "sigmoid",
257 "zero_mean_unit_variance",
258]
259PreprocessingId = Literal[
260 "binarize",
261 "clip",
262 "ensure_dtype",
263 "scale_linear",
264 "sigmoid",
265 "zero_mean_unit_variance",
266 "scale_range",
267]
270SAME_AS_TYPE = "<same as type>"
273ParameterizedSize_N = int
274"""
275Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
276"""
279class ParameterizedSize(Node):
280 """Describes a range of valid tensor axis sizes as `size = min + n*step`.
282 - **min** and **step** are given by the model description.
283 - All blocksize paramters n = 0,1,2,... yield a valid `size`.
284 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
285 This allows to adjust the axis size more generically.
286 """
288 N: ClassVar[Type[int]] = ParameterizedSize_N
289 """Positive integer to parameterize this axis"""
291 min: Annotated[int, Gt(0)]
292 step: Annotated[int, Gt(0)]
294 def validate_size(self, size: int) -> int:
295 if size < self.min:
296 raise ValueError(f"size {size} < {self.min}")
297 if (size - self.min) % self.step != 0:
298 raise ValueError(
299 f"axis of size {size} is not parameterized by `min + n*step` ="
300 + f" `{self.min} + n*{self.step}`"
301 )
303 return size
305 def get_size(self, n: ParameterizedSize_N) -> int:
306 return self.min + self.step * n
308 def get_n(self, s: int) -> ParameterizedSize_N:
309 """return smallest n parameterizing a size greater or equal than `s`"""
310 return ceil((s - self.min) / self.step)
313class DataDependentSize(Node):
314 min: Annotated[int, Gt(0)] = 1
315 max: Annotated[Optional[int], Gt(1)] = None
317 @model_validator(mode="after")
318 def _validate_max_gt_min(self):
319 if self.max is not None and self.min >= self.max:
320 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
322 return self
324 def validate_size(self, size: int) -> int:
325 if size < self.min:
326 raise ValueError(f"size {size} < {self.min}")
328 if self.max is not None and size > self.max:
329 raise ValueError(f"size {size} > {self.max}")
331 return size
334class SizeReference(Node):
335 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
337 `axis.size = reference.size * reference.scale / axis.scale + offset`
339 Note:
340 1. The axis and the referenced axis need to have the same unit (or no unit).
341 2. Batch axes may not be referenced.
342 3. Fractions are rounded down.
343 4. If the reference axis is `concatenable` the referencing axis is assumed to be
344 `concatenable` as well with the same block order.
346 Example:
347 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
348 Let's assume that we want to express the image height h in relation to its width w
349 instead of only accepting input images of exactly 100*49 pixels
350 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
352 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
353 >>> h = SpaceInputAxis(
354 ... id=AxisId("h"),
355 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
356 ... unit="millimeter",
357 ... scale=4,
358 ... )
359 >>> print(h.size.get_size(h, w))
360 49
362 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
363 """
365 tensor_id: TensorId
366 """tensor id of the reference axis"""
368 axis_id: AxisId
369 """axis id of the reference axis"""
371 offset: int = 0
373 def get_size(
374 self,
375 axis: Union[
376 ChannelAxis,
377 IndexInputAxis,
378 IndexOutputAxis,
379 TimeInputAxis,
380 SpaceInputAxis,
381 TimeOutputAxis,
382 TimeOutputAxisWithHalo,
383 SpaceOutputAxis,
384 SpaceOutputAxisWithHalo,
385 ],
386 ref_axis: Union[
387 ChannelAxis,
388 IndexInputAxis,
389 IndexOutputAxis,
390 TimeInputAxis,
391 SpaceInputAxis,
392 TimeOutputAxis,
393 TimeOutputAxisWithHalo,
394 SpaceOutputAxis,
395 SpaceOutputAxisWithHalo,
396 ],
397 n: ParameterizedSize_N = 0,
398 ref_size: Optional[int] = None,
399 ):
400 """Compute the concrete size for a given axis and its reference axis.
402 Args:
403 axis: The axis this `SizeReference` is the size of.
404 ref_axis: The reference axis to compute the size from.
405 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
406 and no fixed **ref_size** is given,
407 **n** is used to compute the size of the parameterized **ref_axis**.
408 ref_size: Overwrite the reference size instead of deriving it from
409 **ref_axis**
410 (**ref_axis.scale** is still used; any given **n** is ignored).
411 """
412 assert (
413 axis.size == self
414 ), "Given `axis.size` is not defined by this `SizeReference`"
416 assert (
417 ref_axis.id == self.axis_id
418 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
420 assert axis.unit == ref_axis.unit, (
421 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
422 f" but {axis.unit}!={ref_axis.unit}"
423 )
424 if ref_size is None:
425 if isinstance(ref_axis.size, (int, float)):
426 ref_size = ref_axis.size
427 elif isinstance(ref_axis.size, ParameterizedSize):
428 ref_size = ref_axis.size.get_size(n)
429 elif isinstance(ref_axis.size, DataDependentSize):
430 raise ValueError(
431 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
432 )
433 elif isinstance(ref_axis.size, SizeReference):
434 raise ValueError(
435 "Reference axis referenced in `SizeReference` may not be sized by a"
436 + " `SizeReference` itself."
437 )
438 else:
439 assert_never(ref_axis.size)
441 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
443 @staticmethod
444 def _get_unit(
445 axis: Union[
446 ChannelAxis,
447 IndexInputAxis,
448 IndexOutputAxis,
449 TimeInputAxis,
450 SpaceInputAxis,
451 TimeOutputAxis,
452 TimeOutputAxisWithHalo,
453 SpaceOutputAxis,
454 SpaceOutputAxisWithHalo,
455 ],
456 ):
457 return axis.unit
460class AxisBase(NodeWithExplicitlySetFields):
461 id: AxisId
462 """An axis id unique across all axes of one tensor."""
464 description: Annotated[str, MaxLen(128)] = ""
467class WithHalo(Node):
468 halo: Annotated[int, Ge(1)]
469 """The halo should be cropped from the output tensor to avoid boundary effects.
470 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
471 To document a halo that is already cropped by the model use `size.offset` instead."""
473 size: Annotated[
474 SizeReference,
475 Field(
476 examples=[
477 10,
478 SizeReference(
479 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
480 ).model_dump(mode="json"),
481 ]
482 ),
483 ]
484 """reference to another axis with an optional offset (see `SizeReference`)"""
487BATCH_AXIS_ID = AxisId("batch")
490class BatchAxis(AxisBase):
491 implemented_type: ClassVar[Literal["batch"]] = "batch"
492 if TYPE_CHECKING:
493 type: Literal["batch"] = "batch"
494 else:
495 type: Literal["batch"]
497 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
498 size: Optional[Literal[1]] = None
499 """The batch size may be fixed to 1,
500 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
502 @property
503 def scale(self):
504 return 1.0
506 @property
507 def concatenable(self):
508 return True
510 @property
511 def unit(self):
512 return None
515class ChannelAxis(AxisBase):
516 implemented_type: ClassVar[Literal["channel"]] = "channel"
517 if TYPE_CHECKING:
518 type: Literal["channel"] = "channel"
519 else:
520 type: Literal["channel"]
522 id: NonBatchAxisId = AxisId("channel")
523 channel_names: NotEmpty[List[Identifier]]
525 @property
526 def size(self) -> int:
527 return len(self.channel_names)
529 @property
530 def concatenable(self):
531 return False
533 @property
534 def scale(self) -> float:
535 return 1.0
537 @property
538 def unit(self):
539 return None
542class IndexAxisBase(AxisBase):
543 implemented_type: ClassVar[Literal["index"]] = "index"
544 if TYPE_CHECKING:
545 type: Literal["index"] = "index"
546 else:
547 type: Literal["index"]
549 id: NonBatchAxisId = AxisId("index")
551 @property
552 def scale(self) -> float:
553 return 1.0
555 @property
556 def unit(self):
557 return None
560class _WithInputAxisSize(Node):
561 size: Annotated[
562 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
563 Field(
564 examples=[
565 10,
566 ParameterizedSize(min=32, step=16).model_dump(mode="json"),
567 SizeReference(
568 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
569 ).model_dump(mode="json"),
570 ]
571 ),
572 ]
573 """The size/length of this axis can be specified as
574 - fixed integer
575 - parameterized series of valid sizes (`ParameterizedSize`)
576 - reference to another axis with an optional offset (`SizeReference`)
577 """
580class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
581 concatenable: bool = False
582 """If a model has a `concatenable` input axis, it can be processed blockwise,
583 splitting a longer sample axis into blocks matching its input tensor description.
584 Output axes are concatenable if they have a `SizeReference` to a concatenable
585 input axis.
586 """
589class IndexOutputAxis(IndexAxisBase):
590 size: Annotated[
591 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
592 Field(
593 examples=[
594 10,
595 SizeReference(
596 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
597 ).model_dump(mode="json"),
598 ]
599 ),
600 ]
601 """The size/length of this axis can be specified as
602 - fixed integer
603 - reference to another axis with an optional offset (`SizeReference`)
604 - data dependent size using `DataDependentSize` (size is only known after model inference)
605 """
608class TimeAxisBase(AxisBase):
609 implemented_type: ClassVar[Literal["time"]] = "time"
610 if TYPE_CHECKING:
611 type: Literal["time"] = "time"
612 else:
613 type: Literal["time"]
615 id: NonBatchAxisId = AxisId("time")
616 unit: Optional[TimeUnit] = None
617 scale: Annotated[float, Gt(0)] = 1.0
620class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
621 concatenable: bool = False
622 """If a model has a `concatenable` input axis, it can be processed blockwise,
623 splitting a longer sample axis into blocks matching its input tensor description.
624 Output axes are concatenable if they have a `SizeReference` to a concatenable
625 input axis.
626 """
629class SpaceAxisBase(AxisBase):
630 implemented_type: ClassVar[Literal["space"]] = "space"
631 if TYPE_CHECKING:
632 type: Literal["space"] = "space"
633 else:
634 type: Literal["space"]
636 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
637 unit: Optional[SpaceUnit] = None
638 scale: Annotated[float, Gt(0)] = 1.0
641class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
642 concatenable: bool = False
643 """If a model has a `concatenable` input axis, it can be processed blockwise,
644 splitting a longer sample axis into blocks matching its input tensor description.
645 Output axes are concatenable if they have a `SizeReference` to a concatenable
646 input axis.
647 """
650INPUT_AXIS_TYPES = (
651 BatchAxis,
652 ChannelAxis,
653 IndexInputAxis,
654 TimeInputAxis,
655 SpaceInputAxis,
656)
657"""intended for isinstance comparisons in py<3.10"""
659_InputAxisUnion = Union[
660 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
661]
662InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
665class _WithOutputAxisSize(Node):
666 size: Annotated[
667 Union[Annotated[int, Gt(0)], SizeReference],
668 Field(
669 examples=[
670 10,
671 SizeReference(
672 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
673 ).model_dump(mode="json"),
674 ]
675 ),
676 ]
677 """The size/length of this axis can be specified as
678 - fixed integer
679 - reference to another axis with an optional offset (see `SizeReference`)
680 """
683class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
684 pass
687class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
688 pass
691def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
692 if isinstance(v, dict):
693 return "with_halo" if "halo" in v else "wo_halo"
694 else:
695 return "with_halo" if hasattr(v, "halo") else "wo_halo"
698_TimeOutputAxisUnion = Annotated[
699 Union[
700 Annotated[TimeOutputAxis, Tag("wo_halo")],
701 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
702 ],
703 Discriminator(_get_halo_axis_discriminator_value),
704]
707class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
708 pass
711class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
712 pass
715_SpaceOutputAxisUnion = Annotated[
716 Union[
717 Annotated[SpaceOutputAxis, Tag("wo_halo")],
718 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
719 ],
720 Discriminator(_get_halo_axis_discriminator_value),
721]
724_OutputAxisUnion = Union[
725 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
726]
727OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
729OUTPUT_AXIS_TYPES = (
730 BatchAxis,
731 ChannelAxis,
732 IndexOutputAxis,
733 TimeOutputAxis,
734 TimeOutputAxisWithHalo,
735 SpaceOutputAxis,
736 SpaceOutputAxisWithHalo,
737)
738"""intended for isinstance comparisons in py<3.10"""
741AnyAxis = Union[InputAxis, OutputAxis]
743ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
744"""intended for isinstance comparisons in py<3.10"""
746TVs = Union[
747 NotEmpty[List[int]],
748 NotEmpty[List[float]],
749 NotEmpty[List[bool]],
750 NotEmpty[List[str]],
751]
754NominalOrOrdinalDType = Literal[
755 "float32",
756 "float64",
757 "uint8",
758 "int8",
759 "uint16",
760 "int16",
761 "uint32",
762 "int32",
763 "uint64",
764 "int64",
765 "bool",
766]
769class NominalOrOrdinalDataDescr(Node):
770 values: TVs
771 """A fixed set of nominal or an ascending sequence of ordinal values.
772 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
773 String `values` are interpreted as labels for tensor values 0, ..., N.
774 Note: as YAML 1.2 does not natively support a "set" datatype,
775 nominal values should be given as a sequence (aka list/array) as well.
776 """
778 type: Annotated[
779 NominalOrOrdinalDType,
780 Field(
781 examples=[
782 "float32",
783 "uint8",
784 "uint16",
785 "int64",
786 "bool",
787 ],
788 ),
789 ] = "uint8"
791 @model_validator(mode="after")
792 def _validate_values_match_type(
793 self,
794 ) -> Self:
795 incompatible: List[Any] = []
796 for v in self.values:
797 if self.type == "bool":
798 if not isinstance(v, bool):
799 incompatible.append(v)
800 elif self.type in DTYPE_LIMITS:
801 if (
802 isinstance(v, (int, float))
803 and (
804 v < DTYPE_LIMITS[self.type].min
805 or v > DTYPE_LIMITS[self.type].max
806 )
807 or (isinstance(v, str) and "uint" not in self.type)
808 or (isinstance(v, float) and "int" in self.type)
809 ):
810 incompatible.append(v)
811 else:
812 incompatible.append(v)
814 if len(incompatible) == 5:
815 incompatible.append("...")
816 break
818 if incompatible:
819 raise ValueError(
820 f"data type '{self.type}' incompatible with values {incompatible}"
821 )
823 return self
825 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
827 @property
828 def range(self):
829 if isinstance(self.values[0], str):
830 return 0, len(self.values) - 1
831 else:
832 return min(self.values), max(self.values)
835IntervalOrRatioDType = Literal[
836 "float32",
837 "float64",
838 "uint8",
839 "int8",
840 "uint16",
841 "int16",
842 "uint32",
843 "int32",
844 "uint64",
845 "int64",
846]
849class IntervalOrRatioDataDescr(Node):
850 type: Annotated[ # todo: rename to dtype
851 IntervalOrRatioDType,
852 Field(
853 examples=["float32", "float64", "uint8", "uint16"],
854 ),
855 ] = "float32"
856 range: Tuple[Optional[float], Optional[float]] = (
857 None,
858 None,
859 )
860 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
861 `None` corresponds to min/max of what can be expressed by **type**."""
862 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
863 scale: float = 1.0
864 """Scale for data on an interval (or ratio) scale."""
865 offset: Optional[float] = None
866 """Offset for data on a ratio scale."""
869TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
872class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
873 """processing base class"""
876class BinarizeKwargs(ProcessingKwargs):
877 """key word arguments for `BinarizeDescr`"""
879 threshold: float
880 """The fixed threshold"""
883class BinarizeAlongAxisKwargs(ProcessingKwargs):
884 """key word arguments for `BinarizeDescr`"""
886 threshold: NotEmpty[List[float]]
887 """The fixed threshold values along `axis`"""
889 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
890 """The `threshold` axis"""
893class BinarizeDescr(ProcessingDescrBase):
894 """Binarize the tensor with a fixed threshold.
896 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
897 will be set to one, values below the threshold to zero.
899 Examples:
900 - in YAML
901 ```yaml
902 postprocessing:
903 - id: binarize
904 kwargs:
905 axis: 'channel'
906 threshold: [0.25, 0.5, 0.75]
907 ```
908 - in Python:
909 >>> postprocessing = [BinarizeDescr(
910 ... kwargs=BinarizeAlongAxisKwargs(
911 ... axis=AxisId('channel'),
912 ... threshold=[0.25, 0.5, 0.75],
913 ... )
914 ... )]
915 """
917 implemented_id: ClassVar[Literal["binarize"]] = "binarize"
918 if TYPE_CHECKING:
919 id: Literal["binarize"] = "binarize"
920 else:
921 id: Literal["binarize"]
922 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
925class ClipDescr(ProcessingDescrBase):
926 """Set tensor values below min to min and above max to max.
928 See `ScaleRangeDescr` for examples.
929 """
931 implemented_id: ClassVar[Literal["clip"]] = "clip"
932 if TYPE_CHECKING:
933 id: Literal["clip"] = "clip"
934 else:
935 id: Literal["clip"]
937 kwargs: ClipKwargs
940class EnsureDtypeKwargs(ProcessingKwargs):
941 """key word arguments for `EnsureDtypeDescr`"""
943 dtype: Literal[
944 "float32",
945 "float64",
946 "uint8",
947 "int8",
948 "uint16",
949 "int16",
950 "uint32",
951 "int32",
952 "uint64",
953 "int64",
954 "bool",
955 ]
958class EnsureDtypeDescr(ProcessingDescrBase):
959 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
961 This can for example be used to ensure the inner neural network model gets a
962 different input tensor data type than the fully described bioimage.io model does.
964 Examples:
965 The described bioimage.io model (incl. preprocessing) accepts any
966 float32-compatible tensor, normalizes it with percentiles and clipping and then
967 casts it to uint8, which is what the neural network in this example expects.
968 - in YAML
969 ```yaml
970 inputs:
971 - data:
972 type: float32 # described bioimage.io model is compatible with any float32 input tensor
973 preprocessing:
974 - id: scale_range
975 kwargs:
976 axes: ['y', 'x']
977 max_percentile: 99.8
978 min_percentile: 5.0
979 - id: clip
980 kwargs:
981 min: 0.0
982 max: 1.0
983 - id: ensure_dtype
984 kwargs:
985 dtype: uint8
986 ```
987 - in Python:
988 >>> preprocessing = [
989 ... ScaleRangeDescr(
990 ... kwargs=ScaleRangeKwargs(
991 ... axes= (AxisId('y'), AxisId('x')),
992 ... max_percentile= 99.8,
993 ... min_percentile= 5.0,
994 ... )
995 ... ),
996 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
997 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
998 ... ]
999 """
1001 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1002 if TYPE_CHECKING:
1003 id: Literal["ensure_dtype"] = "ensure_dtype"
1004 else:
1005 id: Literal["ensure_dtype"]
1007 kwargs: EnsureDtypeKwargs
1010class ScaleLinearKwargs(ProcessingKwargs):
1011 """Key word arguments for `ScaleLinearDescr`"""
1013 gain: float = 1.0
1014 """multiplicative factor"""
1016 offset: float = 0.0
1017 """additive term"""
1019 @model_validator(mode="after")
1020 def _validate(self) -> Self:
1021 if self.gain == 1.0 and self.offset == 0.0:
1022 raise ValueError(
1023 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1024 + " != 0.0."
1025 )
1027 return self
1030class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
1031 """Key word arguments for `ScaleLinearDescr`"""
1033 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1034 """The axis of gain and offset values."""
1036 gain: Union[float, NotEmpty[List[float]]] = 1.0
1037 """multiplicative factor"""
1039 offset: Union[float, NotEmpty[List[float]]] = 0.0
1040 """additive term"""
1042 @model_validator(mode="after")
1043 def _validate(self) -> Self:
1045 if isinstance(self.gain, list):
1046 if isinstance(self.offset, list):
1047 if len(self.gain) != len(self.offset):
1048 raise ValueError(
1049 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1050 )
1051 else:
1052 self.offset = [float(self.offset)] * len(self.gain)
1053 elif isinstance(self.offset, list):
1054 self.gain = [float(self.gain)] * len(self.offset)
1055 else:
1056 raise ValueError(
1057 "Do not specify an `axis` for scalar gain and offset values."
1058 )
1060 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1061 raise ValueError(
1062 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1063 + " != 0.0."
1064 )
1066 return self
1069class ScaleLinearDescr(ProcessingDescrBase):
1070 """Fixed linear scaling.
1072 Examples:
1073 1. Scale with scalar gain and offset
1074 - in YAML
1075 ```yaml
1076 preprocessing:
1077 - id: scale_linear
1078 kwargs:
1079 gain: 2.0
1080 offset: 3.0
1081 ```
1082 - in Python:
1083 >>> preprocessing = [
1084 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1085 ... ]
1087 2. Independent scaling along an axis
1088 - in YAML
1089 ```yaml
1090 preprocessing:
1091 - id: scale_linear
1092 kwargs:
1093 axis: 'channel'
1094 gain: [1.0, 2.0, 3.0]
1095 ```
1096 - in Python:
1097 >>> preprocessing = [
1098 ... ScaleLinearDescr(
1099 ... kwargs=ScaleLinearAlongAxisKwargs(
1100 ... axis=AxisId("channel"),
1101 ... gain=[1.0, 2.0, 3.0],
1102 ... )
1103 ... )
1104 ... ]
1106 """
1108 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1109 if TYPE_CHECKING:
1110 id: Literal["scale_linear"] = "scale_linear"
1111 else:
1112 id: Literal["scale_linear"]
1113 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1116class SigmoidDescr(ProcessingDescrBase):
1117 """The logistic sigmoid funciton, a.k.a. expit function.
1119 Examples:
1120 - in YAML
1121 ```yaml
1122 postprocessing:
1123 - id: sigmoid
1124 ```
1125 - in Python:
1126 >>> postprocessing = [SigmoidDescr()]
1127 """
1129 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1130 if TYPE_CHECKING:
1131 id: Literal["sigmoid"] = "sigmoid"
1132 else:
1133 id: Literal["sigmoid"]
1135 @property
1136 def kwargs(self) -> ProcessingKwargs:
1137 """empty kwargs"""
1138 return ProcessingKwargs()
1141class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1142 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1144 mean: float
1145 """The mean value to normalize with."""
1147 std: Annotated[float, Ge(1e-6)]
1148 """The standard deviation value to normalize with."""
1151class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1152 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1154 mean: NotEmpty[List[float]]
1155 """The mean value(s) to normalize with."""
1157 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1158 """The standard deviation value(s) to normalize with.
1159 Size must match `mean` values."""
1161 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1162 """The axis of the mean/std values to normalize each entry along that dimension
1163 separately."""
1165 @model_validator(mode="after")
1166 def _mean_and_std_match(self) -> Self:
1167 if len(self.mean) != len(self.std):
1168 raise ValueError(
1169 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1170 + " must match."
1171 )
1173 return self
1176class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1177 """Subtract a given mean and divide by the standard deviation.
1179 Normalize with fixed, precomputed values for
1180 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1181 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1182 axes.
1184 Examples:
1185 1. scalar value for whole tensor
1186 - in YAML
1187 ```yaml
1188 preprocessing:
1189 - id: fixed_zero_mean_unit_variance
1190 kwargs:
1191 mean: 103.5
1192 std: 13.7
1193 ```
1194 - in Python
1195 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1196 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1197 ... )]
1199 2. independently along an axis
1200 - in YAML
1201 ```yaml
1202 preprocessing:
1203 - id: fixed_zero_mean_unit_variance
1204 kwargs:
1205 axis: channel
1206 mean: [101.5, 102.5, 103.5]
1207 std: [11.7, 12.7, 13.7]
1208 ```
1209 - in Python
1210 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1211 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1212 ... axis=AxisId("channel"),
1213 ... mean=[101.5, 102.5, 103.5],
1214 ... std=[11.7, 12.7, 13.7],
1215 ... )
1216 ... )]
1217 """
1219 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1220 "fixed_zero_mean_unit_variance"
1221 )
1222 if TYPE_CHECKING:
1223 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1224 else:
1225 id: Literal["fixed_zero_mean_unit_variance"]
1227 kwargs: Union[
1228 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1229 ]
1232class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1233 """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1235 axes: Annotated[
1236 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1237 ] = None
1238 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1239 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1240 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1241 To normalize each sample independently leave out the 'batch' axis.
1242 Default: Scale all axes jointly."""
1244 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1245 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1248class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1249 """Subtract mean and divide by variance.
1251 Examples:
1252 Subtract tensor mean and variance
1253 - in YAML
1254 ```yaml
1255 preprocessing:
1256 - id: zero_mean_unit_variance
1257 ```
1258 - in Python
1259 >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1260 """
1262 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1263 "zero_mean_unit_variance"
1264 )
1265 if TYPE_CHECKING:
1266 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1267 else:
1268 id: Literal["zero_mean_unit_variance"]
1270 kwargs: ZeroMeanUnitVarianceKwargs = Field(
1271 default_factory=ZeroMeanUnitVarianceKwargs
1272 )
1275class ScaleRangeKwargs(ProcessingKwargs):
1276 """key word arguments for `ScaleRangeDescr`
1278 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1279 this processing step normalizes data to the [0, 1] intervall.
1280 For other percentiles the normalized values will partially be outside the [0, 1]
1281 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1282 normalized values to a range.
1283 """
1285 axes: Annotated[
1286 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1287 ] = None
1288 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1289 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1290 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1291 To normalize samples independently, leave out the "batch" axis.
1292 Default: Scale all axes jointly."""
1294 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1295 """The lower percentile used to determine the value to align with zero."""
1297 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1298 """The upper percentile used to determine the value to align with one.
1299 Has to be bigger than `min_percentile`.
1300 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1301 accepting percentiles specified in the range 0.0 to 1.0."""
1303 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1304 """Epsilon for numeric stability.
1305 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1306 with `v_lower,v_upper` values at the respective percentiles."""
1308 reference_tensor: Optional[TensorId] = None
1309 """Tensor ID to compute the percentiles from. Default: The tensor itself.
1310 For any tensor in `inputs` only input tensor references are allowed."""
1312 @field_validator("max_percentile", mode="after")
1313 @classmethod
1314 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1315 if (min_p := info.data["min_percentile"]) >= value:
1316 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1318 return value
1321class ScaleRangeDescr(ProcessingDescrBase):
1322 """Scale with percentiles.
1324 Examples:
1325 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1326 - in YAML
1327 ```yaml
1328 preprocessing:
1329 - id: scale_range
1330 kwargs:
1331 axes: ['y', 'x']
1332 max_percentile: 99.8
1333 min_percentile: 5.0
1334 ```
1335 - in Python
1336 >>> preprocessing = [
1337 ... ScaleRangeDescr(
1338 ... kwargs=ScaleRangeKwargs(
1339 ... axes= (AxisId('y'), AxisId('x')),
1340 ... max_percentile= 99.8,
1341 ... min_percentile= 5.0,
1342 ... )
1343 ... ),
1344 ... ClipDescr(
1345 ... kwargs=ClipKwargs(
1346 ... min=0.0,
1347 ... max=1.0,
1348 ... )
1349 ... ),
1350 ... ]
1352 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1353 - in YAML
1354 ```yaml
1355 preprocessing:
1356 - id: scale_range
1357 kwargs:
1358 axes: ['y', 'x']
1359 max_percentile: 99.8
1360 min_percentile: 5.0
1361 - id: scale_range
1362 - id: clip
1363 kwargs:
1364 min: 0.0
1365 max: 1.0
1366 ```
1367 - in Python
1368 >>> preprocessing = [ScaleRangeDescr(
1369 ... kwargs=ScaleRangeKwargs(
1370 ... axes= (AxisId('y'), AxisId('x')),
1371 ... max_percentile= 99.8,
1372 ... min_percentile= 5.0,
1373 ... )
1374 ... )]
1376 """
1378 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1379 if TYPE_CHECKING:
1380 id: Literal["scale_range"] = "scale_range"
1381 else:
1382 id: Literal["scale_range"]
1383 kwargs: ScaleRangeKwargs
1386class ScaleMeanVarianceKwargs(ProcessingKwargs):
1387 """key word arguments for `ScaleMeanVarianceKwargs`"""
1389 reference_tensor: TensorId
1390 """Name of tensor to match."""
1392 axes: Annotated[
1393 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1394 ] = None
1395 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1396 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1397 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1398 To normalize samples independently, leave out the 'batch' axis.
1399 Default: Scale all axes jointly."""
1401 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1402 """Epsilon for numeric stability:
1403 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1406class ScaleMeanVarianceDescr(ProcessingDescrBase):
1407 """Scale a tensor's data distribution to match another tensor's mean/std.
1408 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1409 """
1411 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1412 if TYPE_CHECKING:
1413 id: Literal["scale_mean_variance"] = "scale_mean_variance"
1414 else:
1415 id: Literal["scale_mean_variance"]
1416 kwargs: ScaleMeanVarianceKwargs
1419PreprocessingDescr = Annotated[
1420 Union[
1421 BinarizeDescr,
1422 ClipDescr,
1423 EnsureDtypeDescr,
1424 ScaleLinearDescr,
1425 SigmoidDescr,
1426 FixedZeroMeanUnitVarianceDescr,
1427 ZeroMeanUnitVarianceDescr,
1428 ScaleRangeDescr,
1429 ],
1430 Discriminator("id"),
1431]
1432PostprocessingDescr = Annotated[
1433 Union[
1434 BinarizeDescr,
1435 ClipDescr,
1436 EnsureDtypeDescr,
1437 ScaleLinearDescr,
1438 SigmoidDescr,
1439 FixedZeroMeanUnitVarianceDescr,
1440 ZeroMeanUnitVarianceDescr,
1441 ScaleRangeDescr,
1442 ScaleMeanVarianceDescr,
1443 ],
1444 Discriminator("id"),
1445]
1447IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1450class TensorDescrBase(Node, Generic[IO_AxisT]):
1451 id: TensorId
1452 """Tensor id. No duplicates are allowed."""
1454 description: Annotated[str, MaxLen(128)] = ""
1455 """free text description"""
1457 axes: NotEmpty[Sequence[IO_AxisT]]
1458 """tensor axes"""
1460 @property
1461 def shape(self):
1462 return tuple(a.size for a in self.axes)
1464 @field_validator("axes", mode="after", check_fields=False)
1465 @classmethod
1466 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1467 batch_axes = [a for a in axes if a.type == "batch"]
1468 if len(batch_axes) > 1:
1469 raise ValueError(
1470 f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1471 )
1473 seen_ids: Set[AxisId] = set()
1474 duplicate_axes_ids: Set[AxisId] = set()
1475 for a in axes:
1476 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1478 if duplicate_axes_ids:
1479 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1481 return axes
1483 test_tensor: FileDescr
1484 """An example tensor to use for testing.
1485 Using the model with the test input tensors is expected to yield the test output tensors.
1486 Each test tensor has be a an ndarray in the
1487 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1488 The file extension must be '.npy'."""
1490 sample_tensor: Optional[FileDescr] = None
1491 """A sample tensor to illustrate a possible input/output for the model,
1492 The sample image primarily serves to inform a human user about an example use case
1493 and is typically stored as .hdf5, .png or .tiff.
1494 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1495 (numpy's `.npy` format is not supported).
1496 The image dimensionality has to match the number of axes specified in this tensor description.
1497 """
1499 @model_validator(mode="after")
1500 def _validate_sample_tensor(self) -> Self:
1501 if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1502 return self
1504 local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1505 tensor: NDArray[Any] = imread(
1506 local.path.read_bytes(),
1507 extension=PurePosixPath(local.original_file_name).suffix,
1508 )
1509 n_dims = len(tensor.squeeze().shape)
1510 n_dims_min = n_dims_max = len(self.axes)
1512 for a in self.axes:
1513 if isinstance(a, BatchAxis):
1514 n_dims_min -= 1
1515 elif isinstance(a.size, int):
1516 if a.size == 1:
1517 n_dims_min -= 1
1518 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1519 if a.size.min == 1:
1520 n_dims_min -= 1
1521 elif isinstance(a.size, SizeReference):
1522 if a.size.offset < 2:
1523 # size reference may result in singleton axis
1524 n_dims_min -= 1
1525 else:
1526 assert_never(a.size)
1528 n_dims_min = max(0, n_dims_min)
1529 if n_dims < n_dims_min or n_dims > n_dims_max:
1530 raise ValueError(
1531 f"Expected sample tensor to have {n_dims_min} to"
1532 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1533 )
1535 return self
1537 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1538 IntervalOrRatioDataDescr()
1539 )
1540 """Description of the tensor's data values, optionally per channel.
1541 If specified per channel, the data `type` needs to match across channels."""
1543 @property
1544 def dtype(
1545 self,
1546 ) -> Literal[
1547 "float32",
1548 "float64",
1549 "uint8",
1550 "int8",
1551 "uint16",
1552 "int16",
1553 "uint32",
1554 "int32",
1555 "uint64",
1556 "int64",
1557 "bool",
1558 ]:
1559 """dtype as specified under `data.type` or `data[i].type`"""
1560 if isinstance(self.data, collections.abc.Sequence):
1561 return self.data[0].type
1562 else:
1563 return self.data.type
1565 @field_validator("data", mode="after")
1566 @classmethod
1567 def _check_data_type_across_channels(
1568 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1569 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1570 if not isinstance(value, list):
1571 return value
1573 dtypes = {t.type for t in value}
1574 if len(dtypes) > 1:
1575 raise ValueError(
1576 "Tensor data descriptions per channel need to agree in their data"
1577 + f" `type`, but found {dtypes}."
1578 )
1580 return value
1582 @model_validator(mode="after")
1583 def _check_data_matches_channelaxis(self) -> Self:
1584 if not isinstance(self.data, (list, tuple)):
1585 return self
1587 for a in self.axes:
1588 if isinstance(a, ChannelAxis):
1589 size = a.size
1590 assert isinstance(size, int)
1591 break
1592 else:
1593 return self
1595 if len(self.data) != size:
1596 raise ValueError(
1597 f"Got tensor data descriptions for {len(self.data)} channels, but"
1598 + f" '{a.id}' axis has size {size}."
1599 )
1601 return self
1603 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1604 if len(array.shape) != len(self.axes):
1605 raise ValueError(
1606 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1607 + f" incompatible with {len(self.axes)} axes."
1608 )
1609 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1612class InputTensorDescr(TensorDescrBase[InputAxis]):
1613 id: TensorId = TensorId("input")
1614 """Input tensor id.
1615 No duplicates are allowed across all inputs and outputs."""
1617 optional: bool = False
1618 """indicates that this tensor may be `None`"""
1620 preprocessing: List[PreprocessingDescr] = Field(default_factory=list)
1621 """Description of how this input should be preprocessed.
1623 notes:
1624 - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1625 to ensure an input tensor's data type matches the input tensor's data description.
1626 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1627 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1628 changing the data type.
1629 """
1631 @model_validator(mode="after")
1632 def _validate_preprocessing_kwargs(self) -> Self:
1633 axes_ids = [a.id for a in self.axes]
1634 for p in self.preprocessing:
1635 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1636 if kwargs_axes is None:
1637 continue
1639 if not isinstance(kwargs_axes, collections.abc.Sequence):
1640 raise ValueError(
1641 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1642 )
1644 if any(a not in axes_ids for a in kwargs_axes):
1645 raise ValueError(
1646 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1647 )
1649 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1650 dtype = self.data.type
1651 else:
1652 dtype = self.data[0].type
1654 # ensure `preprocessing` begins with `EnsureDtypeDescr`
1655 if not self.preprocessing or not isinstance(
1656 self.preprocessing[0], EnsureDtypeDescr
1657 ):
1658 self.preprocessing.insert(
1659 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1660 )
1662 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1663 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1664 self.preprocessing.append(
1665 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1666 )
1668 return self
1671def convert_axes(
1672 axes: str,
1673 *,
1674 shape: Union[
1675 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1676 ],
1677 tensor_type: Literal["input", "output"],
1678 halo: Optional[Sequence[int]],
1679 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1680):
1681 ret: List[AnyAxis] = []
1682 for i, a in enumerate(axes):
1683 axis_type = _AXIS_TYPE_MAP.get(a, a)
1684 if axis_type == "batch":
1685 ret.append(BatchAxis())
1686 continue
1688 scale = 1.0
1689 if isinstance(shape, _ParameterizedInputShape_v0_4):
1690 if shape.step[i] == 0:
1691 size = shape.min[i]
1692 else:
1693 size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1694 elif isinstance(shape, _ImplicitOutputShape_v0_4):
1695 ref_t = str(shape.reference_tensor)
1696 if ref_t.count(".") == 1:
1697 t_id, orig_a_id = ref_t.split(".")
1698 else:
1699 t_id = ref_t
1700 orig_a_id = a
1702 a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1703 if not (orig_scale := shape.scale[i]):
1704 # old way to insert a new axis dimension
1705 size = int(2 * shape.offset[i])
1706 else:
1707 scale = 1 / orig_scale
1708 if axis_type in ("channel", "index"):
1709 # these axes no longer have a scale
1710 offset_from_scale = orig_scale * size_refs.get(
1711 _TensorName_v0_4(t_id), {}
1712 ).get(orig_a_id, 0)
1713 else:
1714 offset_from_scale = 0
1715 size = SizeReference(
1716 tensor_id=TensorId(t_id),
1717 axis_id=AxisId(a_id),
1718 offset=int(offset_from_scale + 2 * shape.offset[i]),
1719 )
1720 else:
1721 size = shape[i]
1723 if axis_type == "time":
1724 if tensor_type == "input":
1725 ret.append(TimeInputAxis(size=size, scale=scale))
1726 else:
1727 assert not isinstance(size, ParameterizedSize)
1728 if halo is None:
1729 ret.append(TimeOutputAxis(size=size, scale=scale))
1730 else:
1731 assert not isinstance(size, int)
1732 ret.append(
1733 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1734 )
1736 elif axis_type == "index":
1737 if tensor_type == "input":
1738 ret.append(IndexInputAxis(size=size))
1739 else:
1740 if isinstance(size, ParameterizedSize):
1741 size = DataDependentSize(min=size.min)
1743 ret.append(IndexOutputAxis(size=size))
1744 elif axis_type == "channel":
1745 assert not isinstance(size, ParameterizedSize)
1746 if isinstance(size, SizeReference):
1747 warnings.warn(
1748 "Conversion of channel size from an implicit output shape may be"
1749 + " wrong"
1750 )
1751 ret.append(
1752 ChannelAxis(
1753 channel_names=[
1754 Identifier(f"channel{i}") for i in range(size.offset)
1755 ]
1756 )
1757 )
1758 else:
1759 ret.append(
1760 ChannelAxis(
1761 channel_names=[Identifier(f"channel{i}") for i in range(size)]
1762 )
1763 )
1764 elif axis_type == "space":
1765 if tensor_type == "input":
1766 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1767 else:
1768 assert not isinstance(size, ParameterizedSize)
1769 if halo is None or halo[i] == 0:
1770 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1771 elif isinstance(size, int):
1772 raise NotImplementedError(
1773 f"output axis with halo and fixed size (here {size}) not allowed"
1774 )
1775 else:
1776 ret.append(
1777 SpaceOutputAxisWithHalo(
1778 id=AxisId(a), size=size, scale=scale, halo=halo[i]
1779 )
1780 )
1782 return ret
1785def _axes_letters_to_ids(
1786 axes: Optional[str],
1787) -> Optional[List[AxisId]]:
1788 if axes is None:
1789 return None
1791 return [AxisId(a) for a in axes]
1794def _get_complement_v04_axis(
1795 tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1796) -> Optional[AxisId]:
1797 if axes is None:
1798 return None
1800 non_complement_axes = set(axes) | {"b"}
1801 complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1802 if len(complement_axes) > 1:
1803 raise ValueError(
1804 f"Expected none or a single complement axis, but axes '{axes}' "
1805 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1806 )
1808 return None if not complement_axes else AxisId(complement_axes[0])
1811def _convert_proc(
1812 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1813 tensor_axes: Sequence[str],
1814) -> Union[PreprocessingDescr, PostprocessingDescr]:
1815 if isinstance(p, _BinarizeDescr_v0_4):
1816 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1817 elif isinstance(p, _ClipDescr_v0_4):
1818 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1819 elif isinstance(p, _SigmoidDescr_v0_4):
1820 return SigmoidDescr()
1821 elif isinstance(p, _ScaleLinearDescr_v0_4):
1822 axes = _axes_letters_to_ids(p.kwargs.axes)
1823 if p.kwargs.axes is None:
1824 axis = None
1825 else:
1826 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1828 if axis is None:
1829 assert not isinstance(p.kwargs.gain, list)
1830 assert not isinstance(p.kwargs.offset, list)
1831 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1832 else:
1833 kwargs = ScaleLinearAlongAxisKwargs(
1834 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1835 )
1836 return ScaleLinearDescr(kwargs=kwargs)
1837 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1838 return ScaleMeanVarianceDescr(
1839 kwargs=ScaleMeanVarianceKwargs(
1840 axes=_axes_letters_to_ids(p.kwargs.axes),
1841 reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1842 eps=p.kwargs.eps,
1843 )
1844 )
1845 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1846 if p.kwargs.mode == "fixed":
1847 mean = p.kwargs.mean
1848 std = p.kwargs.std
1849 assert mean is not None
1850 assert std is not None
1852 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1854 if axis is None:
1855 return FixedZeroMeanUnitVarianceDescr(
1856 kwargs=FixedZeroMeanUnitVarianceKwargs(
1857 mean=mean, std=std # pyright: ignore[reportArgumentType]
1858 )
1859 )
1860 else:
1861 if not isinstance(mean, list):
1862 mean = [float(mean)]
1863 if not isinstance(std, list):
1864 std = [float(std)]
1866 return FixedZeroMeanUnitVarianceDescr(
1867 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1868 axis=axis, mean=mean, std=std
1869 )
1870 )
1872 else:
1873 axes = _axes_letters_to_ids(p.kwargs.axes) or []
1874 if p.kwargs.mode == "per_dataset":
1875 axes = [AxisId("batch")] + axes
1876 if not axes:
1877 axes = None
1878 return ZeroMeanUnitVarianceDescr(
1879 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1880 )
1882 elif isinstance(p, _ScaleRangeDescr_v0_4):
1883 return ScaleRangeDescr(
1884 kwargs=ScaleRangeKwargs(
1885 axes=_axes_letters_to_ids(p.kwargs.axes),
1886 min_percentile=p.kwargs.min_percentile,
1887 max_percentile=p.kwargs.max_percentile,
1888 eps=p.kwargs.eps,
1889 )
1890 )
1891 else:
1892 assert_never(p)
1895class _InputTensorConv(
1896 Converter[
1897 _InputTensorDescr_v0_4,
1898 InputTensorDescr,
1899 ImportantFileSource,
1900 Optional[ImportantFileSource],
1901 Mapping[_TensorName_v0_4, Mapping[str, int]],
1902 ]
1903):
1904 def _convert(
1905 self,
1906 src: _InputTensorDescr_v0_4,
1907 tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1908 test_tensor: ImportantFileSource,
1909 sample_tensor: Optional[ImportantFileSource],
1910 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1911 ) -> "InputTensorDescr | dict[str, Any]":
1912 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
1913 src.axes,
1914 shape=src.shape,
1915 tensor_type="input",
1916 halo=None,
1917 size_refs=size_refs,
1918 )
1919 prep: List[PreprocessingDescr] = []
1920 for p in src.preprocessing:
1921 cp = _convert_proc(p, src.axes)
1922 assert not isinstance(cp, ScaleMeanVarianceDescr)
1923 prep.append(cp)
1925 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
1927 return tgt(
1928 axes=axes,
1929 id=TensorId(str(src.name)),
1930 test_tensor=FileDescr(source=test_tensor),
1931 sample_tensor=(
1932 None if sample_tensor is None else FileDescr(source=sample_tensor)
1933 ),
1934 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType]
1935 preprocessing=prep,
1936 )
1939_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
1942class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1943 id: TensorId = TensorId("output")
1944 """Output tensor id.
1945 No duplicates are allowed across all inputs and outputs."""
1947 postprocessing: List[PostprocessingDescr] = Field(default_factory=list)
1948 """Description of how this output should be postprocessed.
1950 note: `postprocessing` always ends with an 'ensure_dtype' operation.
1951 If not given this is added to cast to this tensor's `data.type`.
1952 """
1954 @model_validator(mode="after")
1955 def _validate_postprocessing_kwargs(self) -> Self:
1956 axes_ids = [a.id for a in self.axes]
1957 for p in self.postprocessing:
1958 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1959 if kwargs_axes is None:
1960 continue
1962 if not isinstance(kwargs_axes, collections.abc.Sequence):
1963 raise ValueError(
1964 f"expected `axes` sequence, but got {type(kwargs_axes)}"
1965 )
1967 if any(a not in axes_ids for a in kwargs_axes):
1968 raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1970 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1971 dtype = self.data.type
1972 else:
1973 dtype = self.data[0].type
1975 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1976 if not self.postprocessing or not isinstance(
1977 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1978 ):
1979 self.postprocessing.append(
1980 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1981 )
1982 return self
1985class _OutputTensorConv(
1986 Converter[
1987 _OutputTensorDescr_v0_4,
1988 OutputTensorDescr,
1989 ImportantFileSource,
1990 Optional[ImportantFileSource],
1991 Mapping[_TensorName_v0_4, Mapping[str, int]],
1992 ]
1993):
1994 def _convert(
1995 self,
1996 src: _OutputTensorDescr_v0_4,
1997 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
1998 test_tensor: ImportantFileSource,
1999 sample_tensor: Optional[ImportantFileSource],
2000 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2001 ) -> "OutputTensorDescr | dict[str, Any]":
2002 # TODO: split convert_axes into convert_output_axes and convert_input_axes
2003 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2004 src.axes,
2005 shape=src.shape,
2006 tensor_type="output",
2007 halo=src.halo,
2008 size_refs=size_refs,
2009 )
2010 data_descr: Dict[str, Any] = dict(type=src.data_type)
2011 if data_descr["type"] == "bool":
2012 data_descr["values"] = [False, True]
2014 return tgt(
2015 axes=axes,
2016 id=TensorId(str(src.name)),
2017 test_tensor=FileDescr(source=test_tensor),
2018 sample_tensor=(
2019 None if sample_tensor is None else FileDescr(source=sample_tensor)
2020 ),
2021 data=data_descr, # pyright: ignore[reportArgumentType]
2022 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2023 )
2026_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2029TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2032def validate_tensors(
2033 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
2034 tensor_origin: Literal[
2035 "test_tensor"
2036 ], # for more precise error messages, e.g. 'test_tensor'
2037):
2038 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
2040 def e_msg(d: TensorDescr):
2041 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2043 for descr, array in tensors.values():
2044 try:
2045 axis_sizes = descr.get_axis_sizes_for_array(array)
2046 except ValueError as e:
2047 raise ValueError(f"{e_msg(descr)} {e}")
2048 else:
2049 all_tensor_axes[descr.id] = {
2050 a.id: (a, axis_sizes[a.id]) for a in descr.axes
2051 }
2053 for descr, array in tensors.values():
2054 if descr.dtype in ("float32", "float64"):
2055 invalid_test_tensor_dtype = array.dtype.name not in (
2056 "float32",
2057 "float64",
2058 "uint8",
2059 "int8",
2060 "uint16",
2061 "int16",
2062 "uint32",
2063 "int32",
2064 "uint64",
2065 "int64",
2066 )
2067 else:
2068 invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2070 if invalid_test_tensor_dtype:
2071 raise ValueError(
2072 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
2073 + f" match described dtype '{descr.dtype}'"
2074 )
2076 if array.min() > -1e-4 and array.max() < 1e-4:
2077 raise ValueError(
2078 "Output values are too small for reliable testing."
2079 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2080 )
2082 for a in descr.axes:
2083 actual_size = all_tensor_axes[descr.id][a.id][1]
2084 if a.size is None:
2085 continue
2087 if isinstance(a.size, int):
2088 if actual_size != a.size:
2089 raise ValueError(
2090 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
2091 + f"has incompatible size {actual_size}, expected {a.size}"
2092 )
2093 elif isinstance(a.size, ParameterizedSize):
2094 _ = a.size.validate_size(actual_size)
2095 elif isinstance(a.size, DataDependentSize):
2096 _ = a.size.validate_size(actual_size)
2097 elif isinstance(a.size, SizeReference):
2098 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2099 if ref_tensor_axes is None:
2100 raise ValueError(
2101 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2102 + f" reference '{a.size.tensor_id}'"
2103 )
2105 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2106 if ref_axis is None or ref_size is None:
2107 raise ValueError(
2108 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2109 + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
2110 )
2112 if a.unit != ref_axis.unit:
2113 raise ValueError(
2114 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
2115 + " axis and reference axis to have the same `unit`, but"
2116 + f" {a.unit}!={ref_axis.unit}"
2117 )
2119 if actual_size != (
2120 expected_size := (
2121 ref_size * ref_axis.scale / a.scale + a.size.offset
2122 )
2123 ):
2124 raise ValueError(
2125 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
2126 + f" {actual_size} invalid for referenced size {ref_size};"
2127 + f" expected {expected_size}"
2128 )
2129 else:
2130 assert_never(a.size)
2133class EnvironmentFileDescr(FileDescr):
2134 source: Annotated[
2135 ImportantFileSource,
2136 WithSuffix((".yaml", ".yml"), case_sensitive=True),
2137 Field(
2138 examples=["environment.yaml"],
2139 ),
2140 ]
2141 """∈📦 Conda environment file.
2142 Allows to specify custom dependencies, see conda docs:
2143 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2144 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2145 """
2148class _ArchitectureCallableDescr(Node):
2149 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2150 """Identifier of the callable that returns a torch.nn.Module instance."""
2152 kwargs: Dict[str, YamlValue] = Field(default_factory=dict)
2153 """key word arguments for the `callable`"""
2156class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2157 pass
2160class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2161 import_from: str
2162 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2165ArchitectureDescr = Annotated[
2166 Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr],
2167 Field(union_mode="left_to_right"),
2168]
2171class _ArchFileConv(
2172 Converter[
2173 _CallableFromFile_v0_4,
2174 ArchitectureFromFileDescr,
2175 Optional[Sha256],
2176 Dict[str, Any],
2177 ]
2178):
2179 def _convert(
2180 self,
2181 src: _CallableFromFile_v0_4,
2182 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2183 sha256: Optional[Sha256],
2184 kwargs: Dict[str, Any],
2185 ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2186 if src.startswith("http") and src.count(":") == 2:
2187 http, source, callable_ = src.split(":")
2188 source = ":".join((http, source))
2189 elif not src.startswith("http") and src.count(":") == 1:
2190 source, callable_ = src.split(":")
2191 else:
2192 source = str(src)
2193 callable_ = str(src)
2194 return tgt(
2195 callable=Identifier(callable_),
2196 source=cast(ImportantFileSource, source),
2197 sha256=sha256,
2198 kwargs=kwargs,
2199 )
2202_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2205class _ArchLibConv(
2206 Converter[
2207 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2208 ]
2209):
2210 def _convert(
2211 self,
2212 src: _CallableFromDepencency_v0_4,
2213 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2214 kwargs: Dict[str, Any],
2215 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2216 *mods, callable_ = src.split(".")
2217 import_from = ".".join(mods)
2218 return tgt(
2219 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2220 )
2223_arch_lib_conv = _ArchLibConv(
2224 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2225)
2228class WeightsEntryDescrBase(FileDescr):
2229 type: ClassVar[WeightsFormat]
2230 weights_format_name: ClassVar[str] # human readable
2232 source: ImportantFileSource
2233 """∈📦 The weights file."""
2235 authors: Optional[List[Author]] = None
2236 """Authors
2237 Either the person(s) that have trained this model resulting in the original weights file.
2238 (If this is the initial weights entry, i.e. it does not have a `parent`)
2239 Or the person(s) who have converted the weights to this weights format.
2240 (If this is a child weight, i.e. it has a `parent` field)
2241 """
2243 parent: Annotated[
2244 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2245 ] = None
2246 """The source weights these weights were converted from.
2247 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2248 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2249 All weight entries except one (the initial set of weights resulting from training the model),
2250 need to have this field."""
2252 comment: str = ""
2253 """A comment about this weights entry, for example how these weights were created."""
2255 @model_validator(mode="after")
2256 def check_parent_is_not_self(self) -> Self:
2257 if self.type == self.parent:
2258 raise ValueError("Weights entry can't be it's own parent.")
2260 return self
2263class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2264 type = "keras_hdf5"
2265 weights_format_name: ClassVar[str] = "Keras HDF5"
2266 tensorflow_version: Version
2267 """TensorFlow version used to create these weights."""
2270class OnnxWeightsDescr(WeightsEntryDescrBase):
2271 type = "onnx"
2272 weights_format_name: ClassVar[str] = "ONNX"
2273 opset_version: Annotated[int, Ge(7)]
2274 """ONNX opset version"""
2277class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2278 type = "pytorch_state_dict"
2279 weights_format_name: ClassVar[str] = "Pytorch State Dict"
2280 architecture: ArchitectureDescr
2281 pytorch_version: Version
2282 """Version of the PyTorch library used.
2283 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2284 """
2285 dependencies: Optional[EnvironmentFileDescr] = None
2286 """Custom depencies beyond pytorch.
2287 The conda environment file should include pytorch and any version pinning has to be compatible with
2288 `pytorch_version`.
2289 """
2292class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2293 type = "tensorflow_js"
2294 weights_format_name: ClassVar[str] = "Tensorflow.js"
2295 tensorflow_version: Version
2296 """Version of the TensorFlow library used."""
2298 source: ImportantFileSource
2299 """∈📦 The multi-file weights.
2300 All required files/folders should be a zip archive."""
2303class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2304 type = "tensorflow_saved_model_bundle"
2305 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2306 tensorflow_version: Version
2307 """Version of the TensorFlow library used."""
2309 dependencies: Optional[EnvironmentFileDescr] = None
2310 """Custom dependencies beyond tensorflow.
2311 Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`."""
2313 source: ImportantFileSource
2314 """∈📦 The multi-file weights.
2315 All required files/folders should be a zip archive."""
2318class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2319 type = "torchscript"
2320 weights_format_name: ClassVar[str] = "TorchScript"
2321 pytorch_version: Version
2322 """Version of the PyTorch library used."""
2325class WeightsDescr(Node):
2326 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2327 onnx: Optional[OnnxWeightsDescr] = None
2328 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2329 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2330 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2331 None
2332 )
2333 torchscript: Optional[TorchscriptWeightsDescr] = None
2335 @model_validator(mode="after")
2336 def check_entries(self) -> Self:
2337 entries = {wtype for wtype, entry in self if entry is not None}
2339 if not entries:
2340 raise ValueError("Missing weights entry")
2342 entries_wo_parent = {
2343 wtype
2344 for wtype, entry in self
2345 if entry is not None and hasattr(entry, "parent") and entry.parent is None
2346 }
2347 if len(entries_wo_parent) != 1:
2348 issue_warning(
2349 "Exactly one weights entry may not specify the `parent` field (got"
2350 + " {value}). That entry is considered the original set of model weights."
2351 + " Other weight formats are created through conversion of the orignal or"
2352 + " already converted weights. They have to reference the weights format"
2353 + " they were converted from as their `parent`.",
2354 value=len(entries_wo_parent),
2355 field="weights",
2356 )
2358 for wtype, entry in self:
2359 if entry is None:
2360 continue
2362 assert hasattr(entry, "type")
2363 assert hasattr(entry, "parent")
2364 assert wtype == entry.type
2365 if (
2366 entry.parent is not None and entry.parent not in entries
2367 ): # self reference checked for `parent` field
2368 raise ValueError(
2369 f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2370 + f" formats: {entries}"
2371 )
2373 return self
2375 def __getitem__(
2376 self,
2377 key: Literal[
2378 "keras_hdf5",
2379 "onnx",
2380 "pytorch_state_dict",
2381 "tensorflow_js",
2382 "tensorflow_saved_model_bundle",
2383 "torchscript",
2384 ],
2385 ):
2386 if key == "keras_hdf5":
2387 ret = self.keras_hdf5
2388 elif key == "onnx":
2389 ret = self.onnx
2390 elif key == "pytorch_state_dict":
2391 ret = self.pytorch_state_dict
2392 elif key == "tensorflow_js":
2393 ret = self.tensorflow_js
2394 elif key == "tensorflow_saved_model_bundle":
2395 ret = self.tensorflow_saved_model_bundle
2396 elif key == "torchscript":
2397 ret = self.torchscript
2398 else:
2399 raise KeyError(key)
2401 if ret is None:
2402 raise KeyError(key)
2404 return ret
2406 @property
2407 def available_formats(self):
2408 return {
2409 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2410 **({} if self.onnx is None else {"onnx": self.onnx}),
2411 **(
2412 {}
2413 if self.pytorch_state_dict is None
2414 else {"pytorch_state_dict": self.pytorch_state_dict}
2415 ),
2416 **(
2417 {}
2418 if self.tensorflow_js is None
2419 else {"tensorflow_js": self.tensorflow_js}
2420 ),
2421 **(
2422 {}
2423 if self.tensorflow_saved_model_bundle is None
2424 else {
2425 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2426 }
2427 ),
2428 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2429 }
2431 @property
2432 def missing_formats(self):
2433 return {
2434 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2435 }
2438class ModelId(ResourceId):
2439 pass
2442class LinkedModel(LinkedResourceBase):
2443 """Reference to a bioimage.io model."""
2445 id: ModelId
2446 """A valid model `id` from the bioimage.io collection."""
2449class _DataDepSize(NamedTuple):
2450 min: int
2451 max: Optional[int]
2454class _AxisSizes(NamedTuple):
2455 """the lenghts of all axes of model inputs and outputs"""
2457 inputs: Dict[Tuple[TensorId, AxisId], int]
2458 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2461class _TensorSizes(NamedTuple):
2462 """_AxisSizes as nested dicts"""
2464 inputs: Dict[TensorId, Dict[AxisId, int]]
2465 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2468class ReproducibilityTolerance(Node, extra="allow"):
2469 """Describes what small numerical differences -- if any -- may be tolerated
2470 in the generated output when executing in different environments.
2472 A tensor element *output* is considered mismatched to the **test_tensor** if
2473 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
2474 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
2476 Motivation:
2477 For testing we can request the respective deep learning frameworks to be as
2478 reproducible as possible by setting seeds and chosing deterministic algorithms,
2479 but differences in operating systems, available hardware and installed drivers
2480 may still lead to numerical differences.
2481 """
2483 relative_tolerance: RelativeTolerance = 1e-3
2484 """Maximum relative tolerance of reproduced test tensor."""
2486 absolute_tolerance: AbsoluteTolerance = 1e-4
2487 """Maximum absolute tolerance of reproduced test tensor."""
2489 mismatched_elements_per_million: MismatchedElementsPerMillion = 0
2490 """Maximum number of mismatched elements/pixels per million to tolerate."""
2492 output_ids: Sequence[TensorId] = ()
2493 """Limits the output tensor IDs these reproducibility details apply to."""
2495 weights_formats: Sequence[WeightsFormat] = ()
2496 """Limits the weights formats these details apply to."""
2499class BioimageioConfig(Node, extra="allow"):
2500 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
2501 """Tolerances to allow when reproducing the model's test outputs
2502 from the model's test inputs.
2503 Only the first entry matching tensor id and weights format is considered.
2504 """
2507class Config(Node, extra="allow"):
2508 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig)
2511class ModelDescr(GenericModelDescrBase):
2512 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2513 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2514 """
2516 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2517 if TYPE_CHECKING:
2518 format_version: Literal["0.5.4"] = "0.5.4"
2519 else:
2520 format_version: Literal["0.5.4"]
2521 """Version of the bioimage.io model description specification used.
2522 When creating a new model always use the latest micro/patch version described here.
2523 The `format_version` is important for any consumer software to understand how to parse the fields.
2524 """
2526 implemented_type: ClassVar[Literal["model"]] = "model"
2527 if TYPE_CHECKING:
2528 type: Literal["model"] = "model"
2529 else:
2530 type: Literal["model"]
2531 """Specialized resource type 'model'"""
2533 id: Optional[ModelId] = None
2534 """bioimage.io-wide unique resource identifier
2535 assigned by bioimage.io; version **un**specific."""
2537 authors: NotEmpty[List[Author]]
2538 """The authors are the creators of the model RDF and the primary points of contact."""
2540 documentation: Annotated[
2541 DocumentationSource,
2542 Field(
2543 examples=[
2544 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
2545 "README.md",
2546 ],
2547 ),
2548 ]
2549 """∈📦 URL or relative path to a markdown file with additional documentation.
2550 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2551 The documentation should include a '#[#] Validation' (sub)section
2552 with details on how to quantitatively validate the model on unseen data."""
2554 @field_validator("documentation", mode="after")
2555 @classmethod
2556 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource:
2557 if not get_validation_context().perform_io_checks:
2558 return value
2560 doc_path = download(value).path
2561 doc_content = doc_path.read_text(encoding="utf-8")
2562 assert isinstance(doc_content, str)
2563 if not re.search("#.*[vV]alidation", doc_content):
2564 issue_warning(
2565 "No '# Validation' (sub)section found in {value}.",
2566 value=value,
2567 field="documentation",
2568 )
2570 return value
2572 inputs: NotEmpty[Sequence[InputTensorDescr]]
2573 """Describes the input tensors expected by this model."""
2575 @field_validator("inputs", mode="after")
2576 @classmethod
2577 def _validate_input_axes(
2578 cls, inputs: Sequence[InputTensorDescr]
2579 ) -> Sequence[InputTensorDescr]:
2580 input_size_refs = cls._get_axes_with_independent_size(inputs)
2582 for i, ipt in enumerate(inputs):
2583 valid_independent_refs: Dict[
2584 Tuple[TensorId, AxisId],
2585 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2586 ] = {
2587 **{
2588 (ipt.id, a.id): (ipt, a, a.size)
2589 for a in ipt.axes
2590 if not isinstance(a, BatchAxis)
2591 and isinstance(a.size, (int, ParameterizedSize))
2592 },
2593 **input_size_refs,
2594 }
2595 for a, ax in enumerate(ipt.axes):
2596 cls._validate_axis(
2597 "inputs",
2598 i=i,
2599 tensor_id=ipt.id,
2600 a=a,
2601 axis=ax,
2602 valid_independent_refs=valid_independent_refs,
2603 )
2604 return inputs
2606 @staticmethod
2607 def _validate_axis(
2608 field_name: str,
2609 i: int,
2610 tensor_id: TensorId,
2611 a: int,
2612 axis: AnyAxis,
2613 valid_independent_refs: Dict[
2614 Tuple[TensorId, AxisId],
2615 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2616 ],
2617 ):
2618 if isinstance(axis, BatchAxis) or isinstance(
2619 axis.size, (int, ParameterizedSize, DataDependentSize)
2620 ):
2621 return
2622 elif not isinstance(axis.size, SizeReference):
2623 assert_never(axis.size)
2625 # validate axis.size SizeReference
2626 ref = (axis.size.tensor_id, axis.size.axis_id)
2627 if ref not in valid_independent_refs:
2628 raise ValueError(
2629 "Invalid tensor axis reference at"
2630 + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2631 )
2632 if ref == (tensor_id, axis.id):
2633 raise ValueError(
2634 "Self-referencing not allowed for"
2635 + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2636 )
2637 if axis.type == "channel":
2638 if valid_independent_refs[ref][1].type != "channel":
2639 raise ValueError(
2640 "A channel axis' size may only reference another fixed size"
2641 + " channel axis."
2642 )
2643 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2644 ref_size = valid_independent_refs[ref][2]
2645 assert isinstance(ref_size, int), (
2646 "channel axis ref (another channel axis) has to specify fixed"
2647 + " size"
2648 )
2649 generated_channel_names = [
2650 Identifier(axis.channel_names.format(i=i))
2651 for i in range(1, ref_size + 1)
2652 ]
2653 axis.channel_names = generated_channel_names
2655 if (ax_unit := getattr(axis, "unit", None)) != (
2656 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2657 ):
2658 raise ValueError(
2659 "The units of an axis and its reference axis need to match, but"
2660 + f" '{ax_unit}' != '{ref_unit}'."
2661 )
2662 ref_axis = valid_independent_refs[ref][1]
2663 if isinstance(ref_axis, BatchAxis):
2664 raise ValueError(
2665 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2666 + " (a batch axis is not allowed as reference)."
2667 )
2669 if isinstance(axis, WithHalo):
2670 min_size = axis.size.get_size(axis, ref_axis, n=0)
2671 if (min_size - 2 * axis.halo) < 1:
2672 raise ValueError(
2673 f"axis {axis.id} with minimum size {min_size} is too small for halo"
2674 + f" {axis.halo}."
2675 )
2677 input_halo = axis.halo * axis.scale / ref_axis.scale
2678 if input_halo != int(input_halo) or input_halo % 2 == 1:
2679 raise ValueError(
2680 f"input_halo {input_halo} (output_halo {axis.halo} *"
2681 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2682 + f" {tensor_id}.{axis.id}."
2683 )
2685 @model_validator(mode="after")
2686 def _validate_test_tensors(self) -> Self:
2687 if not get_validation_context().perform_io_checks:
2688 return self
2690 test_output_arrays = [
2691 load_array(descr.test_tensor.download().path) for descr in self.outputs
2692 ]
2693 test_input_arrays = [
2694 load_array(descr.test_tensor.download().path) for descr in self.inputs
2695 ]
2697 tensors = {
2698 descr.id: (descr, array)
2699 for descr, array in zip(
2700 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2701 )
2702 }
2703 validate_tensors(tensors, tensor_origin="test_tensor")
2705 output_arrays = {
2706 descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2707 }
2708 for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2709 if not rep_tol.absolute_tolerance:
2710 continue
2712 if rep_tol.output_ids:
2713 out_arrays = {
2714 oid: a
2715 for oid, a in output_arrays.items()
2716 if oid in rep_tol.output_ids
2717 }
2718 else:
2719 out_arrays = output_arrays
2721 for out_id, array in out_arrays.items():
2722 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2723 raise ValueError(
2724 "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2725 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2726 + f" (1% of the maximum value of the test tensor '{out_id}')"
2727 )
2729 return self
2731 @model_validator(mode="after")
2732 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2733 ipt_refs = {t.id for t in self.inputs}
2734 out_refs = {t.id for t in self.outputs}
2735 for ipt in self.inputs:
2736 for p in ipt.preprocessing:
2737 ref = p.kwargs.get("reference_tensor")
2738 if ref is None:
2739 continue
2740 if ref not in ipt_refs:
2741 raise ValueError(
2742 f"`reference_tensor` '{ref}' not found. Valid input tensor"
2743 + f" references are: {ipt_refs}."
2744 )
2746 for out in self.outputs:
2747 for p in out.postprocessing:
2748 ref = p.kwargs.get("reference_tensor")
2749 if ref is None:
2750 continue
2752 if ref not in ipt_refs and ref not in out_refs:
2753 raise ValueError(
2754 f"`reference_tensor` '{ref}' not found. Valid tensor references"
2755 + f" are: {ipt_refs | out_refs}."
2756 )
2758 return self
2760 # TODO: use validate funcs in validate_test_tensors
2761 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2763 name: Annotated[
2764 Annotated[
2765 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2766 ],
2767 MinLen(5),
2768 MaxLen(128),
2769 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2770 ]
2771 """A human-readable name of this model.
2772 It should be no longer than 64 characters
2773 and may only contain letter, number, underscore, minus, parentheses and spaces.
2774 We recommend to chose a name that refers to the model's task and image modality.
2775 """
2777 outputs: NotEmpty[Sequence[OutputTensorDescr]]
2778 """Describes the output tensors."""
2780 @field_validator("outputs", mode="after")
2781 @classmethod
2782 def _validate_tensor_ids(
2783 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2784 ) -> Sequence[OutputTensorDescr]:
2785 tensor_ids = [
2786 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2787 ]
2788 duplicate_tensor_ids: List[str] = []
2789 seen: Set[str] = set()
2790 for t in tensor_ids:
2791 if t in seen:
2792 duplicate_tensor_ids.append(t)
2794 seen.add(t)
2796 if duplicate_tensor_ids:
2797 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2799 return outputs
2801 @staticmethod
2802 def _get_axes_with_parameterized_size(
2803 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2804 ):
2805 return {
2806 f"{t.id}.{a.id}": (t, a, a.size)
2807 for t in io
2808 for a in t.axes
2809 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2810 }
2812 @staticmethod
2813 def _get_axes_with_independent_size(
2814 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2815 ):
2816 return {
2817 (t.id, a.id): (t, a, a.size)
2818 for t in io
2819 for a in t.axes
2820 if not isinstance(a, BatchAxis)
2821 and isinstance(a.size, (int, ParameterizedSize))
2822 }
2824 @field_validator("outputs", mode="after")
2825 @classmethod
2826 def _validate_output_axes(
2827 cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2828 ) -> List[OutputTensorDescr]:
2829 input_size_refs = cls._get_axes_with_independent_size(
2830 info.data.get("inputs", [])
2831 )
2832 output_size_refs = cls._get_axes_with_independent_size(outputs)
2834 for i, out in enumerate(outputs):
2835 valid_independent_refs: Dict[
2836 Tuple[TensorId, AxisId],
2837 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2838 ] = {
2839 **{
2840 (out.id, a.id): (out, a, a.size)
2841 for a in out.axes
2842 if not isinstance(a, BatchAxis)
2843 and isinstance(a.size, (int, ParameterizedSize))
2844 },
2845 **input_size_refs,
2846 **output_size_refs,
2847 }
2848 for a, ax in enumerate(out.axes):
2849 cls._validate_axis(
2850 "outputs",
2851 i,
2852 out.id,
2853 a,
2854 ax,
2855 valid_independent_refs=valid_independent_refs,
2856 )
2858 return outputs
2860 packaged_by: List[Author] = Field(default_factory=list)
2861 """The persons that have packaged and uploaded this model.
2862 Only required if those persons differ from the `authors`."""
2864 parent: Optional[LinkedModel] = None
2865 """The model from which this model is derived, e.g. by fine-tuning the weights."""
2867 @model_validator(mode="after")
2868 def _validate_parent_is_not_self(self) -> Self:
2869 if self.parent is not None and self.parent.id == self.id:
2870 raise ValueError("A model description may not reference itself as parent.")
2872 return self
2874 run_mode: Annotated[
2875 Optional[RunMode],
2876 warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2877 ] = None
2878 """Custom run mode for this model: for more complex prediction procedures like test time
2879 data augmentation that currently cannot be expressed in the specification.
2880 No standard run modes are defined yet."""
2882 timestamp: Datetime = Field(default_factory=Datetime.now)
2883 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2884 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2885 (In Python a datetime object is valid, too)."""
2887 training_data: Annotated[
2888 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2889 Field(union_mode="left_to_right"),
2890 ] = None
2891 """The dataset used to train this model"""
2893 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2894 """The weights for this model.
2895 Weights can be given for different formats, but should otherwise be equivalent.
2896 The available weight formats determine which consumers can use this model."""
2898 config: Config = Field(default_factory=Config)
2900 @model_validator(mode="after")
2901 def _add_default_cover(self) -> Self:
2902 if not get_validation_context().perform_io_checks or self.covers:
2903 return self
2905 try:
2906 generated_covers = generate_covers(
2907 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs],
2908 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs],
2909 )
2910 except Exception as e:
2911 issue_warning(
2912 "Failed to generate cover image(s): {e}",
2913 value=self.covers,
2914 msg_context=dict(e=e),
2915 field="covers",
2916 )
2917 else:
2918 self.covers.extend(generated_covers)
2920 return self
2922 def get_input_test_arrays(self) -> List[NDArray[Any]]:
2923 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2924 assert all(isinstance(d, np.ndarray) for d in data)
2925 return data
2927 def get_output_test_arrays(self) -> List[NDArray[Any]]:
2928 data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2929 assert all(isinstance(d, np.ndarray) for d in data)
2930 return data
2932 @staticmethod
2933 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2934 batch_size = 1
2935 tensor_with_batchsize: Optional[TensorId] = None
2936 for tid in tensor_sizes:
2937 for aid, s in tensor_sizes[tid].items():
2938 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2939 continue
2941 if batch_size != 1:
2942 assert tensor_with_batchsize is not None
2943 raise ValueError(
2944 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2945 )
2947 batch_size = s
2948 tensor_with_batchsize = tid
2950 return batch_size
2952 def get_output_tensor_sizes(
2953 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2954 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2955 """Returns the tensor output sizes for given **input_sizes**.
2956 Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2957 Otherwise it might be larger than the actual (valid) output"""
2958 batch_size = self.get_batch_size(input_sizes)
2959 ns = self.get_ns(input_sizes)
2961 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2962 return tensor_sizes.outputs
2964 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2965 """get parameter `n` for each parameterized axis
2966 such that the valid input size is >= the given input size"""
2967 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2968 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2969 for tid in input_sizes:
2970 for aid, s in input_sizes[tid].items():
2971 size_descr = axes[tid][aid].size
2972 if isinstance(size_descr, ParameterizedSize):
2973 ret[(tid, aid)] = size_descr.get_n(s)
2974 elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2975 pass
2976 else:
2977 assert_never(size_descr)
2979 return ret
2981 def get_tensor_sizes(
2982 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2983 ) -> _TensorSizes:
2984 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2985 return _TensorSizes(
2986 {
2987 t: {
2988 aa: axis_sizes.inputs[(tt, aa)]
2989 for tt, aa in axis_sizes.inputs
2990 if tt == t
2991 }
2992 for t in {tt for tt, _ in axis_sizes.inputs}
2993 },
2994 {
2995 t: {
2996 aa: axis_sizes.outputs[(tt, aa)]
2997 for tt, aa in axis_sizes.outputs
2998 if tt == t
2999 }
3000 for t in {tt for tt, _ in axis_sizes.outputs}
3001 },
3002 )
3004 def get_axis_sizes(
3005 self,
3006 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3007 batch_size: Optional[int] = None,
3008 *,
3009 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3010 ) -> _AxisSizes:
3011 """Determine input and output block shape for scale factors **ns**
3012 of parameterized input sizes.
3014 Args:
3015 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3016 that is parameterized as `size = min + n * step`.
3017 batch_size: The desired size of the batch dimension.
3018 If given **batch_size** overwrites any batch size present in
3019 **max_input_shape**. Default 1.
3020 max_input_shape: Limits the derived block shapes.
3021 Each axis for which the input size, parameterized by `n`, is larger
3022 than **max_input_shape** is set to the minimal value `n_min` for which
3023 this is still true.
3024 Use this for small input samples or large values of **ns**.
3025 Or simply whenever you know the full input shape.
3027 Returns:
3028 Resolved axis sizes for model inputs and outputs.
3029 """
3030 max_input_shape = max_input_shape or {}
3031 if batch_size is None:
3032 for (_t_id, a_id), s in max_input_shape.items():
3033 if a_id == BATCH_AXIS_ID:
3034 batch_size = s
3035 break
3036 else:
3037 batch_size = 1
3039 all_axes = {
3040 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3041 }
3043 inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3044 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3046 def get_axis_size(a: Union[InputAxis, OutputAxis]):
3047 if isinstance(a, BatchAxis):
3048 if (t_descr.id, a.id) in ns:
3049 logger.warning(
3050 "Ignoring unexpected size increment factor (n) for batch axis"
3051 + " of tensor '{}'.",
3052 t_descr.id,
3053 )
3054 return batch_size
3055 elif isinstance(a.size, int):
3056 if (t_descr.id, a.id) in ns:
3057 logger.warning(
3058 "Ignoring unexpected size increment factor (n) for fixed size"
3059 + " axis '{}' of tensor '{}'.",
3060 a.id,
3061 t_descr.id,
3062 )
3063 return a.size
3064 elif isinstance(a.size, ParameterizedSize):
3065 if (t_descr.id, a.id) not in ns:
3066 raise ValueError(
3067 "Size increment factor (n) missing for parametrized axis"
3068 + f" '{a.id}' of tensor '{t_descr.id}'."
3069 )
3070 n = ns[(t_descr.id, a.id)]
3071 s_max = max_input_shape.get((t_descr.id, a.id))
3072 if s_max is not None:
3073 n = min(n, a.size.get_n(s_max))
3075 return a.size.get_size(n)
3077 elif isinstance(a.size, SizeReference):
3078 if (t_descr.id, a.id) in ns:
3079 logger.warning(
3080 "Ignoring unexpected size increment factor (n) for axis '{}'"
3081 + " of tensor '{}' with size reference.",
3082 a.id,
3083 t_descr.id,
3084 )
3085 assert not isinstance(a, BatchAxis)
3086 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3087 assert not isinstance(ref_axis, BatchAxis)
3088 ref_key = (a.size.tensor_id, a.size.axis_id)
3089 ref_size = inputs.get(ref_key, outputs.get(ref_key))
3090 assert ref_size is not None, ref_key
3091 assert not isinstance(ref_size, _DataDepSize), ref_key
3092 return a.size.get_size(
3093 axis=a,
3094 ref_axis=ref_axis,
3095 ref_size=ref_size,
3096 )
3097 elif isinstance(a.size, DataDependentSize):
3098 if (t_descr.id, a.id) in ns:
3099 logger.warning(
3100 "Ignoring unexpected increment factor (n) for data dependent"
3101 + " size axis '{}' of tensor '{}'.",
3102 a.id,
3103 t_descr.id,
3104 )
3105 return _DataDepSize(a.size.min, a.size.max)
3106 else:
3107 assert_never(a.size)
3109 # first resolve all , but the `SizeReference` input sizes
3110 for t_descr in self.inputs:
3111 for a in t_descr.axes:
3112 if not isinstance(a.size, SizeReference):
3113 s = get_axis_size(a)
3114 assert not isinstance(s, _DataDepSize)
3115 inputs[t_descr.id, a.id] = s
3117 # resolve all other input axis sizes
3118 for t_descr in self.inputs:
3119 for a in t_descr.axes:
3120 if isinstance(a.size, SizeReference):
3121 s = get_axis_size(a)
3122 assert not isinstance(s, _DataDepSize)
3123 inputs[t_descr.id, a.id] = s
3125 # resolve all output axis sizes
3126 for t_descr in self.outputs:
3127 for a in t_descr.axes:
3128 assert not isinstance(a.size, ParameterizedSize)
3129 s = get_axis_size(a)
3130 outputs[t_descr.id, a.id] = s
3132 return _AxisSizes(inputs=inputs, outputs=outputs)
3134 @model_validator(mode="before")
3135 @classmethod
3136 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3137 cls.convert_from_old_format_wo_validation(data)
3138 return data
3140 @classmethod
3141 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3142 """Convert metadata following an older format version to this classes' format
3143 without validating the result.
3144 """
3145 if (
3146 data.get("type") == "model"
3147 and isinstance(fv := data.get("format_version"), str)
3148 and fv.count(".") == 2
3149 ):
3150 fv_parts = fv.split(".")
3151 if any(not p.isdigit() for p in fv_parts):
3152 return
3154 fv_tuple = tuple(map(int, fv_parts))
3156 assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3157 if fv_tuple[:2] in ((0, 3), (0, 4)):
3158 m04 = _ModelDescr_v0_4.load(data)
3159 if isinstance(m04, InvalidDescr):
3160 try:
3161 updated = _model_conv.convert_as_dict(
3162 m04 # pyright: ignore[reportArgumentType]
3163 )
3164 except Exception as e:
3165 logger.error(
3166 "Failed to convert from invalid model 0.4 description."
3167 + f"\nerror: {e}"
3168 + "\nProceeding with model 0.5 validation without conversion."
3169 )
3170 updated = None
3171 else:
3172 updated = _model_conv.convert_as_dict(m04)
3174 if updated is not None:
3175 data.clear()
3176 data.update(updated)
3178 elif fv_tuple[:2] == (0, 5):
3179 # bump patch version
3180 data["format_version"] = cls.implemented_format_version
3183class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
3184 def _convert(
3185 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
3186 ) -> "ModelDescr | dict[str, Any]":
3187 name = "".join(
3188 c if c in string.ascii_letters + string.digits + "_+- ()" else " "
3189 for c in src.name
3190 )
3192 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
3193 conv = (
3194 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
3195 )
3196 return None if auths is None else [conv(a) for a in auths]
3198 if TYPE_CHECKING:
3199 arch_file_conv = _arch_file_conv.convert
3200 arch_lib_conv = _arch_lib_conv.convert
3201 else:
3202 arch_file_conv = _arch_file_conv.convert_as_dict
3203 arch_lib_conv = _arch_lib_conv.convert_as_dict
3205 input_size_refs = {
3206 ipt.name: {
3207 a: s
3208 for a, s in zip(
3209 ipt.axes,
3210 (
3211 ipt.shape.min
3212 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
3213 else ipt.shape
3214 ),
3215 )
3216 }
3217 for ipt in src.inputs
3218 if ipt.shape
3219 }
3220 output_size_refs = {
3221 **{
3222 out.name: {a: s for a, s in zip(out.axes, out.shape)}
3223 for out in src.outputs
3224 if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
3225 },
3226 **input_size_refs,
3227 }
3229 return tgt(
3230 attachments=(
3231 []
3232 if src.attachments is None
3233 else [FileDescr(source=f) for f in src.attachments.files]
3234 ),
3235 authors=[
3236 _author_conv.convert_as_dict(a) for a in src.authors
3237 ], # pyright: ignore[reportArgumentType]
3238 cite=[
3239 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite
3240 ], # pyright: ignore[reportArgumentType]
3241 config=src.config, # pyright: ignore[reportArgumentType]
3242 covers=src.covers,
3243 description=src.description,
3244 documentation=src.documentation,
3245 format_version="0.5.4",
3246 git_repo=src.git_repo, # pyright: ignore[reportArgumentType]
3247 icon=src.icon,
3248 id=None if src.id is None else ModelId(src.id),
3249 id_emoji=src.id_emoji,
3250 license=src.license, # type: ignore
3251 links=src.links,
3252 maintainers=[
3253 _maintainer_conv.convert_as_dict(m) for m in src.maintainers
3254 ], # pyright: ignore[reportArgumentType]
3255 name=name,
3256 tags=src.tags,
3257 type=src.type,
3258 uploader=src.uploader,
3259 version=src.version,
3260 inputs=[ # pyright: ignore[reportArgumentType]
3261 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3262 for ipt, tt, st, in zip(
3263 src.inputs,
3264 src.test_inputs,
3265 src.sample_inputs or [None] * len(src.test_inputs),
3266 )
3267 ],
3268 outputs=[ # pyright: ignore[reportArgumentType]
3269 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3270 for out, tt, st, in zip(
3271 src.outputs,
3272 src.test_outputs,
3273 src.sample_outputs or [None] * len(src.test_outputs),
3274 )
3275 ],
3276 parent=(
3277 None
3278 if src.parent is None
3279 else LinkedModel(
3280 id=ModelId(
3281 str(src.parent.id)
3282 + (
3283 ""
3284 if src.parent.version_number is None
3285 else f"/{src.parent.version_number}"
3286 )
3287 )
3288 )
3289 ),
3290 training_data=(
3291 None
3292 if src.training_data is None
3293 else (
3294 LinkedDataset(
3295 id=DatasetId(
3296 str(src.training_data.id)
3297 + (
3298 ""
3299 if src.training_data.version_number is None
3300 else f"/{src.training_data.version_number}"
3301 )
3302 )
3303 )
3304 if isinstance(src.training_data, LinkedDataset02)
3305 else src.training_data
3306 )
3307 ),
3308 packaged_by=[
3309 _author_conv.convert_as_dict(a) for a in src.packaged_by
3310 ], # pyright: ignore[reportArgumentType]
3311 run_mode=src.run_mode,
3312 timestamp=src.timestamp,
3313 weights=(WeightsDescr if TYPE_CHECKING else dict)(
3314 keras_hdf5=(w := src.weights.keras_hdf5)
3315 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3316 authors=conv_authors(w.authors),
3317 source=w.source,
3318 tensorflow_version=w.tensorflow_version or Version("1.15"),
3319 parent=w.parent,
3320 ),
3321 onnx=(w := src.weights.onnx)
3322 and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3323 source=w.source,
3324 authors=conv_authors(w.authors),
3325 parent=w.parent,
3326 opset_version=w.opset_version or 15,
3327 ),
3328 pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3329 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3330 source=w.source,
3331 authors=conv_authors(w.authors),
3332 parent=w.parent,
3333 architecture=(
3334 arch_file_conv(
3335 w.architecture,
3336 w.architecture_sha256,
3337 w.kwargs,
3338 )
3339 if isinstance(w.architecture, _CallableFromFile_v0_4)
3340 else arch_lib_conv(w.architecture, w.kwargs)
3341 ),
3342 pytorch_version=w.pytorch_version or Version("1.10"),
3343 dependencies=(
3344 None
3345 if w.dependencies is None
3346 else (EnvironmentFileDescr if TYPE_CHECKING else dict)(
3347 source=cast(
3348 ImportantFileSource,
3349 str(deps := w.dependencies)[
3350 (
3351 len("conda:")
3352 if str(deps).startswith("conda:")
3353 else 0
3354 ) :
3355 ],
3356 )
3357 )
3358 ),
3359 ),
3360 tensorflow_js=(w := src.weights.tensorflow_js)
3361 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3362 source=w.source,
3363 authors=conv_authors(w.authors),
3364 parent=w.parent,
3365 tensorflow_version=w.tensorflow_version or Version("1.15"),
3366 ),
3367 tensorflow_saved_model_bundle=(
3368 w := src.weights.tensorflow_saved_model_bundle
3369 )
3370 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3371 authors=conv_authors(w.authors),
3372 parent=w.parent,
3373 source=w.source,
3374 tensorflow_version=w.tensorflow_version or Version("1.15"),
3375 dependencies=(
3376 None
3377 if w.dependencies is None
3378 else (EnvironmentFileDescr if TYPE_CHECKING else dict)(
3379 source=cast(
3380 ImportantFileSource,
3381 (
3382 str(w.dependencies)[len("conda:") :]
3383 if str(w.dependencies).startswith("conda:")
3384 else str(w.dependencies)
3385 ),
3386 )
3387 )
3388 ),
3389 ),
3390 torchscript=(w := src.weights.torchscript)
3391 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3392 source=w.source,
3393 authors=conv_authors(w.authors),
3394 parent=w.parent,
3395 pytorch_version=w.pytorch_version or Version("1.10"),
3396 ),
3397 ),
3398 )
3401_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3404# create better cover images for 3d data and non-image outputs
3405def generate_covers(
3406 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3407 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3408) -> List[Path]:
3409 def squeeze(
3410 data: NDArray[Any], axes: Sequence[AnyAxis]
3411 ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3412 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3413 if data.ndim != len(axes):
3414 raise ValueError(
3415 f"tensor shape {data.shape} does not match described axes"
3416 + f" {[a.id for a in axes]}"
3417 )
3419 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3420 return data.squeeze(), axes
3422 def normalize(
3423 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3424 ) -> NDArray[np.float32]:
3425 data = data.astype("float32")
3426 data -= data.min(axis=axis, keepdims=True)
3427 data /= data.max(axis=axis, keepdims=True) + eps
3428 return data
3430 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3431 original_shape = data.shape
3432 data, axes = squeeze(data, axes)
3434 # take slice fom any batch or index axis if needed
3435 # and convert the first channel axis and take a slice from any additional channel axes
3436 slices: Tuple[slice, ...] = ()
3437 ndim = data.ndim
3438 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3439 has_c_axis = False
3440 for i, a in enumerate(axes):
3441 s = data.shape[i]
3442 assert s > 1
3443 if (
3444 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3445 and ndim > ndim_need
3446 ):
3447 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3448 ndim -= 1
3449 elif isinstance(a, ChannelAxis):
3450 if has_c_axis:
3451 # second channel axis
3452 data = data[slices + (slice(0, 1),)]
3453 ndim -= 1
3454 else:
3455 has_c_axis = True
3456 if s == 2:
3457 # visualize two channels with cyan and magenta
3458 data = np.concatenate(
3459 [
3460 data[slices + (slice(1, 2),)],
3461 data[slices + (slice(0, 1),)],
3462 (
3463 data[slices + (slice(0, 1),)]
3464 + data[slices + (slice(1, 2),)]
3465 )
3466 / 2, # TODO: take maximum instead?
3467 ],
3468 axis=i,
3469 )
3470 elif data.shape[i] == 3:
3471 pass # visualize 3 channels as RGB
3472 else:
3473 # visualize first 3 channels as RGB
3474 data = data[slices + (slice(3),)]
3476 assert data.shape[i] == 3
3478 slices += (slice(None),)
3480 data, axes = squeeze(data, axes)
3481 assert len(axes) == ndim
3482 # take slice from z axis if needed
3483 slices = ()
3484 if ndim > ndim_need:
3485 for i, a in enumerate(axes):
3486 s = data.shape[i]
3487 if a.id == AxisId("z"):
3488 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3489 data, axes = squeeze(data, axes)
3490 ndim -= 1
3491 break
3493 slices += (slice(None),)
3495 # take slice from any space or time axis
3496 slices = ()
3498 for i, a in enumerate(axes):
3499 if ndim <= ndim_need:
3500 break
3502 s = data.shape[i]
3503 assert s > 1
3504 if isinstance(
3505 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3506 ):
3507 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3508 ndim -= 1
3510 slices += (slice(None),)
3512 del slices
3513 data, axes = squeeze(data, axes)
3514 assert len(axes) == ndim
3516 if (has_c_axis and ndim != 3) or ndim != 2:
3517 raise ValueError(
3518 f"Failed to construct cover image from shape {original_shape}"
3519 )
3521 if not has_c_axis:
3522 assert ndim == 2
3523 data = np.repeat(data[:, :, None], 3, axis=2)
3524 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3525 ndim += 1
3527 assert ndim == 3
3529 # transpose axis order such that longest axis comes first...
3530 axis_order: List[int] = list(np.argsort(list(data.shape)))
3531 axis_order.reverse()
3532 # ... and channel axis is last
3533 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3534 axis_order.append(axis_order.pop(c))
3535 axes = [axes[ao] for ao in axis_order]
3536 data = data.transpose(axis_order)
3538 # h, w = data.shape[:2]
3539 # if h / w in (1.0 or 2.0):
3540 # pass
3541 # elif h / w < 2:
3542 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3544 norm_along = (
3545 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3546 )
3547 # normalize the data and map to 8 bit
3548 data = normalize(data, norm_along)
3549 data = (data * 255).astype("uint8")
3551 return data
3553 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3554 assert im0.dtype == im1.dtype == np.uint8
3555 assert im0.shape == im1.shape
3556 assert im0.ndim == 3
3557 N, M, C = im0.shape
3558 assert C == 3
3559 out = np.ones((N, M, C), dtype="uint8")
3560 for c in range(C):
3561 outc = np.tril(im0[..., c])
3562 mask = outc == 0
3563 outc[mask] = np.triu(im1[..., c])[mask]
3564 out[..., c] = outc
3566 return out
3568 ipt_descr, ipt = inputs[0]
3569 out_descr, out = outputs[0]
3571 ipt_img = to_2d_image(ipt, ipt_descr.axes)
3572 out_img = to_2d_image(out, out_descr.axes)
3574 cover_folder = Path(mkdtemp())
3575 if ipt_img.shape == out_img.shape:
3576 covers = [cover_folder / "cover.png"]
3577 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3578 else:
3579 covers = [cover_folder / "input.png", cover_folder / "output.png"]
3580 imwrite(covers[0], ipt_img)
3581 imwrite(covers[1], out_img)
3583 return covers