bioimageio.spec.model

implementaions of all released minor versions are available in submodules:

 1# autogen: start
 2"""
 3implementaions of all released minor versions are available in submodules:
 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr`
 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr`
 6"""
 7
 8from typing import Union
 9
10from pydantic import Discriminator, Field
11from typing_extensions import Annotated
12
13from . import v0_4, v0_5
14
15ModelDescr = v0_5.ModelDescr
16ModelDescr_v0_4 = v0_4.ModelDescr
17ModelDescr_v0_5 = v0_5.ModelDescr
18
19AnyModelDescr = Annotated[
20    Union[
21        Annotated[ModelDescr_v0_4, Field(title="model 0.4")],
22        Annotated[ModelDescr_v0_5, Field(title="model 0.5")],
23    ],
24    Discriminator("format_version"),
25    Field(title="model"),
26]
27"""Union of any released model desription"""
28# autogen: stop
2610class ModelDescr(GenericModelDescrBase):
2611    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2612    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2613    """
2614
2615    implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5"
2616    if TYPE_CHECKING:
2617        format_version: Literal["0.5.5"] = "0.5.5"
2618    else:
2619        format_version: Literal["0.5.5"]
2620        """Version of the bioimage.io model description specification used.
2621        When creating a new model always use the latest micro/patch version described here.
2622        The `format_version` is important for any consumer software to understand how to parse the fields.
2623        """
2624
2625    implemented_type: ClassVar[Literal["model"]] = "model"
2626    if TYPE_CHECKING:
2627        type: Literal["model"] = "model"
2628    else:
2629        type: Literal["model"]
2630        """Specialized resource type 'model'"""
2631
2632    id: Optional[ModelId] = None
2633    """bioimage.io-wide unique resource identifier
2634    assigned by bioimage.io; version **un**specific."""
2635
2636    authors: FAIR[List[Author]] = Field(
2637        default_factory=cast(Callable[[], List[Author]], list)
2638    )
2639    """The authors are the creators of the model RDF and the primary points of contact."""
2640
2641    documentation: FAIR[Optional[FileSource_documentation]] = None
2642    """URL or relative path to a markdown file with additional documentation.
2643    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2644    The documentation should include a '#[#] Validation' (sub)section
2645    with details on how to quantitatively validate the model on unseen data."""
2646
2647    @field_validator("documentation", mode="after")
2648    @classmethod
2649    def _validate_documentation(
2650        cls, value: Optional[FileSource_documentation]
2651    ) -> Optional[FileSource_documentation]:
2652        if not get_validation_context().perform_io_checks or value is None:
2653            return value
2654
2655        doc_reader = get_reader(value)
2656        doc_content = doc_reader.read().decode(encoding="utf-8")
2657        if not re.search("#.*[vV]alidation", doc_content):
2658            issue_warning(
2659                "No '# Validation' (sub)section found in {value}.",
2660                value=value,
2661                field="documentation",
2662            )
2663
2664        return value
2665
2666    inputs: NotEmpty[Sequence[InputTensorDescr]]
2667    """Describes the input tensors expected by this model."""
2668
2669    @field_validator("inputs", mode="after")
2670    @classmethod
2671    def _validate_input_axes(
2672        cls, inputs: Sequence[InputTensorDescr]
2673    ) -> Sequence[InputTensorDescr]:
2674        input_size_refs = cls._get_axes_with_independent_size(inputs)
2675
2676        for i, ipt in enumerate(inputs):
2677            valid_independent_refs: Dict[
2678                Tuple[TensorId, AxisId],
2679                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2680            ] = {
2681                **{
2682                    (ipt.id, a.id): (ipt, a, a.size)
2683                    for a in ipt.axes
2684                    if not isinstance(a, BatchAxis)
2685                    and isinstance(a.size, (int, ParameterizedSize))
2686                },
2687                **input_size_refs,
2688            }
2689            for a, ax in enumerate(ipt.axes):
2690                cls._validate_axis(
2691                    "inputs",
2692                    i=i,
2693                    tensor_id=ipt.id,
2694                    a=a,
2695                    axis=ax,
2696                    valid_independent_refs=valid_independent_refs,
2697                )
2698        return inputs
2699
2700    @staticmethod
2701    def _validate_axis(
2702        field_name: str,
2703        i: int,
2704        tensor_id: TensorId,
2705        a: int,
2706        axis: AnyAxis,
2707        valid_independent_refs: Dict[
2708            Tuple[TensorId, AxisId],
2709            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2710        ],
2711    ):
2712        if isinstance(axis, BatchAxis) or isinstance(
2713            axis.size, (int, ParameterizedSize, DataDependentSize)
2714        ):
2715            return
2716        elif not isinstance(axis.size, SizeReference):
2717            assert_never(axis.size)
2718
2719        # validate axis.size SizeReference
2720        ref = (axis.size.tensor_id, axis.size.axis_id)
2721        if ref not in valid_independent_refs:
2722            raise ValueError(
2723                "Invalid tensor axis reference at"
2724                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2725            )
2726        if ref == (tensor_id, axis.id):
2727            raise ValueError(
2728                "Self-referencing not allowed for"
2729                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2730            )
2731        if axis.type == "channel":
2732            if valid_independent_refs[ref][1].type != "channel":
2733                raise ValueError(
2734                    "A channel axis' size may only reference another fixed size"
2735                    + " channel axis."
2736                )
2737            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2738                ref_size = valid_independent_refs[ref][2]
2739                assert isinstance(ref_size, int), (
2740                    "channel axis ref (another channel axis) has to specify fixed"
2741                    + " size"
2742                )
2743                generated_channel_names = [
2744                    Identifier(axis.channel_names.format(i=i))
2745                    for i in range(1, ref_size + 1)
2746                ]
2747                axis.channel_names = generated_channel_names
2748
2749        if (ax_unit := getattr(axis, "unit", None)) != (
2750            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2751        ):
2752            raise ValueError(
2753                "The units of an axis and its reference axis need to match, but"
2754                + f" '{ax_unit}' != '{ref_unit}'."
2755            )
2756        ref_axis = valid_independent_refs[ref][1]
2757        if isinstance(ref_axis, BatchAxis):
2758            raise ValueError(
2759                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2760                + " (a batch axis is not allowed as reference)."
2761            )
2762
2763        if isinstance(axis, WithHalo):
2764            min_size = axis.size.get_size(axis, ref_axis, n=0)
2765            if (min_size - 2 * axis.halo) < 1:
2766                raise ValueError(
2767                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2768                    + f" {axis.halo}."
2769                )
2770
2771            input_halo = axis.halo * axis.scale / ref_axis.scale
2772            if input_halo != int(input_halo) or input_halo % 2 == 1:
2773                raise ValueError(
2774                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2775                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2776                    + f"     {tensor_id}.{axis.id}."
2777                )
2778
2779    @model_validator(mode="after")
2780    def _validate_test_tensors(self) -> Self:
2781        if not get_validation_context().perform_io_checks:
2782            return self
2783
2784        test_output_arrays = [
2785            None if descr.test_tensor is None else load_array(descr.test_tensor)
2786            for descr in self.outputs
2787        ]
2788        test_input_arrays = [
2789            None if descr.test_tensor is None else load_array(descr.test_tensor)
2790            for descr in self.inputs
2791        ]
2792
2793        tensors = {
2794            descr.id: (descr, array)
2795            for descr, array in zip(
2796                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2797            )
2798        }
2799        validate_tensors(tensors, tensor_origin="test_tensor")
2800
2801        output_arrays = {
2802            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2803        }
2804        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2805            if not rep_tol.absolute_tolerance:
2806                continue
2807
2808            if rep_tol.output_ids:
2809                out_arrays = {
2810                    oid: a
2811                    for oid, a in output_arrays.items()
2812                    if oid in rep_tol.output_ids
2813                }
2814            else:
2815                out_arrays = output_arrays
2816
2817            for out_id, array in out_arrays.items():
2818                if array is None:
2819                    continue
2820
2821                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2822                    raise ValueError(
2823                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2824                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2825                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2826                    )
2827
2828        return self
2829
2830    @model_validator(mode="after")
2831    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2832        ipt_refs = {t.id for t in self.inputs}
2833        out_refs = {t.id for t in self.outputs}
2834        for ipt in self.inputs:
2835            for p in ipt.preprocessing:
2836                ref = p.kwargs.get("reference_tensor")
2837                if ref is None:
2838                    continue
2839                if ref not in ipt_refs:
2840                    raise ValueError(
2841                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2842                        + f" references are: {ipt_refs}."
2843                    )
2844
2845        for out in self.outputs:
2846            for p in out.postprocessing:
2847                ref = p.kwargs.get("reference_tensor")
2848                if ref is None:
2849                    continue
2850
2851                if ref not in ipt_refs and ref not in out_refs:
2852                    raise ValueError(
2853                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2854                        + f" are: {ipt_refs | out_refs}."
2855                    )
2856
2857        return self
2858
2859    # TODO: use validate funcs in validate_test_tensors
2860    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2861
2862    name: Annotated[
2863        str,
2864        RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"),
2865        MinLen(5),
2866        MaxLen(128),
2867        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2868    ]
2869    """A human-readable name of this model.
2870    It should be no longer than 64 characters
2871    and may only contain letter, number, underscore, minus, parentheses and spaces.
2872    We recommend to chose a name that refers to the model's task and image modality.
2873    """
2874
2875    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2876    """Describes the output tensors."""
2877
2878    @field_validator("outputs", mode="after")
2879    @classmethod
2880    def _validate_tensor_ids(
2881        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2882    ) -> Sequence[OutputTensorDescr]:
2883        tensor_ids = [
2884            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2885        ]
2886        duplicate_tensor_ids: List[str] = []
2887        seen: Set[str] = set()
2888        for t in tensor_ids:
2889            if t in seen:
2890                duplicate_tensor_ids.append(t)
2891
2892            seen.add(t)
2893
2894        if duplicate_tensor_ids:
2895            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2896
2897        return outputs
2898
2899    @staticmethod
2900    def _get_axes_with_parameterized_size(
2901        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2902    ):
2903        return {
2904            f"{t.id}.{a.id}": (t, a, a.size)
2905            for t in io
2906            for a in t.axes
2907            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2908        }
2909
2910    @staticmethod
2911    def _get_axes_with_independent_size(
2912        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2913    ):
2914        return {
2915            (t.id, a.id): (t, a, a.size)
2916            for t in io
2917            for a in t.axes
2918            if not isinstance(a, BatchAxis)
2919            and isinstance(a.size, (int, ParameterizedSize))
2920        }
2921
2922    @field_validator("outputs", mode="after")
2923    @classmethod
2924    def _validate_output_axes(
2925        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2926    ) -> List[OutputTensorDescr]:
2927        input_size_refs = cls._get_axes_with_independent_size(
2928            info.data.get("inputs", [])
2929        )
2930        output_size_refs = cls._get_axes_with_independent_size(outputs)
2931
2932        for i, out in enumerate(outputs):
2933            valid_independent_refs: Dict[
2934                Tuple[TensorId, AxisId],
2935                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2936            ] = {
2937                **{
2938                    (out.id, a.id): (out, a, a.size)
2939                    for a in out.axes
2940                    if not isinstance(a, BatchAxis)
2941                    and isinstance(a.size, (int, ParameterizedSize))
2942                },
2943                **input_size_refs,
2944                **output_size_refs,
2945            }
2946            for a, ax in enumerate(out.axes):
2947                cls._validate_axis(
2948                    "outputs",
2949                    i,
2950                    out.id,
2951                    a,
2952                    ax,
2953                    valid_independent_refs=valid_independent_refs,
2954                )
2955
2956        return outputs
2957
2958    packaged_by: List[Author] = Field(
2959        default_factory=cast(Callable[[], List[Author]], list)
2960    )
2961    """The persons that have packaged and uploaded this model.
2962    Only required if those persons differ from the `authors`."""
2963
2964    parent: Optional[LinkedModel] = None
2965    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2966
2967    @model_validator(mode="after")
2968    def _validate_parent_is_not_self(self) -> Self:
2969        if self.parent is not None and self.parent.id == self.id:
2970            raise ValueError("A model description may not reference itself as parent.")
2971
2972        return self
2973
2974    run_mode: Annotated[
2975        Optional[RunMode],
2976        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2977    ] = None
2978    """Custom run mode for this model: for more complex prediction procedures like test time
2979    data augmentation that currently cannot be expressed in the specification.
2980    No standard run modes are defined yet."""
2981
2982    timestamp: Datetime = Field(default_factory=Datetime.now)
2983    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2984    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2985    (In Python a datetime object is valid, too)."""
2986
2987    training_data: Annotated[
2988        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2989        Field(union_mode="left_to_right"),
2990    ] = None
2991    """The dataset used to train this model"""
2992
2993    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2994    """The weights for this model.
2995    Weights can be given for different formats, but should otherwise be equivalent.
2996    The available weight formats determine which consumers can use this model."""
2997
2998    config: Config = Field(default_factory=Config.model_construct)
2999
3000    @model_validator(mode="after")
3001    def _add_default_cover(self) -> Self:
3002        if not get_validation_context().perform_io_checks or self.covers:
3003            return self
3004
3005        try:
3006            generated_covers = generate_covers(
3007                [
3008                    (t, load_array(t.test_tensor))
3009                    for t in self.inputs
3010                    if t.test_tensor is not None
3011                ],
3012                [
3013                    (t, load_array(t.test_tensor))
3014                    for t in self.outputs
3015                    if t.test_tensor is not None
3016                ],
3017            )
3018        except Exception as e:
3019            issue_warning(
3020                "Failed to generate cover image(s): {e}",
3021                value=self.covers,
3022                msg_context=dict(e=e),
3023                field="covers",
3024            )
3025        else:
3026            self.covers.extend(generated_covers)
3027
3028        return self
3029
3030    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3031        return self._get_test_arrays(self.inputs)
3032
3033    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3034        return self._get_test_arrays(self.outputs)
3035
3036    @staticmethod
3037    def _get_test_arrays(
3038        io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
3039    ):
3040        ts: List[FileDescr] = []
3041        for d in io_descr:
3042            if d.test_tensor is None:
3043                raise ValueError(
3044                    f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`."
3045                )
3046            ts.append(d.test_tensor)
3047
3048        data = [load_array(t) for t in ts]
3049        assert all(isinstance(d, np.ndarray) for d in data)
3050        return data
3051
3052    @staticmethod
3053    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3054        batch_size = 1
3055        tensor_with_batchsize: Optional[TensorId] = None
3056        for tid in tensor_sizes:
3057            for aid, s in tensor_sizes[tid].items():
3058                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3059                    continue
3060
3061                if batch_size != 1:
3062                    assert tensor_with_batchsize is not None
3063                    raise ValueError(
3064                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3065                    )
3066
3067                batch_size = s
3068                tensor_with_batchsize = tid
3069
3070        return batch_size
3071
3072    def get_output_tensor_sizes(
3073        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3074    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3075        """Returns the tensor output sizes for given **input_sizes**.
3076        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3077        Otherwise it might be larger than the actual (valid) output"""
3078        batch_size = self.get_batch_size(input_sizes)
3079        ns = self.get_ns(input_sizes)
3080
3081        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3082        return tensor_sizes.outputs
3083
3084    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3085        """get parameter `n` for each parameterized axis
3086        such that the valid input size is >= the given input size"""
3087        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3088        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3089        for tid in input_sizes:
3090            for aid, s in input_sizes[tid].items():
3091                size_descr = axes[tid][aid].size
3092                if isinstance(size_descr, ParameterizedSize):
3093                    ret[(tid, aid)] = size_descr.get_n(s)
3094                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3095                    pass
3096                else:
3097                    assert_never(size_descr)
3098
3099        return ret
3100
3101    def get_tensor_sizes(
3102        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3103    ) -> _TensorSizes:
3104        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3105        return _TensorSizes(
3106            {
3107                t: {
3108                    aa: axis_sizes.inputs[(tt, aa)]
3109                    for tt, aa in axis_sizes.inputs
3110                    if tt == t
3111                }
3112                for t in {tt for tt, _ in axis_sizes.inputs}
3113            },
3114            {
3115                t: {
3116                    aa: axis_sizes.outputs[(tt, aa)]
3117                    for tt, aa in axis_sizes.outputs
3118                    if tt == t
3119                }
3120                for t in {tt for tt, _ in axis_sizes.outputs}
3121            },
3122        )
3123
3124    def get_axis_sizes(
3125        self,
3126        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3127        batch_size: Optional[int] = None,
3128        *,
3129        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3130    ) -> _AxisSizes:
3131        """Determine input and output block shape for scale factors **ns**
3132        of parameterized input sizes.
3133
3134        Args:
3135            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3136                that is parameterized as `size = min + n * step`.
3137            batch_size: The desired size of the batch dimension.
3138                If given **batch_size** overwrites any batch size present in
3139                **max_input_shape**. Default 1.
3140            max_input_shape: Limits the derived block shapes.
3141                Each axis for which the input size, parameterized by `n`, is larger
3142                than **max_input_shape** is set to the minimal value `n_min` for which
3143                this is still true.
3144                Use this for small input samples or large values of **ns**.
3145                Or simply whenever you know the full input shape.
3146
3147        Returns:
3148            Resolved axis sizes for model inputs and outputs.
3149        """
3150        max_input_shape = max_input_shape or {}
3151        if batch_size is None:
3152            for (_t_id, a_id), s in max_input_shape.items():
3153                if a_id == BATCH_AXIS_ID:
3154                    batch_size = s
3155                    break
3156            else:
3157                batch_size = 1
3158
3159        all_axes = {
3160            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3161        }
3162
3163        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3164        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3165
3166        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3167            if isinstance(a, BatchAxis):
3168                if (t_descr.id, a.id) in ns:
3169                    logger.warning(
3170                        "Ignoring unexpected size increment factor (n) for batch axis"
3171                        + " of tensor '{}'.",
3172                        t_descr.id,
3173                    )
3174                return batch_size
3175            elif isinstance(a.size, int):
3176                if (t_descr.id, a.id) in ns:
3177                    logger.warning(
3178                        "Ignoring unexpected size increment factor (n) for fixed size"
3179                        + " axis '{}' of tensor '{}'.",
3180                        a.id,
3181                        t_descr.id,
3182                    )
3183                return a.size
3184            elif isinstance(a.size, ParameterizedSize):
3185                if (t_descr.id, a.id) not in ns:
3186                    raise ValueError(
3187                        "Size increment factor (n) missing for parametrized axis"
3188                        + f" '{a.id}' of tensor '{t_descr.id}'."
3189                    )
3190                n = ns[(t_descr.id, a.id)]
3191                s_max = max_input_shape.get((t_descr.id, a.id))
3192                if s_max is not None:
3193                    n = min(n, a.size.get_n(s_max))
3194
3195                return a.size.get_size(n)
3196
3197            elif isinstance(a.size, SizeReference):
3198                if (t_descr.id, a.id) in ns:
3199                    logger.warning(
3200                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3201                        + " of tensor '{}' with size reference.",
3202                        a.id,
3203                        t_descr.id,
3204                    )
3205                assert not isinstance(a, BatchAxis)
3206                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3207                assert not isinstance(ref_axis, BatchAxis)
3208                ref_key = (a.size.tensor_id, a.size.axis_id)
3209                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3210                assert ref_size is not None, ref_key
3211                assert not isinstance(ref_size, _DataDepSize), ref_key
3212                return a.size.get_size(
3213                    axis=a,
3214                    ref_axis=ref_axis,
3215                    ref_size=ref_size,
3216                )
3217            elif isinstance(a.size, DataDependentSize):
3218                if (t_descr.id, a.id) in ns:
3219                    logger.warning(
3220                        "Ignoring unexpected increment factor (n) for data dependent"
3221                        + " size axis '{}' of tensor '{}'.",
3222                        a.id,
3223                        t_descr.id,
3224                    )
3225                return _DataDepSize(a.size.min, a.size.max)
3226            else:
3227                assert_never(a.size)
3228
3229        # first resolve all , but the `SizeReference` input sizes
3230        for t_descr in self.inputs:
3231            for a in t_descr.axes:
3232                if not isinstance(a.size, SizeReference):
3233                    s = get_axis_size(a)
3234                    assert not isinstance(s, _DataDepSize)
3235                    inputs[t_descr.id, a.id] = s
3236
3237        # resolve all other input axis sizes
3238        for t_descr in self.inputs:
3239            for a in t_descr.axes:
3240                if isinstance(a.size, SizeReference):
3241                    s = get_axis_size(a)
3242                    assert not isinstance(s, _DataDepSize)
3243                    inputs[t_descr.id, a.id] = s
3244
3245        # resolve all output axis sizes
3246        for t_descr in self.outputs:
3247            for a in t_descr.axes:
3248                assert not isinstance(a.size, ParameterizedSize)
3249                s = get_axis_size(a)
3250                outputs[t_descr.id, a.id] = s
3251
3252        return _AxisSizes(inputs=inputs, outputs=outputs)
3253
3254    @model_validator(mode="before")
3255    @classmethod
3256    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3257        cls.convert_from_old_format_wo_validation(data)
3258        return data
3259
3260    @classmethod
3261    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3262        """Convert metadata following an older format version to this classes' format
3263        without validating the result.
3264        """
3265        if (
3266            data.get("type") == "model"
3267            and isinstance(fv := data.get("format_version"), str)
3268            and fv.count(".") == 2
3269        ):
3270            fv_parts = fv.split(".")
3271            if any(not p.isdigit() for p in fv_parts):
3272                return
3273
3274            fv_tuple = tuple(map(int, fv_parts))
3275
3276            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3277            if fv_tuple[:2] in ((0, 3), (0, 4)):
3278                m04 = _ModelDescr_v0_4.load(data)
3279                if isinstance(m04, InvalidDescr):
3280                    try:
3281                        updated = _model_conv.convert_as_dict(
3282                            m04  # pyright: ignore[reportArgumentType]
3283                        )
3284                    except Exception as e:
3285                        logger.error(
3286                            "Failed to convert from invalid model 0.4 description."
3287                            + f"\nerror: {e}"
3288                            + "\nProceeding with model 0.5 validation without conversion."
3289                        )
3290                        updated = None
3291                else:
3292                    updated = _model_conv.convert_as_dict(m04)
3293
3294                if updated is not None:
3295                    data.clear()
3296                    data.update(updated)
3297
3298            elif fv_tuple[:2] == (0, 5):
3299                # bump patch version
3300                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.5']] = '0.5.5'
implemented_type: ClassVar[Literal['model']] = 'model'

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fe58e0ccfe0>, severity=35, msg=None, context=None)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Optional[Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7fe59c528d60>), PlainSerializer(func=<function _package_serializer at 0x7fe58e0b3060>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.md', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fe58e0ccfe0>, severity=35, msg=None, context=None)]

URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[bioimageio.spec.model.v0_5.InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fe58aed3ce0>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[bioimageio.spec.model.v0_5.OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7fe58aed3d80>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[bioimageio.spec.model.v0_5.WeightsDescr, WrapSerializer(func=<function package_weights at 0x7fe58dd0e480>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

def get_input_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
3030    def get_input_test_arrays(self) -> List[NDArray[Any]]:
3031        return self._get_test_arrays(self.inputs)
def get_output_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
3033    def get_output_test_arrays(self) -> List[NDArray[Any]]:
3034        return self._get_test_arrays(self.outputs)
@staticmethod
def get_batch_size( tensor_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> int:
3052    @staticmethod
3053    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
3054        batch_size = 1
3055        tensor_with_batchsize: Optional[TensorId] = None
3056        for tid in tensor_sizes:
3057            for aid, s in tensor_sizes[tid].items():
3058                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
3059                    continue
3060
3061                if batch_size != 1:
3062                    assert tensor_with_batchsize is not None
3063                    raise ValueError(
3064                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
3065                    )
3066
3067                batch_size = s
3068                tensor_with_batchsize = tid
3069
3070        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Dict[bioimageio.spec.model.v0_5.TensorId, Dict[bioimageio.spec.model.v0_5.AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
3072    def get_output_tensor_sizes(
3073        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
3074    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
3075        """Returns the tensor output sizes for given **input_sizes**.
3076        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
3077        Otherwise it might be larger than the actual (valid) output"""
3078        batch_size = self.get_batch_size(input_sizes)
3079        ns = self.get_ns(input_sizes)
3080
3081        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
3082        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]):
3084    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
3085        """get parameter `n` for each parameterized axis
3086        such that the valid input size is >= the given input size"""
3087        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3088        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3089        for tid in input_sizes:
3090            for aid, s in input_sizes[tid].items():
3091                size_descr = axes[tid][aid].size
3092                if isinstance(size_descr, ParameterizedSize):
3093                    ret[(tid, aid)] = size_descr.get_n(s)
3094                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3095                    pass
3096                else:
3097                    assert_never(size_descr)
3098
3099        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
3101    def get_tensor_sizes(
3102        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3103    ) -> _TensorSizes:
3104        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3105        return _TensorSizes(
3106            {
3107                t: {
3108                    aa: axis_sizes.inputs[(tt, aa)]
3109                    for tt, aa in axis_sizes.inputs
3110                    if tt == t
3111                }
3112                for t in {tt for tt, _ in axis_sizes.inputs}
3113            },
3114            {
3115                t: {
3116                    aa: axis_sizes.outputs[(tt, aa)]
3117                    for tt, aa in axis_sizes.outputs
3118                    if tt == t
3119                }
3120                for t in {tt for tt, _ in axis_sizes.outputs}
3121            },
3122        )
def get_axis_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3124    def get_axis_sizes(
3125        self,
3126        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3127        batch_size: Optional[int] = None,
3128        *,
3129        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3130    ) -> _AxisSizes:
3131        """Determine input and output block shape for scale factors **ns**
3132        of parameterized input sizes.
3133
3134        Args:
3135            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3136                that is parameterized as `size = min + n * step`.
3137            batch_size: The desired size of the batch dimension.
3138                If given **batch_size** overwrites any batch size present in
3139                **max_input_shape**. Default 1.
3140            max_input_shape: Limits the derived block shapes.
3141                Each axis for which the input size, parameterized by `n`, is larger
3142                than **max_input_shape** is set to the minimal value `n_min` for which
3143                this is still true.
3144                Use this for small input samples or large values of **ns**.
3145                Or simply whenever you know the full input shape.
3146
3147        Returns:
3148            Resolved axis sizes for model inputs and outputs.
3149        """
3150        max_input_shape = max_input_shape or {}
3151        if batch_size is None:
3152            for (_t_id, a_id), s in max_input_shape.items():
3153                if a_id == BATCH_AXIS_ID:
3154                    batch_size = s
3155                    break
3156            else:
3157                batch_size = 1
3158
3159        all_axes = {
3160            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3161        }
3162
3163        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3164        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3165
3166        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3167            if isinstance(a, BatchAxis):
3168                if (t_descr.id, a.id) in ns:
3169                    logger.warning(
3170                        "Ignoring unexpected size increment factor (n) for batch axis"
3171                        + " of tensor '{}'.",
3172                        t_descr.id,
3173                    )
3174                return batch_size
3175            elif isinstance(a.size, int):
3176                if (t_descr.id, a.id) in ns:
3177                    logger.warning(
3178                        "Ignoring unexpected size increment factor (n) for fixed size"
3179                        + " axis '{}' of tensor '{}'.",
3180                        a.id,
3181                        t_descr.id,
3182                    )
3183                return a.size
3184            elif isinstance(a.size, ParameterizedSize):
3185                if (t_descr.id, a.id) not in ns:
3186                    raise ValueError(
3187                        "Size increment factor (n) missing for parametrized axis"
3188                        + f" '{a.id}' of tensor '{t_descr.id}'."
3189                    )
3190                n = ns[(t_descr.id, a.id)]
3191                s_max = max_input_shape.get((t_descr.id, a.id))
3192                if s_max is not None:
3193                    n = min(n, a.size.get_n(s_max))
3194
3195                return a.size.get_size(n)
3196
3197            elif isinstance(a.size, SizeReference):
3198                if (t_descr.id, a.id) in ns:
3199                    logger.warning(
3200                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3201                        + " of tensor '{}' with size reference.",
3202                        a.id,
3203                        t_descr.id,
3204                    )
3205                assert not isinstance(a, BatchAxis)
3206                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3207                assert not isinstance(ref_axis, BatchAxis)
3208                ref_key = (a.size.tensor_id, a.size.axis_id)
3209                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3210                assert ref_size is not None, ref_key
3211                assert not isinstance(ref_size, _DataDepSize), ref_key
3212                return a.size.get_size(
3213                    axis=a,
3214                    ref_axis=ref_axis,
3215                    ref_size=ref_size,
3216                )
3217            elif isinstance(a.size, DataDependentSize):
3218                if (t_descr.id, a.id) in ns:
3219                    logger.warning(
3220                        "Ignoring unexpected increment factor (n) for data dependent"
3221                        + " size axis '{}' of tensor '{}'.",
3222                        a.id,
3223                        t_descr.id,
3224                    )
3225                return _DataDepSize(a.size.min, a.size.max)
3226            else:
3227                assert_never(a.size)
3228
3229        # first resolve all , but the `SizeReference` input sizes
3230        for t_descr in self.inputs:
3231            for a in t_descr.axes:
3232                if not isinstance(a.size, SizeReference):
3233                    s = get_axis_size(a)
3234                    assert not isinstance(s, _DataDepSize)
3235                    inputs[t_descr.id, a.id] = s
3236
3237        # resolve all other input axis sizes
3238        for t_descr in self.inputs:
3239            for a in t_descr.axes:
3240                if isinstance(a.size, SizeReference):
3241                    s = get_axis_size(a)
3242                    assert not isinstance(s, _DataDepSize)
3243                    inputs[t_descr.id, a.id] = s
3244
3245        # resolve all output axis sizes
3246        for t_descr in self.outputs:
3247            for a in t_descr.axes:
3248                assert not isinstance(a.size, ParameterizedSize)
3249                s = get_axis_size(a)
3250                outputs[t_descr.id, a.id] = s
3251
3252        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3260    @classmethod
3261    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3262        """Convert metadata following an older format version to this classes' format
3263        without validating the result.
3264        """
3265        if (
3266            data.get("type") == "model"
3267            and isinstance(fv := data.get("format_version"), str)
3268            and fv.count(".") == 2
3269        ):
3270            fv_parts = fv.split(".")
3271            if any(not p.isdigit() for p in fv_parts):
3272                return
3273
3274            fv_tuple = tuple(map(int, fv_parts))
3275
3276            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3277            if fv_tuple[:2] in ((0, 3), (0, 4)):
3278                m04 = _ModelDescr_v0_4.load(data)
3279                if isinstance(m04, InvalidDescr):
3280                    try:
3281                        updated = _model_conv.convert_as_dict(
3282                            m04  # pyright: ignore[reportArgumentType]
3283                        )
3284                    except Exception as e:
3285                        logger.error(
3286                            "Failed to convert from invalid model 0.4 description."
3287                            + f"\nerror: {e}"
3288                            + "\nProceeding with model 0.5 validation without conversion."
3289                        )
3290                        updated = None
3291                else:
3292                    updated = _model_conv.convert_as_dict(m04)
3293
3294                if updated is not None:
3295                    data.clear()
3296                    data.update(updated)
3297
3298            elif fv_tuple[:2] == (0, 5):
3299                # bump patch version
3300                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 5)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'always', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': True, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
337def init_private_attributes(self: BaseModel, context: Any, /) -> None:
338    """This function is meant to behave like a BaseModel method to initialise private attributes.
339
340    It takes context as an argument since that's what pydantic-core passes when calling it.
341
342    Args:
343        self: The BaseModel instance.
344        context: The context.
345    """
346    if getattr(self, '__pydantic_private__', None) is None:
347        pydantic_private = {}
348        for name, private_attr in self.__private_attributes__.items():
349            default = private_attr.get_default()
350            if default is not PydanticUndefined:
351                pydantic_private[name] = default
352        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Arguments:
  • self: The BaseModel instance.
  • context: The context.
ModelDescr_v0_4 = <class 'bioimageio.spec.model.v0_4.ModelDescr'>
ModelDescr_v0_5 = <class 'ModelDescr'>
AnyModelDescr = typing.Annotated[typing.Union[typing.Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], typing.Annotated[ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]

Union of any released model desription