Coverage for src / bioimageio / spec / model / v0_5.py: 71%
1657 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-13 11:29 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-13 11:29 +0000
1from __future__ import annotations
3import collections.abc
4import re
5import string
6import warnings
7from copy import deepcopy
8from functools import partial
9from itertools import chain
10from math import ceil
11from pathlib import Path, PurePosixPath
12from tempfile import mkdtemp
13from textwrap import dedent
14from typing import (
15 TYPE_CHECKING,
16 Any,
17 Callable,
18 ClassVar,
19 Dict,
20 Generic,
21 List,
22 Literal,
23 Mapping,
24 NamedTuple,
25 Optional,
26 Sequence,
27 Set,
28 Tuple,
29 Type,
30 TypeVar,
31 Union,
32 cast,
33 overload,
34)
36import numpy as np
37from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
38from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
39from loguru import logger
40from numpy.typing import NDArray
41from pydantic import (
42 AfterValidator,
43 Discriminator,
44 Field,
45 RootModel,
46 SerializationInfo,
47 SerializerFunctionWrapHandler,
48 StrictInt,
49 Tag,
50 ValidationInfo,
51 WrapSerializer,
52 field_validator,
53 model_serializer,
54 model_validator,
55)
56from typing_extensions import Annotated, Self, assert_never, get_args
58from .._internal.common_nodes import (
59 InvalidDescr,
60 KwargsNode,
61 Node,
62 NodeWithExplicitlySetFields,
63)
64from .._internal.constants import DTYPE_LIMITS
65from .._internal.field_warning import issue_warning, warn
66from .._internal.io import BioimageioYamlContent as BioimageioYamlContent
67from .._internal.io import FileDescr as FileDescr
68from .._internal.io import (
69 FileSource,
70 WithSuffix,
71 YamlValue,
72 extract_file_name,
73 get_reader,
74 wo_special_file_name,
75)
76from .._internal.io_basics import Sha256 as Sha256
77from .._internal.io_packaging import (
78 FileDescr_,
79 FileSource_,
80 package_file_descr_serializer,
81)
82from .._internal.io_utils import load_array
83from .._internal.node_converter import Converter
84from .._internal.type_guards import is_dict, is_sequence
85from .._internal.types import (
86 FAIR,
87 AbsoluteTolerance,
88 LowerCaseIdentifier,
89 LowerCaseIdentifierAnno,
90 MismatchedElementsPerMillion,
91 RelativeTolerance,
92)
93from .._internal.types import Datetime as Datetime
94from .._internal.types import Identifier as Identifier
95from .._internal.types import NotEmpty as NotEmpty
96from .._internal.types import SiUnit as SiUnit
97from .._internal.url import HttpUrl as HttpUrl
98from .._internal.utils import try_all_raise_last
99from .._internal.validation_context import get_validation_context
100from .._internal.validator_annotations import RestrictCharacters
101from .._internal.version_type import Version as Version
102from .._internal.warning_levels import INFO
103from ..dataset.v0_2 import DatasetDescr as DatasetDescr02
104from ..dataset.v0_2 import LinkedDataset as LinkedDataset02
105from ..dataset.v0_3 import DatasetDescr as DatasetDescr
106from ..dataset.v0_3 import DatasetId as DatasetId
107from ..dataset.v0_3 import LinkedDataset as LinkedDataset
108from ..dataset.v0_3 import Uploader as Uploader
109from ..generic.v0_3 import (
110 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS,
111)
112from ..generic.v0_3 import Author as Author
113from ..generic.v0_3 import BadgeDescr as BadgeDescr
114from ..generic.v0_3 import CiteEntry as CiteEntry
115from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId
116from ..generic.v0_3 import Doi as Doi
117from ..generic.v0_3 import (
118 FileSource_documentation,
119 GenericModelDescrBase,
120 LinkedResourceBase,
121 _author_conv, # pyright: ignore[reportPrivateUsage]
122 _maintainer_conv, # pyright: ignore[reportPrivateUsage]
123)
124from ..generic.v0_3 import LicenseId as LicenseId
125from ..generic.v0_3 import LinkedResource as LinkedResource
126from ..generic.v0_3 import Maintainer as Maintainer
127from ..generic.v0_3 import OrcidId as OrcidId
128from ..generic.v0_3 import RelativeFilePath as RelativeFilePath
129from ..generic.v0_3 import ResourceId as ResourceId
130from .v0_4 import Author as _Author_v0_4
131from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4
132from .v0_4 import CallableFromDepencency as CallableFromDepencency
133from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4
134from .v0_4 import CallableFromFile as _CallableFromFile_v0_4
135from .v0_4 import ClipDescr as _ClipDescr_v0_4
136from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4
137from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4
138from .v0_4 import KnownRunMode as KnownRunMode
139from .v0_4 import ModelDescr as _ModelDescr_v0_4
140from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4
141from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4
142from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4
143from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4
144from .v0_4 import RunMode as RunMode
145from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4
146from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4
147from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4
148from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4
149from .v0_4 import TensorName as _TensorName_v0_4
150from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
151from .v0_4 import package_weights
153SpaceUnit = Literal[
154 "attometer",
155 "angstrom",
156 "centimeter",
157 "decimeter",
158 "exameter",
159 "femtometer",
160 "foot",
161 "gigameter",
162 "hectometer",
163 "inch",
164 "kilometer",
165 "megameter",
166 "meter",
167 "micrometer",
168 "mile",
169 "millimeter",
170 "nanometer",
171 "parsec",
172 "petameter",
173 "picometer",
174 "terameter",
175 "yard",
176 "yoctometer",
177 "yottameter",
178 "zeptometer",
179 "zettameter",
180]
181"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
183TimeUnit = Literal[
184 "attosecond",
185 "centisecond",
186 "day",
187 "decisecond",
188 "exasecond",
189 "femtosecond",
190 "gigasecond",
191 "hectosecond",
192 "hour",
193 "kilosecond",
194 "megasecond",
195 "microsecond",
196 "millisecond",
197 "minute",
198 "nanosecond",
199 "petasecond",
200 "picosecond",
201 "second",
202 "terasecond",
203 "yoctosecond",
204 "yottasecond",
205 "zeptosecond",
206 "zettasecond",
207]
208"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)"""
210AxisType = Literal["batch", "channel", "index", "time", "space"]
212_AXIS_TYPE_MAP: Mapping[str, AxisType] = {
213 "b": "batch",
214 "t": "time",
215 "i": "index",
216 "c": "channel",
217 "x": "space",
218 "y": "space",
219 "z": "space",
220}
222_AXIS_ID_MAP = {
223 "b": "batch",
224 "t": "time",
225 "i": "index",
226 "c": "channel",
227}
229WeightsFormat = Literal[
230 "keras_hdf5",
231 "keras_v3",
232 "onnx",
233 "pytorch_state_dict",
234 "tensorflow_js",
235 "tensorflow_saved_model_bundle",
236 "torchscript",
237]
240class TensorId(LowerCaseIdentifier):
241 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
242 Annotated[LowerCaseIdentifierAnno, MaxLen(32)]
243 ]
246def _normalize_axis_id(a: str):
247 a = str(a)
248 normalized = _AXIS_ID_MAP.get(a, a)
249 if a != normalized:
250 logger.opt(depth=3).warning(
251 "Normalized axis id from '{}' to '{}'.", a, normalized
252 )
253 return normalized
256class AxisId(LowerCaseIdentifier):
257 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[
258 Annotated[
259 LowerCaseIdentifierAnno,
260 MaxLen(16),
261 AfterValidator(_normalize_axis_id),
262 ]
263 ]
266def _is_batch(a: str) -> bool:
267 return str(a) == "batch"
270def _is_not_batch(a: str) -> bool:
271 return not _is_batch(a)
274NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)]
276PreprocessingId = Literal[
277 "binarize",
278 "clip",
279 "ensure_dtype",
280 "fixed_zero_mean_unit_variance",
281 "scale_linear",
282 "scale_range",
283 "sigmoid",
284 "softmax",
285]
286PostprocessingId = Literal[
287 "binarize",
288 "clip",
289 "custom",
290 "ensure_dtype",
291 "fixed_zero_mean_unit_variance",
292 "scale_linear",
293 "scale_mean_variance",
294 "scale_range",
295 "sigmoid",
296 "softmax",
297 "zero_mean_unit_variance",
298]
301SAME_AS_TYPE = "<same as type>"
304ParameterizedSize_N = int
305"""
306Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`.
307"""
310class ParameterizedSize(Node):
311 """Describes a range of valid tensor axis sizes as `size = min + n*step`.
313 - **min** and **step** are given by the model description.
314 - All blocksize paramters n = 0,1,2,... yield a valid `size`.
315 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**.
316 This allows to adjust the axis size more generically.
317 """
319 N: ClassVar[Type[int]] = ParameterizedSize_N
320 """Positive integer to parameterize this axis"""
322 min: Annotated[int, Gt(0)]
323 step: Annotated[int, Gt(0)]
325 def validate_size(self, size: int, msg_prefix: str = "") -> int:
326 if size < self.min:
327 raise ValueError(
328 f"{msg_prefix}size {size} < {self.min} (minimum axis size)"
329 )
330 if (size - self.min) % self.step != 0:
331 raise ValueError(
332 f"{msg_prefix}size {size} is not parameterized by `min + n*step` ="
333 + f" `{self.min} + n*{self.step}`"
334 )
336 return size
338 def get_size(self, n: ParameterizedSize_N) -> int:
339 return self.min + self.step * n
341 def get_n(self, s: int) -> ParameterizedSize_N:
342 """return smallest n parameterizing a size greater or equal than `s`"""
343 return ceil((s - self.min) / self.step)
346class DataDependentSize(Node):
347 min: Annotated[int, Gt(0)] = 1
348 max: Annotated[Optional[int], Gt(1)] = None
350 @model_validator(mode="after")
351 def _validate_max_gt_min(self):
352 if self.max is not None and self.min >= self.max:
353 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}")
355 return self
357 def validate_size(self, size: int, msg_prefix: str = "") -> int:
358 if size < self.min:
359 raise ValueError(f"{msg_prefix}size {size} < {self.min}")
361 if self.max is not None and size > self.max:
362 raise ValueError(f"{msg_prefix}size {size} > {self.max}")
364 return size
367class SizeReference(Node):
368 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
370 `axis.size = reference.size * reference.scale / axis.scale + offset`
372 Note:
373 1. The axis and the referenced axis need to have the same unit (or no unit).
374 2. Batch axes may not be referenced.
375 3. Fractions are rounded down.
376 4. If the reference axis is `concatenable` the referencing axis is assumed to be
377 `concatenable` as well with the same block order.
379 Example:
380 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm².
381 Let's assume that we want to express the image height h in relation to its width w
382 instead of only accepting input images of exactly 100*49 pixels
383 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`).
385 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2)
386 >>> h = SpaceInputAxis(
387 ... id=AxisId("h"),
388 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1),
389 ... unit="millimeter",
390 ... scale=4,
391 ... )
392 >>> print(h.size.get_size(h, w))
393 49
395 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49
396 """
398 tensor_id: TensorId
399 """tensor id of the reference axis"""
401 axis_id: AxisId
402 """axis id of the reference axis"""
404 offset: StrictInt = 0
406 def get_size(
407 self,
408 axis: Union[
409 ChannelAxis,
410 IndexInputAxis,
411 IndexOutputAxis,
412 TimeInputAxis,
413 SpaceInputAxis,
414 TimeOutputAxis,
415 TimeOutputAxisWithHalo,
416 SpaceOutputAxis,
417 SpaceOutputAxisWithHalo,
418 ],
419 ref_axis: Union[
420 ChannelAxis,
421 IndexInputAxis,
422 IndexOutputAxis,
423 TimeInputAxis,
424 SpaceInputAxis,
425 TimeOutputAxis,
426 TimeOutputAxisWithHalo,
427 SpaceOutputAxis,
428 SpaceOutputAxisWithHalo,
429 ],
430 n: ParameterizedSize_N = 0,
431 ref_size: Optional[int] = None,
432 ):
433 """Compute the concrete size for a given axis and its reference axis.
435 Args:
436 axis: The axis this [SizeReference][] is the size of.
437 ref_axis: The reference axis to compute the size from.
438 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
439 and no fixed **ref_size** is given,
440 **n** is used to compute the size of the parameterized **ref_axis**.
441 ref_size: Overwrite the reference size instead of deriving it from
442 **ref_axis**
443 (**ref_axis.scale** is still used; any given **n** is ignored).
444 """
445 assert axis.size == self, (
446 "Given `axis.size` is not defined by this `SizeReference`"
447 )
449 assert ref_axis.id == self.axis_id, (
450 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}."
451 )
453 assert axis.unit == ref_axis.unit, (
454 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
455 f" but {axis.unit}!={ref_axis.unit}"
456 )
457 if ref_size is None:
458 if isinstance(ref_axis.size, (int, float)):
459 ref_size = ref_axis.size
460 elif isinstance(ref_axis.size, ParameterizedSize):
461 ref_size = ref_axis.size.get_size(n)
462 elif isinstance(ref_axis.size, DataDependentSize):
463 raise ValueError(
464 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
465 )
466 elif isinstance(ref_axis.size, SizeReference):
467 raise ValueError(
468 "Reference axis referenced in `SizeReference` may not be sized by a"
469 + " `SizeReference` itself."
470 )
471 else:
472 assert_never(ref_axis.size)
474 return int(ref_size * ref_axis.scale / axis.scale + self.offset)
476 @staticmethod
477 def _get_unit(
478 axis: Union[
479 ChannelAxis,
480 IndexInputAxis,
481 IndexOutputAxis,
482 TimeInputAxis,
483 SpaceInputAxis,
484 TimeOutputAxis,
485 TimeOutputAxisWithHalo,
486 SpaceOutputAxis,
487 SpaceOutputAxisWithHalo,
488 ],
489 ):
490 return axis.unit
493class AxisBase(NodeWithExplicitlySetFields):
494 id: AxisId
495 """An axis id unique across all axes of one tensor."""
497 description: Annotated[str, MaxLen(128)] = ""
498 """A short description of this axis beyond its type and id."""
501class WithHalo(Node):
502 halo: Annotated[int, Ge(1)]
503 """The halo should be cropped from the output tensor to avoid boundary effects.
504 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`.
505 To document a halo that is already cropped by the model use `size.offset` instead."""
507 size: Annotated[
508 SizeReference,
509 Field(
510 examples=[
511 10,
512 SizeReference(
513 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
514 ).model_dump(mode="json"),
515 ]
516 ),
517 ]
518 """reference to another axis with an optional offset (see [SizeReference][])"""
521BATCH_AXIS_ID = AxisId("batch")
524class BatchAxis(AxisBase):
525 implemented_type: ClassVar[Literal["batch"]] = "batch"
526 if TYPE_CHECKING:
527 type: Literal["batch"] = "batch"
528 else:
529 type: Literal["batch"]
531 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID
532 size: Optional[Literal[1]] = None
533 """The batch size may be fixed to 1,
534 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
536 @property
537 def scale(self):
538 return 1.0
540 @property
541 def concatenable(self):
542 return True
544 @property
545 def unit(self):
546 return None
549class ChannelAxis(AxisBase):
550 implemented_type: ClassVar[Literal["channel"]] = "channel"
551 if TYPE_CHECKING:
552 type: Literal["channel"] = "channel"
553 else:
554 type: Literal["channel"]
556 id: NonBatchAxisId = AxisId("channel")
558 channel_names: NotEmpty[List[Identifier]]
560 @property
561 def size(self) -> int:
562 return len(self.channel_names)
564 @property
565 def concatenable(self):
566 return False
568 @property
569 def scale(self) -> float:
570 return 1.0
572 @property
573 def unit(self):
574 return None
577class _WithInputAxisSize(Node):
578 size: Annotated[
579 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
580 Field(
581 examples=[
582 10,
583 ParameterizedSize(min=32, step=16).model_dump(mode="json"),
584 SizeReference(
585 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
586 ).model_dump(mode="json"),
587 ]
588 ),
589 ]
590 """The size/length of this axis can be specified as
591 - fixed integer
592 - parameterized series of valid sizes ([ParameterizedSize][])
593 - reference to another axis with an optional offset ([SizeReference][])
594 """
597class IndexAxisBase(AxisBase):
598 implemented_type: ClassVar[Literal["index"]] = "index"
599 if TYPE_CHECKING:
600 type: Literal["index"] = "index"
601 else:
602 type: Literal["index"]
604 id: NonBatchAxisId = AxisId("index")
606 @property
607 def scale(self) -> float:
608 return 1.0
610 @property
611 def unit(self):
612 return None
615class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
616 concatenable: bool = False
617 """If a model has a `concatenable` input axis, it can be processed blockwise,
618 splitting a longer sample axis into blocks matching its input tensor description.
619 Output axes are concatenable if they have a [SizeReference][] to a concatenable
620 input axis.
621 """
624class IndexOutputAxis(IndexAxisBase):
625 size: Annotated[
626 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
627 Field(
628 examples=[
629 10,
630 SizeReference(
631 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
632 ).model_dump(mode="json"),
633 ]
634 ),
635 ]
636 """The size/length of this axis can be specified as
637 - fixed integer
638 - reference to another axis with an optional offset ([SizeReference][])
639 - data dependent size using [DataDependentSize][] (size is only known after model inference)
640 """
643class TimeAxisBase(AxisBase):
644 implemented_type: ClassVar[Literal["time"]] = "time"
645 if TYPE_CHECKING:
646 type: Literal["time"] = "time"
647 else:
648 type: Literal["time"]
650 id: NonBatchAxisId = AxisId("time")
651 unit: Optional[TimeUnit] = None
652 scale: Annotated[float, Gt(0)] = 1.0
655class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
656 concatenable: bool = False
657 """If a model has a `concatenable` input axis, it can be processed blockwise,
658 splitting a longer sample axis into blocks matching its input tensor description.
659 Output axes are concatenable if they have a [SizeReference][] to a concatenable
660 input axis.
661 """
664class SpaceAxisBase(AxisBase):
665 implemented_type: ClassVar[Literal["space"]] = "space"
666 if TYPE_CHECKING:
667 type: Literal["space"] = "space"
668 else:
669 type: Literal["space"]
671 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
672 unit: Optional[SpaceUnit] = None
673 scale: Annotated[float, Gt(0)] = 1.0
676class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
677 concatenable: bool = False
678 """If a model has a `concatenable` input axis, it can be processed blockwise,
679 splitting a longer sample axis into blocks matching its input tensor description.
680 Output axes are concatenable if they have a [SizeReference][] to a concatenable
681 input axis.
682 """
685INPUT_AXIS_TYPES = (
686 BatchAxis,
687 ChannelAxis,
688 IndexInputAxis,
689 TimeInputAxis,
690 SpaceInputAxis,
691)
692"""intended for isinstance comparisons in py<3.10"""
694_InputAxisUnion = Union[
695 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
696]
697InputAxis = Annotated[_InputAxisUnion, Discriminator("type")]
700class _WithOutputAxisSize(Node):
701 size: Annotated[
702 Union[Annotated[int, Gt(0)], SizeReference],
703 Field(
704 examples=[
705 10,
706 SizeReference(
707 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
708 ).model_dump(mode="json"),
709 ]
710 ),
711 ]
712 """The size/length of this axis can be specified as
713 - fixed integer
714 - reference to another axis with an optional offset (see [SizeReference][])
715 """
718class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize):
719 pass
722class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo):
723 pass
726def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]:
727 if isinstance(v, dict):
728 return "with_halo" if "halo" in v else "wo_halo"
729 else:
730 return "with_halo" if hasattr(v, "halo") else "wo_halo"
733_TimeOutputAxisUnion = Annotated[
734 Union[
735 Annotated[TimeOutputAxis, Tag("wo_halo")],
736 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")],
737 ],
738 Discriminator(_get_halo_axis_discriminator_value),
739]
742class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize):
743 pass
746class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo):
747 pass
750_SpaceOutputAxisUnion = Annotated[
751 Union[
752 Annotated[SpaceOutputAxis, Tag("wo_halo")],
753 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")],
754 ],
755 Discriminator(_get_halo_axis_discriminator_value),
756]
759_OutputAxisUnion = Union[
760 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion
761]
762OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")]
764OUTPUT_AXIS_TYPES = (
765 BatchAxis,
766 ChannelAxis,
767 IndexOutputAxis,
768 TimeOutputAxis,
769 TimeOutputAxisWithHalo,
770 SpaceOutputAxis,
771 SpaceOutputAxisWithHalo,
772)
773"""intended for isinstance comparisons in py<3.10"""
776AnyAxis = Union[InputAxis, OutputAxis]
778ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES
779"""intended for isinstance comparisons in py<3.10"""
781TVs = Union[
782 NotEmpty[List[int]],
783 NotEmpty[List[float]],
784 NotEmpty[List[bool]],
785 NotEmpty[List[str]],
786]
789NominalOrOrdinalDType = Literal[
790 "float32",
791 "float64",
792 "uint8",
793 "int8",
794 "uint16",
795 "int16",
796 "uint32",
797 "int32",
798 "uint64",
799 "int64",
800 "bool",
801]
804class NominalOrOrdinalDataDescr(Node):
805 values: TVs
806 """A fixed set of nominal or an ascending sequence of ordinal values.
807 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'.
808 String `values` are interpreted as labels for tensor values 0, ..., N.
809 Note: as YAML 1.2 does not natively support a "set" datatype,
810 nominal values should be given as a sequence (aka list/array) as well.
811 """
813 type: Annotated[
814 NominalOrOrdinalDType,
815 Field(
816 examples=[
817 "float32",
818 "uint8",
819 "uint16",
820 "int64",
821 "bool",
822 ],
823 ),
824 ] = "uint8"
826 @model_validator(mode="after")
827 def _validate_values_match_type(
828 self,
829 ) -> Self:
830 incompatible: List[Any] = []
831 for v in self.values:
832 if self.type == "bool":
833 if not isinstance(v, bool):
834 incompatible.append(v)
835 elif self.type in DTYPE_LIMITS:
836 if (
837 isinstance(v, (int, float))
838 and (
839 v < DTYPE_LIMITS[self.type].min
840 or v > DTYPE_LIMITS[self.type].max
841 )
842 or (isinstance(v, str) and "uint" not in self.type)
843 or (isinstance(v, float) and "int" in self.type)
844 ):
845 incompatible.append(v)
846 else:
847 incompatible.append(v)
849 if len(incompatible) == 5:
850 incompatible.append("...")
851 break
853 if incompatible:
854 raise ValueError(
855 f"data type '{self.type}' incompatible with values {incompatible}"
856 )
858 return self
860 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None
862 @property
863 def range(self):
864 if isinstance(self.values[0], str):
865 return 0, len(self.values) - 1
866 else:
867 return min(self.values), max(self.values)
870IntervalOrRatioDType = Literal[
871 "float32",
872 "float64",
873 "uint8",
874 "int8",
875 "uint16",
876 "int16",
877 "uint32",
878 "int32",
879 "uint64",
880 "int64",
881]
884class IntervalOrRatioDataDescr(Node):
885 type: Annotated[ # TODO: rename to dtype
886 IntervalOrRatioDType,
887 Field(
888 examples=["float32", "float64", "uint8", "uint16"],
889 ),
890 ] = "float32"
891 range: Tuple[Optional[float], Optional[float]] = (
892 None,
893 None,
894 )
895 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
896 `None` corresponds to min/max of what can be expressed by **type**."""
897 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
898 scale: float = 1.0
899 """Scale for data on an interval (or ratio) scale."""
900 offset: Optional[float] = None
901 """Offset for data on a ratio scale."""
903 @model_validator(mode="before")
904 def _replace_inf(cls, data: Any):
905 if is_dict(data):
906 if "range" in data and is_sequence(data["range"]):
907 forbidden = (
908 "inf",
909 "-inf",
910 ".inf",
911 "-.inf",
912 float("inf"),
913 float("-inf"),
914 )
915 if any(v in forbidden for v in data["range"]):
916 issue_warning("replaced 'inf' value", value=data["range"])
918 data["range"] = tuple(
919 (None if v in forbidden else v) for v in data["range"]
920 )
922 return data
925TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr]
928class BinarizeKwargs(KwargsNode):
929 """key word arguments for [BinarizeDescr][]"""
931 threshold: float
932 """The fixed threshold"""
935class BinarizeAlongAxisKwargs(KwargsNode):
936 """key word arguments for [BinarizeDescr][]"""
938 threshold: NotEmpty[List[float]]
939 """The fixed threshold values along `axis`"""
941 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
942 """The `threshold` axis"""
945class BinarizeDescr(NodeWithExplicitlySetFields):
946 """Binarize the tensor with a fixed threshold.
948 Values above [BinarizeKwargs.threshold][]/[BinarizeAlongAxisKwargs.threshold][]
949 will be set to one, values below the threshold to zero.
951 Examples:
952 - in YAML
953 ```yaml
954 postprocessing:
955 - id: binarize
956 kwargs:
957 axis: 'channel'
958 threshold: [0.25, 0.5, 0.75]
959 ```
960 - in Python:
962 >>> postprocessing = [BinarizeDescr(
963 ... kwargs=BinarizeAlongAxisKwargs(
964 ... axis=AxisId('channel'),
965 ... threshold=[0.25, 0.5, 0.75],
966 ... )
967 ... )]
968 """
970 implemented_id: ClassVar[Literal["binarize"]] = "binarize"
971 if TYPE_CHECKING:
972 id: Literal["binarize"] = "binarize"
973 else:
974 id: Literal["binarize"]
975 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs]
978class ClipKwargs(KwargsNode):
979 """key word arguments for [ClipDescr][]"""
981 min: Optional[float] = None
982 """Minimum value for clipping.
984 Exclusive with [min_percentile][]
985 """
986 min_percentile: Optional[Annotated[float, Interval(ge=0, lt=100)]] = None
987 """Minimum percentile for clipping.
989 Exclusive with [min][].
991 In range [0, 100).
992 """
994 max: Optional[float] = None
995 """Maximum value for clipping.
997 Exclusive with `max_percentile`.
998 """
999 max_percentile: Optional[Annotated[float, Interval(gt=1, le=100)]] = None
1000 """Maximum percentile for clipping.
1002 Exclusive with `max`.
1004 In range (1, 100].
1005 """
1007 axes: Annotated[
1008 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1009 ] = None
1010 """The subset of axes to determine percentiles jointly,
1012 i.e. axes to reduce to compute min/max from `min_percentile`/`max_percentile`.
1013 For example to clip 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1014 resulting in a tensor of equal shape with clipped values per channel, specify `axes=('batch', 'x', 'y')`.
1015 To clip samples independently, leave out the 'batch' axis.
1017 Only valid if `min_percentile` and/or `max_percentile` are set.
1019 Default: Compute percentiles over all axes jointly."""
1021 @model_validator(mode="after")
1022 def _validate(self) -> Self:
1023 if (self.min is not None) and (self.min_percentile is not None):
1024 raise ValueError(
1025 "Only one of `min` and `min_percentile` may be set, not both."
1026 )
1027 if (self.max is not None) and (self.max_percentile is not None):
1028 raise ValueError(
1029 "Only one of `max` and `max_percentile` may be set, not both."
1030 )
1031 if (
1032 self.min is None
1033 and self.min_percentile is None
1034 and self.max is None
1035 and self.max_percentile is None
1036 ):
1037 raise ValueError(
1038 "At least one of `min`, `min_percentile`, `max`, or `max_percentile` must be set."
1039 )
1041 if (
1042 self.axes is not None
1043 and self.min_percentile is None
1044 and self.max_percentile is None
1045 ):
1046 raise ValueError(
1047 "If `axes` is set, at least one of `min_percentile` or `max_percentile` must be set."
1048 )
1050 return self
1053class ClipDescr(NodeWithExplicitlySetFields):
1054 """Set tensor values below min to min and above max to max.
1056 See `ScaleRangeDescr` for examples.
1057 """
1059 implemented_id: ClassVar[Literal["clip"]] = "clip"
1060 if TYPE_CHECKING:
1061 id: Literal["clip"] = "clip"
1062 else:
1063 id: Literal["clip"]
1065 kwargs: ClipKwargs
1068class EnsureDtypeKwargs(KwargsNode):
1069 """key word arguments for [EnsureDtypeDescr][]"""
1071 dtype: Literal[
1072 "float32",
1073 "float64",
1074 "uint8",
1075 "int8",
1076 "uint16",
1077 "int16",
1078 "uint32",
1079 "int32",
1080 "uint64",
1081 "int64",
1082 "bool",
1083 ]
1086class EnsureDtypeDescr(NodeWithExplicitlySetFields):
1087 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching).
1089 This can for example be used to ensure the inner neural network model gets a
1090 different input tensor data type than the fully described bioimage.io model does.
1092 Examples:
1093 The described bioimage.io model (incl. preprocessing) accepts any
1094 float32-compatible tensor, normalizes it with percentiles and clipping and then
1095 casts it to uint8, which is what the neural network in this example expects.
1096 - in YAML
1097 ```yaml
1098 inputs:
1099 - data:
1100 type: float32 # described bioimage.io model is compatible with any float32 input tensor
1101 preprocessing:
1102 - id: scale_range
1103 kwargs:
1104 axes: ['y', 'x']
1105 max_percentile: 99.8
1106 min_percentile: 5.0
1107 - id: clip
1108 kwargs:
1109 min: 0.0
1110 max: 1.0
1111 - id: ensure_dtype # the neural network of the model requires uint8
1112 kwargs:
1113 dtype: uint8
1114 ```
1115 - in Python:
1116 >>> preprocessing = [
1117 ... ScaleRangeDescr(
1118 ... kwargs=ScaleRangeKwargs(
1119 ... axes= (AxisId('y'), AxisId('x')),
1120 ... max_percentile= 99.8,
1121 ... min_percentile= 5.0,
1122 ... )
1123 ... ),
1124 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)),
1125 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")),
1126 ... ]
1127 """
1129 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype"
1130 if TYPE_CHECKING:
1131 id: Literal["ensure_dtype"] = "ensure_dtype"
1132 else:
1133 id: Literal["ensure_dtype"]
1135 kwargs: EnsureDtypeKwargs
1138class ScaleLinearKwargs(KwargsNode):
1139 """Key word arguments for [ScaleLinearDescr][]"""
1141 gain: float = 1.0
1142 """multiplicative factor"""
1144 offset: float = 0.0
1145 """additive term"""
1147 @model_validator(mode="after")
1148 def _validate(self) -> Self:
1149 if self.gain == 1.0 and self.offset == 0.0:
1150 raise ValueError(
1151 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1152 + " != 0.0."
1153 )
1155 return self
1158class ScaleLinearAlongAxisKwargs(KwargsNode):
1159 """Key word arguments for [ScaleLinearDescr][]"""
1161 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])]
1162 """The axis of gain and offset values."""
1164 gain: Union[float, NotEmpty[List[float]]] = 1.0
1165 """multiplicative factor"""
1167 offset: Union[float, NotEmpty[List[float]]] = 0.0
1168 """additive term"""
1170 @model_validator(mode="after")
1171 def _validate(self) -> Self:
1172 if isinstance(self.gain, list):
1173 if isinstance(self.offset, list):
1174 if len(self.gain) != len(self.offset):
1175 raise ValueError(
1176 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match."
1177 )
1178 else:
1179 self.offset = [float(self.offset)] * len(self.gain)
1180 elif isinstance(self.offset, list):
1181 self.gain = [float(self.gain)] * len(self.offset)
1182 else:
1183 raise ValueError(
1184 "Do not specify an `axis` for scalar gain and offset values."
1185 )
1187 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset):
1188 raise ValueError(
1189 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`"
1190 + " != 0.0."
1191 )
1193 return self
1196class ScaleLinearDescr(NodeWithExplicitlySetFields):
1197 """Fixed linear scaling.
1199 Examples:
1200 1. Scale with scalar gain and offset
1201 - in YAML
1202 ```yaml
1203 preprocessing:
1204 - id: scale_linear
1205 kwargs:
1206 gain: 2.0
1207 offset: 3.0
1208 ```
1209 - in Python:
1211 >>> preprocessing = [
1212 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0))
1213 ... ]
1215 2. Independent scaling along an axis
1216 - in YAML
1217 ```yaml
1218 preprocessing:
1219 - id: scale_linear
1220 kwargs:
1221 axis: 'channel'
1222 gain: [1.0, 2.0, 3.0]
1223 ```
1224 - in Python:
1226 >>> preprocessing = [
1227 ... ScaleLinearDescr(
1228 ... kwargs=ScaleLinearAlongAxisKwargs(
1229 ... axis=AxisId("channel"),
1230 ... gain=[1.0, 2.0, 3.0],
1231 ... )
1232 ... )
1233 ... ]
1235 """
1237 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear"
1238 if TYPE_CHECKING:
1239 id: Literal["scale_linear"] = "scale_linear"
1240 else:
1241 id: Literal["scale_linear"]
1242 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs]
1245class SigmoidDescr(NodeWithExplicitlySetFields):
1246 """The logistic sigmoid function, a.k.a. expit function.
1248 Examples:
1249 - in YAML
1250 ```yaml
1251 postprocessing:
1252 - id: sigmoid
1253 ```
1254 - in Python:
1256 >>> postprocessing = [SigmoidDescr()]
1257 """
1259 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid"
1260 if TYPE_CHECKING:
1261 id: Literal["sigmoid"] = "sigmoid"
1262 else:
1263 id: Literal["sigmoid"]
1265 @property
1266 def kwargs(self) -> KwargsNode:
1267 """empty kwargs"""
1268 return KwargsNode()
1271class SoftmaxKwargs(KwargsNode):
1272 """key word arguments for [SoftmaxDescr][]"""
1274 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel")
1275 """The axis to apply the softmax function along.
1276 Note:
1277 Defaults to 'channel' axis
1278 (which may not exist, in which case
1279 a different axis id has to be specified).
1280 """
1283class SoftmaxDescr(NodeWithExplicitlySetFields):
1284 """The softmax function.
1286 Examples:
1287 - in YAML
1288 ```yaml
1289 postprocessing:
1290 - id: softmax
1291 kwargs:
1292 axis: channel
1293 ```
1294 - in Python:
1296 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))]
1297 """
1299 implemented_id: ClassVar[Literal["softmax"]] = "softmax"
1300 if TYPE_CHECKING:
1301 id: Literal["softmax"] = "softmax"
1302 else:
1303 id: Literal["softmax"]
1305 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct)
1308class _StardistPostprocessingKwargsBase(KwargsNode):
1309 """key word arguments for [StardistPostprocessingDescr][]"""
1311 prob_threshold: float
1312 """The probability threshold for object candidate selection."""
1314 nms_threshold: float
1315 """The IoU threshold for non-maximum suppression."""
1317 n_rays: int
1318 """Number of radial lines (rays) cast from the center of an object to its boundary."""
1321class StardistPostprocessingKwargs2D(_StardistPostprocessingKwargsBase):
1322 grid: Tuple[int, int]
1323 """Grid size of network predictions."""
1325 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int]]]
1326 """Border region in which object probability is set to zero."""
1329class StardistPostprocessingKwargs3D(_StardistPostprocessingKwargsBase):
1330 grid: Tuple[int, int, int]
1331 """Grid size of network predictions."""
1333 b: Union[int, Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]]
1334 """Border region in which object probability is set to zero."""
1336 anisotropy: Tuple[float, float, float]
1337 """Anisotropy factors for 3D star-convex polyhedra, i.e. the physical pixel size along each spatial axis."""
1339 overlap_label: Optional[int] = None
1340 """Optional label to apply to any area of overlapping predicted objects."""
1343class StardistPostprocessingDescr(NodeWithExplicitlySetFields):
1344 """Stardist postprocessing including non-maximum suppression and converting polygon representations to instance labels
1346 as described in:
1347 - Uwe Schmidt, Martin Weigert, Coleman Broaddus, and Gene Myers.
1348 [*Cell Detection with Star-convex Polygons*](https://arxiv.org/abs/1806.03535).
1349 International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018.
1350 - Martin Weigert, Uwe Schmidt, Robert Haase, Ko Sugawara, and Gene Myers.
1351 [*Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy*](http://openaccess.thecvf.com/content_WACV_2020/papers/Weigert_Star-convex_Polyhedra_for_3D_Object_Detection_and_Segmentation_in_Microscopy_WACV_2020_paper.pdf).
1352 The IEEE Winter Conference on Applications of Computer Vision (WACV), Snowmass Village, Colorado, March 2020.
1354 Note: Only available if the `stardist` package is installed.
1355 """
1357 implemented_id: ClassVar[Literal["stardist_postprocessing"]] = (
1358 "stardist_postprocessing"
1359 )
1360 if TYPE_CHECKING:
1361 id: Literal["stardist_postprocessing"] = "stardist_postprocessing"
1362 else:
1363 id: Literal["stardist_postprocessing"]
1365 kwargs: Union[StardistPostprocessingKwargs2D, StardistPostprocessingKwargs3D]
1368class CellposeFlowDynamicsKwargs(KwargsNode):
1369 """key word arguments for [CellposeFlowDynamicsDescr][]"""
1371 cellprob_threshold: float
1372 flow_threshold: float
1373 do_3D: bool
1374 min_size: int = 15
1375 """Minimum size of objects to keep, in pixels. Default is 15, which is the default in Cellpose. Set to 0 to disable filtering by size."""
1376 output_dtype: Literal["uint16", "uint32"] = "uint16"
1379class CellposeFlowDynamicsDescr(NodeWithExplicitlySetFields):
1380 """Cellpose flow dynamics postprocessing as described in:
1381 - Carsen Stringer and Marius Pachitariu. [*Cellpose: a generalist algorithm for cellular segmentation*](https://www.nature.com/articles/s41592-020-01018-x). Nature Methods, 2021.
1383 Note: Only available if the `cellpose` package is installed.
1384 """
1386 implemented_id: ClassVar[Literal["cellpose_flow_dynamics"]] = (
1387 "cellpose_flow_dynamics"
1388 )
1389 if TYPE_CHECKING:
1390 id: Literal["cellpose_flow_dynamics"] = "cellpose_flow_dynamics"
1391 else:
1392 id: Literal["cellpose_flow_dynamics"]
1394 kwargs: CellposeFlowDynamicsKwargs
1397class CustomProcessingDescr(NodeWithExplicitlySetFields, FileDescr):
1398 """Custom (post)processing op — source file shipped inline with the model.
1400 Supports (post)processing that cannot be expressed by the built-in named
1401 operations (watershed, connected components, etc.)
1402 using a simple Python callable interface.
1404 The op is implemented in a ``.py`` file packaged alongside the model weights.
1405 Two styles are supported:
1407 *Callable class* — kwargs go to ``__init__``, tensors arrive in ``__call__``:
1409 .. code-block:: python
1411 # my_postprocess.py
1412 import numpy as np
1414 class my_postprocess:
1415 def __init__(self, threshold: float = 0.5) -> None:
1416 self.threshold = threshold
1417 def __call__(self, *arrays: np.ndarray) -> np.ndarray:
1418 # arrays = model output tensors in rdf.yaml declaration order
1419 return (arrays[0] > self.threshold).astype(np.uint8)
1421 *Factory function* — alternative closure style, identical runtime behaviour:
1423 .. code-block:: python
1425 # my_postprocess.py
1426 import numpy as np
1428 def my_postprocess(threshold: float = 0.5):
1429 def run(*arrays: np.ndarray) -> np.ndarray:
1430 return (arrays[0] > threshold).astype(np.uint8)
1431 return run
1433 Reference it in ``rdf.yaml`` with the source file included in the package:
1435 .. code-block:: yaml
1437 postprocessing:
1438 - id: custom
1439 callable: my_postprocess # class or function name in source
1440 source: my_postprocess.py # packaged alongside weights
1441 sha256: <hash> # sha256 of the source file
1442 kwargs: # forwarded to __init__ / factory
1443 threshold: 0.5
1445 **Security:** source files are SHA-256 verified before execution.
1446 Execution requires explicit opt-in in bioimageio.core and curator
1447 review before Zoo publication.
1448 """
1450 implemented_id: ClassVar[Literal["custom"]] = "custom"
1451 if TYPE_CHECKING:
1452 id: Literal["custom"] = "custom"
1453 else:
1454 id: Literal["custom"]
1456 callable: Annotated[
1457 str,
1458 Field(examples=["my_postprocess_factory", "MyPostprocessClass"]),
1459 ]
1460 """Name of the callable class or factory function defined in ``source``.
1462 At runtime: ``op = callable(**kwargs)``, then ``result = op(*output_tensors)``
1463 per image. Both a class with ``__call__`` and a factory function returning
1464 a callable satisfy this protocol."""
1466 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
1467 """Python source file (included when packaging the model)."""
1469 kwargs: Dict[str, YamlValue] = Field(
1470 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
1471 )
1472 """Keyword arguments forwarded to the callable (``__init__`` or factory)."""
1474 @model_serializer(mode="wrap", when_used="unless-none")
1475 def _serialize(
1476 self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo
1477 ) -> Dict[str, YamlValue]:
1478 return package_file_descr_serializer(self, nxt, info)
1481class FixedZeroMeanUnitVarianceKwargs(KwargsNode):
1482 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]"""
1484 mean: float
1485 """The mean value to normalize with."""
1487 std: Annotated[float, Ge(1e-6)]
1488 """The standard deviation value to normalize with."""
1491class FixedZeroMeanUnitVarianceAlongAxisKwargs(KwargsNode):
1492 """key word arguments for [FixedZeroMeanUnitVarianceDescr][]"""
1494 mean: NotEmpty[List[float]]
1495 """The mean value(s) to normalize with."""
1497 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]]
1498 """The standard deviation value(s) to normalize with.
1499 Size must match `mean` values."""
1501 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])]
1502 """The axis of the mean/std values to normalize each entry along that dimension
1503 separately."""
1505 @model_validator(mode="after")
1506 def _mean_and_std_match(self) -> Self:
1507 if len(self.mean) != len(self.std):
1508 raise ValueError(
1509 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})"
1510 + " must match."
1511 )
1513 return self
1516class FixedZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields):
1517 """Subtract a given mean and divide by the standard deviation.
1519 Normalize with fixed, precomputed values for
1520 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std`
1521 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given
1522 axes.
1524 Examples:
1525 1. scalar value for whole tensor
1526 - in YAML
1527 ```yaml
1528 preprocessing:
1529 - id: fixed_zero_mean_unit_variance
1530 kwargs:
1531 mean: 103.5
1532 std: 13.7
1533 ```
1534 - in Python
1535 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1536 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7)
1537 ... )]
1539 2. independently along an axis
1540 - in YAML
1541 ```yaml
1542 preprocessing:
1543 - id: fixed_zero_mean_unit_variance
1544 kwargs:
1545 axis: channel
1546 mean: [101.5, 102.5, 103.5]
1547 std: [11.7, 12.7, 13.7]
1548 ```
1549 - in Python
1550 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr(
1551 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
1552 ... axis=AxisId("channel"),
1553 ... mean=[101.5, 102.5, 103.5],
1554 ... std=[11.7, 12.7, 13.7],
1555 ... )
1556 ... )]
1557 """
1559 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = (
1560 "fixed_zero_mean_unit_variance"
1561 )
1562 if TYPE_CHECKING:
1563 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance"
1564 else:
1565 id: Literal["fixed_zero_mean_unit_variance"]
1567 kwargs: Union[
1568 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs
1569 ]
1572class ZeroMeanUnitVarianceKwargs(KwargsNode):
1573 """key word arguments for [ZeroMeanUnitVarianceDescr][]"""
1575 axes: Annotated[
1576 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1577 ] = None
1578 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1579 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1580 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1581 To normalize each sample independently leave out the 'batch' axis.
1582 Default: Scale all axes jointly."""
1584 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1585 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
1588class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields):
1589 """Subtract mean and divide by variance.
1591 Examples:
1592 Subtract tensor mean and variance
1593 - in YAML
1594 ```yaml
1595 preprocessing:
1596 - id: zero_mean_unit_variance
1597 ```
1598 - in Python
1599 >>> preprocessing = [ZeroMeanUnitVarianceDescr()]
1600 """
1602 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = (
1603 "zero_mean_unit_variance"
1604 )
1605 if TYPE_CHECKING:
1606 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance"
1607 else:
1608 id: Literal["zero_mean_unit_variance"]
1610 kwargs: ZeroMeanUnitVarianceKwargs = Field(
1611 default_factory=ZeroMeanUnitVarianceKwargs.model_construct
1612 )
1615class ScaleRangeKwargs(KwargsNode):
1616 """key word arguments for [ScaleRangeDescr][]
1618 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default)
1619 this processing step normalizes data to the [0, 1] intervall.
1620 For other percentiles the normalized values will partially be outside the [0, 1]
1621 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the
1622 normalized values to a range.
1623 """
1625 axes: Annotated[
1626 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1627 ] = None
1628 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value.
1629 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1630 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1631 To normalize samples independently, leave out the "batch" axis.
1632 Default: Scale all axes jointly."""
1634 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0
1635 """The lower percentile used to determine the value to align with zero."""
1637 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0
1638 """The upper percentile used to determine the value to align with one.
1639 Has to be bigger than `min_percentile`.
1640 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly
1641 accepting percentiles specified in the range 0.0 to 1.0."""
1643 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1644 """Epsilon for numeric stability.
1645 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`;
1646 with `v_lower,v_upper` values at the respective percentiles."""
1648 reference_tensor: Optional[TensorId] = None
1649 """ID of the unprocessed input tensor to compute the percentiles from.
1650 Default: The tensor itself.
1651 """
1653 @field_validator("max_percentile", mode="after")
1654 @classmethod
1655 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float:
1656 if (min_p := info.data["min_percentile"]) >= value:
1657 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}")
1659 return value
1662class ScaleRangeDescr(NodeWithExplicitlySetFields):
1663 """Scale with percentiles.
1665 Examples:
1666 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0
1667 - in YAML
1668 ```yaml
1669 preprocessing:
1670 - id: scale_range
1671 kwargs:
1672 axes: ['y', 'x']
1673 max_percentile: 99.8
1674 min_percentile: 5.0
1675 ```
1676 - in Python
1678 >>> preprocessing = [
1679 ... ScaleRangeDescr(
1680 ... kwargs=ScaleRangeKwargs(
1681 ... axes= (AxisId('y'), AxisId('x')),
1682 ... max_percentile= 99.8,
1683 ... min_percentile= 5.0,
1684 ... )
1685 ... )
1686 ... ]
1688 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles.
1689 - in YAML
1690 ```yaml
1691 preprocessing:
1692 - id: scale_range
1693 kwargs:
1694 axes: ['y', 'x']
1695 max_percentile: 99.8
1696 min_percentile: 5.0
1697 - id: clip
1698 kwargs:
1699 min: 0.0
1700 max: 1.0
1701 ```
1702 - in Python
1704 >>> preprocessing = [
1705 ... ScaleRangeDescr(
1706 ... kwargs=ScaleRangeKwargs(
1707 ... axes= (AxisId('y'), AxisId('x')),
1708 ... max_percentile= 99.8,
1709 ... min_percentile= 5.0,
1710 ... )
1711 ... ),
1712 ... ClipDescr(
1713 ... kwargs=ClipKwargs(
1714 ... min=0.0,
1715 ... max=1.0,
1716 ... )
1717 ... ),
1718 ... ]
1720 """
1722 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range"
1723 if TYPE_CHECKING:
1724 id: Literal["scale_range"] = "scale_range"
1725 else:
1726 id: Literal["scale_range"]
1727 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct)
1730class ScaleMeanVarianceKwargs(KwargsNode):
1731 """key word arguments for [ScaleMeanVarianceKwargs][]"""
1733 reference_tensor: TensorId
1734 """ID of unprocessed input tensor to match."""
1736 axes: Annotated[
1737 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")])
1738 ] = None
1739 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std.
1740 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x')
1741 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`.
1742 To normalize samples independently, leave out the 'batch' axis.
1743 Default: Scale all axes jointly."""
1745 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
1746 """Epsilon for numeric stability:
1747 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`"""
1750class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields):
1751 """Scale a tensor's data distribution to match another tensor's mean/std.
1752 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`
1753 """
1755 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance"
1756 if TYPE_CHECKING:
1757 id: Literal["scale_mean_variance"] = "scale_mean_variance"
1758 else:
1759 id: Literal["scale_mean_variance"]
1760 kwargs: ScaleMeanVarianceKwargs
1763PreprocessingDescr = Annotated[
1764 Union[
1765 BinarizeDescr,
1766 ClipDescr,
1767 EnsureDtypeDescr,
1768 FixedZeroMeanUnitVarianceDescr,
1769 ScaleLinearDescr,
1770 ScaleRangeDescr,
1771 SigmoidDescr,
1772 SoftmaxDescr,
1773 ZeroMeanUnitVarianceDescr,
1774 ],
1775 Discriminator("id"),
1776]
1777PostprocessingDescr = Annotated[
1778 Union[
1779 BinarizeDescr,
1780 CellposeFlowDynamicsDescr,
1781 ClipDescr,
1782 CustomProcessingDescr,
1783 EnsureDtypeDescr,
1784 FixedZeroMeanUnitVarianceDescr,
1785 ScaleLinearDescr,
1786 ScaleMeanVarianceDescr,
1787 ScaleRangeDescr,
1788 SigmoidDescr,
1789 SoftmaxDescr,
1790 StardistPostprocessingDescr,
1791 ZeroMeanUnitVarianceDescr,
1792 ],
1793 Discriminator("id"),
1794]
1796IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis)
1799class TensorDescrBase(Node, Generic[IO_AxisT]):
1800 id: TensorId
1801 """Tensor id. No duplicates are allowed."""
1803 description: Annotated[str, MaxLen(128)] = ""
1804 """free text description"""
1806 axes: NotEmpty[Sequence[IO_AxisT]]
1807 """tensor axes"""
1809 @property
1810 def shape(self):
1811 return tuple(a.size for a in self.axes)
1813 @field_validator("axes", mode="after", check_fields=False)
1814 @classmethod
1815 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
1816 batch_axes = [a for a in axes if a.type == "batch"]
1817 if len(batch_axes) > 1:
1818 raise ValueError(
1819 f"Only one batch axis (per tensor) allowed, but got {batch_axes}"
1820 )
1822 seen_ids: Set[AxisId] = set()
1823 duplicate_axes_ids: Set[AxisId] = set()
1824 for a in axes:
1825 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id)
1827 if duplicate_axes_ids:
1828 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}")
1830 return axes
1832 test_tensor: FAIR[Optional[FileDescr_]] = None
1833 """An example tensor to use for testing.
1834 Using the model with the test input tensors is expected to yield the test output tensors.
1835 Each test tensor has be a an ndarray in the
1836 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format).
1837 The file extension must be '.npy'."""
1839 sample_tensor: FAIR[Optional[FileDescr_]] = None
1840 """A sample tensor to illustrate a possible input/output for the model,
1841 The sample image primarily serves to inform a human user about an example use case
1842 and is typically stored as .hdf5, .png or .tiff.
1843 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats)
1844 (numpy's `.npy` format is not supported).
1845 The image dimensionality has to match the number of axes specified in this tensor description.
1846 """
1848 @model_validator(mode="after")
1849 def _validate_sample_tensor(self) -> Self:
1850 if self.sample_tensor is None or not get_validation_context().perform_io_checks:
1851 return self
1853 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256)
1854 tensor: NDArray[Any] = imread( # pyright: ignore[reportUnknownVariableType]
1855 reader.read(),
1856 extension=PurePosixPath(reader.original_file_name).suffix,
1857 )
1858 n_dims = len(tensor.squeeze().shape)
1859 n_dims_min = n_dims_max = len(self.axes)
1861 for a in self.axes:
1862 if isinstance(a, BatchAxis):
1863 n_dims_min -= 1
1864 elif isinstance(a.size, int):
1865 if a.size == 1:
1866 n_dims_min -= 1
1867 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
1868 if a.size.min == 1:
1869 n_dims_min -= 1
1870 elif isinstance(a.size, SizeReference):
1871 if a.size.offset < 2:
1872 # size reference may result in singleton axis
1873 n_dims_min -= 1
1874 else:
1875 assert_never(a.size)
1877 n_dims_min = max(0, n_dims_min)
1878 if n_dims < n_dims_min or n_dims > n_dims_max:
1879 raise ValueError(
1880 f"Expected sample tensor to have {n_dims_min} to"
1881 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})."
1882 )
1884 return self
1886 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = (
1887 IntervalOrRatioDataDescr()
1888 )
1889 """Description of the tensor's data values, optionally per channel.
1890 If specified per channel, the data `type` needs to match across channels."""
1892 @property
1893 def dtype(
1894 self,
1895 ) -> Literal[
1896 "float32",
1897 "float64",
1898 "uint8",
1899 "int8",
1900 "uint16",
1901 "int16",
1902 "uint32",
1903 "int32",
1904 "uint64",
1905 "int64",
1906 "bool",
1907 ]:
1908 """dtype as specified under `data.type` or `data[i].type`"""
1909 if isinstance(self.data, collections.abc.Sequence):
1910 return self.data[0].type
1911 else:
1912 return self.data.type
1914 @field_validator("data", mode="after")
1915 @classmethod
1916 def _check_data_type_across_channels(
1917 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]
1918 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]:
1919 if not isinstance(value, list):
1920 return value
1922 dtypes = {t.type for t in value}
1923 if len(dtypes) > 1:
1924 raise ValueError(
1925 "Tensor data descriptions per channel need to agree in their data"
1926 + f" `type`, but found {dtypes}."
1927 )
1929 return value
1931 @model_validator(mode="after")
1932 def _check_data_matches_channelaxis(self) -> Self:
1933 if not isinstance(self.data, (list, tuple)):
1934 return self
1936 for a in self.axes:
1937 if isinstance(a, ChannelAxis):
1938 size = a.size
1939 assert isinstance(size, int)
1940 break
1941 else:
1942 return self
1944 if len(self.data) != size:
1945 raise ValueError(
1946 f"Got tensor data descriptions for {len(self.data)} channels, but"
1947 + f" '{a.id}' axis has size {size}."
1948 )
1950 return self
1952 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]:
1953 if len(array.shape) != len(self.axes):
1954 raise ValueError(
1955 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})"
1956 + f" incompatible with {len(self.axes)} axes."
1957 )
1958 return {a.id: array.shape[i] for i, a in enumerate(self.axes)}
1961class ConstantPadding(Node):
1962 mode: Literal["constant"] = "constant"
1963 value: Union[int, float] = 0
1966class EdgePadding(Node):
1967 mode: Literal["edge"] = "edge"
1970class ReflectPadding(Node):
1971 mode: Literal["reflect"] = "reflect"
1974class SymmetricPadding(Node):
1975 mode: Literal["symmetric"] = "symmetric"
1978Padding = Union[ConstantPadding, EdgePadding, ReflectPadding, SymmetricPadding]
1981class InputTensorDescr(TensorDescrBase[InputAxis]):
1982 id: TensorId = TensorId("input")
1983 """Input tensor id.
1984 No duplicates are allowed across all inputs and outputs."""
1986 optional: bool = False
1987 """indicates that this tensor may be `None`"""
1989 pad: Optional[Padding] = None
1990 """Explicitly specify how to pad this input tensor.
1992 Use `axes[i].pad` to specify padding width.
1994 Note:
1995 Non-blockwise sample prediction only applies padding for axes with a `pad` specification.
1996 """
1998 preprocessing: List[PreprocessingDescr] = Field(
1999 default_factory=cast(Callable[[], List[PreprocessingDescr]], list)
2000 )
2001 """Description of how this input should be preprocessed.
2003 notes:
2004 - If preprocessing does not start with an 'ensure_dtype' entry, it is added
2005 to ensure an input tensor's data type matches the input tensor's data description.
2006 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an
2007 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally
2008 changing the data type.
2009 """
2011 @model_validator(mode="after")
2012 def _validate_preprocessing_kwargs(self) -> Self:
2013 axes_ids = [a.id for a in self.axes]
2014 for p in self.preprocessing:
2015 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes")
2016 if kwargs_axes is None:
2017 continue
2019 if not isinstance(kwargs_axes, collections.abc.Sequence):
2020 raise ValueError(
2021 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}"
2022 )
2024 if any(a not in axes_ids for a in kwargs_axes):
2025 raise ValueError(
2026 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids"
2027 )
2029 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2030 dtype = self.data.type
2031 else:
2032 dtype = self.data[0].type
2034 # ensure `preprocessing` begins with `EnsureDtypeDescr`
2035 if not self.preprocessing or not isinstance(
2036 self.preprocessing[0], EnsureDtypeDescr
2037 ):
2038 self.preprocessing.insert(
2039 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2040 )
2042 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2043 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)):
2044 self.preprocessing.append(
2045 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2046 )
2048 return self
2051def convert_axes(
2052 axes: str,
2053 *,
2054 shape: Union[
2055 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4
2056 ],
2057 tensor_type: Literal["input", "output"],
2058 halo: Optional[Sequence[int]],
2059 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2060):
2061 ret: List[AnyAxis] = []
2062 for i, a in enumerate(axes):
2063 axis_type = _AXIS_TYPE_MAP.get(a, a)
2064 if axis_type == "batch":
2065 ret.append(BatchAxis())
2066 continue
2068 scale = 1.0
2069 if isinstance(shape, _ParameterizedInputShape_v0_4):
2070 if shape.step[i] == 0:
2071 size = shape.min[i]
2072 else:
2073 size = ParameterizedSize(min=shape.min[i], step=shape.step[i])
2074 elif isinstance(shape, _ImplicitOutputShape_v0_4):
2075 ref_t = str(shape.reference_tensor)
2076 if ref_t.count(".") == 1:
2077 t_id, orig_a_id = ref_t.split(".")
2078 else:
2079 t_id = ref_t
2080 orig_a_id = a
2082 a_id = _AXIS_ID_MAP.get(orig_a_id, a)
2083 if not (orig_scale := shape.scale[i]):
2084 # old way to insert a new axis dimension
2085 size = int(2 * shape.offset[i])
2086 else:
2087 scale = 1 / orig_scale
2088 if axis_type in ("channel", "index"):
2089 # these axes no longer have a scale
2090 offset_from_scale = orig_scale * size_refs.get(
2091 _TensorName_v0_4(t_id), {}
2092 ).get(orig_a_id, 0)
2093 else:
2094 offset_from_scale = 0
2095 size = SizeReference(
2096 tensor_id=TensorId(t_id),
2097 axis_id=AxisId(a_id),
2098 offset=int(offset_from_scale + 2 * shape.offset[i]),
2099 )
2100 else:
2101 size = shape[i]
2103 if axis_type == "time":
2104 if tensor_type == "input":
2105 ret.append(TimeInputAxis(size=size, scale=scale))
2106 else:
2107 assert not isinstance(size, ParameterizedSize)
2108 if halo is None:
2109 ret.append(TimeOutputAxis(size=size, scale=scale))
2110 else:
2111 assert not isinstance(size, int)
2112 ret.append(
2113 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i])
2114 )
2116 elif axis_type == "index":
2117 if tensor_type == "input":
2118 ret.append(IndexInputAxis(size=size))
2119 else:
2120 if isinstance(size, ParameterizedSize):
2121 size = DataDependentSize(min=size.min)
2123 ret.append(IndexOutputAxis(size=size))
2124 elif axis_type == "channel":
2125 assert not isinstance(size, ParameterizedSize)
2126 if isinstance(size, SizeReference):
2127 warnings.warn(
2128 "Conversion of channel size from an implicit output shape may be"
2129 + " wrong"
2130 )
2131 ret.append(
2132 ChannelAxis(
2133 channel_names=[
2134 Identifier(f"channel{i}") for i in range(size.offset)
2135 ]
2136 )
2137 )
2138 else:
2139 ret.append(
2140 ChannelAxis(
2141 channel_names=[Identifier(f"channel{i}") for i in range(size)]
2142 )
2143 )
2144 elif axis_type == "space":
2145 if tensor_type == "input":
2146 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale))
2147 else:
2148 assert not isinstance(size, ParameterizedSize)
2149 if halo is None or halo[i] == 0:
2150 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale))
2151 elif isinstance(size, int):
2152 raise NotImplementedError(
2153 f"output axis with halo and fixed size (here {size}) not allowed"
2154 )
2155 else:
2156 ret.append(
2157 SpaceOutputAxisWithHalo(
2158 id=AxisId(a), size=size, scale=scale, halo=halo[i]
2159 )
2160 )
2162 return ret
2165def _axes_letters_to_ids(
2166 axes: Optional[str],
2167) -> Optional[List[AxisId]]:
2168 if axes is None:
2169 return None
2171 return [AxisId(a) for a in axes]
2174def _get_complement_v04_axis(
2175 tensor_axes: Sequence[str], axes: Optional[Sequence[str]]
2176) -> Optional[AxisId]:
2177 if axes is None:
2178 return None
2180 non_complement_axes = set(axes) | {"b"}
2181 complement_axes = [a for a in tensor_axes if a not in non_complement_axes]
2182 if len(complement_axes) > 1:
2183 raise ValueError(
2184 f"Expected none or a single complement axis, but axes '{axes}' "
2185 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'."
2186 )
2188 return None if not complement_axes else AxisId(complement_axes[0])
2191def _convert_proc(
2192 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4],
2193 tensor_axes: Sequence[str],
2194) -> Union[PreprocessingDescr, PostprocessingDescr]:
2195 if isinstance(p, _BinarizeDescr_v0_4):
2196 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold))
2197 elif isinstance(p, _ClipDescr_v0_4):
2198 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max))
2199 elif isinstance(p, _SigmoidDescr_v0_4):
2200 return SigmoidDescr()
2201 elif isinstance(p, _ScaleLinearDescr_v0_4):
2202 axes = _axes_letters_to_ids(p.kwargs.axes)
2203 if p.kwargs.axes is None:
2204 axis = None
2205 else:
2206 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
2208 if axis is None:
2209 assert not isinstance(p.kwargs.gain, list)
2210 assert not isinstance(p.kwargs.offset, list)
2211 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset)
2212 else:
2213 kwargs = ScaleLinearAlongAxisKwargs(
2214 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset
2215 )
2216 return ScaleLinearDescr(kwargs=kwargs)
2217 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4):
2218 return ScaleMeanVarianceDescr(
2219 kwargs=ScaleMeanVarianceKwargs(
2220 axes=_axes_letters_to_ids(p.kwargs.axes),
2221 reference_tensor=TensorId(str(p.kwargs.reference_tensor)),
2222 eps=p.kwargs.eps,
2223 )
2224 )
2225 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4):
2226 if p.kwargs.mode == "fixed":
2227 mean = p.kwargs.mean
2228 std = p.kwargs.std
2229 assert mean is not None
2230 assert std is not None
2232 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes)
2234 if axis is None:
2235 if isinstance(mean, list):
2236 raise ValueError("Expected single float value for mean, not <list>")
2237 if isinstance(std, list):
2238 raise ValueError("Expected single float value for std, not <list>")
2239 return FixedZeroMeanUnitVarianceDescr(
2240 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct(
2241 mean=mean,
2242 std=std,
2243 )
2244 )
2245 else:
2246 if not isinstance(mean, list):
2247 mean = [float(mean)]
2248 if not isinstance(std, list):
2249 std = [float(std)]
2251 return FixedZeroMeanUnitVarianceDescr(
2252 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
2253 axis=axis, mean=mean, std=std
2254 )
2255 )
2257 else:
2258 axes = _axes_letters_to_ids(p.kwargs.axes) or []
2259 if p.kwargs.mode == "per_dataset":
2260 axes = [AxisId("batch")] + axes
2261 if not axes:
2262 axes = None
2263 return ZeroMeanUnitVarianceDescr(
2264 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps)
2265 )
2267 elif isinstance(p, _ScaleRangeDescr_v0_4):
2268 return ScaleRangeDescr(
2269 kwargs=ScaleRangeKwargs(
2270 axes=_axes_letters_to_ids(p.kwargs.axes),
2271 min_percentile=p.kwargs.min_percentile,
2272 max_percentile=p.kwargs.max_percentile,
2273 eps=p.kwargs.eps,
2274 )
2275 )
2276 else:
2277 assert_never(p)
2280class _InputTensorConv(
2281 Converter[
2282 _InputTensorDescr_v0_4,
2283 InputTensorDescr,
2284 FileSource_,
2285 Optional[FileSource_],
2286 Mapping[_TensorName_v0_4, Mapping[str, int]],
2287 ]
2288):
2289 def _convert(
2290 self,
2291 src: _InputTensorDescr_v0_4,
2292 tgt: "type[InputTensorDescr] | type[dict[str, Any]]",
2293 test_tensor: FileSource_,
2294 sample_tensor: Optional[FileSource_],
2295 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2296 ) -> "InputTensorDescr | dict[str, Any]":
2297 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2298 src.axes,
2299 shape=src.shape,
2300 tensor_type="input",
2301 halo=None,
2302 size_refs=size_refs,
2303 )
2304 prep: List[PreprocessingDescr] = []
2305 for p in src.preprocessing:
2306 cp = _convert_proc(p, src.axes)
2307 assert not isinstance(
2308 cp,
2309 (
2310 CellposeFlowDynamicsDescr,
2311 CustomProcessingDescr,
2312 ScaleMeanVarianceDescr,
2313 StardistPostprocessingDescr,
2314 ),
2315 )
2316 prep.append(cp)
2318 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32")))
2320 return tgt(
2321 axes=axes,
2322 id=TensorId(str(src.name)),
2323 test_tensor=FileDescr(source=test_tensor),
2324 sample_tensor=(
2325 None if sample_tensor is None else FileDescr(source=sample_tensor)
2326 ),
2327 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType]
2328 preprocessing=prep,
2329 )
2332_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr)
2335class OutputTensorDescr(TensorDescrBase[OutputAxis]):
2336 id: TensorId = TensorId("output")
2337 """Output tensor id.
2338 No duplicates are allowed across all inputs and outputs."""
2340 postprocessing: List[PostprocessingDescr] = Field(
2341 default_factory=cast(Callable[[], List[PostprocessingDescr]], list)
2342 )
2343 """Description of how this output should be postprocessed.
2345 note: `postprocessing` always ends with an 'ensure_dtype' operation.
2346 If not given this is added to cast to this tensor's `data.type`.
2347 """
2349 @model_validator(mode="after")
2350 def _validate_postprocessing_kwargs(self) -> Self:
2351 axes_ids = [a.id for a in self.axes]
2352 for p in self.postprocessing:
2353 kwargs_axes = p.kwargs.get("axes")
2354 if kwargs_axes is None:
2355 continue
2357 if not isinstance(kwargs_axes, collections.abc.Sequence):
2358 raise ValueError(
2359 f"expected `axes` sequence, but got {type(kwargs_axes)}"
2360 )
2362 kwargs_axes_seq: Sequence[Any] = cast(Sequence[Any], kwargs_axes)
2363 if any(a not in axes_ids for a in kwargs_axes_seq):
2364 raise ValueError("`kwargs.axes` needs to be subset of axes ids")
2366 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)):
2367 dtype = self.data.type
2368 else:
2369 dtype = self.data[0].type
2371 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr`
2372 if not self.postprocessing or not isinstance(
2373 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)
2374 ):
2375 self.postprocessing.append(
2376 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype))
2377 )
2378 return self
2381class _OutputTensorConv(
2382 Converter[
2383 _OutputTensorDescr_v0_4,
2384 OutputTensorDescr,
2385 FileSource_,
2386 Optional[FileSource_],
2387 Mapping[_TensorName_v0_4, Mapping[str, int]],
2388 ]
2389):
2390 def _convert(
2391 self,
2392 src: _OutputTensorDescr_v0_4,
2393 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]",
2394 test_tensor: FileSource_,
2395 sample_tensor: Optional[FileSource_],
2396 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]],
2397 ) -> "OutputTensorDescr | dict[str, Any]":
2398 # TODO: split convert_axes into convert_output_axes and convert_input_axes
2399 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType]
2400 src.axes,
2401 shape=src.shape,
2402 tensor_type="output",
2403 halo=src.halo,
2404 size_refs=size_refs,
2405 )
2406 data_descr: Dict[str, Any] = dict(type=src.data_type)
2407 if data_descr["type"] == "bool":
2408 data_descr["values"] = [False, True]
2410 return tgt(
2411 axes=axes,
2412 id=TensorId(str(src.name)),
2413 test_tensor=FileDescr(source=test_tensor),
2414 sample_tensor=(
2415 None if sample_tensor is None else FileDescr(source=sample_tensor)
2416 ),
2417 data=data_descr, # pyright: ignore[reportArgumentType]
2418 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing],
2419 )
2422_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr)
2425TensorDescr = Union[InputTensorDescr, OutputTensorDescr]
2428def get_halos(
2429 tensors: Mapping[TensorId, TensorDescr],
2430 /,
2431) -> Dict[TensorId, Dict[AxisId, Tuple[int, int]]]:
2432 """Get all input and output halos from tensor descriptions.
2434 Note:
2435 - Input halos are to be padded
2436 - Output halos are to be cropped
2437 """
2438 halos: Dict[TensorId, Dict[AxisId, Tuple[int, int]]] = {}
2439 for descr in tensors.values():
2440 if isinstance(descr, InputTensorDescr):
2441 continue
2442 for axis in descr.axes:
2443 if not isinstance(axis, WithHalo):
2444 continue
2446 ref_scale = next(
2447 a
2448 for a in tensors[axis.size.tensor_id].axes
2449 if a.id == axis.size.axis_id
2450 ).scale
2452 # set output halo (to be cropped)
2453 halos.setdefault(descr.id, {})[axis.id] = (axis.halo, axis.halo)
2454 # set input halo (to be padded)
2455 pad_width = int(axis.halo / axis.scale * ref_scale)
2456 halos.setdefault(axis.size.tensor_id, {})[axis.size.axis_id] = (
2457 pad_width,
2458 pad_width,
2459 )
2461 return halos
2464def validate_tensors(
2465 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]],
2466 tensor_origin: Literal[
2467 "source", "test_tensor"
2468 ] = "source", # for more precise error messages
2469 *,
2470 pad_inputs: Union[bool, Literal["allow"]] = True,
2471 crop_outputs: Union[bool, Literal["allow"]] = True,
2472):
2473 """Validate all inputs (and optionally output tensors) against their tensor descriptions.
2475 Args:
2476 tensors: Mapping of tensor id to a tuple of tensor description and optional numpy array.
2477 tensor_origin: String to use in error messages to indicate the origin of the tensors being validated.
2478 pad_inputs: Wether to apply/allow padding of inputs before shape comparison
2479 crop_outputs: Wether to apply/allow cropping of outputs before shape comparison.
2480 """
2481 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {}
2483 def e_msg_location(d: TensorDescr):
2484 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]"
2486 for descr, array in tensors.values():
2487 if array is None:
2488 axis_sizes = {a.id: None for a in descr.axes}
2489 else:
2490 try:
2491 axis_sizes = descr.get_axis_sizes_for_array(array)
2492 except ValueError as e:
2493 raise ValueError(f"{e_msg_location(descr)} {e}")
2495 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes}
2497 # get halos to be padded/cropped to validate against halo-adjusted sizes
2498 io_halos = get_halos({k: v[0] for k, v in tensors.items()})
2500 for descr, array in tensors.values():
2501 if array is None:
2502 continue
2504 if descr.dtype in ("float32", "float64"):
2505 invalid_test_tensor_dtype = array.dtype.name not in (
2506 "float32",
2507 "float64",
2508 "uint8",
2509 "int8",
2510 "uint16",
2511 "int16",
2512 "uint32",
2513 "int32",
2514 "uint64",
2515 "int64",
2516 )
2517 else:
2518 invalid_test_tensor_dtype = array.dtype.name != descr.dtype
2520 if invalid_test_tensor_dtype:
2521 raise ValueError(
2522 f"{tensor_origin} data type '{array.dtype.name}' does not"
2523 + f" match described {e_msg_location(descr)}.dtype '{descr.dtype}'"
2524 )
2526 if array.min() > -1e-4 and array.max() < 1e-4:
2527 raise ValueError(
2528 "Output values are too small for reliable testing."
2529 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}"
2530 )
2532 for a in descr.axes:
2533 actual_size = all_tensor_axes[descr.id][a.id][1]
2535 if actual_size is None:
2536 continue
2538 if a.size is None:
2539 continue
2541 # add padding width to actual tensor size
2542 total_axis_halo = sum(io_halos.get(descr.id, {}).get(a.id, (0, 0)))
2543 if isinstance(descr, InputTensorDescr):
2544 # pad input halos
2545 actual_size_with_halo = actual_size + total_axis_halo
2546 if pad_inputs is True:
2547 check_sizes = {actual_size_with_halo}
2548 size_hint = " (after padding input halo)"
2549 elif pad_inputs == "allow":
2550 check_sizes = {actual_size, actual_size_with_halo}
2551 size_hint = " (with or without padding input halo)"
2552 elif pad_inputs is False:
2553 check_sizes = {actual_size}
2554 size_hint = ""
2555 else:
2556 assert_never(pad_inputs)
2558 elif isinstance(descr, OutputTensorDescr):
2559 # crop output halos
2560 actual_size_with_halo = max(0, actual_size - total_axis_halo)
2561 if crop_outputs is True:
2562 check_sizes = {actual_size_with_halo}
2563 size_hint = " (after cropping output halo)"
2564 elif crop_outputs == "allow":
2565 check_sizes = {actual_size, actual_size_with_halo}
2566 size_hint = " (with or without cropping output halo)"
2567 elif crop_outputs is False:
2568 check_sizes = {actual_size}
2569 size_hint = ""
2570 else:
2571 assert_never(crop_outputs)
2572 else:
2573 assert_never(descr)
2575 del actual_size # make sure we explicitly use unchanged or halo-adjusted size from here on
2577 if isinstance(a.size, int):
2578 if a.size not in check_sizes:
2579 raise ValueError(
2580 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis "
2581 + f"has incompatible size {check_sizes}{size_hint}, expected {a.size}"
2582 )
2583 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
2584 _ = try_all_raise_last(
2585 (partial(a.size.validate_size, s) for s in check_sizes),
2586 f"{e_msg_location(descr)}.axes[{a.id}]: {tensor_origin} axis ",
2587 )
2588 elif isinstance(a.size, SizeReference):
2589 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
2590 if ref_tensor_axes is None:
2591 raise ValueError(
2592 f"{e_msg_location(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor"
2593 + f" reference '{a.size.tensor_id}', available: {list(all_tensor_axes)}"
2594 )
2596 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None))
2597 if ref_axis is None or ref_size is None:
2598 raise ValueError(
2599 f"{e_msg_location(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis"
2600 + f" reference '{a.size.tensor_id}.{a.size.axis_id}, available: {list(ref_tensor_axes)}"
2601 )
2603 if a.unit != ref_axis.unit:
2604 raise ValueError(
2605 f"{e_msg_location(descr)}.axes[{a.id}].size: `SizeReference` requires"
2606 + " axis and reference axis to have the same `unit`, but"
2607 + f" {a.unit}!={ref_axis.unit}"
2608 )
2610 if (
2611 expected_size := (
2612 ref_size * ref_axis.scale / a.scale + a.size.offset
2613 )
2614 ) not in check_sizes:
2615 raise ValueError(
2616 f"{e_msg_location(descr)}.{tensor_origin}: axis '{a.id}' of size"
2617 + f" {check_sizes} invalid for referenced size {ref_size};"
2618 + f" expected {expected_size}"
2619 )
2620 else:
2621 assert_never(a.size)
2624FileDescr_dependencies = Annotated[
2625 FileDescr_,
2626 WithSuffix((".yaml", ".yml"), case_sensitive=True),
2627 Field(examples=[dict(source="environment.yaml")]),
2628]
2631class _ArchitectureCallableDescr(Node):
2632 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])]
2633 """Identifier of the callable that returns a torch.nn.Module instance."""
2635 kwargs: Dict[str, YamlValue] = Field(
2636 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
2637 )
2638 """key word arguments for the `callable`"""
2641class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr):
2642 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2643 """Architecture source file"""
2645 @model_serializer(mode="wrap", when_used="unless-none")
2646 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2647 return package_file_descr_serializer(self, nxt, info)
2650class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr):
2651 import_from: str
2652 """Where to import the callable from, i.e. `from <import_from> import <callable>`"""
2655class _ArchFileConv(
2656 Converter[
2657 _CallableFromFile_v0_4,
2658 ArchitectureFromFileDescr,
2659 Optional[Sha256],
2660 Dict[str, Any],
2661 ]
2662):
2663 def _convert(
2664 self,
2665 src: _CallableFromFile_v0_4,
2666 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]",
2667 sha256: Optional[Sha256],
2668 kwargs: Dict[str, Any],
2669 ) -> "ArchitectureFromFileDescr | dict[str, Any]":
2670 if src.startswith("http") and src.count(":") == 2:
2671 http, source, callable_ = src.split(":")
2672 source = ":".join((http, source))
2673 elif not src.startswith("http") and src.count(":") == 1:
2674 source, callable_ = src.split(":")
2675 else:
2676 source = str(src)
2677 callable_ = str(src)
2678 return tgt(
2679 callable=Identifier(callable_),
2680 source=cast(FileSource_, source),
2681 sha256=sha256,
2682 kwargs=kwargs,
2683 )
2686_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr)
2689class _ArchLibConv(
2690 Converter[
2691 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any]
2692 ]
2693):
2694 def _convert(
2695 self,
2696 src: _CallableFromDepencency_v0_4,
2697 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]",
2698 kwargs: Dict[str, Any],
2699 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]":
2700 *mods, callable_ = src.split(".")
2701 import_from = ".".join(mods)
2702 return tgt(
2703 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs
2704 )
2707_arch_lib_conv = _ArchLibConv(
2708 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr
2709)
2712class WeightsEntryDescrBase(FileDescr):
2713 type: ClassVar[WeightsFormat]
2714 weights_format_name: ClassVar[str] # human readable
2716 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2717 """Source of the weights file."""
2719 authors: Optional[List[Author]] = None
2720 """Authors
2721 Either the person(s) that have trained this model resulting in the original weights file.
2722 (If this is the initial weights entry, i.e. it does not have a `parent`)
2723 Or the person(s) who have converted the weights to this weights format.
2724 (If this is a child weight, i.e. it has a `parent` field)
2725 """
2727 parent: Annotated[
2728 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"])
2729 ] = None
2730 """The source weights these weights were converted from.
2731 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`,
2732 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights.
2733 All weight entries except one (the initial set of weights resulting from training the model),
2734 need to have this field."""
2736 comment: str = ""
2737 """A comment about this weights entry, for example how these weights were created."""
2739 @model_validator(mode="after")
2740 def _validate(self) -> Self:
2741 if self.type == self.parent:
2742 raise ValueError("Weights entry can't be it's own parent.")
2744 return self
2746 @model_serializer(mode="wrap", when_used="unless-none")
2747 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo):
2748 return package_file_descr_serializer(self, nxt, info)
2751class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
2752 type: ClassVar[WeightsFormat] = "keras_hdf5"
2753 weights_format_name: ClassVar[str] = "Keras HDF5"
2754 tensorflow_version: Version
2755 """TensorFlow version used to create these weights."""
2758class KerasV3WeightsDescr(WeightsEntryDescrBase):
2759 type: ClassVar[WeightsFormat] = "keras_v3"
2760 weights_format_name: ClassVar[str] = "Keras v3"
2761 keras_version: Annotated[Version, Ge(Version(3))]
2762 """Keras version used to create these weights."""
2763 backend: Tuple[Literal["tensorflow", "jax", "torch"], Version]
2764 """Keras backend used to create these weights."""
2765 source: Annotated[
2766 FileSource,
2767 AfterValidator(wo_special_file_name),
2768 WithSuffix(".keras", case_sensitive=True),
2769 ]
2770 """Source of the .keras weights file."""
2773FileDescr_external_data = Annotated[
2774 FileDescr_,
2775 WithSuffix(".data", case_sensitive=True),
2776 Field(examples=[dict(source="weights.onnx.data")]),
2777]
2780class OnnxWeightsDescr(WeightsEntryDescrBase):
2781 type: ClassVar[WeightsFormat] = "onnx"
2782 weights_format_name: ClassVar[str] = "ONNX"
2783 opset_version: Annotated[int, Ge(7)]
2784 """ONNX opset version"""
2786 external_data: Optional[FileDescr_external_data] = None
2787 """Source of the external ONNX data file holding the weights.
2788 (If present **source** holds the ONNX architecture without weights)."""
2790 @model_validator(mode="after")
2791 def _validate_external_data_unique_file_name(self) -> Self:
2792 if self.external_data is not None and (
2793 extract_file_name(self.source)
2794 == extract_file_name(self.external_data.source)
2795 ):
2796 raise ValueError(
2797 f"ONNX `external_data` file name '{extract_file_name(self.external_data.source)}'"
2798 + " must be different from ONNX `source` file name."
2799 )
2801 return self
2804class PytorchStateDictWeightsDescr(WeightsEntryDescrBase):
2805 type: ClassVar[WeightsFormat] = "pytorch_state_dict"
2806 weights_format_name: ClassVar[str] = "Pytorch State Dict"
2807 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr]
2808 pytorch_version: Version
2809 """Version of the PyTorch library used.
2810 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible.
2811 """
2812 dependencies: Optional[FileDescr_dependencies] = None
2813 """Custom depencies beyond pytorch described in a Conda environment file.
2814 Allows to specify custom dependencies, see conda docs:
2815 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms)
2816 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually)
2818 The conda environment file should include pytorch and any version pinning has to be compatible with
2819 **pytorch_version**.
2820 """
2821 strict: bool = True
2822 """Whether to allow missing or unexpected keys or to be strict about the architecture matching the state dict weights."""
2825class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
2826 type: ClassVar[WeightsFormat] = "tensorflow_js"
2827 weights_format_name: ClassVar[str] = "Tensorflow.js"
2828 tensorflow_version: Version
2829 """Version of the TensorFlow library used."""
2831 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2832 """The multi-file weights.
2833 All required files/folders should be a zip archive."""
2836class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
2837 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle"
2838 weights_format_name: ClassVar[str] = "Tensorflow Saved Model"
2839 tensorflow_version: Version
2840 """Version of the TensorFlow library used."""
2842 dependencies: Optional[FileDescr_dependencies] = None
2843 """Custom dependencies beyond tensorflow.
2844 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**."""
2846 source: Annotated[FileSource, AfterValidator(wo_special_file_name)]
2847 """The multi-file weights.
2848 All required files/folders should be a zip archive."""
2851class TorchscriptWeightsDescr(WeightsEntryDescrBase):
2852 type: ClassVar[WeightsFormat] = "torchscript"
2853 weights_format_name: ClassVar[str] = "TorchScript"
2854 pytorch_version: Version
2855 """Version of the PyTorch library used."""
2858SpecificWeightsDescr = Union[
2859 KerasHdf5WeightsDescr,
2860 KerasV3WeightsDescr,
2861 OnnxWeightsDescr,
2862 PytorchStateDictWeightsDescr,
2863 TensorflowJsWeightsDescr,
2864 TensorflowSavedModelBundleWeightsDescr,
2865 TorchscriptWeightsDescr,
2866]
2869class WeightsDescr(Node):
2870 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None
2871 keras_v3: Optional[KerasV3WeightsDescr] = None
2872 onnx: Optional[OnnxWeightsDescr] = None
2873 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None
2874 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None
2875 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = (
2876 None
2877 )
2878 torchscript: Optional[TorchscriptWeightsDescr] = None
2880 @model_validator(mode="after")
2881 def check_entries(self) -> Self:
2882 entries = {wtype for wtype, entry in self if entry is not None}
2884 if not entries:
2885 raise ValueError("Missing weights entry")
2887 entries_wo_parent = {
2888 wtype
2889 for wtype, entry in self
2890 if entry is not None and hasattr(entry, "parent") and entry.parent is None
2891 }
2892 if len(entries_wo_parent) != 1:
2893 issue_warning(
2894 "Exactly one weights entry may not specify the `parent` field (got"
2895 + " {value}). That entry is considered the original set of model weights."
2896 + " Other weight formats are created through conversion of the orignal or"
2897 + " already converted weights. They have to reference the weights format"
2898 + " they were converted from as their `parent`.",
2899 value=len(entries_wo_parent),
2900 field="weights",
2901 )
2903 for wtype, entry in self:
2904 if entry is None:
2905 continue
2907 assert hasattr(entry, "type")
2908 assert hasattr(entry, "parent")
2909 assert wtype == entry.type
2910 if (
2911 entry.parent is not None and entry.parent not in entries
2912 ): # self reference checked for `parent` field
2913 raise ValueError(
2914 f"`weights.{wtype}.parent={entry.parent} not in specified weight"
2915 + f" formats: {entries}"
2916 )
2918 return self
2920 def __getitem__(
2921 self,
2922 key: WeightsFormat,
2923 ):
2924 if key == "keras_hdf5":
2925 ret = self.keras_hdf5
2926 elif key == "keras_v3":
2927 ret = self.keras_v3
2928 elif key == "onnx":
2929 ret = self.onnx
2930 elif key == "pytorch_state_dict":
2931 ret = self.pytorch_state_dict
2932 elif key == "tensorflow_js":
2933 ret = self.tensorflow_js
2934 elif key == "tensorflow_saved_model_bundle":
2935 ret = self.tensorflow_saved_model_bundle
2936 elif key == "torchscript":
2937 ret = self.torchscript
2938 else:
2939 raise KeyError(key)
2941 if ret is None:
2942 raise KeyError(key)
2944 return ret
2946 @overload
2947 def __setitem__(
2948 self, key: Literal["keras_hdf5"], value: Optional[KerasHdf5WeightsDescr]
2949 ) -> None: ...
2950 @overload
2951 def __setitem__(
2952 self, key: Literal["keras_v3"], value: Optional[KerasV3WeightsDescr]
2953 ) -> None: ...
2954 @overload
2955 def __setitem__(
2956 self, key: Literal["onnx"], value: Optional[OnnxWeightsDescr]
2957 ) -> None: ...
2958 @overload
2959 def __setitem__(
2960 self,
2961 key: Literal["pytorch_state_dict"],
2962 value: Optional[PytorchStateDictWeightsDescr],
2963 ) -> None: ...
2964 @overload
2965 def __setitem__(
2966 self, key: Literal["tensorflow_js"], value: Optional[TensorflowJsWeightsDescr]
2967 ) -> None: ...
2968 @overload
2969 def __setitem__(
2970 self,
2971 key: Literal["tensorflow_saved_model_bundle"],
2972 value: Optional[TensorflowSavedModelBundleWeightsDescr],
2973 ) -> None: ...
2974 @overload
2975 def __setitem__(
2976 self, key: Literal["torchscript"], value: Optional[TorchscriptWeightsDescr]
2977 ) -> None: ...
2979 def __setitem__(
2980 self,
2981 key: WeightsFormat,
2982 value: Optional[SpecificWeightsDescr],
2983 ):
2984 if key == "keras_hdf5":
2985 if value is not None and not isinstance(value, KerasHdf5WeightsDescr):
2986 raise TypeError(
2987 f"Expected KerasHdf5WeightsDescr or None for key 'keras_hdf5', got {type(value)}"
2988 )
2989 self.keras_hdf5 = value
2990 elif key == "keras_v3":
2991 if value is not None and not isinstance(value, KerasV3WeightsDescr):
2992 raise TypeError(
2993 f"Expected KerasV3WeightsDescr or None for key 'keras_v3', got {type(value)}"
2994 )
2995 self.keras_v3 = value
2996 elif key == "onnx":
2997 if value is not None and not isinstance(value, OnnxWeightsDescr):
2998 raise TypeError(
2999 f"Expected OnnxWeightsDescr or None for key 'onnx', got {type(value)}"
3000 )
3001 self.onnx = value
3002 elif key == "pytorch_state_dict":
3003 if value is not None and not isinstance(
3004 value, PytorchStateDictWeightsDescr
3005 ):
3006 raise TypeError(
3007 f"Expected PytorchStateDictWeightsDescr or None for key 'pytorch_state_dict', got {type(value)}"
3008 )
3009 self.pytorch_state_dict = value
3010 elif key == "tensorflow_js":
3011 if value is not None and not isinstance(value, TensorflowJsWeightsDescr):
3012 raise TypeError(
3013 f"Expected TensorflowJsWeightsDescr or None for key 'tensorflow_js', got {type(value)}"
3014 )
3015 self.tensorflow_js = value
3016 elif key == "tensorflow_saved_model_bundle":
3017 if value is not None and not isinstance(
3018 value, TensorflowSavedModelBundleWeightsDescr
3019 ):
3020 raise TypeError(
3021 f"Expected TensorflowSavedModelBundleWeightsDescr or None for key 'tensorflow_saved_model_bundle', got {type(value)}"
3022 )
3023 self.tensorflow_saved_model_bundle = value
3024 elif key == "torchscript":
3025 if value is not None and not isinstance(value, TorchscriptWeightsDescr):
3026 raise TypeError(
3027 f"Expected TorchscriptWeightsDescr or None for key 'torchscript', got {type(value)}"
3028 )
3029 self.torchscript = value
3030 else:
3031 raise KeyError(key)
3033 @property
3034 def available_formats(self) -> Dict[WeightsFormat, SpecificWeightsDescr]:
3035 return {
3036 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}),
3037 **({} if self.keras_v3 is None else {"keras_v3": self.keras_v3}),
3038 **({} if self.onnx is None else {"onnx": self.onnx}),
3039 **(
3040 {}
3041 if self.pytorch_state_dict is None
3042 else {"pytorch_state_dict": self.pytorch_state_dict}
3043 ),
3044 **(
3045 {}
3046 if self.tensorflow_js is None
3047 else {"tensorflow_js": self.tensorflow_js}
3048 ),
3049 **(
3050 {}
3051 if self.tensorflow_saved_model_bundle is None
3052 else {
3053 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle
3054 }
3055 ),
3056 **({} if self.torchscript is None else {"torchscript": self.torchscript}),
3057 }
3059 @property
3060 def missing_formats(self) -> Set[WeightsFormat]:
3061 return {
3062 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats
3063 }
3066class ModelId(ResourceId):
3067 pass
3070class LinkedModel(LinkedResourceBase):
3071 """Reference to a bioimage.io model."""
3073 id: ModelId
3074 """A valid model `id` from the bioimage.io collection."""
3077class _DataDepSize(NamedTuple):
3078 min: StrictInt
3079 max: Optional[StrictInt]
3082class _AxisSizes(NamedTuple):
3083 """the lenghts of all axes of model inputs and outputs"""
3085 inputs: Dict[Tuple[TensorId, AxisId], int]
3086 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]]
3089class _TensorSizes(NamedTuple):
3090 """_AxisSizes as nested dicts"""
3092 inputs: Dict[TensorId, Dict[AxisId, int]]
3093 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]
3096class ReproducibilityTolerance(Node, extra="allow"):
3097 """Describes what small numerical differences -- if any -- may be tolerated
3098 in the generated output when executing in different environments.
3100 A tensor element *output* is considered mismatched to the **test_tensor** if
3101 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**).
3102 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).)
3104 Motivation:
3105 For testing we can request the respective deep learning frameworks to be as
3106 reproducible as possible by setting seeds and chosing deterministic algorithms,
3107 but differences in operating systems, available hardware and installed drivers
3108 may still lead to numerical differences.
3109 """
3111 relative_tolerance: RelativeTolerance = 1e-3
3112 """Maximum relative tolerance of reproduced test tensor."""
3114 absolute_tolerance: AbsoluteTolerance = 1e-3
3115 """Maximum absolute tolerance of reproduced test tensor."""
3117 mismatched_elements_per_million: MismatchedElementsPerMillion = 100
3118 """Maximum number of mismatched elements/pixels per million to tolerate."""
3120 output_ids: Sequence[TensorId] = ()
3121 """Limits the output tensor IDs these reproducibility details apply to."""
3123 weights_formats: Sequence[WeightsFormat] = ()
3124 """Limits the weights formats these details apply to."""
3127class BiasRisksLimitations(Node, extra="allow"):
3128 """Known biases, risks, technical limitations, and recommendations for model use."""
3130 known_biases: str = dedent("""\
3131 In general bioimage models may suffer from biases caused by:
3133 - Imaging protocol dependencies
3134 - Use of a specific cell type
3135 - Species-specific training data limitations
3137 """)
3138 """Biases in training data or model behavior."""
3140 risks: str = dedent("""\
3141 Common risks in bioimage analysis include:
3143 - Erroneously assuming generalization to unseen experimental conditions
3144 - Trusting (overconfident) model outputs without validation
3145 - Misinterpretation of results
3147 """)
3148 """Potential risks in the context of bioimage analysis."""
3150 limitations: Optional[str] = None
3151 """Technical limitations and failure modes."""
3153 recommendations: str = "Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model."
3154 """Mitigation strategies regarding `known_biases`, `risks`, and `limitations`, as well as applicable best practices.
3156 Consider:
3157 - How to use a validation dataset?
3158 - How to manually validate?
3159 - Feasibility of domain adaptation for different experimental setups?
3161 """
3163 def format_md(self) -> str:
3164 if self.limitations is None:
3165 limitations_header = ""
3166 else:
3167 limitations_header = "## Limitations\n\n"
3169 return f"""# Bias, Risks, and Limitations
3171{self.known_biases}
3173{self.risks}
3175{limitations_header}{self.limitations or ""}
3177## Recommendations
3179{self.recommendations}
3181"""
3184class TrainingDetails(Node, extra="allow"):
3185 training_preprocessing: Optional[str] = None
3186 """Detailed image preprocessing steps during model training:
3188 Mention:
3189 - *Normalization methods*
3190 - *Augmentation strategies*
3191 - *Resizing/resampling procedures*
3192 - *Artifact handling*
3194 """
3196 training_epochs: Optional[float] = None
3197 """Number of training epochs."""
3199 training_batch_size: Optional[float] = None
3200 """Batch size used in training."""
3202 initial_learning_rate: Optional[float] = None
3203 """Initial learning rate used in training."""
3205 learning_rate_schedule: Optional[str] = None
3206 """Learning rate schedule used in training."""
3208 loss_function: Optional[str] = None
3209 """Loss function used in training, e.g. nn.MSELoss."""
3211 loss_function_kwargs: Dict[str, YamlValue] = Field(
3212 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
3213 )
3214 """key word arguments for the `loss_function`"""
3216 optimizer: Optional[str] = None
3217 """optimizer, e.g. torch.optim.Adam"""
3219 optimizer_kwargs: Dict[str, YamlValue] = Field(
3220 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict)
3221 )
3222 """key word arguments for the `optimizer`"""
3224 regularization: Optional[str] = None
3225 """Regularization techniques used during training, e.g. drop-out or weight decay."""
3227 training_duration: Optional[float] = None
3228 """Total training duration in hours."""
3231class Evaluation(Node, extra="allow"):
3232 model_id: Optional[ModelId] = None
3233 """Model being evaluated."""
3235 dataset_id: DatasetId
3236 """Dataset used for evaluation."""
3238 dataset_source: HttpUrl
3239 """Source of the dataset."""
3241 dataset_role: Literal["train", "validation", "test", "independent", "unknown"]
3242 """Role of the dataset used for evaluation.
3244 - `train`: dataset was (part of) the training data
3245 - `validation`: dataset was (part of) the validation data used during training, e.g. used for model selection or hyperparameter tuning
3246 - `test`: dataset was (part of) the designated test data; not used during training or validation, but acquired from the same source/distribution as training data
3247 - `independent`: dataset is entirely independent test data; not used during training or validation, and acquired from a different source/distribution than training data
3248 - `unknown`: role of the dataset is unknown; choose this if you are not certain if (a subset) of the data was seen by the model during training.
3249 """
3251 sample_count: int
3252 """Number of evaluated samples."""
3254 evaluation_factors: List[Annotated[str, MaxLen(16)]]
3255 """(Abbreviations of) each evaluation factor.
3257 Evaluation factors are criteria along which model performance is evaluated, e.g. different image conditions
3258 like 'low SNR', 'high cell density', or different biological conditions like 'cell type A', 'cell type B'.
3259 An 'overall' factor may be included to summarize performance across all conditions.
3260 """
3262 evaluation_factors_long: List[str]
3263 """Descriptions (long form) of each evaluation factor."""
3265 metrics: List[Annotated[str, MaxLen(16)]]
3266 """(Abbreviations of) metrics used for evaluation."""
3268 metrics_long: List[str]
3269 """Description of each metric used."""
3271 @model_validator(mode="after")
3272 def _validate_list_lengths(self) -> Self:
3273 if len(self.evaluation_factors) != len(self.evaluation_factors_long):
3274 raise ValueError(
3275 "`evaluation_factors` and `evaluation_factors_long` must have the same length"
3276 )
3278 if len(self.metrics) != len(self.metrics_long):
3279 raise ValueError("`metrics` and `metrics_long` must have the same length")
3281 if len(self.results) != len(self.metrics):
3282 raise ValueError("`results` must have the same number of rows as `metrics`")
3284 for row in self.results:
3285 if len(row) != len(self.evaluation_factors):
3286 raise ValueError(
3287 "`results` must have the same number of columns (in every row) as `evaluation_factors`"
3288 )
3290 return self
3292 results: List[List[Union[str, float, int]]]
3293 """Results for each metric (rows; outer list) and each evaluation factor (columns; inner list)."""
3295 results_summary: Optional[str] = None
3296 """Interpretation of results for general audience.
3298 Consider:
3299 - Overall model performance
3300 - Comparison to existing methods
3301 - Limitations and areas for improvement
3303"""
3305 def format_md(self):
3306 results_header = ["Metric"] + self.evaluation_factors
3307 results_table_cells = [results_header, ["---"] * len(results_header)] + [
3308 [metric] + [str(r) for r in row]
3309 for metric, row in zip(self.metrics, self.results)
3310 ]
3312 results_table = "".join(
3313 "| " + " | ".join(row) + " |\n" for row in results_table_cells
3314 )
3315 factors = "".join(
3316 f"\n - {ef}: {efl}"
3317 for ef, efl in zip(self.evaluation_factors, self.evaluation_factors_long)
3318 )
3319 metrics = "".join(
3320 f"\n - {em}: {eml}" for em, eml in zip(self.metrics, self.metrics_long)
3321 )
3323 return f"""## Testing Data, Factors & Metrics
3325Evaluation of {self.model_id or "this"} model on the {self.dataset_id} dataset (dataset role: {self.dataset_role}).
3327### Testing Data
3329- **Source:** [{self.dataset_id}]({self.dataset_source})
3330- **Size:** {self.sample_count} evaluated samples
3332### Factors
3333{factors}
3335### Metrics
3336{metrics}
3338## Results
3340### Quantitative Results
3342{results_table}
3344### Summary
3346{self.results_summary or "missing"}
3348"""
3351class EnvironmentalImpact(Node, extra="allow"):
3352 """Environmental considerations for model training and deployment.
3354 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
3355 """
3357 hardware_type: Optional[str] = None
3358 """GPU/CPU specifications"""
3360 hours_used: Optional[float] = None
3361 """Total compute hours"""
3363 cloud_provider: Optional[str] = None
3364 """If applicable"""
3366 compute_region: Optional[str] = None
3367 """Geographic location"""
3369 co2_emitted: Optional[float] = None
3370 """kg CO2 equivalent
3372 Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
3373 """
3375 def format_md(self):
3376 """Filled Markdown template section following [Hugging Face Model Card Template](https://huggingface.co/docs/hub/en/model-card-annotated)."""
3377 if self == self.__class__():
3378 return ""
3380 ret = "# Environmental Impact\n\n"
3381 if self.hardware_type is not None:
3382 ret += f"- **Hardware Type:** {self.hardware_type}\n"
3383 if self.hours_used is not None:
3384 ret += f"- **Hours used:** {self.hours_used}\n"
3385 if self.cloud_provider is not None:
3386 ret += f"- **Cloud Provider:** {self.cloud_provider}\n"
3387 if self.compute_region is not None:
3388 ret += f"- **Compute Region:** {self.compute_region}\n"
3389 if self.co2_emitted is not None:
3390 ret += f"- **Carbon Emitted:** {self.co2_emitted} kg CO2e\n"
3392 return ret + "\n"
3395class BioimageioConfig(Node, extra="allow"):
3396 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = ()
3397 """Tolerances to allow when reproducing the model's test outputs
3398 from the model's test inputs.
3399 Only the first entry matching tensor id and weights format is considered.
3400 """
3402 funded_by: Optional[str] = None
3403 """Funding agency, grant number if applicable"""
3405 architecture_type: Optional[Annotated[str, MaxLen(32)]] = (
3406 None # TODO: add to differentiated tags
3407 )
3408 """Model architecture type, e.g., 3D U-Net, ResNet, transformer"""
3410 architecture_description: Optional[str] = None
3411 """Text description of model architecture."""
3413 modality: Optional[str] = None # TODO: add to differentiated tags
3414 """Input modality, e.g., fluorescence microscopy, electron microscopy"""
3416 target_structure: List[str] = Field( # TODO: add to differentiated tags
3417 default_factory=cast(Callable[[], List[str]], list)
3418 )
3419 """Biological structure(s) the model is designed to analyze, e.g., nuclei, mitochondria, cells"""
3421 task: Optional[str] = None # TODO: add to differentiated tags
3422 """Bioimage-specific task type, e.g., segmentation, classification, detection, denoising"""
3424 new_version: Optional[ModelId] = None
3425 """A new version of this model exists with a different model id."""
3427 out_of_scope_use: Optional[str] = None
3428 """Describe how the model may be misused in bioimage analysis contexts and what users should **not** do with the model."""
3430 bias_risks_limitations: BiasRisksLimitations = Field(
3431 default_factory=BiasRisksLimitations.model_construct
3432 )
3433 """Description of known bias, risks, and technical limitations for in-scope model use."""
3435 model_parameter_count: Optional[int] = None
3436 """Total number of model parameters."""
3438 training: TrainingDetails = Field(default_factory=TrainingDetails.model_construct)
3439 """Details on how the model was trained."""
3441 inference_time: Optional[str] = None
3442 """Average inference time per image/tile. Specify hardware and image size. Multiple examples can be given."""
3444 memory_requirements_inference: Optional[str] = None
3445 """GPU memory needed for inference. Multiple examples with different image size can be given."""
3447 memory_requirements_training: Optional[str] = None
3448 """GPU memory needed for training. Multiple examples with different image/batch sizes can be given."""
3450 evaluations: List[Evaluation] = Field(
3451 default_factory=cast(Callable[[], List[Evaluation]], list)
3452 )
3453 """Quantitative model evaluations.
3455 Note:
3456 At the moment we recommend to include only a single test dataset
3457 (with evaluation factors that may mark subsets of the dataset)
3458 to avoid confusion and make the presentation of results cleaner.
3459 """
3461 environmental_impact: EnvironmentalImpact = Field(
3462 default_factory=EnvironmentalImpact.model_construct
3463 )
3464 """Environmental considerations for model training and deployment"""
3467class Config(Node, extra="allow"):
3468 bioimageio: BioimageioConfig = Field(
3469 default_factory=BioimageioConfig.model_construct
3470 )
3471 stardist: YamlValue = None
3474class ModelDescr(GenericModelDescrBase):
3475 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
3476 These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
3477 """
3479 implemented_format_version: ClassVar[Literal["0.5.10"]] = "0.5.10"
3480 if TYPE_CHECKING:
3481 format_version: Literal["0.5.10"] = "0.5.10"
3482 else:
3483 format_version: Literal["0.5.10"]
3484 """Version of the bioimage.io model description specification used.
3485 When creating a new model always use the latest micro/patch version described here.
3486 The `format_version` is important for any consumer software to understand how to parse the fields.
3487 """
3489 implemented_type: ClassVar[Literal["model"]] = "model"
3490 if TYPE_CHECKING:
3491 type: Literal["model"] = "model"
3492 else:
3493 type: Literal["model"]
3494 """Specialized resource type 'model'"""
3496 id: Optional[ModelId] = None
3497 """bioimage.io-wide unique resource identifier
3498 assigned by bioimage.io; version **un**specific."""
3500 authors: FAIR[List[Author]] = Field(
3501 default_factory=cast(Callable[[], List[Author]], list)
3502 )
3503 """The authors are the creators of the model RDF and the primary points of contact."""
3505 documentation: FAIR[Optional[FileSource_documentation]] = None
3506 """URL or relative path to a markdown file with additional documentation.
3507 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
3508 The documentation should include a '#[#] Validation' (sub)section
3509 with details on how to quantitatively validate the model on unseen data."""
3511 @field_validator("documentation", mode="after")
3512 @classmethod
3513 def _validate_documentation(
3514 cls, value: Optional[FileSource_documentation]
3515 ) -> Optional[FileSource_documentation]:
3516 if not get_validation_context().perform_io_checks or value is None:
3517 return value
3519 doc_reader = get_reader(value)
3520 doc_content = doc_reader.read().decode(encoding="utf-8")
3521 if not re.search("#.*[vV]alidation", doc_content):
3522 issue_warning(
3523 "No '# Validation' (sub)section found in {value}.",
3524 value=value,
3525 field="documentation",
3526 )
3528 return value
3530 inputs: NotEmpty[Sequence[InputTensorDescr]]
3531 """Describes the input tensors expected by this model."""
3533 @field_validator("inputs", mode="after")
3534 @classmethod
3535 def _validate_input_axes(
3536 cls, inputs: Sequence[InputTensorDescr]
3537 ) -> Sequence[InputTensorDescr]:
3538 input_size_refs = cls._get_axes_with_independent_size(inputs)
3540 for i, ipt in enumerate(inputs):
3541 valid_independent_refs: Dict[
3542 Tuple[TensorId, AxisId],
3543 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3544 ] = {
3545 **{
3546 (ipt.id, a.id): (ipt, a, a.size)
3547 for a in ipt.axes
3548 if not isinstance(a, BatchAxis)
3549 and isinstance(a.size, (int, ParameterizedSize))
3550 },
3551 **input_size_refs,
3552 }
3553 for a, ax in enumerate(ipt.axes):
3554 cls._validate_axis(
3555 "inputs",
3556 i=i,
3557 tensor_id=ipt.id,
3558 a=a,
3559 axis=ax,
3560 valid_independent_refs=valid_independent_refs,
3561 )
3562 return inputs
3564 @staticmethod
3565 def _validate_axis(
3566 field_name: str,
3567 i: int,
3568 tensor_id: TensorId,
3569 a: int,
3570 axis: AnyAxis,
3571 valid_independent_refs: Dict[
3572 Tuple[TensorId, AxisId],
3573 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3574 ],
3575 ):
3576 if isinstance(axis, BatchAxis) or isinstance(
3577 axis.size, (int, ParameterizedSize, DataDependentSize)
3578 ):
3579 return
3580 elif not isinstance(axis.size, SizeReference):
3581 assert_never(axis.size)
3583 # validate axis.size SizeReference
3584 ref = (axis.size.tensor_id, axis.size.axis_id)
3585 if ref not in valid_independent_refs:
3586 raise ValueError(
3587 "Invalid tensor axis reference at"
3588 + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
3589 )
3590 if ref == (tensor_id, axis.id):
3591 raise ValueError(
3592 "Self-referencing not allowed for"
3593 + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
3594 )
3595 if axis.type == "channel":
3596 if valid_independent_refs[ref][1].type != "channel":
3597 raise ValueError(
3598 "A channel axis' size may only reference another fixed size"
3599 + " channel axis."
3600 )
3601 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
3602 ref_size = valid_independent_refs[ref][2]
3603 assert isinstance(ref_size, int), (
3604 "channel axis ref (another channel axis) has to specify fixed"
3605 + " size"
3606 )
3607 generated_channel_names = [
3608 Identifier(axis.channel_names.format(i=i))
3609 for i in range(1, ref_size + 1)
3610 ]
3611 axis.channel_names = generated_channel_names
3613 if (ax_unit := getattr(axis, "unit", None)) != (
3614 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
3615 ):
3616 raise ValueError(
3617 "The units of an axis and its reference axis need to match, but"
3618 + f" '{ax_unit}' != '{ref_unit}'."
3619 )
3620 ref_axis = valid_independent_refs[ref][1]
3621 if isinstance(ref_axis, BatchAxis):
3622 raise ValueError(
3623 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
3624 + " (a batch axis is not allowed as reference)."
3625 )
3627 if isinstance(axis, WithHalo):
3628 min_size = axis.size.get_size(axis, ref_axis, n=0)
3629 if (min_size - 2 * axis.halo) < 1:
3630 raise ValueError(
3631 f"axis {axis.id} with minimum size {min_size} is too small for halo"
3632 + f" {axis.halo}."
3633 )
3635 ref_halo = axis.halo * axis.scale / ref_axis.scale
3636 if ref_halo != int(ref_halo):
3637 raise ValueError(
3638 f"Inferred halo for {'.'.join(ref)} is not an integer ({ref_halo} ="
3639 + f" {tensor_id}.{axis.id}.halo {axis.halo}"
3640 + f" * {tensor_id}.{axis.id}.scale {axis.scale}"
3641 + f" / {'.'.join(ref)}.scale {ref_axis.scale})."
3642 )
3644 def validate_input_tensors(
3645 self,
3646 sources: Union[
3647 Sequence[NDArray[Any]], Mapping[TensorId, Optional[NDArray[Any]]]
3648 ],
3649 *,
3650 pad_inputs: Union[bool, Literal["allow"]] = True,
3651 crop_outputs: Union[bool, Literal["allow"]] = True,
3652 ) -> Mapping[TensorId, Optional[NDArray[Any]]]:
3653 """Check if the given input tensors match the model's input tensor descriptions.
3654 This includes checks of tensor shapes and dtypes, but not of the actual values.
3655 """
3656 if not isinstance(sources, collections.abc.Mapping):
3657 sources = {descr.id: tensor for descr, tensor in zip(self.inputs, sources)}
3659 tensors = {
3660 **{descr.id: (descr, sources.get(descr.id)) for descr in self.inputs},
3661 **{ # outputs are required for halo
3662 descr.id: (descr, None) for descr in self.outputs
3663 },
3664 }
3665 validate_tensors(tensors, pad_inputs=pad_inputs, crop_outputs=crop_outputs)
3667 return sources
3669 @model_validator(mode="after")
3670 def _validate_test_tensors(self) -> Self:
3671 if not get_validation_context().perform_io_checks:
3672 return self
3674 test_inputs = {
3675 descr.id: (
3676 descr,
3677 None if descr.test_tensor is None else load_array(descr.test_tensor),
3678 )
3679 for descr in self.inputs
3680 }
3681 test_outputs = {
3682 descr.id: (
3683 descr,
3684 None if descr.test_tensor is None else load_array(descr.test_tensor),
3685 )
3686 for descr in self.outputs
3687 }
3689 validate_tensors(
3690 {**test_inputs, **test_outputs},
3691 tensor_origin="test_tensor",
3692 pad_inputs="allow",
3693 crop_outputs="allow",
3694 )
3696 for rep_tol in self.config.bioimageio.reproducibility_tolerance:
3697 if not rep_tol.absolute_tolerance:
3698 continue
3700 if rep_tol.output_ids:
3701 out_arrays = {
3702 k: v[1] for k, v in test_outputs.items() if k in rep_tol.output_ids
3703 }
3704 else:
3705 out_arrays = {k: v[1] for k, v in test_outputs.items()}
3707 for out_id, array in out_arrays.items():
3708 if array is None:
3709 continue
3711 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
3712 raise ValueError(
3713 "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
3714 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
3715 + f" (1% of the maximum value of the test tensor '{out_id}')"
3716 )
3718 return self
3720 @model_validator(mode="after")
3721 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
3722 ipt_refs = {t.id for t in self.inputs}
3723 missing_refs = [
3724 k["reference_tensor"]
3725 for k in [p.kwargs for ipt in self.inputs for p in ipt.preprocessing]
3726 + [p.kwargs for out in self.outputs for p in out.postprocessing]
3727 if "reference_tensor" in k
3728 and k["reference_tensor"] is not None
3729 and k["reference_tensor"] not in ipt_refs
3730 ]
3732 if missing_refs:
3733 raise ValueError(
3734 f"`reference_tensor`s {missing_refs} not found. Valid input tensor"
3735 + f" references are: {ipt_refs}."
3736 )
3738 return self
3740 name: Annotated[
3741 str,
3742 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
3743 MinLen(5),
3744 MaxLen(128),
3745 warn(MaxLen(64), "Name longer than 64 characters.", INFO),
3746 ]
3747 """A human-readable name of this model.
3748 It should be no longer than 64 characters
3749 and may only contain letter, number, underscore, minus, parentheses and spaces.
3750 We recommend to chose a name that refers to the model's task and image modality.
3751 """
3753 outputs: NotEmpty[Sequence[OutputTensorDescr]]
3754 """Describes the output tensors."""
3756 @field_validator("outputs", mode="after")
3757 @classmethod
3758 def _validate_tensor_ids(
3759 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
3760 ) -> Sequence[OutputTensorDescr]:
3761 tensor_ids = [
3762 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
3763 ]
3764 duplicate_tensor_ids: List[str] = []
3765 seen: Set[str] = set()
3766 for t in tensor_ids:
3767 if t in seen:
3768 duplicate_tensor_ids.append(t)
3770 seen.add(t)
3772 if duplicate_tensor_ids:
3773 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
3775 return outputs
3777 @staticmethod
3778 def _get_axes_with_parameterized_size(
3779 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3780 ):
3781 return {
3782 f"{t.id}.{a.id}": (t, a, a.size)
3783 for t in io
3784 for a in t.axes
3785 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
3786 }
3788 @staticmethod
3789 def _get_axes_with_independent_size(
3790 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3791 ):
3792 return {
3793 (t.id, a.id): (t, a, a.size)
3794 for t in io
3795 for a in t.axes
3796 if not isinstance(a, BatchAxis)
3797 and isinstance(a.size, (int, ParameterizedSize))
3798 }
3800 @field_validator("outputs", mode="after")
3801 @classmethod
3802 def _validate_output_axes(
3803 cls, outputs: List[OutputTensorDescr], info: ValidationInfo
3804 ) -> List[OutputTensorDescr]:
3805 input_size_refs = cls._get_axes_with_independent_size(
3806 info.data.get("inputs", [])
3807 )
3808 output_size_refs = cls._get_axes_with_independent_size(outputs)
3810 for i, out in enumerate(outputs):
3811 valid_independent_refs: Dict[
3812 Tuple[TensorId, AxisId],
3813 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
3814 ] = {
3815 **{
3816 (out.id, a.id): (out, a, a.size)
3817 for a in out.axes
3818 if not isinstance(a, BatchAxis)
3819 and isinstance(a.size, (int, ParameterizedSize))
3820 },
3821 **input_size_refs,
3822 **output_size_refs,
3823 }
3824 for a, ax in enumerate(out.axes):
3825 cls._validate_axis(
3826 "outputs",
3827 i,
3828 out.id,
3829 a,
3830 ax,
3831 valid_independent_refs=valid_independent_refs,
3832 )
3834 return outputs
3836 packaged_by: List[Author] = Field(
3837 default_factory=cast(Callable[[], List[Author]], list)
3838 )
3839 """The persons that have packaged and uploaded this model.
3840 Only required if those persons differ from the `authors`."""
3842 parent: Optional[LinkedModel] = None
3843 """The model from which this model is derived, e.g. by fine-tuning the weights."""
3845 @model_validator(mode="after")
3846 def _validate_parent_is_not_self(self) -> Self:
3847 if self.parent is not None and self.parent.id == self.id:
3848 raise ValueError("A model description may not reference itself as parent.")
3850 return self
3852 run_mode: Annotated[
3853 Optional[RunMode],
3854 warn(None, "Run mode '{value}' has limited support across consumer softwares."),
3855 ] = None
3856 """Custom run mode for this model: for more complex prediction procedures like test time
3857 data augmentation that currently cannot be expressed in the specification.
3858 No standard run modes are defined yet."""
3860 timestamp: Datetime = Field(default_factory=Datetime.now)
3861 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
3862 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
3863 (In Python a datetime object is valid, too)."""
3865 training_data: Annotated[
3866 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
3867 Field(union_mode="left_to_right"),
3868 ] = None
3869 """The dataset used to train this model"""
3871 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
3872 """The weights for this model.
3873 Weights can be given for different formats, but should otherwise be equivalent.
3874 The available weight formats determine which consumers can use this model."""
3876 config: Config = Field(default_factory=Config.model_construct)
3878 @model_validator(mode="after")
3879 def _add_default_cover(self) -> Self:
3880 if not get_validation_context().perform_io_checks or self.covers:
3881 return self
3883 try:
3884 generated_covers = generate_covers(
3885 [
3886 (t, load_array(t.test_tensor))
3887 for t in self.inputs
3888 if t.test_tensor is not None
3889 ],
3890 [
3891 (t, load_array(t.test_tensor))
3892 for t in self.outputs
3893 if t.test_tensor is not None
3894 ],
3895 )
3896 except Exception as e:
3897 issue_warning(
3898 "Failed to generate cover image(s): {e}",
3899 value=self.covers,
3900 msg_context=dict(e=e),
3901 field="covers",
3902 )
3903 else:
3904 self.covers.extend(generated_covers)
3906 return self
3908 def get_input_test_arrays(self) -> List[NDArray[Any]]:
3909 return self._get_test_arrays(self.inputs)
3911 def get_output_test_arrays(self) -> List[NDArray[Any]]:
3912 return self._get_test_arrays(self.outputs)
3914 @staticmethod
3915 def _get_test_arrays(
3916 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3917 ):
3918 ts: List[FileDescr] = []
3919 for d in io_descr:
3920 if d.test_tensor is None:
3921 raise ValueError(
3922 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3923 )
3924 ts.append(d.test_tensor)
3926 data = [load_array(t) for t in ts]
3927 assert all(isinstance(d, np.ndarray) for d in data)
3928 return data
3930 @staticmethod
3931 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3932 batch_size = 1
3933 tensor_with_batchsize: Optional[TensorId] = None
3934 for tid in tensor_sizes:
3935 for aid, s in tensor_sizes[tid].items():
3936 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3937 continue
3939 if batch_size != 1:
3940 assert tensor_with_batchsize is not None
3941 raise ValueError(
3942 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3943 )
3945 batch_size = s
3946 tensor_with_batchsize = tid
3948 return batch_size
3950 def get_output_tensor_sizes(
3951 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3952 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3953 """Returns the tensor output sizes for given **input_sizes**.
3954 Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3955 Otherwise it might be larger than the actual (valid) output"""
3956 batch_size = self.get_batch_size(input_sizes)
3957 ns = self.get_ns(input_sizes)
3959 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3960 return tensor_sizes.outputs
3962 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3963 """get parameter `n` for each parameterized axis
3964 such that the valid input size is >= the given input size"""
3965 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3966 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3967 for tid in input_sizes:
3968 for aid, s in input_sizes[tid].items():
3969 size_descr = axes[tid][aid].size
3970 if isinstance(size_descr, ParameterizedSize):
3971 ret[(tid, aid)] = size_descr.get_n(s)
3972 elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3973 pass
3974 else:
3975 assert_never(size_descr)
3977 return ret
3979 def get_tensor_sizes(
3980 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3981 ) -> _TensorSizes:
3982 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3983 return _TensorSizes(
3984 {
3985 t: {
3986 aa: axis_sizes.inputs[(tt, aa)]
3987 for tt, aa in axis_sizes.inputs
3988 if tt == t
3989 }
3990 for t in {tt for tt, _ in axis_sizes.inputs}
3991 },
3992 {
3993 t: {
3994 aa: axis_sizes.outputs[(tt, aa)]
3995 for tt, aa in axis_sizes.outputs
3996 if tt == t
3997 }
3998 for t in {tt for tt, _ in axis_sizes.outputs}
3999 },
4000 )
4002 def get_axis_sizes(
4003 self,
4004 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
4005 batch_size: Optional[int] = None,
4006 *,
4007 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
4008 ) -> _AxisSizes:
4009 """Determine input and output block shape for scale factors **ns**
4010 of parameterized input sizes.
4012 Args:
4013 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
4014 that is parameterized as `size = min + n * step`.
4015 batch_size: The desired size of the batch dimension.
4016 If given **batch_size** overwrites any batch size present in
4017 **max_input_shape**. Default 1.
4018 max_input_shape: Limits the derived block shapes.
4019 Each axis for which the input size, parameterized by `n`, is larger
4020 than **max_input_shape** is set to the minimal value `n_min` for which
4021 this is still true.
4022 Use this for small input samples or large values of **ns**.
4023 Or simply whenever you know the full input shape.
4025 Returns:
4026 Resolved axis sizes for model inputs and outputs.
4027 """
4028 max_input_shape = max_input_shape or {}
4029 if batch_size is None:
4030 for (_t_id, a_id), s in max_input_shape.items():
4031 if a_id == BATCH_AXIS_ID:
4032 batch_size = s
4033 break
4034 else:
4035 batch_size = 1
4037 all_axes = {
4038 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
4039 }
4041 inputs: Dict[Tuple[TensorId, AxisId], int] = {}
4042 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
4044 def get_axis_size(a: Union[InputAxis, OutputAxis]):
4045 if isinstance(a, BatchAxis):
4046 if (t_descr.id, a.id) in ns:
4047 logger.warning(
4048 "Ignoring unexpected size increment factor (n) for batch axis"
4049 + " of tensor '{}'.",
4050 t_descr.id,
4051 )
4052 return batch_size
4053 elif isinstance(a.size, int):
4054 if (t_descr.id, a.id) in ns:
4055 logger.warning(
4056 "Ignoring unexpected size increment factor (n) for fixed size"
4057 + " axis '{}' of tensor '{}'.",
4058 a.id,
4059 t_descr.id,
4060 )
4061 return a.size
4062 elif isinstance(a.size, ParameterizedSize):
4063 if (t_descr.id, a.id) not in ns:
4064 raise ValueError(
4065 "Size increment factor (n) missing for parametrized axis"
4066 + f" '{a.id}' of tensor '{t_descr.id}'."
4067 )
4068 n = ns[(t_descr.id, a.id)]
4069 s_max = max_input_shape.get((t_descr.id, a.id))
4070 if s_max is not None:
4071 n = min(n, a.size.get_n(s_max))
4073 return a.size.get_size(n)
4075 elif isinstance(a.size, SizeReference):
4076 if (t_descr.id, a.id) in ns:
4077 logger.warning(
4078 "Ignoring unexpected size increment factor (n) for axis '{}'"
4079 + " of tensor '{}' with size reference.",
4080 a.id,
4081 t_descr.id,
4082 )
4083 assert not isinstance(a, BatchAxis)
4084 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
4085 assert not isinstance(ref_axis, BatchAxis)
4086 ref_key = (a.size.tensor_id, a.size.axis_id)
4087 ref_size = inputs.get(ref_key, outputs.get(ref_key))
4088 assert ref_size is not None, ref_key
4089 assert not isinstance(ref_size, _DataDepSize), ref_key
4090 return a.size.get_size(
4091 axis=a,
4092 ref_axis=ref_axis,
4093 ref_size=ref_size,
4094 )
4095 elif isinstance(a.size, DataDependentSize):
4096 if (t_descr.id, a.id) in ns:
4097 logger.warning(
4098 "Ignoring unexpected increment factor (n) for data dependent"
4099 + " size axis '{}' of tensor '{}'.",
4100 a.id,
4101 t_descr.id,
4102 )
4103 return _DataDepSize(a.size.min, a.size.max)
4104 else:
4105 assert_never(a.size)
4107 # first resolve all , but the `SizeReference` input sizes
4108 for t_descr in self.inputs:
4109 for a in t_descr.axes:
4110 if not isinstance(a.size, SizeReference):
4111 s = get_axis_size(a)
4112 assert not isinstance(s, _DataDepSize)
4113 inputs[t_descr.id, a.id] = s
4115 # resolve all other input axis sizes
4116 for t_descr in self.inputs:
4117 for a in t_descr.axes:
4118 if isinstance(a.size, SizeReference):
4119 s = get_axis_size(a)
4120 assert not isinstance(s, _DataDepSize)
4121 inputs[t_descr.id, a.id] = s
4123 # resolve all output axis sizes
4124 for t_descr in self.outputs:
4125 for a in t_descr.axes:
4126 assert not isinstance(a.size, ParameterizedSize)
4127 s = get_axis_size(a)
4128 outputs[t_descr.id, a.id] = s
4130 return _AxisSizes(inputs=inputs, outputs=outputs)
4132 @model_validator(mode="before")
4133 @classmethod
4134 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
4135 cls.convert_from_old_format_wo_validation(data)
4136 return data
4138 @classmethod
4139 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
4140 """Convert metadata following an older format version to this classes' format
4141 without validating the result.
4142 """
4143 if (
4144 data.get("type") == "model"
4145 and isinstance(fv := data.get("format_version"), str)
4146 and fv.count(".") == 2
4147 ):
4148 fv_parts = fv.split(".")
4149 if any(not p.isdigit() for p in fv_parts):
4150 return
4152 fv_tuple = tuple(map(int, fv_parts))
4154 assert cls.implemented_format_version_tuple[0:2] == (0, 5)
4155 if fv_tuple[:2] in ((0, 3), (0, 4)):
4156 m04 = _ModelDescr_v0_4.load(data)
4157 if isinstance(m04, InvalidDescr):
4158 try:
4159 updated = _model_conv.convert_as_dict(
4160 m04 # pyright: ignore[reportArgumentType]
4161 )
4162 except Exception as e:
4163 logger.error(
4164 "Failed to convert from invalid model 0.4 description."
4165 + f"\nerror: {e}"
4166 + "\nProceeding with model 0.5 validation without conversion."
4167 )
4168 updated = None
4169 else:
4170 updated = _model_conv.convert_as_dict(m04)
4172 if updated is not None:
4173 data.clear()
4174 data.update(updated)
4176 elif fv_tuple[:2] == (0, 5):
4177 # bump patch version
4178 data["format_version"] = cls.implemented_format_version
4181class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]):
4182 def _convert(
4183 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]"
4184 ) -> "ModelDescr | dict[str, Any]":
4185 name = "".join(
4186 c if c in string.ascii_letters + string.digits + "_+- ()" else " "
4187 for c in src.name
4188 )
4190 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]):
4191 conv = (
4192 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict
4193 )
4194 return None if auths is None else [conv(a) for a in auths]
4196 if TYPE_CHECKING:
4197 arch_file_conv = _arch_file_conv.convert
4198 arch_lib_conv = _arch_lib_conv.convert
4199 else:
4200 arch_file_conv = _arch_file_conv.convert_as_dict
4201 arch_lib_conv = _arch_lib_conv.convert_as_dict
4203 input_size_refs = {
4204 ipt.name: {
4205 a: s
4206 for a, s in zip(
4207 ipt.axes,
4208 (
4209 ipt.shape.min
4210 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4)
4211 else ipt.shape
4212 ),
4213 )
4214 }
4215 for ipt in src.inputs
4216 if ipt.shape
4217 }
4218 output_size_refs = {
4219 **{
4220 out.name: {a: s for a, s in zip(out.axes, out.shape)}
4221 for out in src.outputs
4222 if not isinstance(out.shape, _ImplicitOutputShape_v0_4)
4223 },
4224 **input_size_refs,
4225 }
4227 return tgt(
4228 attachments=(
4229 []
4230 if src.attachments is None
4231 else [FileDescr(source=f) for f in src.attachments.files]
4232 ),
4233 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType]
4234 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType]
4235 config=src.config, # pyright: ignore[reportArgumentType]
4236 covers=src.covers,
4237 description=src.description,
4238 documentation=src.documentation,
4239 format_version="0.5.10",
4240 git_repo=src.git_repo, # pyright: ignore[reportArgumentType]
4241 icon=src.icon,
4242 id=None if src.id is None else ModelId(src.id),
4243 id_emoji=src.id_emoji,
4244 license=src.license, # type: ignore
4245 links=src.links,
4246 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType]
4247 name=name,
4248 tags=src.tags,
4249 type=src.type,
4250 uploader=src.uploader,
4251 version=src.version,
4252 inputs=[ # pyright: ignore[reportArgumentType]
4253 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs)
4254 for ipt, tt, st in zip(
4255 src.inputs,
4256 src.test_inputs,
4257 src.sample_inputs or [None] * len(src.test_inputs),
4258 )
4259 ],
4260 outputs=[ # pyright: ignore[reportArgumentType]
4261 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs)
4262 for out, tt, st in zip(
4263 src.outputs,
4264 src.test_outputs,
4265 src.sample_outputs or [None] * len(src.test_outputs),
4266 )
4267 ],
4268 parent=(
4269 None
4270 if src.parent is None
4271 else LinkedModel(
4272 id=ModelId(
4273 str(src.parent.id)
4274 + (
4275 ""
4276 if src.parent.version_number is None
4277 else f"/{src.parent.version_number}"
4278 )
4279 )
4280 )
4281 ),
4282 training_data=(
4283 None
4284 if src.training_data is None
4285 else (
4286 LinkedDataset(
4287 id=DatasetId(
4288 str(src.training_data.id)
4289 + (
4290 ""
4291 if src.training_data.version_number is None
4292 else f"/{src.training_data.version_number}"
4293 )
4294 )
4295 )
4296 if isinstance(src.training_data, LinkedDataset02)
4297 else src.training_data
4298 )
4299 ),
4300 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType]
4301 run_mode=src.run_mode,
4302 timestamp=src.timestamp,
4303 weights=(WeightsDescr if TYPE_CHECKING else dict)(
4304 keras_hdf5=(w := src.weights.keras_hdf5)
4305 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)(
4306 authors=conv_authors(w.authors),
4307 source=w.source,
4308 tensorflow_version=w.tensorflow_version or Version("1.15"),
4309 parent=w.parent,
4310 ),
4311 onnx=(w := src.weights.onnx)
4312 and (OnnxWeightsDescr if TYPE_CHECKING else dict)(
4313 source=w.source,
4314 authors=conv_authors(w.authors),
4315 parent=w.parent,
4316 opset_version=w.opset_version or 15,
4317 ),
4318 pytorch_state_dict=(w := src.weights.pytorch_state_dict)
4319 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)(
4320 source=w.source,
4321 authors=conv_authors(w.authors),
4322 parent=w.parent,
4323 architecture=(
4324 arch_file_conv(
4325 w.architecture,
4326 w.architecture_sha256,
4327 w.kwargs,
4328 )
4329 if isinstance(w.architecture, _CallableFromFile_v0_4)
4330 else arch_lib_conv(w.architecture, w.kwargs)
4331 ),
4332 pytorch_version=w.pytorch_version or Version("1.10"),
4333 dependencies=(
4334 None
4335 if w.dependencies is None
4336 else (FileDescr if TYPE_CHECKING else dict)(
4337 source=cast(
4338 FileSource,
4339 str(deps := w.dependencies)[
4340 (
4341 len("conda:")
4342 if str(deps).startswith("conda:")
4343 else 0
4344 ) :
4345 ],
4346 )
4347 )
4348 ),
4349 ),
4350 tensorflow_js=(w := src.weights.tensorflow_js)
4351 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)(
4352 source=w.source,
4353 authors=conv_authors(w.authors),
4354 parent=w.parent,
4355 tensorflow_version=w.tensorflow_version or Version("1.15"),
4356 ),
4357 tensorflow_saved_model_bundle=(
4358 w := src.weights.tensorflow_saved_model_bundle
4359 )
4360 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)(
4361 authors=conv_authors(w.authors),
4362 parent=w.parent,
4363 source=w.source,
4364 tensorflow_version=w.tensorflow_version or Version("1.15"),
4365 dependencies=(
4366 None
4367 if w.dependencies is None
4368 else (FileDescr if TYPE_CHECKING else dict)(
4369 source=cast(
4370 FileSource,
4371 (
4372 str(w.dependencies)[len("conda:") :]
4373 if str(w.dependencies).startswith("conda:")
4374 else str(w.dependencies)
4375 ),
4376 )
4377 )
4378 ),
4379 ),
4380 torchscript=(w := src.weights.torchscript)
4381 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)(
4382 source=w.source,
4383 authors=conv_authors(w.authors),
4384 parent=w.parent,
4385 pytorch_version=w.pytorch_version or Version("1.10"),
4386 ),
4387 ),
4388 )
4391_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr)
4394# create better cover images for 3d data and non-image outputs
4395def generate_covers(
4396 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]],
4397 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]],
4398) -> List[Path]:
4399 def squeeze(
4400 data: NDArray[Any], axes: Sequence[AnyAxis]
4401 ) -> Tuple[NDArray[Any], List[AnyAxis]]:
4402 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining"""
4403 if data.ndim != len(axes):
4404 raise ValueError(
4405 f"tensor shape {data.shape} does not match described axes"
4406 + f" {[a.id for a in axes]}"
4407 )
4409 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1]
4410 return data.squeeze(), axes
4412 def normalize(
4413 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7
4414 ) -> NDArray[np.float32]:
4415 data = data.astype("float32")
4416 data -= data.min(axis=axis, keepdims=True)
4417 data /= data.max(axis=axis, keepdims=True) + eps
4418 return data
4420 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
4421 original_shape = data.shape
4422 original_axes = list(axes)
4423 data, axes = squeeze(data, axes)
4425 # take slice fom any batch or index axis if needed
4426 # and convert the first channel axis and take a slice from any additional channel axes
4427 slices: Tuple[slice, ...] = ()
4428 ndim = data.ndim
4429 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2
4430 has_c_axis = False
4431 for i, a in enumerate(axes):
4432 s = data.shape[i]
4433 assert s > 1
4434 if (
4435 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
4436 and ndim > ndim_need
4437 ):
4438 data = data[slices + (slice(s // 2 - 1, s // 2),)]
4439 ndim -= 1
4440 elif isinstance(a, ChannelAxis):
4441 if has_c_axis:
4442 # second channel axis
4443 data = data[slices + (slice(0, 1),)]
4444 ndim -= 1
4445 else:
4446 has_c_axis = True
4447 if s == 2:
4448 # visualize two channels with cyan and magenta
4449 data = np.concatenate(
4450 [
4451 data[slices + (slice(1, 2),)],
4452 data[slices + (slice(0, 1),)],
4453 (
4454 data[slices + (slice(0, 1),)]
4455 + data[slices + (slice(1, 2),)]
4456 )
4457 / 2, # TODO: take maximum instead?
4458 ],
4459 axis=i,
4460 )
4461 elif data.shape[i] == 3:
4462 pass # visualize 3 channels as RGB
4463 else:
4464 # visualize first 3 channels as RGB
4465 data = data[slices + (slice(3),)]
4467 assert data.shape[i] == 3
4469 slices += (slice(None),)
4471 data, axes = squeeze(data, axes)
4472 assert len(axes) == ndim
4473 # take slice from z axis if needed
4474 slices = ()
4475 if ndim > ndim_need:
4476 for i, a in enumerate(axes):
4477 s = data.shape[i]
4478 if a.id == AxisId("z"):
4479 data = data[slices + (slice(s // 2 - 1, s // 2),)]
4480 data, axes = squeeze(data, axes)
4481 ndim -= 1
4482 break
4484 slices += (slice(None),)
4486 # take slice from any space or time axis
4487 slices = ()
4489 for i, a in enumerate(axes):
4490 if ndim <= ndim_need:
4491 break
4493 s = data.shape[i]
4494 assert s > 1
4495 if isinstance(
4496 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis)
4497 ):
4498 data = data[slices + (slice(s // 2 - 1, s // 2),)]
4499 ndim -= 1
4501 slices += (slice(None),)
4503 del slices
4504 data, axes = squeeze(data, axes)
4505 assert len(axes) == ndim
4507 if (has_c_axis and ndim != 3) or (not has_c_axis and ndim != 2):
4508 raise ValueError(
4509 f"Failed to construct cover image from shape {original_shape} with axes {[a.id for a in original_axes]}."
4510 )
4512 if not has_c_axis:
4513 assert ndim == 2
4514 data = np.repeat(data[:, :, None], 3, axis=2)
4515 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB"))))
4516 ndim += 1
4518 assert ndim == 3
4520 # transpose axis order such that longest axis comes first...
4521 axis_order: List[int] = list(np.argsort(list(data.shape)))
4522 axis_order.reverse()
4523 # ... and channel axis is last
4524 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0]
4525 axis_order.append(axis_order.pop(c))
4526 axes = [axes[ao] for ao in axis_order]
4527 data = data.transpose(axis_order)
4529 # h, w = data.shape[:2]
4530 # if h / w in (1.0 or 2.0):
4531 # pass
4532 # elif h / w < 2:
4533 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images
4535 norm_along = (
4536 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None
4537 )
4538 # normalize the data and map to 8 bit
4539 data = normalize(data, norm_along)
4540 data = (data * 255).astype("uint8")
4542 return data
4544 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]):
4545 assert im0.dtype == im1.dtype == np.uint8
4546 assert im0.shape == im1.shape
4547 assert im0.ndim == 3
4548 N, M, C = im0.shape
4549 assert C == 3
4550 out = np.ones((N, M, C), dtype="uint8")
4551 for c in range(C):
4552 outc = np.tril(im0[..., c])
4553 mask = outc == 0
4554 outc[mask] = np.triu(im1[..., c])[mask]
4555 out[..., c] = outc
4557 return out
4559 if not inputs:
4560 raise ValueError("Missing test input tensor for cover generation.")
4562 if not outputs:
4563 raise ValueError("Missing test output tensor for cover generation.")
4565 ipt_descr, ipt = inputs[0]
4566 out_descr, out = outputs[0]
4568 ipt_img = to_2d_image(ipt, ipt_descr.axes)
4569 out_img = to_2d_image(out, out_descr.axes)
4571 cover_folder = Path(mkdtemp())
4572 if ipt_img.shape == out_img.shape:
4573 covers = [cover_folder / "cover.png"]
4574 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img))
4575 else:
4576 covers = [cover_folder / "input.png", cover_folder / "output.png"]
4577 imwrite(covers[0], ipt_img)
4578 imwrite(covers[1], out_img)
4580 return covers