bioimageio.core.digest_spec
1from __future__ import annotations 2 3import collections.abc 4import importlib.util 5import sys 6from itertools import chain 7from pathlib import Path 8from tempfile import TemporaryDirectory 9from typing import ( 10 Any, 11 Callable, 12 Dict, 13 Iterable, 14 List, 15 Mapping, 16 NamedTuple, 17 Optional, 18 Sequence, 19 Tuple, 20 Union, 21) 22from zipfile import ZipFile, is_zipfile 23 24import numpy as np 25import xarray as xr 26from loguru import logger 27from numpy.typing import NDArray 28from typing_extensions import Unpack, assert_never 29 30from bioimageio.spec._internal.io import HashKwargs 31from bioimageio.spec.common import FileDescr, FileSource, ZipPath 32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 34from bioimageio.spec.model.v0_5 import ( 35 ArchitectureFromFileDescr, 36 ArchitectureFromLibraryDescr, 37 ParameterizedSize_N, 38) 39from bioimageio.spec.utils import load_array 40 41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 42from .block_meta import split_multiple_shapes_into_blocks 43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 44from .io import load_tensor 45from .sample import ( 46 LinearSampleAxisTransform, 47 Sample, 48 SampleBlockMeta, 49 sample_block_meta_generator, 50) 51from .stat_measures import Stat 52from .tensor import Tensor 53 54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path] 55 56 57def import_callable( 58 node: Union[ 59 ArchitectureFromFileDescr, 60 ArchitectureFromLibraryDescr, 61 CallableFromDepencency, 62 CallableFromFile, 63 ], 64 /, 65 **kwargs: Unpack[HashKwargs], 66) -> Callable[..., Any]: 67 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 68 if isinstance(node, CallableFromDepencency): 69 module = importlib.import_module(node.module_name) 70 c = getattr(module, str(node.callable_name)) 71 elif isinstance(node, ArchitectureFromLibraryDescr): 72 module = importlib.import_module(node.import_from) 73 c = getattr(module, str(node.callable)) 74 elif isinstance(node, CallableFromFile): 75 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 76 elif isinstance(node, ArchitectureFromFileDescr): 77 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 78 else: 79 assert_never(node) 80 81 if not callable(c): 82 raise ValueError(f"{node} (imported: {c}) is not callable") 83 84 return c 85 86 87tmp_dirs_in_use: List[TemporaryDirectory[str]] = [] 88"""keep global reference to temporary directories created during import to delay cleanup""" 89 90 91def _import_from_file_impl( 92 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 93): 94 src_descr = FileDescr(source=source, **kwargs) 95 # ensure sha is valid even if perform_io_checks=False 96 # or the source has changed since last sha computation 97 src_descr.validate_sha256(force_recompute=True) 98 assert src_descr.sha256 is not None 99 source_sha = src_descr.sha256 100 101 reader = src_descr.get_reader() 102 # make sure we have unique module name 103 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}" 104 105 # make sure we have a unique and valid module name 106 if not module_name.isidentifier(): 107 module_name = f"custom_module_{source_sha}" 108 assert module_name.isidentifier(), module_name 109 110 source_bytes = reader.read() 111 112 module = sys.modules.get(module_name) 113 if module is None: 114 try: 115 td_kwargs: Dict[str, Any] = ( 116 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 117 ) 118 if sys.version_info >= (3, 12): 119 td_kwargs["delete"] = False 120 121 tmp_dir = TemporaryDirectory(**td_kwargs) 122 # keep global ref to tmp_dir to delay cleanup until program exit 123 # TODO: remove for py >= 3.12, when delete=False works 124 tmp_dirs_in_use.append(tmp_dir) 125 126 module_path = Path(tmp_dir.name) / module_name 127 if reader.original_file_name.endswith(".zip") or is_zipfile(reader): 128 module_path.mkdir() 129 ZipFile(reader).extractall(path=module_path) 130 else: 131 module_path = module_path.with_suffix(".py") 132 _ = module_path.write_bytes(source_bytes) 133 134 importlib_spec = importlib.util.spec_from_file_location( 135 module_name, str(module_path) 136 ) 137 138 if importlib_spec is None: 139 raise ImportError(f"Failed to import {source}") 140 141 module = importlib.util.module_from_spec(importlib_spec) 142 143 sys.modules[module_name] = module # cache this module 144 145 assert importlib_spec.loader is not None 146 importlib_spec.loader.exec_module(module) 147 148 except Exception as e: 149 del sys.modules[module_name] 150 raise ImportError(f"Failed to import {source}") from e 151 152 try: 153 callable_attr = getattr(module, callable_name) 154 except AttributeError as e: 155 raise AttributeError( 156 f"Imported custom module from {source} has no `{callable_name}` attribute." 157 ) from e 158 except Exception as e: 159 raise AttributeError( 160 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ." 161 ) from e 162 163 else: 164 return callable_attr 165 166 167def get_axes_infos( 168 io_descr: Union[ 169 v0_4.InputTensorDescr, 170 v0_4.OutputTensorDescr, 171 v0_5.InputTensorDescr, 172 v0_5.OutputTensorDescr, 173 ], 174) -> List[AxisInfo]: 175 """get a unified, simplified axis representation from spec axes""" 176 ret: List[AxisInfo] = [] 177 for a in io_descr.axes: 178 if isinstance(a, v0_5.AxisBase): 179 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 180 else: 181 assert a in ("b", "i", "t", "c", "z", "y", "x") 182 ret.append(AxisInfo.create(a)) 183 184 return ret 185 186 187def get_member_id( 188 tensor_description: Union[ 189 v0_4.InputTensorDescr, 190 v0_4.OutputTensorDescr, 191 v0_5.InputTensorDescr, 192 v0_5.OutputTensorDescr, 193 ], 194) -> MemberId: 195 """get the normalized tensor ID, usable as a sample member ID""" 196 197 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 198 return MemberId(tensor_description.name) 199 elif isinstance( 200 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 201 ): 202 return tensor_description.id 203 else: 204 assert_never(tensor_description) 205 206 207def get_member_ids( 208 tensor_descriptions: Sequence[ 209 Union[ 210 v0_4.InputTensorDescr, 211 v0_4.OutputTensorDescr, 212 v0_5.InputTensorDescr, 213 v0_5.OutputTensorDescr, 214 ] 215 ], 216) -> List[MemberId]: 217 """get normalized tensor IDs to be used as sample member IDs""" 218 return [get_member_id(descr) for descr in tensor_descriptions] 219 220 221def get_test_input_sample(model: AnyModelDescr) -> Sample: 222 return _get_test_sample( 223 model.inputs, 224 model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs, 225 ) 226 227 228get_test_inputs = get_test_input_sample 229"""DEPRECATED: use `get_test_input_sample` instead""" 230 231 232def get_test_output_sample(model: AnyModelDescr) -> Sample: 233 """returns a model's test output sample""" 234 return _get_test_sample( 235 model.outputs, 236 model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs, 237 ) 238 239 240get_test_outputs = get_test_output_sample 241"""DEPRECATED: use `get_test_input_sample` instead""" 242 243 244def _get_test_sample( 245 tensor_descrs: Sequence[ 246 Union[ 247 v0_4.InputTensorDescr, 248 v0_4.OutputTensorDescr, 249 v0_5.InputTensorDescr, 250 v0_5.OutputTensorDescr, 251 ] 252 ], 253 test_sources: Sequence[Union[FileSource, v0_5.TensorDescr]], 254) -> Sample: 255 """returns a model's input/output test sample""" 256 member_ids = get_member_ids(tensor_descrs) 257 arrays: List[NDArray[Any]] = [] 258 for src in test_sources: 259 if isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)): 260 if src.test_tensor is None: 261 raise ValueError( 262 f"Model input '{src.id}' has no test tensor defined, cannot create test sample." 263 ) 264 arrays.append(load_array(src.test_tensor)) 265 else: 266 arrays.append(load_array(src)) 267 268 axes = [get_axes_infos(t) for t in tensor_descrs] 269 return Sample( 270 members={ 271 m: Tensor.from_numpy(arr, dims=ax) 272 for m, arr, ax in zip(member_ids, arrays, axes) 273 }, 274 stat={}, 275 id="test-sample", 276 ) 277 278 279class IO_SampleBlockMeta(NamedTuple): 280 input: SampleBlockMeta 281 output: SampleBlockMeta 282 283 284def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 285 """returns which halo input tensors need to be divided into blocks with, such that 286 `output_halo` can be cropped from their outputs without introducing gaps.""" 287 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 288 outputs = {t.id: t for t in model.outputs} 289 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 290 291 for t, th in output_halo.items(): 292 axes = {a.id: a for a in outputs[t].axes} 293 294 for a, ah in th.items(): 295 s = axes[a].size 296 if not isinstance(s, v0_5.SizeReference): 297 raise ValueError( 298 f"Unable to map output halo for {t}.{a} to an input axis" 299 ) 300 301 axis = axes[a] 302 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 303 304 total_output_halo = sum(ah) 305 total_input_halo = total_output_halo * axis.scale / ref_axis.scale 306 assert ( 307 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 308 ) 309 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 310 int(total_input_halo // 2), int(total_input_halo // 2) 311 ) 312 313 return input_halo 314 315 316def get_block_transform( 317 model: v0_5.ModelDescr, 318) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 319 """returns how a model's output tensor shapes relates to its input shapes""" 320 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 321 batch_axis_trf = None 322 for ipt in model.inputs: 323 for a in ipt.axes: 324 if a.type == "batch": 325 batch_axis_trf = LinearSampleAxisTransform( 326 axis=a.id, scale=1, offset=0, member=ipt.id 327 ) 328 break 329 if batch_axis_trf is not None: 330 break 331 axis_scales = { 332 t.id: {a.id: a.scale for a in t.axes} 333 for t in chain(model.inputs, model.outputs) 334 } 335 for out in model.outputs: 336 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 337 for a in out.axes: 338 if a.size is None: 339 assert a.type == "batch" 340 if batch_axis_trf is None: 341 raise ValueError( 342 "no batch axis found in any input tensor, but output tensor" 343 + f" '{out.id}' has one." 344 ) 345 s = batch_axis_trf 346 elif isinstance(a.size, int): 347 s = a.size 348 elif isinstance(a.size, v0_5.DataDependentSize): 349 s = -1 350 elif isinstance(a.size, v0_5.SizeReference): 351 s = LinearSampleAxisTransform( 352 axis=a.size.axis_id, 353 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 354 offset=a.size.offset, 355 member=a.size.tensor_id, 356 ) 357 else: 358 assert_never(a.size) 359 360 new_axes[a.id] = s 361 362 ret[out.id] = new_axes 363 364 return ret 365 366 367def get_io_sample_block_metas( 368 model: v0_5.ModelDescr, 369 input_sample_shape: PerMember[PerAxis[int]], 370 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 371 batch_size: int = 1, 372) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 373 """returns an iterable yielding meta data for corresponding input and output samples""" 374 if not isinstance(model, v0_5.ModelDescr): 375 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 376 377 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 378 input_block_shape = { 379 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 380 for t in {tt for tt, _ in block_axis_sizes.inputs} 381 } 382 output_halo = { 383 t.id: { 384 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 385 } 386 for t in model.outputs 387 } 388 input_halo = get_input_halo(model, output_halo) 389 390 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 391 input_sample_shape, input_block_shape, halo=input_halo 392 ) 393 block_transform = get_block_transform(model) 394 return n_input_blocks, ( 395 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 396 for ipt in sample_block_meta_generator( 397 input_blocks, sample_shape=input_sample_shape, sample_id=None 398 ) 399 ) 400 401 402def get_tensor( 403 src: Union[ZipPath, TensorSource], 404 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 405): 406 """helper to cast/load various tensor sources""" 407 408 if isinstance(src, Tensor): 409 return src 410 elif isinstance(src, xr.DataArray): 411 return Tensor.from_xarray(src) 412 elif isinstance(src, np.ndarray): 413 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 414 else: 415 return load_tensor(src, axes=get_axes_infos(ipt)) 416 417 418def create_sample_for_model( 419 model: AnyModelDescr, 420 *, 421 stat: Optional[Stat] = None, 422 sample_id: SampleId = None, 423 inputs: Union[PerMember[TensorSource], TensorSource], 424) -> Sample: 425 """Create a sample from a single set of input(s) for a specific bioimage.io model 426 427 Args: 428 model: a bioimage.io model description 429 stat: dictionary with sample and dataset statistics (may be updated in-place!) 430 inputs: the input(s) constituting a single sample. 431 """ 432 433 model_inputs = {get_member_id(d): d for d in model.inputs} 434 if isinstance(inputs, collections.abc.Mapping): 435 inputs = {MemberId(k): v for k, v in inputs.items()} 436 elif len(model_inputs) == 1: 437 inputs = {list(model_inputs)[0]: inputs} 438 else: 439 raise TypeError( 440 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 441 ) 442 443 if unknown := {k for k in inputs if k not in model_inputs}: 444 raise ValueError(f"Got unexpected inputs: {unknown}") 445 446 if missing := { 447 k 448 for k, v in model_inputs.items() 449 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 450 }: 451 raise ValueError(f"Missing non-optional model inputs: {missing}") 452 453 return Sample( 454 members={ 455 m: get_tensor(inputs[m], ipt) 456 for m, ipt in model_inputs.items() 457 if m in inputs 458 }, 459 stat={} if stat is None else stat, 460 id=sample_id, 461 ) 462 463 464def load_sample_for_model( 465 *, 466 model: AnyModelDescr, 467 paths: PerMember[Path], 468 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 469 stat: Optional[Stat] = None, 470 sample_id: Optional[SampleId] = None, 471): 472 """load a single sample from `paths` that can be processed by `model`""" 473 474 if axes is None: 475 axes = {} 476 477 # make sure members are keyed by MemberId, not string 478 paths = {MemberId(k): v for k, v in paths.items()} 479 axes = {MemberId(k): v for k, v in axes.items()} 480 481 model_inputs = {get_member_id(d): d for d in model.inputs} 482 483 if unknown := {k for k in paths if k not in model_inputs}: 484 raise ValueError(f"Got unexpected paths for {unknown}") 485 486 if unknown := {k for k in axes if k not in model_inputs}: 487 raise ValueError(f"Got unexpected axes hints for: {unknown}") 488 489 members: Dict[MemberId, Tensor] = {} 490 for m, p in paths.items(): 491 if m not in axes: 492 axes[m] = get_axes_infos(model_inputs[m]) 493 logger.debug( 494 "loading '{}' from {} with default input axes {} ", 495 m, 496 p, 497 axes[m], 498 ) 499 members[m] = load_tensor(p, axes[m]) 500 501 return Sample( 502 members=members, 503 stat={} if stat is None else stat, 504 id=sample_id or tuple(sorted(paths.values())), 505 )
58def import_callable( 59 node: Union[ 60 ArchitectureFromFileDescr, 61 ArchitectureFromLibraryDescr, 62 CallableFromDepencency, 63 CallableFromFile, 64 ], 65 /, 66 **kwargs: Unpack[HashKwargs], 67) -> Callable[..., Any]: 68 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 69 if isinstance(node, CallableFromDepencency): 70 module = importlib.import_module(node.module_name) 71 c = getattr(module, str(node.callable_name)) 72 elif isinstance(node, ArchitectureFromLibraryDescr): 73 module = importlib.import_module(node.import_from) 74 c = getattr(module, str(node.callable)) 75 elif isinstance(node, CallableFromFile): 76 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 77 elif isinstance(node, ArchitectureFromFileDescr): 78 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 79 else: 80 assert_never(node) 81 82 if not callable(c): 83 raise ValueError(f"{node} (imported: {c}) is not callable") 84 85 return c
import a callable (e.g. a torch.nn.Module) from a spec node describing it
keep global reference to temporary directories created during import to delay cleanup
168def get_axes_infos( 169 io_descr: Union[ 170 v0_4.InputTensorDescr, 171 v0_4.OutputTensorDescr, 172 v0_5.InputTensorDescr, 173 v0_5.OutputTensorDescr, 174 ], 175) -> List[AxisInfo]: 176 """get a unified, simplified axis representation from spec axes""" 177 ret: List[AxisInfo] = [] 178 for a in io_descr.axes: 179 if isinstance(a, v0_5.AxisBase): 180 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 181 else: 182 assert a in ("b", "i", "t", "c", "z", "y", "x") 183 ret.append(AxisInfo.create(a)) 184 185 return ret
get a unified, simplified axis representation from spec axes
188def get_member_id( 189 tensor_description: Union[ 190 v0_4.InputTensorDescr, 191 v0_4.OutputTensorDescr, 192 v0_5.InputTensorDescr, 193 v0_5.OutputTensorDescr, 194 ], 195) -> MemberId: 196 """get the normalized tensor ID, usable as a sample member ID""" 197 198 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 199 return MemberId(tensor_description.name) 200 elif isinstance( 201 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 202 ): 203 return tensor_description.id 204 else: 205 assert_never(tensor_description)
get the normalized tensor ID, usable as a sample member ID
208def get_member_ids( 209 tensor_descriptions: Sequence[ 210 Union[ 211 v0_4.InputTensorDescr, 212 v0_4.OutputTensorDescr, 213 v0_5.InputTensorDescr, 214 v0_5.OutputTensorDescr, 215 ] 216 ], 217) -> List[MemberId]: 218 """get normalized tensor IDs to be used as sample member IDs""" 219 return [get_member_id(descr) for descr in tensor_descriptions]
get normalized tensor IDs to be used as sample member IDs
222def get_test_input_sample(model: AnyModelDescr) -> Sample: 223 return _get_test_sample( 224 model.inputs, 225 model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs, 226 )
DEPRECATED: use get_test_input_sample
instead
233def get_test_output_sample(model: AnyModelDescr) -> Sample: 234 """returns a model's test output sample""" 235 return _get_test_sample( 236 model.outputs, 237 model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs, 238 )
returns a model's test output sample
233def get_test_output_sample(model: AnyModelDescr) -> Sample: 234 """returns a model's test output sample""" 235 return _get_test_sample( 236 model.outputs, 237 model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs, 238 )
DEPRECATED: use get_test_input_sample
instead
IO_SampleBlockMeta(input, output)
285def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 286 """returns which halo input tensors need to be divided into blocks with, such that 287 `output_halo` can be cropped from their outputs without introducing gaps.""" 288 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 289 outputs = {t.id: t for t in model.outputs} 290 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 291 292 for t, th in output_halo.items(): 293 axes = {a.id: a for a in outputs[t].axes} 294 295 for a, ah in th.items(): 296 s = axes[a].size 297 if not isinstance(s, v0_5.SizeReference): 298 raise ValueError( 299 f"Unable to map output halo for {t}.{a} to an input axis" 300 ) 301 302 axis = axes[a] 303 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 304 305 total_output_halo = sum(ah) 306 total_input_halo = total_output_halo * axis.scale / ref_axis.scale 307 assert ( 308 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 309 ) 310 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 311 int(total_input_halo // 2), int(total_input_halo // 2) 312 ) 313 314 return input_halo
returns which halo input tensors need to be divided into blocks with, such that
output_halo
can be cropped from their outputs without introducing gaps.
317def get_block_transform( 318 model: v0_5.ModelDescr, 319) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 320 """returns how a model's output tensor shapes relates to its input shapes""" 321 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 322 batch_axis_trf = None 323 for ipt in model.inputs: 324 for a in ipt.axes: 325 if a.type == "batch": 326 batch_axis_trf = LinearSampleAxisTransform( 327 axis=a.id, scale=1, offset=0, member=ipt.id 328 ) 329 break 330 if batch_axis_trf is not None: 331 break 332 axis_scales = { 333 t.id: {a.id: a.scale for a in t.axes} 334 for t in chain(model.inputs, model.outputs) 335 } 336 for out in model.outputs: 337 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 338 for a in out.axes: 339 if a.size is None: 340 assert a.type == "batch" 341 if batch_axis_trf is None: 342 raise ValueError( 343 "no batch axis found in any input tensor, but output tensor" 344 + f" '{out.id}' has one." 345 ) 346 s = batch_axis_trf 347 elif isinstance(a.size, int): 348 s = a.size 349 elif isinstance(a.size, v0_5.DataDependentSize): 350 s = -1 351 elif isinstance(a.size, v0_5.SizeReference): 352 s = LinearSampleAxisTransform( 353 axis=a.size.axis_id, 354 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 355 offset=a.size.offset, 356 member=a.size.tensor_id, 357 ) 358 else: 359 assert_never(a.size) 360 361 new_axes[a.id] = s 362 363 ret[out.id] = new_axes 364 365 return ret
returns how a model's output tensor shapes relates to its input shapes
368def get_io_sample_block_metas( 369 model: v0_5.ModelDescr, 370 input_sample_shape: PerMember[PerAxis[int]], 371 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 372 batch_size: int = 1, 373) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 374 """returns an iterable yielding meta data for corresponding input and output samples""" 375 if not isinstance(model, v0_5.ModelDescr): 376 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 377 378 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 379 input_block_shape = { 380 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 381 for t in {tt for tt, _ in block_axis_sizes.inputs} 382 } 383 output_halo = { 384 t.id: { 385 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 386 } 387 for t in model.outputs 388 } 389 input_halo = get_input_halo(model, output_halo) 390 391 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 392 input_sample_shape, input_block_shape, halo=input_halo 393 ) 394 block_transform = get_block_transform(model) 395 return n_input_blocks, ( 396 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 397 for ipt in sample_block_meta_generator( 398 input_blocks, sample_shape=input_sample_shape, sample_id=None 399 ) 400 )
returns an iterable yielding meta data for corresponding input and output samples
403def get_tensor( 404 src: Union[ZipPath, TensorSource], 405 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 406): 407 """helper to cast/load various tensor sources""" 408 409 if isinstance(src, Tensor): 410 return src 411 elif isinstance(src, xr.DataArray): 412 return Tensor.from_xarray(src) 413 elif isinstance(src, np.ndarray): 414 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 415 else: 416 return load_tensor(src, axes=get_axes_infos(ipt))
helper to cast/load various tensor sources
419def create_sample_for_model( 420 model: AnyModelDescr, 421 *, 422 stat: Optional[Stat] = None, 423 sample_id: SampleId = None, 424 inputs: Union[PerMember[TensorSource], TensorSource], 425) -> Sample: 426 """Create a sample from a single set of input(s) for a specific bioimage.io model 427 428 Args: 429 model: a bioimage.io model description 430 stat: dictionary with sample and dataset statistics (may be updated in-place!) 431 inputs: the input(s) constituting a single sample. 432 """ 433 434 model_inputs = {get_member_id(d): d for d in model.inputs} 435 if isinstance(inputs, collections.abc.Mapping): 436 inputs = {MemberId(k): v for k, v in inputs.items()} 437 elif len(model_inputs) == 1: 438 inputs = {list(model_inputs)[0]: inputs} 439 else: 440 raise TypeError( 441 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 442 ) 443 444 if unknown := {k for k in inputs if k not in model_inputs}: 445 raise ValueError(f"Got unexpected inputs: {unknown}") 446 447 if missing := { 448 k 449 for k, v in model_inputs.items() 450 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 451 }: 452 raise ValueError(f"Missing non-optional model inputs: {missing}") 453 454 return Sample( 455 members={ 456 m: get_tensor(inputs[m], ipt) 457 for m, ipt in model_inputs.items() 458 if m in inputs 459 }, 460 stat={} if stat is None else stat, 461 id=sample_id, 462 )
Create a sample from a single set of input(s) for a specific bioimage.io model
Arguments:
- model: a bioimage.io model description
- stat: dictionary with sample and dataset statistics (may be updated in-place!)
- inputs: the input(s) constituting a single sample.
465def load_sample_for_model( 466 *, 467 model: AnyModelDescr, 468 paths: PerMember[Path], 469 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 470 stat: Optional[Stat] = None, 471 sample_id: Optional[SampleId] = None, 472): 473 """load a single sample from `paths` that can be processed by `model`""" 474 475 if axes is None: 476 axes = {} 477 478 # make sure members are keyed by MemberId, not string 479 paths = {MemberId(k): v for k, v in paths.items()} 480 axes = {MemberId(k): v for k, v in axes.items()} 481 482 model_inputs = {get_member_id(d): d for d in model.inputs} 483 484 if unknown := {k for k in paths if k not in model_inputs}: 485 raise ValueError(f"Got unexpected paths for {unknown}") 486 487 if unknown := {k for k in axes if k not in model_inputs}: 488 raise ValueError(f"Got unexpected axes hints for: {unknown}") 489 490 members: Dict[MemberId, Tensor] = {} 491 for m, p in paths.items(): 492 if m not in axes: 493 axes[m] = get_axes_infos(model_inputs[m]) 494 logger.debug( 495 "loading '{}' from {} with default input axes {} ", 496 m, 497 p, 498 axes[m], 499 ) 500 members[m] = load_tensor(p, axes[m]) 501 502 return Sample( 503 members=members, 504 stat={} if stat is None else stat, 505 id=sample_id or tuple(sorted(paths.values())), 506 )
load a single sample from paths
that can be processed by model