bioimageio.spec.model
implementaions of all released minor versions are available in submodules:
- model v0_4:
bioimageio.spec.model.v0_4.ModelDescr
- model v0_5:
bioimageio.spec.model.v0_5.ModelDescr
1# autogen: start 2""" 3implementaions of all released minor versions are available in submodules: 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr` 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr` 6""" 7 8from typing import Union 9 10from pydantic import Discriminator, Field 11from typing_extensions import Annotated 12 13from . import v0_4, v0_5 14 15ModelDescr = v0_5.ModelDescr 16ModelDescr_v0_4 = v0_4.ModelDescr 17ModelDescr_v0_5 = v0_5.ModelDescr 18 19AnyModelDescr = Annotated[ 20 Union[ 21 Annotated[ModelDescr_v0_4, Field(title="model 0.4")], 22 Annotated[ModelDescr_v0_5, Field(title="model 0.5")], 23 ], 24 Discriminator("format_version"), 25 Field(title="model"), 26] 27"""Union of any released model desription""" 28# autogen: stop
2531class ModelDescr(GenericModelDescrBase): 2532 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2533 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2534 """ 2535 2536 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 2537 if TYPE_CHECKING: 2538 format_version: Literal["0.5.4"] = "0.5.4" 2539 else: 2540 format_version: Literal["0.5.4"] 2541 """Version of the bioimage.io model description specification used. 2542 When creating a new model always use the latest micro/patch version described here. 2543 The `format_version` is important for any consumer software to understand how to parse the fields. 2544 """ 2545 2546 implemented_type: ClassVar[Literal["model"]] = "model" 2547 if TYPE_CHECKING: 2548 type: Literal["model"] = "model" 2549 else: 2550 type: Literal["model"] 2551 """Specialized resource type 'model'""" 2552 2553 id: Optional[ModelId] = None 2554 """bioimage.io-wide unique resource identifier 2555 assigned by bioimage.io; version **un**specific.""" 2556 2557 authors: NotEmpty[List[Author]] 2558 """The authors are the creators of the model RDF and the primary points of contact.""" 2559 2560 documentation: FileSource_documentation 2561 """URL or relative path to a markdown file with additional documentation. 2562 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2563 The documentation should include a '#[#] Validation' (sub)section 2564 with details on how to quantitatively validate the model on unseen data.""" 2565 2566 @field_validator("documentation", mode="after") 2567 @classmethod 2568 def _validate_documentation( 2569 cls, value: FileSource_documentation 2570 ) -> FileSource_documentation: 2571 if not get_validation_context().perform_io_checks: 2572 return value 2573 2574 doc_reader = get_reader(value) 2575 doc_content = doc_reader.read().decode(encoding="utf-8") 2576 if not re.search("#.*[vV]alidation", doc_content): 2577 issue_warning( 2578 "No '# Validation' (sub)section found in {value}.", 2579 value=value, 2580 field="documentation", 2581 ) 2582 2583 return value 2584 2585 inputs: NotEmpty[Sequence[InputTensorDescr]] 2586 """Describes the input tensors expected by this model.""" 2587 2588 @field_validator("inputs", mode="after") 2589 @classmethod 2590 def _validate_input_axes( 2591 cls, inputs: Sequence[InputTensorDescr] 2592 ) -> Sequence[InputTensorDescr]: 2593 input_size_refs = cls._get_axes_with_independent_size(inputs) 2594 2595 for i, ipt in enumerate(inputs): 2596 valid_independent_refs: Dict[ 2597 Tuple[TensorId, AxisId], 2598 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2599 ] = { 2600 **{ 2601 (ipt.id, a.id): (ipt, a, a.size) 2602 for a in ipt.axes 2603 if not isinstance(a, BatchAxis) 2604 and isinstance(a.size, (int, ParameterizedSize)) 2605 }, 2606 **input_size_refs, 2607 } 2608 for a, ax in enumerate(ipt.axes): 2609 cls._validate_axis( 2610 "inputs", 2611 i=i, 2612 tensor_id=ipt.id, 2613 a=a, 2614 axis=ax, 2615 valid_independent_refs=valid_independent_refs, 2616 ) 2617 return inputs 2618 2619 @staticmethod 2620 def _validate_axis( 2621 field_name: str, 2622 i: int, 2623 tensor_id: TensorId, 2624 a: int, 2625 axis: AnyAxis, 2626 valid_independent_refs: Dict[ 2627 Tuple[TensorId, AxisId], 2628 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2629 ], 2630 ): 2631 if isinstance(axis, BatchAxis) or isinstance( 2632 axis.size, (int, ParameterizedSize, DataDependentSize) 2633 ): 2634 return 2635 elif not isinstance(axis.size, SizeReference): 2636 assert_never(axis.size) 2637 2638 # validate axis.size SizeReference 2639 ref = (axis.size.tensor_id, axis.size.axis_id) 2640 if ref not in valid_independent_refs: 2641 raise ValueError( 2642 "Invalid tensor axis reference at" 2643 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2644 ) 2645 if ref == (tensor_id, axis.id): 2646 raise ValueError( 2647 "Self-referencing not allowed for" 2648 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2649 ) 2650 if axis.type == "channel": 2651 if valid_independent_refs[ref][1].type != "channel": 2652 raise ValueError( 2653 "A channel axis' size may only reference another fixed size" 2654 + " channel axis." 2655 ) 2656 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2657 ref_size = valid_independent_refs[ref][2] 2658 assert isinstance(ref_size, int), ( 2659 "channel axis ref (another channel axis) has to specify fixed" 2660 + " size" 2661 ) 2662 generated_channel_names = [ 2663 Identifier(axis.channel_names.format(i=i)) 2664 for i in range(1, ref_size + 1) 2665 ] 2666 axis.channel_names = generated_channel_names 2667 2668 if (ax_unit := getattr(axis, "unit", None)) != ( 2669 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2670 ): 2671 raise ValueError( 2672 "The units of an axis and its reference axis need to match, but" 2673 + f" '{ax_unit}' != '{ref_unit}'." 2674 ) 2675 ref_axis = valid_independent_refs[ref][1] 2676 if isinstance(ref_axis, BatchAxis): 2677 raise ValueError( 2678 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2679 + " (a batch axis is not allowed as reference)." 2680 ) 2681 2682 if isinstance(axis, WithHalo): 2683 min_size = axis.size.get_size(axis, ref_axis, n=0) 2684 if (min_size - 2 * axis.halo) < 1: 2685 raise ValueError( 2686 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2687 + f" {axis.halo}." 2688 ) 2689 2690 input_halo = axis.halo * axis.scale / ref_axis.scale 2691 if input_halo != int(input_halo) or input_halo % 2 == 1: 2692 raise ValueError( 2693 f"input_halo {input_halo} (output_halo {axis.halo} *" 2694 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2695 + f" {tensor_id}.{axis.id}." 2696 ) 2697 2698 @model_validator(mode="after") 2699 def _validate_test_tensors(self) -> Self: 2700 if not get_validation_context().perform_io_checks: 2701 return self 2702 2703 test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs] 2704 test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs] 2705 2706 tensors = { 2707 descr.id: (descr, array) 2708 for descr, array in zip( 2709 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2710 ) 2711 } 2712 validate_tensors(tensors, tensor_origin="test_tensor") 2713 2714 output_arrays = { 2715 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2716 } 2717 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2718 if not rep_tol.absolute_tolerance: 2719 continue 2720 2721 if rep_tol.output_ids: 2722 out_arrays = { 2723 oid: a 2724 for oid, a in output_arrays.items() 2725 if oid in rep_tol.output_ids 2726 } 2727 else: 2728 out_arrays = output_arrays 2729 2730 for out_id, array in out_arrays.items(): 2731 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2732 raise ValueError( 2733 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2734 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2735 + f" (1% of the maximum value of the test tensor '{out_id}')" 2736 ) 2737 2738 return self 2739 2740 @model_validator(mode="after") 2741 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2742 ipt_refs = {t.id for t in self.inputs} 2743 out_refs = {t.id for t in self.outputs} 2744 for ipt in self.inputs: 2745 for p in ipt.preprocessing: 2746 ref = p.kwargs.get("reference_tensor") 2747 if ref is None: 2748 continue 2749 if ref not in ipt_refs: 2750 raise ValueError( 2751 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2752 + f" references are: {ipt_refs}." 2753 ) 2754 2755 for out in self.outputs: 2756 for p in out.postprocessing: 2757 ref = p.kwargs.get("reference_tensor") 2758 if ref is None: 2759 continue 2760 2761 if ref not in ipt_refs and ref not in out_refs: 2762 raise ValueError( 2763 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2764 + f" are: {ipt_refs | out_refs}." 2765 ) 2766 2767 return self 2768 2769 # TODO: use validate funcs in validate_test_tensors 2770 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2771 2772 name: Annotated[ 2773 Annotated[ 2774 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 2775 ], 2776 MinLen(5), 2777 MaxLen(128), 2778 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2779 ] 2780 """A human-readable name of this model. 2781 It should be no longer than 64 characters 2782 and may only contain letter, number, underscore, minus, parentheses and spaces. 2783 We recommend to chose a name that refers to the model's task and image modality. 2784 """ 2785 2786 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2787 """Describes the output tensors.""" 2788 2789 @field_validator("outputs", mode="after") 2790 @classmethod 2791 def _validate_tensor_ids( 2792 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2793 ) -> Sequence[OutputTensorDescr]: 2794 tensor_ids = [ 2795 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2796 ] 2797 duplicate_tensor_ids: List[str] = [] 2798 seen: Set[str] = set() 2799 for t in tensor_ids: 2800 if t in seen: 2801 duplicate_tensor_ids.append(t) 2802 2803 seen.add(t) 2804 2805 if duplicate_tensor_ids: 2806 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2807 2808 return outputs 2809 2810 @staticmethod 2811 def _get_axes_with_parameterized_size( 2812 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2813 ): 2814 return { 2815 f"{t.id}.{a.id}": (t, a, a.size) 2816 for t in io 2817 for a in t.axes 2818 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2819 } 2820 2821 @staticmethod 2822 def _get_axes_with_independent_size( 2823 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2824 ): 2825 return { 2826 (t.id, a.id): (t, a, a.size) 2827 for t in io 2828 for a in t.axes 2829 if not isinstance(a, BatchAxis) 2830 and isinstance(a.size, (int, ParameterizedSize)) 2831 } 2832 2833 @field_validator("outputs", mode="after") 2834 @classmethod 2835 def _validate_output_axes( 2836 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2837 ) -> List[OutputTensorDescr]: 2838 input_size_refs = cls._get_axes_with_independent_size( 2839 info.data.get("inputs", []) 2840 ) 2841 output_size_refs = cls._get_axes_with_independent_size(outputs) 2842 2843 for i, out in enumerate(outputs): 2844 valid_independent_refs: Dict[ 2845 Tuple[TensorId, AxisId], 2846 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2847 ] = { 2848 **{ 2849 (out.id, a.id): (out, a, a.size) 2850 for a in out.axes 2851 if not isinstance(a, BatchAxis) 2852 and isinstance(a.size, (int, ParameterizedSize)) 2853 }, 2854 **input_size_refs, 2855 **output_size_refs, 2856 } 2857 for a, ax in enumerate(out.axes): 2858 cls._validate_axis( 2859 "outputs", 2860 i, 2861 out.id, 2862 a, 2863 ax, 2864 valid_independent_refs=valid_independent_refs, 2865 ) 2866 2867 return outputs 2868 2869 packaged_by: List[Author] = Field( 2870 default_factory=cast(Callable[[], List[Author]], list) 2871 ) 2872 """The persons that have packaged and uploaded this model. 2873 Only required if those persons differ from the `authors`.""" 2874 2875 parent: Optional[LinkedModel] = None 2876 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2877 2878 @model_validator(mode="after") 2879 def _validate_parent_is_not_self(self) -> Self: 2880 if self.parent is not None and self.parent.id == self.id: 2881 raise ValueError("A model description may not reference itself as parent.") 2882 2883 return self 2884 2885 run_mode: Annotated[ 2886 Optional[RunMode], 2887 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2888 ] = None 2889 """Custom run mode for this model: for more complex prediction procedures like test time 2890 data augmentation that currently cannot be expressed in the specification. 2891 No standard run modes are defined yet.""" 2892 2893 timestamp: Datetime = Field(default_factory=Datetime.now) 2894 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2895 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2896 (In Python a datetime object is valid, too).""" 2897 2898 training_data: Annotated[ 2899 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2900 Field(union_mode="left_to_right"), 2901 ] = None 2902 """The dataset used to train this model""" 2903 2904 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2905 """The weights for this model. 2906 Weights can be given for different formats, but should otherwise be equivalent. 2907 The available weight formats determine which consumers can use this model.""" 2908 2909 config: Config = Field(default_factory=Config) 2910 2911 @model_validator(mode="after") 2912 def _add_default_cover(self) -> Self: 2913 if not get_validation_context().perform_io_checks or self.covers: 2914 return self 2915 2916 try: 2917 generated_covers = generate_covers( 2918 [(t, load_array(t.test_tensor)) for t in self.inputs], 2919 [(t, load_array(t.test_tensor)) for t in self.outputs], 2920 ) 2921 except Exception as e: 2922 issue_warning( 2923 "Failed to generate cover image(s): {e}", 2924 value=self.covers, 2925 msg_context=dict(e=e), 2926 field="covers", 2927 ) 2928 else: 2929 self.covers.extend(generated_covers) 2930 2931 return self 2932 2933 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2934 data = [load_array(ipt.test_tensor) for ipt in self.inputs] 2935 assert all(isinstance(d, np.ndarray) for d in data) 2936 return data 2937 2938 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2939 data = [load_array(out.test_tensor) for out in self.outputs] 2940 assert all(isinstance(d, np.ndarray) for d in data) 2941 return data 2942 2943 @staticmethod 2944 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2945 batch_size = 1 2946 tensor_with_batchsize: Optional[TensorId] = None 2947 for tid in tensor_sizes: 2948 for aid, s in tensor_sizes[tid].items(): 2949 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2950 continue 2951 2952 if batch_size != 1: 2953 assert tensor_with_batchsize is not None 2954 raise ValueError( 2955 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2956 ) 2957 2958 batch_size = s 2959 tensor_with_batchsize = tid 2960 2961 return batch_size 2962 2963 def get_output_tensor_sizes( 2964 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2965 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2966 """Returns the tensor output sizes for given **input_sizes**. 2967 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2968 Otherwise it might be larger than the actual (valid) output""" 2969 batch_size = self.get_batch_size(input_sizes) 2970 ns = self.get_ns(input_sizes) 2971 2972 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2973 return tensor_sizes.outputs 2974 2975 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2976 """get parameter `n` for each parameterized axis 2977 such that the valid input size is >= the given input size""" 2978 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2979 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2980 for tid in input_sizes: 2981 for aid, s in input_sizes[tid].items(): 2982 size_descr = axes[tid][aid].size 2983 if isinstance(size_descr, ParameterizedSize): 2984 ret[(tid, aid)] = size_descr.get_n(s) 2985 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2986 pass 2987 else: 2988 assert_never(size_descr) 2989 2990 return ret 2991 2992 def get_tensor_sizes( 2993 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2994 ) -> _TensorSizes: 2995 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2996 return _TensorSizes( 2997 { 2998 t: { 2999 aa: axis_sizes.inputs[(tt, aa)] 3000 for tt, aa in axis_sizes.inputs 3001 if tt == t 3002 } 3003 for t in {tt for tt, _ in axis_sizes.inputs} 3004 }, 3005 { 3006 t: { 3007 aa: axis_sizes.outputs[(tt, aa)] 3008 for tt, aa in axis_sizes.outputs 3009 if tt == t 3010 } 3011 for t in {tt for tt, _ in axis_sizes.outputs} 3012 }, 3013 ) 3014 3015 def get_axis_sizes( 3016 self, 3017 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3018 batch_size: Optional[int] = None, 3019 *, 3020 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3021 ) -> _AxisSizes: 3022 """Determine input and output block shape for scale factors **ns** 3023 of parameterized input sizes. 3024 3025 Args: 3026 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3027 that is parameterized as `size = min + n * step`. 3028 batch_size: The desired size of the batch dimension. 3029 If given **batch_size** overwrites any batch size present in 3030 **max_input_shape**. Default 1. 3031 max_input_shape: Limits the derived block shapes. 3032 Each axis for which the input size, parameterized by `n`, is larger 3033 than **max_input_shape** is set to the minimal value `n_min` for which 3034 this is still true. 3035 Use this for small input samples or large values of **ns**. 3036 Or simply whenever you know the full input shape. 3037 3038 Returns: 3039 Resolved axis sizes for model inputs and outputs. 3040 """ 3041 max_input_shape = max_input_shape or {} 3042 if batch_size is None: 3043 for (_t_id, a_id), s in max_input_shape.items(): 3044 if a_id == BATCH_AXIS_ID: 3045 batch_size = s 3046 break 3047 else: 3048 batch_size = 1 3049 3050 all_axes = { 3051 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3052 } 3053 3054 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3055 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3056 3057 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3058 if isinstance(a, BatchAxis): 3059 if (t_descr.id, a.id) in ns: 3060 logger.warning( 3061 "Ignoring unexpected size increment factor (n) for batch axis" 3062 + " of tensor '{}'.", 3063 t_descr.id, 3064 ) 3065 return batch_size 3066 elif isinstance(a.size, int): 3067 if (t_descr.id, a.id) in ns: 3068 logger.warning( 3069 "Ignoring unexpected size increment factor (n) for fixed size" 3070 + " axis '{}' of tensor '{}'.", 3071 a.id, 3072 t_descr.id, 3073 ) 3074 return a.size 3075 elif isinstance(a.size, ParameterizedSize): 3076 if (t_descr.id, a.id) not in ns: 3077 raise ValueError( 3078 "Size increment factor (n) missing for parametrized axis" 3079 + f" '{a.id}' of tensor '{t_descr.id}'." 3080 ) 3081 n = ns[(t_descr.id, a.id)] 3082 s_max = max_input_shape.get((t_descr.id, a.id)) 3083 if s_max is not None: 3084 n = min(n, a.size.get_n(s_max)) 3085 3086 return a.size.get_size(n) 3087 3088 elif isinstance(a.size, SizeReference): 3089 if (t_descr.id, a.id) in ns: 3090 logger.warning( 3091 "Ignoring unexpected size increment factor (n) for axis '{}'" 3092 + " of tensor '{}' with size reference.", 3093 a.id, 3094 t_descr.id, 3095 ) 3096 assert not isinstance(a, BatchAxis) 3097 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3098 assert not isinstance(ref_axis, BatchAxis) 3099 ref_key = (a.size.tensor_id, a.size.axis_id) 3100 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3101 assert ref_size is not None, ref_key 3102 assert not isinstance(ref_size, _DataDepSize), ref_key 3103 return a.size.get_size( 3104 axis=a, 3105 ref_axis=ref_axis, 3106 ref_size=ref_size, 3107 ) 3108 elif isinstance(a.size, DataDependentSize): 3109 if (t_descr.id, a.id) in ns: 3110 logger.warning( 3111 "Ignoring unexpected increment factor (n) for data dependent" 3112 + " size axis '{}' of tensor '{}'.", 3113 a.id, 3114 t_descr.id, 3115 ) 3116 return _DataDepSize(a.size.min, a.size.max) 3117 else: 3118 assert_never(a.size) 3119 3120 # first resolve all , but the `SizeReference` input sizes 3121 for t_descr in self.inputs: 3122 for a in t_descr.axes: 3123 if not isinstance(a.size, SizeReference): 3124 s = get_axis_size(a) 3125 assert not isinstance(s, _DataDepSize) 3126 inputs[t_descr.id, a.id] = s 3127 3128 # resolve all other input axis sizes 3129 for t_descr in self.inputs: 3130 for a in t_descr.axes: 3131 if isinstance(a.size, SizeReference): 3132 s = get_axis_size(a) 3133 assert not isinstance(s, _DataDepSize) 3134 inputs[t_descr.id, a.id] = s 3135 3136 # resolve all output axis sizes 3137 for t_descr in self.outputs: 3138 for a in t_descr.axes: 3139 assert not isinstance(a.size, ParameterizedSize) 3140 s = get_axis_size(a) 3141 outputs[t_descr.id, a.id] = s 3142 3143 return _AxisSizes(inputs=inputs, outputs=outputs) 3144 3145 @model_validator(mode="before") 3146 @classmethod 3147 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3148 cls.convert_from_old_format_wo_validation(data) 3149 return data 3150 3151 @classmethod 3152 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3153 """Convert metadata following an older format version to this classes' format 3154 without validating the result. 3155 """ 3156 if ( 3157 data.get("type") == "model" 3158 and isinstance(fv := data.get("format_version"), str) 3159 and fv.count(".") == 2 3160 ): 3161 fv_parts = fv.split(".") 3162 if any(not p.isdigit() for p in fv_parts): 3163 return 3164 3165 fv_tuple = tuple(map(int, fv_parts)) 3166 3167 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3168 if fv_tuple[:2] in ((0, 3), (0, 4)): 3169 m04 = _ModelDescr_v0_4.load(data) 3170 if isinstance(m04, InvalidDescr): 3171 try: 3172 updated = _model_conv.convert_as_dict( 3173 m04 # pyright: ignore[reportArgumentType] 3174 ) 3175 except Exception as e: 3176 logger.error( 3177 "Failed to convert from invalid model 0.4 description." 3178 + f"\nerror: {e}" 3179 + "\nProceeding with model 0.5 validation without conversion." 3180 ) 3181 updated = None 3182 else: 3183 updated = _model_conv.convert_as_dict(m04) 3184 3185 if updated is not None: 3186 data.clear() 3187 data.update(updated) 3188 3189 elif fv_tuple[:2] == (0, 5): 3190 # bump patch version 3191 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
Describes the output tensors.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
2943 @staticmethod 2944 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2945 batch_size = 1 2946 tensor_with_batchsize: Optional[TensorId] = None 2947 for tid in tensor_sizes: 2948 for aid, s in tensor_sizes[tid].items(): 2949 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2950 continue 2951 2952 if batch_size != 1: 2953 assert tensor_with_batchsize is not None 2954 raise ValueError( 2955 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2956 ) 2957 2958 batch_size = s 2959 tensor_with_batchsize = tid 2960 2961 return batch_size
2963 def get_output_tensor_sizes( 2964 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2965 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2966 """Returns the tensor output sizes for given **input_sizes**. 2967 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2968 Otherwise it might be larger than the actual (valid) output""" 2969 batch_size = self.get_batch_size(input_sizes) 2970 ns = self.get_ns(input_sizes) 2971 2972 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2973 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
2975 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2976 """get parameter `n` for each parameterized axis 2977 such that the valid input size is >= the given input size""" 2978 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2979 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2980 for tid in input_sizes: 2981 for aid, s in input_sizes[tid].items(): 2982 size_descr = axes[tid][aid].size 2983 if isinstance(size_descr, ParameterizedSize): 2984 ret[(tid, aid)] = size_descr.get_n(s) 2985 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2986 pass 2987 else: 2988 assert_never(size_descr) 2989 2990 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
2992 def get_tensor_sizes( 2993 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2994 ) -> _TensorSizes: 2995 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2996 return _TensorSizes( 2997 { 2998 t: { 2999 aa: axis_sizes.inputs[(tt, aa)] 3000 for tt, aa in axis_sizes.inputs 3001 if tt == t 3002 } 3003 for t in {tt for tt, _ in axis_sizes.inputs} 3004 }, 3005 { 3006 t: { 3007 aa: axis_sizes.outputs[(tt, aa)] 3008 for tt, aa in axis_sizes.outputs 3009 if tt == t 3010 } 3011 for t in {tt for tt, _ in axis_sizes.outputs} 3012 }, 3013 )
3015 def get_axis_sizes( 3016 self, 3017 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3018 batch_size: Optional[int] = None, 3019 *, 3020 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3021 ) -> _AxisSizes: 3022 """Determine input and output block shape for scale factors **ns** 3023 of parameterized input sizes. 3024 3025 Args: 3026 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3027 that is parameterized as `size = min + n * step`. 3028 batch_size: The desired size of the batch dimension. 3029 If given **batch_size** overwrites any batch size present in 3030 **max_input_shape**. Default 1. 3031 max_input_shape: Limits the derived block shapes. 3032 Each axis for which the input size, parameterized by `n`, is larger 3033 than **max_input_shape** is set to the minimal value `n_min` for which 3034 this is still true. 3035 Use this for small input samples or large values of **ns**. 3036 Or simply whenever you know the full input shape. 3037 3038 Returns: 3039 Resolved axis sizes for model inputs and outputs. 3040 """ 3041 max_input_shape = max_input_shape or {} 3042 if batch_size is None: 3043 for (_t_id, a_id), s in max_input_shape.items(): 3044 if a_id == BATCH_AXIS_ID: 3045 batch_size = s 3046 break 3047 else: 3048 batch_size = 1 3049 3050 all_axes = { 3051 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3052 } 3053 3054 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3055 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3056 3057 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3058 if isinstance(a, BatchAxis): 3059 if (t_descr.id, a.id) in ns: 3060 logger.warning( 3061 "Ignoring unexpected size increment factor (n) for batch axis" 3062 + " of tensor '{}'.", 3063 t_descr.id, 3064 ) 3065 return batch_size 3066 elif isinstance(a.size, int): 3067 if (t_descr.id, a.id) in ns: 3068 logger.warning( 3069 "Ignoring unexpected size increment factor (n) for fixed size" 3070 + " axis '{}' of tensor '{}'.", 3071 a.id, 3072 t_descr.id, 3073 ) 3074 return a.size 3075 elif isinstance(a.size, ParameterizedSize): 3076 if (t_descr.id, a.id) not in ns: 3077 raise ValueError( 3078 "Size increment factor (n) missing for parametrized axis" 3079 + f" '{a.id}' of tensor '{t_descr.id}'." 3080 ) 3081 n = ns[(t_descr.id, a.id)] 3082 s_max = max_input_shape.get((t_descr.id, a.id)) 3083 if s_max is not None: 3084 n = min(n, a.size.get_n(s_max)) 3085 3086 return a.size.get_size(n) 3087 3088 elif isinstance(a.size, SizeReference): 3089 if (t_descr.id, a.id) in ns: 3090 logger.warning( 3091 "Ignoring unexpected size increment factor (n) for axis '{}'" 3092 + " of tensor '{}' with size reference.", 3093 a.id, 3094 t_descr.id, 3095 ) 3096 assert not isinstance(a, BatchAxis) 3097 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3098 assert not isinstance(ref_axis, BatchAxis) 3099 ref_key = (a.size.tensor_id, a.size.axis_id) 3100 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3101 assert ref_size is not None, ref_key 3102 assert not isinstance(ref_size, _DataDepSize), ref_key 3103 return a.size.get_size( 3104 axis=a, 3105 ref_axis=ref_axis, 3106 ref_size=ref_size, 3107 ) 3108 elif isinstance(a.size, DataDependentSize): 3109 if (t_descr.id, a.id) in ns: 3110 logger.warning( 3111 "Ignoring unexpected increment factor (n) for data dependent" 3112 + " size axis '{}' of tensor '{}'.", 3113 a.id, 3114 t_descr.id, 3115 ) 3116 return _DataDepSize(a.size.min, a.size.max) 3117 else: 3118 assert_never(a.size) 3119 3120 # first resolve all , but the `SizeReference` input sizes 3121 for t_descr in self.inputs: 3122 for a in t_descr.axes: 3123 if not isinstance(a.size, SizeReference): 3124 s = get_axis_size(a) 3125 assert not isinstance(s, _DataDepSize) 3126 inputs[t_descr.id, a.id] = s 3127 3128 # resolve all other input axis sizes 3129 for t_descr in self.inputs: 3130 for a in t_descr.axes: 3131 if isinstance(a.size, SizeReference): 3132 s = get_axis_size(a) 3133 assert not isinstance(s, _DataDepSize) 3134 inputs[t_descr.id, a.id] = s 3135 3136 # resolve all output axis sizes 3137 for t_descr in self.outputs: 3138 for a in t_descr.axes: 3139 assert not isinstance(a.size, ParameterizedSize) 3140 s = get_axis_size(a) 3141 outputs[t_descr.id, a.id] = s 3142 3143 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3151 @classmethod 3152 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3153 """Convert metadata following an older format version to this classes' format 3154 without validating the result. 3155 """ 3156 if ( 3157 data.get("type") == "model" 3158 and isinstance(fv := data.get("format_version"), str) 3159 and fv.count(".") == 2 3160 ): 3161 fv_parts = fv.split(".") 3162 if any(not p.isdigit() for p in fv_parts): 3163 return 3164 3165 fv_tuple = tuple(map(int, fv_parts)) 3166 3167 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3168 if fv_tuple[:2] in ((0, 3), (0, 4)): 3169 m04 = _ModelDescr_v0_4.load(data) 3170 if isinstance(m04, InvalidDescr): 3171 try: 3172 updated = _model_conv.convert_as_dict( 3173 m04 # pyright: ignore[reportArgumentType] 3174 ) 3175 except Exception as e: 3176 logger.error( 3177 "Failed to convert from invalid model 0.4 description." 3178 + f"\nerror: {e}" 3179 + "\nProceeding with model 0.5 validation without conversion." 3180 ) 3181 updated = None 3182 else: 3183 updated = _model_conv.convert_as_dict(m04) 3184 3185 if updated is not None: 3186 data.clear() 3187 data.update(updated) 3188 3189 elif fv_tuple[:2] == (0, 5): 3190 # bump patch version 3191 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
337def init_private_attributes(self: BaseModel, context: Any, /) -> None: 338 """This function is meant to behave like a BaseModel method to initialise private attributes. 339 340 It takes context as an argument since that's what pydantic-core passes when calling it. 341 342 Args: 343 self: The BaseModel instance. 344 context: The context. 345 """ 346 if getattr(self, '__pydantic_private__', None) is None: 347 pydantic_private = {} 348 for name, private_attr in self.__private_attributes__.items(): 349 default = private_attr.get_default() 350 if default is not PydanticUndefined: 351 pydantic_private[name] = default 352 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
Union of any released model desription