bioimageio.spec.model
implementaions of all released minor versions are available in submodules:
- model v0_4:
bioimageio.spec.model.v0_4.ModelDescr
- model v0_5:
bioimageio.spec.model.v0_5.ModelDescr
1# autogen: start 2""" 3implementaions of all released minor versions are available in submodules: 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr` 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr` 6""" 7 8from typing import Union 9 10from pydantic import Discriminator, Field 11from typing_extensions import Annotated 12 13from . import v0_4, v0_5 14 15ModelDescr = v0_5.ModelDescr 16ModelDescr_v0_4 = v0_4.ModelDescr 17ModelDescr_v0_5 = v0_5.ModelDescr 18 19AnyModelDescr = Annotated[ 20 Union[ 21 Annotated[ModelDescr_v0_4, Field(title="model 0.4")], 22 Annotated[ModelDescr_v0_5, Field(title="model 0.5")], 23 ], 24 Discriminator("format_version"), 25 Field(title="model"), 26] 27"""Union of any released model desription""" 28# autogen: stop
2554class ModelDescr(GenericModelDescrBase): 2555 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2556 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2557 """ 2558 2559 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 2560 if TYPE_CHECKING: 2561 format_version: Literal["0.5.4"] = "0.5.4" 2562 else: 2563 format_version: Literal["0.5.4"] 2564 """Version of the bioimage.io model description specification used. 2565 When creating a new model always use the latest micro/patch version described here. 2566 The `format_version` is important for any consumer software to understand how to parse the fields. 2567 """ 2568 2569 implemented_type: ClassVar[Literal["model"]] = "model" 2570 if TYPE_CHECKING: 2571 type: Literal["model"] = "model" 2572 else: 2573 type: Literal["model"] 2574 """Specialized resource type 'model'""" 2575 2576 id: Optional[ModelId] = None 2577 """bioimage.io-wide unique resource identifier 2578 assigned by bioimage.io; version **un**specific.""" 2579 2580 authors: NotEmpty[List[Author]] 2581 """The authors are the creators of the model RDF and the primary points of contact.""" 2582 2583 documentation: FileSource_documentation 2584 """URL or relative path to a markdown file with additional documentation. 2585 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2586 The documentation should include a '#[#] Validation' (sub)section 2587 with details on how to quantitatively validate the model on unseen data.""" 2588 2589 @field_validator("documentation", mode="after") 2590 @classmethod 2591 def _validate_documentation( 2592 cls, value: FileSource_documentation 2593 ) -> FileSource_documentation: 2594 if not get_validation_context().perform_io_checks: 2595 return value 2596 2597 doc_reader = get_reader(value) 2598 doc_content = doc_reader.read().decode(encoding="utf-8") 2599 if not re.search("#.*[vV]alidation", doc_content): 2600 issue_warning( 2601 "No '# Validation' (sub)section found in {value}.", 2602 value=value, 2603 field="documentation", 2604 ) 2605 2606 return value 2607 2608 inputs: NotEmpty[Sequence[InputTensorDescr]] 2609 """Describes the input tensors expected by this model.""" 2610 2611 @field_validator("inputs", mode="after") 2612 @classmethod 2613 def _validate_input_axes( 2614 cls, inputs: Sequence[InputTensorDescr] 2615 ) -> Sequence[InputTensorDescr]: 2616 input_size_refs = cls._get_axes_with_independent_size(inputs) 2617 2618 for i, ipt in enumerate(inputs): 2619 valid_independent_refs: Dict[ 2620 Tuple[TensorId, AxisId], 2621 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2622 ] = { 2623 **{ 2624 (ipt.id, a.id): (ipt, a, a.size) 2625 for a in ipt.axes 2626 if not isinstance(a, BatchAxis) 2627 and isinstance(a.size, (int, ParameterizedSize)) 2628 }, 2629 **input_size_refs, 2630 } 2631 for a, ax in enumerate(ipt.axes): 2632 cls._validate_axis( 2633 "inputs", 2634 i=i, 2635 tensor_id=ipt.id, 2636 a=a, 2637 axis=ax, 2638 valid_independent_refs=valid_independent_refs, 2639 ) 2640 return inputs 2641 2642 @staticmethod 2643 def _validate_axis( 2644 field_name: str, 2645 i: int, 2646 tensor_id: TensorId, 2647 a: int, 2648 axis: AnyAxis, 2649 valid_independent_refs: Dict[ 2650 Tuple[TensorId, AxisId], 2651 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2652 ], 2653 ): 2654 if isinstance(axis, BatchAxis) or isinstance( 2655 axis.size, (int, ParameterizedSize, DataDependentSize) 2656 ): 2657 return 2658 elif not isinstance(axis.size, SizeReference): 2659 assert_never(axis.size) 2660 2661 # validate axis.size SizeReference 2662 ref = (axis.size.tensor_id, axis.size.axis_id) 2663 if ref not in valid_independent_refs: 2664 raise ValueError( 2665 "Invalid tensor axis reference at" 2666 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2667 ) 2668 if ref == (tensor_id, axis.id): 2669 raise ValueError( 2670 "Self-referencing not allowed for" 2671 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2672 ) 2673 if axis.type == "channel": 2674 if valid_independent_refs[ref][1].type != "channel": 2675 raise ValueError( 2676 "A channel axis' size may only reference another fixed size" 2677 + " channel axis." 2678 ) 2679 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2680 ref_size = valid_independent_refs[ref][2] 2681 assert isinstance(ref_size, int), ( 2682 "channel axis ref (another channel axis) has to specify fixed" 2683 + " size" 2684 ) 2685 generated_channel_names = [ 2686 Identifier(axis.channel_names.format(i=i)) 2687 for i in range(1, ref_size + 1) 2688 ] 2689 axis.channel_names = generated_channel_names 2690 2691 if (ax_unit := getattr(axis, "unit", None)) != ( 2692 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2693 ): 2694 raise ValueError( 2695 "The units of an axis and its reference axis need to match, but" 2696 + f" '{ax_unit}' != '{ref_unit}'." 2697 ) 2698 ref_axis = valid_independent_refs[ref][1] 2699 if isinstance(ref_axis, BatchAxis): 2700 raise ValueError( 2701 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2702 + " (a batch axis is not allowed as reference)." 2703 ) 2704 2705 if isinstance(axis, WithHalo): 2706 min_size = axis.size.get_size(axis, ref_axis, n=0) 2707 if (min_size - 2 * axis.halo) < 1: 2708 raise ValueError( 2709 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2710 + f" {axis.halo}." 2711 ) 2712 2713 input_halo = axis.halo * axis.scale / ref_axis.scale 2714 if input_halo != int(input_halo) or input_halo % 2 == 1: 2715 raise ValueError( 2716 f"input_halo {input_halo} (output_halo {axis.halo} *" 2717 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2718 + f" {tensor_id}.{axis.id}." 2719 ) 2720 2721 @model_validator(mode="after") 2722 def _validate_test_tensors(self) -> Self: 2723 if not get_validation_context().perform_io_checks: 2724 return self 2725 2726 test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs] 2727 test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs] 2728 2729 tensors = { 2730 descr.id: (descr, array) 2731 for descr, array in zip( 2732 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2733 ) 2734 } 2735 validate_tensors(tensors, tensor_origin="test_tensor") 2736 2737 output_arrays = { 2738 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2739 } 2740 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2741 if not rep_tol.absolute_tolerance: 2742 continue 2743 2744 if rep_tol.output_ids: 2745 out_arrays = { 2746 oid: a 2747 for oid, a in output_arrays.items() 2748 if oid in rep_tol.output_ids 2749 } 2750 else: 2751 out_arrays = output_arrays 2752 2753 for out_id, array in out_arrays.items(): 2754 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2755 raise ValueError( 2756 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2757 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2758 + f" (1% of the maximum value of the test tensor '{out_id}')" 2759 ) 2760 2761 return self 2762 2763 @model_validator(mode="after") 2764 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2765 ipt_refs = {t.id for t in self.inputs} 2766 out_refs = {t.id for t in self.outputs} 2767 for ipt in self.inputs: 2768 for p in ipt.preprocessing: 2769 ref = p.kwargs.get("reference_tensor") 2770 if ref is None: 2771 continue 2772 if ref not in ipt_refs: 2773 raise ValueError( 2774 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2775 + f" references are: {ipt_refs}." 2776 ) 2777 2778 for out in self.outputs: 2779 for p in out.postprocessing: 2780 ref = p.kwargs.get("reference_tensor") 2781 if ref is None: 2782 continue 2783 2784 if ref not in ipt_refs and ref not in out_refs: 2785 raise ValueError( 2786 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2787 + f" are: {ipt_refs | out_refs}." 2788 ) 2789 2790 return self 2791 2792 # TODO: use validate funcs in validate_test_tensors 2793 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2794 2795 name: Annotated[ 2796 Annotated[ 2797 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 2798 ], 2799 MinLen(5), 2800 MaxLen(128), 2801 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2802 ] 2803 """A human-readable name of this model. 2804 It should be no longer than 64 characters 2805 and may only contain letter, number, underscore, minus, parentheses and spaces. 2806 We recommend to chose a name that refers to the model's task and image modality. 2807 """ 2808 2809 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2810 """Describes the output tensors.""" 2811 2812 @field_validator("outputs", mode="after") 2813 @classmethod 2814 def _validate_tensor_ids( 2815 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2816 ) -> Sequence[OutputTensorDescr]: 2817 tensor_ids = [ 2818 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2819 ] 2820 duplicate_tensor_ids: List[str] = [] 2821 seen: Set[str] = set() 2822 for t in tensor_ids: 2823 if t in seen: 2824 duplicate_tensor_ids.append(t) 2825 2826 seen.add(t) 2827 2828 if duplicate_tensor_ids: 2829 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2830 2831 return outputs 2832 2833 @staticmethod 2834 def _get_axes_with_parameterized_size( 2835 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2836 ): 2837 return { 2838 f"{t.id}.{a.id}": (t, a, a.size) 2839 for t in io 2840 for a in t.axes 2841 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2842 } 2843 2844 @staticmethod 2845 def _get_axes_with_independent_size( 2846 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2847 ): 2848 return { 2849 (t.id, a.id): (t, a, a.size) 2850 for t in io 2851 for a in t.axes 2852 if not isinstance(a, BatchAxis) 2853 and isinstance(a.size, (int, ParameterizedSize)) 2854 } 2855 2856 @field_validator("outputs", mode="after") 2857 @classmethod 2858 def _validate_output_axes( 2859 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2860 ) -> List[OutputTensorDescr]: 2861 input_size_refs = cls._get_axes_with_independent_size( 2862 info.data.get("inputs", []) 2863 ) 2864 output_size_refs = cls._get_axes_with_independent_size(outputs) 2865 2866 for i, out in enumerate(outputs): 2867 valid_independent_refs: Dict[ 2868 Tuple[TensorId, AxisId], 2869 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2870 ] = { 2871 **{ 2872 (out.id, a.id): (out, a, a.size) 2873 for a in out.axes 2874 if not isinstance(a, BatchAxis) 2875 and isinstance(a.size, (int, ParameterizedSize)) 2876 }, 2877 **input_size_refs, 2878 **output_size_refs, 2879 } 2880 for a, ax in enumerate(out.axes): 2881 cls._validate_axis( 2882 "outputs", 2883 i, 2884 out.id, 2885 a, 2886 ax, 2887 valid_independent_refs=valid_independent_refs, 2888 ) 2889 2890 return outputs 2891 2892 packaged_by: List[Author] = Field( 2893 default_factory=cast(Callable[[], List[Author]], list) 2894 ) 2895 """The persons that have packaged and uploaded this model. 2896 Only required if those persons differ from the `authors`.""" 2897 2898 parent: Optional[LinkedModel] = None 2899 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2900 2901 @model_validator(mode="after") 2902 def _validate_parent_is_not_self(self) -> Self: 2903 if self.parent is not None and self.parent.id == self.id: 2904 raise ValueError("A model description may not reference itself as parent.") 2905 2906 return self 2907 2908 run_mode: Annotated[ 2909 Optional[RunMode], 2910 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2911 ] = None 2912 """Custom run mode for this model: for more complex prediction procedures like test time 2913 data augmentation that currently cannot be expressed in the specification. 2914 No standard run modes are defined yet.""" 2915 2916 timestamp: Datetime = Field(default_factory=Datetime.now) 2917 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2918 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2919 (In Python a datetime object is valid, too).""" 2920 2921 training_data: Annotated[ 2922 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2923 Field(union_mode="left_to_right"), 2924 ] = None 2925 """The dataset used to train this model""" 2926 2927 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2928 """The weights for this model. 2929 Weights can be given for different formats, but should otherwise be equivalent. 2930 The available weight formats determine which consumers can use this model.""" 2931 2932 config: Config = Field(default_factory=Config) 2933 2934 @model_validator(mode="after") 2935 def _add_default_cover(self) -> Self: 2936 if not get_validation_context().perform_io_checks or self.covers: 2937 return self 2938 2939 try: 2940 generated_covers = generate_covers( 2941 [(t, load_array(t.test_tensor)) for t in self.inputs], 2942 [(t, load_array(t.test_tensor)) for t in self.outputs], 2943 ) 2944 except Exception as e: 2945 issue_warning( 2946 "Failed to generate cover image(s): {e}", 2947 value=self.covers, 2948 msg_context=dict(e=e), 2949 field="covers", 2950 ) 2951 else: 2952 self.covers.extend(generated_covers) 2953 2954 return self 2955 2956 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2957 data = [load_array(ipt.test_tensor) for ipt in self.inputs] 2958 assert all(isinstance(d, np.ndarray) for d in data) 2959 return data 2960 2961 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2962 data = [load_array(out.test_tensor) for out in self.outputs] 2963 assert all(isinstance(d, np.ndarray) for d in data) 2964 return data 2965 2966 @staticmethod 2967 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2968 batch_size = 1 2969 tensor_with_batchsize: Optional[TensorId] = None 2970 for tid in tensor_sizes: 2971 for aid, s in tensor_sizes[tid].items(): 2972 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2973 continue 2974 2975 if batch_size != 1: 2976 assert tensor_with_batchsize is not None 2977 raise ValueError( 2978 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2979 ) 2980 2981 batch_size = s 2982 tensor_with_batchsize = tid 2983 2984 return batch_size 2985 2986 def get_output_tensor_sizes( 2987 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2988 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2989 """Returns the tensor output sizes for given **input_sizes**. 2990 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2991 Otherwise it might be larger than the actual (valid) output""" 2992 batch_size = self.get_batch_size(input_sizes) 2993 ns = self.get_ns(input_sizes) 2994 2995 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2996 return tensor_sizes.outputs 2997 2998 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2999 """get parameter `n` for each parameterized axis 3000 such that the valid input size is >= the given input size""" 3001 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3002 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3003 for tid in input_sizes: 3004 for aid, s in input_sizes[tid].items(): 3005 size_descr = axes[tid][aid].size 3006 if isinstance(size_descr, ParameterizedSize): 3007 ret[(tid, aid)] = size_descr.get_n(s) 3008 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3009 pass 3010 else: 3011 assert_never(size_descr) 3012 3013 return ret 3014 3015 def get_tensor_sizes( 3016 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3017 ) -> _TensorSizes: 3018 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3019 return _TensorSizes( 3020 { 3021 t: { 3022 aa: axis_sizes.inputs[(tt, aa)] 3023 for tt, aa in axis_sizes.inputs 3024 if tt == t 3025 } 3026 for t in {tt for tt, _ in axis_sizes.inputs} 3027 }, 3028 { 3029 t: { 3030 aa: axis_sizes.outputs[(tt, aa)] 3031 for tt, aa in axis_sizes.outputs 3032 if tt == t 3033 } 3034 for t in {tt for tt, _ in axis_sizes.outputs} 3035 }, 3036 ) 3037 3038 def get_axis_sizes( 3039 self, 3040 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3041 batch_size: Optional[int] = None, 3042 *, 3043 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3044 ) -> _AxisSizes: 3045 """Determine input and output block shape for scale factors **ns** 3046 of parameterized input sizes. 3047 3048 Args: 3049 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3050 that is parameterized as `size = min + n * step`. 3051 batch_size: The desired size of the batch dimension. 3052 If given **batch_size** overwrites any batch size present in 3053 **max_input_shape**. Default 1. 3054 max_input_shape: Limits the derived block shapes. 3055 Each axis for which the input size, parameterized by `n`, is larger 3056 than **max_input_shape** is set to the minimal value `n_min` for which 3057 this is still true. 3058 Use this for small input samples or large values of **ns**. 3059 Or simply whenever you know the full input shape. 3060 3061 Returns: 3062 Resolved axis sizes for model inputs and outputs. 3063 """ 3064 max_input_shape = max_input_shape or {} 3065 if batch_size is None: 3066 for (_t_id, a_id), s in max_input_shape.items(): 3067 if a_id == BATCH_AXIS_ID: 3068 batch_size = s 3069 break 3070 else: 3071 batch_size = 1 3072 3073 all_axes = { 3074 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3075 } 3076 3077 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3078 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3079 3080 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3081 if isinstance(a, BatchAxis): 3082 if (t_descr.id, a.id) in ns: 3083 logger.warning( 3084 "Ignoring unexpected size increment factor (n) for batch axis" 3085 + " of tensor '{}'.", 3086 t_descr.id, 3087 ) 3088 return batch_size 3089 elif isinstance(a.size, int): 3090 if (t_descr.id, a.id) in ns: 3091 logger.warning( 3092 "Ignoring unexpected size increment factor (n) for fixed size" 3093 + " axis '{}' of tensor '{}'.", 3094 a.id, 3095 t_descr.id, 3096 ) 3097 return a.size 3098 elif isinstance(a.size, ParameterizedSize): 3099 if (t_descr.id, a.id) not in ns: 3100 raise ValueError( 3101 "Size increment factor (n) missing for parametrized axis" 3102 + f" '{a.id}' of tensor '{t_descr.id}'." 3103 ) 3104 n = ns[(t_descr.id, a.id)] 3105 s_max = max_input_shape.get((t_descr.id, a.id)) 3106 if s_max is not None: 3107 n = min(n, a.size.get_n(s_max)) 3108 3109 return a.size.get_size(n) 3110 3111 elif isinstance(a.size, SizeReference): 3112 if (t_descr.id, a.id) in ns: 3113 logger.warning( 3114 "Ignoring unexpected size increment factor (n) for axis '{}'" 3115 + " of tensor '{}' with size reference.", 3116 a.id, 3117 t_descr.id, 3118 ) 3119 assert not isinstance(a, BatchAxis) 3120 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3121 assert not isinstance(ref_axis, BatchAxis) 3122 ref_key = (a.size.tensor_id, a.size.axis_id) 3123 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3124 assert ref_size is not None, ref_key 3125 assert not isinstance(ref_size, _DataDepSize), ref_key 3126 return a.size.get_size( 3127 axis=a, 3128 ref_axis=ref_axis, 3129 ref_size=ref_size, 3130 ) 3131 elif isinstance(a.size, DataDependentSize): 3132 if (t_descr.id, a.id) in ns: 3133 logger.warning( 3134 "Ignoring unexpected increment factor (n) for data dependent" 3135 + " size axis '{}' of tensor '{}'.", 3136 a.id, 3137 t_descr.id, 3138 ) 3139 return _DataDepSize(a.size.min, a.size.max) 3140 else: 3141 assert_never(a.size) 3142 3143 # first resolve all , but the `SizeReference` input sizes 3144 for t_descr in self.inputs: 3145 for a in t_descr.axes: 3146 if not isinstance(a.size, SizeReference): 3147 s = get_axis_size(a) 3148 assert not isinstance(s, _DataDepSize) 3149 inputs[t_descr.id, a.id] = s 3150 3151 # resolve all other input axis sizes 3152 for t_descr in self.inputs: 3153 for a in t_descr.axes: 3154 if isinstance(a.size, SizeReference): 3155 s = get_axis_size(a) 3156 assert not isinstance(s, _DataDepSize) 3157 inputs[t_descr.id, a.id] = s 3158 3159 # resolve all output axis sizes 3160 for t_descr in self.outputs: 3161 for a in t_descr.axes: 3162 assert not isinstance(a.size, ParameterizedSize) 3163 s = get_axis_size(a) 3164 outputs[t_descr.id, a.id] = s 3165 3166 return _AxisSizes(inputs=inputs, outputs=outputs) 3167 3168 @model_validator(mode="before") 3169 @classmethod 3170 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3171 cls.convert_from_old_format_wo_validation(data) 3172 return data 3173 3174 @classmethod 3175 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3176 """Convert metadata following an older format version to this classes' format 3177 without validating the result. 3178 """ 3179 if ( 3180 data.get("type") == "model" 3181 and isinstance(fv := data.get("format_version"), str) 3182 and fv.count(".") == 2 3183 ): 3184 fv_parts = fv.split(".") 3185 if any(not p.isdigit() for p in fv_parts): 3186 return 3187 3188 fv_tuple = tuple(map(int, fv_parts)) 3189 3190 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3191 if fv_tuple[:2] in ((0, 3), (0, 4)): 3192 m04 = _ModelDescr_v0_4.load(data) 3193 if isinstance(m04, InvalidDescr): 3194 try: 3195 updated = _model_conv.convert_as_dict( 3196 m04 # pyright: ignore[reportArgumentType] 3197 ) 3198 except Exception as e: 3199 logger.error( 3200 "Failed to convert from invalid model 0.4 description." 3201 + f"\nerror: {e}" 3202 + "\nProceeding with model 0.5 validation without conversion." 3203 ) 3204 updated = None 3205 else: 3206 updated = _model_conv.convert_as_dict(m04) 3207 3208 if updated is not None: 3209 data.clear() 3210 data.update(updated) 3211 3212 elif fv_tuple[:2] == (0, 5): 3213 # bump patch version 3214 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
Describes the output tensors.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
2966 @staticmethod 2967 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2968 batch_size = 1 2969 tensor_with_batchsize: Optional[TensorId] = None 2970 for tid in tensor_sizes: 2971 for aid, s in tensor_sizes[tid].items(): 2972 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2973 continue 2974 2975 if batch_size != 1: 2976 assert tensor_with_batchsize is not None 2977 raise ValueError( 2978 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2979 ) 2980 2981 batch_size = s 2982 tensor_with_batchsize = tid 2983 2984 return batch_size
2986 def get_output_tensor_sizes( 2987 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2988 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2989 """Returns the tensor output sizes for given **input_sizes**. 2990 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2991 Otherwise it might be larger than the actual (valid) output""" 2992 batch_size = self.get_batch_size(input_sizes) 2993 ns = self.get_ns(input_sizes) 2994 2995 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2996 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
2998 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2999 """get parameter `n` for each parameterized axis 3000 such that the valid input size is >= the given input size""" 3001 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 3002 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 3003 for tid in input_sizes: 3004 for aid, s in input_sizes[tid].items(): 3005 size_descr = axes[tid][aid].size 3006 if isinstance(size_descr, ParameterizedSize): 3007 ret[(tid, aid)] = size_descr.get_n(s) 3008 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 3009 pass 3010 else: 3011 assert_never(size_descr) 3012 3013 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
3015 def get_tensor_sizes( 3016 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 3017 ) -> _TensorSizes: 3018 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 3019 return _TensorSizes( 3020 { 3021 t: { 3022 aa: axis_sizes.inputs[(tt, aa)] 3023 for tt, aa in axis_sizes.inputs 3024 if tt == t 3025 } 3026 for t in {tt for tt, _ in axis_sizes.inputs} 3027 }, 3028 { 3029 t: { 3030 aa: axis_sizes.outputs[(tt, aa)] 3031 for tt, aa in axis_sizes.outputs 3032 if tt == t 3033 } 3034 for t in {tt for tt, _ in axis_sizes.outputs} 3035 }, 3036 )
3038 def get_axis_sizes( 3039 self, 3040 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3041 batch_size: Optional[int] = None, 3042 *, 3043 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3044 ) -> _AxisSizes: 3045 """Determine input and output block shape for scale factors **ns** 3046 of parameterized input sizes. 3047 3048 Args: 3049 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3050 that is parameterized as `size = min + n * step`. 3051 batch_size: The desired size of the batch dimension. 3052 If given **batch_size** overwrites any batch size present in 3053 **max_input_shape**. Default 1. 3054 max_input_shape: Limits the derived block shapes. 3055 Each axis for which the input size, parameterized by `n`, is larger 3056 than **max_input_shape** is set to the minimal value `n_min` for which 3057 this is still true. 3058 Use this for small input samples or large values of **ns**. 3059 Or simply whenever you know the full input shape. 3060 3061 Returns: 3062 Resolved axis sizes for model inputs and outputs. 3063 """ 3064 max_input_shape = max_input_shape or {} 3065 if batch_size is None: 3066 for (_t_id, a_id), s in max_input_shape.items(): 3067 if a_id == BATCH_AXIS_ID: 3068 batch_size = s 3069 break 3070 else: 3071 batch_size = 1 3072 3073 all_axes = { 3074 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3075 } 3076 3077 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3078 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3079 3080 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3081 if isinstance(a, BatchAxis): 3082 if (t_descr.id, a.id) in ns: 3083 logger.warning( 3084 "Ignoring unexpected size increment factor (n) for batch axis" 3085 + " of tensor '{}'.", 3086 t_descr.id, 3087 ) 3088 return batch_size 3089 elif isinstance(a.size, int): 3090 if (t_descr.id, a.id) in ns: 3091 logger.warning( 3092 "Ignoring unexpected size increment factor (n) for fixed size" 3093 + " axis '{}' of tensor '{}'.", 3094 a.id, 3095 t_descr.id, 3096 ) 3097 return a.size 3098 elif isinstance(a.size, ParameterizedSize): 3099 if (t_descr.id, a.id) not in ns: 3100 raise ValueError( 3101 "Size increment factor (n) missing for parametrized axis" 3102 + f" '{a.id}' of tensor '{t_descr.id}'." 3103 ) 3104 n = ns[(t_descr.id, a.id)] 3105 s_max = max_input_shape.get((t_descr.id, a.id)) 3106 if s_max is not None: 3107 n = min(n, a.size.get_n(s_max)) 3108 3109 return a.size.get_size(n) 3110 3111 elif isinstance(a.size, SizeReference): 3112 if (t_descr.id, a.id) in ns: 3113 logger.warning( 3114 "Ignoring unexpected size increment factor (n) for axis '{}'" 3115 + " of tensor '{}' with size reference.", 3116 a.id, 3117 t_descr.id, 3118 ) 3119 assert not isinstance(a, BatchAxis) 3120 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3121 assert not isinstance(ref_axis, BatchAxis) 3122 ref_key = (a.size.tensor_id, a.size.axis_id) 3123 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3124 assert ref_size is not None, ref_key 3125 assert not isinstance(ref_size, _DataDepSize), ref_key 3126 return a.size.get_size( 3127 axis=a, 3128 ref_axis=ref_axis, 3129 ref_size=ref_size, 3130 ) 3131 elif isinstance(a.size, DataDependentSize): 3132 if (t_descr.id, a.id) in ns: 3133 logger.warning( 3134 "Ignoring unexpected increment factor (n) for data dependent" 3135 + " size axis '{}' of tensor '{}'.", 3136 a.id, 3137 t_descr.id, 3138 ) 3139 return _DataDepSize(a.size.min, a.size.max) 3140 else: 3141 assert_never(a.size) 3142 3143 # first resolve all , but the `SizeReference` input sizes 3144 for t_descr in self.inputs: 3145 for a in t_descr.axes: 3146 if not isinstance(a.size, SizeReference): 3147 s = get_axis_size(a) 3148 assert not isinstance(s, _DataDepSize) 3149 inputs[t_descr.id, a.id] = s 3150 3151 # resolve all other input axis sizes 3152 for t_descr in self.inputs: 3153 for a in t_descr.axes: 3154 if isinstance(a.size, SizeReference): 3155 s = get_axis_size(a) 3156 assert not isinstance(s, _DataDepSize) 3157 inputs[t_descr.id, a.id] = s 3158 3159 # resolve all output axis sizes 3160 for t_descr in self.outputs: 3161 for a in t_descr.axes: 3162 assert not isinstance(a.size, ParameterizedSize) 3163 s = get_axis_size(a) 3164 outputs[t_descr.id, a.id] = s 3165 3166 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3174 @classmethod 3175 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3176 """Convert metadata following an older format version to this classes' format 3177 without validating the result. 3178 """ 3179 if ( 3180 data.get("type") == "model" 3181 and isinstance(fv := data.get("format_version"), str) 3182 and fv.count(".") == 2 3183 ): 3184 fv_parts = fv.split(".") 3185 if any(not p.isdigit() for p in fv_parts): 3186 return 3187 3188 fv_tuple = tuple(map(int, fv_parts)) 3189 3190 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3191 if fv_tuple[:2] in ((0, 3), (0, 4)): 3192 m04 = _ModelDescr_v0_4.load(data) 3193 if isinstance(m04, InvalidDescr): 3194 try: 3195 updated = _model_conv.convert_as_dict( 3196 m04 # pyright: ignore[reportArgumentType] 3197 ) 3198 except Exception as e: 3199 logger.error( 3200 "Failed to convert from invalid model 0.4 description." 3201 + f"\nerror: {e}" 3202 + "\nProceeding with model 0.5 validation without conversion." 3203 ) 3204 updated = None 3205 else: 3206 updated = _model_conv.convert_as_dict(m04) 3207 3208 if updated is not None: 3209 data.clear() 3210 data.update(updated) 3211 3212 elif fv_tuple[:2] == (0, 5): 3213 # bump patch version 3214 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
337def init_private_attributes(self: BaseModel, context: Any, /) -> None: 338 """This function is meant to behave like a BaseModel method to initialise private attributes. 339 340 It takes context as an argument since that's what pydantic-core passes when calling it. 341 342 Args: 343 self: The BaseModel instance. 344 context: The context. 345 """ 346 if getattr(self, '__pydantic_private__', None) is None: 347 pydantic_private = {} 348 for name, private_attr in self.__private_attributes__.items(): 349 default = private_attr.get_default() 350 if default is not PydanticUndefined: 351 pydantic_private[name] = default 352 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
Inherited Members
Union of any released model desription