bioimageio.core.digest_spec
1from __future__ import annotations 2 3import collections.abc 4import importlib.util 5import sys 6from itertools import chain 7from pathlib import Path 8from tempfile import TemporaryDirectory 9from typing import ( 10 Any, 11 Callable, 12 Dict, 13 Iterable, 14 List, 15 Mapping, 16 NamedTuple, 17 Optional, 18 Sequence, 19 Tuple, 20 Union, 21) 22from zipfile import ZipFile, is_zipfile 23 24import numpy as np 25import xarray as xr 26from bioimageio.spec._internal.io import HashKwargs 27from bioimageio.spec.common import FileDescr, FileSource, ZipPath 28from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 29from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 30from bioimageio.spec.model.v0_5 import ( 31 ArchitectureFromFileDescr, 32 ArchitectureFromLibraryDescr, 33 ParameterizedSize_N, 34) 35from bioimageio.spec.utils import load_array 36from loguru import logger 37from numpy.typing import NDArray 38from typing_extensions import Unpack, assert_never 39 40from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 41from .block_meta import split_multiple_shapes_into_blocks 42from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 43from .io import load_tensor 44from .sample import ( 45 LinearSampleAxisTransform, 46 Sample, 47 SampleBlockMeta, 48 sample_block_meta_generator, 49) 50from .stat_measures import Stat 51from .tensor import Tensor 52 53TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path] 54 55 56def import_callable( 57 node: Union[ 58 ArchitectureFromFileDescr, 59 ArchitectureFromLibraryDescr, 60 CallableFromDepencency, 61 CallableFromFile, 62 ], 63 /, 64 **kwargs: Unpack[HashKwargs], 65) -> Callable[..., Any]: 66 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 67 if isinstance(node, CallableFromDepencency): 68 module = importlib.import_module(node.module_name) 69 c = getattr(module, str(node.callable_name)) 70 elif isinstance(node, ArchitectureFromLibraryDescr): 71 module = importlib.import_module(node.import_from) 72 c = getattr(module, str(node.callable)) 73 elif isinstance(node, CallableFromFile): 74 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 75 elif isinstance(node, ArchitectureFromFileDescr): 76 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 77 else: 78 assert_never(node) 79 80 if not callable(c): 81 raise ValueError(f"{node} (imported: {c}) is not callable") 82 83 return c 84 85 86tmp_dirs_in_use: List[TemporaryDirectory[str]] = [] 87"""keep global reference to temporary directories created during import to delay cleanup""" 88 89 90def _import_from_file_impl( 91 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 92): 93 src_descr = FileDescr(source=source, **kwargs) 94 # ensure sha is valid even if perform_io_checks=False 95 # or the source has changed since last sha computation 96 src_descr.validate_sha256(force_recompute=True) 97 assert src_descr.sha256 is not None 98 source_sha = src_descr.sha256 99 100 reader = src_descr.get_reader() 101 # make sure we have unique module name 102 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}" 103 104 # make sure we have a unique and valid module name 105 if not module_name.isidentifier(): 106 module_name = f"custom_module_{source_sha}" 107 assert module_name.isidentifier(), module_name 108 109 source_bytes = reader.read() 110 111 module = sys.modules.get(module_name) 112 if module is None: 113 try: 114 td_kwargs: Dict[str, Any] = ( 115 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 116 ) 117 if sys.version_info >= (3, 12): 118 td_kwargs["delete"] = False 119 120 tmp_dir = TemporaryDirectory(**td_kwargs) 121 # keep global ref to tmp_dir to delay cleanup until program exit 122 # TODO: remove for py >= 3.12, when delete=False works 123 tmp_dirs_in_use.append(tmp_dir) 124 125 module_path = Path(tmp_dir.name) / module_name 126 if reader.original_file_name.endswith(".zip") or is_zipfile(reader): 127 module_path.mkdir() 128 ZipFile(reader).extractall(path=module_path) 129 else: 130 module_path = module_path.with_suffix(".py") 131 _ = module_path.write_bytes(source_bytes) 132 133 importlib_spec = importlib.util.spec_from_file_location( 134 module_name, str(module_path) 135 ) 136 137 if importlib_spec is None: 138 raise ImportError(f"Failed to import {source}") 139 140 module = importlib.util.module_from_spec(importlib_spec) 141 142 sys.modules[module_name] = module # cache this module 143 144 assert importlib_spec.loader is not None 145 importlib_spec.loader.exec_module(module) 146 147 except Exception as e: 148 if module_name in sys.modules: 149 del sys.modules[module_name] 150 151 raise ImportError(f"Failed to import {source}") from e 152 153 try: 154 callable_attr = getattr(module, callable_name) 155 except AttributeError as e: 156 raise AttributeError( 157 f"Imported custom module from {source} has no `{callable_name}` attribute." 158 ) from e 159 except Exception as e: 160 raise AttributeError( 161 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ." 162 ) from e 163 164 else: 165 return callable_attr 166 167 168def get_axes_infos( 169 io_descr: Union[ 170 v0_4.InputTensorDescr, 171 v0_4.OutputTensorDescr, 172 v0_5.InputTensorDescr, 173 v0_5.OutputTensorDescr, 174 ], 175) -> List[AxisInfo]: 176 """get a unified, simplified axis representation from spec axes""" 177 ret: List[AxisInfo] = [] 178 for a in io_descr.axes: 179 if isinstance(a, v0_5.AxisBase): 180 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 181 else: 182 assert a in ("b", "i", "t", "c", "z", "y", "x") 183 ret.append(AxisInfo.create(a)) 184 185 return ret 186 187 188def get_member_id( 189 tensor_description: Union[ 190 v0_4.InputTensorDescr, 191 v0_4.OutputTensorDescr, 192 v0_5.InputTensorDescr, 193 v0_5.OutputTensorDescr, 194 ], 195) -> MemberId: 196 """get the normalized tensor ID, usable as a sample member ID""" 197 198 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 199 return MemberId(tensor_description.name) 200 elif isinstance( 201 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 202 ): 203 return tensor_description.id 204 else: 205 assert_never(tensor_description) 206 207 208def get_member_ids( 209 tensor_descriptions: Sequence[ 210 Union[ 211 v0_4.InputTensorDescr, 212 v0_4.OutputTensorDescr, 213 v0_5.InputTensorDescr, 214 v0_5.OutputTensorDescr, 215 ] 216 ], 217) -> List[MemberId]: 218 """get normalized tensor IDs to be used as sample member IDs""" 219 return [get_member_id(descr) for descr in tensor_descriptions] 220 221 222def get_test_input_sample(model: AnyModelDescr) -> Sample: 223 return _get_test_sample( 224 model.inputs, 225 model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs, 226 ) 227 228 229get_test_inputs = get_test_input_sample 230"""DEPRECATED: use `get_test_input_sample` instead""" 231 232 233def get_test_output_sample(model: AnyModelDescr) -> Sample: 234 """returns a model's test output sample""" 235 return _get_test_sample( 236 model.outputs, 237 model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs, 238 ) 239 240 241get_test_outputs = get_test_output_sample 242"""DEPRECATED: use `get_test_input_sample` instead""" 243 244 245def _get_test_sample( 246 tensor_descrs: Sequence[ 247 Union[ 248 v0_4.InputTensorDescr, 249 v0_4.OutputTensorDescr, 250 v0_5.InputTensorDescr, 251 v0_5.OutputTensorDescr, 252 ] 253 ], 254 test_sources: Sequence[Union[FileSource, v0_5.TensorDescr]], 255) -> Sample: 256 """returns a model's input/output test sample""" 257 member_ids = get_member_ids(tensor_descrs) 258 arrays: List[NDArray[Any]] = [] 259 for src in test_sources: 260 if isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)): 261 if src.test_tensor is None: 262 raise ValueError( 263 f"Model input '{src.id}' has no test tensor defined, cannot create test sample." 264 ) 265 arrays.append(load_array(src.test_tensor)) 266 else: 267 arrays.append(load_array(src)) 268 269 axes = [get_axes_infos(t) for t in tensor_descrs] 270 return Sample( 271 members={ 272 m: Tensor.from_numpy(arr, dims=ax) 273 for m, arr, ax in zip(member_ids, arrays, axes) 274 }, 275 stat={}, 276 id="test-sample", 277 ) 278 279 280class IO_SampleBlockMeta(NamedTuple): 281 input: SampleBlockMeta 282 output: SampleBlockMeta 283 284 285def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 286 """returns which halo input tensors need to be divided into blocks with, such that 287 `output_halo` can be cropped from their outputs without introducing gaps.""" 288 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 289 outputs = {t.id: t for t in model.outputs} 290 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 291 292 for t, th in output_halo.items(): 293 axes = {a.id: a for a in outputs[t].axes} 294 295 for a, ah in th.items(): 296 s = axes[a].size 297 if not isinstance(s, v0_5.SizeReference): 298 raise ValueError( 299 f"Unable to map output halo for {t}.{a} to an input axis" 300 ) 301 302 axis = axes[a] 303 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 304 305 input_halo_left = ah.left * axis.scale / ref_axis.scale 306 input_halo_right = ah.right * axis.scale / ref_axis.scale 307 assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int" 308 assert input_halo_right == int(input_halo_right), ( 309 f"{input_halo_right} not int" 310 ) 311 312 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 313 int(input_halo_left), int(input_halo_right) 314 ) 315 316 return input_halo 317 318 319def get_block_transform( 320 model: v0_5.ModelDescr, 321) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 322 """returns how a model's output tensor shapes relates to its input shapes""" 323 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 324 batch_axis_trf = None 325 for ipt in model.inputs: 326 for a in ipt.axes: 327 if a.type == "batch": 328 batch_axis_trf = LinearSampleAxisTransform( 329 axis=a.id, scale=1, offset=0, member=ipt.id 330 ) 331 break 332 if batch_axis_trf is not None: 333 break 334 axis_scales = { 335 t.id: {a.id: a.scale for a in t.axes} 336 for t in chain(model.inputs, model.outputs) 337 } 338 for out in model.outputs: 339 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 340 for a in out.axes: 341 if a.size is None: 342 assert a.type == "batch" 343 if batch_axis_trf is None: 344 raise ValueError( 345 "no batch axis found in any input tensor, but output tensor" 346 + f" '{out.id}' has one." 347 ) 348 s = batch_axis_trf 349 elif isinstance(a.size, int): 350 s = a.size 351 elif isinstance(a.size, v0_5.DataDependentSize): 352 s = -1 353 elif isinstance(a.size, v0_5.SizeReference): 354 s = LinearSampleAxisTransform( 355 axis=a.size.axis_id, 356 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 357 offset=a.size.offset, 358 member=a.size.tensor_id, 359 ) 360 else: 361 assert_never(a.size) 362 363 new_axes[a.id] = s 364 365 ret[out.id] = new_axes 366 367 return ret 368 369 370def get_io_sample_block_metas( 371 model: v0_5.ModelDescr, 372 input_sample_shape: PerMember[PerAxis[int]], 373 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 374 batch_size: int = 1, 375) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 376 """returns an iterable yielding meta data for corresponding input and output samples""" 377 if not isinstance(model, v0_5.ModelDescr): 378 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 379 380 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 381 input_block_shape = { 382 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 383 for t in {tt for tt, _ in block_axis_sizes.inputs} 384 } 385 output_halo = { 386 t.id: { 387 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 388 } 389 for t in model.outputs 390 } 391 input_halo = get_input_halo(model, output_halo) 392 393 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 394 input_sample_shape, input_block_shape, halo=input_halo 395 ) 396 block_transform = get_block_transform(model) 397 return n_input_blocks, ( 398 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 399 for ipt in sample_block_meta_generator( 400 input_blocks, sample_shape=input_sample_shape, sample_id=None 401 ) 402 ) 403 404 405def get_tensor( 406 src: Union[ZipPath, TensorSource], 407 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 408): 409 """helper to cast/load various tensor sources""" 410 411 if isinstance(src, Tensor): 412 return src 413 elif isinstance(src, xr.DataArray): 414 return Tensor.from_xarray(src) 415 elif isinstance(src, np.ndarray): 416 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 417 else: 418 return load_tensor(src, axes=get_axes_infos(ipt)) 419 420 421def create_sample_for_model( 422 model: AnyModelDescr, 423 *, 424 stat: Optional[Stat] = None, 425 sample_id: SampleId = None, 426 inputs: Union[PerMember[TensorSource], TensorSource], 427) -> Sample: 428 """Create a sample from a single set of input(s) for a specific bioimage.io model 429 430 Args: 431 model: a bioimage.io model description 432 stat: dictionary with sample and dataset statistics (may be updated in-place!) 433 inputs: the input(s) constituting a single sample. 434 """ 435 436 model_inputs = {get_member_id(d): d for d in model.inputs} 437 if isinstance(inputs, collections.abc.Mapping): 438 inputs = {MemberId(k): v for k, v in inputs.items()} 439 elif len(model_inputs) == 1: 440 inputs = {list(model_inputs)[0]: inputs} 441 else: 442 raise TypeError( 443 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 444 ) 445 446 if unknown := {k for k in inputs if k not in model_inputs}: 447 raise ValueError(f"Got unexpected inputs: {unknown}") 448 449 if missing := { 450 k 451 for k, v in model_inputs.items() 452 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 453 }: 454 raise ValueError(f"Missing non-optional model inputs: {missing}") 455 456 return Sample( 457 members={ 458 m: get_tensor(inputs[m], ipt) 459 for m, ipt in model_inputs.items() 460 if m in inputs 461 }, 462 stat={} if stat is None else stat, 463 id=sample_id, 464 ) 465 466 467def load_sample_for_model( 468 *, 469 model: AnyModelDescr, 470 paths: PerMember[Path], 471 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 472 stat: Optional[Stat] = None, 473 sample_id: Optional[SampleId] = None, 474): 475 """load a single sample from `paths` that can be processed by `model`""" 476 477 if axes is None: 478 axes = {} 479 480 # make sure members are keyed by MemberId, not string 481 paths = {MemberId(k): v for k, v in paths.items()} 482 axes = {MemberId(k): v for k, v in axes.items()} 483 484 model_inputs = {get_member_id(d): d for d in model.inputs} 485 486 if unknown := {k for k in paths if k not in model_inputs}: 487 raise ValueError(f"Got unexpected paths for {unknown}") 488 489 if unknown := {k for k in axes if k not in model_inputs}: 490 raise ValueError(f"Got unexpected axes hints for: {unknown}") 491 492 members: Dict[MemberId, Tensor] = {} 493 for m, p in paths.items(): 494 if m not in axes: 495 axes[m] = get_axes_infos(model_inputs[m]) 496 logger.debug( 497 "loading '{}' from {} with default input axes {} ", 498 m, 499 p, 500 axes[m], 501 ) 502 members[m] = load_tensor(p, axes[m]) 503 504 return Sample( 505 members=members, 506 stat={} if stat is None else stat, 507 id=sample_id or tuple(sorted(paths.values())), 508 )
57def import_callable( 58 node: Union[ 59 ArchitectureFromFileDescr, 60 ArchitectureFromLibraryDescr, 61 CallableFromDepencency, 62 CallableFromFile, 63 ], 64 /, 65 **kwargs: Unpack[HashKwargs], 66) -> Callable[..., Any]: 67 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 68 if isinstance(node, CallableFromDepencency): 69 module = importlib.import_module(node.module_name) 70 c = getattr(module, str(node.callable_name)) 71 elif isinstance(node, ArchitectureFromLibraryDescr): 72 module = importlib.import_module(node.import_from) 73 c = getattr(module, str(node.callable)) 74 elif isinstance(node, CallableFromFile): 75 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 76 elif isinstance(node, ArchitectureFromFileDescr): 77 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 78 else: 79 assert_never(node) 80 81 if not callable(c): 82 raise ValueError(f"{node} (imported: {c}) is not callable") 83 84 return c
import a callable (e.g. a torch.nn.Module) from a spec node describing it
keep global reference to temporary directories created during import to delay cleanup
169def get_axes_infos( 170 io_descr: Union[ 171 v0_4.InputTensorDescr, 172 v0_4.OutputTensorDescr, 173 v0_5.InputTensorDescr, 174 v0_5.OutputTensorDescr, 175 ], 176) -> List[AxisInfo]: 177 """get a unified, simplified axis representation from spec axes""" 178 ret: List[AxisInfo] = [] 179 for a in io_descr.axes: 180 if isinstance(a, v0_5.AxisBase): 181 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 182 else: 183 assert a in ("b", "i", "t", "c", "z", "y", "x") 184 ret.append(AxisInfo.create(a)) 185 186 return ret
get a unified, simplified axis representation from spec axes
189def get_member_id( 190 tensor_description: Union[ 191 v0_4.InputTensorDescr, 192 v0_4.OutputTensorDescr, 193 v0_5.InputTensorDescr, 194 v0_5.OutputTensorDescr, 195 ], 196) -> MemberId: 197 """get the normalized tensor ID, usable as a sample member ID""" 198 199 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 200 return MemberId(tensor_description.name) 201 elif isinstance( 202 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 203 ): 204 return tensor_description.id 205 else: 206 assert_never(tensor_description)
get the normalized tensor ID, usable as a sample member ID
209def get_member_ids( 210 tensor_descriptions: Sequence[ 211 Union[ 212 v0_4.InputTensorDescr, 213 v0_4.OutputTensorDescr, 214 v0_5.InputTensorDescr, 215 v0_5.OutputTensorDescr, 216 ] 217 ], 218) -> List[MemberId]: 219 """get normalized tensor IDs to be used as sample member IDs""" 220 return [get_member_id(descr) for descr in tensor_descriptions]
get normalized tensor IDs to be used as sample member IDs
223def get_test_input_sample(model: AnyModelDescr) -> Sample: 224 return _get_test_sample( 225 model.inputs, 226 model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs, 227 )
DEPRECATED: use get_test_input_sample instead
234def get_test_output_sample(model: AnyModelDescr) -> Sample: 235 """returns a model's test output sample""" 236 return _get_test_sample( 237 model.outputs, 238 model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs, 239 )
returns a model's test output sample
234def get_test_output_sample(model: AnyModelDescr) -> Sample: 235 """returns a model's test output sample""" 236 return _get_test_sample( 237 model.outputs, 238 model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs, 239 )
DEPRECATED: use get_test_input_sample instead
IO_SampleBlockMeta(input, output)
286def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 287 """returns which halo input tensors need to be divided into blocks with, such that 288 `output_halo` can be cropped from their outputs without introducing gaps.""" 289 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 290 outputs = {t.id: t for t in model.outputs} 291 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 292 293 for t, th in output_halo.items(): 294 axes = {a.id: a for a in outputs[t].axes} 295 296 for a, ah in th.items(): 297 s = axes[a].size 298 if not isinstance(s, v0_5.SizeReference): 299 raise ValueError( 300 f"Unable to map output halo for {t}.{a} to an input axis" 301 ) 302 303 axis = axes[a] 304 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 305 306 input_halo_left = ah.left * axis.scale / ref_axis.scale 307 input_halo_right = ah.right * axis.scale / ref_axis.scale 308 assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int" 309 assert input_halo_right == int(input_halo_right), ( 310 f"{input_halo_right} not int" 311 ) 312 313 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 314 int(input_halo_left), int(input_halo_right) 315 ) 316 317 return input_halo
returns which halo input tensors need to be divided into blocks with, such that
output_halo can be cropped from their outputs without introducing gaps.
320def get_block_transform( 321 model: v0_5.ModelDescr, 322) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 323 """returns how a model's output tensor shapes relates to its input shapes""" 324 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 325 batch_axis_trf = None 326 for ipt in model.inputs: 327 for a in ipt.axes: 328 if a.type == "batch": 329 batch_axis_trf = LinearSampleAxisTransform( 330 axis=a.id, scale=1, offset=0, member=ipt.id 331 ) 332 break 333 if batch_axis_trf is not None: 334 break 335 axis_scales = { 336 t.id: {a.id: a.scale for a in t.axes} 337 for t in chain(model.inputs, model.outputs) 338 } 339 for out in model.outputs: 340 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 341 for a in out.axes: 342 if a.size is None: 343 assert a.type == "batch" 344 if batch_axis_trf is None: 345 raise ValueError( 346 "no batch axis found in any input tensor, but output tensor" 347 + f" '{out.id}' has one." 348 ) 349 s = batch_axis_trf 350 elif isinstance(a.size, int): 351 s = a.size 352 elif isinstance(a.size, v0_5.DataDependentSize): 353 s = -1 354 elif isinstance(a.size, v0_5.SizeReference): 355 s = LinearSampleAxisTransform( 356 axis=a.size.axis_id, 357 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 358 offset=a.size.offset, 359 member=a.size.tensor_id, 360 ) 361 else: 362 assert_never(a.size) 363 364 new_axes[a.id] = s 365 366 ret[out.id] = new_axes 367 368 return ret
returns how a model's output tensor shapes relates to its input shapes
371def get_io_sample_block_metas( 372 model: v0_5.ModelDescr, 373 input_sample_shape: PerMember[PerAxis[int]], 374 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 375 batch_size: int = 1, 376) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 377 """returns an iterable yielding meta data for corresponding input and output samples""" 378 if not isinstance(model, v0_5.ModelDescr): 379 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 380 381 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 382 input_block_shape = { 383 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 384 for t in {tt for tt, _ in block_axis_sizes.inputs} 385 } 386 output_halo = { 387 t.id: { 388 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 389 } 390 for t in model.outputs 391 } 392 input_halo = get_input_halo(model, output_halo) 393 394 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 395 input_sample_shape, input_block_shape, halo=input_halo 396 ) 397 block_transform = get_block_transform(model) 398 return n_input_blocks, ( 399 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 400 for ipt in sample_block_meta_generator( 401 input_blocks, sample_shape=input_sample_shape, sample_id=None 402 ) 403 )
returns an iterable yielding meta data for corresponding input and output samples
406def get_tensor( 407 src: Union[ZipPath, TensorSource], 408 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 409): 410 """helper to cast/load various tensor sources""" 411 412 if isinstance(src, Tensor): 413 return src 414 elif isinstance(src, xr.DataArray): 415 return Tensor.from_xarray(src) 416 elif isinstance(src, np.ndarray): 417 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 418 else: 419 return load_tensor(src, axes=get_axes_infos(ipt))
helper to cast/load various tensor sources
422def create_sample_for_model( 423 model: AnyModelDescr, 424 *, 425 stat: Optional[Stat] = None, 426 sample_id: SampleId = None, 427 inputs: Union[PerMember[TensorSource], TensorSource], 428) -> Sample: 429 """Create a sample from a single set of input(s) for a specific bioimage.io model 430 431 Args: 432 model: a bioimage.io model description 433 stat: dictionary with sample and dataset statistics (may be updated in-place!) 434 inputs: the input(s) constituting a single sample. 435 """ 436 437 model_inputs = {get_member_id(d): d for d in model.inputs} 438 if isinstance(inputs, collections.abc.Mapping): 439 inputs = {MemberId(k): v for k, v in inputs.items()} 440 elif len(model_inputs) == 1: 441 inputs = {list(model_inputs)[0]: inputs} 442 else: 443 raise TypeError( 444 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 445 ) 446 447 if unknown := {k for k in inputs if k not in model_inputs}: 448 raise ValueError(f"Got unexpected inputs: {unknown}") 449 450 if missing := { 451 k 452 for k, v in model_inputs.items() 453 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 454 }: 455 raise ValueError(f"Missing non-optional model inputs: {missing}") 456 457 return Sample( 458 members={ 459 m: get_tensor(inputs[m], ipt) 460 for m, ipt in model_inputs.items() 461 if m in inputs 462 }, 463 stat={} if stat is None else stat, 464 id=sample_id, 465 )
Create a sample from a single set of input(s) for a specific bioimage.io model
Arguments:
- model: a bioimage.io model description
- stat: dictionary with sample and dataset statistics (may be updated in-place!)
- inputs: the input(s) constituting a single sample.
468def load_sample_for_model( 469 *, 470 model: AnyModelDescr, 471 paths: PerMember[Path], 472 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 473 stat: Optional[Stat] = None, 474 sample_id: Optional[SampleId] = None, 475): 476 """load a single sample from `paths` that can be processed by `model`""" 477 478 if axes is None: 479 axes = {} 480 481 # make sure members are keyed by MemberId, not string 482 paths = {MemberId(k): v for k, v in paths.items()} 483 axes = {MemberId(k): v for k, v in axes.items()} 484 485 model_inputs = {get_member_id(d): d for d in model.inputs} 486 487 if unknown := {k for k in paths if k not in model_inputs}: 488 raise ValueError(f"Got unexpected paths for {unknown}") 489 490 if unknown := {k for k in axes if k not in model_inputs}: 491 raise ValueError(f"Got unexpected axes hints for: {unknown}") 492 493 members: Dict[MemberId, Tensor] = {} 494 for m, p in paths.items(): 495 if m not in axes: 496 axes[m] = get_axes_infos(model_inputs[m]) 497 logger.debug( 498 "loading '{}' from {} with default input axes {} ", 499 m, 500 p, 501 axes[m], 502 ) 503 members[m] = load_tensor(p, axes[m]) 504 505 return Sample( 506 members=members, 507 stat={} if stat is None else stat, 508 id=sample_id or tuple(sorted(paths.values())), 509 )
load a single sample from paths that can be processed by model