bioimageio.core
1""" 2.. include:: ../../README.md 3""" 4# ruff: noqa: E402 5 6__version__ = "0.9.4" 7from loguru import logger 8 9logger.disable("bioimageio.core") 10 11from bioimageio.spec import ( 12 ValidationSummary, 13 build_description, 14 dump_description, 15 load_dataset_description, 16 load_description, 17 load_description_and_validate_format_only, 18 load_model_description, 19 save_bioimageio_package, 20 save_bioimageio_package_as_folder, 21 save_bioimageio_yaml_only, 22 validate_format, 23) 24 25from . import ( 26 axis, 27 block_meta, 28 cli, 29 commands, 30 common, 31 digest_spec, 32 io, 33 model_adapters, 34 prediction, 35 proc_ops, 36 proc_setup, 37 sample, 38 stat_calculators, 39 stat_measures, 40 tensor, 41) 42from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline 43from ._resource_tests import ( 44 enable_determinism, 45 load_description_and_test, 46 test_description, 47 test_model, 48) 49from ._settings import settings 50from .axis import Axis, AxisId 51from .backends import create_model_adapter 52from .block_meta import BlockMeta 53from .common import MemberId 54from .prediction import predict, predict_many 55from .sample import Sample 56from .stat_calculators import compute_dataset_measures 57from .stat_measures import Stat 58from .tensor import Tensor 59from .weight_converters import add_weights 60 61# aliases 62test_resource = test_description 63"""alias of `test_description`""" 64load_resource = load_description 65"""alias of `load_description`""" 66load_model = load_model_description 67"""alias of `load_model_description`""" 68 69__all__ = [ 70 "__version__", 71 "add_weights", 72 "axis", 73 "Axis", 74 "AxisId", 75 "block_meta", 76 "BlockMeta", 77 "build_description", 78 "cli", 79 "commands", 80 "common", 81 "compute_dataset_measures", 82 "create_model_adapter", 83 "create_prediction_pipeline", 84 "digest_spec", 85 "dump_description", 86 "enable_determinism", 87 "io", 88 "load_dataset_description", 89 "load_description_and_test", 90 "load_description_and_validate_format_only", 91 "load_description", 92 "load_model_description", 93 "load_model", 94 "load_resource", 95 "MemberId", 96 "model_adapters", 97 "predict_many", 98 "predict", 99 "prediction", 100 "PredictionPipeline", 101 "proc_ops", 102 "proc_setup", 103 "sample", 104 "Sample", 105 "save_bioimageio_package_as_folder", 106 "save_bioimageio_package", 107 "save_bioimageio_yaml_only", 108 "settings", 109 "stat_calculators", 110 "stat_measures", 111 "Stat", 112 "tensor", 113 "Tensor", 114 "test_description", 115 "test_model", 116 "test_resource", 117 "validate_format", 118 "ValidationSummary", 119]
18def add_weights( 19 model_descr: ModelDescr, 20 *, 21 output_path: DirectoryPath, 22 source_format: Optional[WeightsFormat] = None, 23 target_format: Optional[WeightsFormat] = None, 24 verbose: bool = False, 25 allow_tracing: bool = True, 26) -> Union[ModelDescr, InvalidDescr]: 27 """Convert model weights to other formats and add them to the model description 28 29 Args: 30 output_path: Path to save updated model package to. 31 source_format: convert from a specific weights format. 32 Default: choose automatically from any available. 33 target_format: convert to a specific weights format. 34 Default: attempt to convert to any missing format. 35 devices: Devices that may be used during conversion. 36 verbose: log more (error) output 37 38 Returns: 39 A (potentially invalid) model copy stored at `output_path` with added weights if any conversion was possible. 40 41 """ 42 if not isinstance(model_descr, ModelDescr): 43 if model_descr.type == "model" and not isinstance(model_descr, InvalidDescr): 44 raise TypeError( 45 f"Model format {model_descr.format} is not supported, please update" 46 + f" model to format {ModelDescr.implemented_format_version} first." 47 ) 48 49 raise TypeError(type(model_descr)) 50 51 # save model to local folder 52 output_path = save_bioimageio_package_as_folder( 53 model_descr, output_path=output_path 54 ) 55 # reload from local folder to make sure we do not edit the given model 56 model_descr = load_model_description( 57 output_path, perform_io_checks=False, format_version="latest" 58 ) 59 60 if source_format is None: 61 available = set(model_descr.weights.available_formats) 62 else: 63 available = {source_format} 64 65 if target_format is None: 66 missing = set(model_descr.weights.missing_formats) 67 else: 68 missing = {target_format} 69 70 originally_missing = set(missing) 71 72 if "pytorch_state_dict" in available and "torchscript" in missing: 73 logger.info( 74 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript'." 75 ) 76 from .pytorch_to_torchscript import convert 77 78 try: 79 torchscript_weights_path = output_path / "weights_torchscript.pt" 80 model_descr.weights.torchscript = convert( 81 model_descr, 82 output_path=torchscript_weights_path, 83 use_tracing=False, 84 ) 85 except Exception as e: 86 if verbose: 87 traceback.print_exception(type(e), e, e.__traceback__) 88 89 logger.error(e) 90 else: 91 available.add("torchscript") 92 missing.discard("torchscript") 93 94 if allow_tracing and "pytorch_state_dict" in available and "torchscript" in missing: 95 logger.info( 96 "Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing." 97 ) 98 from .pytorch_to_torchscript import convert 99 100 try: 101 torchscript_weights_path = output_path / "weights_torchscript_traced.pt" 102 103 model_descr.weights.torchscript = convert( 104 model_descr, 105 output_path=torchscript_weights_path, 106 use_tracing=True, 107 ) 108 except Exception as e: 109 if verbose: 110 traceback.print_exception(type(e), e, e.__traceback__) 111 112 logger.error(e) 113 else: 114 available.add("torchscript") 115 missing.discard("torchscript") 116 117 if "torchscript" in available and "onnx" in missing: 118 logger.info("Attempting to convert 'torchscript' weights to 'onnx'.") 119 from .torchscript_to_onnx import convert 120 121 try: 122 onnx_weights_path = output_path / "weights.onnx" 123 model_descr.weights.onnx = convert( 124 model_descr, 125 output_path=onnx_weights_path, 126 ) 127 except Exception as e: 128 if verbose: 129 traceback.print_exception(type(e), e, e.__traceback__) 130 131 logger.error(e) 132 else: 133 available.add("onnx") 134 missing.discard("onnx") 135 136 if "pytorch_state_dict" in available and "onnx" in missing: 137 logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.") 138 from .pytorch_to_onnx import convert 139 140 try: 141 onnx_weights_path = output_path / "weights.onnx" 142 143 model_descr.weights.onnx = convert( 144 model_descr, 145 output_path=onnx_weights_path, 146 verbose=verbose, 147 ) 148 except Exception as e: 149 if verbose: 150 traceback.print_exception(type(e), e, e.__traceback__) 151 152 logger.error(e) 153 else: 154 available.add("onnx") 155 missing.discard("onnx") 156 157 if missing: 158 logger.warning( 159 f"Converting from any of the available weights formats {available} to any" 160 + f" of {missing} failed or is not yet implemented. Please create an issue" 161 + " at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose" 162 + " if you would like bioimageio.core to support a particular conversion." 163 ) 164 165 if originally_missing == missing: 166 logger.warning("failed to add any converted weights") 167 return model_descr 168 else: 169 logger.info("added weights formats {}", originally_missing - missing) 170 # resave model with updated rdf.yaml 171 _ = save_bioimageio_package_as_folder(model_descr, output_path=output_path) 172 tested_model_descr = load_description_and_test( 173 model_descr, format_version="latest", expected_type="model" 174 ) 175 if not isinstance(tested_model_descr, ModelDescr): 176 logger.error( 177 f"The updated model description at {output_path} did not pass testing." 178 ) 179 180 return tested_model_descr
Convert model weights to other formats and add them to the model description
Arguments:
- output_path: Path to save updated model package to.
- source_format: convert from a specific weights format. Default: choose automatically from any available.
- target_format: convert to a specific weights format. Default: attempt to convert to any missing format.
- devices: Devices that may be used during conversion.
- verbose: log more (error) output
Returns:
A (potentially invalid) model copy stored at
output_path
with added weights if any conversion was possible.
57@dataclass 58class Axis: 59 id: AxisId 60 type: Literal["batch", "channel", "index", "space", "time"] 61 62 def __post_init__(self): 63 if self.type == "batch": 64 self.id = AxisId("batch") 65 elif self.type == "channel": 66 self.id = AxisId("channel") 67 68 @classmethod 69 def create(cls, axis: AxisLike) -> Axis: 70 if isinstance(axis, cls): 71 return axis 72 73 if isinstance(axis, (AxisId, str)): 74 axis_id = axis 75 axis_type = _guess_axis_type(str(axis)) 76 else: 77 if hasattr(axis, "type"): 78 axis_type = axis.type 79 else: 80 axis_type = _guess_axis_type(str(axis)) 81 82 if hasattr(axis, "id"): 83 axis_id = axis.id 84 else: 85 axis_id = axis 86 87 return Axis(id=AxisId(axis_id), type=axis_type)
68 @classmethod 69 def create(cls, axis: AxisLike) -> Axis: 70 if isinstance(axis, cls): 71 return axis 72 73 if isinstance(axis, (AxisId, str)): 74 axis_id = axis 75 axis_type = _guess_axis_type(str(axis)) 76 else: 77 if hasattr(axis, "type"): 78 axis_type = axis.type 79 else: 80 axis_type = _guess_axis_type(str(axis)) 81 82 if hasattr(axis, "id"): 83 axis_id = axis.id 84 else: 85 axis_id = axis 86 87 return Axis(id=AxisId(axis_id), type=axis_type)
245class AxisId(LowerCaseIdentifier): 246 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 247 Annotated[ 248 LowerCaseIdentifierAnno, 249 MaxLen(16), 250 AfterValidator(_normalize_axis_id), 251 ] 252 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
46@dataclass(frozen=True) 47class BlockMeta: 48 """Block meta data of a sample member (a tensor in a sample) 49 50 Figure for illustration: 51 The first 2d block (dashed) of a sample member (**bold**). 52 The inner slice (thin) is expanded by a halo in both dimensions on both sides. 53 The outer slice reaches from the sample member origin (0, 0) to the right halo point. 54 55 ```terminal 56 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ 57 ╷ halo(left) ╷ 58 ╷ ╷ 59 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ 60 ╷ ┃ │ ╷ sample member 61 ╷ ┃ inner │ ╷ 62 ╷ ┃ (and outer) │ outer ╷ 63 ╷ ┃ slice │ slice ╷ 64 ╷ ┃ │ ╷ 65 ╷ ┣─────────────────┘ ╷ 66 ╷ ┃ outer slice ╷ 67 ╷ ┃ halo(right) ╷ 68 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ 69 ⬇ 70 ``` 71 72 note: 73 - Inner and outer slices are specified in sample member coordinates. 74 - The outer_slice of a block at the sample edge may overlap by more than the 75 halo with the neighboring block (the inner slices will not overlap though). 76 77 """ 78 79 sample_shape: PerAxis[int] 80 """the axis sizes of the whole (unblocked) sample""" 81 82 inner_slice: PerAxis[SliceInfo] 83 """inner region (without halo) wrt the sample""" 84 85 halo: PerAxis[Halo] 86 """halo enlarging the inner region to the block's sizes""" 87 88 block_index: BlockIndex 89 """the i-th block of the sample""" 90 91 blocks_in_sample: TotalNumberOfBlocks 92 """total number of blocks in the sample""" 93 94 @cached_property 95 def shape(self) -> PerAxis[int]: 96 """axis lengths of the block""" 97 return Frozen( 98 { 99 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 100 for a, s in self.inner_slice.items() 101 } 102 ) 103 104 @cached_property 105 def padding(self) -> PerAxis[PadWidth]: 106 """padding to realize the halo at the sample edge 107 where we cannot simply enlarge the inner slice""" 108 return Frozen( 109 { 110 a: PadWidth( 111 ( 112 self.halo[a].left 113 - (self.inner_slice[a].start - self.outer_slice[a].start) 114 if a in self.halo 115 else 0 116 ), 117 ( 118 self.halo[a].right 119 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 120 if a in self.halo 121 else 0 122 ), 123 ) 124 for a in self.inner_slice 125 } 126 ) 127 128 @cached_property 129 def outer_slice(self) -> PerAxis[SliceInfo]: 130 """slice of the outer block (without padding) wrt the sample""" 131 return Frozen( 132 { 133 a: SliceInfo( 134 max( 135 0, 136 min( 137 self.inner_slice[a].start 138 - (self.halo[a].left if a in self.halo else 0), 139 self.sample_shape[a] 140 - self.inner_shape[a] 141 - (self.halo[a].left if a in self.halo else 0), 142 ), 143 ), 144 min( 145 self.sample_shape[a], 146 self.inner_slice[a].stop 147 + (self.halo[a].right if a in self.halo else 0), 148 ), 149 ) 150 for a in self.inner_slice 151 } 152 ) 153 154 @cached_property 155 def inner_shape(self) -> PerAxis[int]: 156 """axis lengths of the inner region (without halo)""" 157 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()}) 158 159 @cached_property 160 def local_slice(self) -> PerAxis[SliceInfo]: 161 """inner slice wrt the block, **not** the sample""" 162 return Frozen( 163 { 164 a: SliceInfo( 165 self.halo[a].left, 166 self.halo[a].left + self.inner_shape[a], 167 ) 168 for a in self.inner_slice 169 } 170 ) 171 172 @property 173 def dims(self) -> Collection[AxisId]: 174 return set(self.inner_shape) 175 176 @property 177 def tagged_shape(self) -> PerAxis[int]: 178 """alias for shape""" 179 return self.shape 180 181 @property 182 def inner_slice_wo_overlap(self): 183 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 184 stiched together trivially to form the original sample. 185 186 This can also be used to calculate statistics 187 without overrepresenting block edge regions.""" 188 # TODO: update inner_slice_wo_overlap when adding block overlap 189 return self.inner_slice 190 191 def __post_init__(self): 192 # freeze mutable inputs 193 if not isinstance(self.sample_shape, Frozen): 194 object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) 195 196 if not isinstance(self.inner_slice, Frozen): 197 object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) 198 199 if not isinstance(self.halo, Frozen): 200 object.__setattr__(self, "halo", Frozen(self.halo)) 201 202 assert all(a in self.sample_shape for a in self.inner_slice), ( 203 "block has axes not present in sample" 204 ) 205 206 assert all(a in self.inner_slice for a in self.halo), ( 207 "halo has axes not present in block" 208 ) 209 210 if any(s > self.sample_shape[a] for a, s in self.shape.items()): 211 logger.warning( 212 "block {} larger than sample {}", self.shape, self.sample_shape 213 ) 214 215 def get_transformed( 216 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 217 ) -> Self: 218 return self.__class__( 219 sample_shape={ 220 a: ( 221 trf 222 if isinstance(trf, int) 223 else trf.compute(self.sample_shape[trf.axis]) 224 ) 225 for a, trf in new_axes.items() 226 }, 227 inner_slice={ 228 a: ( 229 SliceInfo(0, trf) 230 if isinstance(trf, int) 231 else SliceInfo( 232 trf.compute(self.inner_slice[trf.axis].start), 233 trf.compute(self.inner_slice[trf.axis].stop), 234 ) 235 ) 236 for a, trf in new_axes.items() 237 }, 238 halo={ 239 a: ( 240 Halo(0, 0) 241 if isinstance(trf, int) 242 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 243 ) 244 for a, trf in new_axes.items() 245 }, 246 block_index=self.block_index, 247 blocks_in_sample=self.blocks_in_sample, 248 )
Block meta data of a sample member (a tensor in a sample)
Figure for illustration: The first 2d block (dashed) of a sample member (bold). The inner slice (thin) is expanded by a halo in both dimensions on both sides. The outer slice reaches from the sample member origin (0, 0) to the right halo point.
┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐
╷ halo(left) ╷
╷ ╷
╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔
╷ ┃ │ ╷ sample member
╷ ┃ inner │ ╷
╷ ┃ (and outer) │ outer ╷
╷ ┃ slice │ slice ╷
╷ ┃ │ ╷
╷ ┣─────────────────┘ ╷
╷ ┃ outer slice ╷
╷ ┃ halo(right) ╷
└ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘
⬇
note:
- Inner and outer slices are specified in sample member coordinates.
- The outer_slice of a block at the sample edge may overlap by more than the halo with the neighboring block (the inner slices will not overlap though).
inner region (without halo) wrt the sample
halo enlarging the inner region to the block's sizes
94 @cached_property 95 def shape(self) -> PerAxis[int]: 96 """axis lengths of the block""" 97 return Frozen( 98 { 99 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 100 for a, s in self.inner_slice.items() 101 } 102 )
axis lengths of the block
104 @cached_property 105 def padding(self) -> PerAxis[PadWidth]: 106 """padding to realize the halo at the sample edge 107 where we cannot simply enlarge the inner slice""" 108 return Frozen( 109 { 110 a: PadWidth( 111 ( 112 self.halo[a].left 113 - (self.inner_slice[a].start - self.outer_slice[a].start) 114 if a in self.halo 115 else 0 116 ), 117 ( 118 self.halo[a].right 119 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 120 if a in self.halo 121 else 0 122 ), 123 ) 124 for a in self.inner_slice 125 } 126 )
padding to realize the halo at the sample edge where we cannot simply enlarge the inner slice
128 @cached_property 129 def outer_slice(self) -> PerAxis[SliceInfo]: 130 """slice of the outer block (without padding) wrt the sample""" 131 return Frozen( 132 { 133 a: SliceInfo( 134 max( 135 0, 136 min( 137 self.inner_slice[a].start 138 - (self.halo[a].left if a in self.halo else 0), 139 self.sample_shape[a] 140 - self.inner_shape[a] 141 - (self.halo[a].left if a in self.halo else 0), 142 ), 143 ), 144 min( 145 self.sample_shape[a], 146 self.inner_slice[a].stop 147 + (self.halo[a].right if a in self.halo else 0), 148 ), 149 ) 150 for a in self.inner_slice 151 } 152 )
slice of the outer block (without padding) wrt the sample
154 @cached_property 155 def inner_shape(self) -> PerAxis[int]: 156 """axis lengths of the inner region (without halo)""" 157 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()})
axis lengths of the inner region (without halo)
159 @cached_property 160 def local_slice(self) -> PerAxis[SliceInfo]: 161 """inner slice wrt the block, **not** the sample""" 162 return Frozen( 163 { 164 a: SliceInfo( 165 self.halo[a].left, 166 self.halo[a].left + self.inner_shape[a], 167 ) 168 for a in self.inner_slice 169 } 170 )
inner slice wrt the block, not the sample
176 @property 177 def tagged_shape(self) -> PerAxis[int]: 178 """alias for shape""" 179 return self.shape
alias for shape
181 @property 182 def inner_slice_wo_overlap(self): 183 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 184 stiched together trivially to form the original sample. 185 186 This can also be used to calculate statistics 187 without overrepresenting block edge regions.""" 188 # TODO: update inner_slice_wo_overlap when adding block overlap 189 return self.inner_slice
subslice of the inner slice, such that all inner_slice_wo_overlap
can be
stiched together trivially to form the original sample.
This can also be used to calculate statistics without overrepresenting block edge regions.
215 def get_transformed( 216 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 217 ) -> Self: 218 return self.__class__( 219 sample_shape={ 220 a: ( 221 trf 222 if isinstance(trf, int) 223 else trf.compute(self.sample_shape[trf.axis]) 224 ) 225 for a, trf in new_axes.items() 226 }, 227 inner_slice={ 228 a: ( 229 SliceInfo(0, trf) 230 if isinstance(trf, int) 231 else SliceInfo( 232 trf.compute(self.inner_slice[trf.axis].start), 233 trf.compute(self.inner_slice[trf.axis].stop), 234 ) 235 ) 236 for a, trf in new_axes.items() 237 }, 238 halo={ 239 a: ( 240 Halo(0, 0) 241 if isinstance(trf, int) 242 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 243 ) 244 for a, trf in new_axes.items() 245 }, 246 block_index=self.block_index, 247 blocks_in_sample=self.blocks_in_sample, 248 )
175def build_description( 176 content: BioimageioYamlContentView, 177 /, 178 *, 179 context: Optional[ValidationContext] = None, 180 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 181) -> Union[ResourceDescr, InvalidDescr]: 182 """build a bioimage.io resource description from an RDF's content. 183 184 Use `load_description` if you want to build a resource description from an rdf.yaml 185 or bioimage.io zip-package. 186 187 Args: 188 content: loaded rdf.yaml file (loaded with YAML, not bioimageio.spec) 189 context: validation context to use during validation 190 format_version: 191 (optional) use this argument to load the resource and 192 convert its metadata to a higher format_version. 193 Note: 194 - Use "latest" to convert to the latest available format version. 195 - Use "discover" to use the format version specified in the RDF. 196 - Only considers major.minor format version, ignores patch version. 197 - Conversion to lower format versions is not supported. 198 199 Returns: 200 An object holding all metadata of the bioimage.io resource 201 202 """ 203 204 return build_description_impl( 205 content, 206 context=context, 207 format_version=format_version, 208 get_rd_class=_get_rd_class, 209 )
build a bioimage.io resource description from an RDF's content.
Use load_description
if you want to build a resource description from an rdf.yaml
or bioimage.io zip-package.
Arguments:
- content: loaded rdf.yaml file (loaded with YAML, not bioimageio.spec)
- context: validation context to use during validation
- format_version: (optional) use this argument to load the resource and
convert its metadata to a higher format_version.
Note:
- Use "latest" to convert to the latest available format version.
- Use "discover" to use the format version specified in the RDF.
- Only considers major.minor format version, ignores patch version.
- Conversion to lower format versions is not supported.
Returns:
An object holding all metadata of the bioimage.io resource
578def compute_dataset_measures( 579 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 580) -> Dict[DatasetMeasure, MeasureValue]: 581 """compute all dataset `measures` for the given `dataset`""" 582 sample_calculators, calculators = get_measure_calculators(measures) 583 assert not sample_calculators 584 585 ret: Dict[DatasetMeasure, MeasureValue] = {} 586 587 for sample in dataset: 588 for calc in calculators: 589 calc.update(sample) 590 591 for calc in calculators: 592 ret.update(calc.finalize().items()) 593 594 return ret
compute all dataset measures
for the given dataset
72 @final 73 @classmethod 74 def create( 75 cls, 76 model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], 77 *, 78 devices: Optional[Sequence[str]] = None, 79 weight_format_priority_order: Optional[Sequence[SupportedWeightsFormat]] = None, 80 ): 81 """ 82 Creates model adapter based on the passed spec 83 Note: All specific adapters should happen inside this function to prevent different framework 84 initializations interfering with each other 85 """ 86 if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): 87 raise TypeError( 88 f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" 89 ) 90 91 weights = model_description.weights 92 errors: List[Exception] = [] 93 weight_format_priority_order = ( 94 DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER 95 if weight_format_priority_order is None 96 else weight_format_priority_order 97 ) 98 # limit weight formats to the ones present 99 weight_format_priority_order_present: Sequence[SupportedWeightsFormat] = [ 100 w for w in weight_format_priority_order if getattr(weights, w) is not None 101 ] 102 if not weight_format_priority_order_present: 103 raise ValueError( 104 f"None of the specified weight formats ({weight_format_priority_order}) is present ({weight_format_priority_order_present})" 105 ) 106 107 for wf in weight_format_priority_order_present: 108 if wf == "pytorch_state_dict": 109 assert weights.pytorch_state_dict is not None 110 try: 111 from .pytorch_backend import PytorchModelAdapter 112 113 return PytorchModelAdapter( 114 model_description=model_description, devices=devices 115 ) 116 except Exception as e: 117 errors.append(e) 118 elif wf == "tensorflow_saved_model_bundle": 119 assert weights.tensorflow_saved_model_bundle is not None 120 try: 121 from .tensorflow_backend import create_tf_model_adapter 122 123 return create_tf_model_adapter( 124 model_description=model_description, devices=devices 125 ) 126 except Exception as e: 127 errors.append(e) 128 elif wf == "onnx": 129 assert weights.onnx is not None 130 try: 131 from .onnx_backend import ONNXModelAdapter 132 133 return ONNXModelAdapter( 134 model_description=model_description, devices=devices 135 ) 136 except Exception as e: 137 errors.append(e) 138 elif wf == "torchscript": 139 assert weights.torchscript is not None 140 try: 141 from .torchscript_backend import TorchscriptModelAdapter 142 143 return TorchscriptModelAdapter( 144 model_description=model_description, devices=devices 145 ) 146 except Exception as e: 147 errors.append(e) 148 elif wf == "keras_hdf5": 149 assert weights.keras_hdf5 is not None 150 # keras can either be installed as a separate package or used as part of tensorflow 151 # we try to first import the keras model adapter using the separate package and, 152 # if it is not available, try to load the one using tf 153 try: 154 try: 155 from .keras_backend import KerasModelAdapter 156 except Exception: 157 from .tensorflow_backend import KerasModelAdapter 158 159 return KerasModelAdapter( 160 model_description=model_description, devices=devices 161 ) 162 except Exception as e: 163 errors.append(e) 164 else: 165 assert_never(wf) 166 167 assert errors 168 if len(weight_format_priority_order) == 1: 169 assert len(errors) == 1 170 raise errors[0] 171 172 else: 173 msg = ( 174 "None of the weight format specific model adapters could be created" 175 + " in this environment." 176 ) 177 raise ExceptionGroup(msg, errors)
Creates model adapter based on the passed spec Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other
317def create_prediction_pipeline( 318 bioimageio_model: AnyModelDescr, 319 *, 320 devices: Optional[Sequence[str]] = None, 321 weight_format: Optional[SupportedWeightsFormat] = None, 322 weights_format: Optional[SupportedWeightsFormat] = None, 323 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), 324 keep_updating_initial_dataset_statistics: bool = False, 325 fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( 326 {} 327 ), 328 model_adapter: Optional[ModelAdapter] = None, 329 ns: Optional[BlocksizeParameter] = None, 330 default_blocksize_parameter: BlocksizeParameter = 10, 331 **deprecated_kwargs: Any, 332) -> PredictionPipeline: 333 """ 334 Creates prediction pipeline which includes: 335 * computation of input statistics 336 * preprocessing 337 * model prediction 338 * computation of output statistics 339 * postprocessing 340 341 Args: 342 bioimageio_model: A bioimageio model description. 343 devices: (optional) 344 weight_format: deprecated in favor of **weights_format** 345 weights_format: (optional) Use a specific **weights_format** rather than 346 choosing one automatically. 347 A corresponding `bioimageio.core.model_adapters.ModelAdapter` will be 348 created to run inference with the **bioimageio_model**. 349 dataset_for_initial_statistics: (optional) If preprocessing steps require input 350 dataset statistics, **dataset_for_initial_statistics** allows you to 351 specifcy a dataset from which these statistics are computed. 352 keep_updating_initial_dataset_statistics: (optional) Set to `True` if you want 353 to update dataset statistics with each processed sample. 354 fixed_dataset_statistics: (optional) Allows you to specify a mapping of 355 `DatasetMeasure`s to precomputed `MeasureValue`s. 356 model_adapter: (optional) Allows you to use a custom **model_adapter** instead 357 of creating one according to the present/selected **weights_format**. 358 ns: deprecated in favor of **default_blocksize_parameter** 359 default_blocksize_parameter: Allows to control the default block size for 360 blockwise predictions, see `BlocksizeParameter`. 361 362 """ 363 weights_format = weight_format or weights_format 364 del weight_format 365 default_blocksize_parameter = ns or default_blocksize_parameter 366 del ns 367 if deprecated_kwargs: 368 warnings.warn( 369 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}" 370 ) 371 372 model_adapter = model_adapter or create_model_adapter( 373 model_description=bioimageio_model, 374 devices=devices, 375 weight_format_priority_order=weights_format and (weights_format,), 376 ) 377 378 input_ids = get_member_ids(bioimageio_model.inputs) 379 380 def dataset(): 381 common_stat: Stat = {} 382 for i, x in enumerate(dataset_for_initial_statistics): 383 if isinstance(x, Sample): 384 yield x 385 else: 386 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i) 387 388 preprocessing, postprocessing = setup_pre_and_postprocessing( 389 bioimageio_model, 390 dataset(), 391 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics, 392 fixed_dataset_stats=fixed_dataset_statistics, 393 ) 394 395 return PredictionPipeline( 396 name=bioimageio_model.name, 397 model_description=bioimageio_model, 398 model_adapter=model_adapter, 399 preprocessing=preprocessing, 400 postprocessing=postprocessing, 401 default_blocksize_parameter=default_blocksize_parameter, 402 )
Creates prediction pipeline which includes:
- computation of input statistics
- preprocessing
- model prediction
- computation of output statistics
- postprocessing
Arguments:
- bioimageio_model: A bioimageio model description.
- devices: (optional)
- weight_format: deprecated in favor of weights_format
- weights_format: (optional) Use a specific weights_format rather than
choosing one automatically.
A corresponding
bioimageio.core.model_adapters.ModelAdapter
will be created to run inference with the bioimageio_model. - dataset_for_initial_statistics: (optional) If preprocessing steps require input dataset statistics, dataset_for_initial_statistics allows you to specifcy a dataset from which these statistics are computed.
- keep_updating_initial_dataset_statistics: (optional) Set to
True
if you want to update dataset statistics with each processed sample. - fixed_dataset_statistics: (optional) Allows you to specify a mapping of
DatasetMeasure
s to precomputedMeasureValue
s. - model_adapter: (optional) Allows you to use a custom model_adapter instead of creating one according to the present/selected weights_format.
- ns: deprecated in favor of default_blocksize_parameter
- default_blocksize_parameter: Allows to control the default block size for
blockwise predictions, see
BlocksizeParameter
.
66def dump_description( 67 rd: Union[ResourceDescr, InvalidDescr], 68 /, 69 *, 70 exclude_unset: bool = True, 71 exclude_defaults: bool = False, 72) -> BioimageioYamlContent: 73 """Converts a resource to a dictionary containing only simple types that can directly be serialzed to YAML. 74 75 Args: 76 rd: bioimageio resource description 77 exclude_unset: Exclude fields that have not explicitly be set. 78 exclude_defaults: Exclude fields that have the default value (even if set explicitly). 79 """ 80 return rd.model_dump( 81 mode="json", exclude_unset=exclude_unset, exclude_defaults=exclude_defaults 82 )
Converts a resource to a dictionary containing only simple types that can directly be serialzed to YAML.
Arguments:
- rd: bioimageio resource description
- exclude_unset: Exclude fields that have not explicitly be set.
- exclude_defaults: Exclude fields that have the default value (even if set explicitly).
88def enable_determinism( 89 mode: Literal["seed_only", "full"] = "full", 90 weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None, 91): 92 """Seed and configure ML frameworks for maximum reproducibility. 93 May degrade performance. Only recommended for testing reproducibility! 94 95 Seed any random generators and (if **mode**=="full") request ML frameworks to use 96 deterministic algorithms. 97 98 Args: 99 mode: determinism mode 100 - 'seed_only' -- only set seeds, or 101 - 'full' determinsm features (might degrade performance or throw exceptions) 102 weight_formats: Limit deep learning importing deep learning frameworks 103 based on weight_formats. 104 E.g. this allows to avoid importing tensorflow when testing with pytorch. 105 106 Notes: 107 - **mode** == "full" might degrade performance or throw exceptions. 108 - Subsequent inference calls might still differ. Call before each function 109 (sequence) that is expected to be reproducible. 110 - Degraded performance: Use for testing reproducibility only! 111 - Recipes: 112 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 113 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 114 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 115 """ 116 try: 117 try: 118 import numpy.random 119 except ImportError: 120 pass 121 else: 122 numpy.random.seed(0) 123 except Exception as e: 124 logger.debug(str(e)) 125 126 if ( 127 weight_formats is None 128 or "pytorch_state_dict" in weight_formats 129 or "torchscript" in weight_formats 130 ): 131 try: 132 try: 133 import torch 134 except ImportError: 135 pass 136 else: 137 _ = torch.manual_seed(0) 138 torch.use_deterministic_algorithms(mode == "full") 139 except Exception as e: 140 logger.debug(str(e)) 141 142 if ( 143 weight_formats is None 144 or "tensorflow_saved_model_bundle" in weight_formats 145 or "keras_hdf5" in weight_formats 146 ): 147 try: 148 os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 149 try: 150 import tensorflow as tf # pyright: ignore[reportMissingTypeStubs] 151 except ImportError: 152 pass 153 else: 154 tf.random.set_seed(0) 155 if mode == "full": 156 tf.config.experimental.enable_op_determinism() 157 # TODO: find possibility to switch it off again?? 158 except Exception as e: 159 logger.debug(str(e)) 160 161 if weight_formats is None or "keras_hdf5" in weight_formats: 162 try: 163 try: 164 import keras # pyright: ignore[reportMissingTypeStubs] 165 except ImportError: 166 pass 167 else: 168 keras.utils.set_random_seed(0) 169 except Exception as e: 170 logger.debug(str(e))
Seed and configure ML frameworks for maximum reproducibility. May degrade performance. Only recommended for testing reproducibility!
Seed any random generators and (if mode=="full") request ML frameworks to use deterministic algorithms.
Arguments:
- mode: determinism mode
- 'seed_only' -- only set seeds, or
- 'full' determinsm features (might degrade performance or throw exceptions)
- weight_formats: Limit deep learning importing deep learning frameworks based on weight_formats. E.g. this allows to avoid importing tensorflow when testing with pytorch.
Notes:
191def load_dataset_description( 192 source: Union[PermissiveFileSource, ZipFile], 193 /, 194 *, 195 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 196 perform_io_checks: Optional[bool] = None, 197 known_files: Optional[Dict[str, Optional[Sha256]]] = None, 198 sha256: Optional[Sha256] = None, 199) -> AnyDatasetDescr: 200 """same as `load_description`, but addtionally ensures that the loaded 201 description is valid and of type 'dataset'. 202 """ 203 rd = load_description( 204 source, 205 format_version=format_version, 206 perform_io_checks=perform_io_checks, 207 known_files=known_files, 208 sha256=sha256, 209 ) 210 return ensure_description_is_dataset(rd)
same as load_description
, but addtionally ensures that the loaded
description is valid and of type 'dataset'.
588def load_description_and_test( 589 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 590 *, 591 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 592 weight_format: Optional[SupportedWeightsFormat] = None, 593 devices: Optional[Sequence[str]] = None, 594 determinism: Literal["seed_only", "full"] = "seed_only", 595 expected_type: Optional[str] = None, 596 sha256: Optional[Sha256] = None, 597 stop_early: bool = True, 598 **deprecated: Unpack[DeprecatedKwargs], 599) -> Union[ResourceDescr, InvalidDescr]: 600 """Test a bioimage.io resource dynamically, 601 for example run prediction of test tensors for models. 602 603 See `test_description` for more details. 604 605 Returns: 606 A (possibly invalid) resource description object 607 with a populated `.validation_summary` attribute. 608 """ 609 if isinstance(source, ResourceDescrBase): 610 root = source.root 611 file_name = source.file_name 612 if ( 613 ( 614 format_version 615 not in ( 616 DISCOVER, 617 source.format_version, 618 ".".join(source.format_version.split(".")[:2]), 619 ) 620 ) 621 or (c := source.validation_summary.details[0].context) is None 622 or not c.perform_io_checks 623 ): 624 logger.debug( 625 "deserializing source to ensure we validate and test using format {} and perform io checks", 626 format_version, 627 ) 628 source = dump_description(source) 629 else: 630 root = Path() 631 file_name = None 632 633 if isinstance(source, ResourceDescrBase): 634 rd = source 635 elif isinstance(source, dict): 636 # check context for a given root; default to root of source 637 context = get_validation_context( 638 ValidationContext(root=root, file_name=file_name) 639 ).replace( 640 perform_io_checks=True # make sure we perform io checks though 641 ) 642 643 rd = build_description( 644 source, 645 format_version=format_version, 646 context=context, 647 ) 648 else: 649 rd = load_description( 650 source, format_version=format_version, sha256=sha256, perform_io_checks=True 651 ) 652 653 rd.validation_summary.env.add( 654 InstalledPackage(name="bioimageio.core", version=__version__) 655 ) 656 657 if expected_type is not None: 658 _test_expected_resource_type(rd, expected_type) 659 660 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 661 if weight_format is None: 662 weight_formats: List[SupportedWeightsFormat] = [ 663 w for w, we in rd.weights if we is not None 664 ] # pyright: ignore[reportAssignmentType] 665 else: 666 weight_formats = [weight_format] 667 668 enable_determinism(determinism, weight_formats=weight_formats) 669 for w in weight_formats: 670 _test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated) 671 if stop_early and rd.validation_summary.status == "failed": 672 break 673 674 if not isinstance(rd, v0_4.ModelDescr): 675 _test_model_inference_parametrized( 676 rd, w, devices, stop_early=stop_early 677 ) 678 if stop_early and rd.validation_summary.status == "failed": 679 break 680 681 # TODO: add execution of jupyter notebooks 682 # TODO: add more tests 683 684 if rd.validation_summary.status == "valid-format": 685 rd.validation_summary.status = "passed" 686 687 return rd
Test a bioimage.io resource dynamically, for example run prediction of test tensors for models.
See test_description
for more details.
Returns:
A (possibly invalid) resource description object with a populated
.validation_summary
attribute.
243def load_description_and_validate_format_only( 244 source: Union[PermissiveFileSource, ZipFile], 245 /, 246 *, 247 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 248 perform_io_checks: Optional[bool] = None, 249 known_files: Optional[Dict[str, Optional[Sha256]]] = None, 250 sha256: Optional[Sha256] = None, 251) -> ValidationSummary: 252 """same as `load_description`, but only return the validation summary. 253 254 Returns: 255 Validation summary of the bioimage.io resource found at `source`. 256 257 """ 258 rd = load_description( 259 source, 260 format_version=format_version, 261 perform_io_checks=perform_io_checks, 262 known_files=known_files, 263 sha256=sha256, 264 ) 265 assert rd.validation_summary is not None 266 return rd.validation_summary
same as load_description
, but only return the validation summary.
Returns:
Validation summary of the bioimage.io resource found at
source
.
57def load_description( 58 source: Union[PermissiveFileSource, ZipFile], 59 /, 60 *, 61 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 62 perform_io_checks: Optional[bool] = None, 63 known_files: Optional[Dict[str, Optional[Sha256]]] = None, 64 sha256: Optional[Sha256] = None, 65) -> Union[ResourceDescr, InvalidDescr]: 66 """load a bioimage.io resource description 67 68 Args: 69 source: 70 Path or URL to an rdf.yaml or a bioimage.io package 71 (zip-file with rdf.yaml in it). 72 format_version: 73 (optional) Use this argument to load the resource and 74 convert its metadata to a higher format_version. 75 Note: 76 - Use "latest" to convert to the latest available format version. 77 - Use "discover" to use the format version specified in the RDF. 78 - Only considers major.minor format version, ignores patch version. 79 - Conversion to lower format versions is not supported. 80 perform_io_checks: 81 Wether or not to perform validation that requires file io, 82 e.g. downloading a remote files. The existence of local 83 absolute file paths is still being checked. 84 known_files: 85 Allows to bypass download and hashing of referenced files 86 (even if perform_io_checks is True). 87 Checked files will be added to this dictionary 88 with their SHA-256 value. 89 sha256: 90 Optional SHA-256 value of **source** 91 92 Returns: 93 An object holding all metadata of the bioimage.io resource 94 95 """ 96 if isinstance(source, ResourceDescrBase): 97 name = getattr(source, "name", f"{str(source)[:10]}...") 98 logger.warning("returning already loaded description '{}' as is", name) 99 return source # pyright: ignore[reportReturnType] 100 101 opened = open_bioimageio_yaml(source, sha256=sha256) 102 103 context = get_validation_context().replace( 104 root=opened.original_root, 105 file_name=opened.original_file_name, 106 original_source_name=opened.original_source_name, 107 perform_io_checks=perform_io_checks, 108 known_files=known_files, 109 ) 110 111 return build_description( 112 opened.content, 113 context=context, 114 format_version=format_version, 115 )
load a bioimage.io resource description
Arguments:
- source: Path or URL to an rdf.yaml or a bioimage.io package (zip-file with rdf.yaml in it).
- format_version: (optional) Use this argument to load the resource and
convert its metadata to a higher format_version.
Note:
- Use "latest" to convert to the latest available format version.
- Use "discover" to use the format version specified in the RDF.
- Only considers major.minor format version, ignores patch version.
- Conversion to lower format versions is not supported.
- perform_io_checks: Wether or not to perform validation that requires file io, e.g. downloading a remote files. The existence of local absolute file paths is still being checked.
- known_files: Allows to bypass download and hashing of referenced files (even if perform_io_checks is True). Checked files will be added to this dictionary with their SHA-256 value.
- sha256: Optional SHA-256 value of source
Returns:
An object holding all metadata of the bioimage.io resource
142def load_model_description( 143 source: Union[PermissiveFileSource, ZipFile], 144 /, 145 *, 146 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 147 perform_io_checks: Optional[bool] = None, 148 known_files: Optional[Dict[str, Optional[Sha256]]] = None, 149 sha256: Optional[Sha256] = None, 150) -> AnyModelDescr: 151 """same as `load_description`, but addtionally ensures that the loaded 152 description is valid and of type 'model'. 153 154 Raises: 155 ValueError: for invalid or non-model resources 156 """ 157 rd = load_description( 158 source, 159 format_version=format_version, 160 perform_io_checks=perform_io_checks, 161 known_files=known_files, 162 sha256=sha256, 163 ) 164 return ensure_description_is_model(rd)
same as load_description
, but addtionally ensures that the loaded
description is valid and of type 'model'.
Raises:
- ValueError: for invalid or non-model resources
142def load_model_description( 143 source: Union[PermissiveFileSource, ZipFile], 144 /, 145 *, 146 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 147 perform_io_checks: Optional[bool] = None, 148 known_files: Optional[Dict[str, Optional[Sha256]]] = None, 149 sha256: Optional[Sha256] = None, 150) -> AnyModelDescr: 151 """same as `load_description`, but addtionally ensures that the loaded 152 description is valid and of type 'model'. 153 154 Raises: 155 ValueError: for invalid or non-model resources 156 """ 157 rd = load_description( 158 source, 159 format_version=format_version, 160 perform_io_checks=perform_io_checks, 161 known_files=known_files, 162 sha256=sha256, 163 ) 164 return ensure_description_is_model(rd)
alias of load_model_description
57def load_description( 58 source: Union[PermissiveFileSource, ZipFile], 59 /, 60 *, 61 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 62 perform_io_checks: Optional[bool] = None, 63 known_files: Optional[Dict[str, Optional[Sha256]]] = None, 64 sha256: Optional[Sha256] = None, 65) -> Union[ResourceDescr, InvalidDescr]: 66 """load a bioimage.io resource description 67 68 Args: 69 source: 70 Path or URL to an rdf.yaml or a bioimage.io package 71 (zip-file with rdf.yaml in it). 72 format_version: 73 (optional) Use this argument to load the resource and 74 convert its metadata to a higher format_version. 75 Note: 76 - Use "latest" to convert to the latest available format version. 77 - Use "discover" to use the format version specified in the RDF. 78 - Only considers major.minor format version, ignores patch version. 79 - Conversion to lower format versions is not supported. 80 perform_io_checks: 81 Wether or not to perform validation that requires file io, 82 e.g. downloading a remote files. The existence of local 83 absolute file paths is still being checked. 84 known_files: 85 Allows to bypass download and hashing of referenced files 86 (even if perform_io_checks is True). 87 Checked files will be added to this dictionary 88 with their SHA-256 value. 89 sha256: 90 Optional SHA-256 value of **source** 91 92 Returns: 93 An object holding all metadata of the bioimage.io resource 94 95 """ 96 if isinstance(source, ResourceDescrBase): 97 name = getattr(source, "name", f"{str(source)[:10]}...") 98 logger.warning("returning already loaded description '{}' as is", name) 99 return source # pyright: ignore[reportReturnType] 100 101 opened = open_bioimageio_yaml(source, sha256=sha256) 102 103 context = get_validation_context().replace( 104 root=opened.original_root, 105 file_name=opened.original_file_name, 106 original_source_name=opened.original_source_name, 107 perform_io_checks=perform_io_checks, 108 known_files=known_files, 109 ) 110 111 return build_description( 112 opened.content, 113 context=context, 114 format_version=format_version, 115 )
alias of load_description
131def predict_many( 132 *, 133 model: Union[ 134 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 135 ], 136 inputs: Union[Iterable[PerMember[TensorSource]], Iterable[TensorSource]], 137 sample_id: str = "sample{i:03}", 138 blocksize_parameter: Optional[ 139 Union[ 140 v0_5.ParameterizedSize_N, 141 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 142 ] 143 ] = None, 144 skip_preprocessing: bool = False, 145 skip_postprocessing: bool = False, 146 save_output_path: Optional[Union[Path, str]] = None, 147) -> Iterator[Sample]: 148 """Run prediction for a multiple sets of inputs with a bioimage.io model 149 150 Args: 151 model: Model to predict with. 152 May be given as RDF source, model description or prediction pipeline. 153 inputs: An iterable of the named input(s) for this model as a dictionary. 154 sample_id: The sample id. 155 note: `{i}` will be formatted as the i-th sample. 156 If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}` 157 is appended. 158 blocksize_parameter: (optional) Tile the input into blocks parametrized by 159 blocksize according to any parametrized axis sizes defined in the model RDF. 160 skip_preprocessing: Flag to skip the model's preprocessing. 161 skip_postprocessing: Flag to skip the model's postprocessing. 162 save_output_path: A path to save the output to. 163 Must contain: 164 - `{sample_id}` to differentiate predicted samples 165 - `{output_id}` (or `{member_id}`) if the model has multiple outputs 166 """ 167 if save_output_path is not None and "{sample_id}" not in str(save_output_path): 168 raise ValueError( 169 f"Missing `{{sample_id}}` in save_output_path={save_output_path}" 170 + " to differentiate predicted samples." 171 ) 172 173 if isinstance(model, PredictionPipeline): 174 pp = model 175 else: 176 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 177 loaded = load_description(model) 178 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 179 raise ValueError(f"expected model description, but got {loaded}") 180 model = loaded 181 182 pp = create_prediction_pipeline(model) 183 184 if not isinstance(inputs, collections.abc.Mapping): 185 if "{i}" not in sample_id and "{i:" not in sample_id: 186 sample_id += "{i:03}" 187 188 total = len(inputs) if isinstance(inputs, collections.abc.Sized) else None 189 190 for i, ipts in tqdm(enumerate(inputs), total=total): 191 yield predict( 192 model=pp, 193 inputs=ipts, 194 sample_id=sample_id.format(i=i), 195 blocksize_parameter=blocksize_parameter, 196 skip_preprocessing=skip_preprocessing, 197 skip_postprocessing=skip_postprocessing, 198 save_output_path=save_output_path, 199 )
Run prediction for a multiple sets of inputs with a bioimage.io model
Arguments:
- model: Model to predict with. May be given as RDF source, model description or prediction pipeline.
- inputs: An iterable of the named input(s) for this model as a dictionary.
- sample_id: The sample id.
note:
{i}
will be formatted as the i-th sample. If{i}
(or{i:
) is not present andinputs
is not an iterable{i:03}
is appended. - blocksize_parameter: (optional) Tile the input into blocks parametrized by blocksize according to any parametrized axis sizes defined in the model RDF.
- skip_preprocessing: Flag to skip the model's preprocessing.
- skip_postprocessing: Flag to skip the model's postprocessing.
- save_output_path: A path to save the output to.
Must contain:
{sample_id}
to differentiate predicted samples{output_id}
(or{member_id}
) if the model has multiple outputs
29def predict( 30 *, 31 model: Union[ 32 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 33 ], 34 inputs: Union[Sample, PerMember[TensorSource], TensorSource], 35 sample_id: Hashable = "sample", 36 blocksize_parameter: Optional[BlocksizeParameter] = None, 37 input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None, 38 skip_preprocessing: bool = False, 39 skip_postprocessing: bool = False, 40 save_output_path: Optional[Union[Path, str]] = None, 41) -> Sample: 42 """Run prediction for a single set of input(s) with a bioimage.io model 43 44 Args: 45 model: Model to predict with. 46 May be given as RDF source, model description or prediction pipeline. 47 inputs: the input sample or the named input(s) for this model as a dictionary 48 sample_id: the sample id. 49 The **sample_id** is used to format **save_output_path** 50 and to distinguish sample specific log messages. 51 blocksize_parameter: (optional) Tile the input into blocks parametrized by 52 **blocksize_parameter** according to any parametrized axis sizes defined 53 by the **model**. 54 See `bioimageio.spec.model.v0_5.ParameterizedSize` for details. 55 Note: For a predetermined, fixed block shape use **input_block_shape**. 56 input_block_shape: (optional) Tile the input sample tensors into blocks. 57 Note: Use **blocksize_parameter** for a parameterized block shape to 58 run prediction independent of the exact block shape. 59 skip_preprocessing: Flag to skip the model's preprocessing. 60 skip_postprocessing: Flag to skip the model's postprocessing. 61 save_output_path: A path with to save the output to. M 62 Must contain: 63 - `{output_id}` (or `{member_id}`) if the model has multiple output tensors 64 May contain: 65 - `{sample_id}` to avoid overwriting recurrent calls 66 """ 67 if isinstance(model, PredictionPipeline): 68 pp = model 69 model = pp.model_description 70 else: 71 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 72 loaded = load_description(model) 73 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 74 raise ValueError(f"expected model description, but got {loaded}") 75 model = loaded 76 77 pp = create_prediction_pipeline(model) 78 79 if save_output_path is not None: 80 if ( 81 "{output_id}" not in str(save_output_path) 82 and "{member_id}" not in str(save_output_path) 83 and len(model.outputs) > 1 84 ): 85 raise ValueError( 86 f"Missing `{{output_id}}` in save_output_path={save_output_path} to " 87 + "distinguish model outputs " 88 + str([get_member_id(d) for d in model.outputs]) 89 ) 90 91 if isinstance(inputs, Sample): 92 sample = inputs 93 else: 94 sample = create_sample_for_model( 95 pp.model_description, inputs=inputs, sample_id=sample_id 96 ) 97 98 if input_block_shape is not None: 99 if blocksize_parameter is not None: 100 logger.warning( 101 "ignoring blocksize_parameter={} in favor of input_block_shape={}", 102 blocksize_parameter, 103 input_block_shape, 104 ) 105 106 output = pp.predict_sample_with_fixed_blocking( 107 sample, 108 input_block_shape=input_block_shape, 109 skip_preprocessing=skip_preprocessing, 110 skip_postprocessing=skip_postprocessing, 111 ) 112 elif blocksize_parameter is not None: 113 output = pp.predict_sample_with_blocking( 114 sample, 115 skip_preprocessing=skip_preprocessing, 116 skip_postprocessing=skip_postprocessing, 117 ns=blocksize_parameter, 118 ) 119 else: 120 output = pp.predict_sample_without_blocking( 121 sample, 122 skip_preprocessing=skip_preprocessing, 123 skip_postprocessing=skip_postprocessing, 124 ) 125 if save_output_path: 126 save_sample(save_output_path, output) 127 128 return output
Run prediction for a single set of input(s) with a bioimage.io model
Arguments:
- model: Model to predict with. May be given as RDF source, model description or prediction pipeline.
- inputs: the input sample or the named input(s) for this model as a dictionary
- sample_id: the sample id. The sample_id is used to format save_output_path and to distinguish sample specific log messages.
- blocksize_parameter: (optional) Tile the input into blocks parametrized by
blocksize_parameter according to any parametrized axis sizes defined
by the model.
See
bioimageio.spec.model.v0_5.ParameterizedSize
for details. Note: For a predetermined, fixed block shape use input_block_shape. - input_block_shape: (optional) Tile the input sample tensors into blocks. Note: Use blocksize_parameter for a parameterized block shape to run prediction independent of the exact block shape.
- skip_preprocessing: Flag to skip the model's preprocessing.
- skip_postprocessing: Flag to skip the model's postprocessing.
- save_output_path: A path with to save the output to. M
Must contain:
{output_id}
(or{member_id}
) if the model has multiple output tensors May contain:{sample_id}
to avoid overwriting recurrent calls
51class PredictionPipeline: 52 """ 53 Represents model computation including preprocessing and postprocessing 54 Note: Ideally use the `PredictionPipeline` in a with statement 55 (as a context manager). 56 """ 57 58 def __init__( 59 self, 60 *, 61 name: str, 62 model_description: AnyModelDescr, 63 preprocessing: List[Processing], 64 postprocessing: List[Processing], 65 model_adapter: ModelAdapter, 66 default_ns: Optional[BlocksizeParameter] = None, 67 default_blocksize_parameter: BlocksizeParameter = 10, 68 default_batch_size: int = 1, 69 ) -> None: 70 """Use `create_prediction_pipeline` to create a `PredictionPipeline`""" 71 super().__init__() 72 default_blocksize_parameter = default_ns or default_blocksize_parameter 73 if default_ns is not None: 74 warnings.warn( 75 "Argument `default_ns` is deprecated in favor of" 76 + " `default_blocksize_paramter` and will be removed soon." 77 ) 78 del default_ns 79 80 if model_description.run_mode: 81 warnings.warn( 82 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" 83 ) 84 85 self.name = name 86 self._preprocessing = preprocessing 87 self._postprocessing = postprocessing 88 89 self.model_description = model_description 90 if isinstance(model_description, v0_4.ModelDescr): 91 self._default_input_halo: PerMember[PerAxis[Halo]] = {} 92 self._block_transform = None 93 else: 94 default_output_halo = { 95 t.id: { 96 a.id: Halo(a.halo, a.halo) 97 for a in t.axes 98 if isinstance(a, v0_5.WithHalo) 99 } 100 for t in model_description.outputs 101 } 102 self._default_input_halo = get_input_halo( 103 model_description, default_output_halo 104 ) 105 self._block_transform = get_block_transform(model_description) 106 107 self._default_blocksize_parameter = default_blocksize_parameter 108 self._default_batch_size = default_batch_size 109 110 self._input_ids = get_member_ids(model_description.inputs) 111 self._output_ids = get_member_ids(model_description.outputs) 112 113 self._adapter: ModelAdapter = model_adapter 114 115 def __enter__(self): 116 self.load() 117 return self 118 119 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore 120 self.unload() 121 return False 122 123 def predict_sample_block( 124 self, 125 sample_block: SampleBlockWithOrigin, 126 skip_preprocessing: bool = False, 127 skip_postprocessing: bool = False, 128 ) -> SampleBlock: 129 if isinstance(self.model_description, v0_4.ModelDescr): 130 raise NotImplementedError( 131 f"predict_sample_block not implemented for model {self.model_description.format_version}" 132 ) 133 else: 134 assert self._block_transform is not None 135 136 if not skip_preprocessing: 137 self.apply_preprocessing(sample_block) 138 139 output_meta = sample_block.get_transformed_meta(self._block_transform) 140 local_output = self._adapter.forward(sample_block) 141 142 output = output_meta.with_data(local_output.members, stat=local_output.stat) 143 if not skip_postprocessing: 144 self.apply_postprocessing(output) 145 146 return output 147 148 def predict_sample_without_blocking( 149 self, 150 sample: Sample, 151 skip_preprocessing: bool = False, 152 skip_postprocessing: bool = False, 153 ) -> Sample: 154 """predict a sample. 155 The sample's tensor shapes have to match the model's input tensor description. 156 If that is not the case, consider `predict_sample_with_blocking`""" 157 158 if not skip_preprocessing: 159 self.apply_preprocessing(sample) 160 161 output = self._adapter.forward(sample) 162 if not skip_postprocessing: 163 self.apply_postprocessing(output) 164 165 return output 166 167 def get_output_sample_id(self, input_sample_id: SampleId): 168 warnings.warn( 169 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be" 170 + " removed soon. Output sample id is equal to input sample id, hence this" 171 + " function is not needed." 172 ) 173 return input_sample_id 174 175 def predict_sample_with_fixed_blocking( 176 self, 177 sample: Sample, 178 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]], 179 *, 180 skip_preprocessing: bool = False, 181 skip_postprocessing: bool = False, 182 ) -> Sample: 183 if not skip_preprocessing: 184 self.apply_preprocessing(sample) 185 186 n_blocks, input_blocks = sample.split_into_blocks( 187 input_block_shape, 188 halo=self._default_input_halo, 189 pad_mode="reflect", 190 ) 191 input_blocks = list(input_blocks) 192 predicted_blocks: List[SampleBlock] = [] 193 logger.info( 194 "split sample shape {} into {} blocks of {}.", 195 {k: dict(v) for k, v in sample.shape.items()}, 196 n_blocks, 197 {k: dict(v) for k, v in input_block_shape.items()}, 198 ) 199 for b in tqdm( 200 input_blocks, 201 desc=f"predict {sample.id or ''} with {self.model_description.id or self.model_description.name}", 202 unit="block", 203 unit_divisor=1, 204 total=n_blocks, 205 ): 206 predicted_blocks.append( 207 self.predict_sample_block( 208 b, skip_preprocessing=True, skip_postprocessing=True 209 ) 210 ) 211 212 predicted_sample = Sample.from_blocks(predicted_blocks) 213 if not skip_postprocessing: 214 self.apply_postprocessing(predicted_sample) 215 216 return predicted_sample 217 218 def predict_sample_with_blocking( 219 self, 220 sample: Sample, 221 skip_preprocessing: bool = False, 222 skip_postprocessing: bool = False, 223 ns: Optional[ 224 Union[ 225 v0_5.ParameterizedSize_N, 226 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 227 ] 228 ] = None, 229 batch_size: Optional[int] = None, 230 ) -> Sample: 231 """predict a sample by splitting it into blocks according to the model and the `ns` parameter""" 232 233 if isinstance(self.model_description, v0_4.ModelDescr): 234 raise NotImplementedError( 235 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" 236 + f" {self.model_description.name}." 237 + " Consider using `predict_sample_with_fixed_blocking`" 238 ) 239 240 ns = ns or self._default_blocksize_parameter 241 if isinstance(ns, int): 242 ns = { 243 (ipt.id, a.id): ns 244 for ipt in self.model_description.inputs 245 for a in ipt.axes 246 if isinstance(a.size, v0_5.ParameterizedSize) 247 } 248 input_block_shape = self.model_description.get_tensor_sizes( 249 ns, batch_size or self._default_batch_size 250 ).inputs 251 252 return self.predict_sample_with_fixed_blocking( 253 sample, 254 input_block_shape=input_block_shape, 255 skip_preprocessing=skip_preprocessing, 256 skip_postprocessing=skip_postprocessing, 257 ) 258 259 # def predict( 260 # self, 261 # inputs: Predict_IO, 262 # skip_preprocessing: bool = False, 263 # skip_postprocessing: bool = False, 264 # ) -> Predict_IO: 265 # """Run model prediction **including** pre/postprocessing.""" 266 267 # if isinstance(inputs, Sample): 268 # return self.predict_sample_with_blocking( 269 # inputs, 270 # skip_preprocessing=skip_preprocessing, 271 # skip_postprocessing=skip_postprocessing, 272 # ) 273 # elif isinstance(inputs, collections.abc.Iterable): 274 # return ( 275 # self.predict( 276 # ipt, 277 # skip_preprocessing=skip_preprocessing, 278 # skip_postprocessing=skip_postprocessing, 279 # ) 280 # for ipt in inputs 281 # ) 282 # else: 283 # assert_never(inputs) 284 285 def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 286 """apply preprocessing in-place, also updates sample stats""" 287 for op in self._preprocessing: 288 op(sample) 289 290 def apply_postprocessing( 291 self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] 292 ) -> None: 293 """apply postprocessing in-place, also updates samples stats""" 294 for op in self._postprocessing: 295 if isinstance(sample, (Sample, SampleBlockWithOrigin)): 296 op(sample) 297 elif not isinstance(op, BlockedOperator): 298 raise NotImplementedError( 299 "block wise update of output statistics not yet implemented" 300 ) 301 else: 302 op(sample) 303 304 def load(self): 305 """ 306 optional step: load model onto devices before calling forward if not using it as context manager 307 """ 308 pass 309 310 def unload(self): 311 """ 312 free any device memory in use 313 """ 314 self._adapter.unload()
Represents model computation including preprocessing and postprocessing
Note: Ideally use the PredictionPipeline
in a with statement
(as a context manager).
58 def __init__( 59 self, 60 *, 61 name: str, 62 model_description: AnyModelDescr, 63 preprocessing: List[Processing], 64 postprocessing: List[Processing], 65 model_adapter: ModelAdapter, 66 default_ns: Optional[BlocksizeParameter] = None, 67 default_blocksize_parameter: BlocksizeParameter = 10, 68 default_batch_size: int = 1, 69 ) -> None: 70 """Use `create_prediction_pipeline` to create a `PredictionPipeline`""" 71 super().__init__() 72 default_blocksize_parameter = default_ns or default_blocksize_parameter 73 if default_ns is not None: 74 warnings.warn( 75 "Argument `default_ns` is deprecated in favor of" 76 + " `default_blocksize_paramter` and will be removed soon." 77 ) 78 del default_ns 79 80 if model_description.run_mode: 81 warnings.warn( 82 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" 83 ) 84 85 self.name = name 86 self._preprocessing = preprocessing 87 self._postprocessing = postprocessing 88 89 self.model_description = model_description 90 if isinstance(model_description, v0_4.ModelDescr): 91 self._default_input_halo: PerMember[PerAxis[Halo]] = {} 92 self._block_transform = None 93 else: 94 default_output_halo = { 95 t.id: { 96 a.id: Halo(a.halo, a.halo) 97 for a in t.axes 98 if isinstance(a, v0_5.WithHalo) 99 } 100 for t in model_description.outputs 101 } 102 self._default_input_halo = get_input_halo( 103 model_description, default_output_halo 104 ) 105 self._block_transform = get_block_transform(model_description) 106 107 self._default_blocksize_parameter = default_blocksize_parameter 108 self._default_batch_size = default_batch_size 109 110 self._input_ids = get_member_ids(model_description.inputs) 111 self._output_ids = get_member_ids(model_description.outputs) 112 113 self._adapter: ModelAdapter = model_adapter
Use create_prediction_pipeline
to create a PredictionPipeline
123 def predict_sample_block( 124 self, 125 sample_block: SampleBlockWithOrigin, 126 skip_preprocessing: bool = False, 127 skip_postprocessing: bool = False, 128 ) -> SampleBlock: 129 if isinstance(self.model_description, v0_4.ModelDescr): 130 raise NotImplementedError( 131 f"predict_sample_block not implemented for model {self.model_description.format_version}" 132 ) 133 else: 134 assert self._block_transform is not None 135 136 if not skip_preprocessing: 137 self.apply_preprocessing(sample_block) 138 139 output_meta = sample_block.get_transformed_meta(self._block_transform) 140 local_output = self._adapter.forward(sample_block) 141 142 output = output_meta.with_data(local_output.members, stat=local_output.stat) 143 if not skip_postprocessing: 144 self.apply_postprocessing(output) 145 146 return output
148 def predict_sample_without_blocking( 149 self, 150 sample: Sample, 151 skip_preprocessing: bool = False, 152 skip_postprocessing: bool = False, 153 ) -> Sample: 154 """predict a sample. 155 The sample's tensor shapes have to match the model's input tensor description. 156 If that is not the case, consider `predict_sample_with_blocking`""" 157 158 if not skip_preprocessing: 159 self.apply_preprocessing(sample) 160 161 output = self._adapter.forward(sample) 162 if not skip_postprocessing: 163 self.apply_postprocessing(output) 164 165 return output
predict a sample.
The sample's tensor shapes have to match the model's input tensor description.
If that is not the case, consider predict_sample_with_blocking
167 def get_output_sample_id(self, input_sample_id: SampleId): 168 warnings.warn( 169 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be" 170 + " removed soon. Output sample id is equal to input sample id, hence this" 171 + " function is not needed." 172 ) 173 return input_sample_id
175 def predict_sample_with_fixed_blocking( 176 self, 177 sample: Sample, 178 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]], 179 *, 180 skip_preprocessing: bool = False, 181 skip_postprocessing: bool = False, 182 ) -> Sample: 183 if not skip_preprocessing: 184 self.apply_preprocessing(sample) 185 186 n_blocks, input_blocks = sample.split_into_blocks( 187 input_block_shape, 188 halo=self._default_input_halo, 189 pad_mode="reflect", 190 ) 191 input_blocks = list(input_blocks) 192 predicted_blocks: List[SampleBlock] = [] 193 logger.info( 194 "split sample shape {} into {} blocks of {}.", 195 {k: dict(v) for k, v in sample.shape.items()}, 196 n_blocks, 197 {k: dict(v) for k, v in input_block_shape.items()}, 198 ) 199 for b in tqdm( 200 input_blocks, 201 desc=f"predict {sample.id or ''} with {self.model_description.id or self.model_description.name}", 202 unit="block", 203 unit_divisor=1, 204 total=n_blocks, 205 ): 206 predicted_blocks.append( 207 self.predict_sample_block( 208 b, skip_preprocessing=True, skip_postprocessing=True 209 ) 210 ) 211 212 predicted_sample = Sample.from_blocks(predicted_blocks) 213 if not skip_postprocessing: 214 self.apply_postprocessing(predicted_sample) 215 216 return predicted_sample
218 def predict_sample_with_blocking( 219 self, 220 sample: Sample, 221 skip_preprocessing: bool = False, 222 skip_postprocessing: bool = False, 223 ns: Optional[ 224 Union[ 225 v0_5.ParameterizedSize_N, 226 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 227 ] 228 ] = None, 229 batch_size: Optional[int] = None, 230 ) -> Sample: 231 """predict a sample by splitting it into blocks according to the model and the `ns` parameter""" 232 233 if isinstance(self.model_description, v0_4.ModelDescr): 234 raise NotImplementedError( 235 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" 236 + f" {self.model_description.name}." 237 + " Consider using `predict_sample_with_fixed_blocking`" 238 ) 239 240 ns = ns or self._default_blocksize_parameter 241 if isinstance(ns, int): 242 ns = { 243 (ipt.id, a.id): ns 244 for ipt in self.model_description.inputs 245 for a in ipt.axes 246 if isinstance(a.size, v0_5.ParameterizedSize) 247 } 248 input_block_shape = self.model_description.get_tensor_sizes( 249 ns, batch_size or self._default_batch_size 250 ).inputs 251 252 return self.predict_sample_with_fixed_blocking( 253 sample, 254 input_block_shape=input_block_shape, 255 skip_preprocessing=skip_preprocessing, 256 skip_postprocessing=skip_postprocessing, 257 )
predict a sample by splitting it into blocks according to the model and the ns
parameter
285 def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 286 """apply preprocessing in-place, also updates sample stats""" 287 for op in self._preprocessing: 288 op(sample)
apply preprocessing in-place, also updates sample stats
290 def apply_postprocessing( 291 self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] 292 ) -> None: 293 """apply postprocessing in-place, also updates samples stats""" 294 for op in self._postprocessing: 295 if isinstance(sample, (Sample, SampleBlockWithOrigin)): 296 op(sample) 297 elif not isinstance(op, BlockedOperator): 298 raise NotImplementedError( 299 "block wise update of output statistics not yet implemented" 300 ) 301 else: 302 op(sample)
apply postprocessing in-place, also updates samples stats
46@dataclass 47class Sample: 48 """A dataset sample. 49 50 A `Sample` has `members`, which allows to combine multiple tensors into a single 51 sample. 52 For example a `Sample` from a dataset with masked images may contain a 53 `MemberId("raw")` and `MemberId("mask")` image. 54 """ 55 56 members: Dict[MemberId, Tensor] 57 """The sample's tensors""" 58 59 stat: Stat 60 """Sample and dataset statistics""" 61 62 id: SampleId 63 """Identifies the `Sample` within the dataset -- typically a number or a string.""" 64 65 @property 66 def shape(self) -> PerMember[PerAxis[int]]: 67 return {tid: t.sizes for tid, t in self.members.items()} 68 69 def as_arrays(self) -> Dict[str, NDArray[Any]]: 70 """Return sample as dictionary of arrays.""" 71 return {str(m): t.data.to_numpy() for m, t in self.members.items()} 72 73 def split_into_blocks( 74 self, 75 block_shapes: PerMember[PerAxis[int]], 76 halo: PerMember[PerAxis[HaloLike]], 77 pad_mode: PadMode, 78 broadcast: bool = False, 79 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: 80 assert not (missing := [m for m in block_shapes if m not in self.members]), ( 81 f"`block_shapes` specified for unknown members: {missing}" 82 ) 83 assert not (missing := [m for m in halo if m not in block_shapes]), ( 84 f"`halo` specified for members without `block_shape`: {missing}" 85 ) 86 87 n_blocks, blocks = split_multiple_shapes_into_blocks( 88 shapes=self.shape, 89 block_shapes=block_shapes, 90 halo=halo, 91 broadcast=broadcast, 92 ) 93 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode) 94 95 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): 96 if halo is None: 97 halo = {} 98 return SampleBlockWithOrigin( 99 sample_shape=self.shape, 100 sample_id=self.id, 101 blocks={ 102 m: Block( 103 sample_shape=self.shape[m], 104 data=data, 105 inner_slice={ 106 a: SliceInfo(0, s) for a, s in data.tagged_shape.items() 107 }, 108 halo=halo.get(m, {}), 109 block_index=0, 110 blocks_in_sample=1, 111 ) 112 for m, data in self.members.items() 113 }, 114 stat=self.stat, 115 origin=self, 116 block_index=0, 117 blocks_in_sample=1, 118 ) 119 120 @classmethod 121 def from_blocks( 122 cls, 123 sample_blocks: Iterable[SampleBlock], 124 *, 125 fill_value: float = float("nan"), 126 ) -> Self: 127 members: PerMember[Tensor] = {} 128 stat: Stat = {} 129 sample_id = None 130 for sample_block in sample_blocks: 131 assert sample_id is None or sample_id == sample_block.sample_id 132 sample_id = sample_block.sample_id 133 stat = sample_block.stat 134 for m, block in sample_block.blocks.items(): 135 if m not in members: 136 if -1 in block.sample_shape.values(): 137 raise NotImplementedError( 138 "merging blocks with data dependent axis not yet implemented" 139 ) 140 141 members[m] = Tensor( 142 np.full( 143 tuple(block.sample_shape[a] for a in block.data.dims), 144 fill_value, 145 dtype=block.data.dtype, 146 ), 147 dims=block.data.dims, 148 ) 149 150 members[m][block.inner_slice] = block.inner_data 151 152 return cls(members=members, stat=stat, id=sample_id)
A dataset sample.
A Sample
has members
, which allows to combine multiple tensors into a single
sample.
For example a Sample
from a dataset with masked images may contain a
MemberId("raw")
and MemberId("mask")
image.
Sample and dataset statistics
69 def as_arrays(self) -> Dict[str, NDArray[Any]]: 70 """Return sample as dictionary of arrays.""" 71 return {str(m): t.data.to_numpy() for m, t in self.members.items()}
Return sample as dictionary of arrays.
73 def split_into_blocks( 74 self, 75 block_shapes: PerMember[PerAxis[int]], 76 halo: PerMember[PerAxis[HaloLike]], 77 pad_mode: PadMode, 78 broadcast: bool = False, 79 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: 80 assert not (missing := [m for m in block_shapes if m not in self.members]), ( 81 f"`block_shapes` specified for unknown members: {missing}" 82 ) 83 assert not (missing := [m for m in halo if m not in block_shapes]), ( 84 f"`halo` specified for members without `block_shape`: {missing}" 85 ) 86 87 n_blocks, blocks = split_multiple_shapes_into_blocks( 88 shapes=self.shape, 89 block_shapes=block_shapes, 90 halo=halo, 91 broadcast=broadcast, 92 ) 93 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode)
95 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): 96 if halo is None: 97 halo = {} 98 return SampleBlockWithOrigin( 99 sample_shape=self.shape, 100 sample_id=self.id, 101 blocks={ 102 m: Block( 103 sample_shape=self.shape[m], 104 data=data, 105 inner_slice={ 106 a: SliceInfo(0, s) for a, s in data.tagged_shape.items() 107 }, 108 halo=halo.get(m, {}), 109 block_index=0, 110 blocks_in_sample=1, 111 ) 112 for m, data in self.members.items() 113 }, 114 stat=self.stat, 115 origin=self, 116 block_index=0, 117 blocks_in_sample=1, 118 )
120 @classmethod 121 def from_blocks( 122 cls, 123 sample_blocks: Iterable[SampleBlock], 124 *, 125 fill_value: float = float("nan"), 126 ) -> Self: 127 members: PerMember[Tensor] = {} 128 stat: Stat = {} 129 sample_id = None 130 for sample_block in sample_blocks: 131 assert sample_id is None or sample_id == sample_block.sample_id 132 sample_id = sample_block.sample_id 133 stat = sample_block.stat 134 for m, block in sample_block.blocks.items(): 135 if m not in members: 136 if -1 in block.sample_shape.values(): 137 raise NotImplementedError( 138 "merging blocks with data dependent axis not yet implemented" 139 ) 140 141 members[m] = Tensor( 142 np.full( 143 tuple(block.sample_shape[a] for a in block.data.dims), 144 fill_value, 145 dtype=block.data.dtype, 146 ), 147 dims=block.data.dims, 148 ) 149 150 members[m][block.inner_slice] = block.inner_data 151 152 return cls(members=members, stat=stat, id=sample_id)
150def save_bioimageio_package_as_folder( 151 source: Union[BioimageioYamlSource, ResourceDescr], 152 /, 153 *, 154 output_path: Union[NewPath, DirectoryPath, None] = None, 155 weights_priority_order: Optional[ # model only 156 Sequence[ 157 Literal[ 158 "keras_hdf5", 159 "onnx", 160 "pytorch_state_dict", 161 "tensorflow_js", 162 "tensorflow_saved_model_bundle", 163 "torchscript", 164 ] 165 ] 166 ] = None, 167) -> DirectoryPath: 168 """Write the content of a bioimage.io resource package to a folder. 169 170 Args: 171 source: bioimageio resource description 172 output_path: file path to write package to 173 weights_priority_order: If given only the first weights format present in the model is included. 174 If none of the prioritized weights formats is found all are included. 175 176 Returns: 177 directory path to bioimageio package folder 178 """ 179 package_content = _prepare_resource_package( 180 source, 181 weights_priority_order=weights_priority_order, 182 ) 183 if output_path is None: 184 output_path = Path(mkdtemp()) 185 else: 186 output_path = Path(output_path) 187 188 output_path.mkdir(exist_ok=True, parents=True) 189 for name, src in package_content.items(): 190 if isinstance(src, collections.abc.Mapping): 191 write_yaml(src, output_path / name) 192 elif ( 193 isinstance(src.original_root, Path) 194 and src.original_root / src.original_file_name 195 == (output_path / name).resolve() 196 ): 197 logger.debug( 198 f"Not copying {src.original_root / src.original_file_name} to itself." 199 ) 200 else: 201 if isinstance(src.original_root, Path): 202 logger.debug( 203 f"Copying from path {src.original_root / src.original_file_name} to {output_path / name}." 204 ) 205 else: 206 logger.debug( 207 f"Copying {src.original_root}/{src.original_file_name} to {output_path / name}." 208 ) 209 with (output_path / name).open("wb") as dest: 210 _ = shutil.copyfileobj(src, dest) 211 212 return output_path
Write the content of a bioimage.io resource package to a folder.
Arguments:
- source: bioimageio resource description
- output_path: file path to write package to
- weights_priority_order: If given only the first weights format present in the model is included. If none of the prioritized weights formats is found all are included.
Returns:
directory path to bioimageio package folder
215def save_bioimageio_package( 216 source: Union[BioimageioYamlSource, ResourceDescr], 217 /, 218 *, 219 compression: int = ZIP_DEFLATED, 220 compression_level: int = 1, 221 output_path: Union[NewPath, FilePath, None] = None, 222 weights_priority_order: Optional[ # model only 223 Sequence[ 224 Literal[ 225 "keras_hdf5", 226 "onnx", 227 "pytorch_state_dict", 228 "tensorflow_js", 229 "tensorflow_saved_model_bundle", 230 "torchscript", 231 ] 232 ] 233 ] = None, 234 allow_invalid: bool = False, 235) -> FilePath: 236 """Package a bioimageio resource as a zip file. 237 238 Args: 239 rd: bioimageio resource description 240 compression: The numeric constant of compression method. 241 compression_level: Compression level to use when writing files to the archive. 242 See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile 243 output_path: file path to write package to 244 weights_priority_order: If given only the first weights format present in the model is included. 245 If none of the prioritized weights formats is found all are included. 246 247 Returns: 248 path to zipped bioimageio package 249 """ 250 package_content = _prepare_resource_package( 251 source, 252 weights_priority_order=weights_priority_order, 253 ) 254 if output_path is None: 255 output_path = Path( 256 NamedTemporaryFile(suffix=".bioimageio.zip", delete=False).name 257 ) 258 else: 259 output_path = Path(output_path) 260 261 write_zip( 262 output_path, 263 package_content, 264 compression=compression, 265 compression_level=compression_level, 266 ) 267 with get_validation_context().replace(warning_level=ERROR): 268 if isinstance((exported := load_description(output_path)), InvalidDescr): 269 exported.validation_summary.display() 270 msg = f"Exported package at '{output_path}' is invalid." 271 if allow_invalid: 272 logger.error(msg) 273 else: 274 raise ValueError(msg) 275 276 return output_path
Package a bioimageio resource as a zip file.
Arguments:
- rd: bioimageio resource description
- compression: The numeric constant of compression method.
- compression_level: Compression level to use when writing files to the archive. See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile
- output_path: file path to write package to
- weights_priority_order: If given only the first weights format present in the model is included. If none of the prioritized weights formats is found all are included.
Returns:
path to zipped bioimageio package
213def save_bioimageio_yaml_only( 214 rd: Union[ResourceDescr, BioimageioYamlContent, InvalidDescr], 215 /, 216 file: Union[NewPath, FilePath, TextIO], 217 *, 218 exclude_unset: bool = True, 219 exclude_defaults: bool = False, 220): 221 """write the metadata of a resource description (`rd`) to `file` 222 without writing any of the referenced files in it. 223 224 Args: 225 rd: bioimageio resource description 226 file: file or stream to save to 227 exclude_unset: Exclude fields that have not explicitly be set. 228 exclude_defaults: Exclude fields that have the default value (even if set explicitly). 229 230 Note: To save a resource description with its associated files as a package, 231 use `save_bioimageio_package` or `save_bioimageio_package_as_folder`. 232 """ 233 if isinstance(rd, ResourceDescrBase): 234 content = dump_description( 235 rd, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults 236 ) 237 else: 238 content = rd 239 240 write_yaml(cast(YamlValue, content), file)
write the metadata of a resource description (rd
) to file
without writing any of the referenced files in it.
Arguments:
- rd: bioimageio resource description
- file: file or stream to save to
- exclude_unset: Exclude fields that have not explicitly be set.
- exclude_defaults: Exclude fields that have the default value (even if set explicitly).
Note: To save a resource description with its associated files as a package,
use save_bioimageio_package
or save_bioimageio_package_as_folder
.
49class Tensor(MagicTensorOpsMixin): 50 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 51 and improved type annotations.""" 52 53 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 54 55 def __init__( 56 self, 57 array: NDArray[Any], 58 dims: Sequence[Union[AxisId, AxisLike]], 59 ) -> None: 60 super().__init__() 61 axes = tuple( 62 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 63 ) 64 self._data = xr.DataArray(array, dims=axes) 65 66 def __array__(self, dtype: DTypeLike = None): 67 return np.asarray(self._data, dtype=dtype) 68 69 def __getitem__( 70 self, 71 key: Union[ 72 SliceInfo, 73 slice, 74 int, 75 PerAxis[Union[SliceInfo, slice, int]], 76 Tensor, 77 xr.DataArray, 78 ], 79 ) -> Self: 80 if isinstance(key, SliceInfo): 81 key = slice(*key) 82 elif isinstance(key, collections.abc.Mapping): 83 key = { 84 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 85 for a, s in key.items() 86 } 87 elif isinstance(key, Tensor): 88 key = key._data 89 90 return self.__class__.from_xarray(self._data[key]) 91 92 def __setitem__( 93 self, 94 key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray], 95 value: Union[Tensor, xr.DataArray, float, int], 96 ) -> None: 97 if isinstance(key, Tensor): 98 key = key._data 99 elif isinstance(key, xr.DataArray): 100 pass 101 else: 102 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 103 104 if isinstance(value, Tensor): 105 value = value._data 106 107 self._data[key] = value 108 109 def __len__(self) -> int: 110 return len(self.data) 111 112 def _iter(self: Any) -> Iterator[Any]: 113 for n in range(len(self)): 114 yield self[n] 115 116 def __iter__(self: Any) -> Iterator[Any]: 117 if self.ndim == 0: 118 raise TypeError("iteration over a 0-d array") 119 return self._iter() 120 121 def _binary_op( 122 self, 123 other: _Compatible, 124 f: Callable[[Any, Any], Any], 125 reflexive: bool = False, 126 ) -> Self: 127 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 128 (other._data if isinstance(other, Tensor) else other), 129 f, 130 reflexive, 131 ) 132 return self.__class__.from_xarray(data) 133 134 def _inplace_binary_op( 135 self, 136 other: _Compatible, 137 f: Callable[[Any, Any], Any], 138 ) -> Self: 139 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 140 ( 141 other_d 142 if (other_d := getattr(other, "data")) is not None 143 and isinstance( 144 other_d, 145 xr.DataArray, 146 ) 147 else other 148 ), 149 f, 150 ) 151 return self 152 153 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 154 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 155 f, *args, **kwargs 156 ) 157 return self.__class__.from_xarray(data) 158 159 @classmethod 160 def from_xarray(cls, data_array: xr.DataArray) -> Self: 161 """create a `Tensor` from an xarray data array 162 163 note for internal use: this factory method is round-trip save 164 for any `Tensor`'s `data` property (an xarray.DataArray). 165 """ 166 return cls( 167 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 168 ) 169 170 @classmethod 171 def from_numpy( 172 cls, 173 array: NDArray[Any], 174 *, 175 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 176 ) -> Tensor: 177 """create a `Tensor` from a numpy array 178 179 Args: 180 array: the nd numpy array 181 axes: A description of the array's axes, 182 if None axes are guessed (which might fail and raise a ValueError.) 183 184 Raises: 185 ValueError: if `axes` is None and axes guessing fails. 186 """ 187 188 if dims is None: 189 return cls._interprete_array_wo_known_axes(array) 190 elif isinstance(dims, collections.abc.Sequence): 191 dim_seq = list(dims) 192 else: 193 dim_seq = [dims] 194 195 axis_infos = [AxisInfo.create(a) for a in dim_seq] 196 original_shape = tuple(array.shape) 197 198 successful_view = _get_array_view(array, axis_infos) 199 if successful_view is None: 200 raise ValueError( 201 f"Array shape {original_shape} does not map to axes {dims}" 202 ) 203 204 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 205 206 @property 207 def data(self): 208 return self._data 209 210 @property 211 def dims(self): # TODO: rename to `axes`? 212 """Tuple of dimension names associated with this tensor.""" 213 return cast(Tuple[AxisId, ...], self._data.dims) 214 215 @property 216 def dtype(self) -> DTypeStr: 217 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 218 assert dt in get_args(DTypeStr) 219 return dt # pyright: ignore[reportReturnType] 220 221 @property 222 def ndim(self): 223 """Number of tensor dimensions.""" 224 return self._data.ndim 225 226 @property 227 def shape(self): 228 """Tuple of tensor axes lengths""" 229 return self._data.shape 230 231 @property 232 def shape_tuple(self): 233 """Tuple of tensor axes lengths""" 234 return self._data.shape 235 236 @property 237 def size(self): 238 """Number of elements in the tensor. 239 240 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 241 """ 242 return self._data.size 243 244 @property 245 def sizes(self): 246 """Ordered, immutable mapping from axis ids to axis lengths.""" 247 return cast(Mapping[AxisId, int], self.data.sizes) 248 249 @property 250 def tagged_shape(self): 251 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 252 return self.sizes 253 254 def argmax(self) -> Mapping[AxisId, int]: 255 ret = self._data.argmax(...) 256 assert isinstance(ret, dict) 257 return {cast(AxisId, k): cast(int, v.item()) for k, v in ret.items()} 258 259 def astype(self, dtype: DTypeStr, *, copy: bool = False): 260 """Return tensor cast to `dtype` 261 262 note: if dtype is already satisfied copy if `copy`""" 263 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 264 265 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 266 """Return a tensor whose values are limited to [min, max]. 267 At least one of max or min must be given.""" 268 return self.__class__.from_xarray(self._data.clip(min, max)) 269 270 def crop_to( 271 self, 272 sizes: PerAxis[int], 273 crop_where: Union[ 274 CropWhere, 275 PerAxis[CropWhere], 276 ] = "left_and_right", 277 ) -> Self: 278 """crop to match `sizes`""" 279 if isinstance(crop_where, str): 280 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 281 else: 282 crop_axis_where = crop_where 283 284 slices: Dict[AxisId, SliceInfo] = {} 285 286 for a, s_is in self.sizes.items(): 287 if a not in sizes or sizes[a] == s_is: 288 pass 289 elif sizes[a] > s_is: 290 logger.warning( 291 "Cannot crop axis {} of size {} to larger size {}", 292 a, 293 s_is, 294 sizes[a], 295 ) 296 elif a not in crop_axis_where: 297 raise ValueError( 298 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 299 ) 300 else: 301 crop_this_axis_where = crop_axis_where[a] 302 if crop_this_axis_where == "left": 303 slices[a] = SliceInfo(s_is - sizes[a], s_is) 304 elif crop_this_axis_where == "right": 305 slices[a] = SliceInfo(0, sizes[a]) 306 elif crop_this_axis_where == "left_and_right": 307 slices[a] = SliceInfo( 308 start := (s_is - sizes[a]) // 2, sizes[a] + start 309 ) 310 else: 311 assert_never(crop_this_axis_where) 312 313 return self[slices] 314 315 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 316 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 317 318 def item( 319 self, 320 key: Union[ 321 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 322 ] = None, 323 ): 324 """Copy a tensor element to a standard Python scalar and return it.""" 325 if key is None: 326 ret = self._data.item() 327 else: 328 ret = self[key]._data.item() 329 330 assert isinstance(ret, (bool, float, int)) 331 return ret 332 333 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 334 return self.__class__.from_xarray(self._data.mean(dim=dim)) 335 336 def pad( 337 self, 338 pad_width: PerAxis[PadWidthLike], 339 mode: PadMode = "symmetric", 340 ) -> Self: 341 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 342 return self.__class__.from_xarray( 343 self._data.pad(pad_width=pad_width, mode=mode) 344 ) 345 346 def pad_to( 347 self, 348 sizes: PerAxis[int], 349 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 350 mode: PadMode = "symmetric", 351 ) -> Self: 352 """pad `tensor` to match `sizes`""" 353 if isinstance(pad_where, str): 354 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 355 else: 356 pad_axis_where = pad_where 357 358 pad_width: Dict[AxisId, PadWidth] = {} 359 for a, s_is in self.sizes.items(): 360 if a not in sizes or sizes[a] == s_is: 361 pad_width[a] = PadWidth(0, 0) 362 elif s_is > sizes[a]: 363 pad_width[a] = PadWidth(0, 0) 364 logger.warning( 365 "Cannot pad axis {} of size {} to smaller size {}", 366 a, 367 s_is, 368 sizes[a], 369 ) 370 elif a not in pad_axis_where: 371 raise ValueError( 372 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 373 ) 374 else: 375 pad_this_axis_where = pad_axis_where[a] 376 d = sizes[a] - s_is 377 if pad_this_axis_where == "left": 378 pad_width[a] = PadWidth(d, 0) 379 elif pad_this_axis_where == "right": 380 pad_width[a] = PadWidth(0, d) 381 elif pad_this_axis_where == "left_and_right": 382 pad_width[a] = PadWidth(left := d // 2, d - left) 383 else: 384 assert_never(pad_this_axis_where) 385 386 return self.pad(pad_width, mode) 387 388 def quantile( 389 self, 390 q: Union[float, Sequence[float]], 391 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 392 ) -> Self: 393 assert ( 394 isinstance(q, (float, int)) 395 and q >= 0.0 396 or not isinstance(q, (float, int)) 397 and all(qq >= 0.0 for qq in q) 398 ) 399 assert ( 400 isinstance(q, (float, int)) 401 and q <= 1.0 402 or not isinstance(q, (float, int)) 403 and all(qq <= 1.0 for qq in q) 404 ) 405 assert dim is None or ( 406 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 407 ) 408 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 409 410 def resize_to( 411 self, 412 sizes: PerAxis[int], 413 *, 414 pad_where: Union[ 415 PadWhere, 416 PerAxis[PadWhere], 417 ] = "left_and_right", 418 crop_where: Union[ 419 CropWhere, 420 PerAxis[CropWhere], 421 ] = "left_and_right", 422 pad_mode: PadMode = "symmetric", 423 ): 424 """return cropped/padded tensor with `sizes`""" 425 crop_to_sizes: Dict[AxisId, int] = {} 426 pad_to_sizes: Dict[AxisId, int] = {} 427 new_axes = dict(sizes) 428 for a, s_is in self.sizes.items(): 429 a = AxisId(str(a)) 430 _ = new_axes.pop(a, None) 431 if a not in sizes or sizes[a] == s_is: 432 pass 433 elif s_is > sizes[a]: 434 crop_to_sizes[a] = sizes[a] 435 else: 436 pad_to_sizes[a] = sizes[a] 437 438 tensor = self 439 if crop_to_sizes: 440 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 441 442 if pad_to_sizes: 443 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 444 445 if new_axes: 446 tensor = tensor.expand_dims(new_axes) 447 448 return tensor 449 450 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 451 return self.__class__.from_xarray(self._data.std(dim=dim)) 452 453 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 454 """Reduce this Tensor's data by applying sum along some dimension(s).""" 455 return self.__class__.from_xarray(self._data.sum(dim=dim)) 456 457 def transpose( 458 self, 459 axes: Sequence[AxisId], 460 ) -> Self: 461 """return a transposed tensor 462 463 Args: 464 axes: the desired tensor axes 465 """ 466 # expand missing tensor axes 467 missing_axes = tuple(a for a in axes if a not in self.dims) 468 array = self._data 469 if missing_axes: 470 array = array.expand_dims(missing_axes) 471 472 # transpose to the correct axis order 473 return self.__class__.from_xarray(array.transpose(*axes)) 474 475 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 476 return self.__class__.from_xarray(self._data.var(dim=dim)) 477 478 @classmethod 479 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 480 ndim = array.ndim 481 if ndim == 2: 482 current_axes = ( 483 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 484 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 485 ) 486 elif ndim == 3 and any(s <= 3 for s in array.shape): 487 current_axes = ( 488 v0_5.ChannelAxis( 489 channel_names=[ 490 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 491 ] 492 ), 493 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 494 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 495 ) 496 elif ndim == 3: 497 current_axes = ( 498 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 499 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 500 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 501 ) 502 elif ndim == 4: 503 current_axes = ( 504 v0_5.ChannelAxis( 505 channel_names=[ 506 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 507 ] 508 ), 509 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 510 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 511 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 512 ) 513 elif ndim == 5: 514 current_axes = ( 515 v0_5.BatchAxis(), 516 v0_5.ChannelAxis( 517 channel_names=[ 518 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 519 ] 520 ), 521 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 522 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 523 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 524 ) 525 else: 526 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 527 528 return cls(array, dims=tuple(a.id for a in current_axes))
A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.
159 @classmethod 160 def from_xarray(cls, data_array: xr.DataArray) -> Self: 161 """create a `Tensor` from an xarray data array 162 163 note for internal use: this factory method is round-trip save 164 for any `Tensor`'s `data` property (an xarray.DataArray). 165 """ 166 return cls( 167 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 168 )
170 @classmethod 171 def from_numpy( 172 cls, 173 array: NDArray[Any], 174 *, 175 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 176 ) -> Tensor: 177 """create a `Tensor` from a numpy array 178 179 Args: 180 array: the nd numpy array 181 axes: A description of the array's axes, 182 if None axes are guessed (which might fail and raise a ValueError.) 183 184 Raises: 185 ValueError: if `axes` is None and axes guessing fails. 186 """ 187 188 if dims is None: 189 return cls._interprete_array_wo_known_axes(array) 190 elif isinstance(dims, collections.abc.Sequence): 191 dim_seq = list(dims) 192 else: 193 dim_seq = [dims] 194 195 axis_infos = [AxisInfo.create(a) for a in dim_seq] 196 original_shape = tuple(array.shape) 197 198 successful_view = _get_array_view(array, axis_infos) 199 if successful_view is None: 200 raise ValueError( 201 f"Array shape {original_shape} does not map to axes {dims}" 202 ) 203 204 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
create a Tensor
from a numpy array
Arguments:
- array: the nd numpy array
- axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
- ValueError: if
axes
is None and axes guessing fails.
210 @property 211 def dims(self): # TODO: rename to `axes`? 212 """Tuple of dimension names associated with this tensor.""" 213 return cast(Tuple[AxisId, ...], self._data.dims)
Tuple of dimension names associated with this tensor.
226 @property 227 def shape(self): 228 """Tuple of tensor axes lengths""" 229 return self._data.shape
Tuple of tensor axes lengths
231 @property 232 def shape_tuple(self): 233 """Tuple of tensor axes lengths""" 234 return self._data.shape
Tuple of tensor axes lengths
236 @property 237 def size(self): 238 """Number of elements in the tensor. 239 240 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 241 """ 242 return self._data.size
Number of elements in the tensor.
Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
244 @property 245 def sizes(self): 246 """Ordered, immutable mapping from axis ids to axis lengths.""" 247 return cast(Mapping[AxisId, int], self.data.sizes)
Ordered, immutable mapping from axis ids to axis lengths.
249 @property 250 def tagged_shape(self): 251 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 252 return self.sizes
(alias for sizes
) Ordered, immutable mapping from axis ids to lengths.
259 def astype(self, dtype: DTypeStr, *, copy: bool = False): 260 """Return tensor cast to `dtype` 261 262 note: if dtype is already satisfied copy if `copy`""" 263 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
Return tensor cast to dtype
note: if dtype is already satisfied copy if copy
265 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 266 """Return a tensor whose values are limited to [min, max]. 267 At least one of max or min must be given.""" 268 return self.__class__.from_xarray(self._data.clip(min, max))
Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.
270 def crop_to( 271 self, 272 sizes: PerAxis[int], 273 crop_where: Union[ 274 CropWhere, 275 PerAxis[CropWhere], 276 ] = "left_and_right", 277 ) -> Self: 278 """crop to match `sizes`""" 279 if isinstance(crop_where, str): 280 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 281 else: 282 crop_axis_where = crop_where 283 284 slices: Dict[AxisId, SliceInfo] = {} 285 286 for a, s_is in self.sizes.items(): 287 if a not in sizes or sizes[a] == s_is: 288 pass 289 elif sizes[a] > s_is: 290 logger.warning( 291 "Cannot crop axis {} of size {} to larger size {}", 292 a, 293 s_is, 294 sizes[a], 295 ) 296 elif a not in crop_axis_where: 297 raise ValueError( 298 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 299 ) 300 else: 301 crop_this_axis_where = crop_axis_where[a] 302 if crop_this_axis_where == "left": 303 slices[a] = SliceInfo(s_is - sizes[a], s_is) 304 elif crop_this_axis_where == "right": 305 slices[a] = SliceInfo(0, sizes[a]) 306 elif crop_this_axis_where == "left_and_right": 307 slices[a] = SliceInfo( 308 start := (s_is - sizes[a]) // 2, sizes[a] + start 309 ) 310 else: 311 assert_never(crop_this_axis_where) 312 313 return self[slices]
crop to match sizes
318 def item( 319 self, 320 key: Union[ 321 None, SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]] 322 ] = None, 323 ): 324 """Copy a tensor element to a standard Python scalar and return it.""" 325 if key is None: 326 ret = self._data.item() 327 else: 328 ret = self[key]._data.item() 329 330 assert isinstance(ret, (bool, float, int)) 331 return ret
Copy a tensor element to a standard Python scalar and return it.
346 def pad_to( 347 self, 348 sizes: PerAxis[int], 349 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 350 mode: PadMode = "symmetric", 351 ) -> Self: 352 """pad `tensor` to match `sizes`""" 353 if isinstance(pad_where, str): 354 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 355 else: 356 pad_axis_where = pad_where 357 358 pad_width: Dict[AxisId, PadWidth] = {} 359 for a, s_is in self.sizes.items(): 360 if a not in sizes or sizes[a] == s_is: 361 pad_width[a] = PadWidth(0, 0) 362 elif s_is > sizes[a]: 363 pad_width[a] = PadWidth(0, 0) 364 logger.warning( 365 "Cannot pad axis {} of size {} to smaller size {}", 366 a, 367 s_is, 368 sizes[a], 369 ) 370 elif a not in pad_axis_where: 371 raise ValueError( 372 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 373 ) 374 else: 375 pad_this_axis_where = pad_axis_where[a] 376 d = sizes[a] - s_is 377 if pad_this_axis_where == "left": 378 pad_width[a] = PadWidth(d, 0) 379 elif pad_this_axis_where == "right": 380 pad_width[a] = PadWidth(0, d) 381 elif pad_this_axis_where == "left_and_right": 382 pad_width[a] = PadWidth(left := d // 2, d - left) 383 else: 384 assert_never(pad_this_axis_where) 385 386 return self.pad(pad_width, mode)
pad tensor
to match sizes
388 def quantile( 389 self, 390 q: Union[float, Sequence[float]], 391 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 392 ) -> Self: 393 assert ( 394 isinstance(q, (float, int)) 395 and q >= 0.0 396 or not isinstance(q, (float, int)) 397 and all(qq >= 0.0 for qq in q) 398 ) 399 assert ( 400 isinstance(q, (float, int)) 401 and q <= 1.0 402 or not isinstance(q, (float, int)) 403 and all(qq <= 1.0 for qq in q) 404 ) 405 assert dim is None or ( 406 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 407 ) 408 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
410 def resize_to( 411 self, 412 sizes: PerAxis[int], 413 *, 414 pad_where: Union[ 415 PadWhere, 416 PerAxis[PadWhere], 417 ] = "left_and_right", 418 crop_where: Union[ 419 CropWhere, 420 PerAxis[CropWhere], 421 ] = "left_and_right", 422 pad_mode: PadMode = "symmetric", 423 ): 424 """return cropped/padded tensor with `sizes`""" 425 crop_to_sizes: Dict[AxisId, int] = {} 426 pad_to_sizes: Dict[AxisId, int] = {} 427 new_axes = dict(sizes) 428 for a, s_is in self.sizes.items(): 429 a = AxisId(str(a)) 430 _ = new_axes.pop(a, None) 431 if a not in sizes or sizes[a] == s_is: 432 pass 433 elif s_is > sizes[a]: 434 crop_to_sizes[a] = sizes[a] 435 else: 436 pad_to_sizes[a] = sizes[a] 437 438 tensor = self 439 if crop_to_sizes: 440 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 441 442 if pad_to_sizes: 443 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 444 445 if new_axes: 446 tensor = tensor.expand_dims(new_axes) 447 448 return tensor
return cropped/padded tensor with sizes
453 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 454 """Reduce this Tensor's data by applying sum along some dimension(s).""" 455 return self.__class__.from_xarray(self._data.sum(dim=dim))
Reduce this Tensor's data by applying sum along some dimension(s).
457 def transpose( 458 self, 459 axes: Sequence[AxisId], 460 ) -> Self: 461 """return a transposed tensor 462 463 Args: 464 axes: the desired tensor axes 465 """ 466 # expand missing tensor axes 467 missing_axes = tuple(a for a in axes if a not in self.dims) 468 array = self._data 469 if missing_axes: 470 array = array.expand_dims(missing_axes) 471 472 # transpose to the correct axis order 473 return self.__class__.from_xarray(array.transpose(*axes))
return a transposed tensor
Arguments:
- axes: the desired tensor axes
201def test_description( 202 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 203 *, 204 format_version: Union[FormatVersionPlaceholder, str] = "discover", 205 weight_format: Optional[SupportedWeightsFormat] = None, 206 devices: Optional[Sequence[str]] = None, 207 determinism: Literal["seed_only", "full"] = "seed_only", 208 expected_type: Optional[str] = None, 209 sha256: Optional[Sha256] = None, 210 stop_early: bool = True, 211 runtime_env: Union[ 212 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 213 ] = ("currently-active"), 214 run_command: Callable[[Sequence[str]], None] = default_run_command, 215 **deprecated: Unpack[DeprecatedKwargs], 216) -> ValidationSummary: 217 """Test a bioimage.io resource dynamically, 218 for example run prediction of test tensors for models. 219 220 Args: 221 source: model description source. 222 weight_format: Weight format to test. 223 Default: All weight formats present in **source**. 224 devices: Devices to test with, e.g. 'cpu', 'cuda'. 225 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 226 determinism: Modes to improve reproducibility of test outputs. 227 expected_type: Assert an expected resource description `type`. 228 sha256: Expected SHA256 value of **source**. 229 (Ignored if **source** already is a loaded `ResourceDescr` object.) 230 stop_early: Do not run further subtests after a failed one. 231 runtime_env: (Experimental feature!) The Python environment to run the tests in 232 - `"currently-active"`: Use active Python interpreter. 233 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 234 environment YAML file based on the model weights description. 235 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 236 Note: The `bioimageio.core` dependency will be added automatically if not present. 237 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 238 The function should raise an exception if the command fails. 239 **run_command** is ignored if **runtime_env** is `"currently-active"`. 240 """ 241 if runtime_env == "currently-active": 242 rd = load_description_and_test( 243 source, 244 format_version=format_version, 245 weight_format=weight_format, 246 devices=devices, 247 determinism=determinism, 248 expected_type=expected_type, 249 sha256=sha256, 250 stop_early=stop_early, 251 **deprecated, 252 ) 253 return rd.validation_summary 254 255 if runtime_env == "as-described": 256 conda_env = None 257 elif isinstance(runtime_env, (str, Path)): 258 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 259 elif isinstance(runtime_env, BioimageioCondaEnv): 260 conda_env = runtime_env 261 else: 262 assert_never(runtime_env) 263 264 try: 265 run_command(["thiscommandshouldalwaysfail", "please"]) 266 except Exception: 267 pass 268 else: 269 raise RuntimeError( 270 "given run_command does not raise an exception for a failing command" 271 ) 272 273 td_kwargs: Dict[str, Any] = ( 274 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 275 ) 276 with TemporaryDirectory(**td_kwargs) as _d: 277 working_dir = Path(_d) 278 if isinstance(source, (dict, ResourceDescrBase)): 279 file_source = save_bioimageio_package( 280 source, output_path=working_dir / "package.zip" 281 ) 282 else: 283 file_source = source 284 285 return _test_in_env( 286 file_source, 287 working_dir=working_dir, 288 weight_format=weight_format, 289 conda_env=conda_env, 290 devices=devices, 291 determinism=determinism, 292 expected_type=expected_type, 293 sha256=sha256, 294 stop_early=stop_early, 295 run_command=run_command, 296 **deprecated, 297 )
Test a bioimage.io resource dynamically, for example run prediction of test tensors for models.
Arguments:
- source: model description source.
- weight_format: Weight format to test. Default: All weight formats present in source.
- devices: Devices to test with, e.g. 'cpu', 'cuda'. Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
- determinism: Modes to improve reproducibility of test outputs.
- expected_type: Assert an expected resource description
type
. - sha256: Expected SHA256 value of source.
(Ignored if source already is a loaded
ResourceDescr
object.) - stop_early: Do not run further subtests after a failed one.
- runtime_env: (Experimental feature!) The Python environment to run the tests in
"currently-active"
: Use active Python interpreter."as-described"
: Usebioimageio.spec.get_conda_env
to generate a conda environment YAML file based on the model weights description.- A
BioimageioCondaEnv
or a path to a conda environment YAML file. Note: Thebioimageio.core
dependency will be added automatically if not present.
- run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess.
The function should raise an exception if the command fails.
run_command is ignored if runtime_env is
"currently-active"
.
173def test_model( 174 source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource], 175 weight_format: Optional[SupportedWeightsFormat] = None, 176 devices: Optional[List[str]] = None, 177 *, 178 determinism: Literal["seed_only", "full"] = "seed_only", 179 sha256: Optional[Sha256] = None, 180 stop_early: bool = True, 181 **deprecated: Unpack[DeprecatedKwargs], 182) -> ValidationSummary: 183 """Test model inference""" 184 return test_description( 185 source, 186 weight_format=weight_format, 187 devices=devices, 188 determinism=determinism, 189 expected_type="model", 190 sha256=sha256, 191 stop_early=stop_early, 192 **deprecated, 193 )
Test model inference
201def test_description( 202 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 203 *, 204 format_version: Union[FormatVersionPlaceholder, str] = "discover", 205 weight_format: Optional[SupportedWeightsFormat] = None, 206 devices: Optional[Sequence[str]] = None, 207 determinism: Literal["seed_only", "full"] = "seed_only", 208 expected_type: Optional[str] = None, 209 sha256: Optional[Sha256] = None, 210 stop_early: bool = True, 211 runtime_env: Union[ 212 Literal["currently-active", "as-described"], Path, BioimageioCondaEnv 213 ] = ("currently-active"), 214 run_command: Callable[[Sequence[str]], None] = default_run_command, 215 **deprecated: Unpack[DeprecatedKwargs], 216) -> ValidationSummary: 217 """Test a bioimage.io resource dynamically, 218 for example run prediction of test tensors for models. 219 220 Args: 221 source: model description source. 222 weight_format: Weight format to test. 223 Default: All weight formats present in **source**. 224 devices: Devices to test with, e.g. 'cpu', 'cuda'. 225 Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise. 226 determinism: Modes to improve reproducibility of test outputs. 227 expected_type: Assert an expected resource description `type`. 228 sha256: Expected SHA256 value of **source**. 229 (Ignored if **source** already is a loaded `ResourceDescr` object.) 230 stop_early: Do not run further subtests after a failed one. 231 runtime_env: (Experimental feature!) The Python environment to run the tests in 232 - `"currently-active"`: Use active Python interpreter. 233 - `"as-described"`: Use `bioimageio.spec.get_conda_env` to generate a conda 234 environment YAML file based on the model weights description. 235 - A `BioimageioCondaEnv` or a path to a conda environment YAML file. 236 Note: The `bioimageio.core` dependency will be added automatically if not present. 237 run_command: (Experimental feature!) Function to execute (conda) terminal commands in a subprocess. 238 The function should raise an exception if the command fails. 239 **run_command** is ignored if **runtime_env** is `"currently-active"`. 240 """ 241 if runtime_env == "currently-active": 242 rd = load_description_and_test( 243 source, 244 format_version=format_version, 245 weight_format=weight_format, 246 devices=devices, 247 determinism=determinism, 248 expected_type=expected_type, 249 sha256=sha256, 250 stop_early=stop_early, 251 **deprecated, 252 ) 253 return rd.validation_summary 254 255 if runtime_env == "as-described": 256 conda_env = None 257 elif isinstance(runtime_env, (str, Path)): 258 conda_env = BioimageioCondaEnv.model_validate(read_yaml(Path(runtime_env))) 259 elif isinstance(runtime_env, BioimageioCondaEnv): 260 conda_env = runtime_env 261 else: 262 assert_never(runtime_env) 263 264 try: 265 run_command(["thiscommandshouldalwaysfail", "please"]) 266 except Exception: 267 pass 268 else: 269 raise RuntimeError( 270 "given run_command does not raise an exception for a failing command" 271 ) 272 273 td_kwargs: Dict[str, Any] = ( 274 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 275 ) 276 with TemporaryDirectory(**td_kwargs) as _d: 277 working_dir = Path(_d) 278 if isinstance(source, (dict, ResourceDescrBase)): 279 file_source = save_bioimageio_package( 280 source, output_path=working_dir / "package.zip" 281 ) 282 else: 283 file_source = source 284 285 return _test_in_env( 286 file_source, 287 working_dir=working_dir, 288 weight_format=weight_format, 289 conda_env=conda_env, 290 devices=devices, 291 determinism=determinism, 292 expected_type=expected_type, 293 sha256=sha256, 294 stop_early=stop_early, 295 run_command=run_command, 296 **deprecated, 297 )
alias of test_description
212def validate_format( 213 data: BioimageioYamlContent, 214 /, 215 *, 216 format_version: Union[Literal["discover", "latest"], str] = DISCOVER, 217 context: Optional[ValidationContext] = None, 218) -> ValidationSummary: 219 """Validate a dictionary holding a bioimageio description. 220 See `bioimagieo.spec.load_description_and_validate_format_only` 221 to validate a file source. 222 223 Args: 224 data: Dictionary holding the raw bioimageio.yaml content. 225 format_version: 226 Format version to (update to and) use for validation. 227 Note: 228 - Use "latest" to convert to the latest available format version. 229 - Use "discover" to use the format version specified in the RDF. 230 - Only considers major.minor format version, ignores patch version. 231 - Conversion to lower format versions is not supported. 232 context: Validation context, see `bioimagieo.spec.ValidationContext` 233 234 Note: 235 Use `bioimagieo.spec.load_description_and_validate_format_only` to validate a 236 file source instead of loading the YAML content and creating the appropriate 237 `ValidationContext`. 238 239 Alternatively you can use `bioimagieo.spec.load_description` and access the 240 `validation_summary` attribute of the returned object. 241 """ 242 with context or get_validation_context(): 243 rd = build_description(data, format_version=format_version) 244 245 assert rd.validation_summary is not None 246 return rd.validation_summary
Validate a dictionary holding a bioimageio description.
See bioimagieo.spec.load_description_and_validate_format_only
to validate a file source.
Arguments:
- data: Dictionary holding the raw bioimageio.yaml content.
- format_version: Format version to (update to and) use for validation.
Note:
- Use "latest" to convert to the latest available format version.
- Use "discover" to use the format version specified in the RDF.
- Only considers major.minor format version, ignores patch version.
- Conversion to lower format versions is not supported.
- context: Validation context, see
bioimagieo.spec.ValidationContext
Note:
Use
bioimagieo.spec.load_description_and_validate_format_only
to validate a file source instead of loading the YAML content and creating the appropriateValidationContext
.Alternatively you can use
bioimagieo.spec.load_description
and access thevalidation_summary
attribute of the returned object.
243class ValidationSummary(BaseModel, extra="allow"): 244 """Summarizes output of all bioimageio validations and tests 245 for one specific `ResourceDescr` instance.""" 246 247 name: str 248 """Name of the validation""" 249 source_name: str 250 """Source of the validated bioimageio description""" 251 id: Optional[str] = None 252 """ID of the resource being validated""" 253 type: str 254 """Type of the resource being validated""" 255 format_version: str 256 """Format version of the resource being validated""" 257 status: Literal["passed", "valid-format", "failed"] 258 """overall status of the bioimageio validation""" 259 metadata_completeness: Annotated[float, annotated_types.Interval(ge=0, le=1)] = 0.0 260 """Estimate of completeness of the metadata in the resource description. 261 262 Note: This completeness estimate may change with subsequent releases 263 and should be considered bioimageio.spec version specific. 264 """ 265 266 details: List[ValidationDetail] 267 """List of validation details""" 268 env: Set[InstalledPackage] = Field( 269 default_factory=lambda: { 270 InstalledPackage( 271 name="bioimageio.spec", 272 version=VERSION, 273 ) 274 } 275 ) 276 """List of selected, relevant package versions""" 277 278 saved_conda_list: Optional[str] = None 279 280 @field_serializer("saved_conda_list") 281 def _save_conda_list(self, value: Optional[str]): 282 return self.conda_list 283 284 @property 285 def conda_list(self): 286 if self.saved_conda_list is None: 287 p = subprocess.run( 288 [CONDA_CMD, "list"], 289 stdout=subprocess.PIPE, 290 stderr=subprocess.STDOUT, 291 shell=False, 292 text=True, 293 ) 294 self.saved_conda_list = ( 295 p.stdout or f"`conda list` exited with {p.returncode}" 296 ) 297 298 return self.saved_conda_list 299 300 @property 301 def status_icon(self): 302 if self.status == "passed": 303 return "✔️" 304 elif self.status == "valid-format": 305 return "🟡" 306 else: 307 return "❌" 308 309 @property 310 def errors(self) -> List[ErrorEntry]: 311 return list(chain.from_iterable(d.errors for d in self.details)) 312 313 @property 314 def warnings(self) -> List[WarningEntry]: 315 return list(chain.from_iterable(d.warnings for d in self.details)) 316 317 def format( 318 self, 319 *, 320 width: Optional[int] = None, 321 include_conda_list: bool = False, 322 ): 323 """Format summary as Markdown string""" 324 return self._format( 325 width=width, target="md", include_conda_list=include_conda_list 326 ) 327 328 format_md = format 329 330 def format_html( 331 self, 332 *, 333 width: Optional[int] = None, 334 include_conda_list: bool = False, 335 ): 336 md_with_html = self._format( 337 target="html", width=width, include_conda_list=include_conda_list 338 ) 339 return markdown.markdown( 340 md_with_html, extensions=["tables", "fenced_code", "nl2br"] 341 ) 342 343 def display( 344 self, 345 *, 346 width: Optional[int] = None, 347 include_conda_list: bool = False, 348 tab_size: int = 4, 349 soft_wrap: bool = True, 350 ) -> None: 351 try: # render as HTML in Jupyter notebook 352 from IPython.core.getipython import get_ipython 353 from IPython.display import ( 354 display_html, # pyright: ignore[reportUnknownVariableType] 355 ) 356 except ImportError: 357 pass 358 else: 359 if get_ipython() is not None: 360 _ = display_html( 361 self.format_html( 362 width=width, include_conda_list=include_conda_list 363 ), 364 raw=True, 365 ) 366 return 367 368 # render with rich 369 _ = self._format( 370 target=rich.console.Console( 371 width=width, 372 tab_size=tab_size, 373 soft_wrap=soft_wrap, 374 ), 375 width=width, 376 include_conda_list=include_conda_list, 377 ) 378 379 def add_detail(self, detail: ValidationDetail): 380 if detail.status == "failed": 381 self.status = "failed" 382 elif detail.status != "passed": 383 assert_never(detail.status) 384 385 self.details.append(detail) 386 387 def log( 388 self, 389 to: Union[Literal["display"], Path, Sequence[Union[Literal["display"], Path]]], 390 ) -> List[Path]: 391 """Convenience method to display the validation summary in the terminal and/or 392 save it to disk. See `save` for details.""" 393 if to == "display": 394 display = True 395 save_to = [] 396 elif isinstance(to, Path): 397 display = False 398 save_to = [to] 399 else: 400 display = "display" in to 401 save_to = [p for p in to if p != "display"] 402 403 if display: 404 self.display() 405 406 return self.save(save_to) 407 408 def save( 409 self, path: Union[Path, Sequence[Path]] = Path("{id}_summary_{now}") 410 ) -> List[Path]: 411 """Save the validation/test summary in JSON, Markdown or HTML format. 412 413 Returns: 414 List of file paths the summary was saved to. 415 416 Notes: 417 - Format is chosen based on the suffix: `.json`, `.md`, `.html`. 418 - If **path** has no suffix it is assumed to be a direcotry to which a 419 `summary.json`, `summary.md` and `summary.html` are saved to. 420 """ 421 if isinstance(path, (str, Path)): 422 path = [Path(path)] 423 424 # folder to file paths 425 file_paths: List[Path] = [] 426 for p in path: 427 if p.suffix: 428 file_paths.append(p) 429 else: 430 file_paths.extend( 431 [ 432 p / "summary.json", 433 p / "summary.md", 434 p / "summary.html", 435 ] 436 ) 437 438 now = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") 439 for p in file_paths: 440 p = Path(str(p).format(id=self.id or "bioimageio", now=now)) 441 if p.suffix == ".json": 442 self.save_json(p) 443 elif p.suffix == ".md": 444 self.save_markdown(p) 445 elif p.suffix == ".html": 446 self.save_html(p) 447 else: 448 raise ValueError(f"Unknown summary path suffix '{p.suffix}'") 449 450 return file_paths 451 452 def save_json( 453 self, path: Path = Path("summary.json"), *, indent: Optional[int] = 2 454 ): 455 """Save validation/test summary as JSON file.""" 456 json_str = self.model_dump_json(indent=indent) 457 path.parent.mkdir(exist_ok=True, parents=True) 458 _ = path.write_text(json_str, encoding="utf-8") 459 logger.info("Saved summary to {}", path.absolute()) 460 461 def save_markdown(self, path: Path = Path("summary.md")): 462 """Save rendered validation/test summary as Markdown file.""" 463 formatted = self.format_md() 464 path.parent.mkdir(exist_ok=True, parents=True) 465 _ = path.write_text(formatted, encoding="utf-8") 466 logger.info("Saved Markdown formatted summary to {}", path.absolute()) 467 468 def save_html(self, path: Path = Path("summary.html")) -> None: 469 """Save rendered validation/test summary as HTML file.""" 470 path.parent.mkdir(exist_ok=True, parents=True) 471 472 html = self.format_html() 473 _ = path.write_text(html, encoding="utf-8") 474 logger.info("Saved HTML formatted summary to {}", path.absolute()) 475 476 @classmethod 477 def load_json(cls, path: Path) -> Self: 478 """Load validation/test summary from a suitable JSON file""" 479 json_str = Path(path).read_text(encoding="utf-8") 480 return cls.model_validate_json(json_str) 481 482 @field_validator("env", mode="before") 483 def _convert_dict(cls, value: List[Union[List[str], Dict[str, str]]]): 484 """convert old env value for backwards compatibility""" 485 if isinstance(value, list): 486 return [ 487 ( 488 (v["name"], v["version"], v.get("build", ""), v.get("channel", "")) 489 if isinstance(v, dict) and "name" in v and "version" in v 490 else v 491 ) 492 for v in value 493 ] 494 else: 495 return value 496 497 def _format( 498 self, 499 *, 500 target: Union[rich.console.Console, Literal["html", "md"]], 501 width: Optional[int], 502 include_conda_list: bool, 503 ): 504 return _format_summary( 505 self, 506 target=target, 507 width=width or 100, 508 include_conda_list=include_conda_list, 509 )
Summarizes output of all bioimageio validations and tests
for one specific ResourceDescr
instance.
Estimate of completeness of the metadata in the resource description.
Note: This completeness estimate may change with subsequent releases and should be considered bioimageio.spec version specific.
284 @property 285 def conda_list(self): 286 if self.saved_conda_list is None: 287 p = subprocess.run( 288 [CONDA_CMD, "list"], 289 stdout=subprocess.PIPE, 290 stderr=subprocess.STDOUT, 291 shell=False, 292 text=True, 293 ) 294 self.saved_conda_list = ( 295 p.stdout or f"`conda list` exited with {p.returncode}" 296 ) 297 298 return self.saved_conda_list
317 def format( 318 self, 319 *, 320 width: Optional[int] = None, 321 include_conda_list: bool = False, 322 ): 323 """Format summary as Markdown string""" 324 return self._format( 325 width=width, target="md", include_conda_list=include_conda_list 326 )
Format summary as Markdown string
317 def format( 318 self, 319 *, 320 width: Optional[int] = None, 321 include_conda_list: bool = False, 322 ): 323 """Format summary as Markdown string""" 324 return self._format( 325 width=width, target="md", include_conda_list=include_conda_list 326 )
Format summary as Markdown string
330 def format_html( 331 self, 332 *, 333 width: Optional[int] = None, 334 include_conda_list: bool = False, 335 ): 336 md_with_html = self._format( 337 target="html", width=width, include_conda_list=include_conda_list 338 ) 339 return markdown.markdown( 340 md_with_html, extensions=["tables", "fenced_code", "nl2br"] 341 )
343 def display( 344 self, 345 *, 346 width: Optional[int] = None, 347 include_conda_list: bool = False, 348 tab_size: int = 4, 349 soft_wrap: bool = True, 350 ) -> None: 351 try: # render as HTML in Jupyter notebook 352 from IPython.core.getipython import get_ipython 353 from IPython.display import ( 354 display_html, # pyright: ignore[reportUnknownVariableType] 355 ) 356 except ImportError: 357 pass 358 else: 359 if get_ipython() is not None: 360 _ = display_html( 361 self.format_html( 362 width=width, include_conda_list=include_conda_list 363 ), 364 raw=True, 365 ) 366 return 367 368 # render with rich 369 _ = self._format( 370 target=rich.console.Console( 371 width=width, 372 tab_size=tab_size, 373 soft_wrap=soft_wrap, 374 ), 375 width=width, 376 include_conda_list=include_conda_list, 377 )
387 def log( 388 self, 389 to: Union[Literal["display"], Path, Sequence[Union[Literal["display"], Path]]], 390 ) -> List[Path]: 391 """Convenience method to display the validation summary in the terminal and/or 392 save it to disk. See `save` for details.""" 393 if to == "display": 394 display = True 395 save_to = [] 396 elif isinstance(to, Path): 397 display = False 398 save_to = [to] 399 else: 400 display = "display" in to 401 save_to = [p for p in to if p != "display"] 402 403 if display: 404 self.display() 405 406 return self.save(save_to)
Convenience method to display the validation summary in the terminal and/or
save it to disk. See save
for details.
408 def save( 409 self, path: Union[Path, Sequence[Path]] = Path("{id}_summary_{now}") 410 ) -> List[Path]: 411 """Save the validation/test summary in JSON, Markdown or HTML format. 412 413 Returns: 414 List of file paths the summary was saved to. 415 416 Notes: 417 - Format is chosen based on the suffix: `.json`, `.md`, `.html`. 418 - If **path** has no suffix it is assumed to be a direcotry to which a 419 `summary.json`, `summary.md` and `summary.html` are saved to. 420 """ 421 if isinstance(path, (str, Path)): 422 path = [Path(path)] 423 424 # folder to file paths 425 file_paths: List[Path] = [] 426 for p in path: 427 if p.suffix: 428 file_paths.append(p) 429 else: 430 file_paths.extend( 431 [ 432 p / "summary.json", 433 p / "summary.md", 434 p / "summary.html", 435 ] 436 ) 437 438 now = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") 439 for p in file_paths: 440 p = Path(str(p).format(id=self.id or "bioimageio", now=now)) 441 if p.suffix == ".json": 442 self.save_json(p) 443 elif p.suffix == ".md": 444 self.save_markdown(p) 445 elif p.suffix == ".html": 446 self.save_html(p) 447 else: 448 raise ValueError(f"Unknown summary path suffix '{p.suffix}'") 449 450 return file_paths
Save the validation/test summary in JSON, Markdown or HTML format.
Returns:
List of file paths the summary was saved to.
Notes:
- Format is chosen based on the suffix:
.json
,.md
,.html
. - If path has no suffix it is assumed to be a direcotry to which a
summary.json
,summary.md
andsummary.html
are saved to.
452 def save_json( 453 self, path: Path = Path("summary.json"), *, indent: Optional[int] = 2 454 ): 455 """Save validation/test summary as JSON file.""" 456 json_str = self.model_dump_json(indent=indent) 457 path.parent.mkdir(exist_ok=True, parents=True) 458 _ = path.write_text(json_str, encoding="utf-8") 459 logger.info("Saved summary to {}", path.absolute())
Save validation/test summary as JSON file.
461 def save_markdown(self, path: Path = Path("summary.md")): 462 """Save rendered validation/test summary as Markdown file.""" 463 formatted = self.format_md() 464 path.parent.mkdir(exist_ok=True, parents=True) 465 _ = path.write_text(formatted, encoding="utf-8") 466 logger.info("Saved Markdown formatted summary to {}", path.absolute())
Save rendered validation/test summary as Markdown file.
468 def save_html(self, path: Path = Path("summary.html")) -> None: 469 """Save rendered validation/test summary as HTML file.""" 470 path.parent.mkdir(exist_ok=True, parents=True) 471 472 html = self.format_html() 473 _ = path.write_text(html, encoding="utf-8") 474 logger.info("Saved HTML formatted summary to {}", path.absolute())
Save rendered validation/test summary as HTML file.
476 @classmethod 477 def load_json(cls, path: Path) -> Self: 478 """Load validation/test summary from a suitable JSON file""" 479 json_str = Path(path).read_text(encoding="utf-8") 480 return cls.model_validate_json(json_str)
Load validation/test summary from a suitable JSON file