bioimageio.core
bioimageio.core
Python specific core utilities for bioimage.io resources (in particular DL models).
Get started
To get started we recommend installing bioimageio.core with conda together with a deep
learning framework, e.g. pytorch, and run a few bioimageio
commands to see what
bioimage.core has to offer:
install with conda (for more details on conda environments, checkout the conda docs)
conda install -c conda-forge bioimageio.core pytorch
test a model
$ bioimageio test powerful-chipmunk ...
(Click to expand output)
✔️ bioimageio validation passed ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ source https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/powerful-chipmunk/1/files/rdf.yaml format version model 0.4.10 bioimageio.spec 0.5.3post4 bioimageio.core 0.6.8 ❓ location detail ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ ✔️ initialized ModelDescr to describe model 0.4.10 ✔️ bioimageio.spec format validation model 0.4.10 🔍 context.perform_io_checks True 🔍 context.root https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/powerful-chipmunk/1/files 🔍 context.known_files.weights.pt 3bd9c518c8473f1e35abb7624f82f3aa92f1015e66fb1f6a9d08444e1f2f5698 🔍 context.known_files.weights-torchscript.pt 4e568fd81c0ffa06ce13061327c3f673e1bac808891135badd3b0fcdacee086b 🔍 context.warning_level error ✔️ Reproduce test outputs from test inputs ✔️ Reproduce test outputs from test inputs
or
$ bioimageio test impartial-shrimp ...
(Click to expand output)
✔️ bioimageio validation passed ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ source https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/impartial-shrimp/1.1/files/rdf.yaml format version model 0.5.3 bioimageio.spec 0.5.3.2 bioimageio.core 0.6.9 ❓ location detail ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ ✔️ initialized ModelDescr to describe model 0.5.3 ✔️ bioimageio.spec format validation model 0.5.3 🔍 context.perform_io_checks False 🔍 context.warning_level error ✔️ Reproduce test outputs from test inputs (pytorch_state_dict) ✔️ Run pytorch_state_dict inference for inputs with batch_size: 1 and size parameter n: 0 ✔️ Run pytorch_state_dict inference for inputs with batch_size: 2 and size parameter n: 0 ✔️ Run pytorch_state_dict inference for inputs with batch_size: 1 and size parameter n: 1 ✔️ Run pytorch_state_dict inference for inputs with batch_size: 2 and size parameter n: 1 ✔️ Run pytorch_state_dict inference for inputs with batch_size: 1 and size parameter n: 2 ✔️ Run pytorch_state_dict inference for inputs with batch_size: 2 and size parameter n: 2 ✔️ Reproduce test outputs from test inputs (torchscript) ✔️ Run torchscript inference for inputs with batch_size: 1 and size parameter n: 0 ✔️ Run torchscript inference for inputs with batch_size: 2 and size parameter n: 0 ✔️ Run torchscript inference for inputs with batch_size: 1 and size parameter n: 1 ✔️ Run torchscript inference for inputs with batch_size: 2 and size parameter n: 1 ✔️ Run torchscript inference for inputs with batch_size: 1 and size parameter n: 2 ✔️ Run torchscript inference for inputs with batch_size: 2 and size parameter n: 2
- run prediction on your data
display the
bioimageio-predict
command help to get an overview:$ bioimageio predict --help ...
(Click to expand output)
usage: bioimageio predict [-h] [--inputs Sequence[Union[str,Annotated[Tuple[str,...],MinLenmin_length=1]]]] [--outputs {str,Tuple[str,...]}] [--overwrite bool] [--blockwise bool] [--stats Path] [--preview bool] [--weight_format {typing.Literal['keras_hdf5','onnx','pytorch_state_dict','tensorflow_js','tensorflow_saved_model_bundle','torchscript'],any}] [--example bool] SOURCE bioimageio-predict - Run inference on your data with a bioimage.io model. positional arguments: SOURCE Url/path to a bioimageio.yaml/rdf.yaml file or a bioimage.io resource identifier, e.g. 'affable-shark' optional arguments: -h, --help show this help message and exit --inputs Sequence[Union[str,Annotated[Tuple[str,...],MinLen(min_length=1)]]] Model input sample paths (for each input tensor) The input paths are expected to have shape... - (n_samples,) or (n_samples,1) for models expecting a single input tensor - (n_samples,) containing the substring '{input_id}', or - (n_samples, n_model_inputs) to provide each input tensor path explicitly. All substrings that are replaced by metadata from the model description: - '{model_id}' - '{input_id}' Example inputs to process sample 'a' and 'b' for a model expecting a 'raw' and a 'mask' input tensor: --inputs="[["a_raw.tif","a_mask.tif"],["b_raw.tif","b_mask.tif"]]" (Note that JSON double quotes need to be escaped.) Alternatively a `bioimageio-cli.yaml` (or `bioimageio-cli.json`) file may provide the arguments, e.g.: ```yaml inputs: - [a_raw.tif, a_mask.tif] - [b_raw.tif, b_mask.tif]
`.npy` and any file extension supported by imageio are supported. Aavailable formats are listed at https://imageio.readthedocs.io/en/stable/formats/index.html#all-formats. Some formats have additional dependencies. (default: ('{input_id}/001.tif',))
--outputs {str,Tuple[str,...]} Model output path pattern (per output tensor)
All substrings that are replaced: - '{model_id}' (from model description) - '{output_id}' (from model description) - '{sample_id}' (extracted from input paths) (default: outputs_{model_id}/{output_id}/{sample_id}.tif)
--overwrite bool allow overwriting existing output files (default: False) --blockwise bool process inputs blockwise (default: False) --stats Path path to dataset statistics (will be written if it does not exist, but the model requires statistical dataset measures) (default: dataset_statistics.json) --preview bool preview which files would be processed and what outputs would be generated. (default: False) --weight_format {typing.Literal['keras_hdf5','onnx','pytorch_state_dict','tensorflow_js','tensorflow_saved_model_bundle','torchscript'],any} The weight format to use. (default: any) --example bool generate and run an example
1. downloads example model inputs 2. creates a `{model_id}_example` folder 3. writes input arguments to `{model_id}_example/bioimageio-cli.yaml` 4. executes a preview dry-run 5. executes prediction with example input (default: False)
</details>
- create an example and run prediction locally!
```console
$ bioimageio predict impartial-shrimp --example=True
...
<details>
<summary>(Click to expand output)</summary>
🛈 bioimageio prediction preview structure:
{'{sample_id}': {'inputs': {'{input_id}': '<input path>'},
'outputs': {'{output_id}': '<output path>'}}}
🔎 bioimageio prediction preview output:
{'1': {'inputs': {'input0': 'impartial-shrimp_example/input0/001.tif'},
'outputs': {'output0': 'impartial-shrimp_example/outputs/output0/1.tif'}}}
predict with impartial-shrimp: 100%|███████████████████████████████████████████████████| 1/1 [00:21<00:00, 21.76s/sample]
🎉 Sucessfully ran example prediction!
To predict the example input using the CLI example config file impartial-shrimp_example\bioimageio-cli.yaml, execute `bioimageio predict` from impartial-shrimp_example:
$ cd impartial-shrimp_example
$ bioimageio predict "impartial-shrimp"
Alternatively run the following command in the current workind directory, not the example folder:
$ bioimageio predict --preview=False --overwrite=True --stats="impartial-shrimp_example/dataset_statistics.json" --inputs="[[\"impartial-shrimp_example/input0/001.tif\"]]" --outputs="impartial-shrimp_example/outputs/{output_id}/{sample_id}.tif" "impartial-shrimp"
(note that a local 'bioimageio-cli.json' or 'bioimageio-cli.yaml' may interfere with this)
</details>
Installation
Via Mamba/Conda
The bioimageio.core
package can be installed from conda-forge via
mamba install -c conda-forge bioimageio.core
If you do not install any additional deep learning libraries, you will only be able to use general convenience functionality, but not any functionality for model prediction. To install additional deep learning libraries use:
Pytorch/Torchscript:
CPU installation (if you don't have an nvidia graphics card):
mamba install -c pytorch -c conda-forge bioimageio.core pytorch torchvision cpuonly
GPU installation (for cuda 11.6, please choose the appropriate cuda version for your system):
mamba install -c pytorch -c nvidia -c conda-forge bioimageio.core pytorch torchvision pytorch-cuda=11.8
Note that the pytorch installation instructions may change in the future. For the latest instructions please refer to pytorch.org.
Tensorflow
Currently only CPU version supported
mamba install -c conda-forge bioimageio.core tensorflow
ONNXRuntime
Currently only cpu version supported
mamba install -c conda-forge bioimageio.core onnxruntime
Via pip
The package is also available via pip
(e.g. with recommended extras onnx
and pytorch
):
pip install "bioimageio.core[onnx,pytorch]"
Set up Development Environment
To set up a development conda environment run the following commands:
mamba env create -f dev/env.yaml
mamba activate core
pip install -e . --no-deps
There are different environment files available that only install tensorflow or pytorch as dependencies.
💻 Use the Command Line Interface
bioimageio.core
installs a command line interface (CLI) for testing models and other functionality.
You can list all the available commands via:
bioimageio
CLI inputs from file
For convenience the command line options (not arguments) may be given in a bioimageio-cli.json
or bioimageio-cli.yaml
file, e.g.:
# bioimageio-cli.yaml
inputs: inputs/*_{tensor_id}.h5
outputs: outputs_{model_id}/{sample_id}_{tensor_id}.h5
overwrite: true
blockwise: true
stats: inputs/dataset_statistics.json
🐍 Use in Python
bioimageio.core
is a python package that implements prediction with bioimageio models
including standardized pre- and postprocessing operations.
These models are described by---and can be loaded with---the bioimageio.spec package.
In addition bioimageio.core provides functionality to convert model weight formats.
To get an overview of this functionality, check out these example notebooks:
and the developer documentation.
Logging level
bioimageio.spec
and bioimageio.core
use loguru for logging, hence the logging level
may be controlled with the LOGURU_LEVEL
environment variable.
Model Specification
The model specification and its validation tools can be found at https://github.com/bioimage-io/spec-bioimage-io.
Changelog
0.7.0
- breaking:
- bioimageio CLI now has implicit boolean flags
- non-breaking:
- use new
ValidationDetail.recommended_env
inValidationSummary
- improve
get_io_sample_block_metas()
- now works for sufficiently large, but not exactly shaped inputs
- update to support
zipfile.ZipFile
object with bioimageio.spec==0.5.3.5 - add io helpers
resolve
andresolve_and_extract
- added
enable_determinism
function and determinism input argument for testing with seeded random generators and optionally (determinsim=="full") instructing DL frameworks to use deterministic algorithms.
- use new
0.6.10
- fix #423
0.6.9
- improve bioimageio command line interface (details in #157)
- add
predict
command - package command input
path
is now required
- add
0.6.8
- testing model inference will now check all weight formats (previously only the first one for which model adapter creation succeeded had been checked)
- fix predict with blocking (Thanks @thodkatz)
0.6.7
predict()
argumentinputs
may be sample
0.6.6
- add aliases to match previous API more closely
0.6.5
- improve adapter error messages
0.6.4
- add
bioimageio validate-format
command - improve error messages and display of command results
0.6.3
- Fix #386
- (in model inference testing) stop assuming model inputs are tileable
0.6.2
- Fix #384
0.6.1
0.6.0
- add compatibility with new bioimageio.spec 0.5 (0.5.2post1)
- improve interfaces
0.5.10
1""" 2.. include:: ../../README.md 3""" 4 5from bioimageio.spec import ( 6 build_description, 7 dump_description, 8 load_dataset_description, 9 load_description, 10 load_description_and_validate_format_only, 11 load_model_description, 12 save_bioimageio_package, 13 save_bioimageio_package_as_folder, 14 save_bioimageio_yaml_only, 15 validate_format, 16) 17 18from . import ( 19 axis, 20 block_meta, 21 cli, 22 commands, 23 common, 24 digest_spec, 25 io, 26 model_adapters, 27 prediction, 28 proc_ops, 29 proc_setup, 30 sample, 31 stat_calculators, 32 stat_measures, 33 tensor, 34) 35from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline 36from ._resource_tests import ( 37 enable_determinism, 38 load_description_and_test, 39 test_description, 40 test_model, 41) 42from ._settings import settings 43from .axis import Axis, AxisId 44from .block_meta import BlockMeta 45from .common import MemberId 46from .prediction import predict, predict_many 47from .sample import Sample 48from .stat_calculators import compute_dataset_measures 49from .stat_measures import Stat 50from .tensor import Tensor 51from .utils import VERSION 52 53__version__ = VERSION 54 55 56# aliases 57test_resource = test_description 58"""alias of `test_description`""" 59load_resource = load_description 60"""alias of `load_description`""" 61load_model = load_model_description 62"""alias of `load_model_description`""" 63 64__all__ = [ 65 "__version__", 66 "axis", 67 "Axis", 68 "AxisId", 69 "block_meta", 70 "BlockMeta", 71 "build_description", 72 "cli", 73 "commands", 74 "common", 75 "compute_dataset_measures", 76 "create_prediction_pipeline", 77 "digest_spec", 78 "dump_description", 79 "enable_determinism", 80 "io", 81 "load_dataset_description", 82 "load_description_and_test", 83 "load_description_and_validate_format_only", 84 "load_description", 85 "load_model_description", 86 "load_model", 87 "load_resource", 88 "MemberId", 89 "model_adapters", 90 "predict_many", 91 "predict", 92 "prediction", 93 "PredictionPipeline", 94 "proc_ops", 95 "proc_setup", 96 "sample", 97 "Sample", 98 "save_bioimageio_package_as_folder", 99 "save_bioimageio_package", 100 "save_bioimageio_yaml_only", 101 "settings", 102 "stat_calculators", 103 "stat_measures", 104 "Stat", 105 "tensor", 106 "Tensor", 107 "test_description", 108 "test_model", 109 "test_resource", 110 "validate_format", 111]
41@dataclass 42class Axis: 43 id: AxisId 44 type: Literal["batch", "channel", "index", "space", "time"] 45 46 @classmethod 47 def create(cls, axis: AxisLike) -> Axis: 48 if isinstance(axis, cls): 49 return axis 50 elif isinstance(axis, Axis): 51 return Axis(id=axis.id, type=axis.type) 52 elif isinstance(axis, str): 53 return Axis(id=AxisId(axis), type=_get_axis_type(axis)) 54 elif isinstance(axis, v0_5.AxisBase): 55 return Axis(id=AxisId(axis.id), type=axis.type) 56 else: 57 assert_never(axis)
46 @classmethod 47 def create(cls, axis: AxisLike) -> Axis: 48 if isinstance(axis, cls): 49 return axis 50 elif isinstance(axis, Axis): 51 return Axis(id=axis.id, type=axis.type) 52 elif isinstance(axis, str): 53 return Axis(id=AxisId(axis), type=_get_axis_type(axis)) 54 elif isinstance(axis, v0_5.AxisBase): 55 return Axis(id=AxisId(axis.id), type=axis.type) 56 else: 57 assert_never(axis)
199class AxisId(LowerCaseIdentifier): 200 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 201 Annotated[LowerCaseIdentifierAnno, MaxLen(16)] 202 ]
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
46@dataclass(frozen=True) 47class BlockMeta: 48 """Block meta data of a sample member (a tensor in a sample) 49 50 Figure for illustration: 51 The first 2d block (dashed) of a sample member (**bold**). 52 The inner slice (thin) is expanded by a halo in both dimensions on both sides. 53 The outer slice reaches from the sample member origin (0, 0) to the right halo point. 54 55 ```terminal 56 ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ 57 ╷ halo(left) ╷ 58 ╷ ╷ 59 ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ 60 ╷ ┃ │ ╷ sample member 61 ╷ ┃ inner │ ╷ 62 ╷ ┃ (and outer) │ outer ╷ 63 ╷ ┃ slice │ slice ╷ 64 ╷ ┃ │ ╷ 65 ╷ ┣─────────────────┘ ╷ 66 ╷ ┃ outer slice ╷ 67 ╷ ┃ halo(right) ╷ 68 └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ 69 ⬇ 70 ``` 71 72 note: 73 - Inner and outer slices are specified in sample member coordinates. 74 - The outer_slice of a block at the sample edge may overlap by more than the 75 halo with the neighboring block (the inner slices will not overlap though). 76 77 """ 78 79 sample_shape: PerAxis[int] 80 """the axis sizes of the whole (unblocked) sample""" 81 82 inner_slice: PerAxis[SliceInfo] 83 """inner region (without halo) wrt the sample""" 84 85 halo: PerAxis[Halo] 86 """halo enlarging the inner region to the block's sizes""" 87 88 block_index: BlockIndex 89 """the i-th block of the sample""" 90 91 blocks_in_sample: TotalNumberOfBlocks 92 """total number of blocks in the sample""" 93 94 @cached_property 95 def shape(self) -> PerAxis[int]: 96 """axis lengths of the block""" 97 return Frozen( 98 { 99 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 100 for a, s in self.inner_slice.items() 101 } 102 ) 103 104 @cached_property 105 def padding(self) -> PerAxis[PadWidth]: 106 """padding to realize the halo at the sample edge 107 where we cannot simply enlarge the inner slice""" 108 return Frozen( 109 { 110 a: PadWidth( 111 ( 112 self.halo[a].left 113 - (self.inner_slice[a].start - self.outer_slice[a].start) 114 if a in self.halo 115 else 0 116 ), 117 ( 118 self.halo[a].right 119 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 120 if a in self.halo 121 else 0 122 ), 123 ) 124 for a in self.inner_slice 125 } 126 ) 127 128 @cached_property 129 def outer_slice(self) -> PerAxis[SliceInfo]: 130 """slice of the outer block (without padding) wrt the sample""" 131 return Frozen( 132 { 133 a: SliceInfo( 134 max( 135 0, 136 min( 137 self.inner_slice[a].start 138 - (self.halo[a].left if a in self.halo else 0), 139 self.sample_shape[a] 140 - self.inner_shape[a] 141 - (self.halo[a].left if a in self.halo else 0), 142 ), 143 ), 144 min( 145 self.sample_shape[a], 146 self.inner_slice[a].stop 147 + (self.halo[a].right if a in self.halo else 0), 148 ), 149 ) 150 for a in self.inner_slice 151 } 152 ) 153 154 @cached_property 155 def inner_shape(self) -> PerAxis[int]: 156 """axis lengths of the inner region (without halo)""" 157 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()}) 158 159 @cached_property 160 def local_slice(self) -> PerAxis[SliceInfo]: 161 """inner slice wrt the block, **not** the sample""" 162 return Frozen( 163 { 164 a: SliceInfo( 165 self.halo[a].left, 166 self.halo[a].left + self.inner_shape[a], 167 ) 168 for a in self.inner_slice 169 } 170 ) 171 172 @property 173 def dims(self) -> Collection[AxisId]: 174 return set(self.inner_shape) 175 176 @property 177 def tagged_shape(self) -> PerAxis[int]: 178 """alias for shape""" 179 return self.shape 180 181 @property 182 def inner_slice_wo_overlap(self): 183 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 184 stiched together trivially to form the original sample. 185 186 This can also be used to calculate statistics 187 without overrepresenting block edge regions.""" 188 # TODO: update inner_slice_wo_overlap when adding block overlap 189 return self.inner_slice 190 191 def __post_init__(self): 192 # freeze mutable inputs 193 if not isinstance(self.sample_shape, Frozen): 194 object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) 195 196 if not isinstance(self.inner_slice, Frozen): 197 object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) 198 199 if not isinstance(self.halo, Frozen): 200 object.__setattr__(self, "halo", Frozen(self.halo)) 201 202 assert all( 203 a in self.sample_shape for a in self.inner_slice 204 ), "block has axes not present in sample" 205 206 assert all( 207 a in self.inner_slice for a in self.halo 208 ), "halo has axes not present in block" 209 210 if any(s > self.sample_shape[a] for a, s in self.shape.items()): 211 logger.warning( 212 "block {} larger than sample {}", self.shape, self.sample_shape 213 ) 214 215 def get_transformed( 216 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 217 ) -> Self: 218 return self.__class__( 219 sample_shape={ 220 a: ( 221 trf 222 if isinstance(trf, int) 223 else trf.compute(self.sample_shape[trf.axis]) 224 ) 225 for a, trf in new_axes.items() 226 }, 227 inner_slice={ 228 a: ( 229 SliceInfo(0, trf) 230 if isinstance(trf, int) 231 else SliceInfo( 232 trf.compute(self.inner_slice[trf.axis].start), 233 trf.compute(self.inner_slice[trf.axis].stop), 234 ) 235 ) 236 for a, trf in new_axes.items() 237 }, 238 halo={ 239 a: ( 240 Halo(0, 0) 241 if isinstance(trf, int) 242 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 243 ) 244 for a, trf in new_axes.items() 245 }, 246 block_index=self.block_index, 247 blocks_in_sample=self.blocks_in_sample, 248 )
Block meta data of a sample member (a tensor in a sample)
Figure for illustration: The first 2d block (dashed) of a sample member (bold). The inner slice (thin) is expanded by a halo in both dimensions on both sides. The outer slice reaches from the sample member origin (0, 0) to the right halo point.
┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐
╷ halo(left) ╷
╷ ╷
╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔
╷ ┃ │ ╷ sample member
╷ ┃ inner │ ╷
╷ ┃ (and outer) │ outer ╷
╷ ┃ slice │ slice ╷
╷ ┃ │ ╷
╷ ┣─────────────────┘ ╷
╷ ┃ outer slice ╷
╷ ┃ halo(right) ╷
└ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘
⬇
note:
- Inner and outer slices are specified in sample member coordinates.
- The outer_slice of a block at the sample edge may overlap by more than the halo with the neighboring block (the inner slices will not overlap though).
inner region (without halo) wrt the sample
halo enlarging the inner region to the block's sizes
94 @cached_property 95 def shape(self) -> PerAxis[int]: 96 """axis lengths of the block""" 97 return Frozen( 98 { 99 a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) 100 for a, s in self.inner_slice.items() 101 } 102 )
axis lengths of the block
104 @cached_property 105 def padding(self) -> PerAxis[PadWidth]: 106 """padding to realize the halo at the sample edge 107 where we cannot simply enlarge the inner slice""" 108 return Frozen( 109 { 110 a: PadWidth( 111 ( 112 self.halo[a].left 113 - (self.inner_slice[a].start - self.outer_slice[a].start) 114 if a in self.halo 115 else 0 116 ), 117 ( 118 self.halo[a].right 119 - (self.outer_slice[a].stop - self.inner_slice[a].stop) 120 if a in self.halo 121 else 0 122 ), 123 ) 124 for a in self.inner_slice 125 } 126 )
padding to realize the halo at the sample edge where we cannot simply enlarge the inner slice
128 @cached_property 129 def outer_slice(self) -> PerAxis[SliceInfo]: 130 """slice of the outer block (without padding) wrt the sample""" 131 return Frozen( 132 { 133 a: SliceInfo( 134 max( 135 0, 136 min( 137 self.inner_slice[a].start 138 - (self.halo[a].left if a in self.halo else 0), 139 self.sample_shape[a] 140 - self.inner_shape[a] 141 - (self.halo[a].left if a in self.halo else 0), 142 ), 143 ), 144 min( 145 self.sample_shape[a], 146 self.inner_slice[a].stop 147 + (self.halo[a].right if a in self.halo else 0), 148 ), 149 ) 150 for a in self.inner_slice 151 } 152 )
slice of the outer block (without padding) wrt the sample
154 @cached_property 155 def inner_shape(self) -> PerAxis[int]: 156 """axis lengths of the inner region (without halo)""" 157 return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()})
axis lengths of the inner region (without halo)
159 @cached_property 160 def local_slice(self) -> PerAxis[SliceInfo]: 161 """inner slice wrt the block, **not** the sample""" 162 return Frozen( 163 { 164 a: SliceInfo( 165 self.halo[a].left, 166 self.halo[a].left + self.inner_shape[a], 167 ) 168 for a in self.inner_slice 169 } 170 )
inner slice wrt the block, not the sample
176 @property 177 def tagged_shape(self) -> PerAxis[int]: 178 """alias for shape""" 179 return self.shape
alias for shape
181 @property 182 def inner_slice_wo_overlap(self): 183 """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be 184 stiched together trivially to form the original sample. 185 186 This can also be used to calculate statistics 187 without overrepresenting block edge regions.""" 188 # TODO: update inner_slice_wo_overlap when adding block overlap 189 return self.inner_slice
subslice of the inner slice, such that all inner_slice_wo_overlap
can be
stiched together trivially to form the original sample.
This can also be used to calculate statistics without overrepresenting block edge regions.
215 def get_transformed( 216 self, new_axes: PerAxis[Union[LinearAxisTransform, int]] 217 ) -> Self: 218 return self.__class__( 219 sample_shape={ 220 a: ( 221 trf 222 if isinstance(trf, int) 223 else trf.compute(self.sample_shape[trf.axis]) 224 ) 225 for a, trf in new_axes.items() 226 }, 227 inner_slice={ 228 a: ( 229 SliceInfo(0, trf) 230 if isinstance(trf, int) 231 else SliceInfo( 232 trf.compute(self.inner_slice[trf.axis].start), 233 trf.compute(self.inner_slice[trf.axis].stop), 234 ) 235 ) 236 for a, trf in new_axes.items() 237 }, 238 halo={ 239 a: ( 240 Halo(0, 0) 241 if isinstance(trf, int) 242 else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) 243 ) 244 for a, trf in new_axes.items() 245 }, 246 block_index=self.block_index, 247 blocks_in_sample=self.blocks_in_sample, 248 )
130def build_description( 131 content: BioimageioYamlContent, 132 /, 133 *, 134 context: Optional[ValidationContext] = None, 135 format_version: Union[FormatVersionPlaceholder, str] = DISCOVER, 136) -> Union[ResourceDescr, InvalidDescr]: 137 """build a bioimage.io resource description from an RDF's content. 138 139 Use `load_description` if you want to build a resource description from an rdf.yaml 140 or bioimage.io zip-package. 141 142 Args: 143 content: loaded rdf.yaml file (loaded with YAML, not bioimageio.spec) 144 context: validation context to use during validation 145 format_version: (optional) use this argument to load the resource and 146 convert its metadata to a higher format_version 147 148 Returns: 149 An object holding all metadata of the bioimage.io resource 150 151 """ 152 153 return build_description_impl( 154 content, 155 context=context, 156 format_version=format_version, 157 get_rd_class=_get_rd_class, 158 )
build a bioimage.io resource description from an RDF's content.
Use load_description
if you want to build a resource description from an rdf.yaml
or bioimage.io zip-package.
Arguments:
- content: loaded rdf.yaml file (loaded with YAML, not bioimageio.spec)
- context: validation context to use during validation
- format_version: (optional) use this argument to load the resource and convert its metadata to a higher format_version
Returns:
An object holding all metadata of the bioimage.io resource
560def compute_dataset_measures( 561 measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] 562) -> Dict[DatasetMeasure, MeasureValue]: 563 """compute all dataset `measures` for the given `dataset`""" 564 sample_calculators, calculators = get_measure_calculators(measures) 565 assert not sample_calculators 566 567 ret: Dict[DatasetMeasure, MeasureValue] = {} 568 569 for sample in dataset: 570 for calc in calculators: 571 calc.update(sample) 572 573 for calc in calculators: 574 ret.update(calc.finalize().items()) 575 576 return ret
compute all dataset measures
for the given dataset
319def create_prediction_pipeline( 320 bioimageio_model: AnyModelDescr, 321 *, 322 devices: Optional[Sequence[str]] = None, 323 weight_format: Optional[WeightsFormat] = None, 324 weights_format: Optional[WeightsFormat] = None, 325 dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), 326 keep_updating_initial_dataset_statistics: bool = False, 327 fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( 328 {} 329 ), 330 model_adapter: Optional[ModelAdapter] = None, 331 ns: Union[ 332 v0_5.ParameterizedSize_N, 333 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 334 ] = 10, 335 **deprecated_kwargs: Any, 336) -> PredictionPipeline: 337 """ 338 Creates prediction pipeline which includes: 339 * computation of input statistics 340 * preprocessing 341 * model prediction 342 * computation of output statistics 343 * postprocessing 344 """ 345 weights_format = weight_format or weights_format 346 del weight_format 347 if deprecated_kwargs: 348 warnings.warn( 349 f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}" 350 ) 351 352 model_adapter = model_adapter or create_model_adapter( 353 model_description=bioimageio_model, 354 devices=devices, 355 weight_format_priority_order=weights_format and (weights_format,), 356 ) 357 358 input_ids = get_member_ids(bioimageio_model.inputs) 359 360 def dataset(): 361 common_stat: Stat = {} 362 for i, x in enumerate(dataset_for_initial_statistics): 363 if isinstance(x, Sample): 364 yield x 365 else: 366 yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i) 367 368 preprocessing, postprocessing = setup_pre_and_postprocessing( 369 bioimageio_model, 370 dataset(), 371 keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics, 372 fixed_dataset_stats=fixed_dataset_statistics, 373 ) 374 375 return PredictionPipeline( 376 name=bioimageio_model.name, 377 model_description=bioimageio_model, 378 model_adapter=model_adapter, 379 preprocessing=preprocessing, 380 postprocessing=postprocessing, 381 default_ns=ns, 382 )
Creates prediction pipeline which includes:
- computation of input statistics
- preprocessing
- model prediction
- computation of output statistics
- postprocessing
65def dump_description( 66 rd: Union[ResourceDescr, InvalidDescr], exclude_unset: bool = True 67) -> BioimageioYamlContent: 68 """Converts a resource to a dictionary containing only simple types that can directly be serialzed to YAML.""" 69 return rd.model_dump(mode="json", exclude_unset=exclude_unset)
Converts a resource to a dictionary containing only simple types that can directly be serialzed to YAML.
36def enable_determinism(mode: Literal["seed_only", "full"]): 37 """Seed and configure ML frameworks for maximum reproducibility. 38 May degrade performance. Only recommended for testing reproducibility! 39 40 Seed any random generators and (if **mode**=="full") request ML frameworks to use 41 deterministic algorithms. 42 Notes: 43 - **mode** == "full" might degrade performance and throw exceptions. 44 - Subsequent inference calls might still differ. Call before each function 45 (sequence) that is expected to be reproducible. 46 - Degraded performance: Use for testing reproducibility only! 47 - Recipes: 48 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 49 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 50 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 51 """ 52 try: 53 try: 54 import numpy.random 55 except ImportError: 56 pass 57 else: 58 numpy.random.seed(0) 59 except Exception as e: 60 logger.debug(str(e)) 61 62 try: 63 try: 64 import torch 65 except ImportError: 66 pass 67 else: 68 _ = torch.manual_seed(0) 69 torch.use_deterministic_algorithms(mode == "full") 70 except Exception as e: 71 logger.debug(str(e)) 72 73 try: 74 try: 75 import keras 76 except ImportError: 77 pass 78 else: 79 keras.utils.set_random_seed(0) 80 except Exception as e: 81 logger.debug(str(e)) 82 83 try: 84 try: 85 import tensorflow as tf # pyright: ignore[reportMissingImports] 86 except ImportError: 87 pass 88 else: 89 tf.random.seed(0) 90 if mode == "full": 91 tf.config.experimental.enable_op_determinism() 92 # TODO: find possibility to switch it off again?? 93 except Exception as e: 94 logger.debug(str(e))
Seed and configure ML frameworks for maximum reproducibility. May degrade performance. Only recommended for testing reproducibility!
Seed any random generators and (if mode=="full") request ML frameworks to use deterministic algorithms.
Notes:
98def load_dataset_description( 99 source: Union[PermissiveFileSource, ZipFile], 100 /, 101 *, 102 format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, 103 perform_io_checks: bool = settings.perform_io_checks, 104 known_files: Optional[Dict[str, Sha256]] = None, 105) -> AnyDatasetDescr: 106 """same as `load_description`, but addtionally ensures that the loaded 107 description is valid and of type 'dataset'. 108 """ 109 rd = load_description( 110 source, 111 format_version=format_version, 112 perform_io_checks=perform_io_checks, 113 known_files=known_files, 114 ) 115 return ensure_description_is_dataset(rd)
same as load_description
, but addtionally ensures that the loaded
description is valid and of type 'dataset'.
147def load_description_and_test( 148 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 149 *, 150 format_version: Union[Literal["discover", "latest"], str] = "discover", 151 weight_format: Optional[WeightsFormat] = None, 152 devices: Optional[Sequence[str]] = None, 153 absolute_tolerance: float = 1.5e-4, 154 relative_tolerance: float = 1e-4, 155 decimal: Optional[int] = None, 156 determinism: Literal["seed_only", "full"] = "seed_only", 157 expected_type: Optional[str] = None, 158) -> Union[ResourceDescr, InvalidDescr]: 159 """Test RDF dynamically, e.g. model inference of test inputs""" 160 if ( 161 isinstance(source, ResourceDescrBase) 162 and format_version != "discover" 163 and source.format_version != format_version 164 ): 165 warnings.warn( 166 f"deserializing source to ensure we validate and test using format {format_version}" 167 ) 168 source = dump_description(source) 169 170 if isinstance(source, ResourceDescrBase): 171 rd = source 172 elif isinstance(source, dict): 173 rd = build_description(source, format_version=format_version) 174 else: 175 rd = load_description(source, format_version=format_version) 176 177 rd.validation_summary.env.add( 178 InstalledPackage(name="bioimageio.core", version=VERSION) 179 ) 180 181 if expected_type is not None: 182 _test_expected_resource_type(rd, expected_type) 183 184 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 185 if weight_format is None: 186 weight_formats: List[WeightsFormat] = [ 187 w for w, we in rd.weights if we is not None 188 ] # pyright: ignore[reportAssignmentType] 189 else: 190 weight_formats = [weight_format] 191 192 if decimal is None: 193 atol = absolute_tolerance 194 rtol = relative_tolerance 195 else: 196 warnings.warn( 197 "The argument `decimal` has been deprecated in favour of" 198 + " `relative_tolerance` and `absolute_tolerance`, with different" 199 + " validation logic, using `numpy.testing.assert_allclose, see" 200 + " 'https://numpy.org/doc/stable/reference/generated/" 201 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 202 + " will cause validation to revert to the old behaviour." 203 ) 204 atol = 1.5 * 10 ** (-decimal) 205 rtol = 0 206 207 enable_determinism(determinism) 208 for w in weight_formats: 209 _test_model_inference(rd, w, devices, atol, rtol) 210 if not isinstance(rd, v0_4.ModelDescr): 211 _test_model_inference_parametrized(rd, w, devices) 212 213 # TODO: add execution of jupyter notebooks 214 # TODO: add more tests 215 216 return rd
Test RDF dynamically, e.g. model inference of test inputs
137def load_description_and_validate_format_only( 138 source: Union[PermissiveFileSource, ZipFile], 139 /, 140 *, 141 format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, 142 perform_io_checks: bool = settings.perform_io_checks, 143 known_files: Optional[Dict[str, Sha256]] = None, 144) -> ValidationSummary: 145 """load a bioimage.io resource description 146 147 Args: 148 source: Path or URL to an rdf.yaml or a bioimage.io package 149 (zip-file with rdf.yaml in it). 150 format_version: (optional) Use this argument to load the resource and 151 convert its metadata to a higher format_version. 152 perform_io_checks: Wether or not to perform validation that requires file io, 153 e.g. downloading a remote files. The existence of local 154 absolute file paths is still being checked. 155 known_files: Allows to bypass download and hashing of referenced files 156 (even if perform_io_checks is True). 157 158 Returns: 159 Validation summary of the bioimage.io resource found at `source`. 160 161 """ 162 rd = load_description( 163 source, 164 format_version=format_version, 165 perform_io_checks=perform_io_checks, 166 known_files=known_files, 167 ) 168 assert rd.validation_summary is not None 169 return rd.validation_summary
load a bioimage.io resource description
Arguments:
- source: Path or URL to an rdf.yaml or a bioimage.io package (zip-file with rdf.yaml in it).
- format_version: (optional) Use this argument to load the resource and convert its metadata to a higher format_version.
- perform_io_checks: Wether or not to perform validation that requires file io, e.g. downloading a remote files. The existence of local absolute file paths is still being checked.
- known_files: Allows to bypass download and hashing of referenced files (even if perform_io_checks is True).
Returns:
Validation summary of the bioimage.io resource found at
source
.
29def load_description( 30 source: Union[PermissiveFileSource, ZipFile], 31 /, 32 *, 33 format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, 34 perform_io_checks: bool = settings.perform_io_checks, 35 known_files: Optional[Dict[str, Sha256]] = None, 36) -> Union[ResourceDescr, InvalidDescr]: 37 """load a bioimage.io resource description 38 39 Args: 40 source: Path or URL to an rdf.yaml or a bioimage.io package 41 (zip-file with rdf.yaml in it). 42 format_version: (optional) Use this argument to load the resource and 43 convert its metadata to a higher format_version. 44 perform_io_checks: Wether or not to perform validation that requires file io, 45 e.g. downloading a remote files. The existence of local 46 absolute file paths is still being checked. 47 known_files: Allows to bypass download and hashing of referenced files 48 (even if perform_io_checks is True). 49 50 Returns: 51 An object holding all metadata of the bioimage.io resource 52 53 """ 54 if isinstance(source, ResourceDescrBase): 55 name = getattr(source, "name", f"{str(source)[:10]}...") 56 logger.warning("returning already loaded description '{}' as is", name) 57 return source # pyright: ignore[reportReturnType] 58 59 opened = open_bioimageio_yaml(source) 60 61 context = validation_context_var.get().replace( 62 root=opened.original_root, 63 file_name=opened.original_file_name, 64 perform_io_checks=perform_io_checks, 65 known_files=known_files, 66 ) 67 68 return build_description( 69 opened.content, 70 context=context, 71 format_version=format_version, 72 )
load a bioimage.io resource description
Arguments:
- source: Path or URL to an rdf.yaml or a bioimage.io package (zip-file with rdf.yaml in it).
- format_version: (optional) Use this argument to load the resource and convert its metadata to a higher format_version.
- perform_io_checks: Wether or not to perform validation that requires file io, e.g. downloading a remote files. The existence of local absolute file paths is still being checked.
- known_files: Allows to bypass download and hashing of referenced files (even if perform_io_checks is True).
Returns:
An object holding all metadata of the bioimage.io resource
75def load_model_description( 76 source: Union[PermissiveFileSource, ZipFile], 77 /, 78 *, 79 format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, 80 perform_io_checks: bool = settings.perform_io_checks, 81 known_files: Optional[Dict[str, Sha256]] = None, 82) -> AnyModelDescr: 83 """same as `load_description`, but addtionally ensures that the loaded 84 description is valid and of type 'model'. 85 86 Raises: 87 ValueError: for invalid or non-model resources 88 """ 89 rd = load_description( 90 source, 91 format_version=format_version, 92 perform_io_checks=perform_io_checks, 93 known_files=known_files, 94 ) 95 return ensure_description_is_model(rd)
same as load_description
, but addtionally ensures that the loaded
description is valid and of type 'model'.
Raises:
- ValueError: for invalid or non-model resources
75def load_model_description( 76 source: Union[PermissiveFileSource, ZipFile], 77 /, 78 *, 79 format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, 80 perform_io_checks: bool = settings.perform_io_checks, 81 known_files: Optional[Dict[str, Sha256]] = None, 82) -> AnyModelDescr: 83 """same as `load_description`, but addtionally ensures that the loaded 84 description is valid and of type 'model'. 85 86 Raises: 87 ValueError: for invalid or non-model resources 88 """ 89 rd = load_description( 90 source, 91 format_version=format_version, 92 perform_io_checks=perform_io_checks, 93 known_files=known_files, 94 ) 95 return ensure_description_is_model(rd)
alias of load_model_description
29def load_description( 30 source: Union[PermissiveFileSource, ZipFile], 31 /, 32 *, 33 format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, 34 perform_io_checks: bool = settings.perform_io_checks, 35 known_files: Optional[Dict[str, Sha256]] = None, 36) -> Union[ResourceDescr, InvalidDescr]: 37 """load a bioimage.io resource description 38 39 Args: 40 source: Path or URL to an rdf.yaml or a bioimage.io package 41 (zip-file with rdf.yaml in it). 42 format_version: (optional) Use this argument to load the resource and 43 convert its metadata to a higher format_version. 44 perform_io_checks: Wether or not to perform validation that requires file io, 45 e.g. downloading a remote files. The existence of local 46 absolute file paths is still being checked. 47 known_files: Allows to bypass download and hashing of referenced files 48 (even if perform_io_checks is True). 49 50 Returns: 51 An object holding all metadata of the bioimage.io resource 52 53 """ 54 if isinstance(source, ResourceDescrBase): 55 name = getattr(source, "name", f"{str(source)[:10]}...") 56 logger.warning("returning already loaded description '{}' as is", name) 57 return source # pyright: ignore[reportReturnType] 58 59 opened = open_bioimageio_yaml(source) 60 61 context = validation_context_var.get().replace( 62 root=opened.original_root, 63 file_name=opened.original_file_name, 64 perform_io_checks=perform_io_checks, 65 known_files=known_files, 66 ) 67 68 return build_description( 69 opened.content, 70 context=context, 71 format_version=format_version, 72 )
alias of load_description
126def predict_many( 127 *, 128 model: Union[ 129 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 130 ], 131 inputs: Iterable[PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]], 132 sample_id: str = "sample{i:03}", 133 blocksize_parameter: Optional[ 134 Union[ 135 v0_5.ParameterizedSize_N, 136 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 137 ] 138 ] = None, 139 skip_preprocessing: bool = False, 140 skip_postprocessing: bool = False, 141 save_output_path: Optional[Union[Path, str]] = None, 142) -> Iterator[Sample]: 143 """Run prediction for a multiple sets of inputs with a bioimage.io model 144 145 Args: 146 model: model to predict with. 147 May be given as RDF source, model description or prediction pipeline. 148 inputs: An iterable of the named input(s) for this model as a dictionary. 149 sample_id: the sample id. 150 note: `{i}` will be formatted as the i-th sample. 151 If `{i}` (or `{i:`) is not present and `inputs` is an iterable `{i:03}` is appended. 152 blocksize_parameter: (optional) tile the input into blocks parametrized by 153 blocksize according to any parametrized axis sizes defined in the model RDF 154 skip_preprocessing: flag to skip the model's preprocessing 155 skip_postprocessing: flag to skip the model's postprocessing 156 save_output_path: A path with `{member_id}` `{sample_id}` in it 157 to save the output to. 158 """ 159 if save_output_path is not None: 160 if "{member_id}" not in str(save_output_path): 161 raise ValueError( 162 f"Missing `{{member_id}}` in save_output_path={save_output_path}" 163 ) 164 165 if not isinstance(inputs, collections.abc.Mapping) and "{sample_id}" not in str( 166 save_output_path 167 ): 168 raise ValueError( 169 f"Missing `{{sample_id}}` in save_output_path={save_output_path}" 170 ) 171 172 if isinstance(model, PredictionPipeline): 173 pp = model 174 else: 175 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 176 loaded = load_description(model) 177 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 178 raise ValueError(f"expected model description, but got {loaded}") 179 model = loaded 180 181 pp = create_prediction_pipeline(model) 182 183 if not isinstance(inputs, collections.abc.Mapping): 184 sample_id = str(sample_id) 185 if "{i}" not in sample_id and "{i:" not in sample_id: 186 sample_id += "{i:03}" 187 188 total = len(inputs) if isinstance(inputs, collections.abc.Sized) else None 189 190 for i, ipts in tqdm(enumerate(inputs), total=total): 191 yield predict( 192 model=pp, 193 inputs=ipts, 194 sample_id=sample_id.format(i=i), 195 blocksize_parameter=blocksize_parameter, 196 skip_preprocessing=skip_preprocessing, 197 skip_postprocessing=skip_postprocessing, 198 save_output_path=save_output_path, 199 )
Run prediction for a multiple sets of inputs with a bioimage.io model
Arguments:
- model: model to predict with. May be given as RDF source, model description or prediction pipeline.
- inputs: An iterable of the named input(s) for this model as a dictionary.
- sample_id: the sample id.
note:
{i}
will be formatted as the i-th sample. If{i}
(or{i:
) is not present andinputs
is an iterable{i:03}
is appended. - blocksize_parameter: (optional) tile the input into blocks parametrized by blocksize according to any parametrized axis sizes defined in the model RDF
- skip_preprocessing: flag to skip the model's preprocessing
- skip_postprocessing: flag to skip the model's postprocessing
- save_output_path: A path with
{member_id}
{sample_id}
in it to save the output to.
33def predict( 34 *, 35 model: Union[ 36 PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline 37 ], 38 inputs: Union[Sample, PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]], 39 sample_id: Hashable = "sample", 40 blocksize_parameter: Optional[ 41 Union[ 42 v0_5.ParameterizedSize_N, 43 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 44 ] 45 ] = None, 46 input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None, 47 skip_preprocessing: bool = False, 48 skip_postprocessing: bool = False, 49 save_output_path: Optional[Union[Path, str]] = None, 50) -> Sample: 51 """Run prediction for a single set of input(s) with a bioimage.io model 52 53 Args: 54 model: model to predict with. 55 May be given as RDF source, model description or prediction pipeline. 56 inputs: the input sample or the named input(s) for this model as a dictionary 57 sample_id: the sample id. 58 blocksize_parameter: (optional) tile the input into blocks parametrized by 59 blocksize according to any parametrized axis sizes defined in the model RDF. 60 Note: For a predetermined, fixed block shape use `input_block_shape` 61 input_block_shape: (optional) tile the input sample tensors into blocks. 62 Note: For a parameterized block shape, not dealing with the exact block shape, 63 use `blocksize_parameter`. 64 skip_preprocessing: flag to skip the model's preprocessing 65 skip_postprocessing: flag to skip the model's postprocessing 66 save_output_path: A path with `{member_id}` `{sample_id}` in it 67 to save the output to. 68 """ 69 if save_output_path is not None: 70 if "{member_id}" not in str(save_output_path): 71 raise ValueError( 72 f"Missing `{{member_id}}` in save_output_path={save_output_path}" 73 ) 74 75 if isinstance(model, PredictionPipeline): 76 pp = model 77 else: 78 if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)): 79 loaded = load_description(model) 80 if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)): 81 raise ValueError(f"expected model description, but got {loaded}") 82 model = loaded 83 84 pp = create_prediction_pipeline(model) 85 86 if isinstance(inputs, Sample): 87 sample = inputs 88 else: 89 sample = create_sample_for_model( 90 pp.model_description, inputs=inputs, sample_id=sample_id 91 ) 92 93 if input_block_shape is not None: 94 if blocksize_parameter is not None: 95 logger.warning( 96 "ignoring blocksize_parameter={} in favor of input_block_shape={}", 97 blocksize_parameter, 98 input_block_shape, 99 ) 100 101 output = pp.predict_sample_with_fixed_blocking( 102 sample, 103 input_block_shape=input_block_shape, 104 skip_preprocessing=skip_preprocessing, 105 skip_postprocessing=skip_postprocessing, 106 ) 107 elif blocksize_parameter is not None: 108 output = pp.predict_sample_with_blocking( 109 sample, 110 skip_preprocessing=skip_preprocessing, 111 skip_postprocessing=skip_postprocessing, 112 ns=blocksize_parameter, 113 ) 114 else: 115 output = pp.predict_sample_without_blocking( 116 sample, 117 skip_preprocessing=skip_preprocessing, 118 skip_postprocessing=skip_postprocessing, 119 ) 120 if save_output_path: 121 save_sample(save_output_path, output) 122 123 return output
Run prediction for a single set of input(s) with a bioimage.io model
Arguments:
- model: model to predict with. May be given as RDF source, model description or prediction pipeline.
- inputs: the input sample or the named input(s) for this model as a dictionary
- sample_id: the sample id.
- blocksize_parameter: (optional) tile the input into blocks parametrized by
blocksize according to any parametrized axis sizes defined in the model RDF.
Note: For a predetermined, fixed block shape use
input_block_shape
- input_block_shape: (optional) tile the input sample tensors into blocks.
Note: For a parameterized block shape, not dealing with the exact block shape,
use
blocksize_parameter
. - skip_preprocessing: flag to skip the model's preprocessing
- skip_postprocessing: flag to skip the model's postprocessing
- save_output_path: A path with
{member_id}
{sample_id}
in it to save the output to.
44class PredictionPipeline: 45 """ 46 Represents model computation including preprocessing and postprocessing 47 Note: Ideally use the PredictionPipeline as a context manager 48 """ 49 50 def __init__( 51 self, 52 *, 53 name: str, 54 model_description: AnyModelDescr, 55 preprocessing: List[Processing], 56 postprocessing: List[Processing], 57 model_adapter: ModelAdapter, 58 default_ns: Union[ 59 v0_5.ParameterizedSize_N, 60 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 61 ] = 10, 62 default_batch_size: int = 1, 63 ) -> None: 64 super().__init__() 65 if model_description.run_mode: 66 warnings.warn( 67 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" 68 ) 69 70 self.name = name 71 self._preprocessing = preprocessing 72 self._postprocessing = postprocessing 73 74 self.model_description = model_description 75 if isinstance(model_description, v0_4.ModelDescr): 76 self._default_input_halo: PerMember[PerAxis[Halo]] = {} 77 self._block_transform = None 78 else: 79 default_output_halo = { 80 t.id: { 81 a.id: Halo(a.halo, a.halo) 82 for a in t.axes 83 if isinstance(a, v0_5.WithHalo) 84 } 85 for t in model_description.outputs 86 } 87 self._default_input_halo = get_input_halo( 88 model_description, default_output_halo 89 ) 90 self._block_transform = get_block_transform(model_description) 91 92 self._default_ns = default_ns 93 self._default_batch_size = default_batch_size 94 95 self._input_ids = get_member_ids(model_description.inputs) 96 self._output_ids = get_member_ids(model_description.outputs) 97 98 self._adapter: ModelAdapter = model_adapter 99 100 def __enter__(self): 101 self.load() 102 return self 103 104 def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore 105 self.unload() 106 return False 107 108 def predict_sample_block( 109 self, 110 sample_block: SampleBlockWithOrigin, 111 skip_preprocessing: bool = False, 112 skip_postprocessing: bool = False, 113 ) -> SampleBlock: 114 if isinstance(self.model_description, v0_4.ModelDescr): 115 raise NotImplementedError( 116 f"predict_sample_block not implemented for model {self.model_description.format_version}" 117 ) 118 else: 119 assert self._block_transform is not None 120 121 if not skip_preprocessing: 122 self.apply_preprocessing(sample_block) 123 124 output_meta = sample_block.get_transformed_meta(self._block_transform) 125 output = output_meta.with_data( 126 { 127 tid: out 128 for tid, out in zip( 129 self._output_ids, 130 self._adapter.forward( 131 *(sample_block.members.get(t) for t in self._input_ids) 132 ), 133 ) 134 if out is not None 135 }, 136 stat=sample_block.stat, 137 ) 138 if not skip_postprocessing: 139 self.apply_postprocessing(output) 140 141 return output 142 143 def predict_sample_without_blocking( 144 self, 145 sample: Sample, 146 skip_preprocessing: bool = False, 147 skip_postprocessing: bool = False, 148 ) -> Sample: 149 """predict a sample. 150 The sample's tensor shapes have to match the model's input tensor description. 151 If that is not the case, consider `predict_sample_with_blocking`""" 152 153 if not skip_preprocessing: 154 self.apply_preprocessing(sample) 155 156 output = Sample( 157 members={ 158 out_id: out 159 for out_id, out in zip( 160 self._output_ids, 161 self._adapter.forward( 162 *(sample.members.get(in_id) for in_id in self._input_ids) 163 ), 164 ) 165 if out is not None 166 }, 167 stat=sample.stat, 168 id=sample.id, 169 ) 170 if not skip_postprocessing: 171 self.apply_postprocessing(output) 172 173 return output 174 175 def get_output_sample_id(self, input_sample_id: SampleId): 176 warnings.warn( 177 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be" 178 + " removed soon. Output sample id is equal to input sample id, hence this" 179 + " function is not needed." 180 ) 181 return input_sample_id 182 183 def predict_sample_with_fixed_blocking( 184 self, 185 sample: Sample, 186 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]], 187 *, 188 skip_preprocessing: bool = False, 189 skip_postprocessing: bool = False, 190 ) -> Sample: 191 if not skip_preprocessing: 192 self.apply_preprocessing(sample) 193 194 n_blocks, input_blocks = sample.split_into_blocks( 195 input_block_shape, 196 halo=self._default_input_halo, 197 pad_mode="reflect", 198 ) 199 input_blocks = list(input_blocks) 200 predicted_blocks: List[SampleBlock] = [] 201 for b in tqdm( 202 input_blocks, 203 desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}", 204 unit="block", 205 unit_divisor=1, 206 total=n_blocks, 207 ): 208 predicted_blocks.append( 209 self.predict_sample_block( 210 b, skip_preprocessing=True, skip_postprocessing=True 211 ) 212 ) 213 214 predicted_sample = Sample.from_blocks(predicted_blocks) 215 if not skip_postprocessing: 216 self.apply_postprocessing(predicted_sample) 217 218 return predicted_sample 219 220 def predict_sample_with_blocking( 221 self, 222 sample: Sample, 223 skip_preprocessing: bool = False, 224 skip_postprocessing: bool = False, 225 ns: Optional[ 226 Union[ 227 v0_5.ParameterizedSize_N, 228 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 229 ] 230 ] = None, 231 batch_size: Optional[int] = None, 232 ) -> Sample: 233 """predict a sample by splitting it into blocks according to the model and the `ns` parameter""" 234 235 if isinstance(self.model_description, v0_4.ModelDescr): 236 raise NotImplementedError( 237 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" 238 + f" {self.model_description.name}." 239 + " Consider using `predict_sample_with_fixed_blocking`" 240 ) 241 242 ns = ns or self._default_ns 243 if isinstance(ns, int): 244 ns = { 245 (ipt.id, a.id): ns 246 for ipt in self.model_description.inputs 247 for a in ipt.axes 248 if isinstance(a.size, v0_5.ParameterizedSize) 249 } 250 input_block_shape = self.model_description.get_tensor_sizes( 251 ns, batch_size or self._default_batch_size 252 ).inputs 253 254 return self.predict_sample_with_fixed_blocking( 255 sample, 256 input_block_shape=input_block_shape, 257 skip_preprocessing=skip_preprocessing, 258 skip_postprocessing=skip_postprocessing, 259 ) 260 261 # def predict( 262 # self, 263 # inputs: Predict_IO, 264 # skip_preprocessing: bool = False, 265 # skip_postprocessing: bool = False, 266 # ) -> Predict_IO: 267 # """Run model prediction **including** pre/postprocessing.""" 268 269 # if isinstance(inputs, Sample): 270 # return self.predict_sample_with_blocking( 271 # inputs, 272 # skip_preprocessing=skip_preprocessing, 273 # skip_postprocessing=skip_postprocessing, 274 # ) 275 # elif isinstance(inputs, collections.abc.Iterable): 276 # return ( 277 # self.predict( 278 # ipt, 279 # skip_preprocessing=skip_preprocessing, 280 # skip_postprocessing=skip_postprocessing, 281 # ) 282 # for ipt in inputs 283 # ) 284 # else: 285 # assert_never(inputs) 286 287 def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 288 """apply preprocessing in-place, also updates sample stats""" 289 for op in self._preprocessing: 290 op(sample) 291 292 def apply_postprocessing( 293 self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] 294 ) -> None: 295 """apply postprocessing in-place, also updates samples stats""" 296 for op in self._postprocessing: 297 if isinstance(sample, (Sample, SampleBlockWithOrigin)): 298 op(sample) 299 elif not isinstance(op, BlockedOperator): 300 raise NotImplementedError( 301 "block wise update of output statistics not yet implemented" 302 ) 303 else: 304 op(sample) 305 306 def load(self): 307 """ 308 optional step: load model onto devices before calling forward if not using it as context manager 309 """ 310 pass 311 312 def unload(self): 313 """ 314 free any device memory in use 315 """ 316 self._adapter.unload()
Represents model computation including preprocessing and postprocessing Note: Ideally use the PredictionPipeline as a context manager
50 def __init__( 51 self, 52 *, 53 name: str, 54 model_description: AnyModelDescr, 55 preprocessing: List[Processing], 56 postprocessing: List[Processing], 57 model_adapter: ModelAdapter, 58 default_ns: Union[ 59 v0_5.ParameterizedSize_N, 60 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 61 ] = 10, 62 default_batch_size: int = 1, 63 ) -> None: 64 super().__init__() 65 if model_description.run_mode: 66 warnings.warn( 67 f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" 68 ) 69 70 self.name = name 71 self._preprocessing = preprocessing 72 self._postprocessing = postprocessing 73 74 self.model_description = model_description 75 if isinstance(model_description, v0_4.ModelDescr): 76 self._default_input_halo: PerMember[PerAxis[Halo]] = {} 77 self._block_transform = None 78 else: 79 default_output_halo = { 80 t.id: { 81 a.id: Halo(a.halo, a.halo) 82 for a in t.axes 83 if isinstance(a, v0_5.WithHalo) 84 } 85 for t in model_description.outputs 86 } 87 self._default_input_halo = get_input_halo( 88 model_description, default_output_halo 89 ) 90 self._block_transform = get_block_transform(model_description) 91 92 self._default_ns = default_ns 93 self._default_batch_size = default_batch_size 94 95 self._input_ids = get_member_ids(model_description.inputs) 96 self._output_ids = get_member_ids(model_description.outputs) 97 98 self._adapter: ModelAdapter = model_adapter
108 def predict_sample_block( 109 self, 110 sample_block: SampleBlockWithOrigin, 111 skip_preprocessing: bool = False, 112 skip_postprocessing: bool = False, 113 ) -> SampleBlock: 114 if isinstance(self.model_description, v0_4.ModelDescr): 115 raise NotImplementedError( 116 f"predict_sample_block not implemented for model {self.model_description.format_version}" 117 ) 118 else: 119 assert self._block_transform is not None 120 121 if not skip_preprocessing: 122 self.apply_preprocessing(sample_block) 123 124 output_meta = sample_block.get_transformed_meta(self._block_transform) 125 output = output_meta.with_data( 126 { 127 tid: out 128 for tid, out in zip( 129 self._output_ids, 130 self._adapter.forward( 131 *(sample_block.members.get(t) for t in self._input_ids) 132 ), 133 ) 134 if out is not None 135 }, 136 stat=sample_block.stat, 137 ) 138 if not skip_postprocessing: 139 self.apply_postprocessing(output) 140 141 return output
143 def predict_sample_without_blocking( 144 self, 145 sample: Sample, 146 skip_preprocessing: bool = False, 147 skip_postprocessing: bool = False, 148 ) -> Sample: 149 """predict a sample. 150 The sample's tensor shapes have to match the model's input tensor description. 151 If that is not the case, consider `predict_sample_with_blocking`""" 152 153 if not skip_preprocessing: 154 self.apply_preprocessing(sample) 155 156 output = Sample( 157 members={ 158 out_id: out 159 for out_id, out in zip( 160 self._output_ids, 161 self._adapter.forward( 162 *(sample.members.get(in_id) for in_id in self._input_ids) 163 ), 164 ) 165 if out is not None 166 }, 167 stat=sample.stat, 168 id=sample.id, 169 ) 170 if not skip_postprocessing: 171 self.apply_postprocessing(output) 172 173 return output
predict a sample.
The sample's tensor shapes have to match the model's input tensor description.
If that is not the case, consider predict_sample_with_blocking
175 def get_output_sample_id(self, input_sample_id: SampleId): 176 warnings.warn( 177 "`PredictionPipeline.get_output_sample_id()` is deprecated and will be" 178 + " removed soon. Output sample id is equal to input sample id, hence this" 179 + " function is not needed." 180 ) 181 return input_sample_id
183 def predict_sample_with_fixed_blocking( 184 self, 185 sample: Sample, 186 input_block_shape: Mapping[MemberId, Mapping[AxisId, int]], 187 *, 188 skip_preprocessing: bool = False, 189 skip_postprocessing: bool = False, 190 ) -> Sample: 191 if not skip_preprocessing: 192 self.apply_preprocessing(sample) 193 194 n_blocks, input_blocks = sample.split_into_blocks( 195 input_block_shape, 196 halo=self._default_input_halo, 197 pad_mode="reflect", 198 ) 199 input_blocks = list(input_blocks) 200 predicted_blocks: List[SampleBlock] = [] 201 for b in tqdm( 202 input_blocks, 203 desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}", 204 unit="block", 205 unit_divisor=1, 206 total=n_blocks, 207 ): 208 predicted_blocks.append( 209 self.predict_sample_block( 210 b, skip_preprocessing=True, skip_postprocessing=True 211 ) 212 ) 213 214 predicted_sample = Sample.from_blocks(predicted_blocks) 215 if not skip_postprocessing: 216 self.apply_postprocessing(predicted_sample) 217 218 return predicted_sample
220 def predict_sample_with_blocking( 221 self, 222 sample: Sample, 223 skip_preprocessing: bool = False, 224 skip_postprocessing: bool = False, 225 ns: Optional[ 226 Union[ 227 v0_5.ParameterizedSize_N, 228 Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N], 229 ] 230 ] = None, 231 batch_size: Optional[int] = None, 232 ) -> Sample: 233 """predict a sample by splitting it into blocks according to the model and the `ns` parameter""" 234 235 if isinstance(self.model_description, v0_4.ModelDescr): 236 raise NotImplementedError( 237 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" 238 + f" {self.model_description.name}." 239 + " Consider using `predict_sample_with_fixed_blocking`" 240 ) 241 242 ns = ns or self._default_ns 243 if isinstance(ns, int): 244 ns = { 245 (ipt.id, a.id): ns 246 for ipt in self.model_description.inputs 247 for a in ipt.axes 248 if isinstance(a.size, v0_5.ParameterizedSize) 249 } 250 input_block_shape = self.model_description.get_tensor_sizes( 251 ns, batch_size or self._default_batch_size 252 ).inputs 253 254 return self.predict_sample_with_fixed_blocking( 255 sample, 256 input_block_shape=input_block_shape, 257 skip_preprocessing=skip_preprocessing, 258 skip_postprocessing=skip_postprocessing, 259 )
predict a sample by splitting it into blocks according to the model and the ns
parameter
287 def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: 288 """apply preprocessing in-place, also updates sample stats""" 289 for op in self._preprocessing: 290 op(sample)
apply preprocessing in-place, also updates sample stats
292 def apply_postprocessing( 293 self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] 294 ) -> None: 295 """apply postprocessing in-place, also updates samples stats""" 296 for op in self._postprocessing: 297 if isinstance(sample, (Sample, SampleBlockWithOrigin)): 298 op(sample) 299 elif not isinstance(op, BlockedOperator): 300 raise NotImplementedError( 301 "block wise update of output statistics not yet implemented" 302 ) 303 else: 304 op(sample)
apply postprocessing in-place, also updates samples stats
44@dataclass 45class Sample: 46 """A dataset sample""" 47 48 members: Dict[MemberId, Tensor] 49 """the sample's tensors""" 50 51 stat: Stat 52 """sample and dataset statistics""" 53 54 id: SampleId 55 """identifier within the sample's dataset""" 56 57 @property 58 def shape(self) -> PerMember[PerAxis[int]]: 59 return {tid: t.sizes for tid, t in self.members.items()} 60 61 def split_into_blocks( 62 self, 63 block_shapes: PerMember[PerAxis[int]], 64 halo: PerMember[PerAxis[HaloLike]], 65 pad_mode: PadMode, 66 broadcast: bool = False, 67 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: 68 assert not ( 69 missing := [m for m in block_shapes if m not in self.members] 70 ), f"`block_shapes` specified for unknown members: {missing}" 71 assert not ( 72 missing := [m for m in halo if m not in block_shapes] 73 ), f"`halo` specified for members without `block_shape`: {missing}" 74 75 n_blocks, blocks = split_multiple_shapes_into_blocks( 76 shapes=self.shape, 77 block_shapes=block_shapes, 78 halo=halo, 79 broadcast=broadcast, 80 ) 81 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode) 82 83 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): 84 if halo is None: 85 halo = {} 86 return SampleBlockWithOrigin( 87 sample_shape=self.shape, 88 sample_id=self.id, 89 blocks={ 90 m: Block( 91 sample_shape=self.shape[m], 92 data=data, 93 inner_slice={ 94 a: SliceInfo(0, s) for a, s in data.tagged_shape.items() 95 }, 96 halo=halo.get(m, {}), 97 block_index=0, 98 blocks_in_sample=1, 99 ) 100 for m, data in self.members.items() 101 }, 102 stat=self.stat, 103 origin=self, 104 block_index=0, 105 blocks_in_sample=1, 106 ) 107 108 @classmethod 109 def from_blocks( 110 cls, 111 sample_blocks: Iterable[SampleBlock], 112 *, 113 fill_value: float = float("nan"), 114 ) -> Self: 115 members: PerMember[Tensor] = {} 116 stat: Stat = {} 117 sample_id = None 118 for sample_block in sample_blocks: 119 assert sample_id is None or sample_id == sample_block.sample_id 120 sample_id = sample_block.sample_id 121 stat = sample_block.stat 122 for m, block in sample_block.blocks.items(): 123 if m not in members: 124 if -1 in block.sample_shape.values(): 125 raise NotImplementedError( 126 "merging blocks with data dependent axis not yet implemented" 127 ) 128 129 members[m] = Tensor( 130 np.full( 131 tuple(block.sample_shape[a] for a in block.data.dims), 132 fill_value, 133 dtype=block.data.dtype, 134 ), 135 dims=block.data.dims, 136 ) 137 138 members[m][block.inner_slice] = block.inner_data 139 140 return cls(members=members, stat=stat, id=sample_id)
A dataset sample
sample and dataset statistics
61 def split_into_blocks( 62 self, 63 block_shapes: PerMember[PerAxis[int]], 64 halo: PerMember[PerAxis[HaloLike]], 65 pad_mode: PadMode, 66 broadcast: bool = False, 67 ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: 68 assert not ( 69 missing := [m for m in block_shapes if m not in self.members] 70 ), f"`block_shapes` specified for unknown members: {missing}" 71 assert not ( 72 missing := [m for m in halo if m not in block_shapes] 73 ), f"`halo` specified for members without `block_shape`: {missing}" 74 75 n_blocks, blocks = split_multiple_shapes_into_blocks( 76 shapes=self.shape, 77 block_shapes=block_shapes, 78 halo=halo, 79 broadcast=broadcast, 80 ) 81 return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode)
83 def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): 84 if halo is None: 85 halo = {} 86 return SampleBlockWithOrigin( 87 sample_shape=self.shape, 88 sample_id=self.id, 89 blocks={ 90 m: Block( 91 sample_shape=self.shape[m], 92 data=data, 93 inner_slice={ 94 a: SliceInfo(0, s) for a, s in data.tagged_shape.items() 95 }, 96 halo=halo.get(m, {}), 97 block_index=0, 98 blocks_in_sample=1, 99 ) 100 for m, data in self.members.items() 101 }, 102 stat=self.stat, 103 origin=self, 104 block_index=0, 105 blocks_in_sample=1, 106 )
108 @classmethod 109 def from_blocks( 110 cls, 111 sample_blocks: Iterable[SampleBlock], 112 *, 113 fill_value: float = float("nan"), 114 ) -> Self: 115 members: PerMember[Tensor] = {} 116 stat: Stat = {} 117 sample_id = None 118 for sample_block in sample_blocks: 119 assert sample_id is None or sample_id == sample_block.sample_id 120 sample_id = sample_block.sample_id 121 stat = sample_block.stat 122 for m, block in sample_block.blocks.items(): 123 if m not in members: 124 if -1 in block.sample_shape.values(): 125 raise NotImplementedError( 126 "merging blocks with data dependent axis not yet implemented" 127 ) 128 129 members[m] = Tensor( 130 np.full( 131 tuple(block.sample_shape[a] for a in block.data.dims), 132 fill_value, 133 dtype=block.data.dtype, 134 ), 135 dims=block.data.dims, 136 ) 137 138 members[m][block.inner_slice] = block.inner_data 139 140 return cls(members=members, stat=stat, id=sample_id)
121def save_bioimageio_package_as_folder( 122 source: Union[BioimageioYamlSource, ResourceDescr], 123 /, 124 *, 125 output_path: Union[NewPath, DirectoryPath, None] = None, 126 weights_priority_order: Optional[ # model only 127 Sequence[ 128 Literal[ 129 "keras_hdf5", 130 "onnx", 131 "pytorch_state_dict", 132 "tensorflow_js", 133 "tensorflow_saved_model_bundle", 134 "torchscript", 135 ] 136 ] 137 ] = None, 138) -> DirectoryPath: 139 """Write the content of a bioimage.io resource package to a folder. 140 141 Args: 142 source: bioimageio resource description 143 output_path: file path to write package to 144 weights_priority_order: If given only the first weights format present in the model is included. 145 If none of the prioritized weights formats is found all are included. 146 147 Returns: 148 directory path to bioimageio package folder 149 """ 150 package_content = _prepare_resource_package( 151 source, 152 weights_priority_order=weights_priority_order, 153 ) 154 if output_path is None: 155 output_path = Path(mkdtemp()) 156 else: 157 output_path = Path(output_path) 158 159 output_path.mkdir(exist_ok=True, parents=True) 160 for name, src in package_content.items(): 161 if isinstance(src, collections.abc.Mapping): 162 write_yaml(cast(YamlValue, src), output_path / name) 163 elif isinstance(src, ZipPath): 164 extracted = Path(src.root.extract(src.name, output_path)) 165 if extracted.name != src.name: 166 try: 167 shutil.move(str(extracted), output_path / src.name) 168 except Exception as e: 169 raise RuntimeError( 170 f"Failed to rename extracted file '{extracted.name}'" 171 + f" to '{src.name}'." 172 + f" (extracted from '{src.name}' in '{src.root.filename}')" 173 ) from e 174 else: 175 shutil.copy(src, output_path / name) 176 177 return output_path
Write the content of a bioimage.io resource package to a folder.
Arguments:
- source: bioimageio resource description
- output_path: file path to write package to
- weights_priority_order: If given only the first weights format present in the model is included. If none of the prioritized weights formats is found all are included.
Returns:
directory path to bioimageio package folder
180def save_bioimageio_package( 181 source: Union[BioimageioYamlSource, ResourceDescr], 182 /, 183 *, 184 compression: int = ZIP_DEFLATED, 185 compression_level: int = 1, 186 output_path: Union[NewPath, FilePath, None] = None, 187 weights_priority_order: Optional[ # model only 188 Sequence[ 189 Literal[ 190 "keras_hdf5", 191 "onnx", 192 "pytorch_state_dict", 193 "tensorflow_js", 194 "tensorflow_saved_model_bundle", 195 "torchscript", 196 ] 197 ] 198 ] = None, 199) -> FilePath: 200 """Package a bioimageio resource as a zip file. 201 202 Args: 203 rd: bioimageio resource description 204 compression: The numeric constant of compression method. 205 compression_level: Compression level to use when writing files to the archive. 206 See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile 207 output_path: file path to write package to 208 weights_priority_order: If given only the first weights format present in the model is included. 209 If none of the prioritized weights formats is found all are included. 210 211 Returns: 212 path to zipped bioimageio package 213 """ 214 package_content = _prepare_resource_package( 215 source, 216 weights_priority_order=weights_priority_order, 217 ) 218 if output_path is None: 219 output_path = Path( 220 NamedTemporaryFile(suffix=".bioimageio.zip", delete=False).name 221 ) 222 else: 223 output_path = Path(output_path) 224 225 write_zip( 226 output_path, 227 package_content, 228 compression=compression, 229 compression_level=compression_level, 230 ) 231 with validation_context_var.get().replace(warning_level=ERROR): 232 if isinstance((exported := load_description(output_path)), InvalidDescr): 233 raise ValueError( 234 f"Exported package '{output_path}' is invalid:" 235 + f" {exported.validation_summary}" 236 ) 237 238 return output_path
Package a bioimageio resource as a zip file.
Arguments:
- rd: bioimageio resource description
- compression: The numeric constant of compression method.
- compression_level: Compression level to use when writing files to the archive. See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile
- output_path: file path to write package to
- weights_priority_order: If given only the first weights format present in the model is included. If none of the prioritized weights formats is found all are included.
Returns:
path to zipped bioimageio package
118def save_bioimageio_yaml_only( 119 rd: Union[ResourceDescr, BioimageioYamlContent, InvalidDescr], 120 /, 121 file: Union[NewPath, FilePath, TextIO], 122): 123 """write the metadata of a resource description (`rd`) to `file` 124 without writing any of the referenced files in it. 125 126 Note: To save a resource description with its associated files as a package, 127 use `save_bioimageio_package` or `save_bioimageio_package_as_folder`. 128 """ 129 if isinstance(rd, ResourceDescrBase): 130 content = dump_description(rd) 131 else: 132 content = rd 133 134 write_yaml(cast(YamlValue, content), file)
write the metadata of a resource description (rd
) to file
without writing any of the referenced files in it.
Note: To save a resource description with its associated files as a package,
use save_bioimageio_package
or save_bioimageio_package_as_folder
.
49class Tensor(MagicTensorOpsMixin): 50 """A wrapper around an xr.DataArray for better integration with bioimageio.spec 51 and improved type annotations.""" 52 53 _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] 54 55 def __init__( 56 self, 57 array: NDArray[Any], 58 dims: Sequence[Union[AxisId, AxisLike]], 59 ) -> None: 60 super().__init__() 61 axes = tuple( 62 a if isinstance(a, AxisId) else AxisInfo.create(a).id for a in dims 63 ) 64 self._data = xr.DataArray(array, dims=axes) 65 66 def __array__(self, dtype: DTypeLike = None): 67 return np.asarray(self._data, dtype=dtype) 68 69 def __getitem__( 70 self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]] 71 ) -> Self: 72 if isinstance(key, SliceInfo): 73 key = slice(*key) 74 elif isinstance(key, collections.abc.Mapping): 75 key = { 76 a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) 77 for a, s in key.items() 78 } 79 return self.__class__.from_xarray(self._data[key]) 80 81 def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None: 82 key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} 83 self._data[key] = value._data 84 85 def __len__(self) -> int: 86 return len(self.data) 87 88 def _iter(self: Any) -> Iterator[Any]: 89 for n in range(len(self)): 90 yield self[n] 91 92 def __iter__(self: Any) -> Iterator[Any]: 93 if self.ndim == 0: 94 raise TypeError("iteration over a 0-d array") 95 return self._iter() 96 97 def _binary_op( 98 self, 99 other: _Compatible, 100 f: Callable[[Any, Any], Any], 101 reflexive: bool = False, 102 ) -> Self: 103 data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] 104 (other._data if isinstance(other, Tensor) else other), 105 f, 106 reflexive, 107 ) 108 return self.__class__.from_xarray(data) 109 110 def _inplace_binary_op( 111 self, 112 other: _Compatible, 113 f: Callable[[Any, Any], Any], 114 ) -> Self: 115 _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] 116 ( 117 other_d 118 if (other_d := getattr(other, "data")) is not None 119 and isinstance( 120 other_d, 121 xr.DataArray, 122 ) 123 else other 124 ), 125 f, 126 ) 127 return self 128 129 def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: 130 data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] 131 f, *args, **kwargs 132 ) 133 return self.__class__.from_xarray(data) 134 135 @classmethod 136 def from_xarray(cls, data_array: xr.DataArray) -> Self: 137 """create a `Tensor` from an xarray data array 138 139 note for internal use: this factory method is round-trip save 140 for any `Tensor`'s `data` property (an xarray.DataArray). 141 """ 142 return cls( 143 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 144 ) 145 146 @classmethod 147 def from_numpy( 148 cls, 149 array: NDArray[Any], 150 *, 151 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 152 ) -> Tensor: 153 """create a `Tensor` from a numpy array 154 155 Args: 156 array: the nd numpy array 157 axes: A description of the array's axes, 158 if None axes are guessed (which might fail and raise a ValueError.) 159 160 Raises: 161 ValueError: if `axes` is None and axes guessing fails. 162 """ 163 164 if dims is None: 165 return cls._interprete_array_wo_known_axes(array) 166 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 167 dims = [dims] 168 169 axis_infos = [AxisInfo.create(a) for a in dims] 170 original_shape = tuple(array.shape) 171 172 successful_view = _get_array_view(array, axis_infos) 173 if successful_view is None: 174 raise ValueError( 175 f"Array shape {original_shape} does not map to axes {dims}" 176 ) 177 178 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos)) 179 180 @property 181 def data(self): 182 return self._data 183 184 @property 185 def dims(self): # TODO: rename to `axes`? 186 """Tuple of dimension names associated with this tensor.""" 187 return cast(Tuple[AxisId, ...], self._data.dims) 188 189 @property 190 def tagged_shape(self): 191 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 192 return self.sizes 193 194 @property 195 def shape_tuple(self): 196 """Tuple of tensor axes lengths""" 197 return self._data.shape 198 199 @property 200 def size(self): 201 """Number of elements in the tensor. 202 203 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 204 """ 205 return self._data.size 206 207 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 208 """Reduce this Tensor's data by applying sum along some dimension(s).""" 209 return self.__class__.from_xarray(self._data.sum(dim=dim)) 210 211 @property 212 def ndim(self): 213 """Number of tensor dimensions.""" 214 return self._data.ndim 215 216 @property 217 def dtype(self) -> DTypeStr: 218 dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] 219 assert dt in get_args(DTypeStr) 220 return dt # pyright: ignore[reportReturnType] 221 222 @property 223 def sizes(self): 224 """Ordered, immutable mapping from axis ids to axis lengths.""" 225 return cast(Mapping[AxisId, int], self.data.sizes) 226 227 def astype(self, dtype: DTypeStr, *, copy: bool = False): 228 """Return tensor cast to `dtype` 229 230 note: if dtype is already satisfied copy if `copy`""" 231 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) 232 233 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 234 """Return a tensor whose values are limited to [min, max]. 235 At least one of max or min must be given.""" 236 return self.__class__.from_xarray(self._data.clip(min, max)) 237 238 def crop_to( 239 self, 240 sizes: PerAxis[int], 241 crop_where: Union[ 242 CropWhere, 243 PerAxis[CropWhere], 244 ] = "left_and_right", 245 ) -> Self: 246 """crop to match `sizes`""" 247 if isinstance(crop_where, str): 248 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 249 else: 250 crop_axis_where = crop_where 251 252 slices: Dict[AxisId, SliceInfo] = {} 253 254 for a, s_is in self.sizes.items(): 255 if a not in sizes or sizes[a] == s_is: 256 pass 257 elif sizes[a] > s_is: 258 logger.warning( 259 "Cannot crop axis {} of size {} to larger size {}", 260 a, 261 s_is, 262 sizes[a], 263 ) 264 elif a not in crop_axis_where: 265 raise ValueError( 266 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 267 ) 268 else: 269 crop_this_axis_where = crop_axis_where[a] 270 if crop_this_axis_where == "left": 271 slices[a] = SliceInfo(s_is - sizes[a], s_is) 272 elif crop_this_axis_where == "right": 273 slices[a] = SliceInfo(0, sizes[a]) 274 elif crop_this_axis_where == "left_and_right": 275 slices[a] = SliceInfo( 276 start := (s_is - sizes[a]) // 2, sizes[a] + start 277 ) 278 else: 279 assert_never(crop_this_axis_where) 280 281 return self[slices] 282 283 def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: 284 return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) 285 286 def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 287 return self.__class__.from_xarray(self._data.mean(dim=dim)) 288 289 def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 290 return self.__class__.from_xarray(self._data.std(dim=dim)) 291 292 def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 293 return self.__class__.from_xarray(self._data.var(dim=dim)) 294 295 def pad( 296 self, 297 pad_width: PerAxis[PadWidthLike], 298 mode: PadMode = "symmetric", 299 ) -> Self: 300 pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} 301 return self.__class__.from_xarray( 302 self._data.pad(pad_width=pad_width, mode=mode) 303 ) 304 305 def pad_to( 306 self, 307 sizes: PerAxis[int], 308 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 309 mode: PadMode = "symmetric", 310 ) -> Self: 311 """pad `tensor` to match `sizes`""" 312 if isinstance(pad_where, str): 313 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 314 else: 315 pad_axis_where = pad_where 316 317 pad_width: Dict[AxisId, PadWidth] = {} 318 for a, s_is in self.sizes.items(): 319 if a not in sizes or sizes[a] == s_is: 320 pad_width[a] = PadWidth(0, 0) 321 elif s_is > sizes[a]: 322 pad_width[a] = PadWidth(0, 0) 323 logger.warning( 324 "Cannot pad axis {} of size {} to smaller size {}", 325 a, 326 s_is, 327 sizes[a], 328 ) 329 elif a not in pad_axis_where: 330 raise ValueError( 331 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 332 ) 333 else: 334 pad_this_axis_where = pad_axis_where[a] 335 d = sizes[a] - s_is 336 if pad_this_axis_where == "left": 337 pad_width[a] = PadWidth(d, 0) 338 elif pad_this_axis_where == "right": 339 pad_width[a] = PadWidth(0, d) 340 elif pad_this_axis_where == "left_and_right": 341 pad_width[a] = PadWidth(left := d // 2, d - left) 342 else: 343 assert_never(pad_this_axis_where) 344 345 return self.pad(pad_width, mode) 346 347 def quantile( 348 self, 349 q: Union[float, Sequence[float]], 350 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 351 ) -> Self: 352 assert ( 353 isinstance(q, (float, int)) 354 and q >= 0.0 355 or not isinstance(q, (float, int)) 356 and all(qq >= 0.0 for qq in q) 357 ) 358 assert ( 359 isinstance(q, (float, int)) 360 and q <= 1.0 361 or not isinstance(q, (float, int)) 362 and all(qq <= 1.0 for qq in q) 363 ) 364 assert dim is None or ( 365 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 366 ) 367 return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) 368 369 def resize_to( 370 self, 371 sizes: PerAxis[int], 372 *, 373 pad_where: Union[ 374 PadWhere, 375 PerAxis[PadWhere], 376 ] = "left_and_right", 377 crop_where: Union[ 378 CropWhere, 379 PerAxis[CropWhere], 380 ] = "left_and_right", 381 pad_mode: PadMode = "symmetric", 382 ): 383 """return cropped/padded tensor with `sizes`""" 384 crop_to_sizes: Dict[AxisId, int] = {} 385 pad_to_sizes: Dict[AxisId, int] = {} 386 new_axes = dict(sizes) 387 for a, s_is in self.sizes.items(): 388 a = AxisId(str(a)) 389 _ = new_axes.pop(a, None) 390 if a not in sizes or sizes[a] == s_is: 391 pass 392 elif s_is > sizes[a]: 393 crop_to_sizes[a] = sizes[a] 394 else: 395 pad_to_sizes[a] = sizes[a] 396 397 tensor = self 398 if crop_to_sizes: 399 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 400 401 if pad_to_sizes: 402 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 403 404 if new_axes: 405 tensor = tensor.expand_dims(new_axes) 406 407 return tensor 408 409 def transpose( 410 self, 411 axes: Sequence[AxisId], 412 ) -> Self: 413 """return a transposed tensor 414 415 Args: 416 axes: the desired tensor axes 417 """ 418 # expand missing tensor axes 419 missing_axes = tuple(a for a in axes if a not in self.dims) 420 array = self._data 421 if missing_axes: 422 array = array.expand_dims(missing_axes) 423 424 # transpose to the correct axis order 425 return self.__class__.from_xarray(array.transpose(*axes)) 426 427 @classmethod 428 def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): 429 ndim = array.ndim 430 if ndim == 2: 431 current_axes = ( 432 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), 433 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), 434 ) 435 elif ndim == 3 and any(s <= 3 for s in array.shape): 436 current_axes = ( 437 v0_5.ChannelAxis( 438 channel_names=[ 439 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 440 ] 441 ), 442 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 443 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 444 ) 445 elif ndim == 3: 446 current_axes = ( 447 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), 448 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), 449 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), 450 ) 451 elif ndim == 4: 452 current_axes = ( 453 v0_5.ChannelAxis( 454 channel_names=[ 455 v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) 456 ] 457 ), 458 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), 459 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), 460 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), 461 ) 462 elif ndim == 5: 463 current_axes = ( 464 v0_5.BatchAxis(), 465 v0_5.ChannelAxis( 466 channel_names=[ 467 v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) 468 ] 469 ), 470 v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), 471 v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), 472 v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), 473 ) 474 else: 475 raise ValueError(f"Could not guess an axis mapping for {array.shape}") 476 477 return cls(array, dims=tuple(a.id for a in current_axes))
A wrapper around an xr.DataArray for better integration with bioimageio.spec and improved type annotations.
135 @classmethod 136 def from_xarray(cls, data_array: xr.DataArray) -> Self: 137 """create a `Tensor` from an xarray data array 138 139 note for internal use: this factory method is round-trip save 140 for any `Tensor`'s `data` property (an xarray.DataArray). 141 """ 142 return cls( 143 array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) 144 )
146 @classmethod 147 def from_numpy( 148 cls, 149 array: NDArray[Any], 150 *, 151 dims: Optional[Union[AxisLike, Sequence[AxisLike]]], 152 ) -> Tensor: 153 """create a `Tensor` from a numpy array 154 155 Args: 156 array: the nd numpy array 157 axes: A description of the array's axes, 158 if None axes are guessed (which might fail and raise a ValueError.) 159 160 Raises: 161 ValueError: if `axes` is None and axes guessing fails. 162 """ 163 164 if dims is None: 165 return cls._interprete_array_wo_known_axes(array) 166 elif isinstance(dims, (str, Axis, v0_5.AxisBase)): 167 dims = [dims] 168 169 axis_infos = [AxisInfo.create(a) for a in dims] 170 original_shape = tuple(array.shape) 171 172 successful_view = _get_array_view(array, axis_infos) 173 if successful_view is None: 174 raise ValueError( 175 f"Array shape {original_shape} does not map to axes {dims}" 176 ) 177 178 return Tensor(successful_view, dims=tuple(a.id for a in axis_infos))
create a Tensor
from a numpy array
Arguments:
- array: the nd numpy array
- axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.)
Raises:
- ValueError: if
axes
is None and axes guessing fails.
184 @property 185 def dims(self): # TODO: rename to `axes`? 186 """Tuple of dimension names associated with this tensor.""" 187 return cast(Tuple[AxisId, ...], self._data.dims)
Tuple of dimension names associated with this tensor.
189 @property 190 def tagged_shape(self): 191 """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" 192 return self.sizes
(alias for sizes
) Ordered, immutable mapping from axis ids to lengths.
194 @property 195 def shape_tuple(self): 196 """Tuple of tensor axes lengths""" 197 return self._data.shape
Tuple of tensor axes lengths
199 @property 200 def size(self): 201 """Number of elements in the tensor. 202 203 Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. 204 """ 205 return self._data.size
Number of elements in the tensor.
Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions.
207 def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: 208 """Reduce this Tensor's data by applying sum along some dimension(s).""" 209 return self.__class__.from_xarray(self._data.sum(dim=dim))
Reduce this Tensor's data by applying sum along some dimension(s).
222 @property 223 def sizes(self): 224 """Ordered, immutable mapping from axis ids to axis lengths.""" 225 return cast(Mapping[AxisId, int], self.data.sizes)
Ordered, immutable mapping from axis ids to axis lengths.
227 def astype(self, dtype: DTypeStr, *, copy: bool = False): 228 """Return tensor cast to `dtype` 229 230 note: if dtype is already satisfied copy if `copy`""" 231 return self.__class__.from_xarray(self._data.astype(dtype, copy=copy))
Return tensor cast to dtype
note: if dtype is already satisfied copy if copy
233 def clip(self, min: Optional[float] = None, max: Optional[float] = None): 234 """Return a tensor whose values are limited to [min, max]. 235 At least one of max or min must be given.""" 236 return self.__class__.from_xarray(self._data.clip(min, max))
Return a tensor whose values are limited to [min, max]. At least one of max or min must be given.
238 def crop_to( 239 self, 240 sizes: PerAxis[int], 241 crop_where: Union[ 242 CropWhere, 243 PerAxis[CropWhere], 244 ] = "left_and_right", 245 ) -> Self: 246 """crop to match `sizes`""" 247 if isinstance(crop_where, str): 248 crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} 249 else: 250 crop_axis_where = crop_where 251 252 slices: Dict[AxisId, SliceInfo] = {} 253 254 for a, s_is in self.sizes.items(): 255 if a not in sizes or sizes[a] == s_is: 256 pass 257 elif sizes[a] > s_is: 258 logger.warning( 259 "Cannot crop axis {} of size {} to larger size {}", 260 a, 261 s_is, 262 sizes[a], 263 ) 264 elif a not in crop_axis_where: 265 raise ValueError( 266 f"Don't know where to crop axis {a}, `crop_where`={crop_where}" 267 ) 268 else: 269 crop_this_axis_where = crop_axis_where[a] 270 if crop_this_axis_where == "left": 271 slices[a] = SliceInfo(s_is - sizes[a], s_is) 272 elif crop_this_axis_where == "right": 273 slices[a] = SliceInfo(0, sizes[a]) 274 elif crop_this_axis_where == "left_and_right": 275 slices[a] = SliceInfo( 276 start := (s_is - sizes[a]) // 2, sizes[a] + start 277 ) 278 else: 279 assert_never(crop_this_axis_where) 280 281 return self[slices]
crop to match sizes
305 def pad_to( 306 self, 307 sizes: PerAxis[int], 308 pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", 309 mode: PadMode = "symmetric", 310 ) -> Self: 311 """pad `tensor` to match `sizes`""" 312 if isinstance(pad_where, str): 313 pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} 314 else: 315 pad_axis_where = pad_where 316 317 pad_width: Dict[AxisId, PadWidth] = {} 318 for a, s_is in self.sizes.items(): 319 if a not in sizes or sizes[a] == s_is: 320 pad_width[a] = PadWidth(0, 0) 321 elif s_is > sizes[a]: 322 pad_width[a] = PadWidth(0, 0) 323 logger.warning( 324 "Cannot pad axis {} of size {} to smaller size {}", 325 a, 326 s_is, 327 sizes[a], 328 ) 329 elif a not in pad_axis_where: 330 raise ValueError( 331 f"Don't know where to pad axis {a}, `pad_where`={pad_where}" 332 ) 333 else: 334 pad_this_axis_where = pad_axis_where[a] 335 d = sizes[a] - s_is 336 if pad_this_axis_where == "left": 337 pad_width[a] = PadWidth(d, 0) 338 elif pad_this_axis_where == "right": 339 pad_width[a] = PadWidth(0, d) 340 elif pad_this_axis_where == "left_and_right": 341 pad_width[a] = PadWidth(left := d // 2, d - left) 342 else: 343 assert_never(pad_this_axis_where) 344 345 return self.pad(pad_width, mode)
pad tensor
to match sizes
347 def quantile( 348 self, 349 q: Union[float, Sequence[float]], 350 dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, 351 ) -> Self: 352 assert ( 353 isinstance(q, (float, int)) 354 and q >= 0.0 355 or not isinstance(q, (float, int)) 356 and all(qq >= 0.0 for qq in q) 357 ) 358 assert ( 359 isinstance(q, (float, int)) 360 and q <= 1.0 361 or not isinstance(q, (float, int)) 362 and all(qq <= 1.0 for qq in q) 363 ) 364 assert dim is None or ( 365 (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) 366 ) 367 return self.__class__.from_xarray(self._data.quantile(q, dim=dim))
369 def resize_to( 370 self, 371 sizes: PerAxis[int], 372 *, 373 pad_where: Union[ 374 PadWhere, 375 PerAxis[PadWhere], 376 ] = "left_and_right", 377 crop_where: Union[ 378 CropWhere, 379 PerAxis[CropWhere], 380 ] = "left_and_right", 381 pad_mode: PadMode = "symmetric", 382 ): 383 """return cropped/padded tensor with `sizes`""" 384 crop_to_sizes: Dict[AxisId, int] = {} 385 pad_to_sizes: Dict[AxisId, int] = {} 386 new_axes = dict(sizes) 387 for a, s_is in self.sizes.items(): 388 a = AxisId(str(a)) 389 _ = new_axes.pop(a, None) 390 if a not in sizes or sizes[a] == s_is: 391 pass 392 elif s_is > sizes[a]: 393 crop_to_sizes[a] = sizes[a] 394 else: 395 pad_to_sizes[a] = sizes[a] 396 397 tensor = self 398 if crop_to_sizes: 399 tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) 400 401 if pad_to_sizes: 402 tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) 403 404 if new_axes: 405 tensor = tensor.expand_dims(new_axes) 406 407 return tensor
return cropped/padded tensor with sizes
409 def transpose( 410 self, 411 axes: Sequence[AxisId], 412 ) -> Self: 413 """return a transposed tensor 414 415 Args: 416 axes: the desired tensor axes 417 """ 418 # expand missing tensor axes 419 missing_axes = tuple(a for a in axes if a not in self.dims) 420 array = self._data 421 if missing_axes: 422 array = array.expand_dims(missing_axes) 423 424 # transpose to the correct axis order 425 return self.__class__.from_xarray(array.transpose(*axes))
return a transposed tensor
Arguments:
- axes: the desired tensor axes
120def test_description( 121 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 122 *, 123 format_version: Union[Literal["discover", "latest"], str] = "discover", 124 weight_format: Optional[WeightsFormat] = None, 125 devices: Optional[Sequence[str]] = None, 126 absolute_tolerance: float = 1.5e-4, 127 relative_tolerance: float = 1e-4, 128 decimal: Optional[int] = None, 129 determinism: Literal["seed_only", "full"] = "seed_only", 130 expected_type: Optional[str] = None, 131) -> ValidationSummary: 132 """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models""" 133 rd = load_description_and_test( 134 source, 135 format_version=format_version, 136 weight_format=weight_format, 137 devices=devices, 138 absolute_tolerance=absolute_tolerance, 139 relative_tolerance=relative_tolerance, 140 decimal=decimal, 141 determinism=determinism, 142 expected_type=expected_type, 143 ) 144 return rd.validation_summary
Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models
97def test_model( 98 source: Union[v0_5.ModelDescr, PermissiveFileSource], 99 weight_format: Optional[WeightsFormat] = None, 100 devices: Optional[List[str]] = None, 101 absolute_tolerance: float = 1.5e-4, 102 relative_tolerance: float = 1e-4, 103 decimal: Optional[int] = None, 104 *, 105 determinism: Literal["seed_only", "full"] = "seed_only", 106) -> ValidationSummary: 107 """Test model inference""" 108 return test_description( 109 source, 110 weight_format=weight_format, 111 devices=devices, 112 absolute_tolerance=absolute_tolerance, 113 relative_tolerance=relative_tolerance, 114 decimal=decimal, 115 determinism=determinism, 116 expected_type="model", 117 )
Test model inference
120def test_description( 121 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 122 *, 123 format_version: Union[Literal["discover", "latest"], str] = "discover", 124 weight_format: Optional[WeightsFormat] = None, 125 devices: Optional[Sequence[str]] = None, 126 absolute_tolerance: float = 1.5e-4, 127 relative_tolerance: float = 1e-4, 128 decimal: Optional[int] = None, 129 determinism: Literal["seed_only", "full"] = "seed_only", 130 expected_type: Optional[str] = None, 131) -> ValidationSummary: 132 """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models""" 133 rd = load_description_and_test( 134 source, 135 format_version=format_version, 136 weight_format=weight_format, 137 devices=devices, 138 absolute_tolerance=absolute_tolerance, 139 relative_tolerance=relative_tolerance, 140 decimal=decimal, 141 determinism=determinism, 142 expected_type=expected_type, 143 ) 144 return rd.validation_summary
alias of test_description
161def validate_format( 162 data: BioimageioYamlContent, 163 /, 164 *, 165 format_version: Union[Literal["discover", "latest"], str] = DISCOVER, 166 context: Optional[ValidationContext] = None, 167) -> ValidationSummary: 168 """validate a bioimageio.yaml file (RDF)""" 169 with context or validation_context_var.get(): 170 rd = build_description(data, format_version=format_version) 171 172 assert rd.validation_summary is not None 173 return rd.validation_summary
validate a bioimageio.yaml file (RDF)