bioimageio.core.digest_spec
1from __future__ import annotations 2 3import collections.abc 4import importlib.util 5import sys 6from itertools import chain 7from pathlib import Path 8from tempfile import TemporaryDirectory 9from typing import ( 10 Any, 11 Callable, 12 Dict, 13 Iterable, 14 List, 15 Mapping, 16 NamedTuple, 17 Optional, 18 Sequence, 19 Tuple, 20 Union, 21) 22from zipfile import ZipFile, is_zipfile 23 24import numpy as np 25import xarray as xr 26from loguru import logger 27from numpy.typing import NDArray 28from typing_extensions import Unpack, assert_never 29 30from bioimageio.spec._internal.io import HashKwargs 31from bioimageio.spec.common import FileDescr, FileSource, ZipPath 32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 34from bioimageio.spec.model.v0_5 import ( 35 ArchitectureFromFileDescr, 36 ArchitectureFromLibraryDescr, 37 ParameterizedSize_N, 38) 39from bioimageio.spec.utils import load_array 40 41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 42from .block_meta import split_multiple_shapes_into_blocks 43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 44from .io import load_tensor 45from .sample import ( 46 LinearSampleAxisTransform, 47 Sample, 48 SampleBlockMeta, 49 sample_block_meta_generator, 50) 51from .stat_measures import Stat 52from .tensor import Tensor 53 54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path] 55 56 57def import_callable( 58 node: Union[ 59 ArchitectureFromFileDescr, 60 ArchitectureFromLibraryDescr, 61 CallableFromDepencency, 62 CallableFromFile, 63 ], 64 /, 65 **kwargs: Unpack[HashKwargs], 66) -> Callable[..., Any]: 67 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 68 if isinstance(node, CallableFromDepencency): 69 module = importlib.import_module(node.module_name) 70 c = getattr(module, str(node.callable_name)) 71 elif isinstance(node, ArchitectureFromLibraryDescr): 72 module = importlib.import_module(node.import_from) 73 c = getattr(module, str(node.callable)) 74 elif isinstance(node, CallableFromFile): 75 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 76 elif isinstance(node, ArchitectureFromFileDescr): 77 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 78 else: 79 assert_never(node) 80 81 if not callable(c): 82 raise ValueError(f"{node} (imported: {c}) is not callable") 83 84 return c 85 86 87def _import_from_file_impl( 88 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 89): 90 src_descr = FileDescr(source=source, **kwargs) 91 # ensure sha is valid even if perform_io_checks=False 92 # or the source has changed since last sha computation 93 src_descr.validate_sha256(force_recompute=True) 94 assert src_descr.sha256 is not None 95 source_sha = src_descr.sha256 96 97 reader = src_descr.get_reader() 98 # make sure we have unique module name 99 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}" 100 101 # make sure we have a unique and valid module name 102 if not module_name.isidentifier(): 103 module_name = f"custom_module_{source_sha}" 104 assert module_name.isidentifier(), module_name 105 106 source_bytes = reader.read() 107 108 module = sys.modules.get(module_name) 109 if module is None: 110 try: 111 tmp_dir = TemporaryDirectory(ignore_cleanup_errors=True) 112 module_path = Path(tmp_dir.name) / module_name 113 if reader.original_file_name.endswith(".zip") or is_zipfile(reader): 114 module_path.mkdir() 115 ZipFile(reader).extractall(path=module_path) 116 else: 117 module_path = module_path.with_suffix(".py") 118 _ = module_path.write_bytes(source_bytes) 119 120 importlib_spec = importlib.util.spec_from_file_location( 121 module_name, str(module_path) 122 ) 123 124 if importlib_spec is None: 125 raise ImportError(f"Failed to import {source}") 126 127 module = importlib.util.module_from_spec(importlib_spec) 128 129 sys.modules[module_name] = module # cache this module 130 131 assert importlib_spec.loader is not None 132 importlib_spec.loader.exec_module(module) 133 134 except Exception as e: 135 del sys.modules[module_name] 136 raise ImportError(f"Failed to import {source}") from e 137 138 try: 139 callable_attr = getattr(module, callable_name) 140 except AttributeError as e: 141 raise AttributeError( 142 f"Imported custom module from {source} has no `{callable_name}` attribute." 143 ) from e 144 except Exception as e: 145 raise AttributeError( 146 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ." 147 ) from e 148 149 else: 150 return callable_attr 151 152 153def get_axes_infos( 154 io_descr: Union[ 155 v0_4.InputTensorDescr, 156 v0_4.OutputTensorDescr, 157 v0_5.InputTensorDescr, 158 v0_5.OutputTensorDescr, 159 ], 160) -> List[AxisInfo]: 161 """get a unified, simplified axis representation from spec axes""" 162 ret: List[AxisInfo] = [] 163 for a in io_descr.axes: 164 if isinstance(a, v0_5.AxisBase): 165 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 166 else: 167 assert a in ("b", "i", "t", "c", "z", "y", "x") 168 ret.append(AxisInfo.create(a)) 169 170 return ret 171 172 173def get_member_id( 174 tensor_description: Union[ 175 v0_4.InputTensorDescr, 176 v0_4.OutputTensorDescr, 177 v0_5.InputTensorDescr, 178 v0_5.OutputTensorDescr, 179 ], 180) -> MemberId: 181 """get the normalized tensor ID, usable as a sample member ID""" 182 183 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 184 return MemberId(tensor_description.name) 185 elif isinstance( 186 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 187 ): 188 return tensor_description.id 189 else: 190 assert_never(tensor_description) 191 192 193def get_member_ids( 194 tensor_descriptions: Sequence[ 195 Union[ 196 v0_4.InputTensorDescr, 197 v0_4.OutputTensorDescr, 198 v0_5.InputTensorDescr, 199 v0_5.OutputTensorDescr, 200 ] 201 ], 202) -> List[MemberId]: 203 """get normalized tensor IDs to be used as sample member IDs""" 204 return [get_member_id(descr) for descr in tensor_descriptions] 205 206 207def get_test_inputs(model: AnyModelDescr) -> Sample: 208 """returns a model's test input sample""" 209 member_ids = get_member_ids(model.inputs) 210 if isinstance(model, v0_4.ModelDescr): 211 arrays = [load_array(tt) for tt in model.test_inputs] 212 else: 213 arrays = [load_array(d.test_tensor) for d in model.inputs] 214 215 axes = [get_axes_infos(t) for t in model.inputs] 216 return Sample( 217 members={ 218 m: Tensor.from_numpy(arr, dims=ax) 219 for m, arr, ax in zip(member_ids, arrays, axes) 220 }, 221 stat={}, 222 id="test-sample", 223 ) 224 225 226def get_test_outputs(model: AnyModelDescr) -> Sample: 227 """returns a model's test output sample""" 228 member_ids = get_member_ids(model.outputs) 229 230 if isinstance(model, v0_4.ModelDescr): 231 arrays = [load_array(tt) for tt in model.test_outputs] 232 else: 233 arrays = [load_array(d.test_tensor) for d in model.outputs] 234 235 axes = [get_axes_infos(t) for t in model.outputs] 236 237 return Sample( 238 members={ 239 m: Tensor.from_numpy(arr, dims=ax) 240 for m, arr, ax in zip(member_ids, arrays, axes) 241 }, 242 stat={}, 243 id="test-sample", 244 ) 245 246 247class IO_SampleBlockMeta(NamedTuple): 248 input: SampleBlockMeta 249 output: SampleBlockMeta 250 251 252def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 253 """returns which halo input tensors need to be divided into blocks with, such that 254 `output_halo` can be cropped from their outputs without introducing gaps.""" 255 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 256 outputs = {t.id: t for t in model.outputs} 257 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 258 259 for t, th in output_halo.items(): 260 axes = {a.id: a for a in outputs[t].axes} 261 262 for a, ah in th.items(): 263 s = axes[a].size 264 if not isinstance(s, v0_5.SizeReference): 265 raise ValueError( 266 f"Unable to map output halo for {t}.{a} to an input axis" 267 ) 268 269 axis = axes[a] 270 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 271 272 total_output_halo = sum(ah) 273 total_input_halo = total_output_halo * axis.scale / ref_axis.scale 274 assert ( 275 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 276 ) 277 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 278 int(total_input_halo // 2), int(total_input_halo // 2) 279 ) 280 281 return input_halo 282 283 284def get_block_transform( 285 model: v0_5.ModelDescr, 286) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 287 """returns how a model's output tensor shapes relates to its input shapes""" 288 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 289 batch_axis_trf = None 290 for ipt in model.inputs: 291 for a in ipt.axes: 292 if a.type == "batch": 293 batch_axis_trf = LinearSampleAxisTransform( 294 axis=a.id, scale=1, offset=0, member=ipt.id 295 ) 296 break 297 if batch_axis_trf is not None: 298 break 299 axis_scales = { 300 t.id: {a.id: a.scale for a in t.axes} 301 for t in chain(model.inputs, model.outputs) 302 } 303 for out in model.outputs: 304 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 305 for a in out.axes: 306 if a.size is None: 307 assert a.type == "batch" 308 if batch_axis_trf is None: 309 raise ValueError( 310 "no batch axis found in any input tensor, but output tensor" 311 + f" '{out.id}' has one." 312 ) 313 s = batch_axis_trf 314 elif isinstance(a.size, int): 315 s = a.size 316 elif isinstance(a.size, v0_5.DataDependentSize): 317 s = -1 318 elif isinstance(a.size, v0_5.SizeReference): 319 s = LinearSampleAxisTransform( 320 axis=a.size.axis_id, 321 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 322 offset=a.size.offset, 323 member=a.size.tensor_id, 324 ) 325 else: 326 assert_never(a.size) 327 328 new_axes[a.id] = s 329 330 ret[out.id] = new_axes 331 332 return ret 333 334 335def get_io_sample_block_metas( 336 model: v0_5.ModelDescr, 337 input_sample_shape: PerMember[PerAxis[int]], 338 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 339 batch_size: int = 1, 340) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 341 """returns an iterable yielding meta data for corresponding input and output samples""" 342 if not isinstance(model, v0_5.ModelDescr): 343 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 344 345 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 346 input_block_shape = { 347 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 348 for t in {tt for tt, _ in block_axis_sizes.inputs} 349 } 350 output_halo = { 351 t.id: { 352 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 353 } 354 for t in model.outputs 355 } 356 input_halo = get_input_halo(model, output_halo) 357 358 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 359 input_sample_shape, input_block_shape, halo=input_halo 360 ) 361 block_transform = get_block_transform(model) 362 return n_input_blocks, ( 363 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 364 for ipt in sample_block_meta_generator( 365 input_blocks, sample_shape=input_sample_shape, sample_id=None 366 ) 367 ) 368 369 370def get_tensor( 371 src: Union[ZipPath, TensorSource], 372 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 373): 374 """helper to cast/load various tensor sources""" 375 376 if isinstance(src, Tensor): 377 return src 378 elif isinstance(src, xr.DataArray): 379 return Tensor.from_xarray(src) 380 elif isinstance(src, np.ndarray): 381 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 382 else: 383 return load_tensor(src, axes=get_axes_infos(ipt)) 384 385 386def create_sample_for_model( 387 model: AnyModelDescr, 388 *, 389 stat: Optional[Stat] = None, 390 sample_id: SampleId = None, 391 inputs: Union[PerMember[TensorSource], TensorSource], 392) -> Sample: 393 """Create a sample from a single set of input(s) for a specific bioimage.io model 394 395 Args: 396 model: a bioimage.io model description 397 stat: dictionary with sample and dataset statistics (may be updated in-place!) 398 inputs: the input(s) constituting a single sample. 399 """ 400 401 model_inputs = {get_member_id(d): d for d in model.inputs} 402 if isinstance(inputs, collections.abc.Mapping): 403 inputs = {MemberId(k): v for k, v in inputs.items()} 404 elif len(model_inputs) == 1: 405 inputs = {list(model_inputs)[0]: inputs} 406 else: 407 raise TypeError( 408 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 409 ) 410 411 if unknown := {k for k in inputs if k not in model_inputs}: 412 raise ValueError(f"Got unexpected inputs: {unknown}") 413 414 if missing := { 415 k 416 for k, v in model_inputs.items() 417 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 418 }: 419 raise ValueError(f"Missing non-optional model inputs: {missing}") 420 421 return Sample( 422 members={ 423 m: get_tensor(inputs[m], ipt) 424 for m, ipt in model_inputs.items() 425 if m in inputs 426 }, 427 stat={} if stat is None else stat, 428 id=sample_id, 429 ) 430 431 432def load_sample_for_model( 433 *, 434 model: AnyModelDescr, 435 paths: PerMember[Path], 436 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 437 stat: Optional[Stat] = None, 438 sample_id: Optional[SampleId] = None, 439): 440 """load a single sample from `paths` that can be processed by `model`""" 441 442 if axes is None: 443 axes = {} 444 445 # make sure members are keyed by MemberId, not string 446 paths = {MemberId(k): v for k, v in paths.items()} 447 axes = {MemberId(k): v for k, v in axes.items()} 448 449 model_inputs = {get_member_id(d): d for d in model.inputs} 450 451 if unknown := {k for k in paths if k not in model_inputs}: 452 raise ValueError(f"Got unexpected paths for {unknown}") 453 454 if unknown := {k for k in axes if k not in model_inputs}: 455 raise ValueError(f"Got unexpected axes hints for: {unknown}") 456 457 members: Dict[MemberId, Tensor] = {} 458 for m, p in paths.items(): 459 if m not in axes: 460 axes[m] = get_axes_infos(model_inputs[m]) 461 logger.debug( 462 "loading '{}' from {} with default input axes {} ", 463 m, 464 p, 465 axes[m], 466 ) 467 members[m] = load_tensor(p, axes[m]) 468 469 return Sample( 470 members=members, 471 stat={} if stat is None else stat, 472 id=sample_id or tuple(sorted(paths.values())), 473 )
58def import_callable( 59 node: Union[ 60 ArchitectureFromFileDescr, 61 ArchitectureFromLibraryDescr, 62 CallableFromDepencency, 63 CallableFromFile, 64 ], 65 /, 66 **kwargs: Unpack[HashKwargs], 67) -> Callable[..., Any]: 68 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 69 if isinstance(node, CallableFromDepencency): 70 module = importlib.import_module(node.module_name) 71 c = getattr(module, str(node.callable_name)) 72 elif isinstance(node, ArchitectureFromLibraryDescr): 73 module = importlib.import_module(node.import_from) 74 c = getattr(module, str(node.callable)) 75 elif isinstance(node, CallableFromFile): 76 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 77 elif isinstance(node, ArchitectureFromFileDescr): 78 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 79 else: 80 assert_never(node) 81 82 if not callable(c): 83 raise ValueError(f"{node} (imported: {c}) is not callable") 84 85 return c
import a callable (e.g. a torch.nn.Module) from a spec node describing it
154def get_axes_infos( 155 io_descr: Union[ 156 v0_4.InputTensorDescr, 157 v0_4.OutputTensorDescr, 158 v0_5.InputTensorDescr, 159 v0_5.OutputTensorDescr, 160 ], 161) -> List[AxisInfo]: 162 """get a unified, simplified axis representation from spec axes""" 163 ret: List[AxisInfo] = [] 164 for a in io_descr.axes: 165 if isinstance(a, v0_5.AxisBase): 166 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 167 else: 168 assert a in ("b", "i", "t", "c", "z", "y", "x") 169 ret.append(AxisInfo.create(a)) 170 171 return ret
get a unified, simplified axis representation from spec axes
174def get_member_id( 175 tensor_description: Union[ 176 v0_4.InputTensorDescr, 177 v0_4.OutputTensorDescr, 178 v0_5.InputTensorDescr, 179 v0_5.OutputTensorDescr, 180 ], 181) -> MemberId: 182 """get the normalized tensor ID, usable as a sample member ID""" 183 184 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 185 return MemberId(tensor_description.name) 186 elif isinstance( 187 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 188 ): 189 return tensor_description.id 190 else: 191 assert_never(tensor_description)
get the normalized tensor ID, usable as a sample member ID
194def get_member_ids( 195 tensor_descriptions: Sequence[ 196 Union[ 197 v0_4.InputTensorDescr, 198 v0_4.OutputTensorDescr, 199 v0_5.InputTensorDescr, 200 v0_5.OutputTensorDescr, 201 ] 202 ], 203) -> List[MemberId]: 204 """get normalized tensor IDs to be used as sample member IDs""" 205 return [get_member_id(descr) for descr in tensor_descriptions]
get normalized tensor IDs to be used as sample member IDs
208def get_test_inputs(model: AnyModelDescr) -> Sample: 209 """returns a model's test input sample""" 210 member_ids = get_member_ids(model.inputs) 211 if isinstance(model, v0_4.ModelDescr): 212 arrays = [load_array(tt) for tt in model.test_inputs] 213 else: 214 arrays = [load_array(d.test_tensor) for d in model.inputs] 215 216 axes = [get_axes_infos(t) for t in model.inputs] 217 return Sample( 218 members={ 219 m: Tensor.from_numpy(arr, dims=ax) 220 for m, arr, ax in zip(member_ids, arrays, axes) 221 }, 222 stat={}, 223 id="test-sample", 224 )
returns a model's test input sample
227def get_test_outputs(model: AnyModelDescr) -> Sample: 228 """returns a model's test output sample""" 229 member_ids = get_member_ids(model.outputs) 230 231 if isinstance(model, v0_4.ModelDescr): 232 arrays = [load_array(tt) for tt in model.test_outputs] 233 else: 234 arrays = [load_array(d.test_tensor) for d in model.outputs] 235 236 axes = [get_axes_infos(t) for t in model.outputs] 237 238 return Sample( 239 members={ 240 m: Tensor.from_numpy(arr, dims=ax) 241 for m, arr, ax in zip(member_ids, arrays, axes) 242 }, 243 stat={}, 244 id="test-sample", 245 )
returns a model's test output sample
IO_SampleBlockMeta(input, output)
253def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 254 """returns which halo input tensors need to be divided into blocks with, such that 255 `output_halo` can be cropped from their outputs without introducing gaps.""" 256 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 257 outputs = {t.id: t for t in model.outputs} 258 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 259 260 for t, th in output_halo.items(): 261 axes = {a.id: a for a in outputs[t].axes} 262 263 for a, ah in th.items(): 264 s = axes[a].size 265 if not isinstance(s, v0_5.SizeReference): 266 raise ValueError( 267 f"Unable to map output halo for {t}.{a} to an input axis" 268 ) 269 270 axis = axes[a] 271 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 272 273 total_output_halo = sum(ah) 274 total_input_halo = total_output_halo * axis.scale / ref_axis.scale 275 assert ( 276 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 277 ) 278 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 279 int(total_input_halo // 2), int(total_input_halo // 2) 280 ) 281 282 return input_halo
returns which halo input tensors need to be divided into blocks with, such that
output_halo
can be cropped from their outputs without introducing gaps.
285def get_block_transform( 286 model: v0_5.ModelDescr, 287) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 288 """returns how a model's output tensor shapes relates to its input shapes""" 289 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 290 batch_axis_trf = None 291 for ipt in model.inputs: 292 for a in ipt.axes: 293 if a.type == "batch": 294 batch_axis_trf = LinearSampleAxisTransform( 295 axis=a.id, scale=1, offset=0, member=ipt.id 296 ) 297 break 298 if batch_axis_trf is not None: 299 break 300 axis_scales = { 301 t.id: {a.id: a.scale for a in t.axes} 302 for t in chain(model.inputs, model.outputs) 303 } 304 for out in model.outputs: 305 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 306 for a in out.axes: 307 if a.size is None: 308 assert a.type == "batch" 309 if batch_axis_trf is None: 310 raise ValueError( 311 "no batch axis found in any input tensor, but output tensor" 312 + f" '{out.id}' has one." 313 ) 314 s = batch_axis_trf 315 elif isinstance(a.size, int): 316 s = a.size 317 elif isinstance(a.size, v0_5.DataDependentSize): 318 s = -1 319 elif isinstance(a.size, v0_5.SizeReference): 320 s = LinearSampleAxisTransform( 321 axis=a.size.axis_id, 322 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 323 offset=a.size.offset, 324 member=a.size.tensor_id, 325 ) 326 else: 327 assert_never(a.size) 328 329 new_axes[a.id] = s 330 331 ret[out.id] = new_axes 332 333 return ret
returns how a model's output tensor shapes relates to its input shapes
336def get_io_sample_block_metas( 337 model: v0_5.ModelDescr, 338 input_sample_shape: PerMember[PerAxis[int]], 339 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 340 batch_size: int = 1, 341) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 342 """returns an iterable yielding meta data for corresponding input and output samples""" 343 if not isinstance(model, v0_5.ModelDescr): 344 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 345 346 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 347 input_block_shape = { 348 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 349 for t in {tt for tt, _ in block_axis_sizes.inputs} 350 } 351 output_halo = { 352 t.id: { 353 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 354 } 355 for t in model.outputs 356 } 357 input_halo = get_input_halo(model, output_halo) 358 359 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 360 input_sample_shape, input_block_shape, halo=input_halo 361 ) 362 block_transform = get_block_transform(model) 363 return n_input_blocks, ( 364 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 365 for ipt in sample_block_meta_generator( 366 input_blocks, sample_shape=input_sample_shape, sample_id=None 367 ) 368 )
returns an iterable yielding meta data for corresponding input and output samples
371def get_tensor( 372 src: Union[ZipPath, TensorSource], 373 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 374): 375 """helper to cast/load various tensor sources""" 376 377 if isinstance(src, Tensor): 378 return src 379 elif isinstance(src, xr.DataArray): 380 return Tensor.from_xarray(src) 381 elif isinstance(src, np.ndarray): 382 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 383 else: 384 return load_tensor(src, axes=get_axes_infos(ipt))
helper to cast/load various tensor sources
387def create_sample_for_model( 388 model: AnyModelDescr, 389 *, 390 stat: Optional[Stat] = None, 391 sample_id: SampleId = None, 392 inputs: Union[PerMember[TensorSource], TensorSource], 393) -> Sample: 394 """Create a sample from a single set of input(s) for a specific bioimage.io model 395 396 Args: 397 model: a bioimage.io model description 398 stat: dictionary with sample and dataset statistics (may be updated in-place!) 399 inputs: the input(s) constituting a single sample. 400 """ 401 402 model_inputs = {get_member_id(d): d for d in model.inputs} 403 if isinstance(inputs, collections.abc.Mapping): 404 inputs = {MemberId(k): v for k, v in inputs.items()} 405 elif len(model_inputs) == 1: 406 inputs = {list(model_inputs)[0]: inputs} 407 else: 408 raise TypeError( 409 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 410 ) 411 412 if unknown := {k for k in inputs if k not in model_inputs}: 413 raise ValueError(f"Got unexpected inputs: {unknown}") 414 415 if missing := { 416 k 417 for k, v in model_inputs.items() 418 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 419 }: 420 raise ValueError(f"Missing non-optional model inputs: {missing}") 421 422 return Sample( 423 members={ 424 m: get_tensor(inputs[m], ipt) 425 for m, ipt in model_inputs.items() 426 if m in inputs 427 }, 428 stat={} if stat is None else stat, 429 id=sample_id, 430 )
Create a sample from a single set of input(s) for a specific bioimage.io model
Arguments:
- model: a bioimage.io model description
- stat: dictionary with sample and dataset statistics (may be updated in-place!)
- inputs: the input(s) constituting a single sample.
433def load_sample_for_model( 434 *, 435 model: AnyModelDescr, 436 paths: PerMember[Path], 437 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 438 stat: Optional[Stat] = None, 439 sample_id: Optional[SampleId] = None, 440): 441 """load a single sample from `paths` that can be processed by `model`""" 442 443 if axes is None: 444 axes = {} 445 446 # make sure members are keyed by MemberId, not string 447 paths = {MemberId(k): v for k, v in paths.items()} 448 axes = {MemberId(k): v for k, v in axes.items()} 449 450 model_inputs = {get_member_id(d): d for d in model.inputs} 451 452 if unknown := {k for k in paths if k not in model_inputs}: 453 raise ValueError(f"Got unexpected paths for {unknown}") 454 455 if unknown := {k for k in axes if k not in model_inputs}: 456 raise ValueError(f"Got unexpected axes hints for: {unknown}") 457 458 members: Dict[MemberId, Tensor] = {} 459 for m, p in paths.items(): 460 if m not in axes: 461 axes[m] = get_axes_infos(model_inputs[m]) 462 logger.debug( 463 "loading '{}' from {} with default input axes {} ", 464 m, 465 p, 466 axes[m], 467 ) 468 members[m] = load_tensor(p, axes[m]) 469 470 return Sample( 471 members=members, 472 stat={} if stat is None else stat, 473 id=sample_id or tuple(sorted(paths.values())), 474 )
load a single sample from paths
that can be processed by model