bioimageio.spec.model

implementaions of all released minor versions are available in submodules:

 1# autogen: start
 2"""
 3implementaions of all released minor versions are available in submodules:
 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr`
 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr`
 6"""
 7
 8from typing import Union
 9
10from pydantic import Discriminator, Field
11from typing_extensions import Annotated
12
13from . import v0_4, v0_5
14
15ModelDescr = v0_5.ModelDescr
16ModelDescr_v0_4 = v0_4.ModelDescr
17ModelDescr_v0_5 = v0_5.ModelDescr
18
19AnyModelDescr = Annotated[
20    Union[
21        Annotated[ModelDescr_v0_4, Field(title="model 0.4")],
22        Annotated[ModelDescr_v0_5, Field(title="model 0.5")],
23    ],
24    Discriminator("format_version"),
25    Field(title="model"),
26]
27"""Union of any released model desription"""
28# autogen: stop
2512class ModelDescr(GenericModelDescrBase):
2513    """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
2514    These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
2515    """
2516
2517    implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4"
2518    if TYPE_CHECKING:
2519        format_version: Literal["0.5.4"] = "0.5.4"
2520    else:
2521        format_version: Literal["0.5.4"]
2522        """Version of the bioimage.io model description specification used.
2523        When creating a new model always use the latest micro/patch version described here.
2524        The `format_version` is important for any consumer software to understand how to parse the fields.
2525        """
2526
2527    implemented_type: ClassVar[Literal["model"]] = "model"
2528    if TYPE_CHECKING:
2529        type: Literal["model"] = "model"
2530    else:
2531        type: Literal["model"]
2532        """Specialized resource type 'model'"""
2533
2534    id: Optional[ModelId] = None
2535    """bioimage.io-wide unique resource identifier
2536    assigned by bioimage.io; version **un**specific."""
2537
2538    authors: NotEmpty[List[Author]]
2539    """The authors are the creators of the model RDF and the primary points of contact."""
2540
2541    documentation: Annotated[
2542        DocumentationSource,
2543        Field(
2544            examples=[
2545                "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md",
2546                "README.md",
2547            ],
2548        ),
2549    ]
2550    """∈📦 URL or relative path to a markdown file with additional documentation.
2551    The recommended documentation file name is `README.md`. An `.md` suffix is mandatory.
2552    The documentation should include a '#[#] Validation' (sub)section
2553    with details on how to quantitatively validate the model on unseen data."""
2554
2555    @field_validator("documentation", mode="after")
2556    @classmethod
2557    def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource:
2558        if not get_validation_context().perform_io_checks:
2559            return value
2560
2561        doc_path = download(value).path
2562        doc_content = doc_path.read_text(encoding="utf-8")
2563        assert isinstance(doc_content, str)
2564        if not re.search("#.*[vV]alidation", doc_content):
2565            issue_warning(
2566                "No '# Validation' (sub)section found in {value}.",
2567                value=value,
2568                field="documentation",
2569            )
2570
2571        return value
2572
2573    inputs: NotEmpty[Sequence[InputTensorDescr]]
2574    """Describes the input tensors expected by this model."""
2575
2576    @field_validator("inputs", mode="after")
2577    @classmethod
2578    def _validate_input_axes(
2579        cls, inputs: Sequence[InputTensorDescr]
2580    ) -> Sequence[InputTensorDescr]:
2581        input_size_refs = cls._get_axes_with_independent_size(inputs)
2582
2583        for i, ipt in enumerate(inputs):
2584            valid_independent_refs: Dict[
2585                Tuple[TensorId, AxisId],
2586                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2587            ] = {
2588                **{
2589                    (ipt.id, a.id): (ipt, a, a.size)
2590                    for a in ipt.axes
2591                    if not isinstance(a, BatchAxis)
2592                    and isinstance(a.size, (int, ParameterizedSize))
2593                },
2594                **input_size_refs,
2595            }
2596            for a, ax in enumerate(ipt.axes):
2597                cls._validate_axis(
2598                    "inputs",
2599                    i=i,
2600                    tensor_id=ipt.id,
2601                    a=a,
2602                    axis=ax,
2603                    valid_independent_refs=valid_independent_refs,
2604                )
2605        return inputs
2606
2607    @staticmethod
2608    def _validate_axis(
2609        field_name: str,
2610        i: int,
2611        tensor_id: TensorId,
2612        a: int,
2613        axis: AnyAxis,
2614        valid_independent_refs: Dict[
2615            Tuple[TensorId, AxisId],
2616            Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2617        ],
2618    ):
2619        if isinstance(axis, BatchAxis) or isinstance(
2620            axis.size, (int, ParameterizedSize, DataDependentSize)
2621        ):
2622            return
2623        elif not isinstance(axis.size, SizeReference):
2624            assert_never(axis.size)
2625
2626        # validate axis.size SizeReference
2627        ref = (axis.size.tensor_id, axis.size.axis_id)
2628        if ref not in valid_independent_refs:
2629            raise ValueError(
2630                "Invalid tensor axis reference at"
2631                + f" {field_name}[{i}].axes[{a}].size: {axis.size}."
2632            )
2633        if ref == (tensor_id, axis.id):
2634            raise ValueError(
2635                "Self-referencing not allowed for"
2636                + f" {field_name}[{i}].axes[{a}].size: {axis.size}"
2637            )
2638        if axis.type == "channel":
2639            if valid_independent_refs[ref][1].type != "channel":
2640                raise ValueError(
2641                    "A channel axis' size may only reference another fixed size"
2642                    + " channel axis."
2643                )
2644            if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
2645                ref_size = valid_independent_refs[ref][2]
2646                assert isinstance(ref_size, int), (
2647                    "channel axis ref (another channel axis) has to specify fixed"
2648                    + " size"
2649                )
2650                generated_channel_names = [
2651                    Identifier(axis.channel_names.format(i=i))
2652                    for i in range(1, ref_size + 1)
2653                ]
2654                axis.channel_names = generated_channel_names
2655
2656        if (ax_unit := getattr(axis, "unit", None)) != (
2657            ref_unit := getattr(valid_independent_refs[ref][1], "unit", None)
2658        ):
2659            raise ValueError(
2660                "The units of an axis and its reference axis need to match, but"
2661                + f" '{ax_unit}' != '{ref_unit}'."
2662            )
2663        ref_axis = valid_independent_refs[ref][1]
2664        if isinstance(ref_axis, BatchAxis):
2665            raise ValueError(
2666                f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}"
2667                + " (a batch axis is not allowed as reference)."
2668            )
2669
2670        if isinstance(axis, WithHalo):
2671            min_size = axis.size.get_size(axis, ref_axis, n=0)
2672            if (min_size - 2 * axis.halo) < 1:
2673                raise ValueError(
2674                    f"axis {axis.id} with minimum size {min_size} is too small for halo"
2675                    + f" {axis.halo}."
2676                )
2677
2678            input_halo = axis.halo * axis.scale / ref_axis.scale
2679            if input_halo != int(input_halo) or input_halo % 2 == 1:
2680                raise ValueError(
2681                    f"input_halo {input_halo} (output_halo {axis.halo} *"
2682                    + f" output_scale {axis.scale} / input_scale {ref_axis.scale})"
2683                    + f"     {tensor_id}.{axis.id}."
2684                )
2685
2686    @model_validator(mode="after")
2687    def _validate_test_tensors(self) -> Self:
2688        if not get_validation_context().perform_io_checks:
2689            return self
2690
2691        test_output_arrays = [
2692            load_array(descr.test_tensor.download().path) for descr in self.outputs
2693        ]
2694        test_input_arrays = [
2695            load_array(descr.test_tensor.download().path) for descr in self.inputs
2696        ]
2697
2698        tensors = {
2699            descr.id: (descr, array)
2700            for descr, array in zip(
2701                chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays
2702            )
2703        }
2704        validate_tensors(tensors, tensor_origin="test_tensor")
2705
2706        output_arrays = {
2707            descr.id: array for descr, array in zip(self.outputs, test_output_arrays)
2708        }
2709        for rep_tol in self.config.bioimageio.reproducibility_tolerance:
2710            if not rep_tol.absolute_tolerance:
2711                continue
2712
2713            if rep_tol.output_ids:
2714                out_arrays = {
2715                    oid: a
2716                    for oid, a in output_arrays.items()
2717                    if oid in rep_tol.output_ids
2718                }
2719            else:
2720                out_arrays = output_arrays
2721
2722            for out_id, array in out_arrays.items():
2723                if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01:
2724                    raise ValueError(
2725                        "config.bioimageio.reproducibility_tolerance.absolute_tolerance="
2726                        + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}"
2727                        + f" (1% of the maximum value of the test tensor '{out_id}')"
2728                    )
2729
2730        return self
2731
2732    @model_validator(mode="after")
2733    def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self:
2734        ipt_refs = {t.id for t in self.inputs}
2735        out_refs = {t.id for t in self.outputs}
2736        for ipt in self.inputs:
2737            for p in ipt.preprocessing:
2738                ref = p.kwargs.get("reference_tensor")
2739                if ref is None:
2740                    continue
2741                if ref not in ipt_refs:
2742                    raise ValueError(
2743                        f"`reference_tensor` '{ref}' not found. Valid input tensor"
2744                        + f" references are: {ipt_refs}."
2745                    )
2746
2747        for out in self.outputs:
2748            for p in out.postprocessing:
2749                ref = p.kwargs.get("reference_tensor")
2750                if ref is None:
2751                    continue
2752
2753                if ref not in ipt_refs and ref not in out_refs:
2754                    raise ValueError(
2755                        f"`reference_tensor` '{ref}' not found. Valid tensor references"
2756                        + f" are: {ipt_refs | out_refs}."
2757                    )
2758
2759        return self
2760
2761    # TODO: use validate funcs in validate_test_tensors
2762    # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]:
2763
2764    name: Annotated[
2765        Annotated[
2766            str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()")
2767        ],
2768        MinLen(5),
2769        MaxLen(128),
2770        warn(MaxLen(64), "Name longer than 64 characters.", INFO),
2771    ]
2772    """A human-readable name of this model.
2773    It should be no longer than 64 characters
2774    and may only contain letter, number, underscore, minus, parentheses and spaces.
2775    We recommend to chose a name that refers to the model's task and image modality.
2776    """
2777
2778    outputs: NotEmpty[Sequence[OutputTensorDescr]]
2779    """Describes the output tensors."""
2780
2781    @field_validator("outputs", mode="after")
2782    @classmethod
2783    def _validate_tensor_ids(
2784        cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo
2785    ) -> Sequence[OutputTensorDescr]:
2786        tensor_ids = [
2787            t.id for t in info.data.get("inputs", []) + info.data.get("outputs", [])
2788        ]
2789        duplicate_tensor_ids: List[str] = []
2790        seen: Set[str] = set()
2791        for t in tensor_ids:
2792            if t in seen:
2793                duplicate_tensor_ids.append(t)
2794
2795            seen.add(t)
2796
2797        if duplicate_tensor_ids:
2798            raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}")
2799
2800        return outputs
2801
2802    @staticmethod
2803    def _get_axes_with_parameterized_size(
2804        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2805    ):
2806        return {
2807            f"{t.id}.{a.id}": (t, a, a.size)
2808            for t in io
2809            for a in t.axes
2810            if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize)
2811        }
2812
2813    @staticmethod
2814    def _get_axes_with_independent_size(
2815        io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]],
2816    ):
2817        return {
2818            (t.id, a.id): (t, a, a.size)
2819            for t in io
2820            for a in t.axes
2821            if not isinstance(a, BatchAxis)
2822            and isinstance(a.size, (int, ParameterizedSize))
2823        }
2824
2825    @field_validator("outputs", mode="after")
2826    @classmethod
2827    def _validate_output_axes(
2828        cls, outputs: List[OutputTensorDescr], info: ValidationInfo
2829    ) -> List[OutputTensorDescr]:
2830        input_size_refs = cls._get_axes_with_independent_size(
2831            info.data.get("inputs", [])
2832        )
2833        output_size_refs = cls._get_axes_with_independent_size(outputs)
2834
2835        for i, out in enumerate(outputs):
2836            valid_independent_refs: Dict[
2837                Tuple[TensorId, AxisId],
2838                Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]],
2839            ] = {
2840                **{
2841                    (out.id, a.id): (out, a, a.size)
2842                    for a in out.axes
2843                    if not isinstance(a, BatchAxis)
2844                    and isinstance(a.size, (int, ParameterizedSize))
2845                },
2846                **input_size_refs,
2847                **output_size_refs,
2848            }
2849            for a, ax in enumerate(out.axes):
2850                cls._validate_axis(
2851                    "outputs",
2852                    i,
2853                    out.id,
2854                    a,
2855                    ax,
2856                    valid_independent_refs=valid_independent_refs,
2857                )
2858
2859        return outputs
2860
2861    packaged_by: List[Author] = Field(default_factory=list)
2862    """The persons that have packaged and uploaded this model.
2863    Only required if those persons differ from the `authors`."""
2864
2865    parent: Optional[LinkedModel] = None
2866    """The model from which this model is derived, e.g. by fine-tuning the weights."""
2867
2868    @model_validator(mode="after")
2869    def _validate_parent_is_not_self(self) -> Self:
2870        if self.parent is not None and self.parent.id == self.id:
2871            raise ValueError("A model description may not reference itself as parent.")
2872
2873        return self
2874
2875    run_mode: Annotated[
2876        Optional[RunMode],
2877        warn(None, "Run mode '{value}' has limited support across consumer softwares."),
2878    ] = None
2879    """Custom run mode for this model: for more complex prediction procedures like test time
2880    data augmentation that currently cannot be expressed in the specification.
2881    No standard run modes are defined yet."""
2882
2883    timestamp: Datetime = Field(default_factory=Datetime.now)
2884    """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format
2885    with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).
2886    (In Python a datetime object is valid, too)."""
2887
2888    training_data: Annotated[
2889        Union[None, LinkedDataset, DatasetDescr, DatasetDescr02],
2890        Field(union_mode="left_to_right"),
2891    ] = None
2892    """The dataset used to train this model"""
2893
2894    weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
2895    """The weights for this model.
2896    Weights can be given for different formats, but should otherwise be equivalent.
2897    The available weight formats determine which consumers can use this model."""
2898
2899    config: Config = Field(default_factory=Config)
2900
2901    @model_validator(mode="after")
2902    def _add_default_cover(self) -> Self:
2903        if not get_validation_context().perform_io_checks or self.covers:
2904            return self
2905
2906        try:
2907            generated_covers = generate_covers(
2908                [(t, load_array(t.test_tensor.download().path)) for t in self.inputs],
2909                [(t, load_array(t.test_tensor.download().path)) for t in self.outputs],
2910            )
2911        except Exception as e:
2912            issue_warning(
2913                "Failed to generate cover image(s): {e}",
2914                value=self.covers,
2915                msg_context=dict(e=e),
2916                field="covers",
2917            )
2918        else:
2919            self.covers.extend(generated_covers)
2920
2921        return self
2922
2923    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2924        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2925        assert all(isinstance(d, np.ndarray) for d in data)
2926        return data
2927
2928    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2929        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2930        assert all(isinstance(d, np.ndarray) for d in data)
2931        return data
2932
2933    @staticmethod
2934    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2935        batch_size = 1
2936        tensor_with_batchsize: Optional[TensorId] = None
2937        for tid in tensor_sizes:
2938            for aid, s in tensor_sizes[tid].items():
2939                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2940                    continue
2941
2942                if batch_size != 1:
2943                    assert tensor_with_batchsize is not None
2944                    raise ValueError(
2945                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2946                    )
2947
2948                batch_size = s
2949                tensor_with_batchsize = tid
2950
2951        return batch_size
2952
2953    def get_output_tensor_sizes(
2954        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2955    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2956        """Returns the tensor output sizes for given **input_sizes**.
2957        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2958        Otherwise it might be larger than the actual (valid) output"""
2959        batch_size = self.get_batch_size(input_sizes)
2960        ns = self.get_ns(input_sizes)
2961
2962        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2963        return tensor_sizes.outputs
2964
2965    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2966        """get parameter `n` for each parameterized axis
2967        such that the valid input size is >= the given input size"""
2968        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2969        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2970        for tid in input_sizes:
2971            for aid, s in input_sizes[tid].items():
2972                size_descr = axes[tid][aid].size
2973                if isinstance(size_descr, ParameterizedSize):
2974                    ret[(tid, aid)] = size_descr.get_n(s)
2975                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2976                    pass
2977                else:
2978                    assert_never(size_descr)
2979
2980        return ret
2981
2982    def get_tensor_sizes(
2983        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2984    ) -> _TensorSizes:
2985        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2986        return _TensorSizes(
2987            {
2988                t: {
2989                    aa: axis_sizes.inputs[(tt, aa)]
2990                    for tt, aa in axis_sizes.inputs
2991                    if tt == t
2992                }
2993                for t in {tt for tt, _ in axis_sizes.inputs}
2994            },
2995            {
2996                t: {
2997                    aa: axis_sizes.outputs[(tt, aa)]
2998                    for tt, aa in axis_sizes.outputs
2999                    if tt == t
3000                }
3001                for t in {tt for tt, _ in axis_sizes.outputs}
3002            },
3003        )
3004
3005    def get_axis_sizes(
3006        self,
3007        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3008        batch_size: Optional[int] = None,
3009        *,
3010        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3011    ) -> _AxisSizes:
3012        """Determine input and output block shape for scale factors **ns**
3013        of parameterized input sizes.
3014
3015        Args:
3016            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3017                that is parameterized as `size = min + n * step`.
3018            batch_size: The desired size of the batch dimension.
3019                If given **batch_size** overwrites any batch size present in
3020                **max_input_shape**. Default 1.
3021            max_input_shape: Limits the derived block shapes.
3022                Each axis for which the input size, parameterized by `n`, is larger
3023                than **max_input_shape** is set to the minimal value `n_min` for which
3024                this is still true.
3025                Use this for small input samples or large values of **ns**.
3026                Or simply whenever you know the full input shape.
3027
3028        Returns:
3029            Resolved axis sizes for model inputs and outputs.
3030        """
3031        max_input_shape = max_input_shape or {}
3032        if batch_size is None:
3033            for (_t_id, a_id), s in max_input_shape.items():
3034                if a_id == BATCH_AXIS_ID:
3035                    batch_size = s
3036                    break
3037            else:
3038                batch_size = 1
3039
3040        all_axes = {
3041            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3042        }
3043
3044        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3045        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3046
3047        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3048            if isinstance(a, BatchAxis):
3049                if (t_descr.id, a.id) in ns:
3050                    logger.warning(
3051                        "Ignoring unexpected size increment factor (n) for batch axis"
3052                        + " of tensor '{}'.",
3053                        t_descr.id,
3054                    )
3055                return batch_size
3056            elif isinstance(a.size, int):
3057                if (t_descr.id, a.id) in ns:
3058                    logger.warning(
3059                        "Ignoring unexpected size increment factor (n) for fixed size"
3060                        + " axis '{}' of tensor '{}'.",
3061                        a.id,
3062                        t_descr.id,
3063                    )
3064                return a.size
3065            elif isinstance(a.size, ParameterizedSize):
3066                if (t_descr.id, a.id) not in ns:
3067                    raise ValueError(
3068                        "Size increment factor (n) missing for parametrized axis"
3069                        + f" '{a.id}' of tensor '{t_descr.id}'."
3070                    )
3071                n = ns[(t_descr.id, a.id)]
3072                s_max = max_input_shape.get((t_descr.id, a.id))
3073                if s_max is not None:
3074                    n = min(n, a.size.get_n(s_max))
3075
3076                return a.size.get_size(n)
3077
3078            elif isinstance(a.size, SizeReference):
3079                if (t_descr.id, a.id) in ns:
3080                    logger.warning(
3081                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3082                        + " of tensor '{}' with size reference.",
3083                        a.id,
3084                        t_descr.id,
3085                    )
3086                assert not isinstance(a, BatchAxis)
3087                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3088                assert not isinstance(ref_axis, BatchAxis)
3089                ref_key = (a.size.tensor_id, a.size.axis_id)
3090                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3091                assert ref_size is not None, ref_key
3092                assert not isinstance(ref_size, _DataDepSize), ref_key
3093                return a.size.get_size(
3094                    axis=a,
3095                    ref_axis=ref_axis,
3096                    ref_size=ref_size,
3097                )
3098            elif isinstance(a.size, DataDependentSize):
3099                if (t_descr.id, a.id) in ns:
3100                    logger.warning(
3101                        "Ignoring unexpected increment factor (n) for data dependent"
3102                        + " size axis '{}' of tensor '{}'.",
3103                        a.id,
3104                        t_descr.id,
3105                    )
3106                return _DataDepSize(a.size.min, a.size.max)
3107            else:
3108                assert_never(a.size)
3109
3110        # first resolve all , but the `SizeReference` input sizes
3111        for t_descr in self.inputs:
3112            for a in t_descr.axes:
3113                if not isinstance(a.size, SizeReference):
3114                    s = get_axis_size(a)
3115                    assert not isinstance(s, _DataDepSize)
3116                    inputs[t_descr.id, a.id] = s
3117
3118        # resolve all other input axis sizes
3119        for t_descr in self.inputs:
3120            for a in t_descr.axes:
3121                if isinstance(a.size, SizeReference):
3122                    s = get_axis_size(a)
3123                    assert not isinstance(s, _DataDepSize)
3124                    inputs[t_descr.id, a.id] = s
3125
3126        # resolve all output axis sizes
3127        for t_descr in self.outputs:
3128            for a in t_descr.axes:
3129                assert not isinstance(a.size, ParameterizedSize)
3130                s = get_axis_size(a)
3131                outputs[t_descr.id, a.id] = s
3132
3133        return _AxisSizes(inputs=inputs, outputs=outputs)
3134
3135    @model_validator(mode="before")
3136    @classmethod
3137    def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]:
3138        cls.convert_from_old_format_wo_validation(data)
3139        return data
3140
3141    @classmethod
3142    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3143        """Convert metadata following an older format version to this classes' format
3144        without validating the result.
3145        """
3146        if (
3147            data.get("type") == "model"
3148            and isinstance(fv := data.get("format_version"), str)
3149            and fv.count(".") == 2
3150        ):
3151            fv_parts = fv.split(".")
3152            if any(not p.isdigit() for p in fv_parts):
3153                return
3154
3155            fv_tuple = tuple(map(int, fv_parts))
3156
3157            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3158            if fv_tuple[:2] in ((0, 3), (0, 4)):
3159                m04 = _ModelDescr_v0_4.load(data)
3160                if isinstance(m04, InvalidDescr):
3161                    try:
3162                        updated = _model_conv.convert_as_dict(
3163                            m04  # pyright: ignore[reportArgumentType]
3164                        )
3165                    except Exception as e:
3166                        logger.error(
3167                            "Failed to convert from invalid model 0.4 description."
3168                            + f"\nerror: {e}"
3169                            + "\nProceeding with model 0.5 validation without conversion."
3170                        )
3171                        updated = None
3172                else:
3173                    updated = _model_conv.convert_as_dict(m04)
3174
3175                if updated is not None:
3176                    data.clear()
3177                    data.update(updated)
3178
3179            elif fv_tuple[:2] == (0, 5):
3180                # bump patch version
3181                data["format_version"] = cls.implemented_format_version

Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).

implemented_format_version: ClassVar[Literal['0.5.4']] = '0.5.4'
implemented_type: ClassVar[Literal['model']] = 'model'

bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.

authors: Annotated[List[bioimageio.spec.generic.v0_3.Author], MinLen(min_length=1)]

The authors are the creators of the model RDF and the primary points of contact.

documentation: Annotated[Union[Annotated[pathlib.Path, PathType(path_type='file'), Predicate(is_absolute), FieldInfo(annotation=NoneType, required=True, title='AbsoluteFilePath')], bioimageio.spec._internal.io.RelativeFilePath, bioimageio.spec._internal.url.HttpUrl], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')]), AfterValidator(func=<function _validate_md_suffix at 0x7f26013f3e20>), PlainSerializer(func=<function _package at 0x7f2602535e40>, return_type=PydanticUndefined, when_used='unless-none'), FieldInfo(annotation=NoneType, required=True, examples=['https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md', 'README.md'])]

∈📦 URL or relative path to a markdown file with additional documentation. The recommended documentation file name is README.md. An .md suffix is mandatory. The documentation should include a '#[#] Validation' (sub)section with details on how to quantitatively validate the model on unseen data.

inputs: Annotated[Sequence[bioimageio.spec.model.v0_5.InputTensorDescr], MinLen(min_length=1)]

Describes the input tensors expected by this model.

name: Annotated[str, RestrictCharacters(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+- ()'), MinLen(min_length=5), MaxLen(max_length=128), AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f25f18abc40>, severity=20, msg='Name longer than 64 characters.', context={'typ': Annotated[Any, MaxLen(max_length=64)]})]

A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.

outputs: Annotated[Sequence[bioimageio.spec.model.v0_5.OutputTensorDescr], MinLen(min_length=1)]

Describes the output tensors.

The persons that have packaged and uploaded this model. Only required if those persons differ from the authors.

The model from which this model is derived, e.g. by fine-tuning the weights.

run_mode: Annotated[Optional[bioimageio.spec.model.v0_4.RunMode], AfterWarner(func=<function as_warning.<locals>.wrapper at 0x7f25f189c0e0>, severity=30, msg="Run mode '{value}' has limited support across consumer softwares.", context={'typ': None})]

Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.

timestamp: bioimageio.spec._internal.types.Datetime

Timestamp in ISO 8601 format with a few restrictions listed here. (In Python a datetime object is valid, too).

training_data: Annotated[Union[NoneType, bioimageio.spec.dataset.v0_3.LinkedDataset, bioimageio.spec.DatasetDescr, bioimageio.spec.dataset.v0_2.DatasetDescr], FieldInfo(annotation=NoneType, required=True, metadata=[_PydanticGeneralMetadata(union_mode='left_to_right')])]

The dataset used to train this model

weights: Annotated[bioimageio.spec.model.v0_5.WeightsDescr, WrapSerializer(func=<function package_weights at 0x7f25f8f8df80>, return_type=PydanticUndefined, when_used='always')]

The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.

def get_input_test_arrays(self) -> List[numpy.ndarray[Any, numpy.dtype[Any]]]:
2923    def get_input_test_arrays(self) -> List[NDArray[Any]]:
2924        data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs]
2925        assert all(isinstance(d, np.ndarray) for d in data)
2926        return data
def get_output_test_arrays(self) -> List[numpy.ndarray[Any, numpy.dtype[Any]]]:
2928    def get_output_test_arrays(self) -> List[NDArray[Any]]:
2929        data = [load_array(out.test_tensor.download().path) for out in self.outputs]
2930        assert all(isinstance(d, np.ndarray) for d in data)
2931        return data
@staticmethod
def get_batch_size( tensor_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> int:
2933    @staticmethod
2934    def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int:
2935        batch_size = 1
2936        tensor_with_batchsize: Optional[TensorId] = None
2937        for tid in tensor_sizes:
2938            for aid, s in tensor_sizes[tid].items():
2939                if aid != BATCH_AXIS_ID or s == 1 or s == batch_size:
2940                    continue
2941
2942                if batch_size != 1:
2943                    assert tensor_with_batchsize is not None
2944                    raise ValueError(
2945                        f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})"
2946                    )
2947
2948                batch_size = s
2949                tensor_with_batchsize = tid
2950
2951        return batch_size
def get_output_tensor_sizes( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]) -> Dict[bioimageio.spec.model.v0_5.TensorId, Dict[bioimageio.spec.model.v0_5.AxisId, Union[int, bioimageio.spec.model.v0_5._DataDepSize]]]:
2953    def get_output_tensor_sizes(
2954        self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]
2955    ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]:
2956        """Returns the tensor output sizes for given **input_sizes**.
2957        Only if **input_sizes** has a valid input shape, the tensor output size is exact.
2958        Otherwise it might be larger than the actual (valid) output"""
2959        batch_size = self.get_batch_size(input_sizes)
2960        ns = self.get_ns(input_sizes)
2961
2962        tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size)
2963        return tensor_sizes.outputs

Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output

def get_ns( self, input_sizes: Mapping[bioimageio.spec.model.v0_5.TensorId, Mapping[bioimageio.spec.model.v0_5.AxisId, int]]):
2965    def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]):
2966        """get parameter `n` for each parameterized axis
2967        such that the valid input size is >= the given input size"""
2968        ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {}
2969        axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs}
2970        for tid in input_sizes:
2971            for aid, s in input_sizes[tid].items():
2972                size_descr = axes[tid][aid].size
2973                if isinstance(size_descr, ParameterizedSize):
2974                    ret[(tid, aid)] = size_descr.get_n(s)
2975                elif size_descr is None or isinstance(size_descr, (int, SizeReference)):
2976                    pass
2977                else:
2978                    assert_never(size_descr)
2979
2980        return ret

get parameter n for each parameterized axis such that the valid input size is >= the given input size

def get_tensor_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: int) -> bioimageio.spec.model.v0_5._TensorSizes:
2982    def get_tensor_sizes(
2983        self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
2984    ) -> _TensorSizes:
2985        axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size)
2986        return _TensorSizes(
2987            {
2988                t: {
2989                    aa: axis_sizes.inputs[(tt, aa)]
2990                    for tt, aa in axis_sizes.inputs
2991                    if tt == t
2992                }
2993                for t in {tt for tt, _ in axis_sizes.inputs}
2994            },
2995            {
2996                t: {
2997                    aa: axis_sizes.outputs[(tt, aa)]
2998                    for tt, aa in axis_sizes.outputs
2999                    if tt == t
3000                }
3001                for t in {tt for tt, _ in axis_sizes.outputs}
3002            },
3003        )
def get_axis_sizes( self, ns: Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int], batch_size: Optional[int] = None, *, max_input_shape: Optional[Mapping[Tuple[bioimageio.spec.model.v0_5.TensorId, bioimageio.spec.model.v0_5.AxisId], int]] = None) -> bioimageio.spec.model.v0_5._AxisSizes:
3005    def get_axis_sizes(
3006        self,
3007        ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
3008        batch_size: Optional[int] = None,
3009        *,
3010        max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
3011    ) -> _AxisSizes:
3012        """Determine input and output block shape for scale factors **ns**
3013        of parameterized input sizes.
3014
3015        Args:
3016            ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
3017                that is parameterized as `size = min + n * step`.
3018            batch_size: The desired size of the batch dimension.
3019                If given **batch_size** overwrites any batch size present in
3020                **max_input_shape**. Default 1.
3021            max_input_shape: Limits the derived block shapes.
3022                Each axis for which the input size, parameterized by `n`, is larger
3023                than **max_input_shape** is set to the minimal value `n_min` for which
3024                this is still true.
3025                Use this for small input samples or large values of **ns**.
3026                Or simply whenever you know the full input shape.
3027
3028        Returns:
3029            Resolved axis sizes for model inputs and outputs.
3030        """
3031        max_input_shape = max_input_shape or {}
3032        if batch_size is None:
3033            for (_t_id, a_id), s in max_input_shape.items():
3034                if a_id == BATCH_AXIS_ID:
3035                    batch_size = s
3036                    break
3037            else:
3038                batch_size = 1
3039
3040        all_axes = {
3041            t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
3042        }
3043
3044        inputs: Dict[Tuple[TensorId, AxisId], int] = {}
3045        outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {}
3046
3047        def get_axis_size(a: Union[InputAxis, OutputAxis]):
3048            if isinstance(a, BatchAxis):
3049                if (t_descr.id, a.id) in ns:
3050                    logger.warning(
3051                        "Ignoring unexpected size increment factor (n) for batch axis"
3052                        + " of tensor '{}'.",
3053                        t_descr.id,
3054                    )
3055                return batch_size
3056            elif isinstance(a.size, int):
3057                if (t_descr.id, a.id) in ns:
3058                    logger.warning(
3059                        "Ignoring unexpected size increment factor (n) for fixed size"
3060                        + " axis '{}' of tensor '{}'.",
3061                        a.id,
3062                        t_descr.id,
3063                    )
3064                return a.size
3065            elif isinstance(a.size, ParameterizedSize):
3066                if (t_descr.id, a.id) not in ns:
3067                    raise ValueError(
3068                        "Size increment factor (n) missing for parametrized axis"
3069                        + f" '{a.id}' of tensor '{t_descr.id}'."
3070                    )
3071                n = ns[(t_descr.id, a.id)]
3072                s_max = max_input_shape.get((t_descr.id, a.id))
3073                if s_max is not None:
3074                    n = min(n, a.size.get_n(s_max))
3075
3076                return a.size.get_size(n)
3077
3078            elif isinstance(a.size, SizeReference):
3079                if (t_descr.id, a.id) in ns:
3080                    logger.warning(
3081                        "Ignoring unexpected size increment factor (n) for axis '{}'"
3082                        + " of tensor '{}' with size reference.",
3083                        a.id,
3084                        t_descr.id,
3085                    )
3086                assert not isinstance(a, BatchAxis)
3087                ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
3088                assert not isinstance(ref_axis, BatchAxis)
3089                ref_key = (a.size.tensor_id, a.size.axis_id)
3090                ref_size = inputs.get(ref_key, outputs.get(ref_key))
3091                assert ref_size is not None, ref_key
3092                assert not isinstance(ref_size, _DataDepSize), ref_key
3093                return a.size.get_size(
3094                    axis=a,
3095                    ref_axis=ref_axis,
3096                    ref_size=ref_size,
3097                )
3098            elif isinstance(a.size, DataDependentSize):
3099                if (t_descr.id, a.id) in ns:
3100                    logger.warning(
3101                        "Ignoring unexpected increment factor (n) for data dependent"
3102                        + " size axis '{}' of tensor '{}'.",
3103                        a.id,
3104                        t_descr.id,
3105                    )
3106                return _DataDepSize(a.size.min, a.size.max)
3107            else:
3108                assert_never(a.size)
3109
3110        # first resolve all , but the `SizeReference` input sizes
3111        for t_descr in self.inputs:
3112            for a in t_descr.axes:
3113                if not isinstance(a.size, SizeReference):
3114                    s = get_axis_size(a)
3115                    assert not isinstance(s, _DataDepSize)
3116                    inputs[t_descr.id, a.id] = s
3117
3118        # resolve all other input axis sizes
3119        for t_descr in self.inputs:
3120            for a in t_descr.axes:
3121                if isinstance(a.size, SizeReference):
3122                    s = get_axis_size(a)
3123                    assert not isinstance(s, _DataDepSize)
3124                    inputs[t_descr.id, a.id] = s
3125
3126        # resolve all output axis sizes
3127        for t_descr in self.outputs:
3128            for a in t_descr.axes:
3129                assert not isinstance(a.size, ParameterizedSize)
3130                s = get_axis_size(a)
3131                outputs[t_descr.id, a.id] = s
3132
3133        return _AxisSizes(inputs=inputs, outputs=outputs)

Determine input and output block shape for scale factors ns of parameterized input sizes.

Arguments:
  • ns: Scale factor n for each axis (keyed by (tensor_id, axis_id)) that is parameterized as size = min + n * step.
  • batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
  • max_input_shape: Limits the derived block shapes. Each axis for which the input size, parameterized by n, is larger than max_input_shape is set to the minimal value n_min for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:

Resolved axis sizes for model inputs and outputs.

@classmethod
def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3141    @classmethod
3142    def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None:
3143        """Convert metadata following an older format version to this classes' format
3144        without validating the result.
3145        """
3146        if (
3147            data.get("type") == "model"
3148            and isinstance(fv := data.get("format_version"), str)
3149            and fv.count(".") == 2
3150        ):
3151            fv_parts = fv.split(".")
3152            if any(not p.isdigit() for p in fv_parts):
3153                return
3154
3155            fv_tuple = tuple(map(int, fv_parts))
3156
3157            assert cls.implemented_format_version_tuple[0:2] == (0, 5)
3158            if fv_tuple[:2] in ((0, 3), (0, 4)):
3159                m04 = _ModelDescr_v0_4.load(data)
3160                if isinstance(m04, InvalidDescr):
3161                    try:
3162                        updated = _model_conv.convert_as_dict(
3163                            m04  # pyright: ignore[reportArgumentType]
3164                        )
3165                    except Exception as e:
3166                        logger.error(
3167                            "Failed to convert from invalid model 0.4 description."
3168                            + f"\nerror: {e}"
3169                            + "\nProceeding with model 0.5 validation without conversion."
3170                        )
3171                        updated = None
3172                else:
3173                    updated = _model_conv.convert_as_dict(m04)
3174
3175                if updated is not None:
3176                    data.clear()
3177                    data.update(updated)
3178
3179            elif fv_tuple[:2] == (0, 5):
3180                # bump patch version
3181                data["format_version"] = cls.implemented_format_version

Convert metadata following an older format version to this classes' format without validating the result.

implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] = (0, 5, 4)
def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
124                    def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
125                        """We need to both initialize private attributes and call the user-defined model_post_init
126                        method.
127                        """
128                        init_private_attributes(self, context)
129                        original_model_post_init(self, context)

We need to both initialize private attributes and call the user-defined model_post_init method.

ModelDescr_v0_4 = <class 'bioimageio.spec.model.v0_4.ModelDescr'>
ModelDescr_v0_5 = <class 'ModelDescr'>
AnyModelDescr = typing.Annotated[typing.Union[typing.Annotated[bioimageio.spec.model.v0_4.ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.4')], typing.Annotated[ModelDescr, FieldInfo(annotation=NoneType, required=True, title='model 0.5')]], Discriminator(discriminator='format_version', custom_error_type=None, custom_error_message=None, custom_error_context=None), FieldInfo(annotation=NoneType, required=True, title='model')]

Union of any released model desription