bioimageio.spec.model
implementaions of all released minor versions are available in submodules:
- model v0_4:
bioimageio.spec.model.v0_4.ModelDescr
- model v0_5:
bioimageio.spec.model.v0_5.ModelDescr
1# autogen: start 2""" 3implementaions of all released minor versions are available in submodules: 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr` 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr` 6""" 7 8from typing import Union 9 10from pydantic import Discriminator, Field 11from typing_extensions import Annotated 12 13from . import v0_4, v0_5 14 15ModelDescr = v0_5.ModelDescr 16ModelDescr_v0_4 = v0_4.ModelDescr 17ModelDescr_v0_5 = v0_5.ModelDescr 18 19AnyModelDescr = Annotated[ 20 Union[ 21 Annotated[ModelDescr_v0_4, Field(title="model 0.4")], 22 Annotated[ModelDescr_v0_5, Field(title="model 0.5")], 23 ], 24 Discriminator("format_version"), 25 Field(title="model"), 26] 27"""Union of any released model desription""" 28# autogen: stop
2606class ModelDescr(GenericModelDescrBase): 2607 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2608 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2609 """ 2610 2611 implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5" 2612 if TYPE_CHECKING: 2613 format_version: Literal["0.5.5"] = "0.5.5" 2614 else: 2615 format_version: Literal["0.5.5"] 2616 """Version of the bioimage.io model description specification used. 2617 When creating a new model always use the latest micro/patch version described here. 2618 The `format_version` is important for any consumer software to understand how to parse the fields. 2619 """ 2620 2621 implemented_type: ClassVar[Literal["model"]] = "model" 2622 if TYPE_CHECKING: 2623 type: Literal["model"] = "model" 2624 else: 2625 type: Literal["model"] 2626 """Specialized resource type 'model'""" 2627 2628 id: Optional[ModelId] = None 2629 """bioimage.io-wide unique resource identifier 2630 assigned by bioimage.io; version **un**specific.""" 2631 2632 authors: FAIR[List[Author]] = Field( 2633 default_factory=cast(Callable[[], List[Author]], list) 2634 ) 2635 """The authors are the creators of the model RDF and the primary points of contact.""" 2636 2637 documentation: FAIR[Optional[FileSource_documentation]] = None 2638 """URL or relative path to a markdown file with additional documentation. 2639 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2640 The documentation should include a '#[#] Validation' (sub)section 2641 with details on how to quantitatively validate the model on unseen data.""" 2642 2643 @field_validator("documentation", mode="after") 2644 @classmethod 2645 def _validate_documentation( 2646 cls, value: Optional[FileSource_documentation] 2647 ) -> Optional[FileSource_documentation]: 2648 if not get_validation_context().perform_io_checks or value is None: 2649 return value 2650 2651 doc_reader = get_reader(value) 2652 doc_content = doc_reader.read().decode(encoding="utf-8") 2653 if not re.search("#.*[vV]alidation", doc_content): 2654 issue_warning( 2655 "No '# Validation' (sub)section found in {value}.", 2656 value=value, 2657 field="documentation", 2658 ) 2659 2660 return value 2661 2662 inputs: NotEmpty[Sequence[InputTensorDescr]] 2663 """Describes the input tensors expected by this model.""" 2664 2665 @field_validator("inputs", mode="after") 2666 @classmethod 2667 def _validate_input_axes( 2668 cls, inputs: Sequence[InputTensorDescr] 2669 ) -> Sequence[InputTensorDescr]: 2670 input_size_refs = cls._get_axes_with_independent_size(inputs) 2671 2672 for i, ipt in enumerate(inputs): 2673 valid_independent_refs: Dict[ 2674 Tuple[TensorId, AxisId], 2675 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2676 ] = { 2677 **{ 2678 (ipt.id, a.id): (ipt, a, a.size) 2679 for a in ipt.axes 2680 if not isinstance(a, BatchAxis) 2681 and isinstance(a.size, (int, ParameterizedSize)) 2682 }, 2683 **input_size_refs, 2684 } 2685 for a, ax in enumerate(ipt.axes): 2686 cls._validate_axis( 2687 "inputs", 2688 i=i, 2689 tensor_id=ipt.id, 2690 a=a, 2691 axis=ax, 2692 valid_independent_refs=valid_independent_refs, 2693 ) 2694 return inputs 2695 2696 @staticmethod 2697 def _validate_axis( 2698 field_name: str, 2699 i: int, 2700 tensor_id: TensorId, 2701 a: int, 2702 axis: AnyAxis, 2703 valid_independent_refs: Dict[ 2704 Tuple[TensorId, AxisId], 2705 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2706 ], 2707 ): 2708 if isinstance(axis, BatchAxis) or isinstance( 2709 axis.size, (int, ParameterizedSize, DataDependentSize) 2710 ): 2711 return 2712 elif not isinstance(axis.size, SizeReference): 2713 assert_never(axis.size) 2714 2715 # validate axis.size SizeReference 2716 ref = (axis.size.tensor_id, axis.size.axis_id) 2717 if ref not in valid_independent_refs: 2718 raise ValueError( 2719 "Invalid tensor axis reference at" 2720 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2721 ) 2722 if ref == (tensor_id, axis.id): 2723 raise ValueError( 2724 "Self-referencing not allowed for" 2725 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2726 ) 2727 if axis.type == "channel": 2728 if valid_independent_refs[ref][1].type != "channel": 2729 raise ValueError( 2730 "A channel axis' size may only reference another fixed size" 2731 + " channel axis." 2732 ) 2733 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2734 ref_size = valid_independent_refs[ref][2] 2735 assert isinstance(ref_size, int), ( 2736 "channel axis ref (another channel axis) has to specify fixed" 2737 + " size" 2738 ) 2739 generated_channel_names = [ 2740 Identifier(axis.channel_names.format(i=i)) 2741 for i in range(1, ref_size + 1) 2742 ] 2743 axis.channel_names = generated_channel_names 2744 2745 if (ax_unit := getattr(axis, "unit", None)) != ( 2746 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2747 ): 2748 raise ValueError( 2749 "The units of an axis and its reference axis need to match, but" 2750 + f" '{ax_unit}' != '{ref_unit}'." 2751 ) 2752 ref_axis = valid_independent_refs[ref][1] 2753 if isinstance(ref_axis, BatchAxis): 2754 raise ValueError( 2755 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2756 + " (a batch axis is not allowed as reference)." 2757 ) 2758 2759 if isinstance(axis, WithHalo): 2760 min_size = axis.size.get_size(axis, ref_axis, n=0) 2761 if (min_size - 2 * axis.halo) < 1: 2762 raise ValueError( 2763 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2764 + f" {axis.halo}." 2765 ) 2766 2767 input_halo = axis.halo * axis.scale / ref_axis.scale 2768 if input_halo != int(input_halo) or input_halo % 2 == 1: 2769 raise ValueError( 2770 f"input_halo {input_halo} (output_halo {axis.halo} *" 2771 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2772 + f" {tensor_id}.{axis.id}." 2773 ) 2774 2775 @model_validator(mode="after") 2776 def _validate_test_tensors(self) -> Self: 2777 if not get_validation_context().perform_io_checks: 2778 return self 2779 2780 test_output_arrays = [ 2781 None if descr.test_tensor is None else load_array(descr.test_tensor) 2782 for descr in self.outputs 2783 ] 2784 test_input_arrays = [ 2785 None if descr.test_tensor is None else load_array(descr.test_tensor) 2786 for descr in self.inputs 2787 ] 2788 2789 tensors = { 2790 descr.id: (descr, array) 2791 for descr, array in zip( 2792 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2793 ) 2794 } 2795 validate_tensors(tensors, tensor_origin="test_tensor") 2796 2797 output_arrays = { 2798 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2799 } 2800 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2801 if not rep_tol.absolute_tolerance: 2802 continue 2803 2804 if rep_tol.output_ids: 2805 out_arrays = { 2806 oid: a 2807 for oid, a in output_arrays.items() 2808 if oid in rep_tol.output_ids 2809 } 2810 else: 2811 out_arrays = output_arrays 2812 2813 for out_id, array in out_arrays.items(): 2814 if array is None: 2815 continue 2816 2817 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2818 raise ValueError( 2819 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2820 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2821 + f" (1% of the maximum value of the test tensor '{out_id}')" 2822 ) 2823 2824 return self 2825 2826 @model_validator(mode="after") 2827 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2828 ipt_refs = {t.id for t in self.inputs} 2829 out_refs = {t.id for t in self.outputs} 2830 for ipt in self.inputs: 2831 for p in ipt.preprocessing: 2832 ref = p.kwargs.get("reference_tensor") 2833 if ref is None: 2834 continue 2835 if ref not in ipt_refs: 2836 raise ValueError( 2837 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2838 + f" references are: {ipt_refs}." 2839 ) 2840 2841 for out in self.outputs: 2842 for p in out.postprocessing: 2843 ref = p.kwargs.get("reference_tensor") 2844 if ref is None: 2845 continue 2846 2847 if ref not in ipt_refs and ref not in out_refs: 2848 raise ValueError( 2849 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2850 + f" are: {ipt_refs | out_refs}." 2851 ) 2852 2853 return self 2854 2855 # TODO: use validate funcs in validate_test_tensors 2856 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2857 2858 name: Annotated[ 2859 str, 2860 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 2861 MinLen(5), 2862 MaxLen(128), 2863 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2864 ] 2865 """A human-readable name of this model. 2866 It should be no longer than 64 characters 2867 and may only contain letter, number, underscore, minus, parentheses and spaces. 2868 We recommend to chose a name that refers to the model's task and image modality. 2869 """ 2870 2871 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2872 """Describes the output tensors.""" 2873 2874 @field_validator("outputs", mode="after") 2875 @classmethod 2876 def _validate_tensor_ids( 2877 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2878 ) -> Sequence[OutputTensorDescr]: 2879 tensor_ids = [ 2880 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2881 ] 2882 duplicate_tensor_ids: List[str] = [] 2883 seen: Set[str] = set() 2884 for t in tensor_ids: 2885 if t in seen: 2886 duplicate_tensor_ids.append(t) 2887 2888 seen.add(t) 2889 2890 if duplicate_tensor_ids: 2891 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2892 2893 return outputs 2894 2895 @staticmethod 2896 def _get_axes_with_parameterized_size( 2897 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2898 ): 2899 return { 2900 f"{t.id}.{a.id}": (t, a, a.size) 2901 for t in io 2902 for a in t.axes 2903 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2904 } 2905 2906 @staticmethod 2907 def _get_axes_with_independent_size( 2908 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2909 ): 2910 return { 2911 (t.id, a.id): (t, a, a.size) 2912 for t in io 2913 for a in t.axes 2914 if not isinstance(a, BatchAxis) 2915 and isinstance(a.size, (int, ParameterizedSize)) 2916 } 2917 2918 @field_validator("outputs", mode="after") 2919 @classmethod 2920 def _validate_output_axes( 2921 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2922 ) -> List[OutputTensorDescr]: 2923 input_size_refs = cls._get_axes_with_independent_size( 2924 info.data.get("inputs", []) 2925 ) 2926 output_size_refs = cls._get_axes_with_independent_size(outputs) 2927 2928 for i, out in enumerate(outputs): 2929 valid_independent_refs: Dict[ 2930 Tuple[TensorId, AxisId], 2931 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2932 ] = { 2933 **{ 2934 (out.id, a.id): (out, a, a.size) 2935 for a in out.axes 2936 if not isinstance(a, BatchAxis) 2937 and isinstance(a.size, (int, ParameterizedSize)) 2938 }, 2939 **input_size_refs, 2940 **output_size_refs, 2941 } 2942 for a, ax in enumerate(out.axes): 2943 cls._validate_axis( 2944 "outputs", 2945 i, 2946 out.id, 2947 a, 2948 ax, 2949 valid_independent_refs=valid_independent_refs, 2950 ) 2951 2952 return outputs 2953 2954 packaged_by: List[Author] = Field( 2955 default_factory=cast(Callable[[], List[Author]], list) 2956 ) 2957 """The persons that have packaged and uploaded this model. 2958 Only required if those persons differ from the `authors`.""" 2959 2960 parent: Optional[LinkedModel] = None 2961 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2962 2963 @model_validator(mode="after") 2964 def _validate_parent_is_not_self(self) -> Self: 2965 if self.parent is not None and self.parent.id == self.id: 2966 raise ValueError("A model description may not reference itself as parent.") 2967 2968 return self 2969 2970 run_mode: Annotated[ 2971 Optional[RunMode], 2972 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2973 ] = None 2974 """Custom run mode for this model: for more complex prediction procedures like test time 2975 data augmentation that currently cannot be expressed in the specification. 2976 No standard run modes are defined yet.""" 2977 2978 timestamp: Datetime = Field(default_factory=Datetime.now) 2979 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2980 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2981 (In Python a datetime object is valid, too).""" 2982 2983 training_data: Annotated[ 2984 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2985 Field(union_mode="left_to_right"), 2986 ] = None 2987 """The dataset used to train this model""" 2988 2989 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2990 """The weights for this model. 2991 Weights can be given for different formats, but should otherwise be equivalent. 2992 The available weight formats determine which consumers can use this model.""" 2993 2994 config: Config = Field(default_factory=Config.model_construct) 2995 2996 @model_validator(mode="after") 2997 def _add_default_cover(self) -> Self: 2998 if not get_validation_context().perform_io_checks or self.covers: 2999 return self 3000 3001 try: 3002 generated_covers = generate_covers( 3003 [ 3004 (t, load_array(t.test_tensor)) 3005 for t in self.inputs 3006 if t.test_tensor is not None 3007 ], 3008 [ 3009 (t, load_array(t.test_tensor)) 3010 for t in self.outputs 3011 if t.test_tensor is not None 3012 ], 3013 ) 3014 except Exception as e: 3015 issue_warning( 3016 "Failed to generate cover image(s): {e}", 3017 value=self.covers, 3018 msg_context=dict(e=e), 3019 field="covers", 3020 ) 3021 else: 3022 self.covers.extend(generated_covers) 3023 3024 return self 3025 3026 def get_input_test_arrays(self) -> List[NDArray[Any]]: 3027 return self._get_test_arrays(self.inputs) 3028 3029 def get_output_test_arrays(self) -> List[NDArray[Any]]: 3030 return self._get_test_arrays(self.outputs) 3031 3032 @staticmethod 3033 def _get_test_arrays( 3034 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 3035 ): 3036 ts: List[FileDescr] = [] 3037 for d in io_descr: 3038 if d.test_tensor is None: 3039 raise ValueError( 3040 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 3041 ) 3042 ts.append(d.test_tensor) 3043 3044 data = [load_array(t) for t in ts] 3045 assert all(isinstance(d, np.ndarray) for d in data) 3046 return data 3047 3048 @staticmethod 3049 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3050 batch_size = 1 3051 tensor_with_batchsize: Optional[TensorId] = None 3052 for tid in tensor_sizes: 3053 for aid, s in tensor_sizes[tid].items(): 3054 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3055 continue 3056 3057 if batch_size != 1: 3058 assert tensor_with_batchsize is not None 3059 raise ValueError( 3060 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3061 ) 3062 3063 batch_size = s 3064 tensor_with_batchsize = tid 3065 3066 return batch_size 3067 3068 def get_output_tensor_sizes( 3069 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3070 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3071 """Returns the tensor output sizes for given **input_sizes**. 3072 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3073 Otherwise it might be larger than the actual (valid) output""" 3074 batch_size = self.get_batch_size(input_sizes) 3075 ns = self.get_ns(input_sizes) 3076 3077 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3078 return tensor_sizes.outputs 3079 3080 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3081 """get parameter `n` for each parameterized axis 3082 such that the valid input size is >= the given input size""" 3083 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3084 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3085 for tid in input_sizes: 3086 for aid, s in input_sizes[tid].items(): 3087 size_descr = axes[tid][aid].size 3088 if isinstance(size_descr, ParameterizedSize): 3089 ret[(tid, aid)] = size_descr.get_n(s) 3090 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3091 pass 3092 else: 3093 assert_never(size_descr) 3094 3095 return ret 3096 3097 def get_tensor_sizes( 3098 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3099 ) -> _TensorSizes: 3100 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3101 return _TensorSizes( 3102 { 3103 t: { 3104 aa: axis_sizes.inputs[(tt, aa)] 3105 for tt, aa in axis_sizes.inputs 3106 if tt == t 3107 } 3108 for t in {tt for tt, _ in axis_sizes.inputs} 3109 }, 3110 { 3111 t: { 3112 aa: axis_sizes.outputs[(tt, aa)] 3113 for tt, aa in axis_sizes.outputs 3114 if tt == t 3115 } 3116 for t in {tt for tt, _ in axis_sizes.outputs} 3117 }, 3118 ) 3119 3120 def get_axis_sizes( 3121 self, 3122 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3123 batch_size: Optional[int] = None, 3124 *, 3125 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3126 ) -> _AxisSizes: 3127 """Determine input and output block shape for scale factors **ns** 3128 of parameterized input sizes. 3129 3130 Args: 3131 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3132 that is parameterized as `size = min + n * step`. 3133 batch_size: The desired size of the batch dimension. 3134 If given **batch_size** overwrites any batch size present in 3135 **max_input_shape**. Default 1. 3136 max_input_shape: Limits the derived block shapes. 3137 Each axis for which the input size, parameterized by `n`, is larger 3138 than **max_input_shape** is set to the minimal value `n_min` for which 3139 this is still true. 3140 Use this for small input samples or large values of **ns**. 3141 Or simply whenever you know the full input shape. 3142 3143 Returns: 3144 Resolved axis sizes for model inputs and outputs. 3145 """ 3146 max_input_shape = max_input_shape or {} 3147 if batch_size is None: 3148 for (_t_id, a_id), s in max_input_shape.items(): 3149 if a_id == BATCH_AXIS_ID: 3150 batch_size = s 3151 break 3152 else: 3153 batch_size = 1 3154 3155 all_axes = { 3156 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3157 } 3158 3159 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3160 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3161 3162 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3163 if isinstance(a, BatchAxis): 3164 if (t_descr.id, a.id) in ns: 3165 logger.warning( 3166 "Ignoring unexpected size increment factor (n) for batch axis" 3167 + " of tensor '{}'.", 3168 t_descr.id, 3169 ) 3170 return batch_size 3171 elif isinstance(a.size, int): 3172 if (t_descr.id, a.id) in ns: 3173 logger.warning( 3174 "Ignoring unexpected size increment factor (n) for fixed size" 3175 + " axis '{}' of tensor '{}'.", 3176 a.id, 3177 t_descr.id, 3178 ) 3179 return a.size 3180 elif isinstance(a.size, ParameterizedSize): 3181 if (t_descr.id, a.id) not in ns: 3182 raise ValueError( 3183 "Size increment factor (n) missing for parametrized axis" 3184 + f" '{a.id}' of tensor '{t_descr.id}'." 3185 ) 3186 n = ns[(t_descr.id, a.id)] 3187 s_max = max_input_shape.get((t_descr.id, a.id)) 3188 if s_max is not None: 3189 n = min(n, a.size.get_n(s_max)) 3190 3191 return a.size.get_size(n) 3192 3193 elif isinstance(a.size, SizeReference): 3194 if (t_descr.id, a.id) in ns: 3195 logger.warning( 3196 "Ignoring unexpected size increment factor (n) for axis '{}'" 3197 + " of tensor '{}' with size reference.", 3198 a.id, 3199 t_descr.id, 3200 ) 3201 assert not isinstance(a, BatchAxis) 3202 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3203 assert not isinstance(ref_axis, BatchAxis) 3204 ref_key = (a.size.tensor_id, a.size.axis_id) 3205 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3206 assert ref_size is not None, ref_key 3207 assert not isinstance(ref_size, _DataDepSize), ref_key 3208 return a.size.get_size( 3209 axis=a, 3210 ref_axis=ref_axis, 3211 ref_size=ref_size, 3212 ) 3213 elif isinstance(a.size, DataDependentSize): 3214 if (t_descr.id, a.id) in ns: 3215 logger.warning( 3216 "Ignoring unexpected increment factor (n) for data dependent" 3217 + " size axis '{}' of tensor '{}'.", 3218 a.id, 3219 t_descr.id, 3220 ) 3221 return _DataDepSize(a.size.min, a.size.max) 3222 else: 3223 assert_never(a.size) 3224 3225 # first resolve all , but the `SizeReference` input sizes 3226 for t_descr in self.inputs: 3227 for a in t_descr.axes: 3228 if not isinstance(a.size, SizeReference): 3229 s = get_axis_size(a) 3230 assert not isinstance(s, _DataDepSize) 3231 inputs[t_descr.id, a.id] = s 3232 3233 # resolve all other input axis sizes 3234 for t_descr in self.inputs: 3235 for a in t_descr.axes: 3236 if isinstance(a.size, SizeReference): 3237 s = get_axis_size(a) 3238 assert not isinstance(s, _DataDepSize) 3239 inputs[t_descr.id, a.id] = s 3240 3241 # resolve all output axis sizes 3242 for t_descr in self.outputs: 3243 for a in t_descr.axes: 3244 assert not isinstance(a.size, ParameterizedSize) 3245 s = get_axis_size(a) 3246 outputs[t_descr.id, a.id] = s 3247 3248 return _AxisSizes(inputs=inputs, outputs=outputs) 3249 3250 @model_validator(mode="before") 3251 @classmethod 3252 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3253 cls.convert_from_old_format_wo_validation(data) 3254 return data 3255 3256 @classmethod 3257 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3258 """Convert metadata following an older format version to this classes' format 3259 without validating the result. 3260 """ 3261 if ( 3262 data.get("type") == "model" 3263 and isinstance(fv := data.get("format_version"), str) 3264 and fv.count(".") == 2 3265 ): 3266 fv_parts = fv.split(".") 3267 if any(not p.isdigit() for p in fv_parts): 3268 return 3269 3270 fv_tuple = tuple(map(int, fv_parts)) 3271 3272 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3273 if fv_tuple[:2] in ((0, 3), (0, 4)): 3274 m04 = _ModelDescr_v0_4.load(data) 3275 if isinstance(m04, InvalidDescr): 3276 try: 3277 updated = _model_conv.convert_as_dict( 3278 m04 # pyright: ignore[reportArgumentType] 3279 ) 3280 except Exception as e: 3281 logger.error( 3282 "Failed to convert from invalid model 0.4 description." 3283 + f"\nerror: {e}" 3284 + "\nProceeding with model 0.5 validation without conversion." 3285 ) 3286 updated = None 3287 else: 3288 updated = _model_conv.convert_as_dict(m04) 3289 3290 if updated is not None: 3291 data.clear() 3292 data.update(updated) 3293 3294 elif fv_tuple[:2] == (0, 5): 3295 # bump patch version 3296 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
Describes the output tensors.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
3048 @staticmethod 3049 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 3050 batch_size = 1 3051 tensor_with_batchsize: Optional[TensorId] = None 3052 for tid in tensor_sizes: 3053 for aid, s in tensor_sizes[tid].items(): 3054 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 3055 continue 3056 3057 if batch_size != 1: 3058 assert tensor_with_batchsize is not None 3059 raise ValueError( 3060 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 3061 ) 3062 3063 batch_size = s 3064 tensor_with_batchsize = tid 3065 3066 return batch_size
3068 def get_output_tensor_sizes( 3069 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 3070 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 3071 """Returns the tensor output sizes for given **input_sizes**. 3072 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 3073 Otherwise it might be larger than the actual (valid) output""" 3074 batch_size = self.get_batch_size(input_sizes) 3075 ns = self.get_ns(input_sizes) 3076 3077 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 3078 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
3080 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 3081 """get parameter `n` for each parameterized axis 3082 such that the valid input size is >= the given input size""" 3083 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3084 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3085 for tid in input_sizes: 3086 for aid, s in input_sizes[tid].items(): 3087 size_descr = axes[tid][aid].size 3088 if isinstance(size_descr, ParameterizedSize): 3089 ret[(tid, aid)] = size_descr.get_n(s) 3090 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3091 pass 3092 else: 3093 assert_never(size_descr) 3094 3095 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
3097 def get_tensor_sizes( 3098 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3099 ) -> _TensorSizes: 3100 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3101 return _TensorSizes( 3102 { 3103 t: { 3104 aa: axis_sizes.inputs[(tt, aa)] 3105 for tt, aa in axis_sizes.inputs 3106 if tt == t 3107 } 3108 for t in {tt for tt, _ in axis_sizes.inputs} 3109 }, 3110 { 3111 t: { 3112 aa: axis_sizes.outputs[(tt, aa)] 3113 for tt, aa in axis_sizes.outputs 3114 if tt == t 3115 } 3116 for t in {tt for tt, _ in axis_sizes.outputs} 3117 }, 3118 )
3120 def get_axis_sizes( 3121 self, 3122 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3123 batch_size: Optional[int] = None, 3124 *, 3125 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3126 ) -> _AxisSizes: 3127 """Determine input and output block shape for scale factors **ns** 3128 of parameterized input sizes. 3129 3130 Args: 3131 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3132 that is parameterized as `size = min + n * step`. 3133 batch_size: The desired size of the batch dimension. 3134 If given **batch_size** overwrites any batch size present in 3135 **max_input_shape**. Default 1. 3136 max_input_shape: Limits the derived block shapes. 3137 Each axis for which the input size, parameterized by `n`, is larger 3138 than **max_input_shape** is set to the minimal value `n_min` for which 3139 this is still true. 3140 Use this for small input samples or large values of **ns**. 3141 Or simply whenever you know the full input shape. 3142 3143 Returns: 3144 Resolved axis sizes for model inputs and outputs. 3145 """ 3146 max_input_shape = max_input_shape or {} 3147 if batch_size is None: 3148 for (_t_id, a_id), s in max_input_shape.items(): 3149 if a_id == BATCH_AXIS_ID: 3150 batch_size = s 3151 break 3152 else: 3153 batch_size = 1 3154 3155 all_axes = { 3156 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3157 } 3158 3159 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3160 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3161 3162 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3163 if isinstance(a, BatchAxis): 3164 if (t_descr.id, a.id) in ns: 3165 logger.warning( 3166 "Ignoring unexpected size increment factor (n) for batch axis" 3167 + " of tensor '{}'.", 3168 t_descr.id, 3169 ) 3170 return batch_size 3171 elif isinstance(a.size, int): 3172 if (t_descr.id, a.id) in ns: 3173 logger.warning( 3174 "Ignoring unexpected size increment factor (n) for fixed size" 3175 + " axis '{}' of tensor '{}'.", 3176 a.id, 3177 t_descr.id, 3178 ) 3179 return a.size 3180 elif isinstance(a.size, ParameterizedSize): 3181 if (t_descr.id, a.id) not in ns: 3182 raise ValueError( 3183 "Size increment factor (n) missing for parametrized axis" 3184 + f" '{a.id}' of tensor '{t_descr.id}'." 3185 ) 3186 n = ns[(t_descr.id, a.id)] 3187 s_max = max_input_shape.get((t_descr.id, a.id)) 3188 if s_max is not None: 3189 n = min(n, a.size.get_n(s_max)) 3190 3191 return a.size.get_size(n) 3192 3193 elif isinstance(a.size, SizeReference): 3194 if (t_descr.id, a.id) in ns: 3195 logger.warning( 3196 "Ignoring unexpected size increment factor (n) for axis '{}'" 3197 + " of tensor '{}' with size reference.", 3198 a.id, 3199 t_descr.id, 3200 ) 3201 assert not isinstance(a, BatchAxis) 3202 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3203 assert not isinstance(ref_axis, BatchAxis) 3204 ref_key = (a.size.tensor_id, a.size.axis_id) 3205 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3206 assert ref_size is not None, ref_key 3207 assert not isinstance(ref_size, _DataDepSize), ref_key 3208 return a.size.get_size( 3209 axis=a, 3210 ref_axis=ref_axis, 3211 ref_size=ref_size, 3212 ) 3213 elif isinstance(a.size, DataDependentSize): 3214 if (t_descr.id, a.id) in ns: 3215 logger.warning( 3216 "Ignoring unexpected increment factor (n) for data dependent" 3217 + " size axis '{}' of tensor '{}'.", 3218 a.id, 3219 t_descr.id, 3220 ) 3221 return _DataDepSize(a.size.min, a.size.max) 3222 else: 3223 assert_never(a.size) 3224 3225 # first resolve all , but the `SizeReference` input sizes 3226 for t_descr in self.inputs: 3227 for a in t_descr.axes: 3228 if not isinstance(a.size, SizeReference): 3229 s = get_axis_size(a) 3230 assert not isinstance(s, _DataDepSize) 3231 inputs[t_descr.id, a.id] = s 3232 3233 # resolve all other input axis sizes 3234 for t_descr in self.inputs: 3235 for a in t_descr.axes: 3236 if isinstance(a.size, SizeReference): 3237 s = get_axis_size(a) 3238 assert not isinstance(s, _DataDepSize) 3239 inputs[t_descr.id, a.id] = s 3240 3241 # resolve all output axis sizes 3242 for t_descr in self.outputs: 3243 for a in t_descr.axes: 3244 assert not isinstance(a.size, ParameterizedSize) 3245 s = get_axis_size(a) 3246 outputs[t_descr.id, a.id] = s 3247 3248 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3256 @classmethod 3257 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3258 """Convert metadata following an older format version to this classes' format 3259 without validating the result. 3260 """ 3261 if ( 3262 data.get("type") == "model" 3263 and isinstance(fv := data.get("format_version"), str) 3264 and fv.count(".") == 2 3265 ): 3266 fv_parts = fv.split(".") 3267 if any(not p.isdigit() for p in fv_parts): 3268 return 3269 3270 fv_tuple = tuple(map(int, fv_parts)) 3271 3272 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3273 if fv_tuple[:2] in ((0, 3), (0, 4)): 3274 m04 = _ModelDescr_v0_4.load(data) 3275 if isinstance(m04, InvalidDescr): 3276 try: 3277 updated = _model_conv.convert_as_dict( 3278 m04 # pyright: ignore[reportArgumentType] 3279 ) 3280 except Exception as e: 3281 logger.error( 3282 "Failed to convert from invalid model 0.4 description." 3283 + f"\nerror: {e}" 3284 + "\nProceeding with model 0.5 validation without conversion." 3285 ) 3286 updated = None 3287 else: 3288 updated = _model_conv.convert_as_dict(m04) 3289 3290 if updated is not None: 3291 data.clear() 3292 data.update(updated) 3293 3294 elif fv_tuple[:2] == (0, 5): 3295 # bump patch version 3296 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
337def init_private_attributes(self: BaseModel, context: Any, /) -> None: 338 """This function is meant to behave like a BaseModel method to initialise private attributes. 339 340 It takes context as an argument since that's what pydantic-core passes when calling it. 341 342 Args: 343 self: The BaseModel instance. 344 context: The context. 345 """ 346 if getattr(self, '__pydantic_private__', None) is None: 347 pydantic_private = {} 348 for name, private_attr in self.__private_attributes__.items(): 349 default = private_attr.get_default() 350 if default is not PydanticUndefined: 351 pydantic_private[name] = default 352 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
Inherited Members
Union of any released model desription