bioimageio.spec.model

implementaions of all released minor versions are available in submodules:

 1# autogen: start
 2"""
 3implementaions of all released minor versions are available in submodules:
 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr`
 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr`
 6"""
 7
 8from typing import Union
 9
10from pydantic import Discriminator, Field
11from typing_extensions import Annotated
12
13from . import v0_4, v0_5
14
15ModelDescr = v0_5.ModelDescr
16ModelDescr_v0_4 = v0_4.ModelDescr
17ModelDescr_v0_5 = v0_5.ModelDescr
18
19AnyModelDescr = Annotated[
20    Union[
21        Annotated[ModelDescr_v0_4, Field(title="model 0.4")],
22        Annotated[ModelDescr_v0_5, Field(title="model 0.5")],
23    ],
24    Discriminator("format_version"),
25    Field(title="model"),
26]
27"""Union of any released model desription"""
28# autogen: stop
2606class ModelDescr(GenericModelDescrBase):
2607    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2608    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2609    """
2610
2611    implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5"
2612    if TYPE_CHECKING:
2613        format_version: Literal["0.5.5"] = "0.5.5"
2614    else:
2615        format_version: Literal["0.5.5"]
2616        """Version of the bioimage.io model description specification used.
2617        When creating a new model always use the latest micro/patch version described here.
2618        The `format_version` is important for any consumer software to understand how to parse the fields.
2619        """
2620
2621    implemented_type: ClassVar[Literal["model"]] = "model"
2622    if TYPE_CHECKING:
2623        type: Literal["model"] = "model"
2624    else:
2625        type: Literal["model"]
2626        """Specialized resource type 'model'"""
2627
2628    id: Optional[ModelId] = None
2629    """bioimage.io-wide unique resource identifier
2630    assigned by bioimage.io; version **un**specific."""
2631
2632    authors: FAIR[List[Author]] = Field(
2633        default_factory=cast(Callable[[], List[Author]], list)
2634    )
2635    """The authors are the creators of the model RDF and the primary points of contact."""
2636
2637    documentation: FAIR[Optional[FileSource_documentation]] = None
2638    """URL or relative path to a markdown file with additional documentation.
2639    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2640    The documentation should include a '#[#] Validation' (sub)section
2641    with details on how to quantitatively validate the model on unseen data."""
2642
2643    @field_validator("documentation", mode="after")
2644    @classmethod
2645    def _validate_documentation(
2646        cls, value: Optional[FileSource_documentation]
2647    ) -> Optional[FileSource_documentation]:
2648        if not get_validation_context().perform_io_checks or value is None:
2649            return value
2650
2651        doc_reader = get_reader(value)
2652        doc_content = doc_reader.read().decode(encoding="utf-8")
2653        if not re.search("#.*[vV]alidation", doc_content):
2654            issue_warning(
2655                "No '# Validation' (sub)section found in {value}.",
2656                value=value,
2657                field="documentation",
2658            )
2659
2660        return value
2661
2662    inputs: NotEmpty[Sequence[InputTensorDescr]]
2663    """Describes the input tensors expected by this model."""
2664
2665    @field_validator("inputs", mode="after")
2666    @classmethod
2667    def _validate_input_axes(
2668        cls, inputs: Sequence[InputTensorDescr]
2669    ) -> Sequence[InputTensorDescr]:
2670        input_size_refs = cls._get_axes_with_independent_size(inputs)
2671
2672        for i, ipt in enumerate(inputs):
2673            valid_independent_refs: Dict[
2674                Tuple[TensorId, AxisId],
2675                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2676            ] = {
2677                **{
2678                    (ipt.id, a.id): (ipt, a, a.size)
2679                    for a in ipt.axes
2680                    if not isinstance(a, BatchAxis)
2681                    and isinstance(a.size, (int, ParameterizedSize))
2682                },
2683                **input_size_refs,
2684            }
2685            for a, ax in enumerate(ipt.axes):
2686                cls._validate_axis(
2687                    "inputs",
2688                    i=i,
2689                    tensor_id=ipt.id,
2690                    a=a,
2691                    axis=ax,
2692                    valid_independent_refs=valid_independent_refs,
2693                )
2694        return inputs
2695
2696    @staticmethod
2697    def _validate_axis(
2698        field_name: str,
2699        i: int,
2700        tensor_id: TensorId,
2701        a: int,
2702        axis: AnyAxis,
2703        valid_independent_refs: Dict[
2704            Tuple[TensorId, AxisId],
2705            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2706        ],
2707    ):
2708        if isinstance(axis, BatchAxis) or isinstance(
2709            axis.size, (int, ParameterizedSize, DataDependentSize)
2710        ):
2711            return
2712        elif not isinstance(axis.size, SizeReference):
2713            assert_never(axis.size)
2714
2715        # validate axis.size SizeReference
2716        ref = (axis.size.tensor_id, axis.size.axis_id)
2717        if ref not in valid_independent_refs:
2718            raise ValueError(
2719                "Invalid tensor axis reference at"
2720                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2721            )
2722        if ref == (tensor_id, axis.id):
2723            raise ValueError(
2724                "Self-referencing not allowed for"
2725                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2726            )
2727        if axis.type == "channel":
2728            if valid_independent_refs[ref][1].type != "channel":
2729                raise ValueError(
2730                    "A channel axis' size may only reference another fixed size"
2731                    + " channel axis."
2732                )
2733            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2734                ref_size = valid_independent_refs[ref][2]
2735                assert isinstance(ref_size, int), (
2736                    "channel axis ref (another channel axis) has to specify fixed"
2737                    + " size"
2738                )
2739                generated_channel_names = [
2740                    Identifier(axis.channel_names.format(i=i))
2741                    for i in range(1, ref_size + 1)
2742                ]
2743                axis.channel_names = generated_channel_names
2744
2745        if (ax_unit := getattr(axis, "unit", None)) != (
2746            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2747        ):
2748            raise ValueError(
2749                "The units of an axis and its reference axis need to match, but"
2750                + f" '{ax_unit}' != '{ref_unit}'."
2751            )
2752        ref_axis = valid_independent_refs[ref][1]
2753        if isinstance(ref_axis, BatchAxis):
2754            raise ValueError(
2755                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2756                + " (a batch axis is not allowed as reference)."
2757            )
2758
2759        if isinstance(axis, WithHalo):
2760            min_size = axis.size.get_size(axis, ref_axis, n=0)
2761            if (min_size - 2 * axis.halo) < 1:
2762                raise ValueError(
2763                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2764                    + f" {axis.halo}."
2765                )
2766
2767            input_halo = axis.halo * axis.scale / ref_axis.scale
2768            if input_halo != int(input_halo) or input_halo % 2 == 1:
2769                raise ValueError(
2770                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2771                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2772                    + f"     {tensor_id}.{axis.id}."
2773                )
2774
2775    @model_validator(mode="after")
2776    def _validate_test_tensors(self) -> Self:
2777        if not get_validation_context().perform_io_checks:
2778            return self
2779
2780        test_output_arrays = [
2781            None if descr.test_tensor is None else load_array(descr.test_tensor)
2782            for descr in self.outputs
2783        ]
2784        test_input_arrays = [
2785            None if descr.test_tensor is None else load_array(descr.test_tensor)
2786            for descr in self.inputs
2787        ]
2788
2789        tensors = {
2790            descr.id: (descr, array)
2791            for descr, array in zip(
2792                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2793            )
2794        }
2795        validate_tensors(tensors, tensor_origin="test_tensor")
2796
2797        output_arrays = {
2798            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2799        }
2800        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2801            if not rep_tol.absolute_tolerance:
2802                continue
2803
2804            if rep_tol.output_ids:
2805                out_arrays = {
2806                    oid: a
2807                    for oid, a in output_arrays.items()
2808                    if oid in rep_tol.output_ids
2809                }
2810            else:
2811                out_arrays = output_arrays
2812
2813            for out_id, array in out_arrays.items():
2814                if array is None:
2815                    continue
2816
2817                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2818                    raise ValueError(
2819                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2820                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2821                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2822                    )
2823
2824        return self
2825
2826    @model_validator(mode="after")
2827    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2828        ipt_refs = {t.id for t in self.inputs}
2829        out_refs = {t.id for t in self.outputs}
2830        for ipt in self.inputs:
2831            for p in ipt.preprocessing:
2832                ref = p.kwargs.get("reference_tensor")
2833                if ref is None:
2834                    continue
2835                if ref not in ipt_refs:
2836                    raise ValueError(
2837                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2838                        + f" references are: {ipt_refs}."
2839                    )
2840
2841        for out in self.outputs:
2842            for p in out.postprocessing:
2843                ref = p.kwargs.get("reference_tensor")
2844                if ref is None:
2845                    continue
2846
2847                if ref not in ipt_refs and ref not in out_refs:
2848                    raise ValueError(
2849                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2850                        + f" are: {ipt_refs | out_refs}."
2851                    )
2852
2853        return self
2854
2855    # TODO: use validate funcs in validate_test_tensors
2856    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2857
2858    name: Annotated[
2859        str,
2860        RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2861        MinLen(5),
2862        MaxLen(128),
2863        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2864    ]
2865    """A human-readable name of this model.
2866    It should be no longer than 64 characters
2867    and may only contain letter, number, underscore, minus, parentheses and spaces.
2868    We recommend to chose a name that refers to the model's task and image modality.
2869    """
2870
2871    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2872    """Describes the output tensors."""
2873
2874    @field_validator("outputs", mode="after")
2875    @classmethod
2876    def _validate_tensor_ids(
2877        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2878    ) -> Sequence[OutputTensorDescr]:
2879        tensor_ids = [
2880            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2881        ]
2882        duplicate_tensor_ids: List[str] = []
2883        seen: Set[str] = set()
2884        for t in tensor_ids:
2885            if t in seen:
2886                duplicate_tensor_ids.append(t)
2887
2888            seen.add(t)
2889
2890        if duplicate_tensor_ids:
2891            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2892
2893        return outputs
2894
2895    @staticmethod
2896    def _get_axes_with_parameterized_size(
2897        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2898    ):
2899        return {
2900            f"{t.id}.{a.id}": (t, a, a.size)
2901            for t in io
2902            for a in t.axes
2903            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2904        }
2905
2906    @staticmethod
2907    def _get_axes_with_independent_size(
2908        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2909    ):
2910        return {
2911            (t.id, a.id): (t, a, a.size)
2912            for t in io
2913            for a in t.axes
2914            if not isinstance(a, BatchAxis)
2915            and isinstance(a.size, (int, ParameterizedSize))
2916        }
2917
2918    @field_validator("outputs", mode="after")
2919    @classmethod
2920    def _validate_output_axes(
2921        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2922    ) -> List[OutputTensorDescr]:
2923        input_size_refs = cls._get_axes_with_independent_size(
2924            info.data.get("inputs", [])
2925        )
2926        output_size_refs = cls._get_axes_with_independent_size(outputs)
2927
2928        for i, out in enumerate(outputs):
2929            valid_independent_refs: Dict[
2930                Tuple[TensorId, AxisId],
2931                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2932            ] = {
2933                **{
2934                    (out.id, a.id): (out, a, a.size)
2935                    for a in out.axes
2936                    if not isinstance(a, BatchAxis)
2937                    and isinstance(a.size, (int, ParameterizedSize))
2938                },
2939                **input_size_refs,
2940                **output_size_refs,
2941            }
2942            for a, ax in enumerate(out.axes):
2943                cls._validate_axis(
2944                    "outputs",
2945                    i,
2946                    out.id,
2947                    a,
2948                    ax,
2949                    valid_independent_refs=valid_independent_refs,
2950                )
2951
2952        return outputs
2953
2954    packaged_by: List[Author] = Field(
2955        default_factory=cast(Callable[[], List[Author]], list)
2956    )
2957    """The persons that have packaged and uploaded this model.
2958    Only required if those persons differ from the `authors`."""
2959
2960    parent: Optional[LinkedModel] = None
2961    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2962
2963    @model_validator(mode="after")
2964    def _validate_parent_is_not_self(self) -> Self:
2965        if self.parent is not None and self.parent.id == self.id:
2966            raise ValueError("A model description may not reference itself as parent.")
2967
2968        return self
2969
2970    run_mode: Annotated[
2971        Optional[RunMode],
2972        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2973    ] = None
2974    """Custom run mode for this model: for more complex prediction procedures like test time
2975    data augmentation that currently cannot be expressed in the specification.
2976    No standard run modes are defined yet."""
2977
2978    timestamp: Datetime = Field(default_factory=Datetime.now)
2979    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2980    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2981    (In Python a datetime object is valid, too)."""
2982
2983    training_data: Annotated[
2984        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2985        Field(union_mode="left_to_right"),
2986    ] = None
2987    """The dataset used to train this model"""
2988
2989    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2990    """The weights for this model.
2991    Weights can be given for different formats, but should otherwise be equivalent.
2992    The available weight formats determine which consumers can use this model."""
2993
2994    config: Config = Field(default_factory=Config.model_construct)
2995
2996    @model_validator(mode="after")
2997    def _add_default_cover(self) -> Self:
2998        if not get_validation_context().perform_io_checks or self.covers:
2999            return self
3000
3001        try:
3002            generated_covers = generate_covers(
3003                [
3004                    (t, load_array(t.test_tensor))
3005                    for t in self.inputs
3006                    if t.test_tensor is not None
3007                ],
3008                [
3009                    (t, load_array(t.test_tensor))
3010                    for t in self.outputs
3011                    if t.test_tensor is not None
3012                ],
3013            )
3014        except Exception as e:
3015            issue_warning(
3016                "Failed to generate cover image(s): {e}",
3017                value=self.covers,
3018                msg_context=dict(e=e),
3019                field="covers",
3020            )
3021        else:
3022            self.covers.extend(generated_covers)
3023
3024        return self
3025
3026    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3027        return self._get_test_arrays(self.inputs)
3028
3029    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3030        return self._get_test_arrays(self.outputs)
3031
3032    @staticmethod
3033    def _get_test_arrays(
3034        io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3035    ):
3036        ts: List[FileDescr] = []
3037        for d in io_descr:
3038            if d.test_tensor is None:
3039                raise ValueError(
3040                    f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3041                )
3042            ts.append(d.test_tensor)
3043
3044        data = [load_array(t) for t in ts]
3045        assert all(isinstance(d, np.ndarray) for d in data)
3046        return data
3047
3048    @staticmethod
3049    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3050        batch_size = 1
3051        tensor_with_batchsize: Optional[TensorId] = None
3052        for tid in tensor_sizes:
3053            for aid, s in tensor_sizes[tid].items():
3054                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3055                    continue
3056
3057                if batch_size != 1:
3058                    assert tensor_with_batchsize is not None
3059                    raise ValueError(
3060                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3061                    )
3062
3063                batch_size = s
3064                tensor_with_batchsize = tid
3065
3066        return batch_size
3067
3068    def get_output_tensor_sizes(
3069        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3070    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3071        """Returns the tensor output sizes for given **input_sizes**.
3072        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3073        Otherwise it might be larger than the actual (valid) output"""
3074        batch_size = self.get_batch_size(input_sizes)
3075        ns = self.get_ns(input_sizes)
3076
3077        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3078        return tensor_sizes.outputs
3079
3080    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3081        """get parameter `n` for each parameterized axis
3082        such that the valid input size is >= the given input size"""
3083        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3084        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3085        for tid in input_sizes:
3086            for aid, s in input_sizes[tid].items():
3087                size_descr = axes[tid][aid].size
3088                if isinstance(size_descr, ParameterizedSize):
3089                    ret[(tid, aid)] = size_descr.get_n(s)
3090                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3091                    pass
3092                else:
3093                    assert_never(size_descr)
3094
3095        return ret
3096
3097    def get_tensor_sizes(
3098        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3099    ) -> _TensorSizes:
3100        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3101        return _TensorSizes(
3102            {
3103                t: {
3104                    aa: axis_sizes.inputs[(tt, aa)]
3105                    for tt, aa in axis_sizes.inputs
3106                    if tt == t
3107                }
3108                for t in {tt for tt, _ in axis_sizes.inputs}
3109            },
3110            {
3111                t: {
3112                    aa: axis_sizes.outputs[(tt, aa)]
3113                    for tt, aa in axis_sizes.outputs
3114                    if tt == t
3115                }
3116                for t in {tt for tt, _ in axis_sizes.outputs}
3117            },
3118        )
3119
3120    def get_axis_sizes(
3121        self,
3122        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3123        batch_size: Optional[int] = None,
3124        *,
3125        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3126    ) -> _AxisSizes:
3127        """Determine input and output block shape for scale factors **ns**
3128        of parameterized input sizes.
3129
3130        Args:
3131            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3132                that is parameterized as `size = min + n * step`.
3133            batch_size: The desired size of the batch dimension.
3134                If given **batch_size** overwrites any batch size present in
3135                **max_input_shape**. Default 1.
3136            max_input_shape: Limits the derived block shapes.
3137                Each axis for which the input size, parameterized by `n`, is larger
3138                than **max_input_shape** is set to the minimal value `n_min` for which
3139                this is still true.
3140                Use this for small input samples or large values of **ns**.
3141                Or simply whenever you know the full input shape.
3142
3143        Returns:
3144            Resolved axis sizes for model inputs and outputs.
3145        """
3146        max_input_shape = max_input_shape or {}
3147        if batch_size is None:
3148            for (_t_id, a_id), s in max_input_shape.items():
3149                if a_id == BATCH_AXIS_ID:
3150                    batch_size = s
3151                    break
3152            else:
3153                batch_size = 1
3154
3155        all_axes = {
3156            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3157        }
3158
3159        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3160        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3161
3162        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3163            if isinstance(a, BatchAxis):
3164                if (t_descr.id, a.id) in ns:
3165                    logger.warning(
3166                        "Ignoring unexpected size increment factor (n) for batch axis"
3167                        + " of tensor '{}'.",
3168                        t_descr.id,
3169                    )
3170                return batch_size
3171            elif isinstance(a.size, int):
3172                if (t_descr.id, a.id) in ns:
3173                    logger.warning(
3174                        "Ignoring unexpected size increment factor (n) for fixed size"
3175                        + " axis '{}' of tensor '{}'.",
3176                        a.id,
3177                        t_descr.id,
3178                    )
3179                return a.size
3180            elif isinstance(a.size, ParameterizedSize):
3181                if (t_descr.id, a.id) not in ns:
3182                    raise ValueError(
3183                        "Size increment factor (n) missing for parametrized axis"
3184                        + f" '{a.id}' of tensor '{t_descr.id}'."
3185                    )
3186                n = ns[(t_descr.id, a.id)]
3187                s_max = max_input_shape.get((t_descr.id, a.id))
3188                if s_max is not None:
3189                    n = min(n, a.size.get_n(s_max))
3190
3191                return a.size.get_size(n)
3192
3193            elif isinstance(a.size, SizeReference):
3194                if (t_descr.id, a.id) in ns:
3195                    logger.warning(
3196                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3197                        + " of tensor '{}' with size reference.",
3198                        a.id,
3199                        t_descr.id,
3200                    )
3201                assert not isinstance(a, BatchAxis)
3202                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3203                assert not isinstance(ref_axis, BatchAxis)
3204                ref_key = (a.size.tensor_id, a.size.axis_id)
3205                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3206                assert ref_size is not None, ref_key
3207                assert not isinstance(ref_size, _DataDepSize), ref_key
3208                return a.size.get_size(
3209                    axis=a,
3210                    ref_axis=ref_axis,
3211                    ref_size=ref_size,
3212                )
3213            elif isinstance(a.size, DataDependentSize):
3214                if (t_descr.id, a.id) in ns:
3215                    logger.warning(
3216                        "Ignoring unexpected increment factor (n) for data dependent"
3217                        + " size axis '{}' of tensor '{}'.",
3218                        a.id,
3219                        t_descr.id,
3220                    )
3221                return _DataDepSize(a.size.min, a.size.max)
3222            else:
3223                assert_never(a.size)
3224
3225        # first resolve all , but the `SizeReference` input sizes
3226        for t_descr in self.inputs:
3227            for a in t_descr.axes:
3228                if not isinstance(a.size, SizeReference):
3229                    s = get_axis_size(a)
3230                    assert not isinstance(s, _DataDepSize)
3231                    inputs[t_descr.id, a.id] = s
3232
3233        # resolve all other input axis sizes
3234        for t_descr in self.inputs:
3235            for a in t_descr.axes:
3236                if isinstance(a.size, SizeReference):
3237                    s = get_axis_size(a)
3238                    assert not isinstance(s, _DataDepSize)
3239                    inputs[t_descr.id, a.id] = s
3240
3241        # resolve all output axis sizes
3242        for t_descr in self.outputs:
3243            for a in t_descr.axes:
3244                assert not isinstance(a.size, ParameterizedSize)
3245                s = get_axis_size(a)
3246                outputs[t_descr.id, a.id] = s
3247
3248        return _AxisSizes(inputs=inputs, outputs=outputs)
3249
3250    @model_validator(mode="before")
3251    @classmethod
3252    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3253        cls.convert_from_old_format_wo_validation(data)
3254        return data
3255
3256    @classmethod
3257    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3258        """Convert metadata following an older format version to this classes' format
3259        without validating the result.
3260        """
3261        if (
3262            data.get("type") == "model"
3263            and isinstance(fv := data.get("format_version"), str)
3264            and fv.count(".") == 2
3265        ):
3266            fv_parts = fv.split(".")
3267            if any(not p.isdigit() for p in fv_parts):
3268                return
3269
3270            fv_tuple = tuple(map(int, fv_parts))
3271
3272            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3273            if fv_tuple[:2] in ((0, 3), (0, 4)):
3274                m04 = _ModelDescr_v0_4.load(data)
3275                if isinstance(m04, InvalidDescr):
3276                    try:
3277                        updated = _model_conv.convert_as_dict(
3278                            m04  # pyright: ignore[reportArgumentType]
3279                        )
3280                    except Exception as e:
3281                        logger.error(
3282                            "Failed to convert from invalid model 0.4 description."
3283                            + f"\nerror: {e}"
3284                            + "\nProceeding with model 0.5 validation without conversion."
3285                        )
3286                        updated = None
3287                else:
3288                    updated = _model_conv.convert_as_dict(m04)
3289
3290                if updated is not None:
3291                    data.clear()
3292                    data.update(updated)
3293
3294            elif fv_tuple[:2] == (0, 5):
3295                # bump patch version
3296                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.5']] = '0.5.5'
implemented_type: ClassVar[Literal['model']] = 'model'

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f83b7cd1080>, severity=35, msg=None, context=None)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Optional[Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f83b7c0fec0>), PlainSerializer(func=<function _package_serializer at 0x7f83b7cb7100>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.md', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f83b7cd1080>, severity=35, msg=None, context=None)]

URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[bioimageio.spec.model.v0_5.InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f83b4afa980>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[bioimageio.spec.model.v0_5.OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f83b4afaf20>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[bioimageio.spec.model.v0_5.WeightsDescr, WrapSerializer(func=<function package_weights at 0x7f83b790a5c0>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

def get_input_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
3026    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3027        return self._get_test_arrays(self.inputs)
def get_output_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
3029    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3030        return self._get_test_arrays(self.outputs)
@staticmethod
def get_batch_size( tensor_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> int:
3048    @staticmethod
3049    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3050        batch_size = 1
3051        tensor_with_batchsize: Optional[TensorId] = None
3052        for tid in tensor_sizes:
3053            for aid, s in tensor_sizes[tid].items():
3054                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3055                    continue
3056
3057                if batch_size != 1:
3058                    assert tensor_with_batchsize is not None
3059                    raise ValueError(
3060                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3061                    )
3062
3063                batch_size = s
3064                tensor_with_batchsize = tid
3065
3066        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Dict[bioimageio.spec.model.v0_5.TensorId, Dict[bioimageio.spec.model.v0_5.AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
3068    def get_output_tensor_sizes(
3069        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3070    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3071        """Returns the tensor output sizes for given **input_sizes**.
3072        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3073        Otherwise it might be larger than the actual (valid) output"""
3074        batch_size = self.get_batch_size(input_sizes)
3075        ns = self.get_ns(input_sizes)
3076
3077        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3078        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]):
3080    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3081        """get parameter `n` for each parameterized axis
3082        such that the valid input size is >= the given input size"""
3083        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3084        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3085        for tid in input_sizes:
3086            for aid, s in input_sizes[tid].items():
3087                size_descr = axes[tid][aid].size
3088                if isinstance(size_descr, ParameterizedSize):
3089                    ret[(tid, aid)] = size_descr.get_n(s)
3090                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3091                    pass
3092                else:
3093                    assert_never(size_descr)
3094
3095        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
3097    def get_tensor_sizes(
3098        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3099    ) -> _TensorSizes:
3100        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3101        return _TensorSizes(
3102            {
3103                t: {
3104                    aa: axis_sizes.inputs[(tt, aa)]
3105                    for tt, aa in axis_sizes.inputs
3106                    if tt == t
3107                }
3108                for t in {tt for tt, _ in axis_sizes.inputs}
3109            },
3110            {
3111                t: {
3112                    aa: axis_sizes.outputs[(tt, aa)]
3113                    for tt, aa in axis_sizes.outputs
3114                    if tt == t
3115                }
3116                for t in {tt for tt, _ in axis_sizes.outputs}
3117            },
3118        )
def get_axis_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3120    def get_axis_sizes(
3121        self,
3122        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3123        batch_size: Optional[int] = None,
3124        *,
3125        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3126    ) -> _AxisSizes:
3127        """Determine input and output block shape for scale factors **ns**
3128        of parameterized input sizes.
3129
3130        Args:
3131            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3132                that is parameterized as `size = min + n * step`.
3133            batch_size: The desired size of the batch dimension.
3134                If given **batch_size** overwrites any batch size present in
3135                **max_input_shape**. Default 1.
3136            max_input_shape: Limits the derived block shapes.
3137                Each axis for which the input size, parameterized by `n`, is larger
3138                than **max_input_shape** is set to the minimal value `n_min` for which
3139                this is still true.
3140                Use this for small input samples or large values of **ns**.
3141                Or simply whenever you know the full input shape.
3142
3143        Returns:
3144            Resolved axis sizes for model inputs and outputs.
3145        """
3146        max_input_shape = max_input_shape or {}
3147        if batch_size is None:
3148            for (_t_id, a_id), s in max_input_shape.items():
3149                if a_id == BATCH_AXIS_ID:
3150                    batch_size = s
3151                    break
3152            else:
3153                batch_size = 1
3154
3155        all_axes = {
3156            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3157        }
3158
3159        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3160        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3161
3162        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3163            if isinstance(a, BatchAxis):
3164                if (t_descr.id, a.id) in ns:
3165                    logger.warning(
3166                        "Ignoring unexpected size increment factor (n) for batch axis"
3167                        + " of tensor '{}'.",
3168                        t_descr.id,
3169                    )
3170                return batch_size
3171            elif isinstance(a.size, int):
3172                if (t_descr.id, a.id) in ns:
3173                    logger.warning(
3174                        "Ignoring unexpected size increment factor (n) for fixed size"
3175                        + " axis '{}' of tensor '{}'.",
3176                        a.id,
3177                        t_descr.id,
3178                    )
3179                return a.size
3180            elif isinstance(a.size, ParameterizedSize):
3181                if (t_descr.id, a.id) not in ns:
3182                    raise ValueError(
3183                        "Size increment factor (n) missing for parametrized axis"
3184                        + f" '{a.id}' of tensor '{t_descr.id}'."
3185                    )
3186                n = ns[(t_descr.id, a.id)]
3187                s_max = max_input_shape.get((t_descr.id, a.id))
3188                if s_max is not None:
3189                    n = min(n, a.size.get_n(s_max))
3190
3191                return a.size.get_size(n)
3192
3193            elif isinstance(a.size, SizeReference):
3194                if (t_descr.id, a.id) in ns:
3195                    logger.warning(
3196                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3197                        + " of tensor '{}' with size reference.",
3198                        a.id,
3199                        t_descr.id,
3200                    )
3201                assert not isinstance(a, BatchAxis)
3202                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3203                assert not isinstance(ref_axis, BatchAxis)
3204                ref_key = (a.size.tensor_id, a.size.axis_id)
3205                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3206                assert ref_size is not None, ref_key
3207                assert not isinstance(ref_size, _DataDepSize), ref_key
3208                return a.size.get_size(
3209                    axis=a,
3210                    ref_axis=ref_axis,
3211                    ref_size=ref_size,
3212                )
3213            elif isinstance(a.size, DataDependentSize):
3214                if (t_descr.id, a.id) in ns:
3215                    logger.warning(
3216                        "Ignoring unexpected increment factor (n) for data dependent"
3217                        + " size axis '{}' of tensor '{}'.",
3218                        a.id,
3219                        t_descr.id,
3220                    )
3221                return _DataDepSize(a.size.min, a.size.max)
3222            else:
3223                assert_never(a.size)
3224
3225        # first resolve all , but the `SizeReference` input sizes
3226        for t_descr in self.inputs:
3227            for a in t_descr.axes:
3228                if not isinstance(a.size, SizeReference):
3229                    s = get_axis_size(a)
3230                    assert not isinstance(s, _DataDepSize)
3231                    inputs[t_descr.id, a.id] = s
3232
3233        # resolve all other input axis sizes
3234        for t_descr in self.inputs:
3235            for a in t_descr.axes:
3236                if isinstance(a.size, SizeReference):
3237                    s = get_axis_size(a)
3238                    assert not isinstance(s, _DataDepSize)
3239                    inputs[t_descr.id, a.id] = s
3240
3241        # resolve all output axis sizes
3242        for t_descr in self.outputs:
3243            for a in t_descr.axes:
3244                assert not isinstance(a.size, ParameterizedSize)
3245                s = get_axis_size(a)
3246                outputs[t_descr.id, a.id] = s
3247
3248        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3256    @classmethod
3257    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3258        """Convert metadata following an older format version to this classes' format
3259        without validating the result.
3260        """
3261        if (
3262            data.get("type") == "model"
3263            and isinstance(fv := data.get("format_version"), str)
3264            and fv.count(".") == 2
3265        ):
3266            fv_parts = fv.split(".")
3267            if any(not p.isdigit() for p in fv_parts):
3268                return
3269
3270            fv_tuple = tuple(map(int, fv_parts))
3271
3272            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3273            if fv_tuple[:2] in ((0, 3), (0, 4)):
3274                m04 = _ModelDescr_v0_4.load(data)
3275                if isinstance(m04, InvalidDescr):
3276                    try:
3277                        updated = _model_conv.convert_as_dict(
3278                            m04  # pyright: ignore[reportArgumentType]
3279                        )
3280                    except Exception as e:
3281                        logger.error(
3282                            "Failed to convert from invalid model 0.4 description."
3283                            + f"\nerror: {e}"
3284                            + "\nProceeding with model 0.5 validation without conversion."
3285                        )
3286                        updated = None
3287                else:
3288                    updated = _model_conv.convert_as_dict(m04)
3289
3290                if updated is not None:
3291                    data.clear()
3292                    data.update(updated)
3293
3294            elif fv_tuple[:2] == (0, 5):
3295                # bump patch version
3296                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 5)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
337def init_private_attributes(self: BaseModel, context: Any, /) -> None:
338    """This function is meant to behave like a BaseModel method to initialise private attributes.
339
340    It takes context as an argument since that's what pydantic-core passes when calling it.
341
342    Args:
343        self: The BaseModel instance.
344        context: The context.
345    """
346    if getattr(self, '__pydantic_private__', None) is None:
347        pydantic_private = {}
348        for name, private_attr in self.__private_attributes__.items():
349            default = private_attr.get_default()
350            if default is not PydanticUndefined:
351                pydantic_private[name] = default
352        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Arguments:
  • self: The BaseModel instance.
  • context: The context.
ModelDescr_v0_4 = <class 'bioimageio.spec.model.v0_4.ModelDescr'>
ModelDescr_v0_5 = <class 'ModelDescr'>
AnyModelDescr = typing.Annotated[typing.Union[typing.Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], typing.Annotated[ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]

Union of any released model desription