Coverage for bioimageio/spec/model/v0_5.py: 45%
1233 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-05 13:53 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-05 13:53 +0000
1from __future__ import annotations
3import collections.abc
4import re
5import string
6import warnings
7from abc import ABC
8from copy import deepcopy
9from datetime import datetime
10from itertools import chain
11from math import ceil
12from pathlib import Path, PurePosixPath
13from tempfile import mkdtemp
14from typing import (
15 TYPE_CHECKING,
16 Any,
17 ClassVar,
18 Dict,
19 FrozenSet,
20 Generic,
21 List,
22 Literal,
23 Mapping,
24 NamedTuple,
25 Optional,
26 Sequence,
27 Set,
28 Tuple,
29 Type,
30 TypeVar,
31 Union,
32 cast,
33)
35import numpy as np
36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
38from loguru import logger
39from numpy.typing import NDArray
40from pydantic import (
41 Discriminator,
42 Field,
43 RootModel,
44 Tag,
45 ValidationInfo,
46 WrapSerializer,
47 field_validator,
48 model_validator,
49)
50from typing_extensions import Annotated, LiteralString, Self, assert_never, get_args
52from .._internal.common_nodes import (
53 InvalidDescr,
54 Node,
55 NodeWithExplicitlySetFields,
56)
57from .._internal.constants import DTYPE_LIMITS
58from .._internal.field_warning import issue_warning, warn
59from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
60from .._internal.io import FileDescr as FileDescr
61from .._internal.io import WithSuffix, YamlValue, download
62from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath
63from .._internal.io_basics import Sha256 as Sha256
64from .._internal.io_utils import load_array
65from .._internal.node_converter import Converter
66from .._internal.types import Datetime as Datetime
67from .._internal.types import Identifier as Identifier
68from .._internal.types import (
69 ImportantFileSource,
70 LowerCaseIdentifier,
71 LowerCaseIdentifierAnno,
72 SiUnit,
73)
74from .._internal.types import NotEmpty as NotEmpty
75from .._internal.url import HttpUrl as HttpUrl
76from .._internal.validation_context import validation_context_var
77from .._internal.validator_annotations import RestrictCharacters
78from .._internal.version_type import Version as Version
79from .._internal.warning_levels import INFO
80from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
81from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
82from ..dataset.v0_3 import DatasetDescr as DatasetDescr
83from ..dataset.v0_3 import DatasetId as DatasetId
84from ..dataset.v0_3 import LinkedDataset as LinkedDataset
85from ..dataset.v0_3 import Uploader as Uploader
86from ..generic.v0_3 import (
87 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
88)
89from ..generic.v0_3 import Author as Author
90from ..generic.v0_3 import BadgeDescr as BadgeDescr
91from ..generic.v0_3 import CiteEntry as CiteEntry
92from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
93from ..generic.v0_3 import (
94 DocumentationSource,
95 GenericModelDescrBase,
96 LinkedResourceBase,
97 _author_conv, # pyright: ignore[reportPrivateUsage]
98 _maintainer_conv, # pyright: ignore[reportPrivateUsage]
99)
100from ..generic.v0_3 import Doi as Doi
101from ..generic.v0_3 import LicenseId as LicenseId
102from ..generic.v0_3 import LinkedResource as LinkedResource
103from ..generic.v0_3 import Maintainer as Maintainer
104from ..generic.v0_3 import OrcidId as OrcidId
105from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
106from ..generic.v0_3 import ResourceId as ResourceId
107from .v0_4 import Author as _Author_v0_4
108from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
109from .v0_4 import CallableFromDepencency as CallableFromDepencency
110from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
111from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
112from .v0_4 import ClipDescr as _ClipDescr_v0_4
113from .v0_4 import ClipKwargs as ClipKwargs
114from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
115from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
116from .v0_4 import KnownRunMode as KnownRunMode
117from .v0_4 import ModelDescr as _ModelDescr_v0_4
118from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
119from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
120from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
121from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
122from .v0_4 import ProcessingKwargs as ProcessingKwargs
123from .v0_4 import RunMode as RunMode
124from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
125from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
126from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
127from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
128from .v0_4 import TensorName as _TensorName_v0_4
129from .v0_4 import WeightsFormat as WeightsFormat
130from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
131from .v0_4 import package_weights
133SpaceUnit = Literal[
134 "attometer",
135 "angstrom",
136 "centimeter",
137 "decimeter",
138 "exameter",
139 "femtometer",
140 "foot",
141 "gigameter",
142 "hectometer",
143 "inch",
144 "kilometer",
145 "megameter",
146 "meter",
147 "micrometer",
148 "mile",
149 "millimeter",
150 "nanometer",
151 "parsec",
152 "petameter",
153 "picometer",
154 "terameter",
155 "yard",
156 "yoctometer",
157 "yottameter",
158 "zeptometer",
159 "zettameter",
160]
161"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
163TimeUnit = Literal[
164 "attosecond",
165 "centisecond",
166 "day",
167 "decisecond",
168 "exasecond",
169 "femtosecond",
170 "gigasecond",
171 "hectosecond",
172 "hour",
173 "kilosecond",
174 "megasecond",
175 "microsecond",
176 "millisecond",
177 "minute",
178 "nanosecond",
179 "petasecond",
180 "picosecond",
181 "second",
182 "terasecond",
183 "yoctosecond",
184 "yottasecond",
185 "zeptosecond",
186 "zettasecond",
187]
188"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
190AxisType = Literal["batch", "channel", "index", "time", "space"]
193class TensorId(LowerCaseIdentifier):
194 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
195 Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
196 ]
199class AxisId(LowerCaseIdentifier):
200 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
201 Annotated[LowerCaseIdentifierAnno, MaxLen(16)]
202 ]
205def _is_batch(a: str) -> bool:
206 return a == BATCH_AXIS_ID
209def _is_not_batch(a: str) -> bool:
210 return not _is_batch(a)
213NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
215PostprocessingId = Literal[
216 "binarize",
217 "clip",
218 "ensure_dtype",
219 "fixed_zero_mean_unit_variance",
220 "scale_linear",
221 "scale_mean_variance",
222 "scale_range",
223 "sigmoid",
224 "zero_mean_unit_variance",
225]
226PreprocessingId = Literal[
227 "binarize",
228 "clip",
229 "ensure_dtype",
230 "scale_linear",
231 "sigmoid",
232 "zero_mean_unit_variance",
233 "scale_range",
234]
237SAME_AS_TYPE = "<same as type>"
240ParameterizedSize_N = int
243class ParameterizedSize(Node):
244 """Describes a range of valid tensor axis sizes as `size = min + n*step`."""
246 N: ClassVar[Type[int]] = ParameterizedSize_N
247 """integer to parameterize this axis"""
249 min: Annotated[int, Gt(0)]
250 step: Annotated[int, Gt(0)]
252 def validate_size(self, size: int) -> int:
253 if size < self.min:
254 raise ValueError(f"size {size} < {self.min}")
255 if (size - self.min) % self.step != 0:
256 raise ValueError(
257 f"axis of size {size} is not parameterized by `min + n*step` ="
258 + f" `{self.min} + n*{self.step}`"
259 )
261 return size
263 def get_size(self, n: ParameterizedSize_N) -> int:
264 return self.min + self.step * n
266 def get_n(self, s: int) -> ParameterizedSize_N:
267 """return smallest n parameterizing a size greater or equal than `s`"""
268 return ceil((s - self.min) / self.step)
271class DataDependentSize(Node):
272 min: Annotated[int, Gt(0)] = 1
273 max: Annotated[Optional[int], Gt(1)] = None
275 @model_validator(mode="after")
276 def _validate_max_gt_min(self):
277 if self.max is not None and self.min >= self.max:
278 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
280 return self
282 def validate_size(self, size: int) -> int:
283 if size < self.min:
284 raise ValueError(f"size {size} < {self.min}")
286 if self.max is not None and size > self.max:
287 raise ValueError(f"size {size} > {self.max}")
289 return size
292class SizeReference(Node):
293 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
295 `axis.size = reference.size * reference.scale / axis.scale + offset`
297 Note:
298 1. The axis and the referenced axis need to have the same unit (or no unit).
299 2. Batch axes may not be referenced.
300 3. Fractions are rounded down.
301 4. If the reference axis is `concatenable` the referencing axis is assumed to be
302 `concatenable` as well with the same block order.
304 Example:
305 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
306 Let's assume that we want to express the image height h in relation to its width w
307 instead of only accepting input images of exactly 100*49 pixels
308 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
310 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
311 >>> h = SpaceInputAxis(
312 ... id=AxisId("h"),
313 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
314 ... unit="millimeter",
315 ... scale=4,
316 ... )
317 >>> print(h.size.get_size(h, w))
318 49
320 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
321 """
323 tensor_id: TensorId
324 """tensor id of the reference axis"""
326 axis_id: AxisId
327 """axis id of the reference axis"""
329 offset: int = 0
331 def get_size(
332 self,
333 axis: Union[
334 ChannelAxis,
335 IndexInputAxis,
336 IndexOutputAxis,
337 TimeInputAxis,
338 SpaceInputAxis,
339 TimeOutputAxis,
340 TimeOutputAxisWithHalo,
341 SpaceOutputAxis,
342 SpaceOutputAxisWithHalo,
343 ],
344 ref_axis: Union[
345 ChannelAxis,
346 IndexInputAxis,
347 IndexOutputAxis,
348 TimeInputAxis,
349 SpaceInputAxis,
350 TimeOutputAxis,
351 TimeOutputAxisWithHalo,
352 SpaceOutputAxis,
353 SpaceOutputAxisWithHalo,
354 ],
355 n: ParameterizedSize_N = 0,
356 ref_size: Optional[int] = None,
357 ):
358 """Compute the concrete size for a given axis and its reference axis.
360 Args:
361 axis: The axis this `SizeReference` is the size of.
362 ref_axis: The reference axis to compute the size from.
363 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
364 and no fixed **ref_size** is given,
365 **n** is used to compute the size of the parameterized **ref_axis**.
366 ref_size: Overwrite the reference size instead of deriving it from
367 **ref_axis**
368 (**ref_axis.scale** is still used; any given **n** is ignored).
369 """
370 assert (
371 axis.size == self
372 ), "Given `axis.size` is not defined by this `SizeReference`"
374 assert (
375 ref_axis.id == self.axis_id
376 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
378 assert axis.unit == ref_axis.unit, (
379 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
380 f" but {axis.unit}!={ref_axis.unit}"
381 )
382 if ref_size is None:
383 if isinstance(ref_axis.size, (int, float)):
384 ref_size = ref_axis.size
385 elif isinstance(ref_axis.size, ParameterizedSize):
386 ref_size = ref_axis.size.get_size(n)
387 elif isinstance(ref_axis.size, DataDependentSize):
388 raise ValueError(
389 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
390 )
391 elif isinstance(ref_axis.size, SizeReference):
392 raise ValueError(
393 "Reference axis referenced in `SizeReference` may not be sized by a"
394 + " `SizeReference` itself."
395 )
396 else:
397 assert_never(ref_axis.size)
399 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
401 @staticmethod
402 def _get_unit(
403 axis: Union[
404 ChannelAxis,
405 IndexInputAxis,
406 IndexOutputAxis,
407 TimeInputAxis,
408 SpaceInputAxis,
409 TimeOutputAxis,
410 TimeOutputAxisWithHalo,
411 SpaceOutputAxis,
412 SpaceOutputAxisWithHalo,
413 ],
414 ):
415 return axis.unit
418class AxisBase(NodeWithExplicitlySetFields):
419 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"type"})
421 id: AxisId
422 """An axis id unique across all axes of one tensor."""
424 description: Annotated[str, MaxLen(128)] = ""
427class WithHalo(Node):
428 halo: Annotated[int, Ge(1)]
429 """The halo should be cropped from the output tensor to avoid boundary effects.
430 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
431 To document a halo that is already cropped by the model use `size.offset` instead."""
433 size: Annotated[
434 SizeReference,
435 Field(
436 examples=[
437 10,
438 SizeReference(
439 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
440 ).model_dump(mode="json"),
441 ]
442 ),
443 ]
444 """reference to another axis with an optional offset (see `SizeReference`)"""
447BATCH_AXIS_ID = AxisId("batch")
450class BatchAxis(AxisBase):
451 type: Literal["batch"] = "batch"
452 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
453 size: Optional[Literal[1]] = None
454 """The batch size may be fixed to 1,
455 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
457 @property
458 def scale(self):
459 return 1.0
461 @property
462 def concatenable(self):
463 return True
465 @property
466 def unit(self):
467 return None
470class ChannelAxis(AxisBase):
471 type: Literal["channel"] = "channel"
472 id: NonBatchAxisId = AxisId("channel")
473 channel_names: NotEmpty[List[Identifier]]
475 @property
476 def size(self) -> int:
477 return len(self.channel_names)
479 @property
480 def concatenable(self):
481 return False
483 @property
484 def scale(self) -> float:
485 return 1.0
487 @property
488 def unit(self):
489 return None
492class IndexAxisBase(AxisBase):
493 type: Literal["index"] = "index"
494 id: NonBatchAxisId = AxisId("index")
496 @property
497 def scale(self) -> float:
498 return 1.0
500 @property
501 def unit(self):
502 return None
505class _WithInputAxisSize(Node):
506 size: Annotated[
507 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
508 Field(
509 examples=[
510 10,
511 ParameterizedSize(min=32, step=16).model_dump(mode="json"),
512 SizeReference(
513 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
514 ).model_dump(mode="json"),
515 ]
516 ),
517 ]
518 """The size/length of this axis can be specified as
519 - fixed integer
520 - parameterized series of valid sizes (`ParameterizedSize`)
521 - reference to another axis with an optional offset (`SizeReference`)
522 """
525class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
526 concatenable: bool = False
527 """If a model has a `concatenable` input axis, it can be processed blockwise,
528 splitting a longer sample axis into blocks matching its input tensor description.
529 Output axes are concatenable if they have a `SizeReference` to a concatenable
530 input axis.
531 """
534class IndexOutputAxis(IndexAxisBase):
535 size: Annotated[
536 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
537 Field(
538 examples=[
539 10,
540 SizeReference(
541 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
542 ).model_dump(mode="json"),
543 ]
544 ),
545 ]
546 """The size/length of this axis can be specified as
547 - fixed integer
548 - reference to another axis with an optional offset (`SizeReference`)
549 - data dependent size using `DataDependentSize` (size is only known after model inference)
550 """
553class TimeAxisBase(AxisBase):
554 type: Literal["time"] = "time"
555 id: NonBatchAxisId = AxisId("time")
556 unit: Optional[TimeUnit] = None
557 scale: Annotated[float, Gt(0)] = 1.0
560class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
561 concatenable: bool = False
562 """If a model has a `concatenable` input axis, it can be processed blockwise,
563 splitting a longer sample axis into blocks matching its input tensor description.
564 Output axes are concatenable if they have a `SizeReference` to a concatenable
565 input axis.
566 """
569class SpaceAxisBase(AxisBase):
570 type: Literal["space"] = "space"
571 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
572 unit: Optional[SpaceUnit] = None
573 scale: Annotated[float, Gt(0)] = 1.0
576class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
577 concatenable: bool = False
578 """If a model has a `concatenable` input axis, it can be processed blockwise,
579 splitting a longer sample axis into blocks matching its input tensor description.
580 Output axes are concatenable if they have a `SizeReference` to a concatenable
581 input axis.
582 """
585_InputAxisUnion = Union[
586 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
587]
588InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
591class _WithOutputAxisSize(Node):
592 size: Annotated[
593 Union[Annotated[int, Gt(0)], SizeReference],
594 Field(
595 examples=[
596 10,
597 SizeReference(
598 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
599 ).model_dump(mode="json"),
600 ]
601 ),
602 ]
603 """The size/length of this axis can be specified as
604 - fixed integer
605 - reference to another axis with an optional offset (see `SizeReference`)
606 """
609class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
610 pass
613class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
614 pass
617def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
618 if isinstance(v, dict):
619 return "with_halo" if "halo" in v else "wo_halo"
620 else:
621 return "with_halo" if hasattr(v, "halo") else "wo_halo"
624_TimeOutputAxisUnion = Annotated[
625 Union[
626 Annotated[TimeOutputAxis, Tag("wo_halo")],
627 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
628 ],
629 Discriminator(_get_halo_axis_discriminator_value),
630]
633class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
634 pass
637class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
638 pass
641_SpaceOutputAxisUnion = Annotated[
642 Union[
643 Annotated[SpaceOutputAxis, Tag("wo_halo")],
644 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
645 ],
646 Discriminator(_get_halo_axis_discriminator_value),
647]
650_OutputAxisUnion = Union[
651 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
652]
653OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
655AnyAxis = Union[InputAxis, OutputAxis]
657TVs = Union[
658 NotEmpty[List[int]],
659 NotEmpty[List[float]],
660 NotEmpty[List[bool]],
661 NotEmpty[List[str]],
662]
665NominalOrOrdinalDType = Literal[
666 "float32",
667 "float64",
668 "uint8",
669 "int8",
670 "uint16",
671 "int16",
672 "uint32",
673 "int32",
674 "uint64",
675 "int64",
676 "bool",
677]
680class NominalOrOrdinalDataDescr(Node):
681 values: TVs
682 """A fixed set of nominal or an ascending sequence of ordinal values.
683 In this case `data_type` is required to be an unsigend integer type, e.g. 'uint8'.
684 String `values` are interpreted as labels for tensor values 0, ..., N.
685 Note: as YAML 1.2 does not natively support a "set" datatype,
686 nominal values should be given as a sequence (aka list/array) as well.
687 """
689 type: Annotated[
690 NominalOrOrdinalDType,
691 Field(
692 examples=[
693 "float32",
694 "uint8",
695 "uint16",
696 "int64",
697 "bool",
698 ],
699 ),
700 ] = "uint8"
702 @model_validator(mode="after")
703 def _validate_values_match_type(
704 self,
705 ) -> Self:
706 incompatible: List[Any] = []
707 for v in self.values:
708 if self.type == "bool":
709 if not isinstance(v, bool):
710 incompatible.append(v)
711 elif self.type in DTYPE_LIMITS:
712 if (
713 isinstance(v, (int, float))
714 and (
715 v < DTYPE_LIMITS[self.type].min
716 or v > DTYPE_LIMITS[self.type].max
717 )
718 or (isinstance(v, str) and "uint" not in self.type)
719 or (isinstance(v, float) and "int" in self.type)
720 ):
721 incompatible.append(v)
722 else:
723 incompatible.append(v)
725 if len(incompatible) == 5:
726 incompatible.append("...")
727 break
729 if incompatible:
730 raise ValueError(
731 f"data type '{self.type}' incompatible with values {incompatible}"
732 )
734 return self
736 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
738 @property
739 def range(self):
740 if isinstance(self.values[0], str):
741 return 0, len(self.values) - 1
742 else:
743 return min(self.values), max(self.values)
746IntervalOrRatioDType = Literal[
747 "float32",
748 "float64",
749 "uint8",
750 "int8",
751 "uint16",
752 "int16",
753 "uint32",
754 "int32",
755 "uint64",
756 "int64",
757]
760class IntervalOrRatioDataDescr(Node):
761 type: Annotated[ # todo: rename to dtype
762 IntervalOrRatioDType,
763 Field(
764 examples=["float32", "float64", "uint8", "uint16"],
765 ),
766 ] = "float32"
767 range: Tuple[Optional[float], Optional[float]] = (
768 None,
769 None,
770 )
771 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
772 `None` corresponds to min/max of what can be expressed by `data_type`."""
773 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
774 scale: float = 1.0
775 """Scale for data on an interval (or ratio) scale."""
776 offset: Optional[float] = None
777 """Offset for data on a ratio scale."""
780TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
783class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC):
784 """processing base class"""
786 # id: Literal[PreprocessingId, PostprocessingId] # make abstract field
787 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"})
790class BinarizeKwargs(ProcessingKwargs):
791 """key word arguments for `BinarizeDescr`"""
793 threshold: float
794 """The fixed threshold"""
797class BinarizeAlongAxisKwargs(ProcessingKwargs):
798 """key word arguments for `BinarizeDescr`"""
800 threshold: NotEmpty[List[float]]
801 """The fixed threshold values along `axis`"""
803 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
804 """The `threshold` axis"""
807class BinarizeDescr(ProcessingDescrBase):
808 """Binarize the tensor with a fixed threshold.
810 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold`
811 will be set to one, values below the threshold to zero.
813 Examples:
814 - in YAML
815 ```yaml
816 postprocessing:
817 - id: binarize
818 kwargs:
819 axis: 'channel'
820 threshold: [0.25, 0.5, 0.75]
821 ```
822 - in Python:
823 >>> postprocessing = [BinarizeDescr(
824 ... kwargs=BinarizeAlongAxisKwargs(
825 ... axis=AxisId('channel'),
826 ... threshold=[0.25, 0.5, 0.75],
827 ... )
828 ... )]
829 """
831 id: Literal["binarize"] = "binarize"
832 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
835class ClipDescr(ProcessingDescrBase):
836 """Set tensor values below min to min and above max to max.
838 See `ScaleRangeDescr` for examples.
839 """
841 id: Literal["clip"] = "clip"
842 kwargs: ClipKwargs
845class EnsureDtypeKwargs(ProcessingKwargs):
846 """key word arguments for `EnsureDtypeDescr`"""
848 dtype: Literal[
849 "float32",
850 "float64",
851 "uint8",
852 "int8",
853 "uint16",
854 "int16",
855 "uint32",
856 "int32",
857 "uint64",
858 "int64",
859 "bool",
860 ]
863class EnsureDtypeDescr(ProcessingDescrBase):
864 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
866 This can for example be used to ensure the inner neural network model gets a
867 different input tensor data type than the fully described bioimage.io model does.
869 Examples:
870 The described bioimage.io model (incl. preprocessing) accepts any
871 float32-compatible tensor, normalizes it with percentiles and clipping and then
872 casts it to uint8, which is what the neural network in this example expects.
873 - in YAML
874 ```yaml
875 inputs:
876 - data:
877 type: float32 # described bioimage.io model is compatible with any float32 input tensor
878 preprocessing:
879 - id: scale_range
880 kwargs:
881 axes: ['y', 'x']
882 max_percentile: 99.8
883 min_percentile: 5.0
884 - id: clip
885 kwargs:
886 min: 0.0
887 max: 1.0
888 - id: ensure_dtype
889 kwargs:
890 dtype: uint8
891 ```
892 - in Python:
893 >>> preprocessing = [
894 ... ScaleRangeDescr(
895 ... kwargs=ScaleRangeKwargs(
896 ... axes= (AxisId('y'), AxisId('x')),
897 ... max_percentile= 99.8,
898 ... min_percentile= 5.0,
899 ... )
900 ... ),
901 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
902 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
903 ... ]
904 """
906 id: Literal["ensure_dtype"] = "ensure_dtype"
907 kwargs: EnsureDtypeKwargs
910class ScaleLinearKwargs(ProcessingKwargs):
911 """Key word arguments for `ScaleLinearDescr`"""
913 gain: float = 1.0
914 """multiplicative factor"""
916 offset: float = 0.0
917 """additive term"""
919 @model_validator(mode="after")
920 def _validate(self) -> Self:
921 if self.gain == 1.0 and self.offset == 0.0:
922 raise ValueError(
923 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
924 + " != 0.0."
925 )
927 return self
930class ScaleLinearAlongAxisKwargs(ProcessingKwargs):
931 """Key word arguments for `ScaleLinearDescr`"""
933 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
934 """The axis of of gains/offsets values."""
936 gain: Union[float, NotEmpty[List[float]]] = 1.0
937 """multiplicative factor"""
939 offset: Union[float, NotEmpty[List[float]]] = 0.0
940 """additive term"""
942 @model_validator(mode="after")
943 def _validate(self) -> Self:
945 if isinstance(self.gain, list):
946 if isinstance(self.offset, list):
947 if len(self.gain) != len(self.offset):
948 raise ValueError(
949 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
950 )
951 else:
952 self.offset = [float(self.offset)] * len(self.gain)
953 elif isinstance(self.offset, list):
954 self.gain = [float(self.gain)] * len(self.offset)
955 else:
956 raise ValueError(
957 "Do not specify an `axis` for scalar gain and offset values."
958 )
960 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
961 raise ValueError(
962 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
963 + " != 0.0."
964 )
966 return self
969class ScaleLinearDescr(ProcessingDescrBase):
970 """Fixed linear scaling.
972 Examples:
973 1. Scale with scalar gain and offset
974 - in YAML
975 ```yaml
976 preprocessing:
977 - id: scale_linear
978 kwargs:
979 gain: 2.0
980 offset: 3.0
981 ```
982 - in Python:
983 >>> preprocessing = [
984 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
985 ... ]
987 2. Independent scaling along an axis
988 - in YAML
989 ```yaml
990 preprocessing:
991 - id: scale_linear
992 kwargs:
993 axis: 'channel'
994 gain: [1.0, 2.0, 3.0]
995 ```
996 - in Python:
997 >>> preprocessing = [
998 ... ScaleLinearDescr(
999 ... kwargs=ScaleLinearAlongAxisKwargs(
1000 ... axis=AxisId("channel"),
1001 ... gain=[1.0, 2.0, 3.0],
1002 ... )
1003 ... )
1004 ... ]
1006 """
1008 id: Literal["scale_linear"] = "scale_linear"
1009 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1012class SigmoidDescr(ProcessingDescrBase):
1013 """The logistic sigmoid funciton, a.k.a. expit function.
1015 Examples:
1016 - in YAML
1017 ```yaml
1018 postprocessing:
1019 - id: sigmoid
1020 ```
1021 - in Python:
1022 >>> postprocessing = [SigmoidDescr()]
1023 """
1025 id: Literal["sigmoid"] = "sigmoid"
1027 @property
1028 def kwargs(self) -> ProcessingKwargs:
1029 """empty kwargs"""
1030 return ProcessingKwargs()
1033class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1034 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1036 mean: float
1037 """The mean value to normalize with."""
1039 std: Annotated[float, Ge(1e-6)]
1040 """The standard deviation value to normalize with."""
1043class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs):
1044 """key word arguments for `FixedZeroMeanUnitVarianceDescr`"""
1046 mean: NotEmpty[List[float]]
1047 """The mean value(s) to normalize with."""
1049 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1050 """The standard deviation value(s) to normalize with.
1051 Size must match `mean` values."""
1053 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1054 """The axis of the mean/std values to normalize each entry along that dimension
1055 separately."""
1057 @model_validator(mode="after")
1058 def _mean_and_std_match(self) -> Self:
1059 if len(self.mean) != len(self.std):
1060 raise ValueError(
1061 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1062 + " must match."
1063 )
1065 return self
1068class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1069 """Subtract a given mean and divide by the standard deviation.
1071 Normalize with fixed, precomputed values for
1072 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1073 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1074 axes.
1076 Examples:
1077 1. scalar value for whole tensor
1078 - in YAML
1079 ```yaml
1080 preprocessing:
1081 - id: fixed_zero_mean_unit_variance
1082 kwargs:
1083 mean: 103.5
1084 std: 13.7
1085 ```
1086 - in Python
1087 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1088 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1089 ... )]
1091 2. independently along an axis
1092 - in YAML
1093 ```yaml
1094 preprocessing:
1095 - id: fixed_zero_mean_unit_variance
1096 kwargs:
1097 axis: channel
1098 mean: [101.5, 102.5, 103.5]
1099 std: [11.7, 12.7, 13.7]
1100 ```
1101 - in Python
1102 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1103 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1104 ... axis=AxisId("channel"),
1105 ... mean=[101.5, 102.5, 103.5],
1106 ... std=[11.7, 12.7, 13.7],
1107 ... )
1108 ... )]
1109 """
1111 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1112 kwargs: Union[
1113 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1114 ]
1117class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
1118 """key word arguments for `ZeroMeanUnitVarianceDescr`"""
1120 axes: Annotated[
1121 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1122 ] = None
1123 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1124 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1125 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1126 To normalize each sample independently leave out the 'batch' axis.
1127 Default: Scale all axes jointly."""
1129 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1130 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1133class ZeroMeanUnitVarianceDescr(ProcessingDescrBase):
1134 """Subtract mean and divide by variance.
1136 Examples:
1137 Subtract tensor mean and variance
1138 - in YAML
1139 ```yaml
1140 preprocessing:
1141 - id: zero_mean_unit_variance
1142 ```
1143 - in Python
1144 >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1145 """
1147 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1148 kwargs: ZeroMeanUnitVarianceKwargs = Field(
1149 default_factory=ZeroMeanUnitVarianceKwargs
1150 )
1153class ScaleRangeKwargs(ProcessingKwargs):
1154 """key word arguments for `ScaleRangeDescr`
1156 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1157 this processing step normalizes data to the [0, 1] intervall.
1158 For other percentiles the normalized values will partially be outside the [0, 1]
1159 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1160 normalized values to a range.
1161 """
1163 axes: Annotated[
1164 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1165 ] = None
1166 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1167 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1168 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1169 To normalize samples independently, leave out the "batch" axis.
1170 Default: Scale all axes jointly."""
1172 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1173 """The lower percentile used to determine the value to align with zero."""
1175 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1176 """The upper percentile used to determine the value to align with one.
1177 Has to be bigger than `min_percentile`.
1178 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1179 accepting percentiles specified in the range 0.0 to 1.0."""
1181 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1182 """Epsilon for numeric stability.
1183 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1184 with `v_lower,v_upper` values at the respective percentiles."""
1186 reference_tensor: Optional[TensorId] = None
1187 """Tensor ID to compute the percentiles from. Default: The tensor itself.
1188 For any tensor in `inputs` only input tensor references are allowed."""
1190 @field_validator("max_percentile", mode="after")
1191 @classmethod
1192 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1193 if (min_p := info.data["min_percentile"]) >= value:
1194 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1196 return value
1199class ScaleRangeDescr(ProcessingDescrBase):
1200 """Scale with percentiles.
1202 Examples:
1203 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1204 - in YAML
1205 ```yaml
1206 preprocessing:
1207 - id: scale_range
1208 kwargs:
1209 axes: ['y', 'x']
1210 max_percentile: 99.8
1211 min_percentile: 5.0
1212 ```
1213 - in Python
1214 >>> preprocessing = [
1215 ... ScaleRangeDescr(
1216 ... kwargs=ScaleRangeKwargs(
1217 ... axes= (AxisId('y'), AxisId('x')),
1218 ... max_percentile= 99.8,
1219 ... min_percentile= 5.0,
1220 ... )
1221 ... ),
1222 ... ClipDescr(
1223 ... kwargs=ClipKwargs(
1224 ... min=0.0,
1225 ... max=1.0,
1226 ... )
1227 ... ),
1228 ... ]
1230 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1231 - in YAML
1232 ```yaml
1233 preprocessing:
1234 - id: scale_range
1235 kwargs:
1236 axes: ['y', 'x']
1237 max_percentile: 99.8
1238 min_percentile: 5.0
1239 - id: scale_range
1240 - id: clip
1241 kwargs:
1242 min: 0.0
1243 max: 1.0
1244 ```
1245 - in Python
1246 >>> preprocessing = [ScaleRangeDescr(
1247 ... kwargs=ScaleRangeKwargs(
1248 ... axes= (AxisId('y'), AxisId('x')),
1249 ... max_percentile= 99.8,
1250 ... min_percentile= 5.0,
1251 ... )
1252 ... )]
1254 """
1256 id: Literal["scale_range"] = "scale_range"
1257 kwargs: ScaleRangeKwargs
1260class ScaleMeanVarianceKwargs(ProcessingKwargs):
1261 """key word arguments for `ScaleMeanVarianceKwargs`"""
1263 reference_tensor: TensorId
1264 """Name of tensor to match."""
1266 axes: Annotated[
1267 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1268 ] = None
1269 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1270 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1271 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1272 To normalize samples independently, leave out the 'batch' axis.
1273 Default: Scale all axes jointly."""
1275 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1276 """Epsilon for numeric stability:
1277 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1280class ScaleMeanVarianceDescr(ProcessingDescrBase):
1281 """Scale a tensor's data distribution to match another tensor's mean/std.
1282 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1283 """
1285 id: Literal["scale_mean_variance"] = "scale_mean_variance"
1286 kwargs: ScaleMeanVarianceKwargs
1289PreprocessingDescr = Annotated[
1290 Union[
1291 BinarizeDescr,
1292 ClipDescr,
1293 EnsureDtypeDescr,
1294 ScaleLinearDescr,
1295 SigmoidDescr,
1296 FixedZeroMeanUnitVarianceDescr,
1297 ZeroMeanUnitVarianceDescr,
1298 ScaleRangeDescr,
1299 ],
1300 Discriminator("id"),
1301]
1302PostprocessingDescr = Annotated[
1303 Union[
1304 BinarizeDescr,
1305 ClipDescr,
1306 EnsureDtypeDescr,
1307 ScaleLinearDescr,
1308 SigmoidDescr,
1309 FixedZeroMeanUnitVarianceDescr,
1310 ZeroMeanUnitVarianceDescr,
1311 ScaleRangeDescr,
1312 ScaleMeanVarianceDescr,
1313 ],
1314 Discriminator("id"),
1315]
1317IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1320class TensorDescrBase(Node, Generic[IO_AxisT]):
1321 id: TensorId
1322 """Tensor id. No duplicates are allowed."""
1324 description: Annotated[str, MaxLen(128)] = ""
1325 """free text description"""
1327 axes: NotEmpty[Sequence[IO_AxisT]]
1328 """tensor axes"""
1330 @property
1331 def shape(self):
1332 return tuple(a.size for a in self.axes)
1334 @field_validator("axes", mode="after", check_fields=False)
1335 @classmethod
1336 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1337 batch_axes = [a for a in axes if a.type == "batch"]
1338 if len(batch_axes) > 1:
1339 raise ValueError(
1340 f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1341 )
1343 seen_ids: Set[AxisId] = set()
1344 duplicate_axes_ids: Set[AxisId] = set()
1345 for a in axes:
1346 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1348 if duplicate_axes_ids:
1349 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1351 return axes
1353 test_tensor: FileDescr
1354 """An example tensor to use for testing.
1355 Using the model with the test input tensors is expected to yield the test output tensors.
1356 Each test tensor has be a an ndarray in the
1357 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1358 The file extension must be '.npy'."""
1360 sample_tensor: Optional[FileDescr] = None
1361 """A sample tensor to illustrate a possible input/output for the model,
1362 The sample image primarily serves to inform a human user about an example use case
1363 and is typically stored as .hdf5, .png or .tiff.
1364 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1365 (numpy's `.npy` format is not supported).
1366 The image dimensionality has to match the number of axes specified in this tensor description.
1367 """
1369 @model_validator(mode="after")
1370 def _validate_sample_tensor(self) -> Self:
1371 if (
1372 self.sample_tensor is None
1373 or not validation_context_var.get().perform_io_checks
1374 ):
1375 return self
1377 local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1378 tensor: NDArray[Any] = imread(
1379 local.path.read_bytes(),
1380 extension=PurePosixPath(local.original_file_name).suffix,
1381 )
1382 n_dims = len(tensor.squeeze().shape)
1383 n_dims_min = n_dims_max = len(self.axes)
1385 for a in self.axes:
1386 if isinstance(a, BatchAxis):
1387 n_dims_min -= 1
1388 elif isinstance(a.size, int):
1389 if a.size == 1:
1390 n_dims_min -= 1
1391 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1392 if a.size.min == 1:
1393 n_dims_min -= 1
1394 elif isinstance(a.size, SizeReference):
1395 if a.size.offset < 2:
1396 # size reference may result in singleton axis
1397 n_dims_min -= 1
1398 else:
1399 assert_never(a.size)
1401 n_dims_min = max(0, n_dims_min)
1402 if n_dims < n_dims_min or n_dims > n_dims_max:
1403 raise ValueError(
1404 f"Expected sample tensor to have {n_dims_min} to"
1405 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1406 )
1408 return self
1410 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1411 IntervalOrRatioDataDescr()
1412 )
1413 """Description of the tensor's data values, optionally per channel.
1414 If specified per channel, the data `type` needs to match across channels."""
1416 @property
1417 def dtype(
1418 self,
1419 ) -> Literal[
1420 "float32",
1421 "float64",
1422 "uint8",
1423 "int8",
1424 "uint16",
1425 "int16",
1426 "uint32",
1427 "int32",
1428 "uint64",
1429 "int64",
1430 "bool",
1431 ]:
1432 """dtype as specified under `data.type` or `data[i].type`"""
1433 if isinstance(self.data, collections.abc.Sequence):
1434 return self.data[0].type
1435 else:
1436 return self.data.type
1438 @field_validator("data", mode="after")
1439 @classmethod
1440 def _check_data_type_across_channels(
1441 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1442 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1443 if not isinstance(value, list):
1444 return value
1446 dtypes = {t.type for t in value}
1447 if len(dtypes) > 1:
1448 raise ValueError(
1449 "Tensor data descriptions per channel need to agree in their data"
1450 + f" `type`, but found {dtypes}."
1451 )
1453 return value
1455 @model_validator(mode="after")
1456 def _check_data_matches_channelaxis(self) -> Self:
1457 if not isinstance(self.data, (list, tuple)):
1458 return self
1460 for a in self.axes:
1461 if isinstance(a, ChannelAxis):
1462 size = a.size
1463 assert isinstance(size, int)
1464 break
1465 else:
1466 return self
1468 if len(self.data) != size:
1469 raise ValueError(
1470 f"Got tensor data descriptions for {len(self.data)} channels, but"
1471 + f" '{a.id}' axis has size {size}."
1472 )
1474 return self
1476 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1477 if len(array.shape) != len(self.axes):
1478 raise ValueError(
1479 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1480 + f" incompatible with {len(self.axes)} axes."
1481 )
1482 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1485class InputTensorDescr(TensorDescrBase[InputAxis]):
1486 id: TensorId = TensorId("input")
1487 """Input tensor id.
1488 No duplicates are allowed across all inputs and outputs."""
1490 optional: bool = False
1491 """indicates that this tensor may be `None`"""
1493 preprocessing: List[PreprocessingDescr] = Field(default_factory=list)
1494 """Description of how this input should be preprocessed.
1496 notes:
1497 - If preprocessing does not start with an 'ensure_dtype' entry, it is added
1498 to ensure an input tensor's data type matches the input tensor's data description.
1499 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
1500 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
1501 changing the data type.
1502 """
1504 @model_validator(mode="after")
1505 def _validate_preprocessing_kwargs(self) -> Self:
1506 axes_ids = [a.id for a in self.axes]
1507 for p in self.preprocessing:
1508 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1509 if kwargs_axes is None:
1510 continue
1512 if not isinstance(kwargs_axes, collections.abc.Sequence):
1513 raise ValueError(
1514 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
1515 )
1517 if any(a not in axes_ids for a in kwargs_axes):
1518 raise ValueError(
1519 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
1520 )
1522 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1523 dtype = self.data.type
1524 else:
1525 dtype = self.data[0].type
1527 # ensure `preprocessing` begins with `EnsureDtypeDescr`
1528 if not self.preprocessing or not isinstance(
1529 self.preprocessing[0], EnsureDtypeDescr
1530 ):
1531 self.preprocessing.insert(
1532 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1533 )
1535 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1536 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
1537 self.preprocessing.append(
1538 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1539 )
1541 return self
1544def convert_axes(
1545 axes: str,
1546 *,
1547 shape: Union[
1548 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
1549 ],
1550 tensor_type: Literal["input", "output"],
1551 halo: Optional[Sequence[int]],
1552 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1553):
1554 ret: List[AnyAxis] = []
1555 for i, a in enumerate(axes):
1556 axis_type = _AXIS_TYPE_MAP.get(a, a)
1557 if axis_type == "batch":
1558 ret.append(BatchAxis())
1559 continue
1561 scale = 1.0
1562 if isinstance(shape, _ParameterizedInputShape_v0_4):
1563 if shape.step[i] == 0:
1564 size = shape.min[i]
1565 else:
1566 size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
1567 elif isinstance(shape, _ImplicitOutputShape_v0_4):
1568 ref_t = str(shape.reference_tensor)
1569 if ref_t.count(".") == 1:
1570 t_id, orig_a_id = ref_t.split(".")
1571 else:
1572 t_id = ref_t
1573 orig_a_id = a
1575 a_id = _AXIS_ID_MAP.get(orig_a_id, a)
1576 if not (orig_scale := shape.scale[i]):
1577 # old way to insert a new axis dimension
1578 size = int(2 * shape.offset[i])
1579 else:
1580 scale = 1 / orig_scale
1581 if axis_type in ("channel", "index"):
1582 # these axes no longer have a scale
1583 offset_from_scale = orig_scale * size_refs.get(
1584 _TensorName_v0_4(t_id), {}
1585 ).get(orig_a_id, 0)
1586 else:
1587 offset_from_scale = 0
1588 size = SizeReference(
1589 tensor_id=TensorId(t_id),
1590 axis_id=AxisId(a_id),
1591 offset=int(offset_from_scale + 2 * shape.offset[i]),
1592 )
1593 else:
1594 size = shape[i]
1596 if axis_type == "time":
1597 if tensor_type == "input":
1598 ret.append(TimeInputAxis(size=size, scale=scale))
1599 else:
1600 assert not isinstance(size, ParameterizedSize)
1601 if halo is None:
1602 ret.append(TimeOutputAxis(size=size, scale=scale))
1603 else:
1604 assert not isinstance(size, int)
1605 ret.append(
1606 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
1607 )
1609 elif axis_type == "index":
1610 if tensor_type == "input":
1611 ret.append(IndexInputAxis(size=size))
1612 else:
1613 if isinstance(size, ParameterizedSize):
1614 size = DataDependentSize(min=size.min)
1616 ret.append(IndexOutputAxis(size=size))
1617 elif axis_type == "channel":
1618 assert not isinstance(size, ParameterizedSize)
1619 if isinstance(size, SizeReference):
1620 warnings.warn(
1621 "Conversion of channel size from an implicit output shape may be"
1622 + " wrong"
1623 )
1624 ret.append(
1625 ChannelAxis(
1626 channel_names=[
1627 Identifier(f"channel{i}") for i in range(size.offset)
1628 ]
1629 )
1630 )
1631 else:
1632 ret.append(
1633 ChannelAxis(
1634 channel_names=[Identifier(f"channel{i}") for i in range(size)]
1635 )
1636 )
1637 elif axis_type == "space":
1638 if tensor_type == "input":
1639 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
1640 else:
1641 assert not isinstance(size, ParameterizedSize)
1642 if halo is None or halo[i] == 0:
1643 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
1644 elif isinstance(size, int):
1645 raise NotImplementedError(
1646 f"output axis with halo and fixed size (here {size}) not allowed"
1647 )
1648 else:
1649 ret.append(
1650 SpaceOutputAxisWithHalo(
1651 id=AxisId(a), size=size, scale=scale, halo=halo[i]
1652 )
1653 )
1655 return ret
1658_AXIS_TYPE_MAP = {
1659 "b": "batch",
1660 "t": "time",
1661 "i": "index",
1662 "c": "channel",
1663 "x": "space",
1664 "y": "space",
1665 "z": "space",
1666}
1668_AXIS_ID_MAP = {
1669 "b": "batch",
1670 "t": "time",
1671 "i": "index",
1672 "c": "channel",
1673}
1676def _axes_letters_to_ids(
1677 axes: Optional[str],
1678) -> Optional[List[AxisId]]:
1679 if axes is None:
1680 return None
1681 return [AxisId(_AXIS_ID_MAP.get(a, a)) for a in map(str, axes)]
1684def _get_complement_v04_axis(
1685 tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
1686) -> Optional[AxisId]:
1687 if axes is None:
1688 return None
1690 non_complement_axes = set(axes) | {"b"}
1691 complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
1692 if len(complement_axes) > 1:
1693 raise ValueError(
1694 f"Expected none or a single complement axis, but axes '{axes}' "
1695 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
1696 )
1698 return None if not complement_axes else AxisId(complement_axes[0])
1701def _convert_proc(
1702 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
1703 tensor_axes: Sequence[str],
1704) -> Union[PreprocessingDescr, PostprocessingDescr]:
1705 if isinstance(p, _BinarizeDescr_v0_4):
1706 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
1707 elif isinstance(p, _ClipDescr_v0_4):
1708 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
1709 elif isinstance(p, _SigmoidDescr_v0_4):
1710 return SigmoidDescr()
1711 elif isinstance(p, _ScaleLinearDescr_v0_4):
1712 axes = _axes_letters_to_ids(p.kwargs.axes)
1713 if p.kwargs.axes is None:
1714 axis = None
1715 else:
1716 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1718 if axis is None:
1719 assert not isinstance(p.kwargs.gain, list)
1720 assert not isinstance(p.kwargs.offset, list)
1721 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
1722 else:
1723 kwargs = ScaleLinearAlongAxisKwargs(
1724 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
1725 )
1726 return ScaleLinearDescr(kwargs=kwargs)
1727 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
1728 return ScaleMeanVarianceDescr(
1729 kwargs=ScaleMeanVarianceKwargs(
1730 axes=_axes_letters_to_ids(p.kwargs.axes),
1731 reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
1732 eps=p.kwargs.eps,
1733 )
1734 )
1735 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
1736 if p.kwargs.mode == "fixed":
1737 mean = p.kwargs.mean
1738 std = p.kwargs.std
1739 assert mean is not None
1740 assert std is not None
1742 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
1744 if axis is None:
1745 return FixedZeroMeanUnitVarianceDescr(
1746 kwargs=FixedZeroMeanUnitVarianceKwargs(
1747 mean=mean, std=std # pyright: ignore[reportArgumentType]
1748 )
1749 )
1750 else:
1751 if not isinstance(mean, list):
1752 mean = [float(mean)]
1753 if not isinstance(std, list):
1754 std = [float(std)]
1756 return FixedZeroMeanUnitVarianceDescr(
1757 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1758 axis=axis, mean=mean, std=std
1759 )
1760 )
1762 else:
1763 axes = _axes_letters_to_ids(p.kwargs.axes) or []
1764 if p.kwargs.mode == "per_dataset":
1765 axes = [AxisId("batch")] + axes
1766 if not axes:
1767 axes = None
1768 return ZeroMeanUnitVarianceDescr(
1769 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
1770 )
1772 elif isinstance(p, _ScaleRangeDescr_v0_4):
1773 return ScaleRangeDescr(
1774 kwargs=ScaleRangeKwargs(
1775 axes=_axes_letters_to_ids(p.kwargs.axes),
1776 min_percentile=p.kwargs.min_percentile,
1777 max_percentile=p.kwargs.max_percentile,
1778 eps=p.kwargs.eps,
1779 )
1780 )
1781 else:
1782 assert_never(p)
1785class _InputTensorConv(
1786 Converter[
1787 _InputTensorDescr_v0_4,
1788 InputTensorDescr,
1789 ImportantFileSource,
1790 Optional[ImportantFileSource],
1791 Mapping[_TensorName_v0_4, Mapping[str, int]],
1792 ]
1793):
1794 def _convert(
1795 self,
1796 src: _InputTensorDescr_v0_4,
1797 tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
1798 test_tensor: ImportantFileSource,
1799 sample_tensor: Optional[ImportantFileSource],
1800 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1801 ) -> "InputTensorDescr | dict[str, Any]":
1802 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
1803 src.axes,
1804 shape=src.shape,
1805 tensor_type="input",
1806 halo=None,
1807 size_refs=size_refs,
1808 )
1809 prep: List[PreprocessingDescr] = []
1810 for p in src.preprocessing:
1811 cp = _convert_proc(p, src.axes)
1812 assert not isinstance(cp, ScaleMeanVarianceDescr)
1813 prep.append(cp)
1815 return tgt(
1816 axes=axes,
1817 id=TensorId(str(src.name)),
1818 test_tensor=FileDescr(source=test_tensor),
1819 sample_tensor=(
1820 None if sample_tensor is None else FileDescr(source=sample_tensor)
1821 ),
1822 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType]
1823 preprocessing=prep,
1824 )
1827_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
1830class OutputTensorDescr(TensorDescrBase[OutputAxis]):
1831 id: TensorId = TensorId("output")
1832 """Output tensor id.
1833 No duplicates are allowed across all inputs and outputs."""
1835 postprocessing: List[PostprocessingDescr] = Field(default_factory=list)
1836 """Description of how this output should be postprocessed.
1838 note: `postprocessing` always ends with an 'ensure_dtype' operation.
1839 If not given this is added to cast to this tensor's `data.type`.
1840 """
1842 @model_validator(mode="after")
1843 def _validate_postprocessing_kwargs(self) -> Self:
1844 axes_ids = [a.id for a in self.axes]
1845 for p in self.postprocessing:
1846 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
1847 if kwargs_axes is None:
1848 continue
1850 if not isinstance(kwargs_axes, collections.abc.Sequence):
1851 raise ValueError(
1852 f"expected `axes` sequence, but got {type(kwargs_axes)}"
1853 )
1855 if any(a not in axes_ids for a in kwargs_axes):
1856 raise ValueError("`kwargs.axes` needs to be subset of axes ids")
1858 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
1859 dtype = self.data.type
1860 else:
1861 dtype = self.data[0].type
1863 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
1864 if not self.postprocessing or not isinstance(
1865 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
1866 ):
1867 self.postprocessing.append(
1868 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
1869 )
1870 return self
1873class _OutputTensorConv(
1874 Converter[
1875 _OutputTensorDescr_v0_4,
1876 OutputTensorDescr,
1877 ImportantFileSource,
1878 Optional[ImportantFileSource],
1879 Mapping[_TensorName_v0_4, Mapping[str, int]],
1880 ]
1881):
1882 def _convert(
1883 self,
1884 src: _OutputTensorDescr_v0_4,
1885 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
1886 test_tensor: ImportantFileSource,
1887 sample_tensor: Optional[ImportantFileSource],
1888 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
1889 ) -> "OutputTensorDescr | dict[str, Any]":
1890 # TODO: split convert_axes into convert_output_axes and convert_input_axes
1891 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
1892 src.axes,
1893 shape=src.shape,
1894 tensor_type="output",
1895 halo=src.halo,
1896 size_refs=size_refs,
1897 )
1898 data_descr: Dict[str, Any] = dict(type=src.data_type)
1899 if data_descr["type"] == "bool":
1900 data_descr["values"] = [False, True]
1902 return tgt(
1903 axes=axes,
1904 id=TensorId(str(src.name)),
1905 test_tensor=FileDescr(source=test_tensor),
1906 sample_tensor=(
1907 None if sample_tensor is None else FileDescr(source=sample_tensor)
1908 ),
1909 data=data_descr, # pyright: ignore[reportArgumentType]
1910 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
1911 )
1914_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
1917TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
1920def validate_tensors(
1921 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]],
1922 tensor_origin: str, # for more precise error messages, e.g. 'test_tensor'
1923):
1924 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {}
1926 def e_msg(d: TensorDescr):
1927 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
1929 for descr, array in tensors.values():
1930 try:
1931 axis_sizes = descr.get_axis_sizes_for_array(array)
1932 except ValueError as e:
1933 raise ValueError(f"{e_msg(descr)} {e}")
1934 else:
1935 all_tensor_axes[descr.id] = {
1936 a.id: (a, axis_sizes[a.id]) for a in descr.axes
1937 }
1939 for descr, array in tensors.values():
1940 if array.dtype.name != descr.dtype:
1941 raise ValueError(
1942 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not"
1943 + f" match described dtype '{descr.dtype}'"
1944 )
1946 for a in descr.axes:
1947 actual_size = all_tensor_axes[descr.id][a.id][1]
1948 if a.size is None:
1949 continue
1951 if isinstance(a.size, int):
1952 if actual_size != a.size:
1953 raise ValueError(
1954 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' "
1955 + f"has incompatible size {actual_size}, expected {a.size}"
1956 )
1957 elif isinstance(a.size, ParameterizedSize):
1958 _ = a.size.validate_size(actual_size)
1959 elif isinstance(a.size, DataDependentSize):
1960 _ = a.size.validate_size(actual_size)
1961 elif isinstance(a.size, SizeReference):
1962 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
1963 if ref_tensor_axes is None:
1964 raise ValueError(
1965 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
1966 + f" reference '{a.size.tensor_id}'"
1967 )
1969 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
1970 if ref_axis is None or ref_size is None:
1971 raise ValueError(
1972 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
1973 + f" reference '{a.size.tensor_id}.{a.size.axis_id}"
1974 )
1976 if a.unit != ref_axis.unit:
1977 raise ValueError(
1978 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires"
1979 + " axis and reference axis to have the same `unit`, but"
1980 + f" {a.unit}!={ref_axis.unit}"
1981 )
1983 if actual_size != (
1984 expected_size := (
1985 ref_size * ref_axis.scale / a.scale + a.size.offset
1986 )
1987 ):
1988 raise ValueError(
1989 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size"
1990 + f" {actual_size} invalid for referenced size {ref_size};"
1991 + f" expected {expected_size}"
1992 )
1993 else:
1994 assert_never(a.size)
1997class EnvironmentFileDescr(FileDescr):
1998 source: Annotated[
1999 ImportantFileSource,
2000 WithSuffix((".yaml", ".yml"), case_sensitive=True),
2001 Field(
2002 examples=["environment.yaml"],
2003 ),
2004 ]
2005 """∈📦 Conda environment file.
2006 Allows to specify custom dependencies, see conda docs:
2007 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2008 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2009 """
2012class _ArchitectureCallableDescr(Node):
2013 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2014 """Identifier of the callable that returns a torch.nn.Module instance."""
2016 kwargs: Dict[str, YamlValue] = Field(default_factory=dict)
2017 """key word arguments for the `callable`"""
2020class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2021 pass
2024class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2025 import_from: str
2026 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2029ArchitectureDescr = Annotated[
2030 Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr],
2031 Field(union_mode="left_to_right"),
2032]
2035class _ArchFileConv(
2036 Converter[
2037 _CallableFromFile_v0_4,
2038 ArchitectureFromFileDescr,
2039 Optional[Sha256],
2040 Dict[str, Any],
2041 ]
2042):
2043 def _convert(
2044 self,
2045 src: _CallableFromFile_v0_4,
2046 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2047 sha256: Optional[Sha256],
2048 kwargs: Dict[str, Any],
2049 ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2050 if src.startswith("http") and src.count(":") == 2:
2051 http, source, callable_ = src.split(":")
2052 source = ":".join((http, source))
2053 elif not src.startswith("http") and src.count(":") == 1:
2054 source, callable_ = src.split(":")
2055 else:
2056 source = str(src)
2057 callable_ = str(src)
2058 return tgt(
2059 callable=Identifier(callable_),
2060 source=cast(ImportantFileSource, source),
2061 sha256=sha256,
2062 kwargs=kwargs,
2063 )
2066_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2069class _ArchLibConv(
2070 Converter[
2071 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2072 ]
2073):
2074 def _convert(
2075 self,
2076 src: _CallableFromDepencency_v0_4,
2077 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2078 kwargs: Dict[str, Any],
2079 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2080 *mods, callable_ = src.split(".")
2081 import_from = ".".join(mods)
2082 return tgt(
2083 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2084 )
2087_arch_lib_conv = _ArchLibConv(
2088 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2089)
2092class WeightsEntryDescrBase(FileDescr):
2093 type: ClassVar[WeightsFormat]
2094 weights_format_name: ClassVar[str] # human readable
2096 source: ImportantFileSource
2097 """∈📦 The weights file."""
2099 authors: Optional[List[Author]] = None
2100 """Authors
2101 Either the person(s) that have trained this model resulting in the original weights file.
2102 (If this is the initial weights entry, i.e. it does not have a `parent`)
2103 Or the person(s) who have converted the weights to this weights format.
2104 (If this is a child weight, i.e. it has a `parent` field)
2105 """
2107 parent: Annotated[
2108 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2109 ] = None
2110 """The source weights these weights were converted from.
2111 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2112 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2113 All weight entries except one (the initial set of weights resulting from training the model),
2114 need to have this field."""
2116 @model_validator(mode="after")
2117 def check_parent_is_not_self(self) -> Self:
2118 if self.type == self.parent:
2119 raise ValueError("Weights entry can't be it's own parent.")
2121 return self
2124class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2125 type = "keras_hdf5"
2126 weights_format_name: ClassVar[str] = "Keras HDF5"
2127 tensorflow_version: Version
2128 """TensorFlow version used to create these weights."""
2131class OnnxWeightsDescr(WeightsEntryDescrBase):
2132 type = "onnx"
2133 weights_format_name: ClassVar[str] = "ONNX"
2134 opset_version: Annotated[int, Ge(7)]
2135 """ONNX opset version"""
2138class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2139 type = "pytorch_state_dict"
2140 weights_format_name: ClassVar[str] = "Pytorch State Dict"
2141 architecture: ArchitectureDescr
2142 pytorch_version: Version
2143 """Version of the PyTorch library used.
2144 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2145 """
2146 dependencies: Optional[EnvironmentFileDescr] = None
2147 """Custom depencies beyond pytorch.
2148 The conda environment file should include pytorch and any version pinning has to be compatible with
2149 `pytorch_version`.
2150 """
2153class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2154 type = "tensorflow_js"
2155 weights_format_name: ClassVar[str] = "Tensorflow.js"
2156 tensorflow_version: Version
2157 """Version of the TensorFlow library used."""
2159 source: ImportantFileSource
2160 """∈📦 The multi-file weights.
2161 All required files/folders should be a zip archive."""
2164class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2165 type = "tensorflow_saved_model_bundle"
2166 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2167 tensorflow_version: Version
2168 """Version of the TensorFlow library used."""
2170 dependencies: Optional[EnvironmentFileDescr] = None
2171 """Custom dependencies beyond tensorflow.
2172 Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`."""
2174 source: ImportantFileSource
2175 """∈📦 The multi-file weights.
2176 All required files/folders should be a zip archive."""
2179class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2180 type = "torchscript"
2181 weights_format_name: ClassVar[str] = "TorchScript"
2182 pytorch_version: Version
2183 """Version of the PyTorch library used."""
2186class WeightsDescr(Node):
2187 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2188 onnx: Optional[OnnxWeightsDescr] = None
2189 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2190 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2191 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2192 None
2193 )
2194 torchscript: Optional[TorchscriptWeightsDescr] = None
2196 @model_validator(mode="after")
2197 def check_entries(self) -> Self:
2198 entries = {wtype for wtype, entry in self if entry is not None}
2200 if not entries:
2201 raise ValueError("Missing weights entry")
2203 entries_wo_parent = {
2204 wtype
2205 for wtype, entry in self
2206 if entry is not None and hasattr(entry, "parent") and entry.parent is None
2207 }
2208 if len(entries_wo_parent) != 1:
2209 issue_warning(
2210 "Exactly one weights entry may not specify the `parent` field (got"
2211 + " {value}). That entry is considered the original set of model weights."
2212 + " Other weight formats are created through conversion of the orignal or"
2213 + " already converted weights. They have to reference the weights format"
2214 + " they were converted from as their `parent`.",
2215 value=len(entries_wo_parent),
2216 field="weights",
2217 )
2219 for wtype, entry in self:
2220 if entry is None:
2221 continue
2223 assert hasattr(entry, "type")
2224 assert hasattr(entry, "parent")
2225 assert wtype == entry.type
2226 if (
2227 entry.parent is not None and entry.parent not in entries
2228 ): # self reference checked for `parent` field
2229 raise ValueError(
2230 f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2231 + f" formats: {entries}"
2232 )
2234 return self
2236 def __getitem__(
2237 self,
2238 key: Literal[
2239 "keras_hdf5",
2240 "onnx",
2241 "pytorch_state_dict",
2242 "tensorflow_js",
2243 "tensorflow_saved_model_bundle",
2244 "torchscript",
2245 ],
2246 ):
2247 if key == "keras_hdf5":
2248 ret = self.keras_hdf5
2249 elif key == "onnx":
2250 ret = self.onnx
2251 elif key == "pytorch_state_dict":
2252 ret = self.pytorch_state_dict
2253 elif key == "tensorflow_js":
2254 ret = self.tensorflow_js
2255 elif key == "tensorflow_saved_model_bundle":
2256 ret = self.tensorflow_saved_model_bundle
2257 elif key == "torchscript":
2258 ret = self.torchscript
2259 else:
2260 raise KeyError(key)
2262 if ret is None:
2263 raise KeyError(key)
2265 return ret
2267 @property
2268 def available_formats(self):
2269 return {
2270 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
2271 **({} if self.onnx is None else {"onnx": self.onnx}),
2272 **(
2273 {}
2274 if self.pytorch_state_dict is None
2275 else {"pytorch_state_dict": self.pytorch_state_dict}
2276 ),
2277 **(
2278 {}
2279 if self.tensorflow_js is None
2280 else {"tensorflow_js": self.tensorflow_js}
2281 ),
2282 **(
2283 {}
2284 if self.tensorflow_saved_model_bundle is None
2285 else {
2286 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
2287 }
2288 ),
2289 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
2290 }
2292 @property
2293 def missing_formats(self):
2294 return {
2295 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
2296 }
2299class ModelId(ResourceId):
2300 pass
2303class LinkedModel(LinkedResourceBase):
2304 """Reference to a bioimage.io model."""
2306 id: ModelId
2307 """A valid model `id` from the bioimage.io collection."""
2310class _DataDepSize(NamedTuple):
2311 min: int
2312 max: Optional[int]
2315class _AxisSizes(NamedTuple):
2316 """the lenghts of all axes of model inputs and outputs"""
2318 inputs: Dict[Tuple[TensorId, AxisId], int]
2319 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
2322class _TensorSizes(NamedTuple):
2323 """_AxisSizes as nested dicts"""
2325 inputs: Dict[TensorId, Dict[AxisId, int]]
2326 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
2329class ModelDescr(GenericModelDescrBase):
2330 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2331 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2332 """
2334 format_version: Literal["0.5.3"] = "0.5.3"
2335 """Version of the bioimage.io model description specification used.
2336 When creating a new model always use the latest micro/patch version described here.
2337 The `format_version` is important for any consumer software to understand how to parse the fields.
2338 """
2340 type: Literal["model"] = "model"
2341 """Specialized resource type 'model'"""
2343 id: Optional[ModelId] = None
2344 """bioimage.io-wide unique resource identifier
2345 assigned by bioimage.io; version **un**specific."""
2347 authors: NotEmpty[List[Author]]
2348 """The authors are the creators of the model RDF and the primary points of contact."""
2350 documentation: Annotated[
2351 DocumentationSource,
2352 Field(
2353 examples=[
2354 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
2355 "README.md",
2356 ],
2357 ),
2358 ]
2359 """∈📦 URL or relative path to a markdown file with additional documentation.
2360 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2361 The documentation should include a '#[#] Validation' (sub)section
2362 with details on how to quantitatively validate the model on unseen data."""
2364 @field_validator("documentation", mode="after")
2365 @classmethod
2366 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource:
2367 if not validation_context_var.get().perform_io_checks:
2368 return value
2370 doc_path = download(value).path
2371 doc_content = doc_path.read_text(encoding="utf-8")
2372 assert isinstance(doc_content, str)
2373 if not re.match("#.*[vV]alidation", doc_content):
2374 issue_warning(
2375 "No '# Validation' (sub)section found in {value}.",
2376 value=value,
2377 field="documentation",
2378 )
2380 return value
2382 inputs: NotEmpty[Sequence[InputTensorDescr]]
2383 """Describes the input tensors expected by this model."""
2385 @field_validator("inputs", mode="after")
2386 @classmethod
2387 def _validate_input_axes(
2388 cls, inputs: Sequence[InputTensorDescr]
2389 ) -> Sequence[InputTensorDescr]:
2390 input_size_refs = cls._get_axes_with_independent_size(inputs)
2392 for i, ipt in enumerate(inputs):
2393 valid_independent_refs: Dict[
2394 Tuple[TensorId, AxisId],
2395 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2396 ] = {
2397 **{
2398 (ipt.id, a.id): (ipt, a, a.size)
2399 for a in ipt.axes
2400 if not isinstance(a, BatchAxis)
2401 and isinstance(a.size, (int, ParameterizedSize))
2402 },
2403 **input_size_refs,
2404 }
2405 for a, ax in enumerate(ipt.axes):
2406 cls._validate_axis(
2407 "inputs",
2408 i=i,
2409 tensor_id=ipt.id,
2410 a=a,
2411 axis=ax,
2412 valid_independent_refs=valid_independent_refs,
2413 )
2414 return inputs
2416 @staticmethod
2417 def _validate_axis(
2418 field_name: str,
2419 i: int,
2420 tensor_id: TensorId,
2421 a: int,
2422 axis: AnyAxis,
2423 valid_independent_refs: Dict[
2424 Tuple[TensorId, AxisId],
2425 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2426 ],
2427 ):
2428 if isinstance(axis, BatchAxis) or isinstance(
2429 axis.size, (int, ParameterizedSize, DataDependentSize)
2430 ):
2431 return
2432 elif not isinstance(axis.size, SizeReference):
2433 assert_never(axis.size)
2435 # validate axis.size SizeReference
2436 ref = (axis.size.tensor_id, axis.size.axis_id)
2437 if ref not in valid_independent_refs:
2438 raise ValueError(
2439 "Invalid tensor axis reference at"
2440 + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2441 )
2442 if ref == (tensor_id, axis.id):
2443 raise ValueError(
2444 "Self-referencing not allowed for"
2445 + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2446 )
2447 if axis.type == "channel":
2448 if valid_independent_refs[ref][1].type != "channel":
2449 raise ValueError(
2450 "A channel axis' size may only reference another fixed size"
2451 + " channel axis."
2452 )
2453 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2454 ref_size = valid_independent_refs[ref][2]
2455 assert isinstance(ref_size, int), (
2456 "channel axis ref (another channel axis) has to specify fixed"
2457 + " size"
2458 )
2459 generated_channel_names = [
2460 Identifier(axis.channel_names.format(i=i))
2461 for i in range(1, ref_size + 1)
2462 ]
2463 axis.channel_names = generated_channel_names
2465 if (ax_unit := getattr(axis, "unit", None)) != (
2466 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2467 ):
2468 raise ValueError(
2469 "The units of an axis and its reference axis need to match, but"
2470 + f" '{ax_unit}' != '{ref_unit}'."
2471 )
2472 ref_axis = valid_independent_refs[ref][1]
2473 if isinstance(ref_axis, BatchAxis):
2474 raise ValueError(
2475 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2476 + " (a batch axis is not allowed as reference)."
2477 )
2479 if isinstance(axis, WithHalo):
2480 min_size = axis.size.get_size(axis, ref_axis, n=0)
2481 if (min_size - 2 * axis.halo) < 1:
2482 raise ValueError(
2483 f"axis {axis.id} with minimum size {min_size} is too small for halo"
2484 + f" {axis.halo}."
2485 )
2487 input_halo = axis.halo * axis.scale / ref_axis.scale
2488 if input_halo != int(input_halo) or input_halo % 2 == 1:
2489 raise ValueError(
2490 f"input_halo {input_halo} (output_halo {axis.halo} *"
2491 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2492 + f" is not an even integer for {tensor_id}.{axis.id}."
2493 )
2495 @model_validator(mode="after")
2496 def _validate_test_tensors(self) -> Self:
2497 if not validation_context_var.get().perform_io_checks:
2498 return self
2500 test_arrays = [
2501 load_array(descr.test_tensor.download().path)
2502 for descr in chain(self.inputs, self.outputs)
2503 ]
2504 tensors = {
2505 descr.id: (descr, array)
2506 for descr, array in zip(chain(self.inputs, self.outputs), test_arrays)
2507 }
2508 validate_tensors(tensors, tensor_origin="test_tensor")
2509 return self
2511 @model_validator(mode="after")
2512 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2513 ipt_refs = {t.id for t in self.inputs}
2514 out_refs = {t.id for t in self.outputs}
2515 for ipt in self.inputs:
2516 for p in ipt.preprocessing:
2517 ref = p.kwargs.get("reference_tensor")
2518 if ref is None:
2519 continue
2520 if ref not in ipt_refs:
2521 raise ValueError(
2522 f"`reference_tensor` '{ref}' not found. Valid input tensor"
2523 + f" references are: {ipt_refs}."
2524 )
2526 for out in self.outputs:
2527 for p in out.postprocessing:
2528 ref = p.kwargs.get("reference_tensor")
2529 if ref is None:
2530 continue
2532 if ref not in ipt_refs and ref not in out_refs:
2533 raise ValueError(
2534 f"`reference_tensor` '{ref}' not found. Valid tensor references"
2535 + f" are: {ipt_refs | out_refs}."
2536 )
2538 return self
2540 # TODO: use validate funcs in validate_test_tensors
2541 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2543 name: Annotated[
2544 Annotated[
2545 str, RestrictCharacters(string.ascii_letters + string.digits + "_- ()")
2546 ],
2547 MinLen(5),
2548 MaxLen(128),
2549 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2550 ]
2551 """A human-readable name of this model.
2552 It should be no longer than 64 characters
2553 and may only contain letter, number, underscore, minus, parentheses and spaces.
2554 We recommend to chose a name that refers to the model's task and image modality.
2555 """
2557 outputs: NotEmpty[Sequence[OutputTensorDescr]]
2558 """Describes the output tensors."""
2560 @field_validator("outputs", mode="after")
2561 @classmethod
2562 def _validate_tensor_ids(
2563 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2564 ) -> Sequence[OutputTensorDescr]:
2565 tensor_ids = [
2566 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2567 ]
2568 duplicate_tensor_ids: List[str] = []
2569 seen: Set[str] = set()
2570 for t in tensor_ids:
2571 if t in seen:
2572 duplicate_tensor_ids.append(t)
2574 seen.add(t)
2576 if duplicate_tensor_ids:
2577 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2579 return outputs
2581 @staticmethod
2582 def _get_axes_with_parameterized_size(
2583 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2584 ):
2585 return {
2586 f"{t.id}.{a.id}": (t, a, a.size)
2587 for t in io
2588 for a in t.axes
2589 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2590 }
2592 @staticmethod
2593 def _get_axes_with_independent_size(
2594 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2595 ):
2596 return {
2597 (t.id, a.id): (t, a, a.size)
2598 for t in io
2599 for a in t.axes
2600 if not isinstance(a, BatchAxis)
2601 and isinstance(a.size, (int, ParameterizedSize))
2602 }
2604 @field_validator("outputs", mode="after")
2605 @classmethod
2606 def _validate_output_axes(
2607 cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2608 ) -> List[OutputTensorDescr]:
2609 input_size_refs = cls._get_axes_with_independent_size(
2610 info.data.get("inputs", [])
2611 )
2612 output_size_refs = cls._get_axes_with_independent_size(outputs)
2614 for i, out in enumerate(outputs):
2615 valid_independent_refs: Dict[
2616 Tuple[TensorId, AxisId],
2617 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2618 ] = {
2619 **{
2620 (out.id, a.id): (out, a, a.size)
2621 for a in out.axes
2622 if not isinstance(a, BatchAxis)
2623 and isinstance(a.size, (int, ParameterizedSize))
2624 },
2625 **input_size_refs,
2626 **output_size_refs,
2627 }
2628 for a, ax in enumerate(out.axes):
2629 cls._validate_axis(
2630 "outputs",
2631 i,
2632 out.id,
2633 a,
2634 ax,
2635 valid_independent_refs=valid_independent_refs,
2636 )
2638 return outputs
2640 packaged_by: List[Author] = Field(default_factory=list)
2641 """The persons that have packaged and uploaded this model.
2642 Only required if those persons differ from the `authors`."""
2644 parent: Optional[LinkedModel] = None
2645 """The model from which this model is derived, e.g. by fine-tuning the weights."""
2647 # todo: add parent self check once we have `id`
2648 # @model_validator(mode="after")
2649 # def validate_parent_is_not_self(self) -> Self:
2650 # if self.parent is not None and self.parent == self.id:
2651 # raise ValueError("The model may not reference itself as parent model")
2653 # return self
2655 run_mode: Annotated[
2656 Optional[RunMode],
2657 warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2658 ] = None
2659 """Custom run mode for this model: for more complex prediction procedures like test time
2660 data augmentation that currently cannot be expressed in the specification.
2661 No standard run modes are defined yet."""
2663 timestamp: Datetime = Datetime(datetime.now())
2664 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2665 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2666 (In Python a datetime object is valid, too)."""
2668 training_data: Annotated[
2669 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2670 Field(union_mode="left_to_right"),
2671 ] = None
2672 """The dataset used to train this model"""
2674 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2675 """The weights for this model.
2676 Weights can be given for different formats, but should otherwise be equivalent.
2677 The available weight formats determine which consumers can use this model."""
2679 @model_validator(mode="after")
2680 def _add_default_cover(self) -> Self:
2681 if not validation_context_var.get().perform_io_checks or self.covers:
2682 return self
2684 try:
2685 generated_covers = generate_covers(
2686 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs],
2687 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs],
2688 )
2689 except Exception as e:
2690 issue_warning(
2691 "Failed to generate cover image(s): {e}",
2692 value=self.covers,
2693 msg_context=dict(e=e),
2694 field="covers",
2695 )
2696 else:
2697 self.covers.extend(generated_covers)
2699 return self
2701 def get_input_test_arrays(self) -> List[NDArray[Any]]:
2702 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2703 assert all(isinstance(d, np.ndarray) for d in data)
2704 return data
2706 def get_output_test_arrays(self) -> List[NDArray[Any]]:
2707 data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2708 assert all(isinstance(d, np.ndarray) for d in data)
2709 return data
2711 @staticmethod
2712 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2713 batch_size = 1
2714 tensor_with_batchsize: Optional[TensorId] = None
2715 for tid in tensor_sizes:
2716 for aid, s in tensor_sizes[tid].items():
2717 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2718 continue
2720 if batch_size != 1:
2721 assert tensor_with_batchsize is not None
2722 raise ValueError(
2723 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2724 )
2726 batch_size = s
2727 tensor_with_batchsize = tid
2729 return batch_size
2731 def get_output_tensor_sizes(
2732 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2733 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2734 """Returns the tensor output sizes for given **input_sizes**.
2735 Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2736 Otherwise it might be larger than the actual (valid) output"""
2737 batch_size = self.get_batch_size(input_sizes)
2738 ns = self.get_ns(input_sizes)
2740 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2741 return tensor_sizes.outputs
2743 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2744 """get parameter `n` for each parameterized axis
2745 such that the valid input size is >= the given input size"""
2746 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2747 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2748 for tid in input_sizes:
2749 for aid, s in input_sizes[tid].items():
2750 size_descr = axes[tid][aid].size
2751 if isinstance(size_descr, ParameterizedSize):
2752 ret[(tid, aid)] = size_descr.get_n(s)
2753 elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2754 pass
2755 else:
2756 assert_never(size_descr)
2758 return ret
2760 def get_tensor_sizes(
2761 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2762 ) -> _TensorSizes:
2763 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2764 return _TensorSizes(
2765 {
2766 t: {
2767 aa: axis_sizes.inputs[(tt, aa)]
2768 for tt, aa in axis_sizes.inputs
2769 if tt == t
2770 }
2771 for t in {tt for tt, _ in axis_sizes.inputs}
2772 },
2773 {
2774 t: {
2775 aa: axis_sizes.outputs[(tt, aa)]
2776 for tt, aa in axis_sizes.outputs
2777 if tt == t
2778 }
2779 for t in {tt for tt, _ in axis_sizes.outputs}
2780 },
2781 )
2783 def get_axis_sizes(
2784 self,
2785 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
2786 batch_size: Optional[int] = None,
2787 *,
2788 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
2789 ) -> _AxisSizes:
2790 """Determine input and output block shape for scale factors **ns**
2791 of parameterized input sizes.
2793 Args:
2794 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
2795 that is parameterized as `size = min + n * step`.
2796 batch_size: The desired size of the batch dimension.
2797 If given **batch_size** overwrites any batch size present in
2798 **max_input_shape**. Default 1.
2799 max_input_shape: Limits the derived block shapes.
2800 Each axis for which the input size, parameterized by `n`, is larger
2801 than **max_input_shape** is set to the minimal value `n_min` for which
2802 this is still true.
2803 Use this for small input samples or large values of **ns**.
2804 Or simply whenever you know the full input shape.
2806 Returns:
2807 Resolved axis sizes for model inputs and outputs.
2808 """
2809 max_input_shape = max_input_shape or {}
2810 if batch_size is None:
2811 for (_t_id, a_id), s in max_input_shape.items():
2812 if a_id == BATCH_AXIS_ID:
2813 batch_size = s
2814 break
2815 else:
2816 batch_size = 1
2818 all_axes = {
2819 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
2820 }
2822 inputs: Dict[Tuple[TensorId, AxisId], int] = {}
2823 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
2825 def get_axis_size(a: Union[InputAxis, OutputAxis]):
2826 if isinstance(a, BatchAxis):
2827 if (t_descr.id, a.id) in ns:
2828 logger.warning(
2829 "Ignoring unexpected size increment factor (n) for batch axis"
2830 + " of tensor '{}'.",
2831 t_descr.id,
2832 )
2833 return batch_size
2834 elif isinstance(a.size, int):
2835 if (t_descr.id, a.id) in ns:
2836 logger.warning(
2837 "Ignoring unexpected size increment factor (n) for fixed size"
2838 + " axis '{}' of tensor '{}'.",
2839 a.id,
2840 t_descr.id,
2841 )
2842 return a.size
2843 elif isinstance(a.size, ParameterizedSize):
2844 if (t_descr.id, a.id) not in ns:
2845 raise ValueError(
2846 "Size increment factor (n) missing for parametrized axis"
2847 + f" '{a.id}' of tensor '{t_descr.id}'."
2848 )
2849 n = ns[(t_descr.id, a.id)]
2850 s_max = max_input_shape.get((t_descr.id, a.id))
2851 if s_max is not None:
2852 n = min(n, a.size.get_n(s_max))
2854 return a.size.get_size(n)
2856 elif isinstance(a.size, SizeReference):
2857 if (t_descr.id, a.id) in ns:
2858 logger.warning(
2859 "Ignoring unexpected size increment factor (n) for axis '{}'"
2860 + " of tensor '{}' with size reference.",
2861 a.id,
2862 t_descr.id,
2863 )
2864 assert not isinstance(a, BatchAxis)
2865 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
2866 assert not isinstance(ref_axis, BatchAxis)
2867 ref_key = (a.size.tensor_id, a.size.axis_id)
2868 ref_size = inputs.get(ref_key, outputs.get(ref_key))
2869 assert ref_size is not None, ref_key
2870 assert not isinstance(ref_size, _DataDepSize), ref_key
2871 return a.size.get_size(
2872 axis=a,
2873 ref_axis=ref_axis,
2874 ref_size=ref_size,
2875 )
2876 elif isinstance(a.size, DataDependentSize):
2877 if (t_descr.id, a.id) in ns:
2878 logger.warning(
2879 "Ignoring unexpected increment factor (n) for data dependent"
2880 + " size axis '{}' of tensor '{}'.",
2881 a.id,
2882 t_descr.id,
2883 )
2884 return _DataDepSize(a.size.min, a.size.max)
2885 else:
2886 assert_never(a.size)
2888 # first resolve all , but the `SizeReference` input sizes
2889 for t_descr in self.inputs:
2890 for a in t_descr.axes:
2891 if not isinstance(a.size, SizeReference):
2892 s = get_axis_size(a)
2893 assert not isinstance(s, _DataDepSize)
2894 inputs[t_descr.id, a.id] = s
2896 # resolve all other input axis sizes
2897 for t_descr in self.inputs:
2898 for a in t_descr.axes:
2899 if isinstance(a.size, SizeReference):
2900 s = get_axis_size(a)
2901 assert not isinstance(s, _DataDepSize)
2902 inputs[t_descr.id, a.id] = s
2904 # resolve all output axis sizes
2905 for t_descr in self.outputs:
2906 for a in t_descr.axes:
2907 assert not isinstance(a.size, ParameterizedSize)
2908 s = get_axis_size(a)
2909 outputs[t_descr.id, a.id] = s
2911 return _AxisSizes(inputs=inputs, outputs=outputs)
2913 @model_validator(mode="before")
2914 @classmethod
2915 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
2916 if (
2917 data.get("type") == "model"
2918 and isinstance(fv := data.get("format_version"), str)
2919 and fv.count(".") == 2
2920 ):
2921 fv_parts = fv.split(".")
2922 if any(not p.isdigit() for p in fv_parts):
2923 return data
2925 fv_tuple = tuple(map(int, fv_parts))
2927 assert cls.implemented_format_version_tuple[0:2] == (0, 5)
2928 if fv_tuple[:2] in ((0, 3), (0, 4)):
2929 m04 = _ModelDescr_v0_4.load(data)
2930 if not isinstance(m04, InvalidDescr):
2931 return _model_conv.convert_as_dict(m04)
2932 elif fv_tuple[:2] == (0, 5):
2933 # bump patch version
2934 data["format_version"] = cls.implemented_format_version
2936 return data
2939class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
2940 def _convert(
2941 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
2942 ) -> "ModelDescr | dict[str, Any]":
2943 name = "".join(
2944 c if c in string.ascii_letters + string.digits + "_- ()" else " "
2945 for c in src.name
2946 )
2948 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
2949 conv = (
2950 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
2951 )
2952 return None if auths is None else [conv(a) for a in auths]
2954 if TYPE_CHECKING:
2955 arch_file_conv = _arch_file_conv.convert
2956 arch_lib_conv = _arch_lib_conv.convert
2957 else:
2958 arch_file_conv = _arch_file_conv.convert_as_dict
2959 arch_lib_conv = _arch_lib_conv.convert_as_dict
2961 input_size_refs = {
2962 ipt.name: {
2963 a: s
2964 for a, s in zip(
2965 ipt.axes,
2966 (
2967 ipt.shape.min
2968 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
2969 else ipt.shape
2970 ),
2971 )
2972 }
2973 for ipt in src.inputs
2974 if ipt.shape
2975 }
2976 output_size_refs = {
2977 **{
2978 out.name: {a: s for a, s in zip(out.axes, out.shape)}
2979 for out in src.outputs
2980 if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
2981 },
2982 **input_size_refs,
2983 }
2985 return tgt(
2986 attachments=(
2987 []
2988 if src.attachments is None
2989 else [FileDescr(source=f) for f in src.attachments.files]
2990 ),
2991 authors=[
2992 _author_conv.convert_as_dict(a) for a in src.authors
2993 ], # pyright: ignore[reportArgumentType]
2994 cite=[
2995 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite
2996 ], # pyright: ignore[reportArgumentType]
2997 config=src.config,
2998 covers=src.covers,
2999 description=src.description,
3000 documentation=src.documentation,
3001 format_version="0.5.3",
3002 git_repo=src.git_repo, # pyright: ignore[reportArgumentType]
3003 icon=src.icon,
3004 id=None if src.id is None else ModelId(src.id),
3005 id_emoji=src.id_emoji,
3006 license=src.license, # type: ignore
3007 links=src.links,
3008 maintainers=[
3009 _maintainer_conv.convert_as_dict(m) for m in src.maintainers
3010 ], # pyright: ignore[reportArgumentType]
3011 name=name,
3012 tags=src.tags,
3013 type=src.type,
3014 uploader=src.uploader,
3015 version=src.version,
3016 inputs=[ # pyright: ignore[reportArgumentType]
3017 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
3018 for ipt, tt, st, in zip(
3019 src.inputs,
3020 src.test_inputs,
3021 src.sample_inputs or [None] * len(src.test_inputs),
3022 )
3023 ],
3024 outputs=[ # pyright: ignore[reportArgumentType]
3025 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
3026 for out, tt, st, in zip(
3027 src.outputs,
3028 src.test_outputs,
3029 src.sample_outputs or [None] * len(src.test_outputs),
3030 )
3031 ],
3032 parent=(
3033 None
3034 if src.parent is None
3035 else LinkedModel(
3036 id=ModelId(
3037 str(src.parent.id)
3038 + (
3039 ""
3040 if src.parent.version_number is None
3041 else f"/{src.parent.version_number}"
3042 )
3043 )
3044 )
3045 ),
3046 training_data=(
3047 None
3048 if src.training_data is None
3049 else (
3050 LinkedDataset(
3051 id=DatasetId(
3052 str(src.training_data.id)
3053 + (
3054 ""
3055 if src.training_data.version_number is None
3056 else f"/{src.training_data.version_number}"
3057 )
3058 )
3059 )
3060 if isinstance(src.training_data, LinkedDataset02)
3061 else src.training_data
3062 )
3063 ),
3064 packaged_by=[
3065 _author_conv.convert_as_dict(a) for a in src.packaged_by
3066 ], # pyright: ignore[reportArgumentType]
3067 run_mode=src.run_mode,
3068 timestamp=src.timestamp,
3069 weights=(WeightsDescr if TYPE_CHECKING else dict)(
3070 keras_hdf5=(w := src.weights.keras_hdf5)
3071 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
3072 authors=conv_authors(w.authors),
3073 source=w.source,
3074 tensorflow_version=w.tensorflow_version or Version("1.15"),
3075 parent=w.parent,
3076 ),
3077 onnx=(w := src.weights.onnx)
3078 and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
3079 source=w.source,
3080 authors=conv_authors(w.authors),
3081 parent=w.parent,
3082 opset_version=w.opset_version or 15,
3083 ),
3084 pytorch_state_dict=(w := src.weights.pytorch_state_dict)
3085 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
3086 source=w.source,
3087 authors=conv_authors(w.authors),
3088 parent=w.parent,
3089 architecture=(
3090 arch_file_conv(
3091 w.architecture,
3092 w.architecture_sha256,
3093 w.kwargs,
3094 )
3095 if isinstance(w.architecture, _CallableFromFile_v0_4)
3096 else arch_lib_conv(w.architecture, w.kwargs)
3097 ),
3098 pytorch_version=w.pytorch_version or Version("1.10"),
3099 dependencies=(
3100 None
3101 if w.dependencies is None
3102 else (EnvironmentFileDescr if TYPE_CHECKING else dict)(
3103 source=cast(
3104 ImportantFileSource,
3105 str(deps := w.dependencies)[
3106 (
3107 len("conda:")
3108 if str(deps).startswith("conda:")
3109 else 0
3110 ) :
3111 ],
3112 )
3113 )
3114 ),
3115 ),
3116 tensorflow_js=(w := src.weights.tensorflow_js)
3117 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
3118 source=w.source,
3119 authors=conv_authors(w.authors),
3120 parent=w.parent,
3121 tensorflow_version=w.tensorflow_version or Version("1.15"),
3122 ),
3123 tensorflow_saved_model_bundle=(
3124 w := src.weights.tensorflow_saved_model_bundle
3125 )
3126 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
3127 authors=conv_authors(w.authors),
3128 parent=w.parent,
3129 source=w.source,
3130 tensorflow_version=w.tensorflow_version or Version("1.15"),
3131 dependencies=(
3132 None
3133 if w.dependencies is None
3134 else (EnvironmentFileDescr if TYPE_CHECKING else dict)(
3135 source=cast(
3136 ImportantFileSource,
3137 (
3138 str(w.dependencies)[len("conda:") :]
3139 if str(w.dependencies).startswith("conda:")
3140 else str(w.dependencies)
3141 ),
3142 )
3143 )
3144 ),
3145 ),
3146 torchscript=(w := src.weights.torchscript)
3147 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
3148 source=w.source,
3149 authors=conv_authors(w.authors),
3150 parent=w.parent,
3151 pytorch_version=w.pytorch_version or Version("1.10"),
3152 ),
3153 ),
3154 )
3157_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
3160# create better cover images for 3d data and non-image outputs
3161def generate_covers(
3162 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
3163 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
3164) -> List[Path]:
3165 def squeeze(
3166 data: NDArray[Any], axes: Sequence[AnyAxis]
3167 ) -> Tuple[NDArray[Any], List[AnyAxis]]:
3168 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
3169 if data.ndim != len(axes):
3170 raise ValueError(
3171 f"tensor shape {data.shape} does not match described axes"
3172 + f" {[a.id for a in axes]}"
3173 )
3175 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
3176 return data.squeeze(), axes
3178 def normalize(
3179 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
3180 ) -> NDArray[np.float32]:
3181 data = data.astype("float32")
3182 data -= data.min(axis=axis, keepdims=True)
3183 data /= data.max(axis=axis, keepdims=True) + eps
3184 return data
3186 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
3187 original_shape = data.shape
3188 data, axes = squeeze(data, axes)
3190 # take slice fom any batch or index axis if needed
3191 # and convert the first channel axis and take a slice from any additional channel axes
3192 slices: Tuple[slice, ...] = ()
3193 ndim = data.ndim
3194 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
3195 has_c_axis = False
3196 for i, a in enumerate(axes):
3197 s = data.shape[i]
3198 assert s > 1
3199 if (
3200 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
3201 and ndim > ndim_need
3202 ):
3203 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3204 ndim -= 1
3205 elif isinstance(a, ChannelAxis):
3206 if has_c_axis:
3207 # second channel axis
3208 data = data[slices + (slice(0, 1),)]
3209 ndim -= 1
3210 else:
3211 has_c_axis = True
3212 if s == 2:
3213 # visualize two channels with cyan and magenta
3214 data = np.concatenate(
3215 [
3216 data[slices + (slice(1, 2),)],
3217 data[slices + (slice(0, 1),)],
3218 (
3219 data[slices + (slice(0, 1),)]
3220 + data[slices + (slice(1, 2),)]
3221 )
3222 / 2, # TODO: take maximum instead?
3223 ],
3224 axis=i,
3225 )
3226 elif data.shape[i] == 3:
3227 pass # visualize 3 channels as RGB
3228 else:
3229 # visualize first 3 channels as RGB
3230 data = data[slices + (slice(3),)]
3232 assert data.shape[i] == 3
3234 slices += (slice(None),)
3236 data, axes = squeeze(data, axes)
3237 assert len(axes) == ndim
3238 # take slice from z axis if needed
3239 slices = ()
3240 if ndim > ndim_need:
3241 for i, a in enumerate(axes):
3242 s = data.shape[i]
3243 if a.id == AxisId("z"):
3244 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3245 data, axes = squeeze(data, axes)
3246 ndim -= 1
3247 break
3249 slices += (slice(None),)
3251 # take slice from any space or time axis
3252 slices = ()
3254 for i, a in enumerate(axes):
3255 if ndim <= ndim_need:
3256 break
3258 s = data.shape[i]
3259 assert s > 1
3260 if isinstance(
3261 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
3262 ):
3263 data = data[slices + (slice(s // 2 - 1, s // 2),)]
3264 ndim -= 1
3266 slices += (slice(None),)
3268 del slices
3269 data, axes = squeeze(data, axes)
3270 assert len(axes) == ndim
3272 if (has_c_axis and ndim != 3) or ndim != 2:
3273 raise ValueError(
3274 f"Failed to construct cover image from shape {original_shape}"
3275 )
3277 if not has_c_axis:
3278 assert ndim == 2
3279 data = np.repeat(data[:, :, None], 3, axis=2)
3280 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
3281 ndim += 1
3283 assert ndim == 3
3285 # transpose axis order such that longest axis comes first...
3286 axis_order = list(np.argsort(list(data.shape)))
3287 axis_order.reverse()
3288 # ... and channel axis is last
3289 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
3290 axis_order.append(axis_order.pop(c))
3291 axes = [axes[ao] for ao in axis_order]
3292 data = data.transpose(axis_order)
3294 # h, w = data.shape[:2]
3295 # if h / w in (1.0 or 2.0):
3296 # pass
3297 # elif h / w < 2:
3298 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
3300 norm_along = (
3301 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
3302 )
3303 # normalize the data and map to 8 bit
3304 data = normalize(data, norm_along)
3305 data = (data * 255).astype("uint8")
3307 return data
3309 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
3310 assert im0.dtype == im1.dtype == np.uint8
3311 assert im0.shape == im1.shape
3312 assert im0.ndim == 3
3313 N, M, C = im0.shape
3314 assert C == 3
3315 out = np.ones((N, M, C), dtype="uint8")
3316 for c in range(C):
3317 outc = np.tril(im0[..., c])
3318 mask = outc == 0
3319 outc[mask] = np.triu(im1[..., c])[mask]
3320 out[..., c] = outc
3322 return out
3324 ipt_descr, ipt = inputs[0]
3325 out_descr, out = outputs[0]
3327 ipt_img = to_2d_image(ipt, ipt_descr.axes)
3328 out_img = to_2d_image(out, out_descr.axes)
3330 cover_folder = Path(mkdtemp())
3331 if ipt_img.shape == out_img.shape:
3332 covers = [cover_folder / "cover.png"]
3333 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
3334 else:
3335 covers = [cover_folder / "input.png", cover_folder / "output.png"]
3336 imwrite(covers[0], ipt_img)
3337 imwrite(covers[1], out_img)
3339 return covers