bioimageio.spec.model

implementaions of all released minor versions are available in submodules:

 1# autogen: start
 2"""
 3implementaions of all released minor versions are available in submodules:
 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr`
 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr`
 6"""
 7
 8from typing import Union
 9
10from pydantic import Discriminator, Field
11from typing_extensions import Annotated
12
13from . import v0_4, v0_5
14
15ModelDescr = v0_5.ModelDescr
16ModelDescr_v0_4 = v0_4.ModelDescr
17ModelDescr_v0_5 = v0_5.ModelDescr
18
19AnyModelDescr = Annotated[
20    Union[
21        Annotated[ModelDescr_v0_4, Field(title="model 0.4")],
22        Annotated[ModelDescr_v0_5, Field(title="model 0.5")],
23    ],
24    Discriminator("format_version"),
25    Field(title="model"),
26]
27"""Union of any released model desription"""
28# autogen: stop
2531class ModelDescr(GenericModelDescrBase):
2532    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2533    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2534    """
2535
2536    implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2537    if TYPE_CHECKING:
2538        format_version: Literal["0.5.4"] = "0.5.4"
2539    else:
2540        format_version: Literal["0.5.4"]
2541        """Version of the bioimage.io model description specification used.
2542        When creating a new model always use the latest micro/patch version described here.
2543        The `format_version` is important for any consumer software to understand how to parse the fields.
2544        """
2545
2546    implemented_type: ClassVar[Literal["model"]] = "model"
2547    if TYPE_CHECKING:
2548        type: Literal["model"] = "model"
2549    else:
2550        type: Literal["model"]
2551        """Specialized resource type 'model'"""
2552
2553    id: Optional[ModelId] = None
2554    """bioimage.io-wide unique resource identifier
2555    assigned by bioimage.io; version **un**specific."""
2556
2557    authors: NotEmpty[List[Author]]
2558    """The authors are the creators of the model RDF and the primary points of contact."""
2559
2560    documentation: FileSource_documentation
2561    """URL or relative path to a markdown file with additional documentation.
2562    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2563    The documentation should include a '#[#] Validation' (sub)section
2564    with details on how to quantitatively validate the model on unseen data."""
2565
2566    @field_validator("documentation", mode="after")
2567    @classmethod
2568    def _validate_documentation(
2569        cls, value: FileSource_documentation
2570    ) -> FileSource_documentation:
2571        if not get_validation_context().perform_io_checks:
2572            return value
2573
2574        doc_reader = get_reader(value)
2575        doc_content = doc_reader.read().decode(encoding="utf-8")
2576        if not re.search("#.*[vV]alidation", doc_content):
2577            issue_warning(
2578                "No '# Validation' (sub)section found in {value}.",
2579                value=value,
2580                field="documentation",
2581            )
2582
2583        return value
2584
2585    inputs: NotEmpty[Sequence[InputTensorDescr]]
2586    """Describes the input tensors expected by this model."""
2587
2588    @field_validator("inputs", mode="after")
2589    @classmethod
2590    def _validate_input_axes(
2591        cls, inputs: Sequence[InputTensorDescr]
2592    ) -> Sequence[InputTensorDescr]:
2593        input_size_refs = cls._get_axes_with_independent_size(inputs)
2594
2595        for i, ipt in enumerate(inputs):
2596            valid_independent_refs: Dict[
2597                Tuple[TensorId, AxisId],
2598                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2599            ] = {
2600                **{
2601                    (ipt.id, a.id): (ipt, a, a.size)
2602                    for a in ipt.axes
2603                    if not isinstance(a, BatchAxis)
2604                    and isinstance(a.size, (int, ParameterizedSize))
2605                },
2606                **input_size_refs,
2607            }
2608            for a, ax in enumerate(ipt.axes):
2609                cls._validate_axis(
2610                    "inputs",
2611                    i=i,
2612                    tensor_id=ipt.id,
2613                    a=a,
2614                    axis=ax,
2615                    valid_independent_refs=valid_independent_refs,
2616                )
2617        return inputs
2618
2619    @staticmethod
2620    def _validate_axis(
2621        field_name: str,
2622        i: int,
2623        tensor_id: TensorId,
2624        a: int,
2625        axis: AnyAxis,
2626        valid_independent_refs: Dict[
2627            Tuple[TensorId, AxisId],
2628            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2629        ],
2630    ):
2631        if isinstance(axis, BatchAxis) or isinstance(
2632            axis.size, (int, ParameterizedSize, DataDependentSize)
2633        ):
2634            return
2635        elif not isinstance(axis.size, SizeReference):
2636            assert_never(axis.size)
2637
2638        # validate axis.size SizeReference
2639        ref = (axis.size.tensor_id, axis.size.axis_id)
2640        if ref not in valid_independent_refs:
2641            raise ValueError(
2642                "Invalid tensor axis reference at"
2643                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2644            )
2645        if ref == (tensor_id, axis.id):
2646            raise ValueError(
2647                "Self-referencing not allowed for"
2648                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2649            )
2650        if axis.type == "channel":
2651            if valid_independent_refs[ref][1].type != "channel":
2652                raise ValueError(
2653                    "A channel axis' size may only reference another fixed size"
2654                    + " channel axis."
2655                )
2656            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2657                ref_size = valid_independent_refs[ref][2]
2658                assert isinstance(ref_size, int), (
2659                    "channel axis ref (another channel axis) has to specify fixed"
2660                    + " size"
2661                )
2662                generated_channel_names = [
2663                    Identifier(axis.channel_names.format(i=i))
2664                    for i in range(1, ref_size + 1)
2665                ]
2666                axis.channel_names = generated_channel_names
2667
2668        if (ax_unit := getattr(axis, "unit", None)) != (
2669            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2670        ):
2671            raise ValueError(
2672                "The units of an axis and its reference axis need to match, but"
2673                + f" '{ax_unit}' != '{ref_unit}'."
2674            )
2675        ref_axis = valid_independent_refs[ref][1]
2676        if isinstance(ref_axis, BatchAxis):
2677            raise ValueError(
2678                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2679                + " (a batch axis is not allowed as reference)."
2680            )
2681
2682        if isinstance(axis, WithHalo):
2683            min_size = axis.size.get_size(axis, ref_axis, n=0)
2684            if (min_size - 2 * axis.halo) < 1:
2685                raise ValueError(
2686                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2687                    + f" {axis.halo}."
2688                )
2689
2690            input_halo = axis.halo * axis.scale / ref_axis.scale
2691            if input_halo != int(input_halo) or input_halo % 2 == 1:
2692                raise ValueError(
2693                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2694                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2695                    + f"     {tensor_id}.{axis.id}."
2696                )
2697
2698    @model_validator(mode="after")
2699    def _validate_test_tensors(self) -> Self:
2700        if not get_validation_context().perform_io_checks:
2701            return self
2702
2703        test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs]
2704        test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs]
2705
2706        tensors = {
2707            descr.id: (descr, array)
2708            for descr, array in zip(
2709                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2710            )
2711        }
2712        validate_tensors(tensors, tensor_origin="test_tensor")
2713
2714        output_arrays = {
2715            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2716        }
2717        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2718            if not rep_tol.absolute_tolerance:
2719                continue
2720
2721            if rep_tol.output_ids:
2722                out_arrays = {
2723                    oid: a
2724                    for oid, a in output_arrays.items()
2725                    if oid in rep_tol.output_ids
2726                }
2727            else:
2728                out_arrays = output_arrays
2729
2730            for out_id, array in out_arrays.items():
2731                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2732                    raise ValueError(
2733                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2734                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2735                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2736                    )
2737
2738        return self
2739
2740    @model_validator(mode="after")
2741    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2742        ipt_refs = {t.id for t in self.inputs}
2743        out_refs = {t.id for t in self.outputs}
2744        for ipt in self.inputs:
2745            for p in ipt.preprocessing:
2746                ref = p.kwargs.get("reference_tensor")
2747                if ref is None:
2748                    continue
2749                if ref not in ipt_refs:
2750                    raise ValueError(
2751                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2752                        + f" references are: {ipt_refs}."
2753                    )
2754
2755        for out in self.outputs:
2756            for p in out.postprocessing:
2757                ref = p.kwargs.get("reference_tensor")
2758                if ref is None:
2759                    continue
2760
2761                if ref not in ipt_refs and ref not in out_refs:
2762                    raise ValueError(
2763                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2764                        + f" are: {ipt_refs | out_refs}."
2765                    )
2766
2767        return self
2768
2769    # TODO: use validate funcs in validate_test_tensors
2770    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2771
2772    name: Annotated[
2773        Annotated[
2774            str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2775        ],
2776        MinLen(5),
2777        MaxLen(128),
2778        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2779    ]
2780    """A human-readable name of this model.
2781    It should be no longer than 64 characters
2782    and may only contain letter, number, underscore, minus, parentheses and spaces.
2783    We recommend to chose a name that refers to the model's task and image modality.
2784    """
2785
2786    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2787    """Describes the output tensors."""
2788
2789    @field_validator("outputs", mode="after")
2790    @classmethod
2791    def _validate_tensor_ids(
2792        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2793    ) -> Sequence[OutputTensorDescr]:
2794        tensor_ids = [
2795            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2796        ]
2797        duplicate_tensor_ids: List[str] = []
2798        seen: Set[str] = set()
2799        for t in tensor_ids:
2800            if t in seen:
2801                duplicate_tensor_ids.append(t)
2802
2803            seen.add(t)
2804
2805        if duplicate_tensor_ids:
2806            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2807
2808        return outputs
2809
2810    @staticmethod
2811    def _get_axes_with_parameterized_size(
2812        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2813    ):
2814        return {
2815            f"{t.id}.{a.id}": (t, a, a.size)
2816            for t in io
2817            for a in t.axes
2818            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2819        }
2820
2821    @staticmethod
2822    def _get_axes_with_independent_size(
2823        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2824    ):
2825        return {
2826            (t.id, a.id): (t, a, a.size)
2827            for t in io
2828            for a in t.axes
2829            if not isinstance(a, BatchAxis)
2830            and isinstance(a.size, (int, ParameterizedSize))
2831        }
2832
2833    @field_validator("outputs", mode="after")
2834    @classmethod
2835    def _validate_output_axes(
2836        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2837    ) -> List[OutputTensorDescr]:
2838        input_size_refs = cls._get_axes_with_independent_size(
2839            info.data.get("inputs", [])
2840        )
2841        output_size_refs = cls._get_axes_with_independent_size(outputs)
2842
2843        for i, out in enumerate(outputs):
2844            valid_independent_refs: Dict[
2845                Tuple[TensorId, AxisId],
2846                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2847            ] = {
2848                **{
2849                    (out.id, a.id): (out, a, a.size)
2850                    for a in out.axes
2851                    if not isinstance(a, BatchAxis)
2852                    and isinstance(a.size, (int, ParameterizedSize))
2853                },
2854                **input_size_refs,
2855                **output_size_refs,
2856            }
2857            for a, ax in enumerate(out.axes):
2858                cls._validate_axis(
2859                    "outputs",
2860                    i,
2861                    out.id,
2862                    a,
2863                    ax,
2864                    valid_independent_refs=valid_independent_refs,
2865                )
2866
2867        return outputs
2868
2869    packaged_by: List[Author] = Field(
2870        default_factory=cast(Callable[[], List[Author]], list)
2871    )
2872    """The persons that have packaged and uploaded this model.
2873    Only required if those persons differ from the `authors`."""
2874
2875    parent: Optional[LinkedModel] = None
2876    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2877
2878    @model_validator(mode="after")
2879    def _validate_parent_is_not_self(self) -> Self:
2880        if self.parent is not None and self.parent.id == self.id:
2881            raise ValueError("A model description may not reference itself as parent.")
2882
2883        return self
2884
2885    run_mode: Annotated[
2886        Optional[RunMode],
2887        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2888    ] = None
2889    """Custom run mode for this model: for more complex prediction procedures like test time
2890    data augmentation that currently cannot be expressed in the specification.
2891    No standard run modes are defined yet."""
2892
2893    timestamp: Datetime = Field(default_factory=Datetime.now)
2894    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2895    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2896    (In Python a datetime object is valid, too)."""
2897
2898    training_data: Annotated[
2899        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2900        Field(union_mode="left_to_right"),
2901    ] = None
2902    """The dataset used to train this model"""
2903
2904    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2905    """The weights for this model.
2906    Weights can be given for different formats, but should otherwise be equivalent.
2907    The available weight formats determine which consumers can use this model."""
2908
2909    config: Config = Field(default_factory=Config)
2910
2911    @model_validator(mode="after")
2912    def _add_default_cover(self) -> Self:
2913        if not get_validation_context().perform_io_checks or self.covers:
2914            return self
2915
2916        try:
2917            generated_covers = generate_covers(
2918                [(t, load_array(t.test_tensor)) for t in self.inputs],
2919                [(t, load_array(t.test_tensor)) for t in self.outputs],
2920            )
2921        except Exception as e:
2922            issue_warning(
2923                "Failed to generate cover image(s): {e}",
2924                value=self.covers,
2925                msg_context=dict(e=e),
2926                field="covers",
2927            )
2928        else:
2929            self.covers.extend(generated_covers)
2930
2931        return self
2932
2933    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2934        data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2935        assert all(isinstance(d, np.ndarray) for d in data)
2936        return data
2937
2938    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2939        data = [load_array(out.test_tensor) for out in self.outputs]
2940        assert all(isinstance(d, np.ndarray) for d in data)
2941        return data
2942
2943    @staticmethod
2944    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2945        batch_size = 1
2946        tensor_with_batchsize: Optional[TensorId] = None
2947        for tid in tensor_sizes:
2948            for aid, s in tensor_sizes[tid].items():
2949                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2950                    continue
2951
2952                if batch_size != 1:
2953                    assert tensor_with_batchsize is not None
2954                    raise ValueError(
2955                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2956                    )
2957
2958                batch_size = s
2959                tensor_with_batchsize = tid
2960
2961        return batch_size
2962
2963    def get_output_tensor_sizes(
2964        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2965    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2966        """Returns the tensor output sizes for given **input_sizes**.
2967        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2968        Otherwise it might be larger than the actual (valid) output"""
2969        batch_size = self.get_batch_size(input_sizes)
2970        ns = self.get_ns(input_sizes)
2971
2972        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2973        return tensor_sizes.outputs
2974
2975    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2976        """get parameter `n` for each parameterized axis
2977        such that the valid input size is >= the given input size"""
2978        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2979        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2980        for tid in input_sizes:
2981            for aid, s in input_sizes[tid].items():
2982                size_descr = axes[tid][aid].size
2983                if isinstance(size_descr, ParameterizedSize):
2984                    ret[(tid, aid)] = size_descr.get_n(s)
2985                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2986                    pass
2987                else:
2988                    assert_never(size_descr)
2989
2990        return ret
2991
2992    def get_tensor_sizes(
2993        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2994    ) -> _TensorSizes:
2995        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2996        return _TensorSizes(
2997            {
2998                t: {
2999                    aa: axis_sizes.inputs[(tt, aa)]
3000                    for tt, aa in axis_sizes.inputs
3001                    if tt == t
3002                }
3003                for t in {tt for tt, _ in axis_sizes.inputs}
3004            },
3005            {
3006                t: {
3007                    aa: axis_sizes.outputs[(tt, aa)]
3008                    for tt, aa in axis_sizes.outputs
3009                    if tt == t
3010                }
3011                for t in {tt for tt, _ in axis_sizes.outputs}
3012            },
3013        )
3014
3015    def get_axis_sizes(
3016        self,
3017        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3018        batch_size: Optional[int] = None,
3019        *,
3020        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3021    ) -> _AxisSizes:
3022        """Determine input and output block shape for scale factors **ns**
3023        of parameterized input sizes.
3024
3025        Args:
3026            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3027                that is parameterized as `size = min + n * step`.
3028            batch_size: The desired size of the batch dimension.
3029                If given **batch_size** overwrites any batch size present in
3030                **max_input_shape**. Default 1.
3031            max_input_shape: Limits the derived block shapes.
3032                Each axis for which the input size, parameterized by `n`, is larger
3033                than **max_input_shape** is set to the minimal value `n_min` for which
3034                this is still true.
3035                Use this for small input samples or large values of **ns**.
3036                Or simply whenever you know the full input shape.
3037
3038        Returns:
3039            Resolved axis sizes for model inputs and outputs.
3040        """
3041        max_input_shape = max_input_shape or {}
3042        if batch_size is None:
3043            for (_t_id, a_id), s in max_input_shape.items():
3044                if a_id == BATCH_AXIS_ID:
3045                    batch_size = s
3046                    break
3047            else:
3048                batch_size = 1
3049
3050        all_axes = {
3051            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3052        }
3053
3054        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3055        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3056
3057        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3058            if isinstance(a, BatchAxis):
3059                if (t_descr.id, a.id) in ns:
3060                    logger.warning(
3061                        "Ignoring unexpected size increment factor (n) for batch axis"
3062                        + " of tensor '{}'.",
3063                        t_descr.id,
3064                    )
3065                return batch_size
3066            elif isinstance(a.size, int):
3067                if (t_descr.id, a.id) in ns:
3068                    logger.warning(
3069                        "Ignoring unexpected size increment factor (n) for fixed size"
3070                        + " axis '{}' of tensor '{}'.",
3071                        a.id,
3072                        t_descr.id,
3073                    )
3074                return a.size
3075            elif isinstance(a.size, ParameterizedSize):
3076                if (t_descr.id, a.id) not in ns:
3077                    raise ValueError(
3078                        "Size increment factor (n) missing for parametrized axis"
3079                        + f" '{a.id}' of tensor '{t_descr.id}'."
3080                    )
3081                n = ns[(t_descr.id, a.id)]
3082                s_max = max_input_shape.get((t_descr.id, a.id))
3083                if s_max is not None:
3084                    n = min(n, a.size.get_n(s_max))
3085
3086                return a.size.get_size(n)
3087
3088            elif isinstance(a.size, SizeReference):
3089                if (t_descr.id, a.id) in ns:
3090                    logger.warning(
3091                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3092                        + " of tensor '{}' with size reference.",
3093                        a.id,
3094                        t_descr.id,
3095                    )
3096                assert not isinstance(a, BatchAxis)
3097                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3098                assert not isinstance(ref_axis, BatchAxis)
3099                ref_key = (a.size.tensor_id, a.size.axis_id)
3100                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3101                assert ref_size is not None, ref_key
3102                assert not isinstance(ref_size, _DataDepSize), ref_key
3103                return a.size.get_size(
3104                    axis=a,
3105                    ref_axis=ref_axis,
3106                    ref_size=ref_size,
3107                )
3108            elif isinstance(a.size, DataDependentSize):
3109                if (t_descr.id, a.id) in ns:
3110                    logger.warning(
3111                        "Ignoring unexpected increment factor (n) for data dependent"
3112                        + " size axis '{}' of tensor '{}'.",
3113                        a.id,
3114                        t_descr.id,
3115                    )
3116                return _DataDepSize(a.size.min, a.size.max)
3117            else:
3118                assert_never(a.size)
3119
3120        # first resolve all , but the `SizeReference` input sizes
3121        for t_descr in self.inputs:
3122            for a in t_descr.axes:
3123                if not isinstance(a.size, SizeReference):
3124                    s = get_axis_size(a)
3125                    assert not isinstance(s, _DataDepSize)
3126                    inputs[t_descr.id, a.id] = s
3127
3128        # resolve all other input axis sizes
3129        for t_descr in self.inputs:
3130            for a in t_descr.axes:
3131                if isinstance(a.size, SizeReference):
3132                    s = get_axis_size(a)
3133                    assert not isinstance(s, _DataDepSize)
3134                    inputs[t_descr.id, a.id] = s
3135
3136        # resolve all output axis sizes
3137        for t_descr in self.outputs:
3138            for a in t_descr.axes:
3139                assert not isinstance(a.size, ParameterizedSize)
3140                s = get_axis_size(a)
3141                outputs[t_descr.id, a.id] = s
3142
3143        return _AxisSizes(inputs=inputs, outputs=outputs)
3144
3145    @model_validator(mode="before")
3146    @classmethod
3147    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3148        cls.convert_from_old_format_wo_validation(data)
3149        return data
3150
3151    @classmethod
3152    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3153        """Convert metadata following an older format version to this classes' format
3154        without validating the result.
3155        """
3156        if (
3157            data.get("type") == "model"
3158            and isinstance(fv := data.get("format_version"), str)
3159            and fv.count(".") == 2
3160        ):
3161            fv_parts = fv.split(".")
3162            if any(not p.isdigit() for p in fv_parts):
3163                return
3164
3165            fv_tuple = tuple(map(int, fv_parts))
3166
3167            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3168            if fv_tuple[:2] in ((0, 3), (0, 4)):
3169                m04 = _ModelDescr_v0_4.load(data)
3170                if isinstance(m04, InvalidDescr):
3171                    try:
3172                        updated = _model_conv.convert_as_dict(
3173                            m04  # pyright: ignore[reportArgumentType]
3174                        )
3175                    except Exception as e:
3176                        logger.error(
3177                            "Failed to convert from invalid model 0.4 description."
3178                            + f"\nerror: {e}"
3179                            + "\nProceeding with model 0.5 validation without conversion."
3180                        )
3181                        updated = None
3182                else:
3183                    updated = _model_conv.convert_as_dict(m04)
3184
3185                if updated is not None:
3186                    data.clear()
3187                    data.update(updated)
3188
3189            elif fv_tuple[:2] == (0, 5):
3190                # bump patch version
3191                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.4']] = '0.5.4'
implemented_type: ClassVar[Literal['model']] = 'model'

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], MinLen(min_length=1)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Union[bioimageio.spec._internal.url.HttpUrl, bioimageio.spec._internal.io.RelativeFilePath, Annotated[pathlib.Path, PathType(path_type='file'), FieldInfo(annotation=NoneType, required=True, title='FilePath')]], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function wo_special_file_name at 0x7f1ae293f880>), PlainSerializer(func=<function _package_serializer at 0x7f1ae263d800>, return_type=PydanticUndefined, when_used='unless-none'), WithSuffix(suffix='.md', case_sensitive=True), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]

URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[bioimageio.spec.model.v0_5.InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f1adb49eac0>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[bioimageio.spec.model.v0_5.OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f1adb49ea20>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

timestamp: bioimageio.spec._internal.types.Datetime

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[bioimageio.spec.model.v0_5.WeightsDescr, WrapSerializer(func=<function package_weights at 0x7f1ae24b11c0>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

def get_input_test_arrays(self) -> List[numpy.ndarray[Any, numpy.dtype[Any]]]:
2933    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2934        data = [load_array(ipt.test_tensor) for ipt in self.inputs]
2935        assert all(isinstance(d, np.ndarray) for d in data)
2936        return data
def get_output_test_arrays(self) -> List[numpy.ndarray[Any, numpy.dtype[Any]]]:
2938    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2939        data = [load_array(out.test_tensor) for out in self.outputs]
2940        assert all(isinstance(d, np.ndarray) for d in data)
2941        return data
@staticmethod
def get_batch_size( tensor_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> int:
2943    @staticmethod
2944    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2945        batch_size = 1
2946        tensor_with_batchsize: Optional[TensorId] = None
2947        for tid in tensor_sizes:
2948            for aid, s in tensor_sizes[tid].items():
2949                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2950                    continue
2951
2952                if batch_size != 1:
2953                    assert tensor_with_batchsize is not None
2954                    raise ValueError(
2955                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2956                    )
2957
2958                batch_size = s
2959                tensor_with_batchsize = tid
2960
2961        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Dict[bioimageio.spec.model.v0_5.TensorId, Dict[bioimageio.spec.model.v0_5.AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
2963    def get_output_tensor_sizes(
2964        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2965    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2966        """Returns the tensor output sizes for given **input_sizes**.
2967        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2968        Otherwise it might be larger than the actual (valid) output"""
2969        batch_size = self.get_batch_size(input_sizes)
2970        ns = self.get_ns(input_sizes)
2971
2972        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2973        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]):
2975    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2976        """get parameter `n` for each parameterized axis
2977        such that the valid input size is >= the given input size"""
2978        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2979        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2980        for tid in input_sizes:
2981            for aid, s in input_sizes[tid].items():
2982                size_descr = axes[tid][aid].size
2983                if isinstance(size_descr, ParameterizedSize):
2984                    ret[(tid, aid)] = size_descr.get_n(s)
2985                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2986                    pass
2987                else:
2988                    assert_never(size_descr)
2989
2990        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
2992    def get_tensor_sizes(
2993        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2994    ) -> _TensorSizes:
2995        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2996        return _TensorSizes(
2997            {
2998                t: {
2999                    aa: axis_sizes.inputs[(tt, aa)]
3000                    for tt, aa in axis_sizes.inputs
3001                    if tt == t
3002                }
3003                for t in {tt for tt, _ in axis_sizes.inputs}
3004            },
3005            {
3006                t: {
3007                    aa: axis_sizes.outputs[(tt, aa)]
3008                    for tt, aa in axis_sizes.outputs
3009                    if tt == t
3010                }
3011                for t in {tt for tt, _ in axis_sizes.outputs}
3012            },
3013        )
def get_axis_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3015    def get_axis_sizes(
3016        self,
3017        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3018        batch_size: Optional[int] = None,
3019        *,
3020        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3021    ) -> _AxisSizes:
3022        """Determine input and output block shape for scale factors **ns**
3023        of parameterized input sizes.
3024
3025        Args:
3026            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3027                that is parameterized as `size = min + n * step`.
3028            batch_size: The desired size of the batch dimension.
3029                If given **batch_size** overwrites any batch size present in
3030                **max_input_shape**. Default 1.
3031            max_input_shape: Limits the derived block shapes.
3032                Each axis for which the input size, parameterized by `n`, is larger
3033                than **max_input_shape** is set to the minimal value `n_min` for which
3034                this is still true.
3035                Use this for small input samples or large values of **ns**.
3036                Or simply whenever you know the full input shape.
3037
3038        Returns:
3039            Resolved axis sizes for model inputs and outputs.
3040        """
3041        max_input_shape = max_input_shape or {}
3042        if batch_size is None:
3043            for (_t_id, a_id), s in max_input_shape.items():
3044                if a_id == BATCH_AXIS_ID:
3045                    batch_size = s
3046                    break
3047            else:
3048                batch_size = 1
3049
3050        all_axes = {
3051            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3052        }
3053
3054        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3055        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3056
3057        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3058            if isinstance(a, BatchAxis):
3059                if (t_descr.id, a.id) in ns:
3060                    logger.warning(
3061                        "Ignoring unexpected size increment factor (n) for batch axis"
3062                        + " of tensor '{}'.",
3063                        t_descr.id,
3064                    )
3065                return batch_size
3066            elif isinstance(a.size, int):
3067                if (t_descr.id, a.id) in ns:
3068                    logger.warning(
3069                        "Ignoring unexpected size increment factor (n) for fixed size"
3070                        + " axis '{}' of tensor '{}'.",
3071                        a.id,
3072                        t_descr.id,
3073                    )
3074                return a.size
3075            elif isinstance(a.size, ParameterizedSize):
3076                if (t_descr.id, a.id) not in ns:
3077                    raise ValueError(
3078                        "Size increment factor (n) missing for parametrized axis"
3079                        + f" '{a.id}' of tensor '{t_descr.id}'."
3080                    )
3081                n = ns[(t_descr.id, a.id)]
3082                s_max = max_input_shape.get((t_descr.id, a.id))
3083                if s_max is not None:
3084                    n = min(n, a.size.get_n(s_max))
3085
3086                return a.size.get_size(n)
3087
3088            elif isinstance(a.size, SizeReference):
3089                if (t_descr.id, a.id) in ns:
3090                    logger.warning(
3091                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3092                        + " of tensor '{}' with size reference.",
3093                        a.id,
3094                        t_descr.id,
3095                    )
3096                assert not isinstance(a, BatchAxis)
3097                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3098                assert not isinstance(ref_axis, BatchAxis)
3099                ref_key = (a.size.tensor_id, a.size.axis_id)
3100                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3101                assert ref_size is not None, ref_key
3102                assert not isinstance(ref_size, _DataDepSize), ref_key
3103                return a.size.get_size(
3104                    axis=a,
3105                    ref_axis=ref_axis,
3106                    ref_size=ref_size,
3107                )
3108            elif isinstance(a.size, DataDependentSize):
3109                if (t_descr.id, a.id) in ns:
3110                    logger.warning(
3111                        "Ignoring unexpected increment factor (n) for data dependent"
3112                        + " size axis '{}' of tensor '{}'.",
3113                        a.id,
3114                        t_descr.id,
3115                    )
3116                return _DataDepSize(a.size.min, a.size.max)
3117            else:
3118                assert_never(a.size)
3119
3120        # first resolve all , but the `SizeReference` input sizes
3121        for t_descr in self.inputs:
3122            for a in t_descr.axes:
3123                if not isinstance(a.size, SizeReference):
3124                    s = get_axis_size(a)
3125                    assert not isinstance(s, _DataDepSize)
3126                    inputs[t_descr.id, a.id] = s
3127
3128        # resolve all other input axis sizes
3129        for t_descr in self.inputs:
3130            for a in t_descr.axes:
3131                if isinstance(a.size, SizeReference):
3132                    s = get_axis_size(a)
3133                    assert not isinstance(s, _DataDepSize)
3134                    inputs[t_descr.id, a.id] = s
3135
3136        # resolve all output axis sizes
3137        for t_descr in self.outputs:
3138            for a in t_descr.axes:
3139                assert not isinstance(a.size, ParameterizedSize)
3140                s = get_axis_size(a)
3141                outputs[t_descr.id, a.id] = s
3142
3143        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3151    @classmethod
3152    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3153        """Convert metadata following an older format version to this classes' format
3154        without validating the result.
3155        """
3156        if (
3157            data.get("type") == "model"
3158            and isinstance(fv := data.get("format_version"), str)
3159            and fv.count(".") == 2
3160        ):
3161            fv_parts = fv.split(".")
3162            if any(not p.isdigit() for p in fv_parts):
3163                return
3164
3165            fv_tuple = tuple(map(int, fv_parts))
3166
3167            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3168            if fv_tuple[:2] in ((0, 3), (0, 4)):
3169                m04 = _ModelDescr_v0_4.load(data)
3170                if isinstance(m04, InvalidDescr):
3171                    try:
3172                        updated = _model_conv.convert_as_dict(
3173                            m04  # pyright: ignore[reportArgumentType]
3174                        )
3175                    except Exception as e:
3176                        logger.error(
3177                            "Failed to convert from invalid model 0.4 description."
3178                            + f"\nerror: {e}"
3179                            + "\nProceeding with model 0.5 validation without conversion."
3180                        )
3181                        updated = None
3182                else:
3183                    updated = _model_conv.convert_as_dict(m04)
3184
3185                if updated is not None:
3186                    data.clear()
3187                    data.update(updated)
3188
3189            elif fv_tuple[:2] == (0, 5):
3190                # bump patch version
3191                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 4)
model_config: ClassVar[pydantic.config.ConfigDict] = {'extra': 'forbid', 'frozen': False, 'populate_by_name': True, 'revalidate_instances': 'never', 'validate_assignment': True, 'validate_default': False, 'validate_return': True, 'use_attribute_docstrings': True, 'model_title_generator': <function _node_title_generator>, 'validate_by_alias': True, 'validate_by_name': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
337def init_private_attributes(self: BaseModel, context: Any, /) -> None:
338    """This function is meant to behave like a BaseModel method to initialise private attributes.
339
340    It takes context as an argument since that's what pydantic-core passes when calling it.
341
342    Args:
343        self: The BaseModel instance.
344        context: The context.
345    """
346    if getattr(self, '__pydantic_private__', None) is None:
347        pydantic_private = {}
348        for name, private_attr in self.__private_attributes__.items():
349            default = private_attr.get_default()
350            if default is not PydanticUndefined:
351                pydantic_private[name] = default
352        object_setattr(self, '__pydantic_private__', pydantic_private)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Arguments:
  • self: The BaseModel instance.
  • context: The context.
ModelDescr_v0_4 = <class 'bioimageio.spec.model.v0_4.ModelDescr'>
ModelDescr_v0_5 = <class 'ModelDescr'>
AnyModelDescr = typing.Annotated[typing.Union[typing.Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], typing.Annotated[ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]

Union of any released model desription