bioimageio.spec.model
implementaions of all released minor versions are available in submodules:
- model v0_4:
bioimageio.spec.model.v0_4.ModelDescr - model v0_5:
bioimageio.spec.model.v0_5.ModelDescr
1# autogen: start 2""" 3implementaions of all released minor versions are available in submodules: 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr` 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr` 6""" 7 8from typing import Union 9 10from pydantic import Discriminator, Field 11from typing_extensions import Annotated 12 13from . import v0_4, v0_5 14 15ModelDescr = v0_5.ModelDescr 16ModelDescr_v0_4 = v0_4.ModelDescr 17ModelDescr_v0_5 = v0_5.ModelDescr 18 19AnyModelDescr = Annotated[ 20 Union[ 21 Annotated[ModelDescr_v0_4, Field(title="model 0.4")], 22 Annotated[ModelDescr_v0_5, Field(title="model 0.5")], 23 ], 24 Discriminator("format_version"), 25 Field(title="model"), 26] 27"""Union of any released model desription""" 28# autogen: stop
2635class ModelDescr(GenericModelDescrBase): 2636 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2637 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2638 """ 2639 2640 implemented_format_version: ClassVar[Literal["0.5.6"]] = "0.5.6" 2641 if TYPE_CHECKING: 2642 format_version: Literal["0.5.6"] = "0.5.6" 2643 else: 2644 format_version: Literal["0.5.6"] 2645 """Version of the bioimage.io model description specification used. 2646 When creating a new model always use the latest micro/patch version described here. 2647 The `format_version` is important for any consumer software to understand how to parse the fields. 2648 """ 2649 2650 implemented_type: ClassVar[Literal["model"]] = "model" 2651 if TYPE_CHECKING: 2652 type: Literal["model"] = "model" 2653 else: 2654 type: Literal["model"] 2655 """Specialized resource type 'model'""" 2656 2657 id: Optional[ModelId] = None 2658 """bioimage.io-wide unique resource identifier 2659 assigned by bioimage.io; version **un**specific.""" 2660 2661 authors: FAIR[List[Author]] = Field( 2662 default_factory=cast(Callable[[], List[Author]], list) 2663 ) 2664 """The authors are the creators of the model RDF and the primary points of contact.""" 2665 2666 documentation: FAIR[Optional[FileSource_documentation]] = None 2667 """URL or relative path to a markdown file with additional documentation. 2668 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2669 The documentation should include a '#[#] Validation' (sub)section 2670 with details on how to quantitatively validate the model on unseen data.""" 2671 2672 @field_validator("documentation", mode="after") 2673 @classmethod 2674 def _validate_documentation( 2675 cls, value: Optional[FileSource_documentation] 2676 ) -> Optional[FileSource_documentation]: 2677 if not get_validation_context().perform_io_checks or value is None: 2678 return value 2679 2680 doc_reader = get_reader(value) 2681 doc_content = doc_reader.read().decode(encoding="utf-8") 2682 if not re.search("#.*[vV]alidation", doc_content): 2683 issue_warning( 2684 "No '# Validation' (sub)section found in {value}.", 2685 value=value, 2686 field="documentation", 2687 ) 2688 2689 return value 2690 2691 inputs: NotEmpty[Sequence[InputTensorDescr]] 2692 """Describes the input tensors expected by this model.""" 2693 2694 @field_validator("inputs", mode="after") 2695 @classmethod 2696 def _validate_input_axes( 2697 cls, inputs: Sequence[InputTensorDescr] 2698 ) -> Sequence[InputTensorDescr]: 2699 input_size_refs = cls._get_axes_with_independent_size(inputs) 2700 2701 for i, ipt in enumerate(inputs): 2702 valid_independent_refs: Dict[ 2703 Tuple[TensorId, AxisId], 2704 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2705 ] = { 2706 **{ 2707 (ipt.id, a.id): (ipt, a, a.size) 2708 for a in ipt.axes 2709 if not isinstance(a, BatchAxis) 2710 and isinstance(a.size, (int, ParameterizedSize)) 2711 }, 2712 **input_size_refs, 2713 } 2714 for a, ax in enumerate(ipt.axes): 2715 cls._validate_axis( 2716 "inputs", 2717 i=i, 2718 tensor_id=ipt.id, 2719 a=a, 2720 axis=ax, 2721 valid_independent_refs=valid_independent_refs, 2722 ) 2723 return inputs 2724 2725 @staticmethod 2726 def _validate_axis( 2727 field_name: str, 2728 i: int, 2729 tensor_id: TensorId, 2730 a: int, 2731 axis: AnyAxis, 2732 valid_independent_refs: Dict[ 2733 Tuple[TensorId, AxisId], 2734 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2735 ], 2736 ): 2737 if isinstance(axis, BatchAxis) or isinstance( 2738 axis.size, (int, ParameterizedSize, DataDependentSize) 2739 ): 2740 return 2741 elif not isinstance(axis.size, SizeReference): 2742 assert_never(axis.size) 2743 2744 # validate axis.size SizeReference 2745 ref = (axis.size.tensor_id, axis.size.axis_id) 2746 if ref not in valid_independent_refs: 2747 raise ValueError( 2748 "Invalid tensor axis reference at" 2749 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2750 ) 2751 if ref == (tensor_id, axis.id): 2752 raise ValueError( 2753 "Self-referencing not allowed for" 2754 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2755 ) 2756 if axis.type == "channel": 2757 if valid_independent_refs[ref][1].type != "channel": 2758 raise ValueError( 2759 "A channel axis' size may only reference another fixed size" 2760 + " channel axis." 2761 ) 2762 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2763 ref_size = valid_independent_refs[ref][2] 2764 assert isinstance(ref_size, int), ( 2765 "channel axis ref (another channel axis) has to specify fixed" 2766 + " size" 2767 ) 2768 generated_channel_names = [ 2769 Identifier(axis.channel_names.format(i=i)) 2770 for i in range(1, ref_size + 1) 2771 ] 2772 axis.channel_names = generated_channel_names 2773 2774 if (ax_unit := getattr(axis, "unit", None)) != ( 2775 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2776 ): 2777 raise ValueError( 2778 "The units of an axis and its reference axis need to match, but" 2779 + f" '{ax_unit}' != '{ref_unit}'." 2780 ) 2781 ref_axis = valid_independent_refs[ref][1] 2782 if isinstance(ref_axis, BatchAxis): 2783 raise ValueError( 2784 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2785 + " (a batch axis is not allowed as reference)." 2786 ) 2787 2788 if isinstance(axis, WithHalo): 2789 min_size = axis.size.get_size(axis, ref_axis, n=0) 2790 if (min_size - 2 * axis.halo) < 1: 2791 raise ValueError( 2792 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2793 + f" {axis.halo}." 2794 ) 2795 2796 input_halo = axis.halo * axis.scale / ref_axis.scale 2797 if input_halo != int(input_halo) or input_halo % 2 == 1: 2798 raise ValueError( 2799 f"input_halo {input_halo} (output_halo {axis.halo} *" 2800 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2801 + f" {tensor_id}.{axis.id}." 2802 ) 2803 2804 @model_validator(mode="after") 2805 def _validate_test_tensors(self) -> Self: 2806 if not get_validation_context().perform_io_checks: 2807 return self 2808 2809 test_output_arrays = [ 2810 None if descr.test_tensor is None else load_array(descr.test_tensor) 2811 for descr in self.outputs 2812 ] 2813 test_input_arrays = [ 2814 None if descr.test_tensor is None else load_array(descr.test_tensor) 2815 for descr in self.inputs 2816 ] 2817 2818 tensors = { 2819 descr.id: (descr, array) 2820 for descr, array in zip( 2821 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2822 ) 2823 } 2824 validate_tensors(tensors, tensor_origin="test_tensor") 2825 2826 output_arrays = { 2827 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2828 } 2829 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2830 if not rep_tol.absolute_tolerance: 2831 continue 2832 2833 if rep_tol.output_ids: 2834 out_arrays = { 2835 oid: a 2836 for oid, a in output_arrays.items() 2837 if oid in rep_tol.output_ids 2838 } 2839 else: 2840 out_arrays = output_arrays 2841 2842 for out_id, array in out_arrays.items(): 2843 if array is None: 2844 continue 2845 2846 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2847 raise ValueError( 2848 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2849 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2850 + f" (1% of the maximum value of the test tensor '{out_id}')" 2851 ) 2852 2853 return self 2854 2855 @model_validator(mode="after") 2856 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2857 ipt_refs = {t.id for t in self.inputs} 2858 out_refs = {t.id for t in self.outputs} 2859 for ipt in self.inputs: 2860 for p in ipt.preprocessing: 2861 ref = p.kwargs.get("reference_tensor") 2862 if ref is None: 2863 continue 2864 if ref not in ipt_refs: 2865 raise ValueError( 2866 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2867 + f" references are: {ipt_refs}." 2868 ) 2869 2870 for out in self.outputs: 2871 for p in out.postprocessing: 2872 ref = p.kwargs.get("reference_tensor") 2873 if ref is None: 2874 continue 2875 2876 if ref not in ipt_refs and ref not in out_refs: 2877 raise ValueError( 2878 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2879 + f" are: {ipt_refs | out_refs}." 2880 ) 2881 2882 return self 2883 2884 # TODO: use validate funcs in validate_test_tensors 2885 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2886 2887 name: Annotated[ 2888 str, 2889 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 2890 MinLen(5), 2891 MaxLen(128), 2892 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2893 ] 2894 """A human-readable name of this model. 2895 It should be no longer than 64 characters 2896 and may only contain letter, number, underscore, minus, parentheses and spaces. 2897 We recommend to chose a name that refers to the model's task and image modality. 2898 """ 2899 2900 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2901 """Describes the output tensors.""" 2902 2903 @field_validator("outputs", mode="after") 2904 @classmethod 2905 def _validate_tensor_ids( 2906 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2907 ) -> Sequence[OutputTensorDescr]: 2908 tensor_ids = [ 2909 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2910 ] 2911 duplicate_tensor_ids: List[str] = [] 2912 seen: Set[str] = set() 2913 for t in tensor_ids: 2914 if t in seen: 2915 duplicate_tensor_ids.append(t) 2916 2917 seen.add(t) 2918 2919 if duplicate_tensor_ids: 2920 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2921 2922 return outputs 2923 2924 @staticmethod 2925 def _get_axes_with_parameterized_size( 2926 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2927 ): 2928 return { 2929 f"{t.id}.{a.id}": (t, a, a.size) 2930 for t in io 2931 for a in t.axes 2932 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2933 } 2934 2935 @staticmethod 2936 def _get_axes_with_independent_size( 2937 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2938 ): 2939 return { 2940 (t.id, a.id): (t, a, a.size) 2941 for t in io 2942 for a in t.axes 2943 if not isinstance(a, BatchAxis) 2944 and isinstance(a.size, (int, ParameterizedSize)) 2945 } 2946 2947 @field_validator("outputs", mode="after") 2948 @classmethod 2949 def _validate_output_axes( 2950 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2951 ) -> List[OutputTensorDescr]: 2952 input_size_refs = cls._get_axes_with_independent_size( 2953 info.data.get("inputs", []) 2954 ) 2955 output_size_refs = cls._get_axes_with_independent_size(outputs) 2956 2957 for i, out in enumerate(outputs): 2958 valid_independent_refs: Dict[ 2959 Tuple[TensorId, AxisId], 2960 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2961 ] = { 2962 **{ 2963 (out.id, a.id): (out, a, a.size) 2964 for a in out.axes 2965 if not isinstance(a, BatchAxis) 2966 and isinstance(a.size, (int, ParameterizedSize)) 2967 }, 2968 **input_size_refs, 2969 **output_size_refs, 2970 } 2971 for a, ax in enumerate(out.axes): 2972 cls._validate_axis( 2973 "outputs", 2974 i, 2975 out.id, 2976 a, 2977 ax, 2978 valid_independent_refs=valid_independent_refs, 2979 ) 2980 2981 return outputs 2982 2983 packaged_by: List[Author] = Field( 2984 default_factory=cast(Callable[[], List[Author]], list) 2985 ) 2986 """The persons that have packaged and uploaded this model. 2987 Only required if those persons differ from the `authors`.""" 2988 2989 parent: Optional[LinkedModel] = None 2990 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2991 2992 @model_validator(mode="after") 2993 def _validate_parent_is_not_self(self) -> Self: 2994 if self.parent is not None and self.parent.id == self.id: 2995 raise ValueError("A model description may not reference itself as parent.") 2996 2997 return self 2998 2999 run_mode: Annotated[ 3000 Optional[RunMode], 3001 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 3002 ] = None 3003 """Custom run mode for this model: for more complex prediction procedures like test time 3004 data augmentation that currently cannot be expressed in the specification. 3005 No standard run modes are defined yet.""" 3006 3007 timestamp: Datetime = Field(default_factory=Datetime.now) 3008 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 3009 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 3010 (In Python a datetime object is valid, too).""" 3011 3012 training_data: Annotated[ 3013 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 3014 Field(union_mode="left_to_right"), 3015 ] = None 3016 """The dataset used to train this model""" 3017 3018 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 3019 """The weights for this model. 3020 Weights can be given for different formats, but should otherwise be equivalent. 3021 The available weight formats determine which consumers can use this model.""" 3022 3023 config: Config = Field(default_factory=Config.model_construct) 3024 3025 @model_validator(mode="after") 3026 def _add_default_cover(self) -> Self: 3027 if not get_validation_context().perform_io_checks or self.covers: 3028 return self 3029 3030 try: 3031 generated_covers = generate_covers( 3032 [ 3033 (t, load_array(t.test_tensor)) 3034 for t in self.inputs 3035 if t.test_tensor is not None 3036 ], 3037 [ 3038 (t, load_array(t.test_tensor)) 3039 for t in self.outputs 3040 if t.test_tensor is not None 3041 ], 3042 ) 3043 except Exception as e: 3044 issue_warning( 3045 "Failed to generate cover image(s): {e}", 3046 value=self.covers, 3047 msg_context=dict(e=e), 3048 field="covers", 3049 ) 3050 else: 3051 self.covers.extend(generated_covers) 3052 3053 return self 3054 3055 def get_input_test_arrays(self) -> List[NDArray[Any]]: 3056 return self._get_test_arrays(self.inputs) 3057 3058 def get_output_test_arrays(self) -> List[NDArray[Any]]: 3059 return self._get_test_arrays(self.outputs) 3060 3061 @staticmethod 3062 def _get_test_arrays( 3063 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 3064 ): 3065 ts: List[FileDescr] = [] 3066 for d in io_descr: 3067 if d.test_tensor is None: 3068 raise ValueError( 3069 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 3070 ) 3071 ts.append(d.test_tensor) 3072 3073 data = [load_array(t) for t in ts] 3074 assert all(isinstance(d, np.ndarray) for d in data) 3075 return data 3076 3077 @staticmethod 3078 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3079 batch_size = 1 3080 tensor_with_batchsize: Optional[TensorId] = None 3081 for tid in tensor_sizes: 3082 for aid, s in tensor_sizes[tid].items(): 3083 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3084 continue 3085 3086 if batch_size != 1: 3087 assert tensor_with_batchsize is not None 3088 raise ValueError( 3089 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3090 ) 3091 3092 batch_size = s 3093 tensor_with_batchsize = tid 3094 3095 return batch_size 3096 3097 def get_output_tensor_sizes( 3098 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3099 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3100 """Returns the tensor output sizes for given **input_sizes**. 3101 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3102 Otherwise it might be larger than the actual (valid) output""" 3103 batch_size = self.get_batch_size(input_sizes) 3104 ns = self.get_ns(input_sizes) 3105 3106 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3107 return tensor_sizes.outputs 3108 3109 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3110 """get parameter `n` for each parameterized axis 3111 such that the valid input size is >= the given input size""" 3112 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3113 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3114 for tid in input_sizes: 3115 for aid, s in input_sizes[tid].items(): 3116 size_descr = axes[tid][aid].size 3117 if isinstance(size_descr, ParameterizedSize): 3118 ret[(tid, aid)] = size_descr.get_n(s) 3119 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3120 pass 3121 else: 3122 assert_never(size_descr) 3123 3124 return ret 3125 3126 def get_tensor_sizes( 3127 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3128 ) -> _TensorSizes: 3129 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3130 return _TensorSizes( 3131 { 3132 t: { 3133 aa: axis_sizes.inputs[(tt, aa)] 3134 for tt, aa in axis_sizes.inputs 3135 if tt == t 3136 } 3137 for t in {tt for tt, _ in axis_sizes.inputs} 3138 }, 3139 { 3140 t: { 3141 aa: axis_sizes.outputs[(tt, aa)] 3142 for tt, aa in axis_sizes.outputs 3143 if tt == t 3144 } 3145 for t in {tt for tt, _ in axis_sizes.outputs} 3146 }, 3147 ) 3148 3149 def get_axis_sizes( 3150 self, 3151 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3152 batch_size: Optional[int] = None, 3153 *, 3154 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3155 ) -> _AxisSizes: 3156 """Determine input and output block shape for scale factors **ns** 3157 of parameterized input sizes. 3158 3159 Args: 3160 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3161 that is parameterized as `size = min + n * step`. 3162 batch_size: The desired size of the batch dimension. 3163 If given **batch_size** overwrites any batch size present in 3164 **max_input_shape**. Default 1. 3165 max_input_shape: Limits the derived block shapes. 3166 Each axis for which the input size, parameterized by `n`, is larger 3167 than **max_input_shape** is set to the minimal value `n_min` for which 3168 this is still true. 3169 Use this for small input samples or large values of **ns**. 3170 Or simply whenever you know the full input shape. 3171 3172 Returns: 3173 Resolved axis sizes for model inputs and outputs. 3174 """ 3175 max_input_shape = max_input_shape or {} 3176 if batch_size is None: 3177 for (_t_id, a_id), s in max_input_shape.items(): 3178 if a_id == BATCH_AXIS_ID: 3179 batch_size = s 3180 break 3181 else: 3182 batch_size = 1 3183 3184 all_axes = { 3185 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3186 } 3187 3188 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3189 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3190 3191 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3192 if isinstance(a, BatchAxis): 3193 if (t_descr.id, a.id) in ns: 3194 logger.warning( 3195 "Ignoring unexpected size increment factor (n) for batch axis" 3196 + " of tensor '{}'.", 3197 t_descr.id, 3198 ) 3199 return batch_size 3200 elif isinstance(a.size, int): 3201 if (t_descr.id, a.id) in ns: 3202 logger.warning( 3203 "Ignoring unexpected size increment factor (n) for fixed size" 3204 + " axis '{}' of tensor '{}'.", 3205 a.id, 3206 t_descr.id, 3207 ) 3208 return a.size 3209 elif isinstance(a.size, ParameterizedSize): 3210 if (t_descr.id, a.id) not in ns: 3211 raise ValueError( 3212 "Size increment factor (n) missing for parametrized axis" 3213 + f" '{a.id}' of tensor '{t_descr.id}'." 3214 ) 3215 n = ns[(t_descr.id, a.id)] 3216 s_max = max_input_shape.get((t_descr.id, a.id)) 3217 if s_max is not None: 3218 n = min(n, a.size.get_n(s_max)) 3219 3220 return a.size.get_size(n) 3221 3222 elif isinstance(a.size, SizeReference): 3223 if (t_descr.id, a.id) in ns: 3224 logger.warning( 3225 "Ignoring unexpected size increment factor (n) for axis '{}'" 3226 + " of tensor '{}' with size reference.", 3227 a.id, 3228 t_descr.id, 3229 ) 3230 assert not isinstance(a, BatchAxis) 3231 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3232 assert not isinstance(ref_axis, BatchAxis) 3233 ref_key = (a.size.tensor_id, a.size.axis_id) 3234 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3235 assert ref_size is not None, ref_key 3236 assert not isinstance(ref_size, _DataDepSize), ref_key 3237 return a.size.get_size( 3238 axis=a, 3239 ref_axis=ref_axis, 3240 ref_size=ref_size, 3241 ) 3242 elif isinstance(a.size, DataDependentSize): 3243 if (t_descr.id, a.id) in ns: 3244 logger.warning( 3245 "Ignoring unexpected increment factor (n) for data dependent" 3246 + " size axis '{}' of tensor '{}'.", 3247 a.id, 3248 t_descr.id, 3249 ) 3250 return _DataDepSize(a.size.min, a.size.max) 3251 else: 3252 assert_never(a.size) 3253 3254 # first resolve all , but the `SizeReference` input sizes 3255 for t_descr in self.inputs: 3256 for a in t_descr.axes: 3257 if not isinstance(a.size, SizeReference): 3258 s = get_axis_size(a) 3259 assert not isinstance(s, _DataDepSize) 3260 inputs[t_descr.id, a.id] = s 3261 3262 # resolve all other input axis sizes 3263 for t_descr in self.inputs: 3264 for a in t_descr.axes: 3265 if isinstance(a.size, SizeReference): 3266 s = get_axis_size(a) 3267 assert not isinstance(s, _DataDepSize) 3268 inputs[t_descr.id, a.id] = s 3269 3270 # resolve all output axis sizes 3271 for t_descr in self.outputs: 3272 for a in t_descr.axes: 3273 assert not isinstance(a.size, ParameterizedSize) 3274 s = get_axis_size(a) 3275 outputs[t_descr.id, a.id] = s 3276 3277 return _AxisSizes(inputs=inputs, outputs=outputs) 3278 3279 @model_validator(mode="before") 3280 @classmethod 3281 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3282 cls.convert_from_old_format_wo_validation(data) 3283 return data 3284 3285 @classmethod 3286 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3287 """Convert metadata following an older format version to this classes' format 3288 without validating the result. 3289 """ 3290 if ( 3291 data.get("type") == "model" 3292 and isinstance(fv := data.get("format_version"), str) 3293 and fv.count(".") == 2 3294 ): 3295 fv_parts = fv.split(".") 3296 if any(not p.isdigit() for p in fv_parts): 3297 return 3298 3299 fv_tuple = tuple(map(int, fv_parts)) 3300 3301 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3302 if fv_tuple[:2] in ((0, 3), (0, 4)): 3303 m04 = _ModelDescr_v0_4.load(data) 3304 if isinstance(m04, InvalidDescr): 3305 try: 3306 updated = _model_conv.convert_as_dict( 3307 m04 # pyright: ignore[reportArgumentType] 3308 ) 3309 except Exception as e: 3310 logger.error( 3311 "Failed to convert from invalid model 0.4 description." 3312 + f"\nerror: {e}" 3313 + "\nProceeding with model 0.5 validation without conversion." 3314 ) 3315 updated = None 3316 else: 3317 updated = _model_conv.convert_as_dict(m04) 3318 3319 if updated is not None: 3320 data.clear() 3321 data.update(updated) 3322 3323 elif fv_tuple[:2] == (0, 5): 3324 # bump patch version 3325 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md. An .md suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
Describes the output tensors.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
3077 @staticmethod 3078 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3079 batch_size = 1 3080 tensor_with_batchsize: Optional[TensorId] = None 3081 for tid in tensor_sizes: 3082 for aid, s in tensor_sizes[tid].items(): 3083 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3084 continue 3085 3086 if batch_size != 1: 3087 assert tensor_with_batchsize is not None 3088 raise ValueError( 3089 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3090 ) 3091 3092 batch_size = s 3093 tensor_with_batchsize = tid 3094 3095 return batch_size
3097 def get_output_tensor_sizes( 3098 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3099 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3100 """Returns the tensor output sizes for given **input_sizes**. 3101 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3102 Otherwise it might be larger than the actual (valid) output""" 3103 batch_size = self.get_batch_size(input_sizes) 3104 ns = self.get_ns(input_sizes) 3105 3106 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3107 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
3109 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3110 """get parameter `n` for each parameterized axis 3111 such that the valid input size is >= the given input size""" 3112 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3113 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3114 for tid in input_sizes: 3115 for aid, s in input_sizes[tid].items(): 3116 size_descr = axes[tid][aid].size 3117 if isinstance(size_descr, ParameterizedSize): 3118 ret[(tid, aid)] = size_descr.get_n(s) 3119 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3120 pass 3121 else: 3122 assert_never(size_descr) 3123 3124 return ret
get parameter n for each parameterized axis
such that the valid input size is >= the given input size
3126 def get_tensor_sizes( 3127 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3128 ) -> _TensorSizes: 3129 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3130 return _TensorSizes( 3131 { 3132 t: { 3133 aa: axis_sizes.inputs[(tt, aa)] 3134 for tt, aa in axis_sizes.inputs 3135 if tt == t 3136 } 3137 for t in {tt for tt, _ in axis_sizes.inputs} 3138 }, 3139 { 3140 t: { 3141 aa: axis_sizes.outputs[(tt, aa)] 3142 for tt, aa in axis_sizes.outputs 3143 if tt == t 3144 } 3145 for t in {tt for tt, _ in axis_sizes.outputs} 3146 }, 3147 )
3149 def get_axis_sizes( 3150 self, 3151 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3152 batch_size: Optional[int] = None, 3153 *, 3154 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3155 ) -> _AxisSizes: 3156 """Determine input and output block shape for scale factors **ns** 3157 of parameterized input sizes. 3158 3159 Args: 3160 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3161 that is parameterized as `size = min + n * step`. 3162 batch_size: The desired size of the batch dimension. 3163 If given **batch_size** overwrites any batch size present in 3164 **max_input_shape**. Default 1. 3165 max_input_shape: Limits the derived block shapes. 3166 Each axis for which the input size, parameterized by `n`, is larger 3167 than **max_input_shape** is set to the minimal value `n_min` for which 3168 this is still true. 3169 Use this for small input samples or large values of **ns**. 3170 Or simply whenever you know the full input shape. 3171 3172 Returns: 3173 Resolved axis sizes for model inputs and outputs. 3174 """ 3175 max_input_shape = max_input_shape or {} 3176 if batch_size is None: 3177 for (_t_id, a_id), s in max_input_shape.items(): 3178 if a_id == BATCH_AXIS_ID: 3179 batch_size = s 3180 break 3181 else: 3182 batch_size = 1 3183 3184 all_axes = { 3185 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3186 } 3187 3188 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3189 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3190 3191 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3192 if isinstance(a, BatchAxis): 3193 if (t_descr.id, a.id) in ns: 3194 logger.warning( 3195 "Ignoring unexpected size increment factor (n) for batch axis" 3196 + " of tensor '{}'.", 3197 t_descr.id, 3198 ) 3199 return batch_size 3200 elif isinstance(a.size, int): 3201 if (t_descr.id, a.id) in ns: 3202 logger.warning( 3203 "Ignoring unexpected size increment factor (n) for fixed size" 3204 + " axis '{}' of tensor '{}'.", 3205 a.id, 3206 t_descr.id, 3207 ) 3208 return a.size 3209 elif isinstance(a.size, ParameterizedSize): 3210 if (t_descr.id, a.id) not in ns: 3211 raise ValueError( 3212 "Size increment factor (n) missing for parametrized axis" 3213 + f" '{a.id}' of tensor '{t_descr.id}'." 3214 ) 3215 n = ns[(t_descr.id, a.id)] 3216 s_max = max_input_shape.get((t_descr.id, a.id)) 3217 if s_max is not None: 3218 n = min(n, a.size.get_n(s_max)) 3219 3220 return a.size.get_size(n) 3221 3222 elif isinstance(a.size, SizeReference): 3223 if (t_descr.id, a.id) in ns: 3224 logger.warning( 3225 "Ignoring unexpected size increment factor (n) for axis '{}'" 3226 + " of tensor '{}' with size reference.", 3227 a.id, 3228 t_descr.id, 3229 ) 3230 assert not isinstance(a, BatchAxis) 3231 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3232 assert not isinstance(ref_axis, BatchAxis) 3233 ref_key = (a.size.tensor_id, a.size.axis_id) 3234 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3235 assert ref_size is not None, ref_key 3236 assert not isinstance(ref_size, _DataDepSize), ref_key 3237 return a.size.get_size( 3238 axis=a, 3239 ref_axis=ref_axis, 3240 ref_size=ref_size, 3241 ) 3242 elif isinstance(a.size, DataDependentSize): 3243 if (t_descr.id, a.id) in ns: 3244 logger.warning( 3245 "Ignoring unexpected increment factor (n) for data dependent" 3246 + " size axis '{}' of tensor '{}'.", 3247 a.id, 3248 t_descr.id, 3249 ) 3250 return _DataDepSize(a.size.min, a.size.max) 3251 else: 3252 assert_never(a.size) 3253 3254 # first resolve all , but the `SizeReference` input sizes 3255 for t_descr in self.inputs: 3256 for a in t_descr.axes: 3257 if not isinstance(a.size, SizeReference): 3258 s = get_axis_size(a) 3259 assert not isinstance(s, _DataDepSize) 3260 inputs[t_descr.id, a.id] = s 3261 3262 # resolve all other input axis sizes 3263 for t_descr in self.inputs: 3264 for a in t_descr.axes: 3265 if isinstance(a.size, SizeReference): 3266 s = get_axis_size(a) 3267 assert not isinstance(s, _DataDepSize) 3268 inputs[t_descr.id, a.id] = s 3269 3270 # resolve all output axis sizes 3271 for t_descr in self.outputs: 3272 for a in t_descr.axes: 3273 assert not isinstance(a.size, ParameterizedSize) 3274 s = get_axis_size(a) 3275 outputs[t_descr.id, a.id] = s 3276 3277 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
nfor each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n, is larger than max_input_shape is set to the minimal valuen_minfor which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3285 @classmethod 3286 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3287 """Convert metadata following an older format version to this classes' format 3288 without validating the result. 3289 """ 3290 if ( 3291 data.get("type") == "model" 3292 and isinstance(fv := data.get("format_version"), str) 3293 and fv.count(".") == 2 3294 ): 3295 fv_parts = fv.split(".") 3296 if any(not p.isdigit() for p in fv_parts): 3297 return 3298 3299 fv_tuple = tuple(map(int, fv_parts)) 3300 3301 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3302 if fv_tuple[:2] in ((0, 3), (0, 4)): 3303 m04 = _ModelDescr_v0_4.load(data) 3304 if isinstance(m04, InvalidDescr): 3305 try: 3306 updated = _model_conv.convert_as_dict( 3307 m04 # pyright: ignore[reportArgumentType] 3308 ) 3309 except Exception as e: 3310 logger.error( 3311 "Failed to convert from invalid model 0.4 description." 3312 + f"\nerror: {e}" 3313 + "\nProceeding with model 0.5 validation without conversion." 3314 ) 3315 updated = None 3316 else: 3317 updated = _model_conv.convert_as_dict(m04) 3318 3319 if updated is not None: 3320 data.clear() 3321 data.update(updated) 3322 3323 elif fv_tuple[:2] == (0, 5): 3324 # bump patch version 3325 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
Union of any released model desription