bioimageio.spec.model
implementaions of all released minor versions are available in submodules:
- model v0_4:
bioimageio.spec.model.v0_4.ModelDescr
- model v0_5:
bioimageio.spec.model.v0_5.ModelDescr
1# autogen: start 2""" 3implementaions of all released minor versions are available in submodules: 4- model v0_4: `bioimageio.spec.model.v0_4.ModelDescr` 5- model v0_5: `bioimageio.spec.model.v0_5.ModelDescr` 6""" 7 8from typing import Union 9 10from pydantic import Discriminator, Field 11from typing_extensions import Annotated 12 13from . import v0_4, v0_5 14 15ModelDescr = v0_5.ModelDescr 16ModelDescr_v0_4 = v0_4.ModelDescr 17ModelDescr_v0_5 = v0_5.ModelDescr 18 19AnyModelDescr = Annotated[ 20 Union[ 21 Annotated[ModelDescr_v0_4, Field(title="model 0.4")], 22 Annotated[ModelDescr_v0_5, Field(title="model 0.5")], 23 ], 24 Discriminator("format_version"), 25 Field(title="model"), 26] 27"""Union of any released model desription""" 28# autogen: stop
2512class ModelDescr(GenericModelDescrBase): 2513 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 2514 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 2515 """ 2516 2517 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 2518 if TYPE_CHECKING: 2519 format_version: Literal["0.5.4"] = "0.5.4" 2520 else: 2521 format_version: Literal["0.5.4"] 2522 """Version of the bioimage.io model description specification used. 2523 When creating a new model always use the latest micro/patch version described here. 2524 The `format_version` is important for any consumer software to understand how to parse the fields. 2525 """ 2526 2527 implemented_type: ClassVar[Literal["model"]] = "model" 2528 if TYPE_CHECKING: 2529 type: Literal["model"] = "model" 2530 else: 2531 type: Literal["model"] 2532 """Specialized resource type 'model'""" 2533 2534 id: Optional[ModelId] = None 2535 """bioimage.io-wide unique resource identifier 2536 assigned by bioimage.io; version **un**specific.""" 2537 2538 authors: NotEmpty[List[Author]] 2539 """The authors are the creators of the model RDF and the primary points of contact.""" 2540 2541 documentation: Annotated[ 2542 DocumentationSource, 2543 Field( 2544 examples=[ 2545 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 2546 "README.md", 2547 ], 2548 ), 2549 ] 2550 """∈📦 URL or relative path to a markdown file with additional documentation. 2551 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 2552 The documentation should include a '#[#] Validation' (sub)section 2553 with details on how to quantitatively validate the model on unseen data.""" 2554 2555 @field_validator("documentation", mode="after") 2556 @classmethod 2557 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource: 2558 if not get_validation_context().perform_io_checks: 2559 return value 2560 2561 doc_path = download(value).path 2562 doc_content = doc_path.read_text(encoding="utf-8") 2563 assert isinstance(doc_content, str) 2564 if not re.search("#.*[vV]alidation", doc_content): 2565 issue_warning( 2566 "No '# Validation' (sub)section found in {value}.", 2567 value=value, 2568 field="documentation", 2569 ) 2570 2571 return value 2572 2573 inputs: NotEmpty[Sequence[InputTensorDescr]] 2574 """Describes the input tensors expected by this model.""" 2575 2576 @field_validator("inputs", mode="after") 2577 @classmethod 2578 def _validate_input_axes( 2579 cls, inputs: Sequence[InputTensorDescr] 2580 ) -> Sequence[InputTensorDescr]: 2581 input_size_refs = cls._get_axes_with_independent_size(inputs) 2582 2583 for i, ipt in enumerate(inputs): 2584 valid_independent_refs: Dict[ 2585 Tuple[TensorId, AxisId], 2586 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2587 ] = { 2588 **{ 2589 (ipt.id, a.id): (ipt, a, a.size) 2590 for a in ipt.axes 2591 if not isinstance(a, BatchAxis) 2592 and isinstance(a.size, (int, ParameterizedSize)) 2593 }, 2594 **input_size_refs, 2595 } 2596 for a, ax in enumerate(ipt.axes): 2597 cls._validate_axis( 2598 "inputs", 2599 i=i, 2600 tensor_id=ipt.id, 2601 a=a, 2602 axis=ax, 2603 valid_independent_refs=valid_independent_refs, 2604 ) 2605 return inputs 2606 2607 @staticmethod 2608 def _validate_axis( 2609 field_name: str, 2610 i: int, 2611 tensor_id: TensorId, 2612 a: int, 2613 axis: AnyAxis, 2614 valid_independent_refs: Dict[ 2615 Tuple[TensorId, AxisId], 2616 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2617 ], 2618 ): 2619 if isinstance(axis, BatchAxis) or isinstance( 2620 axis.size, (int, ParameterizedSize, DataDependentSize) 2621 ): 2622 return 2623 elif not isinstance(axis.size, SizeReference): 2624 assert_never(axis.size) 2625 2626 # validate axis.size SizeReference 2627 ref = (axis.size.tensor_id, axis.size.axis_id) 2628 if ref not in valid_independent_refs: 2629 raise ValueError( 2630 "Invalid tensor axis reference at" 2631 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 2632 ) 2633 if ref == (tensor_id, axis.id): 2634 raise ValueError( 2635 "Self-referencing not allowed for" 2636 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 2637 ) 2638 if axis.type == "channel": 2639 if valid_independent_refs[ref][1].type != "channel": 2640 raise ValueError( 2641 "A channel axis' size may only reference another fixed size" 2642 + " channel axis." 2643 ) 2644 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 2645 ref_size = valid_independent_refs[ref][2] 2646 assert isinstance(ref_size, int), ( 2647 "channel axis ref (another channel axis) has to specify fixed" 2648 + " size" 2649 ) 2650 generated_channel_names = [ 2651 Identifier(axis.channel_names.format(i=i)) 2652 for i in range(1, ref_size + 1) 2653 ] 2654 axis.channel_names = generated_channel_names 2655 2656 if (ax_unit := getattr(axis, "unit", None)) != ( 2657 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 2658 ): 2659 raise ValueError( 2660 "The units of an axis and its reference axis need to match, but" 2661 + f" '{ax_unit}' != '{ref_unit}'." 2662 ) 2663 ref_axis = valid_independent_refs[ref][1] 2664 if isinstance(ref_axis, BatchAxis): 2665 raise ValueError( 2666 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 2667 + " (a batch axis is not allowed as reference)." 2668 ) 2669 2670 if isinstance(axis, WithHalo): 2671 min_size = axis.size.get_size(axis, ref_axis, n=0) 2672 if (min_size - 2 * axis.halo) < 1: 2673 raise ValueError( 2674 f"axis {axis.id} with minimum size {min_size} is too small for halo" 2675 + f" {axis.halo}." 2676 ) 2677 2678 input_halo = axis.halo * axis.scale / ref_axis.scale 2679 if input_halo != int(input_halo) or input_halo % 2 == 1: 2680 raise ValueError( 2681 f"input_halo {input_halo} (output_halo {axis.halo} *" 2682 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 2683 + f" {tensor_id}.{axis.id}." 2684 ) 2685 2686 @model_validator(mode="after") 2687 def _validate_test_tensors(self) -> Self: 2688 if not get_validation_context().perform_io_checks: 2689 return self 2690 2691 test_output_arrays = [ 2692 load_array(descr.test_tensor.download().path) for descr in self.outputs 2693 ] 2694 test_input_arrays = [ 2695 load_array(descr.test_tensor.download().path) for descr in self.inputs 2696 ] 2697 2698 tensors = { 2699 descr.id: (descr, array) 2700 for descr, array in zip( 2701 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 2702 ) 2703 } 2704 validate_tensors(tensors, tensor_origin="test_tensor") 2705 2706 output_arrays = { 2707 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 2708 } 2709 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 2710 if not rep_tol.absolute_tolerance: 2711 continue 2712 2713 if rep_tol.output_ids: 2714 out_arrays = { 2715 oid: a 2716 for oid, a in output_arrays.items() 2717 if oid in rep_tol.output_ids 2718 } 2719 else: 2720 out_arrays = output_arrays 2721 2722 for out_id, array in out_arrays.items(): 2723 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 2724 raise ValueError( 2725 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 2726 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 2727 + f" (1% of the maximum value of the test tensor '{out_id}')" 2728 ) 2729 2730 return self 2731 2732 @model_validator(mode="after") 2733 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 2734 ipt_refs = {t.id for t in self.inputs} 2735 out_refs = {t.id for t in self.outputs} 2736 for ipt in self.inputs: 2737 for p in ipt.preprocessing: 2738 ref = p.kwargs.get("reference_tensor") 2739 if ref is None: 2740 continue 2741 if ref not in ipt_refs: 2742 raise ValueError( 2743 f"`reference_tensor` '{ref}' not found. Valid input tensor" 2744 + f" references are: {ipt_refs}." 2745 ) 2746 2747 for out in self.outputs: 2748 for p in out.postprocessing: 2749 ref = p.kwargs.get("reference_tensor") 2750 if ref is None: 2751 continue 2752 2753 if ref not in ipt_refs and ref not in out_refs: 2754 raise ValueError( 2755 f"`reference_tensor` '{ref}' not found. Valid tensor references" 2756 + f" are: {ipt_refs | out_refs}." 2757 ) 2758 2759 return self 2760 2761 # TODO: use validate funcs in validate_test_tensors 2762 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 2763 2764 name: Annotated[ 2765 Annotated[ 2766 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 2767 ], 2768 MinLen(5), 2769 MaxLen(128), 2770 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 2771 ] 2772 """A human-readable name of this model. 2773 It should be no longer than 64 characters 2774 and may only contain letter, number, underscore, minus, parentheses and spaces. 2775 We recommend to chose a name that refers to the model's task and image modality. 2776 """ 2777 2778 outputs: NotEmpty[Sequence[OutputTensorDescr]] 2779 """Describes the output tensors.""" 2780 2781 @field_validator("outputs", mode="after") 2782 @classmethod 2783 def _validate_tensor_ids( 2784 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 2785 ) -> Sequence[OutputTensorDescr]: 2786 tensor_ids = [ 2787 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 2788 ] 2789 duplicate_tensor_ids: List[str] = [] 2790 seen: Set[str] = set() 2791 for t in tensor_ids: 2792 if t in seen: 2793 duplicate_tensor_ids.append(t) 2794 2795 seen.add(t) 2796 2797 if duplicate_tensor_ids: 2798 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 2799 2800 return outputs 2801 2802 @staticmethod 2803 def _get_axes_with_parameterized_size( 2804 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2805 ): 2806 return { 2807 f"{t.id}.{a.id}": (t, a, a.size) 2808 for t in io 2809 for a in t.axes 2810 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 2811 } 2812 2813 @staticmethod 2814 def _get_axes_with_independent_size( 2815 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 2816 ): 2817 return { 2818 (t.id, a.id): (t, a, a.size) 2819 for t in io 2820 for a in t.axes 2821 if not isinstance(a, BatchAxis) 2822 and isinstance(a.size, (int, ParameterizedSize)) 2823 } 2824 2825 @field_validator("outputs", mode="after") 2826 @classmethod 2827 def _validate_output_axes( 2828 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 2829 ) -> List[OutputTensorDescr]: 2830 input_size_refs = cls._get_axes_with_independent_size( 2831 info.data.get("inputs", []) 2832 ) 2833 output_size_refs = cls._get_axes_with_independent_size(outputs) 2834 2835 for i, out in enumerate(outputs): 2836 valid_independent_refs: Dict[ 2837 Tuple[TensorId, AxisId], 2838 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 2839 ] = { 2840 **{ 2841 (out.id, a.id): (out, a, a.size) 2842 for a in out.axes 2843 if not isinstance(a, BatchAxis) 2844 and isinstance(a.size, (int, ParameterizedSize)) 2845 }, 2846 **input_size_refs, 2847 **output_size_refs, 2848 } 2849 for a, ax in enumerate(out.axes): 2850 cls._validate_axis( 2851 "outputs", 2852 i, 2853 out.id, 2854 a, 2855 ax, 2856 valid_independent_refs=valid_independent_refs, 2857 ) 2858 2859 return outputs 2860 2861 packaged_by: List[Author] = Field(default_factory=list) 2862 """The persons that have packaged and uploaded this model. 2863 Only required if those persons differ from the `authors`.""" 2864 2865 parent: Optional[LinkedModel] = None 2866 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 2867 2868 @model_validator(mode="after") 2869 def _validate_parent_is_not_self(self) -> Self: 2870 if self.parent is not None and self.parent.id == self.id: 2871 raise ValueError("A model description may not reference itself as parent.") 2872 2873 return self 2874 2875 run_mode: Annotated[ 2876 Optional[RunMode], 2877 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 2878 ] = None 2879 """Custom run mode for this model: for more complex prediction procedures like test time 2880 data augmentation that currently cannot be expressed in the specification. 2881 No standard run modes are defined yet.""" 2882 2883 timestamp: Datetime = Field(default_factory=Datetime.now) 2884 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 2885 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 2886 (In Python a datetime object is valid, too).""" 2887 2888 training_data: Annotated[ 2889 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 2890 Field(union_mode="left_to_right"), 2891 ] = None 2892 """The dataset used to train this model""" 2893 2894 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 2895 """The weights for this model. 2896 Weights can be given for different formats, but should otherwise be equivalent. 2897 The available weight formats determine which consumers can use this model.""" 2898 2899 config: Config = Field(default_factory=Config) 2900 2901 @model_validator(mode="after") 2902 def _add_default_cover(self) -> Self: 2903 if not get_validation_context().perform_io_checks or self.covers: 2904 return self 2905 2906 try: 2907 generated_covers = generate_covers( 2908 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs], 2909 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs], 2910 ) 2911 except Exception as e: 2912 issue_warning( 2913 "Failed to generate cover image(s): {e}", 2914 value=self.covers, 2915 msg_context=dict(e=e), 2916 field="covers", 2917 ) 2918 else: 2919 self.covers.extend(generated_covers) 2920 2921 return self 2922 2923 def get_input_test_arrays(self) -> List[NDArray[Any]]: 2924 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs] 2925 assert all(isinstance(d, np.ndarray) for d in data) 2926 return data 2927 2928 def get_output_test_arrays(self) -> List[NDArray[Any]]: 2929 data = [load_array(out.test_tensor.download().path) for out in self.outputs] 2930 assert all(isinstance(d, np.ndarray) for d in data) 2931 return data 2932 2933 @staticmethod 2934 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2935 batch_size = 1 2936 tensor_with_batchsize: Optional[TensorId] = None 2937 for tid in tensor_sizes: 2938 for aid, s in tensor_sizes[tid].items(): 2939 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2940 continue 2941 2942 if batch_size != 1: 2943 assert tensor_with_batchsize is not None 2944 raise ValueError( 2945 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2946 ) 2947 2948 batch_size = s 2949 tensor_with_batchsize = tid 2950 2951 return batch_size 2952 2953 def get_output_tensor_sizes( 2954 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2955 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2956 """Returns the tensor output sizes for given **input_sizes**. 2957 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2958 Otherwise it might be larger than the actual (valid) output""" 2959 batch_size = self.get_batch_size(input_sizes) 2960 ns = self.get_ns(input_sizes) 2961 2962 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2963 return tensor_sizes.outputs 2964 2965 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2966 """get parameter `n` for each parameterized axis 2967 such that the valid input size is >= the given input size""" 2968 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2969 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2970 for tid in input_sizes: 2971 for aid, s in input_sizes[tid].items(): 2972 size_descr = axes[tid][aid].size 2973 if isinstance(size_descr, ParameterizedSize): 2974 ret[(tid, aid)] = size_descr.get_n(s) 2975 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2976 pass 2977 else: 2978 assert_never(size_descr) 2979 2980 return ret 2981 2982 def get_tensor_sizes( 2983 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2984 ) -> _TensorSizes: 2985 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2986 return _TensorSizes( 2987 { 2988 t: { 2989 aa: axis_sizes.inputs[(tt, aa)] 2990 for tt, aa in axis_sizes.inputs 2991 if tt == t 2992 } 2993 for t in {tt for tt, _ in axis_sizes.inputs} 2994 }, 2995 { 2996 t: { 2997 aa: axis_sizes.outputs[(tt, aa)] 2998 for tt, aa in axis_sizes.outputs 2999 if tt == t 3000 } 3001 for t in {tt for tt, _ in axis_sizes.outputs} 3002 }, 3003 ) 3004 3005 def get_axis_sizes( 3006 self, 3007 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3008 batch_size: Optional[int] = None, 3009 *, 3010 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3011 ) -> _AxisSizes: 3012 """Determine input and output block shape for scale factors **ns** 3013 of parameterized input sizes. 3014 3015 Args: 3016 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3017 that is parameterized as `size = min + n * step`. 3018 batch_size: The desired size of the batch dimension. 3019 If given **batch_size** overwrites any batch size present in 3020 **max_input_shape**. Default 1. 3021 max_input_shape: Limits the derived block shapes. 3022 Each axis for which the input size, parameterized by `n`, is larger 3023 than **max_input_shape** is set to the minimal value `n_min` for which 3024 this is still true. 3025 Use this for small input samples or large values of **ns**. 3026 Or simply whenever you know the full input shape. 3027 3028 Returns: 3029 Resolved axis sizes for model inputs and outputs. 3030 """ 3031 max_input_shape = max_input_shape or {} 3032 if batch_size is None: 3033 for (_t_id, a_id), s in max_input_shape.items(): 3034 if a_id == BATCH_AXIS_ID: 3035 batch_size = s 3036 break 3037 else: 3038 batch_size = 1 3039 3040 all_axes = { 3041 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3042 } 3043 3044 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3045 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3046 3047 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3048 if isinstance(a, BatchAxis): 3049 if (t_descr.id, a.id) in ns: 3050 logger.warning( 3051 "Ignoring unexpected size increment factor (n) for batch axis" 3052 + " of tensor '{}'.", 3053 t_descr.id, 3054 ) 3055 return batch_size 3056 elif isinstance(a.size, int): 3057 if (t_descr.id, a.id) in ns: 3058 logger.warning( 3059 "Ignoring unexpected size increment factor (n) for fixed size" 3060 + " axis '{}' of tensor '{}'.", 3061 a.id, 3062 t_descr.id, 3063 ) 3064 return a.size 3065 elif isinstance(a.size, ParameterizedSize): 3066 if (t_descr.id, a.id) not in ns: 3067 raise ValueError( 3068 "Size increment factor (n) missing for parametrized axis" 3069 + f" '{a.id}' of tensor '{t_descr.id}'." 3070 ) 3071 n = ns[(t_descr.id, a.id)] 3072 s_max = max_input_shape.get((t_descr.id, a.id)) 3073 if s_max is not None: 3074 n = min(n, a.size.get_n(s_max)) 3075 3076 return a.size.get_size(n) 3077 3078 elif isinstance(a.size, SizeReference): 3079 if (t_descr.id, a.id) in ns: 3080 logger.warning( 3081 "Ignoring unexpected size increment factor (n) for axis '{}'" 3082 + " of tensor '{}' with size reference.", 3083 a.id, 3084 t_descr.id, 3085 ) 3086 assert not isinstance(a, BatchAxis) 3087 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3088 assert not isinstance(ref_axis, BatchAxis) 3089 ref_key = (a.size.tensor_id, a.size.axis_id) 3090 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3091 assert ref_size is not None, ref_key 3092 assert not isinstance(ref_size, _DataDepSize), ref_key 3093 return a.size.get_size( 3094 axis=a, 3095 ref_axis=ref_axis, 3096 ref_size=ref_size, 3097 ) 3098 elif isinstance(a.size, DataDependentSize): 3099 if (t_descr.id, a.id) in ns: 3100 logger.warning( 3101 "Ignoring unexpected increment factor (n) for data dependent" 3102 + " size axis '{}' of tensor '{}'.", 3103 a.id, 3104 t_descr.id, 3105 ) 3106 return _DataDepSize(a.size.min, a.size.max) 3107 else: 3108 assert_never(a.size) 3109 3110 # first resolve all , but the `SizeReference` input sizes 3111 for t_descr in self.inputs: 3112 for a in t_descr.axes: 3113 if not isinstance(a.size, SizeReference): 3114 s = get_axis_size(a) 3115 assert not isinstance(s, _DataDepSize) 3116 inputs[t_descr.id, a.id] = s 3117 3118 # resolve all other input axis sizes 3119 for t_descr in self.inputs: 3120 for a in t_descr.axes: 3121 if isinstance(a.size, SizeReference): 3122 s = get_axis_size(a) 3123 assert not isinstance(s, _DataDepSize) 3124 inputs[t_descr.id, a.id] = s 3125 3126 # resolve all output axis sizes 3127 for t_descr in self.outputs: 3128 for a in t_descr.axes: 3129 assert not isinstance(a.size, ParameterizedSize) 3130 s = get_axis_size(a) 3131 outputs[t_descr.id, a.id] = s 3132 3133 return _AxisSizes(inputs=inputs, outputs=outputs) 3134 3135 @model_validator(mode="before") 3136 @classmethod 3137 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 3138 cls.convert_from_old_format_wo_validation(data) 3139 return data 3140 3141 @classmethod 3142 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3143 """Convert metadata following an older format version to this classes' format 3144 without validating the result. 3145 """ 3146 if ( 3147 data.get("type") == "model" 3148 and isinstance(fv := data.get("format_version"), str) 3149 and fv.count(".") == 2 3150 ): 3151 fv_parts = fv.split(".") 3152 if any(not p.isdigit() for p in fv_parts): 3153 return 3154 3155 fv_tuple = tuple(map(int, fv_parts)) 3156 3157 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3158 if fv_tuple[:2] in ((0, 3), (0, 4)): 3159 m04 = _ModelDescr_v0_4.load(data) 3160 if isinstance(m04, InvalidDescr): 3161 try: 3162 updated = _model_conv.convert_as_dict( 3163 m04 # pyright: ignore[reportArgumentType] 3164 ) 3165 except Exception as e: 3166 logger.error( 3167 "Failed to convert from invalid model 0.4 description." 3168 + f"\nerror: {e}" 3169 + "\nProceeding with model 0.5 validation without conversion." 3170 ) 3171 updated = None 3172 else: 3173 updated = _model_conv.convert_as_dict(m04) 3174 3175 if updated is not None: 3176 data.clear() 3177 data.update(updated) 3178 3179 elif fv_tuple[:2] == (0, 5): 3180 # bump patch version 3181 data["format_version"] = cls.implemented_format_version
Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
bioimage.io-wide unique resource identifier assigned by bioimage.io; version unspecific.
∈📦 URL or relative path to a markdown file with additional documentation.
The recommended documentation file name is README.md
. An .md
suffix is mandatory.
The documentation should include a '#[#] Validation' (sub)section
with details on how to quantitatively validate the model on unseen data.
Describes the input tensors expected by this model.
A human-readable name of this model. It should be no longer than 64 characters and may only contain letter, number, underscore, minus, parentheses and spaces. We recommend to chose a name that refers to the model's task and image modality.
Describes the output tensors.
The persons that have packaged and uploaded this model.
Only required if those persons differ from the authors
.
The model from which this model is derived, e.g. by fine-tuning the weights.
Custom run mode for this model: for more complex prediction procedures like test time data augmentation that currently cannot be expressed in the specification. No standard run modes are defined yet.
The dataset used to train this model
The weights for this model. Weights can be given for different formats, but should otherwise be equivalent. The available weight formats determine which consumers can use this model.
2933 @staticmethod 2934 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 2935 batch_size = 1 2936 tensor_with_batchsize: Optional[TensorId] = None 2937 for tid in tensor_sizes: 2938 for aid, s in tensor_sizes[tid].items(): 2939 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 2940 continue 2941 2942 if batch_size != 1: 2943 assert tensor_with_batchsize is not None 2944 raise ValueError( 2945 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 2946 ) 2947 2948 batch_size = s 2949 tensor_with_batchsize = tid 2950 2951 return batch_size
2953 def get_output_tensor_sizes( 2954 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 2955 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 2956 """Returns the tensor output sizes for given **input_sizes**. 2957 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 2958 Otherwise it might be larger than the actual (valid) output""" 2959 batch_size = self.get_batch_size(input_sizes) 2960 ns = self.get_ns(input_sizes) 2961 2962 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 2963 return tensor_sizes.outputs
Returns the tensor output sizes for given input_sizes. Only if input_sizes has a valid input shape, the tensor output size is exact. Otherwise it might be larger than the actual (valid) output
2965 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 2966 """get parameter `n` for each parameterized axis 2967 such that the valid input size is >= the given input size""" 2968 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 2969 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 2970 for tid in input_sizes: 2971 for aid, s in input_sizes[tid].items(): 2972 size_descr = axes[tid][aid].size 2973 if isinstance(size_descr, ParameterizedSize): 2974 ret[(tid, aid)] = size_descr.get_n(s) 2975 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 2976 pass 2977 else: 2978 assert_never(size_descr) 2979 2980 return ret
get parameter n
for each parameterized axis
such that the valid input size is >= the given input size
2982 def get_tensor_sizes( 2983 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 2984 ) -> _TensorSizes: 2985 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 2986 return _TensorSizes( 2987 { 2988 t: { 2989 aa: axis_sizes.inputs[(tt, aa)] 2990 for tt, aa in axis_sizes.inputs 2991 if tt == t 2992 } 2993 for t in {tt for tt, _ in axis_sizes.inputs} 2994 }, 2995 { 2996 t: { 2997 aa: axis_sizes.outputs[(tt, aa)] 2998 for tt, aa in axis_sizes.outputs 2999 if tt == t 3000 } 3001 for t in {tt for tt, _ in axis_sizes.outputs} 3002 }, 3003 )
3005 def get_axis_sizes( 3006 self, 3007 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 3008 batch_size: Optional[int] = None, 3009 *, 3010 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 3011 ) -> _AxisSizes: 3012 """Determine input and output block shape for scale factors **ns** 3013 of parameterized input sizes. 3014 3015 Args: 3016 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 3017 that is parameterized as `size = min + n * step`. 3018 batch_size: The desired size of the batch dimension. 3019 If given **batch_size** overwrites any batch size present in 3020 **max_input_shape**. Default 1. 3021 max_input_shape: Limits the derived block shapes. 3022 Each axis for which the input size, parameterized by `n`, is larger 3023 than **max_input_shape** is set to the minimal value `n_min` for which 3024 this is still true. 3025 Use this for small input samples or large values of **ns**. 3026 Or simply whenever you know the full input shape. 3027 3028 Returns: 3029 Resolved axis sizes for model inputs and outputs. 3030 """ 3031 max_input_shape = max_input_shape or {} 3032 if batch_size is None: 3033 for (_t_id, a_id), s in max_input_shape.items(): 3034 if a_id == BATCH_AXIS_ID: 3035 batch_size = s 3036 break 3037 else: 3038 batch_size = 1 3039 3040 all_axes = { 3041 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 3042 } 3043 3044 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 3045 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 3046 3047 def get_axis_size(a: Union[InputAxis, OutputAxis]): 3048 if isinstance(a, BatchAxis): 3049 if (t_descr.id, a.id) in ns: 3050 logger.warning( 3051 "Ignoring unexpected size increment factor (n) for batch axis" 3052 + " of tensor '{}'.", 3053 t_descr.id, 3054 ) 3055 return batch_size 3056 elif isinstance(a.size, int): 3057 if (t_descr.id, a.id) in ns: 3058 logger.warning( 3059 "Ignoring unexpected size increment factor (n) for fixed size" 3060 + " axis '{}' of tensor '{}'.", 3061 a.id, 3062 t_descr.id, 3063 ) 3064 return a.size 3065 elif isinstance(a.size, ParameterizedSize): 3066 if (t_descr.id, a.id) not in ns: 3067 raise ValueError( 3068 "Size increment factor (n) missing for parametrized axis" 3069 + f" '{a.id}' of tensor '{t_descr.id}'." 3070 ) 3071 n = ns[(t_descr.id, a.id)] 3072 s_max = max_input_shape.get((t_descr.id, a.id)) 3073 if s_max is not None: 3074 n = min(n, a.size.get_n(s_max)) 3075 3076 return a.size.get_size(n) 3077 3078 elif isinstance(a.size, SizeReference): 3079 if (t_descr.id, a.id) in ns: 3080 logger.warning( 3081 "Ignoring unexpected size increment factor (n) for axis '{}'" 3082 + " of tensor '{}' with size reference.", 3083 a.id, 3084 t_descr.id, 3085 ) 3086 assert not isinstance(a, BatchAxis) 3087 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 3088 assert not isinstance(ref_axis, BatchAxis) 3089 ref_key = (a.size.tensor_id, a.size.axis_id) 3090 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 3091 assert ref_size is not None, ref_key 3092 assert not isinstance(ref_size, _DataDepSize), ref_key 3093 return a.size.get_size( 3094 axis=a, 3095 ref_axis=ref_axis, 3096 ref_size=ref_size, 3097 ) 3098 elif isinstance(a.size, DataDependentSize): 3099 if (t_descr.id, a.id) in ns: 3100 logger.warning( 3101 "Ignoring unexpected increment factor (n) for data dependent" 3102 + " size axis '{}' of tensor '{}'.", 3103 a.id, 3104 t_descr.id, 3105 ) 3106 return _DataDepSize(a.size.min, a.size.max) 3107 else: 3108 assert_never(a.size) 3109 3110 # first resolve all , but the `SizeReference` input sizes 3111 for t_descr in self.inputs: 3112 for a in t_descr.axes: 3113 if not isinstance(a.size, SizeReference): 3114 s = get_axis_size(a) 3115 assert not isinstance(s, _DataDepSize) 3116 inputs[t_descr.id, a.id] = s 3117 3118 # resolve all other input axis sizes 3119 for t_descr in self.inputs: 3120 for a in t_descr.axes: 3121 if isinstance(a.size, SizeReference): 3122 s = get_axis_size(a) 3123 assert not isinstance(s, _DataDepSize) 3124 inputs[t_descr.id, a.id] = s 3125 3126 # resolve all output axis sizes 3127 for t_descr in self.outputs: 3128 for a in t_descr.axes: 3129 assert not isinstance(a.size, ParameterizedSize) 3130 s = get_axis_size(a) 3131 outputs[t_descr.id, a.id] = s 3132 3133 return _AxisSizes(inputs=inputs, outputs=outputs)
Determine input and output block shape for scale factors ns of parameterized input sizes.
Arguments:
- ns: Scale factor
n
for each axis (keyed by (tensor_id, axis_id)) that is parameterized assize = min + n * step
. - batch_size: The desired size of the batch dimension. If given batch_size overwrites any batch size present in max_input_shape. Default 1.
- max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by
n
, is larger than max_input_shape is set to the minimal valuen_min
for which this is still true. Use this for small input samples or large values of ns. Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
3141 @classmethod 3142 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 3143 """Convert metadata following an older format version to this classes' format 3144 without validating the result. 3145 """ 3146 if ( 3147 data.get("type") == "model" 3148 and isinstance(fv := data.get("format_version"), str) 3149 and fv.count(".") == 2 3150 ): 3151 fv_parts = fv.split(".") 3152 if any(not p.isdigit() for p in fv_parts): 3153 return 3154 3155 fv_tuple = tuple(map(int, fv_parts)) 3156 3157 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 3158 if fv_tuple[:2] in ((0, 3), (0, 4)): 3159 m04 = _ModelDescr_v0_4.load(data) 3160 if isinstance(m04, InvalidDescr): 3161 try: 3162 updated = _model_conv.convert_as_dict( 3163 m04 # pyright: ignore[reportArgumentType] 3164 ) 3165 except Exception as e: 3166 logger.error( 3167 "Failed to convert from invalid model 0.4 description." 3168 + f"\nerror: {e}" 3169 + "\nProceeding with model 0.5 validation without conversion." 3170 ) 3171 updated = None 3172 else: 3173 updated = _model_conv.convert_as_dict(m04) 3174 3175 if updated is not None: 3176 data.clear() 3177 data.update(updated) 3178 3179 elif fv_tuple[:2] == (0, 5): 3180 # bump patch version 3181 data["format_version"] = cls.implemented_format_version
Convert metadata following an older format version to this classes' format without validating the result.
124 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 125 """We need to both initialize private attributes and call the user-defined model_post_init 126 method. 127 """ 128 init_private_attributes(self, context) 129 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Union of any released model desription