bioimageio.spec.model

implementaions of all released minor versions are available in submodules:

 1# autogen: start
 2"""
 3implementaions of all released minor versions are available in submodules:
 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr`
 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr`
 6"""
 7
 8from typing import Union
 9
10from pydantic import Discriminator, Field
11from typing_extensions import Annotated
12
13from . import v0_4, v0_5
14
15ModelDescr = v0_5.ModelDescr
16ModelDescr_v0_4 = v0_4.ModelDescr
17ModelDescr_v0_5 = v0_5.ModelDescr
18
19AnyModelDescr = Annotated[
20    Union[
21        Annotated[ModelDescr_v0_4, Field(title="model 0.4")],
22        Annotated[ModelDescr_v0_5, Field(title="model 0.5")],
23    ],
24    Discriminator("format_version"),
25    Field(title="model"),
26]
27"""Union of any released model desription"""
28# autogen: stop
2554class ModelDescr(GenericModelDescrBase):
2555    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2556    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2557    """
2558
2559    implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2560    if TYPE_CHECKING:
2561        format_version: Literal["0.5.4"] = "0.5.4"
2562    else:
2563        format_version: Literal["0.5.4"]
2564        """Version of the bioimage.io model description specification used.
2565        When creating a new model always use the latest micro/patch version described here.
2566        The `format_version` is important for any consumer software to understand how to parse the fields.
2567        """
2568
2569    implemented_type: ClassVar[Literal["model"]] = "model"
2570    if TYPE_CHECKING:
2571        type: Literal["model"] = "model"
2572    else:
2573        type: Literal["model"]
2574        """Specialized resource type 'model'"""
2575
2576    id: Optional[ModelId] = None
2577    """bioimage.io-wide unique resource identifier
2578    assigned by bioimage.io; version **un**specific."""
2579
2580    authors: NotEmpty[List[Author]]
2581    """The authors are the creators of the model RDF and the primary points of contact."""
2582
2583    documentation: FileSource_documentation
2584    """URL or relative path to a markdown file with additional documentation.
2585    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2586    The documentation should include a '#[#] Validation' (sub)section
2587    with details on how to quantitatively validate the model on unseen data."""
2588
2589    @field_validator("documentation", mode="after")
2590    @classmethod
2591    def _validate_documentation(
2592        cls, value: FileSource_documentation
2593    ) -> FileSource_documentation:
2594        if not get_validation_context().perform_io_checks:
2595            return value
2596
2597        doc_reader = get_reader(value)
2598        doc_content = doc_reader.read().decode(encoding="utf-8")
2599        if not re.search("#.*[vV]alidation", doc_content):
2600            issue_warning(
2601                "No '# Validation' (sub)section found in {value}.",
2602                value=value,
2603                field="documentation",
2604            )
2605
2606        return value
2607
2608    inputs: NotEmpty[Sequence[InputTensorDescr]]
2609    """Describes the input tensors expected by this model."""
2610
2611    @field_validator("inputs", mode="after")
2612    @classmethod
2613    def _validate_input_axes(
2614        cls, inputs: Sequence[InputTensorDescr]
2615    ) -> Sequence[InputTensorDescr]:
2616        input_size_refs = cls._get_axes_with_independent_size(inputs)
2617
2618        for i, ipt in enumerate(inputs):
2619            valid_independent_refs: Dict[
2620                Tuple[TensorId, AxisId],
2621                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2622            ] = {
2623                **{
2624                    (ipt.id, a.id): (ipt, a, a.size)
2625                    for a in ipt.axes
2626                    if not isinstance(a, BatchAxis)
2627                    and isinstance(a.size, (int, ParameterizedSize))
2628                },
2629                **input_size_refs,
2630            }
2631            for a, ax in enumerate(ipt.axes):
2632                cls._validate_axis(
2633                    "inputs",
2634                    i=i,
2635                    tensor_id=ipt.id,
2636                    a=a,
2637                    axis=ax,
2638                    valid_independent_refs=valid_independent_refs,
2639                )
2640        return inputs
2641
2642    @staticmethod
2643    def _validate_axis(
2644        field_name: str,
2645        i: int,
2646        tensor_id: TensorId,
2647        a: int,
2648        axis: AnyAxis,
2649        valid_independent_refs: Dict[
2650            Tuple[TensorId, AxisId],
2651            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2652        ],
2653    ):
2654        if isinstance(axis, BatchAxis) or isinstance(
2655            axis.size, (int, ParameterizedSize, DataDependentSize)
2656        ):
2657            return
2658        elif not isinstance(axis.size, SizeReference):
2659            assert_never(axis.size)
2660
2661        # validate axis.size SizeReference
2662        ref = (axis.size.tensor_id, axis.size.axis_id)
2663        if ref not in valid_independent_refs:
2664            raise ValueError(
2665                "Invalid tensor axis reference at"
2666                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2667            )
2668        if ref == (tensor_id, axis.id):
2669            raise ValueError(
2670                "Self-referencing not allowed for"
2671                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2672            )
2673        if axis.type == "channel":
2674            if valid_independent_refs[ref][1].type != "channel":
2675                raise ValueError(
2676                    "A channel axis' size may only reference another fixed size"
2677                    + " channel axis."
2678                )
2679            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2680                ref_size = valid_independent_refs[ref][2]
2681                assert isinstance(ref_size, int), (
2682                    "channel axis ref (another channel axis) has to specify fixed"
2683                    + " size"
2684                )
2685                generated_channel_names = [
2686                    Identifier(axis.channel_names.format(i=i))
2687                    for i in range(1, ref_size + 1)
2688                ]
2689                axis.channel_names = generated_channel_names
2690
2691        if (ax_unit := getattr(axis, "unit", None)) != (
2692            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2693        ):
2694            raise ValueError(
2695                "The units of an axis and its reference axis need to match, but"
2696                + f" '{ax_unit}' != '{ref_unit}'."
2697            )
2698        ref_axis = valid_independent_refs[ref][1]
2699        if isinstance(ref_axis, BatchAxis):
2700            raise ValueError(
2701                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2702                + " (a batch axis is not allowed as reference)."
2703            )
2704
2705        if isinstance(axis, WithHalo):
2706            min_size = axis.size.get_size(axis, ref_axis, n=0)
2707            if (min_size - 2 * axis.halo) < 1:
2708                raise ValueError(
2709                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2710                    + f" {axis.halo}."
2711                )
2712
2713            input_halo = axis.halo * axis.scale / ref_axis.scale
2714            if input_halo != int(input_halo) or input_halo % 2 == 1:
2715                raise ValueError(
2716                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2717                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2718                    + f"     {tensor_id}.{axis.id}."
2719                )
2720
2721    @model_validator(mode="after")
2722    def _validate_test_tensors(self) -> Self:
2723        if not get_validation_context().perform_io_checks:
2724            return self
2725
2726        test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs]
2727        test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs]
2728
2729        tensors = {
2730            descr.id: (descr, array)
2731            for descr, array in zip(
2732                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2733            )
2734        }
2735        validate_tensors(tensors, tensor_origin="test_tensor")
2736
2737        output_arrays = {
2738            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2739        }
2740        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2741            if not rep_tol.absolute_tolerance:
2742                continue
2743
2744            if rep_tol.output_ids:
2745                out_arrays = {
2746                    oid: a
2747                    for oid, a in output_arrays.items()
2748                    if oid in rep_tol.output_ids
2749                }
2750            else:
2751                out_arrays = output_arrays
2752
2753            for out_id, array in out_arrays.items():
2754                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2755                    raise ValueError(
2756                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2757                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2758                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2759                    )
2760
2761        return self
2762
2763    @model_validator(mode="after")
2764    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2765        ipt_refs = {t.id for t in self.inputs}
2766        out_refs = {t.id for t in self.outputs}
2767        for ipt in self.inputs:
2768            for p in ipt.preprocessing:
2769                ref = p.kwargs.get("reference_tensor")
2770                if ref is None:
2771                    continue
2772                if ref not in ipt_refs:
2773                    raise ValueError(
2774                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2775                        + f" references are: {ipt_refs}."
2776                    )
2777
2778        for out in self.outputs:
2779            for p in out.postprocessing:
2780                ref = p.kwargs.get("reference_tensor")
2781                if ref is None:
2782                    continue
2783
2784                if ref not in ipt_refs and ref not in out_refs:
2785                    raise ValueError(
2786                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2787                        + f" are: {ipt_refs | out_refs}."
2788                    )
2789
2790        return self
2791
2792    # TODO: use validate funcs in validate_test_tensors
2793    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2794
2795    name: Annotated[
2796        Annotated[
2797            str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2798        ],
2799        MinLen(5),
2800        MaxLen(128),
2801        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2802    ]
2803    """A human-readable name of this model.
2804    It should be no longer than 64 characters
2805    and may only contain letter, number, underscore, minus, parentheses and spaces.
2806    We recommend to chose a name that refers to the model's task and image modality.
2807    """
2808
2809    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2810    """Describes the output tensors."""
2811
2812    @field_validator("outputs", mode="after")
2813    @classmethod
2814    def _validate_tensor_ids(
2815        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2816    ) -> Sequence[OutputTensorDescr]:
2817        tensor_ids = [
2818            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2819        ]
2820        duplicate_tensor_ids: List[str] = []
2821        seen: Set[str] = set()
2822        for t in tensor_ids:
2823            if t in seen:
2824                duplicate_tensor_ids.append(t)
2825
2826            seen.add(t)
2827
2828        if duplicate_tensor_ids:
2829            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2830
2831        return outputs
2832
2833    @staticmethod
2834    def _get_axes_with_parameterized_size(
2835        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2836    ):
2837        return {
2838            f"{t.id}.{a.id}": (t, a, a.size)
2839            for t in io
2840            for a in t.axes
2841            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2842        }
2843
2844    @staticmethod
2845    def _get_axes_with_independent_size(
2846        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2847    ):
2848        return {
2849            (t.id, a.id): (t, a, a.size)
2850            for t in io
2851            for a in t.axes
2852            if not isinstance(a, BatchAxis)
2853            and isinstance(a.size, (int, ParameterizedSize))
2854        }
2855
2856    @field_validator("outputs", mode="after")
2857    @classmethod
2858    def _validate_output_axes(
2859        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2860    ) -> List[OutputTensorDescr]:
2861        input_size_refs = cls._get_axes_with_independent_size(
2862            info.data.get("inputs", [])
2863        )
2864        output_size_refs = cls._get_axes_with_independent_size(outputs)
2865
2866        for i, out in enumerate(outputs):
2867            valid_independent_refs: Dict[
2868                Tuple[TensorId, AxisId],
2869                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2870            ] = {
2871                **{
2872                    (out.id, a.id): (out, a, a.size)
2873                    for a in out.axes
2874                    if not isinstance(a, BatchAxis)
2875                    and isinstance(a.size, (int, ParameterizedSize))
2876                },
2877                **input_size_refs,
2878                **output_size_refs,
2879            }
2880            for a, ax in enumerate(out.axes):
2881                cls._validate_axis(
2882                    "outputs",
2883                    i,
2884                    out.id,
2885                    a,
2886                    ax,
2887                    valid_independent_refs=valid_independent_refs,
2888                )
2889
2890        return outputs
2891
2892    packaged_by: List[Author] = Field(
2893        default_factory=cast(Callable[[], List[Author]], list)
2894    )
2895    """The persons that have packaged and uploaded this model.
2896    Only required if those persons differ from the `authors`."""
2897
2898    parent: Optional[LinkedModel] = None
2899    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2900
2901    @model_validator(mode="after")
2902    def _validate_parent_is_not_self(self) -> Self:
2903        if self.parent is not None and self.parent.id == self.id:
2904            raise ValueError("A model description may not reference itself as parent.")
2905
2906        return self
2907
2908    run_mode: Annotated[
2909        Optional[RunMode],
2910        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2911    ] = None
2912    """Custom run mode for this model: for more complex prediction procedures like test time
2913    data augmentation that currently cannot be expressed in the specification.
2914    No standard run modes are defined yet."""
2915
2916    timestamp: Datetime = Field(default_factory=Datetime.now)
2917    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2918    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2919    (In Python a datetime object is valid, too)."""
2920
2921    training_data: Annotated[
2922        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2923        Field(union_mode="left_to_right"),
2924    ] = None
2925    """The dataset used to train this model"""
2926
2927    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2928    """The weights for this model.
2929    Weights can be given for different formats, but should otherwise be equivalent.
2930    The available weight formats determine which consumers can use this model."""
2931
2932    config: Config = Field(default_factory=Config)
2933
2934    @model_validator(mode="after")
2935    def _add_default_cover(self) -> Self:
2936        if not get_validation_context().perform_io_checks or self.covers:
2937            return self
2938
2939        try:
2940            generated_covers = generate_covers(
2941                [(t, load_array(t.test_tensor)) for t in self.inputs],
2942                [(t, load_array(t.test_tensor)) for t in self.outputs],
2943            )
2944        except Exception as e:
2945            issue_warning(
2946                "Failed to generate cover image(s): {e}",
2947                value=self.covers,
2948                msg_context=dict(e=e),
2949                field="covers",
2950            )
2951        else:
2952            self.covers.extend(generated_covers)
2953
2954        return self
2955
2956    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2957        data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2958        assert all(isinstance(d, np.ndarray) for d in data)
2959        return data
2960
2961    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2962        data = [load_array(out.test_tensor) for out in self.outputs]
2963        assert all(isinstance(d, np.ndarray) for d in data)
2964        return data
2965
2966    @staticmethod
2967    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2968        batch_size = 1
2969        tensor_with_batchsize: Optional[TensorId] = None
2970        for tid in tensor_sizes:
2971            for aid, s in tensor_sizes[tid].items():
2972                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2973                    continue
2974
2975                if batch_size != 1:
2976                    assert tensor_with_batchsize is not None
2977                    raise ValueError(
2978                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2979                    )
2980
2981                batch_size = s
2982                tensor_with_batchsize = tid
2983
2984        return batch_size
2985
2986    def get_output_tensor_sizes(
2987        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2988    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2989        """Returns the tensor output sizes for given **input_sizes**.
2990        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2991        Otherwise it might be larger than the actual (valid) output"""
2992        batch_size = self.get_batch_size(input_sizes)
2993        ns = self.get_ns(input_sizes)
2994
2995        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2996        return tensor_sizes.outputs
2997
2998    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2999        """get parameter `n` for each parameterized axis
3000        such that the valid input size is >= the given input size"""
3001        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3002        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3003        for tid in input_sizes:
3004            for aid, s in input_sizes[tid].items():
3005                size_descr = axes[tid][aid].size
3006                if isinstance(size_descr, ParameterizedSize):
3007                    ret[(tid, aid)] = size_descr.get_n(s)
3008                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3009                    pass
3010                else:
3011                    assert_never(size_descr)
3012
3013        return ret
3014
3015    def get_tensor_sizes(
3016        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3017    ) -> _TensorSizes:
3018        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3019        return _TensorSizes(
3020            {
3021                t: {
3022                    aa: axis_sizes.inputs[(tt, aa)]
3023                    for tt, aa in axis_sizes.inputs
3024                    if tt == t
3025                }
3026                for t in {tt for tt, _ in axis_sizes.inputs}
3027            },
3028            {
3029                t: {
3030                    aa: axis_sizes.outputs[(tt, aa)]
3031                    for tt, aa in axis_sizes.outputs
3032                    if tt == t
3033                }
3034                for t in {tt for tt, _ in axis_sizes.outputs}
3035            },
3036        )
3037
3038    def get_axis_sizes(
3039        self,
3040        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3041        batch_size: Optional[int] = None,
3042        *,
3043        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3044    ) -> _AxisSizes:
3045        """Determine input and output block shape for scale factors **ns**
3046        of parameterized input sizes.
3047
3048        Args:
3049            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3050                that is parameterized as `size = min + n * step`.
3051            batch_size: The desired size of the batch dimension.
3052                If given **batch_size** overwrites any batch size present in
3053                **max_input_shape**. Default 1.
3054            max_input_shape: Limits the derived block shapes.
3055                Each axis for which the input size, parameterized by `n`, is larger
3056                than **max_input_shape** is set to the minimal value `n_min` for which
3057                this is still true.
3058                Use this for small input samples or large values of **ns**.
3059                Or simply whenever you know the full input shape.
3060
3061        Returns:
3062            Resolved axis sizes for model inputs and outputs.
3063        """
3064        max_input_shape = max_input_shape or {}
3065        if batch_size is None:
3066            for (_t_id, a_id), s in max_input_shape.items():
3067                if a_id == BATCH_AXIS_ID:
3068                    batch_size = s
3069                    break
3070            else:
3071                batch_size = 1
3072
3073        all_axes = {
3074            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3075        }
3076
3077        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3078        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3079
3080        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3081            if isinstance(a, BatchAxis):
3082                if (t_descr.id, a.id) in ns:
3083                    logger.warning(
3084                        "Ignoring unexpected size increment factor (n) for batch axis"
3085                        + " of tensor '{}'.",
3086                        t_descr.id,
3087                    )
3088                return batch_size
3089            elif isinstance(a.size, int):
3090                if (t_descr.id, a.id) in ns:
3091                    logger.warning(
3092                        "Ignoring unexpected size increment factor (n) for fixed size"
3093                        + " axis '{}' of tensor '{}'.",
3094                        a.id,
3095                        t_descr.id,
3096                    )
3097                return a.size
3098            elif isinstance(a.size, ParameterizedSize):
3099                if (t_descr.id, a.id) not in ns:
3100                    raise ValueError(
3101                        "Size increment factor (n) missing for parametrized axis"
3102                        + f" '{a.id}' of tensor '{t_descr.id}'."
3103                    )
3104                n = ns[(t_descr.id, a.id)]
3105                s_max = max_input_shape.get((t_descr.id, a.id))
3106                if s_max is not None:
3107                    n = min(n, a.size.get_n(s_max))
3108
3109                return a.size.get_size(n)
3110
3111            elif isinstance(a.size, SizeReference):
3112                if (t_descr.id, a.id) in ns:
3113                    logger.warning(
3114                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3115                        + " of tensor '{}' with size reference.",
3116                        a.id,
3117                        t_descr.id,
3118                    )
3119                assert not isinstance(a, BatchAxis)
3120                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3121                assert not isinstance(ref_axis, BatchAxis)
3122                ref_key = (a.size.tensor_id, a.size.axis_id)
3123                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3124                assert ref_size is not None, ref_key
3125                assert not isinstance(ref_size, _DataDepSize), ref_key
3126                return a.size.get_size(
3127                    axis=a,
3128                    ref_axis=ref_axis,
3129                    ref_size=ref_size,
3130                )
3131            elif isinstance(a.size, DataDependentSize):
3132                if (t_descr.id, a.id) in ns:
3133                    logger.warning(
3134                        "Ignoring unexpected increment factor (n) for data dependent"
3135                        + " size axis '{}' of tensor '{}'.",
3136                        a.id,
3137                        t_descr.id,
3138                    )
3139                return _DataDepSize(a.size.min, a.size.max)
3140            else:
3141                assert_never(a.size)
3142
3143        # first resolve all , but the `SizeReference` input sizes
3144        for t_descr in self.inputs:
3145            for a in t_descr.axes:
3146                if not isinstance(a.size, SizeReference):
3147                    s = get_axis_size(a)
3148                    assert not isinstance(s, _DataDepSize)
3149                    inputs[t_descr.id, a.id] = s
3150
3151        # resolve all other input axis sizes
3152        for t_descr in self.inputs:
3153            for a in t_descr.axes:
3154                if isinstance(a.size, SizeReference):
3155                    s = get_axis_size(a)
3156                    assert not isinstance(s, _DataDepSize)
3157                    inputs[t_descr.id, a.id] = s
3158
3159        # resolve all output axis sizes
3160        for t_descr in self.outputs:
3161            for a in t_descr.axes:
3162                assert not isinstance(a.size, ParameterizedSize)
3163                s = get_axis_size(a)
3164                outputs[t_descr.id, a.id] = s
3165
3166        return _AxisSizes(inputs=inputs, outputs=outputs)
3167
3168    @model_validator(mode="before")
3169    @classmethod
3170    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3171        cls.convert_from_old_format_wo_validation(data)
3172        return data
3173
3174    @classmethod
3175    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3176        """Convert metadata following an older format version to this classes' format
3177        without validating the result.
3178        """
3179        if (
3180            data.get("type") == "model"
3181            and isinstance(fv := data.get("format_version"), str)
3182            and fv.count(".") == 2
3183        ):
3184            fv_parts = fv.split(".")
3185            if any(not p.isdigit() for p in fv_parts):
3186                return
3187
3188            fv_tuple = tuple(map(int, fv_parts))
3189
3190            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3191            if fv_tuple[:2] in ((0, 3), (0, 4)):
3192                m04 = _ModelDescr_v0_4.load(data)
3193                if isinstance(m04, InvalidDescr):
3194                    try:
3195                        updated = _model_conv.convert_as_dict(
3196                            m04  # pyright: ignore[reportArgumentType]
3197                        )
3198                    except Exception as e:
3199                        logger.error(
3200                            "Failed to convert from invalid model 0.4 description."
3201                            + f"\nerror: {e}"
3202                            + "\nProceeding with model 0.5 validation without conversion."
3203                        )
3204                        updated = None
3205                else:
3206                    updated = _model_conv.convert_as_dict(m04)
3207
3208                if updated is not None:
3209                    data.clear()
3210                    data.update(updated)
3211
3212            elif fv_tuple[:2] == (0, 5):
3213                # bump patch version
3214                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.4']] = '0.5.4'
implemented_type: ClassVar[Literal['model']] = 'model'

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], MinLen(min_length=1)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f5fe201bd80>), PlainSerializer(func=<function _package_serializer at 0x7f5fe20b6d40>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.md', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]

URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[bioimageio.spec.model.v0_5.InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f5fdef5bec0>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[bioimageio.spec.model.v0_5.OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f5fdef5ba60>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[bioimageio.spec.model.v0_5.WeightsDescr, WrapSerializer(func=<function package_weights at 0x7f5fe1ef1b20>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

def get_input_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
2956    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2957        data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2958        assert all(isinstance(d, np.ndarray) for d in data)
2959        return data
def get_output_test_arrays(self) -> List[numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]]]:
2961    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2962        data = [load_array(out.test_tensor) for out in self.outputs]
2963        assert all(isinstance(d, np.ndarray) for d in data)
2964        return data
@staticmethod
def get_batch_size( tensor_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> int:
2966    @staticmethod
2967    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2968        batch_size = 1
2969        tensor_with_batchsize: Optional[TensorId] = None
2970        for tid in tensor_sizes:
2971            for aid, s in tensor_sizes[tid].items():
2972                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2973                    continue
2974
2975                if batch_size != 1:
2976                    assert tensor_with_batchsize is not None
2977                    raise ValueError(
2978                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2979                    )
2980
2981                batch_size = s
2982                tensor_with_batchsize = tid
2983
2984        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Dict[bioimageio.spec.model.v0_5.TensorId, Dict[bioimageio.spec.model.v0_5.AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
2986    def get_output_tensor_sizes(
2987        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2988    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2989        """Returns the tensor output sizes for given **input_sizes**.
2990        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2991        Otherwise it might be larger than the actual (valid) output"""
2992        batch_size = self.get_batch_size(input_sizes)
2993        ns = self.get_ns(input_sizes)
2994
2995        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2996        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]):
2998    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2999        """get parameter `n` for each parameterized axis
3000        such that the valid input size is >= the given input size"""
3001        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
3002        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
3003        for tid in input_sizes:
3004            for aid, s in input_sizes[tid].items():
3005                size_descr = axes[tid][aid].size
3006                if isinstance(size_descr, ParameterizedSize):
3007                    ret[(tid, aid)] = size_descr.get_n(s)
3008                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
3009                    pass
3010                else:
3011                    assert_never(size_descr)
3012
3013        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
3015    def get_tensor_sizes(
3016        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
3017    ) -> _TensorSizes:
3018        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
3019        return _TensorSizes(
3020            {
3021                t: {
3022                    aa: axis_sizes.inputs[(tt, aa)]
3023                    for tt, aa in axis_sizes.inputs
3024                    if tt == t
3025                }
3026                for t in {tt for tt, _ in axis_sizes.inputs}
3027            },
3028            {
3029                t: {
3030                    aa: axis_sizes.outputs[(tt, aa)]
3031                    for tt, aa in axis_sizes.outputs
3032                    if tt == t
3033                }
3034                for t in {tt for tt, _ in axis_sizes.outputs}
3035            },
3036        )
def get_axis_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3038    def get_axis_sizes(
3039        self,
3040        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3041        batch_size: Optional[int] = None,
3042        *,
3043        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3044    ) -> _AxisSizes:
3045        """Determine input and output block shape for scale factors **ns**
3046        of parameterized input sizes.
3047
3048        Args:
3049            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3050                that is parameterized as `size = min + n * step`.
3051            batch_size: The desired size of the batch dimension.
3052                If given **batch_size** overwrites any batch size present in
3053                **max_input_shape**. Default 1.
3054            max_input_shape: Limits the derived block shapes.
3055                Each axis for which the input size, parameterized by `n`, is larger
3056                than **max_input_shape** is set to the minimal value `n_min` for which
3057                this is still true.
3058                Use this for small input samples or large values of **ns**.
3059                Or simply whenever you know the full input shape.
3060
3061        Returns:
3062            Resolved axis sizes for model inputs and outputs.
3063        """
3064        max_input_shape = max_input_shape or {}
3065        if batch_size is None:
3066            for (_t_id, a_id), s in max_input_shape.items():
3067                if a_id == BATCH_AXIS_ID:
3068                    batch_size = s
3069                    break
3070            else:
3071                batch_size = 1
3072
3073        all_axes = {
3074            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3075        }
3076
3077        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3078        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3079
3080        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3081            if isinstance(a, BatchAxis):
3082                if (t_descr.id, a.id) in ns:
3083                    logger.warning(
3084                        "Ignoring unexpected size increment factor (n) for batch axis"
3085                        + " of tensor '{}'.",
3086                        t_descr.id,
3087                    )
3088                return batch_size
3089            elif isinstance(a.size, int):
3090                if (t_descr.id, a.id) in ns:
3091                    logger.warning(
3092                        "Ignoring unexpected size increment factor (n) for fixed size"
3093                        + " axis '{}' of tensor '{}'.",
3094                        a.id,
3095                        t_descr.id,
3096                    )
3097                return a.size
3098            elif isinstance(a.size, ParameterizedSize):
3099                if (t_descr.id, a.id) not in ns:
3100                    raise ValueError(
3101                        "Size increment factor (n) missing for parametrized axis"
3102                        + f" '{a.id}' of tensor '{t_descr.id}'."
3103                    )
3104                n = ns[(t_descr.id, a.id)]
3105                s_max = max_input_shape.get((t_descr.id, a.id))
3106                if s_max is not None:
3107                    n = min(n, a.size.get_n(s_max))
3108
3109                return a.size.get_size(n)
3110
3111            elif isinstance(a.size, SizeReference):
3112                if (t_descr.id, a.id) in ns:
3113                    logger.warning(
3114                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3115                        + " of tensor '{}' with size reference.",
3116                        a.id,
3117                        t_descr.id,
3118                    )
3119                assert not isinstance(a, BatchAxis)
3120                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3121                assert not isinstance(ref_axis, BatchAxis)
3122                ref_key = (a.size.tensor_id, a.size.axis_id)
3123                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3124                assert ref_size is not None, ref_key
3125                assert not isinstance(ref_size, _DataDepSize), ref_key
3126                return a.size.get_size(
3127                    axis=a,
3128                    ref_axis=ref_axis,
3129                    ref_size=ref_size,
3130                )
3131            elif isinstance(a.size, DataDependentSize):
3132                if (t_descr.id, a.id) in ns:
3133                    logger.warning(
3134                        "Ignoring unexpected increment factor (n) for data dependent"
3135                        + " size axis '{}' of tensor '{}'.",
3136                        a.id,
3137                        t_descr.id,
3138                    )
3139                return _DataDepSize(a.size.min, a.size.max)
3140            else:
3141                assert_never(a.size)
3142
3143        # first resolve all , but the `SizeReference` input sizes
3144        for t_descr in self.inputs:
3145            for a in t_descr.axes:
3146                if not isinstance(a.size, SizeReference):
3147                    s = get_axis_size(a)
3148                    assert not isinstance(s, _DataDepSize)
3149                    inputs[t_descr.id, a.id] = s
3150
3151        # resolve all other input axis sizes
3152        for t_descr in self.inputs:
3153            for a in t_descr.axes:
3154                if isinstance(a.size, SizeReference):
3155                    s = get_axis_size(a)
3156                    assert not isinstance(s, _DataDepSize)
3157                    inputs[t_descr.id, a.id] = s
3158
3159        # resolve all output axis sizes
3160        for t_descr in self.outputs:
3161            for a in t_descr.axes:
3162                assert not isinstance(a.size, ParameterizedSize)
3163                s = get_axis_size(a)
3164                outputs[t_descr.id, a.id] = s
3165
3166        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3174    @classmethod
3175    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3176        """Convert metadata following an older format version to this classes' format
3177        without validating the result.
3178        """
3179        if (
3180            data.get("type") == "model"
3181            and isinstance(fv := data.get("format_version"), str)
3182            and fv.count(".") == 2
3183        ):
3184            fv_parts = fv.split(".")
3185            if any(not p.isdigit() for p in fv_parts):
3186                return
3187
3188            fv_tuple = tuple(map(int, fv_parts))
3189
3190            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3191            if fv_tuple[:2] in ((0, 3), (0, 4)):
3192                m04 = _ModelDescr_v0_4.load(data)
3193                if isinstance(m04, InvalidDescr):
3194                    try:
3195                        updated = _model_conv.convert_as_dict(
3196                            m04  # pyright: ignore[reportArgumentType]
3197                        )
3198                    except Exception as e:
3199                        logger.error(
3200                            "Failed to convert from invalid model 0.4 description."
3201                            + f"\nerror: {e}"
3202                            + "\nProceeding with model 0.5 validation without conversion."
3203                        )
3204                        updated = None
3205                else:
3206                    updated = _model_conv.convert_as_dict(m04)
3207
3208                if updated is not None:
3209                    data.clear()
3210                    data.update(updated)
3211
3212            elif fv_tuple[:2] == (0, 5):
3213                # bump patch version
3214                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 4)
model_config: ClassVar[pydantic.config.ConfigDict] = {'allow_inf_nan': False, 'extra': 'forbid', 'frozen': False, 'model_title_generator': <function _node_title_generator>, 'populate_by_name': True, 'revalidate_instances': 'never', 'use_attribute_docstrings': True, 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
337def init_private_attributes(self: BaseModel, context: Any, /) -> None:
338    """This function is meant to behave like a BaseModel method to initialise private attributes.
339
340    It takes context as an argument since that's what pydantic-core passes when calling it.
341
342    Args:
343        self: The BaseModel instance.
344        context: The context.
345    """
346    if getattr(self, '__pydantic_private__', None) is None:
347        pydantic_private = {}
348        for name, private_attr in self.__private_attributes__.items():
349            default = private_attr.get_default()
350            if default is not PydanticUndefined:
351                pydantic_private[name] = default
352        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Arguments:
  • self: The BaseModel instance.
  • context: The context.
ModelDescr_v0_4 = <class 'bioimageio.spec.model.v0_4.ModelDescr'>
ModelDescr_v0_5 = <class 'ModelDescr'>
AnyModelDescr = typing.Annotated[typing.Union[typing.Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], typing.Annotated[ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]

Union of any released model desription