bioimageio.core.digest_spec
1from __future__ import annotations 2 3import collections.abc 4import hashlib 5import importlib.util 6import sys 7from itertools import chain 8from pathlib import Path 9from typing import ( 10 Any, 11 Callable, 12 Dict, 13 Iterable, 14 List, 15 Mapping, 16 NamedTuple, 17 Optional, 18 Sequence, 19 Tuple, 20 Union, 21) 22 23import numpy as np 24import xarray as xr 25from loguru import logger 26from numpy.typing import NDArray 27from typing_extensions import Unpack, assert_never 28 29from bioimageio.spec._internal.io import HashKwargs 30from bioimageio.spec.common import FileDescr, FileSource, ZipPath 31from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 32from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 33from bioimageio.spec.model.v0_5 import ( 34 ArchitectureFromFileDescr, 35 ArchitectureFromLibraryDescr, 36 ParameterizedSize_N, 37) 38from bioimageio.spec.utils import download, load_array 39 40from ._settings import settings 41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 42from .block_meta import split_multiple_shapes_into_blocks 43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 44from .io import load_tensor 45from .sample import ( 46 LinearSampleAxisTransform, 47 Sample, 48 SampleBlockMeta, 49 sample_block_meta_generator, 50) 51from .stat_measures import Stat 52from .tensor import Tensor 53 54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path] 55 56 57def import_callable( 58 node: Union[ 59 ArchitectureFromFileDescr, 60 ArchitectureFromLibraryDescr, 61 CallableFromDepencency, 62 CallableFromFile, 63 ], 64 /, 65 **kwargs: Unpack[HashKwargs], 66) -> Callable[..., Any]: 67 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 68 if isinstance(node, CallableFromDepencency): 69 module = importlib.import_module(node.module_name) 70 c = getattr(module, str(node.callable_name)) 71 elif isinstance(node, ArchitectureFromLibraryDescr): 72 module = importlib.import_module(node.import_from) 73 c = getattr(module, str(node.callable)) 74 elif isinstance(node, CallableFromFile): 75 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 76 elif isinstance(node, ArchitectureFromFileDescr): 77 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 78 else: 79 assert_never(node) 80 81 if not callable(c): 82 raise ValueError(f"{node} (imported: {c}) is not callable") 83 84 return c 85 86 87def _import_from_file_impl( 88 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 89): 90 src_descr = FileDescr(source=source, **kwargs) 91 # ensure sha is valid even if perform_io_checks=False 92 src_descr.validate_sha256() 93 assert src_descr.sha256 is not None 94 95 local_source = src_descr.download() 96 97 source_bytes = local_source.path.read_bytes() 98 assert isinstance(source_bytes, bytes) 99 source_sha = hashlib.sha256(source_bytes).hexdigest() 100 101 # make sure we have unique module name 102 module_name = f"{local_source.path.stem}_{source_sha}" 103 104 # make sure we have a valid module name 105 if not module_name.isidentifier(): 106 module_name = f"custom_module_{source_sha}" 107 assert module_name.isidentifier(), module_name 108 109 module = sys.modules.get(module_name) 110 if module is None: 111 try: 112 if isinstance(local_source.path, Path): 113 module_path = local_source.path 114 elif isinstance(local_source.path, ZipPath): 115 # save extract source to cache 116 # loading from a file from disk ensure we get readable tracebacks 117 # if any errors occur 118 module_path = ( 119 settings.cache_path / f"{source_sha}-{local_source.path.name}" 120 ) 121 _ = module_path.write_bytes(source_bytes) 122 else: 123 assert_never(local_source.path) 124 125 importlib_spec = importlib.util.spec_from_file_location( 126 module_name, module_path 127 ) 128 129 if importlib_spec is None: 130 raise ImportError(f"Failed to import {source}") 131 132 module = importlib.util.module_from_spec(importlib_spec) 133 assert importlib_spec.loader is not None 134 importlib_spec.loader.exec_module(module) 135 136 except Exception as e: 137 raise ImportError(f"Failed to import {source}") from e 138 else: 139 sys.modules[module_name] = module # cache this module 140 141 try: 142 callable_attr = getattr(module, callable_name) 143 except AttributeError as e: 144 raise AttributeError( 145 f"Imported custom module from {source} has no `{callable_name}` attribute." 146 ) from e 147 except Exception as e: 148 raise AttributeError( 149 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ." 150 ) from e 151 152 else: 153 return callable_attr 154 155 156def get_axes_infos( 157 io_descr: Union[ 158 v0_4.InputTensorDescr, 159 v0_4.OutputTensorDescr, 160 v0_5.InputTensorDescr, 161 v0_5.OutputTensorDescr, 162 ], 163) -> List[AxisInfo]: 164 """get a unified, simplified axis representation from spec axes""" 165 ret: List[AxisInfo] = [] 166 for a in io_descr.axes: 167 if isinstance(a, v0_5.AxisBase): 168 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 169 else: 170 assert a in ("b", "i", "t", "c", "z", "y", "x") 171 ret.append(AxisInfo.create(a)) 172 173 return ret 174 175 176def get_member_id( 177 tensor_description: Union[ 178 v0_4.InputTensorDescr, 179 v0_4.OutputTensorDescr, 180 v0_5.InputTensorDescr, 181 v0_5.OutputTensorDescr, 182 ], 183) -> MemberId: 184 """get the normalized tensor ID, usable as a sample member ID""" 185 186 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 187 return MemberId(tensor_description.name) 188 elif isinstance( 189 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 190 ): 191 return tensor_description.id 192 else: 193 assert_never(tensor_description) 194 195 196def get_member_ids( 197 tensor_descriptions: Sequence[ 198 Union[ 199 v0_4.InputTensorDescr, 200 v0_4.OutputTensorDescr, 201 v0_5.InputTensorDescr, 202 v0_5.OutputTensorDescr, 203 ] 204 ], 205) -> List[MemberId]: 206 """get normalized tensor IDs to be used as sample member IDs""" 207 return [get_member_id(descr) for descr in tensor_descriptions] 208 209 210def get_test_inputs(model: AnyModelDescr) -> Sample: 211 """returns a model's test input sample""" 212 member_ids = get_member_ids(model.inputs) 213 if isinstance(model, v0_4.ModelDescr): 214 arrays = [load_array(tt) for tt in model.test_inputs] 215 else: 216 arrays = [load_array(d.test_tensor) for d in model.inputs] 217 218 axes = [get_axes_infos(t) for t in model.inputs] 219 return Sample( 220 members={ 221 m: Tensor.from_numpy(arr, dims=ax) 222 for m, arr, ax in zip(member_ids, arrays, axes) 223 }, 224 stat={}, 225 id="test-sample", 226 ) 227 228 229def get_test_outputs(model: AnyModelDescr) -> Sample: 230 """returns a model's test output sample""" 231 member_ids = get_member_ids(model.outputs) 232 233 if isinstance(model, v0_4.ModelDescr): 234 arrays = [load_array(tt) for tt in model.test_outputs] 235 else: 236 arrays = [load_array(d.test_tensor) for d in model.outputs] 237 238 axes = [get_axes_infos(t) for t in model.outputs] 239 240 return Sample( 241 members={ 242 m: Tensor.from_numpy(arr, dims=ax) 243 for m, arr, ax in zip(member_ids, arrays, axes) 244 }, 245 stat={}, 246 id="test-sample", 247 ) 248 249 250class IO_SampleBlockMeta(NamedTuple): 251 input: SampleBlockMeta 252 output: SampleBlockMeta 253 254 255def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 256 """returns which halo input tensors need to be divided into blocks with, such that 257 `output_halo` can be cropped from their outputs without introducing gaps.""" 258 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 259 outputs = {t.id: t for t in model.outputs} 260 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 261 262 for t, th in output_halo.items(): 263 axes = {a.id: a for a in outputs[t].axes} 264 265 for a, ah in th.items(): 266 s = axes[a].size 267 if not isinstance(s, v0_5.SizeReference): 268 raise ValueError( 269 f"Unable to map output halo for {t}.{a} to an input axis" 270 ) 271 272 axis = axes[a] 273 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 274 275 total_output_halo = sum(ah) 276 total_input_halo = total_output_halo * axis.scale / ref_axis.scale 277 assert ( 278 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 279 ) 280 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 281 int(total_input_halo // 2), int(total_input_halo // 2) 282 ) 283 284 return input_halo 285 286 287def get_block_transform( 288 model: v0_5.ModelDescr, 289) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 290 """returns how a model's output tensor shapes relates to its input shapes""" 291 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 292 batch_axis_trf = None 293 for ipt in model.inputs: 294 for a in ipt.axes: 295 if a.type == "batch": 296 batch_axis_trf = LinearSampleAxisTransform( 297 axis=a.id, scale=1, offset=0, member=ipt.id 298 ) 299 break 300 if batch_axis_trf is not None: 301 break 302 axis_scales = { 303 t.id: {a.id: a.scale for a in t.axes} 304 for t in chain(model.inputs, model.outputs) 305 } 306 for out in model.outputs: 307 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 308 for a in out.axes: 309 if a.size is None: 310 assert a.type == "batch" 311 if batch_axis_trf is None: 312 raise ValueError( 313 "no batch axis found in any input tensor, but output tensor" 314 + f" '{out.id}' has one." 315 ) 316 s = batch_axis_trf 317 elif isinstance(a.size, int): 318 s = a.size 319 elif isinstance(a.size, v0_5.DataDependentSize): 320 s = -1 321 elif isinstance(a.size, v0_5.SizeReference): 322 s = LinearSampleAxisTransform( 323 axis=a.size.axis_id, 324 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 325 offset=a.size.offset, 326 member=a.size.tensor_id, 327 ) 328 else: 329 assert_never(a.size) 330 331 new_axes[a.id] = s 332 333 ret[out.id] = new_axes 334 335 return ret 336 337 338def get_io_sample_block_metas( 339 model: v0_5.ModelDescr, 340 input_sample_shape: PerMember[PerAxis[int]], 341 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 342 batch_size: int = 1, 343) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 344 """returns an iterable yielding meta data for corresponding input and output samples""" 345 if not isinstance(model, v0_5.ModelDescr): 346 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 347 348 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 349 input_block_shape = { 350 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 351 for t in {tt for tt, _ in block_axis_sizes.inputs} 352 } 353 output_halo = { 354 t.id: { 355 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 356 } 357 for t in model.outputs 358 } 359 input_halo = get_input_halo(model, output_halo) 360 361 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 362 input_sample_shape, input_block_shape, halo=input_halo 363 ) 364 block_transform = get_block_transform(model) 365 return n_input_blocks, ( 366 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 367 for ipt in sample_block_meta_generator( 368 input_blocks, sample_shape=input_sample_shape, sample_id=None 369 ) 370 ) 371 372 373def get_tensor( 374 src: Union[ZipPath, TensorSource], 375 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 376): 377 """helper to cast/load various tensor sources""" 378 379 if isinstance(src, Tensor): 380 return src 381 382 if isinstance(src, xr.DataArray): 383 return Tensor.from_xarray(src) 384 385 if isinstance(src, np.ndarray): 386 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 387 388 if isinstance(src, FileDescr): 389 src = download(src).path 390 391 if isinstance(src, (ZipPath, Path, str)): 392 return load_tensor(src, axes=get_axes_infos(ipt)) 393 394 assert_never(src) 395 396 397def create_sample_for_model( 398 model: AnyModelDescr, 399 *, 400 stat: Optional[Stat] = None, 401 sample_id: SampleId = None, 402 inputs: Union[PerMember[TensorSource], TensorSource], 403) -> Sample: 404 """Create a sample from a single set of input(s) for a specific bioimage.io model 405 406 Args: 407 model: a bioimage.io model description 408 stat: dictionary with sample and dataset statistics (may be updated in-place!) 409 inputs: the input(s) constituting a single sample. 410 """ 411 412 model_inputs = {get_member_id(d): d for d in model.inputs} 413 if isinstance(inputs, collections.abc.Mapping): 414 inputs = {MemberId(k): v for k, v in inputs.items()} 415 elif len(model_inputs) == 1: 416 inputs = {list(model_inputs)[0]: inputs} 417 else: 418 raise TypeError( 419 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 420 ) 421 422 if unknown := {k for k in inputs if k not in model_inputs}: 423 raise ValueError(f"Got unexpected inputs: {unknown}") 424 425 if missing := { 426 k 427 for k, v in model_inputs.items() 428 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 429 }: 430 raise ValueError(f"Missing non-optional model inputs: {missing}") 431 432 return Sample( 433 members={ 434 m: get_tensor(inputs[m], ipt) 435 for m, ipt in model_inputs.items() 436 if m in inputs 437 }, 438 stat={} if stat is None else stat, 439 id=sample_id, 440 ) 441 442 443def load_sample_for_model( 444 *, 445 model: AnyModelDescr, 446 paths: PerMember[Path], 447 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 448 stat: Optional[Stat] = None, 449 sample_id: Optional[SampleId] = None, 450): 451 """load a single sample from `paths` that can be processed by `model`""" 452 453 if axes is None: 454 axes = {} 455 456 # make sure members are keyed by MemberId, not string 457 paths = {MemberId(k): v for k, v in paths.items()} 458 axes = {MemberId(k): v for k, v in axes.items()} 459 460 model_inputs = {get_member_id(d): d for d in model.inputs} 461 462 if unknown := {k for k in paths if k not in model_inputs}: 463 raise ValueError(f"Got unexpected paths for {unknown}") 464 465 if unknown := {k for k in axes if k not in model_inputs}: 466 raise ValueError(f"Got unexpected axes hints for: {unknown}") 467 468 members: Dict[MemberId, Tensor] = {} 469 for m, p in paths.items(): 470 if m not in axes: 471 axes[m] = get_axes_infos(model_inputs[m]) 472 logger.debug( 473 "loading '{}' from {} with default input axes {} ", 474 m, 475 p, 476 axes[m], 477 ) 478 members[m] = load_tensor(p, axes[m]) 479 480 return Sample( 481 members=members, 482 stat={} if stat is None else stat, 483 id=sample_id or tuple(sorted(paths.values())), 484 )
58def import_callable( 59 node: Union[ 60 ArchitectureFromFileDescr, 61 ArchitectureFromLibraryDescr, 62 CallableFromDepencency, 63 CallableFromFile, 64 ], 65 /, 66 **kwargs: Unpack[HashKwargs], 67) -> Callable[..., Any]: 68 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 69 if isinstance(node, CallableFromDepencency): 70 module = importlib.import_module(node.module_name) 71 c = getattr(module, str(node.callable_name)) 72 elif isinstance(node, ArchitectureFromLibraryDescr): 73 module = importlib.import_module(node.import_from) 74 c = getattr(module, str(node.callable)) 75 elif isinstance(node, CallableFromFile): 76 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 77 elif isinstance(node, ArchitectureFromFileDescr): 78 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 79 else: 80 assert_never(node) 81 82 if not callable(c): 83 raise ValueError(f"{node} (imported: {c}) is not callable") 84 85 return c
import a callable (e.g. a torch.nn.Module) from a spec node describing it
157def get_axes_infos( 158 io_descr: Union[ 159 v0_4.InputTensorDescr, 160 v0_4.OutputTensorDescr, 161 v0_5.InputTensorDescr, 162 v0_5.OutputTensorDescr, 163 ], 164) -> List[AxisInfo]: 165 """get a unified, simplified axis representation from spec axes""" 166 ret: List[AxisInfo] = [] 167 for a in io_descr.axes: 168 if isinstance(a, v0_5.AxisBase): 169 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 170 else: 171 assert a in ("b", "i", "t", "c", "z", "y", "x") 172 ret.append(AxisInfo.create(a)) 173 174 return ret
get a unified, simplified axis representation from spec axes
177def get_member_id( 178 tensor_description: Union[ 179 v0_4.InputTensorDescr, 180 v0_4.OutputTensorDescr, 181 v0_5.InputTensorDescr, 182 v0_5.OutputTensorDescr, 183 ], 184) -> MemberId: 185 """get the normalized tensor ID, usable as a sample member ID""" 186 187 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 188 return MemberId(tensor_description.name) 189 elif isinstance( 190 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 191 ): 192 return tensor_description.id 193 else: 194 assert_never(tensor_description)
get the normalized tensor ID, usable as a sample member ID
197def get_member_ids( 198 tensor_descriptions: Sequence[ 199 Union[ 200 v0_4.InputTensorDescr, 201 v0_4.OutputTensorDescr, 202 v0_5.InputTensorDescr, 203 v0_5.OutputTensorDescr, 204 ] 205 ], 206) -> List[MemberId]: 207 """get normalized tensor IDs to be used as sample member IDs""" 208 return [get_member_id(descr) for descr in tensor_descriptions]
get normalized tensor IDs to be used as sample member IDs
211def get_test_inputs(model: AnyModelDescr) -> Sample: 212 """returns a model's test input sample""" 213 member_ids = get_member_ids(model.inputs) 214 if isinstance(model, v0_4.ModelDescr): 215 arrays = [load_array(tt) for tt in model.test_inputs] 216 else: 217 arrays = [load_array(d.test_tensor) for d in model.inputs] 218 219 axes = [get_axes_infos(t) for t in model.inputs] 220 return Sample( 221 members={ 222 m: Tensor.from_numpy(arr, dims=ax) 223 for m, arr, ax in zip(member_ids, arrays, axes) 224 }, 225 stat={}, 226 id="test-sample", 227 )
returns a model's test input sample
230def get_test_outputs(model: AnyModelDescr) -> Sample: 231 """returns a model's test output sample""" 232 member_ids = get_member_ids(model.outputs) 233 234 if isinstance(model, v0_4.ModelDescr): 235 arrays = [load_array(tt) for tt in model.test_outputs] 236 else: 237 arrays = [load_array(d.test_tensor) for d in model.outputs] 238 239 axes = [get_axes_infos(t) for t in model.outputs] 240 241 return Sample( 242 members={ 243 m: Tensor.from_numpy(arr, dims=ax) 244 for m, arr, ax in zip(member_ids, arrays, axes) 245 }, 246 stat={}, 247 id="test-sample", 248 )
returns a model's test output sample
IO_SampleBlockMeta(input, output)
256def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 257 """returns which halo input tensors need to be divided into blocks with, such that 258 `output_halo` can be cropped from their outputs without introducing gaps.""" 259 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 260 outputs = {t.id: t for t in model.outputs} 261 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 262 263 for t, th in output_halo.items(): 264 axes = {a.id: a for a in outputs[t].axes} 265 266 for a, ah in th.items(): 267 s = axes[a].size 268 if not isinstance(s, v0_5.SizeReference): 269 raise ValueError( 270 f"Unable to map output halo for {t}.{a} to an input axis" 271 ) 272 273 axis = axes[a] 274 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 275 276 total_output_halo = sum(ah) 277 total_input_halo = total_output_halo * axis.scale / ref_axis.scale 278 assert ( 279 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 280 ) 281 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 282 int(total_input_halo // 2), int(total_input_halo // 2) 283 ) 284 285 return input_halo
returns which halo input tensors need to be divided into blocks with, such that
output_halo
can be cropped from their outputs without introducing gaps.
288def get_block_transform( 289 model: v0_5.ModelDescr, 290) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 291 """returns how a model's output tensor shapes relates to its input shapes""" 292 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 293 batch_axis_trf = None 294 for ipt in model.inputs: 295 for a in ipt.axes: 296 if a.type == "batch": 297 batch_axis_trf = LinearSampleAxisTransform( 298 axis=a.id, scale=1, offset=0, member=ipt.id 299 ) 300 break 301 if batch_axis_trf is not None: 302 break 303 axis_scales = { 304 t.id: {a.id: a.scale for a in t.axes} 305 for t in chain(model.inputs, model.outputs) 306 } 307 for out in model.outputs: 308 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 309 for a in out.axes: 310 if a.size is None: 311 assert a.type == "batch" 312 if batch_axis_trf is None: 313 raise ValueError( 314 "no batch axis found in any input tensor, but output tensor" 315 + f" '{out.id}' has one." 316 ) 317 s = batch_axis_trf 318 elif isinstance(a.size, int): 319 s = a.size 320 elif isinstance(a.size, v0_5.DataDependentSize): 321 s = -1 322 elif isinstance(a.size, v0_5.SizeReference): 323 s = LinearSampleAxisTransform( 324 axis=a.size.axis_id, 325 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 326 offset=a.size.offset, 327 member=a.size.tensor_id, 328 ) 329 else: 330 assert_never(a.size) 331 332 new_axes[a.id] = s 333 334 ret[out.id] = new_axes 335 336 return ret
returns how a model's output tensor shapes relates to its input shapes
339def get_io_sample_block_metas( 340 model: v0_5.ModelDescr, 341 input_sample_shape: PerMember[PerAxis[int]], 342 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 343 batch_size: int = 1, 344) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 345 """returns an iterable yielding meta data for corresponding input and output samples""" 346 if not isinstance(model, v0_5.ModelDescr): 347 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 348 349 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 350 input_block_shape = { 351 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 352 for t in {tt for tt, _ in block_axis_sizes.inputs} 353 } 354 output_halo = { 355 t.id: { 356 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 357 } 358 for t in model.outputs 359 } 360 input_halo = get_input_halo(model, output_halo) 361 362 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 363 input_sample_shape, input_block_shape, halo=input_halo 364 ) 365 block_transform = get_block_transform(model) 366 return n_input_blocks, ( 367 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 368 for ipt in sample_block_meta_generator( 369 input_blocks, sample_shape=input_sample_shape, sample_id=None 370 ) 371 )
returns an iterable yielding meta data for corresponding input and output samples
374def get_tensor( 375 src: Union[ZipPath, TensorSource], 376 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 377): 378 """helper to cast/load various tensor sources""" 379 380 if isinstance(src, Tensor): 381 return src 382 383 if isinstance(src, xr.DataArray): 384 return Tensor.from_xarray(src) 385 386 if isinstance(src, np.ndarray): 387 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 388 389 if isinstance(src, FileDescr): 390 src = download(src).path 391 392 if isinstance(src, (ZipPath, Path, str)): 393 return load_tensor(src, axes=get_axes_infos(ipt)) 394 395 assert_never(src)
helper to cast/load various tensor sources
398def create_sample_for_model( 399 model: AnyModelDescr, 400 *, 401 stat: Optional[Stat] = None, 402 sample_id: SampleId = None, 403 inputs: Union[PerMember[TensorSource], TensorSource], 404) -> Sample: 405 """Create a sample from a single set of input(s) for a specific bioimage.io model 406 407 Args: 408 model: a bioimage.io model description 409 stat: dictionary with sample and dataset statistics (may be updated in-place!) 410 inputs: the input(s) constituting a single sample. 411 """ 412 413 model_inputs = {get_member_id(d): d for d in model.inputs} 414 if isinstance(inputs, collections.abc.Mapping): 415 inputs = {MemberId(k): v for k, v in inputs.items()} 416 elif len(model_inputs) == 1: 417 inputs = {list(model_inputs)[0]: inputs} 418 else: 419 raise TypeError( 420 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 421 ) 422 423 if unknown := {k for k in inputs if k not in model_inputs}: 424 raise ValueError(f"Got unexpected inputs: {unknown}") 425 426 if missing := { 427 k 428 for k, v in model_inputs.items() 429 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 430 }: 431 raise ValueError(f"Missing non-optional model inputs: {missing}") 432 433 return Sample( 434 members={ 435 m: get_tensor(inputs[m], ipt) 436 for m, ipt in model_inputs.items() 437 if m in inputs 438 }, 439 stat={} if stat is None else stat, 440 id=sample_id, 441 )
Create a sample from a single set of input(s) for a specific bioimage.io model
Arguments:
- model: a bioimage.io model description
- stat: dictionary with sample and dataset statistics (may be updated in-place!)
- inputs: the input(s) constituting a single sample.
444def load_sample_for_model( 445 *, 446 model: AnyModelDescr, 447 paths: PerMember[Path], 448 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 449 stat: Optional[Stat] = None, 450 sample_id: Optional[SampleId] = None, 451): 452 """load a single sample from `paths` that can be processed by `model`""" 453 454 if axes is None: 455 axes = {} 456 457 # make sure members are keyed by MemberId, not string 458 paths = {MemberId(k): v for k, v in paths.items()} 459 axes = {MemberId(k): v for k, v in axes.items()} 460 461 model_inputs = {get_member_id(d): d for d in model.inputs} 462 463 if unknown := {k for k in paths if k not in model_inputs}: 464 raise ValueError(f"Got unexpected paths for {unknown}") 465 466 if unknown := {k for k in axes if k not in model_inputs}: 467 raise ValueError(f"Got unexpected axes hints for: {unknown}") 468 469 members: Dict[MemberId, Tensor] = {} 470 for m, p in paths.items(): 471 if m not in axes: 472 axes[m] = get_axes_infos(model_inputs[m]) 473 logger.debug( 474 "loading '{}' from {} with default input axes {} ", 475 m, 476 p, 477 axes[m], 478 ) 479 members[m] = load_tensor(p, axes[m]) 480 481 return Sample( 482 members=members, 483 stat={} if stat is None else stat, 484 id=sample_id or tuple(sorted(paths.values())), 485 )
load a single sample from paths
that can be processed by model