bioimageio.spec.model

implementaions of all released minor versions are available in submodules:

 1# autogen: start
 2"""
 3implementaions of all released minor versions are available in submodules:
 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr`
 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr`
 6"""
 7
 8from typing import Union
 9
10from pydantic import Discriminator, Field
11from typing_extensions import Annotated
12
13from . import v0_4, v0_5
14
15ModelDescr = v0_5.ModelDescr
16ModelDescr_v0_4 = v0_4.ModelDescr
17ModelDescr_v0_5 = v0_5.ModelDescr
18
19AnyModelDescr = Annotated[
20    Union[
21        Annotated[ModelDescr_v0_4, Field(title="model 0.4")],
22        Annotated[ModelDescr_v0_5, Field(title="model 0.5")],
23    ],
24    Discriminator("format_version"),
25    Field(title="model"),
26]
27"""Union of any released model desription"""
28# autogen: stop
2635class ModelDescr(GenericModelDescrBase):
2636    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2637    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2638    """
2639
2640    implemented_format_version: ClassVar[Literal["0.5.6"]] = "0.5.6"
2641    if TYPE_CHECKING:
2642        format_version: Literal["0.5.6"] = "0.5.6"
2643    else:
2644        format_version: Literal["0.5.6"]
2645        """Version of the bioimage.io model description specification used.
2646        When creating a new model always use the latest micro/patch version described here.
2647        The `format_version` is important for any consumer software to understand how to parse the fields.
2648        """
2649
2650    implemented_type: ClassVar[Literal["model"]] = "model"
2651    if TYPE_CHECKING:
2652        type: Literal["model"] = "model"
2653    else:
2654        type: Literal["model"]
2655        """Specialized resource type 'model'"""
2656
2657    id: Optional[ModelId] = None
2658    """bioimage.io-wide unique resource identifier
2659    assigned by bioimage.io; version **un**specific."""
2660
2661    authors: FAIR[List[Author]] = Field(
2662        default_factory=cast(Callable[[], List[Author]], list)
2663    )
2664    """The authors are the creators of the model RDF and the primary points of contact."""
2665
2666    documentation: FAIR[Optional[FileSource_documentation]] = None
2667    """URL or relative path to a markdown file with additional documentation.
2668    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2669    The documentation should include a '#[#] Validation' (sub)section
2670    with details on how to quantitatively validate the model on unseen data."""
2671
2672    @field_validator("documentation", mode="after")
2673    @classmethod
2674    def _validate_documentation(
2675        cls, value: Optional[FileSource_documentation]
2676    ) -> Optional[FileSource_documentation]:
2677        if not get_validation_context().perform_io_checks or value is None:
2678            return value
2679
2680        doc_reader = get_reader(value)
2681        doc_content = doc_reader.read().decode(encoding="utf-8")
2682        if not re.search("#.*[vV]alidation", doc_content):
2683            issue_warning(
2684                "No '# Validation' (sub)section found in {value}.",
2685                value=value,
2686                field="documentation",
2687            )
2688
2689        return value
2690
2691    inputs: NotEmpty[Sequence[InputTensorDescr]]
2692    """Describes the input tensors expected by this model."""
2693
2694    @field_validator("inputs", mode="after")
2695    @classmethod
2696    def _validate_input_axes(
2697        cls, inputs: Sequence[InputTensorDescr]
2698    ) -> Sequence[InputTensorDescr]:
2699        input_size_refs = cls._get_axes_with_independent_size(inputs)
2700
2701        for i, ipt in enumerate(inputs):
2702            valid_independent_refs: Dict[
2703                Tuple[TensorId, AxisId],
2704                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2705            ] = {
2706                **{
2707                    (ipt.id, a.id): (ipt, a, a.size)
2708                    for a in ipt.axes
2709                    if not isinstance(a, BatchAxis)
2710                    and isinstance(a.size, (int, ParameterizedSize))
2711                },
2712                **input_size_refs,
2713            }
2714            for a, ax in enumerate(ipt.axes):
2715                cls._validate_axis(
2716                    "inputs",
2717                    i=i,
2718                    tensor_id=ipt.id,
2719                    a=a,
2720                    axis=ax,
2721                    valid_independent_refs=valid_independent_refs,
2722                )
2723        return inputs
2724
2725    @staticmethod
2726    def _validate_axis(
2727        field_name: str,
2728        i: int,
2729        tensor_id: TensorId,
2730        a: int,
2731        axis: AnyAxis,
2732        valid_independent_refs: Dict[
2733            Tuple[TensorId, AxisId],
2734            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2735        ],
2736    ):
2737        if isinstance(axis, BatchAxis) or isinstance(
2738            axis.size, (int, ParameterizedSize, DataDependentSize)
2739        ):
2740            return
2741        elif not isinstance(axis.size, SizeReference):
2742            assert_never(axis.size)
2743
2744        # validate axis.size SizeReference
2745        ref = (axis.size.tensor_id, axis.size.axis_id)
2746        if ref not in valid_independent_refs:
2747            raise ValueError(
2748                "Invalid tensor axis reference at"
2749                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2750            )
2751        if ref == (tensor_id, axis.id):
2752            raise ValueError(
2753                "Self-referencing not allowed for"
2754                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2755            )
2756        if axis.type == "channel":
2757            if valid_independent_refs[ref][1].type != "channel":
2758                raise ValueError(
2759                    "A channel axis' size may only reference another fixed size"
2760                    + " channel axis."
2761                )
2762            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2763                ref_size = valid_independent_refs[ref][2]
2764                assert isinstance(ref_size, int), (
2765                    "channel axis ref (another channel axis) has to specify fixed"
2766                    + " size"
2767                )
2768                generated_channel_names = [
2769                    Identifier(axis.channel_names.format(i=i))
2770                    for i in range(1, ref_size + 1)
2771                ]
2772                axis.channel_names = generated_channel_names
2773
2774        if (ax_unit := getattr(axis, "unit", None)) != (
2775            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2776        ):
2777            raise ValueError(
2778                "The units of an axis and its reference axis need to match, but"
2779                + f" '{ax_unit}' != '{ref_unit}'."
2780            )
2781        ref_axis = valid_independent_refs[ref][1]
2782        if isinstance(ref_axis, BatchAxis):
2783            raise ValueError(
2784                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2785                + " (a batch axis is not allowed as reference)."
2786            )
2787
2788        if isinstance(axis, WithHalo):
2789            min_size = axis.size.get_size(axis, ref_axis, n=0)
2790            if (min_size - 2 * axis.halo) < 1:
2791                raise ValueError(
2792                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2793                    + f" {axis.halo}."
2794                )
2795
2796            input_halo = axis.halo * axis.scale / ref_axis.scale
2797            if input_halo != int(input_halo) or input_halo % 2 == 1:
2798                raise ValueError(
2799                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2800                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2801                    + f"     {tensor_id}.{axis.id}."
2802                )
2803
2804    @model_validator(mode="after")
2805    def _validate_test_tensors(self) -> Self:
2806        if not get_validation_context().perform_io_checks:
2807            return self
2808
2809        test_output_arrays = [
2810            None if descr.test_tensor is None else load_array(descr.test_tensor)
2811            for descr in self.outputs
2812        ]
2813        test_input_arrays = [
2814            None if descr.test_tensor is None else load_array(descr.test_tensor)
2815            for descr in self.inputs
2816        ]
2817
2818        tensors = {
2819            descr.id: (descr, array)
2820            for descr, array in zip(
2821                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2822            )
2823        }
2824        validate_tensors(tensors, tensor_origin="test_tensor")
2825
2826        output_arrays = {
2827            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2828        }
2829        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2830            if not rep_tol.absolute_tolerance:
2831                continue
2832
2833            if rep_tol.output_ids:
2834                out_arrays = {
2835                    oid: a
2836                    for oid, a in output_arrays.items()
2837                    if oid in rep_tol.output_ids
2838                }
2839            else:
2840                out_arrays = output_arrays
2841
2842            for out_id, array in out_arrays.items():
2843                if array is None:
2844                    continue
2845
2846                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2847                    raise ValueError(
2848                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2849                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2850                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2851                    )
2852
2853        return self
2854
2855    @model_validator(mode="after")
2856    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2857        ipt_refs = {t.id for t in self.inputs}
2858        out_refs = {t.id for t in self.outputs}
2859        for ipt in self.inputs:
2860            for p in ipt.preprocessing:
2861                ref = p.kwargs.get("reference_tensor")
2862                if ref is None:
2863                    continue
2864                if ref not in ipt_refs:
2865                    raise ValueError(
2866                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2867                        + f" references are: {ipt_refs}."
2868                    )
2869
2870        for out in self.outputs:
2871            for p in out.postprocessing:
2872                ref = p.kwargs.get("reference_tensor")
2873                if ref is None:
2874                    continue
2875
2876                if ref not in ipt_refs and ref not in out_refs:
2877                    raise ValueError(
2878                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2879                        + f" are: {ipt_refs | out_refs}."
2880                    )
2881
2882        return self
2883
2884    # TODO: use validate funcs in validate_test_tensors
2885    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2886
2887    name: Annotated[
2888        str,
2889        RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2890        MinLen(5),
2891        MaxLen(128),
2892        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2893    ]
2894    """A human-readable name of this model.
2895    It should be no longer than 64 characters
2896    and may only contain letter, number, underscore, minus, parentheses and spaces.
2897    We recommend to chose a name that refers to the model's task and image modality.
2898    """
2899
2900    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2901    """Describes the output tensors."""
2902
2903    @field_validator("outputs", mode="after")
2904    @classmethod
2905    def _validate_tensor_ids(
2906        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2907    ) -> Sequence[OutputTensorDescr]:
2908        tensor_ids = [
2909            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2910        ]
2911        duplicate_tensor_ids: List[str] = []
2912        seen: Set[str] = set()
2913        for t in tensor_ids:
2914            if t in seen:
2915                duplicate_tensor_ids.append(t)
2916
2917            seen.add(t)
2918
2919        if duplicate_tensor_ids:
2920            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2921
2922        return outputs
2923
2924    @staticmethod
2925    def _get_axes_with_parameterized_size(
2926        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2927    ):
2928        return {
2929            f"{t.id}.{a.id}": (t, a, a.size)
2930            for t in io
2931            for a in t.axes
2932            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2933        }
2934
2935    @staticmethod
2936    def _get_axes_with_independent_size(
2937        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2938    ):
2939        return {
2940            (t.id, a.id): (t, a, a.size)
2941            for t in io
2942            for a in t.axes
2943            if not isinstance(a, BatchAxis)
2944            and isinstance(a.size, (int, ParameterizedSize))
2945        }
2946
2947    @field_validator("outputs", mode="after")
2948    @classmethod
2949    def _validate_output_axes(
2950        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2951    ) -> List[OutputTensorDescr]:
2952        input_size_refs = cls._get_axes_with_independent_size(
2953            info.data.get("inputs", [])
2954        )
2955        output_size_refs = cls._get_axes_with_independent_size(outputs)
2956
2957        for i, out in enumerate(outputs):
2958            valid_independent_refs: Dict[
2959                Tuple[TensorId, AxisId],
2960                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2961            ] = {
2962                **{
2963                    (out.id, a.id): (out, a, a.size)
2964                    for a in out.axes
2965                    if not isinstance(a, BatchAxis)
2966                    and isinstance(a.size, (int, ParameterizedSize))
2967                },
2968                **input_size_refs,
2969                **output_size_refs,
2970            }
2971            for a, ax in enumerate(out.axes):
2972                cls._validate_axis(
2973                    "outputs",
2974                    i,
2975                    out.id,
2976                    a,
2977                    ax,
2978                    valid_independent_refs=valid_independent_refs,
2979                )
2980
2981        return outputs
2982
2983    packaged_by: List[Author] = Field(
2984        default_factory=cast(Callable[[], List[Author]], list)
2985    )
2986    """The persons that have packaged and uploaded this model.
2987    Only required if those persons differ from the `authors`."""
2988
2989    parent: Optional[LinkedModel] = None
2990    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2991
2992    @model_validator(mode="after")
2993    def _validate_parent_is_not_self(self) -> Self:
2994        if self.parent is not None and self.parent.id == self.id:
2995            raise ValueError("A model description may not reference itself as parent.")
2996
2997        return self
2998
2999    run_mode: Annotated[
3000        Optional[RunMode],
3001        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
3002    ] = None
3003    """Custom run mode for this model: for more complex prediction procedures like test time
3004    data augmentation that currently cannot be expressed in the specification.
3005    No standard run modes are defined yet."""
3006
3007    timestamp: Datetime = Field(default_factory=Datetime.now)
3008    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
3009    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
3010    (In Python a datetime object is valid, too)."""
3011
3012    training_data: Annotated[
3013        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
3014        Field(union_mode="left_to_right"),
3015    ] = None
3016    """The dataset used to train this model"""
3017
3018    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
3019    """The weights for this model.
3020    Weights can be given for different formats, but should otherwise be equivalent.
3021    The available weight formats determine which consumers can use this model."""
3022
3023    config: Config = Field(default_factory=Config.model_construct)
3024
3025    @model_validator(mode="after")
3026    def _add_default_cover(self) -> Self:
3027        if not get_validation_context().perform_io_checks or self.covers:
3028            return self
3029
3030        try:
3031            generated_covers = generate_covers(
3032                [
3033                    (t, load_array(t.test_tensor))
3034                    for t in self.inputs
3035                    if t.test_tensor is not None
3036                ],
3037                [
3038                    (t, load_array(t.test_tensor))
3039                    for t in self.outputs
3040                    if t.test_tensor is not None
3041                ],
3042            )
3043        except Exception as e:
3044            issue_warning(
3045                "Failed to generate cover image(s): {e}",
3046                value=self.covers,
3047                msg_context=dict(e=e),
3048                field="covers",
3049            )
3050        else:
3051            self.covers.extend(generated_covers)
3052
3053        return self
3054
3055    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3056        return self._get_test_arrays(self.inputs)
3057
3058    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3059        return self._get_test_arrays(self.outputs)
3060
3061    @staticmethod
3062    def _get_test_arrays(
3063        io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3064    ):
3065        ts: List[FileDescr] = []
3066        for d in io_descr:
3067            if d.test_tensor is None:
3068                raise ValueError(
3069                    f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3070                )
3071            ts.append(d.test_tensor)
3072
3073        data = [load_array(t) for t in ts]
3074        assert all(isinstance(d, np.ndarray) for d in data)
3075        return data
3076
3077    @staticmethod
3078    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3079        batch_size = 1
3080        tensor_with_batchsize: Optional[TensorId] = None
3081        for tid in tensor_sizes:
3082            for aid, s in tensor_sizes[tid].items():
3083                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3084                    continue
3085
3086                if batch_size != 1:
3087                    assert tensor_with_batchsize is not None
3088                    raise ValueError(
3089                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3090                    )
3091
3092                batch_size = s
3093                tensor_with_batchsize = tid
3094
3095        return batch_size
3096
3097    def get_output_tensor_sizes(
3098        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3099    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3100        """Returns the tensor output sizes for given **input_sizes**.
3101        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3102        Otherwise it might be larger than the actual (valid) output"""
3103        batch_size = self.get_batch_size(input_sizes)
3104        ns = self.get_ns(input_sizes)
3105
3106        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3107        return tensor_sizes.outputs
3108
3109    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3110        """get parameter `n` for each parameterized axis
3111        such that the valid input size is >= the given input size"""
3112        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3113        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3114        for tid in input_sizes:
3115            for aid, s in input_sizes[tid].items():
3116                size_descr = axes[tid][aid].size
3117                if isinstance(size_descr, ParameterizedSize):
3118                    ret[(tid, aid)] = size_descr.get_n(s)
3119                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3120                    pass
3121                else:
3122                    assert_never(size_descr)
3123
3124        return ret
3125
3126    def get_tensor_sizes(
3127        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3128    ) -> _TensorSizes:
3129        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3130        return _TensorSizes(
3131            {
3132                t: {
3133                    aa: axis_sizes.inputs[(tt, aa)]
3134                    for tt, aa in axis_sizes.inputs
3135                    if tt == t
3136                }
3137                for t in {tt for tt, _ in axis_sizes.inputs}
3138            },
3139            {
3140                t: {
3141                    aa: axis_sizes.outputs[(tt, aa)]
3142                    for tt, aa in axis_sizes.outputs
3143                    if tt == t
3144                }
3145                for t in {tt for tt, _ in axis_sizes.outputs}
3146            },
3147        )
3148
3149    def get_axis_sizes(
3150        self,
3151        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3152        batch_size: Optional[int] = None,
3153        *,
3154        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3155    ) -> _AxisSizes:
3156        """Determine input and output block shape for scale factors **ns**
3157        of parameterized input sizes.
3158
3159        Args:
3160            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3161                that is parameterized as `size = min + n * step`.
3162            batch_size: The desired size of the batch dimension.
3163                If given **batch_size** overwrites any batch size present in
3164                **max_input_shape**. Default 1.
3165            max_input_shape: Limits the derived block shapes.
3166                Each axis for which the input size, parameterized by `n`, is larger
3167                than **max_input_shape** is set to the minimal value `n_min` for which
3168                this is still true.
3169                Use this for small input samples or large values of **ns**.
3170                Or simply whenever you know the full input shape.
3171
3172        Returns:
3173            Resolved axis sizes for model inputs and outputs.
3174        """
3175        max_input_shape = max_input_shape or {}
3176        if batch_size is None:
3177            for (_t_id, a_id), s in max_input_shape.items():
3178                if a_id == BATCH_AXIS_ID:
3179                    batch_size = s
3180                    break
3181            else:
3182                batch_size = 1
3183
3184        all_axes = {
3185            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3186        }
3187
3188        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3189        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3190
3191        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3192            if isinstance(a, BatchAxis):
3193                if (t_descr.id, a.id) in ns:
3194                    logger.warning(
3195                        "Ignoring unexpected size increment factor (n) for batch axis"
3196                        + " of tensor '{}'.",
3197                        t_descr.id,
3198                    )
3199                return batch_size
3200            elif isinstance(a.size, int):
3201                if (t_descr.id, a.id) in ns:
3202                    logger.warning(
3203                        "Ignoring unexpected size increment factor (n) for fixed size"
3204                        + " axis '{}' of tensor '{}'.",
3205                        a.id,
3206                        t_descr.id,
3207                    )
3208                return a.size
3209            elif isinstance(a.size, ParameterizedSize):
3210                if (t_descr.id, a.id) not in ns:
3211                    raise ValueError(
3212                        "Size increment factor (n) missing for parametrized axis"
3213                        + f" '{a.id}' of tensor '{t_descr.id}'."
3214                    )
3215                n = ns[(t_descr.id, a.id)]
3216                s_max = max_input_shape.get((t_descr.id, a.id))
3217                if s_max is not None:
3218                    n = min(n, a.size.get_n(s_max))
3219
3220                return a.size.get_size(n)
3221
3222            elif isinstance(a.size, SizeReference):
3223                if (t_descr.id, a.id) in ns:
3224                    logger.warning(
3225                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3226                        + " of tensor '{}' with size reference.",
3227                        a.id,
3228                        t_descr.id,
3229                    )
3230                assert not isinstance(a, BatchAxis)
3231                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3232                assert not isinstance(ref_axis, BatchAxis)
3233                ref_key = (a.size.tensor_id, a.size.axis_id)
3234                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3235                assert ref_size is not None, ref_key
3236                assert not isinstance(ref_size, _DataDepSize), ref_key
3237                return a.size.get_size(
3238                    axis=a,
3239                    ref_axis=ref_axis,
3240                    ref_size=ref_size,
3241                )
3242            elif isinstance(a.size, DataDependentSize):
3243                if (t_descr.id, a.id) in ns:
3244                    logger.warning(
3245                        "Ignoring unexpected increment factor (n) for data dependent"
3246                        + " size axis '{}' of tensor '{}'.",
3247                        a.id,
3248                        t_descr.id,
3249                    )
3250                return _DataDepSize(a.size.min, a.size.max)
3251            else:
3252                assert_never(a.size)
3253
3254        # first resolve all , but the `SizeReference` input sizes
3255        for t_descr in self.inputs:
3256            for a in t_descr.axes:
3257                if not isinstance(a.size, SizeReference):
3258                    s = get_axis_size(a)
3259                    assert not isinstance(s, _DataDepSize)
3260                    inputs[t_descr.id, a.id] = s
3261
3262        # resolve all other input axis sizes
3263        for t_descr in self.inputs:
3264            for a in t_descr.axes:
3265                if isinstance(a.size, SizeReference):
3266                    s = get_axis_size(a)
3267                    assert not isinstance(s, _DataDepSize)
3268                    inputs[t_descr.id, a.id] = s
3269
3270        # resolve all output axis sizes
3271        for t_descr in self.outputs:
3272            for a in t_descr.axes:
3273                assert not isinstance(a.size, ParameterizedSize)
3274                s = get_axis_size(a)
3275                outputs[t_descr.id, a.id] = s
3276
3277        return _AxisSizes(inputs=inputs, outputs=outputs)
3278
3279    @model_validator(mode="before")
3280    @classmethod
3281    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3282        cls.convert_from_old_format_wo_validation(data)
3283        return data
3284
3285    @classmethod
3286    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3287        """Convert metadata following an older format version to this classes' format
3288        without validating the result.
3289        """
3290        if (
3291            data.get("type") == "model"
3292            and isinstance(fv := data.get("format_version"), str)
3293            and fv.count(".") == 2
3294        ):
3295            fv_parts = fv.split(".")
3296            if any(not p.isdigit() for p in fv_parts):
3297                return
3298
3299            fv_tuple = tuple(map(int, fv_parts))
3300
3301            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3302            if fv_tuple[:2] in ((0, 3), (0, 4)):
3303                m04 = _ModelDescr_v0_4.load(data)
3304                if isinstance(m04, InvalidDescr):
3305                    try:
3306                        updated = _model_conv.convert_as_dict(
3307                            m04  # pyright: ignore[reportArgumentType]
3308                        )
3309                    except Exception as e:
3310                        logger.error(
3311                            "Failed to convert from invalid model 0.4 description."
3312                            + f"\nerror: {e}"
3313                            + "\nProceeding with model 0.5 validation without conversion."
3314                        )
3315                        updated = None
3316                else:
3317                    updated = _model_conv.convert_as_dict(m04)
3318
3319                if updated is not None:
3320                    data.clear()
3321                    data.update(updated)
3322
3323            elif fv_tuple[:2] == (0, 5):
3324                # bump patch version
3325                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.6']] = '0.5.6'
implemented_type: ClassVar[Literal['model']] = 'model'
id: Optional[bioimageio.spec.model.v0_5.ModelId] = None

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fc593b73a60>, severity=35, msg=None, context=None)] = PydanticUndefined

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Optional[Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fc5a20b42c0>), PlainSerializer(func=<function _package_serializer at 0x7fc593b71940>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.md', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fc593b73a60>, severity=35, msg=None, context=None)] = None

URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[bioimageio.spec.model.v0_5.InputTensorDescr], MinLen(min_length=1)] = PydanticUndefined

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fc58b997560>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})] = PydanticUndefined

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[bioimageio.spec.model.v0_5.OutputTensorDescr], MinLen(min_length=1)] = PydanticUndefined

Describes the output tensors.

packaged_by: List[bioimageio.spec.generic.v0_3.Author] = PydanticUndefined

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

parent: Optional[bioimageio.spec.model.v0_5.LinkedModel] = None

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fc58b996b60>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})] = None

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

timestamp: bioimageio.spec._internal.types.Datetime = PydanticUndefined

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])] = None

The dataset used to train this model

weights: Annotated[bioimageio.spec.model.v0_5.WeightsDescr, WrapSerializer(func=<function package_weights at 0x7fc593986340>, return_type=PydanticUndefined, when_used='always')] = PydanticUndefined

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

config: bioimageio.spec.model.v0_5.Config = PydanticUndefined
def get_input_test_arrays(self) -> List[numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]]:
3055    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3056        return self._get_test_arrays(self.inputs)
def get_output_test_arrays(self) -> List[numpy.ndarray[tuple[int, ...], numpy.dtype[Any]]]:
3058    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3059        return self._get_test_arrays(self.outputs)
@staticmethod
def get_batch_size( tensor_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> int:
3077    @staticmethod
3078    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3079        batch_size = 1
3080        tensor_with_batchsize: Optional[TensorId] = None
3081        for tid in tensor_sizes:
3082            for aid, s in tensor_sizes[tid].items():
3083                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3084                    continue
3085
3086                if batch_size != 1:
3087                    assert tensor_with_batchsize is not None
3088                    raise ValueError(
3089                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3090                    )
3091
3092                batch_size = s
3093                tensor_with_batchsize = tid
3094
3095        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Dict[bioimageio.spec.model.v0_5.TensorId, Dict[bioimageio.spec.model.v0_5.AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
3097    def get_output_tensor_sizes(
3098        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3099    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3100        """Returns the tensor output sizes for given **input_sizes**.
3101        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3102        Otherwise it might be larger than the actual (valid) output"""
3103        batch_size = self.get_batch_size(input_sizes)
3104        ns = self.get_ns(input_sizes)
3105
3106        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3107        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]):
3109    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3110        """get parameter `n` for each parameterized axis
3111        such that the valid input size is >= the given input size"""
3112        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3113        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3114        for tid in input_sizes:
3115            for aid, s in input_sizes[tid].items():
3116                size_descr = axes[tid][aid].size
3117                if isinstance(size_descr, ParameterizedSize):
3118                    ret[(tid, aid)] = size_descr.get_n(s)
3119                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3120                    pass
3121                else:
3122                    assert_never(size_descr)
3123
3124        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
3126    def get_tensor_sizes(
3127        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3128    ) -> _TensorSizes:
3129        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3130        return _TensorSizes(
3131            {
3132                t: {
3133                    aa: axis_sizes.inputs[(tt, aa)]
3134                    for tt, aa in axis_sizes.inputs
3135                    if tt == t
3136                }
3137                for t in {tt for tt, _ in axis_sizes.inputs}
3138            },
3139            {
3140                t: {
3141                    aa: axis_sizes.outputs[(tt, aa)]
3142                    for tt, aa in axis_sizes.outputs
3143                    if tt == t
3144                }
3145                for t in {tt for tt, _ in axis_sizes.outputs}
3146            },
3147        )
def get_axis_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3149    def get_axis_sizes(
3150        self,
3151        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3152        batch_size: Optional[int] = None,
3153        *,
3154        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3155    ) -> _AxisSizes:
3156        """Determine input and output block shape for scale factors **ns**
3157        of parameterized input sizes.
3158
3159        Args:
3160            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3161                that is parameterized as `size = min + n * step`.
3162            batch_size: The desired size of the batch dimension.
3163                If given **batch_size** overwrites any batch size present in
3164                **max_input_shape**. Default 1.
3165            max_input_shape: Limits the derived block shapes.
3166                Each axis for which the input size, parameterized by `n`, is larger
3167                than **max_input_shape** is set to the minimal value `n_min` for which
3168                this is still true.
3169                Use this for small input samples or large values of **ns**.
3170                Or simply whenever you know the full input shape.
3171
3172        Returns:
3173            Resolved axis sizes for model inputs and outputs.
3174        """
3175        max_input_shape = max_input_shape or {}
3176        if batch_size is None:
3177            for (_t_id, a_id), s in max_input_shape.items():
3178                if a_id == BATCH_AXIS_ID:
3179                    batch_size = s
3180                    break
3181            else:
3182                batch_size = 1
3183
3184        all_axes = {
3185            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3186        }
3187
3188        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3189        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3190
3191        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3192            if isinstance(a, BatchAxis):
3193                if (t_descr.id, a.id) in ns:
3194                    logger.warning(
3195                        "Ignoring unexpected size increment factor (n) for batch axis"
3196                        + " of tensor '{}'.",
3197                        t_descr.id,
3198                    )
3199                return batch_size
3200            elif isinstance(a.size, int):
3201                if (t_descr.id, a.id) in ns:
3202                    logger.warning(
3203                        "Ignoring unexpected size increment factor (n) for fixed size"
3204                        + " axis '{}' of tensor '{}'.",
3205                        a.id,
3206                        t_descr.id,
3207                    )
3208                return a.size
3209            elif isinstance(a.size, ParameterizedSize):
3210                if (t_descr.id, a.id) not in ns:
3211                    raise ValueError(
3212                        "Size increment factor (n) missing for parametrized axis"
3213                        + f" '{a.id}' of tensor '{t_descr.id}'."
3214                    )
3215                n = ns[(t_descr.id, a.id)]
3216                s_max = max_input_shape.get((t_descr.id, a.id))
3217                if s_max is not None:
3218                    n = min(n, a.size.get_n(s_max))
3219
3220                return a.size.get_size(n)
3221
3222            elif isinstance(a.size, SizeReference):
3223                if (t_descr.id, a.id) in ns:
3224                    logger.warning(
3225                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3226                        + " of tensor '{}' with size reference.",
3227                        a.id,
3228                        t_descr.id,
3229                    )
3230                assert not isinstance(a, BatchAxis)
3231                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3232                assert not isinstance(ref_axis, BatchAxis)
3233                ref_key = (a.size.tensor_id, a.size.axis_id)
3234                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3235                assert ref_size is not None, ref_key
3236                assert not isinstance(ref_size, _DataDepSize), ref_key
3237                return a.size.get_size(
3238                    axis=a,
3239                    ref_axis=ref_axis,
3240                    ref_size=ref_size,
3241                )
3242            elif isinstance(a.size, DataDependentSize):
3243                if (t_descr.id, a.id) in ns:
3244                    logger.warning(
3245                        "Ignoring unexpected increment factor (n) for data dependent"
3246                        + " size axis '{}' of tensor '{}'.",
3247                        a.id,
3248                        t_descr.id,
3249                    )
3250                return _DataDepSize(a.size.min, a.size.max)
3251            else:
3252                assert_never(a.size)
3253
3254        # first resolve all , but the `SizeReference` input sizes
3255        for t_descr in self.inputs:
3256            for a in t_descr.axes:
3257                if not isinstance(a.size, SizeReference):
3258                    s = get_axis_size(a)
3259                    assert not isinstance(s, _DataDepSize)
3260                    inputs[t_descr.id, a.id] = s
3261
3262        # resolve all other input axis sizes
3263        for t_descr in self.inputs:
3264            for a in t_descr.axes:
3265                if isinstance(a.size, SizeReference):
3266                    s = get_axis_size(a)
3267                    assert not isinstance(s, _DataDepSize)
3268                    inputs[t_descr.id, a.id] = s
3269
3270        # resolve all output axis sizes
3271        for t_descr in self.outputs:
3272            for a in t_descr.axes:
3273                assert not isinstance(a.size, ParameterizedSize)
3274                s = get_axis_size(a)
3275                outputs[t_descr.id, a.id] = s
3276
3277        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3285    @classmethod
3286    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3287        """Convert metadata following an older format version to this classes' format
3288        without validating the result.
3289        """
3290        if (
3291            data.get("type") == "model"
3292            and isinstance(fv := data.get("format_version"), str)
3293            and fv.count(".") == 2
3294        ):
3295            fv_parts = fv.split(".")
3296            if any(not p.isdigit() for p in fv_parts):
3297                return
3298
3299            fv_tuple = tuple(map(int, fv_parts))
3300
3301            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3302            if fv_tuple[:2] in ((0, 3), (0, 4)):
3303                m04 = _ModelDescr_v0_4.load(data)
3304                if isinstance(m04, InvalidDescr):
3305                    try:
3306                        updated = _model_conv.convert_as_dict(
3307                            m04  # pyright: ignore[reportArgumentType]
3308                        )
3309                    except Exception as e:
3310                        logger.error(
3311                            "Failed to convert from invalid model 0.4 description."
3312                            + f"\nerror: {e}"
3313                            + "\nProceeding with model 0.5 validation without conversion."
3314                        )
3315                        updated = None
3316                else:
3317                    updated = _model_conv.convert_as_dict(m04)
3318
3319                if updated is not None:
3320                    data.clear()
3321                    data.update(updated)
3322
3323            elif fv_tuple[:2] == (0, 5):
3324                # bump patch version
3325                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 6)
ModelDescr_v0_4 = <class 'bioimageio.spec.model.v0_4.ModelDescr'>
ModelDescr_v0_5 = <class 'ModelDescr'>
AnyModelDescr = typing.Annotated[typing.Union[typing.Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], typing.Annotated[ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]

Union of any released model desription