bioimageio.core.digest_spec
1from __future__ import annotations 2 3import importlib.util 4from itertools import chain 5from pathlib import Path 6from typing import ( 7 Any, 8 Callable, 9 Dict, 10 Iterable, 11 List, 12 Mapping, 13 NamedTuple, 14 Optional, 15 Sequence, 16 Tuple, 17 Union, 18) 19 20import numpy as np 21import xarray as xr 22from loguru import logger 23from numpy.typing import NDArray 24from typing_extensions import Unpack, assert_never 25 26from bioimageio.spec._internal.io import resolve_and_extract 27from bioimageio.spec._internal.io_utils import HashKwargs 28from bioimageio.spec.common import FileSource 29from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 30from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 31from bioimageio.spec.model.v0_5 import ( 32 ArchitectureFromFileDescr, 33 ArchitectureFromLibraryDescr, 34 ParameterizedSize_N, 35) 36from bioimageio.spec.utils import load_array 37 38from .axis import AxisId, AxisInfo, AxisLike, PerAxis 39from .block_meta import split_multiple_shapes_into_blocks 40from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 41from .io import load_tensor 42from .sample import ( 43 LinearSampleAxisTransform, 44 Sample, 45 SampleBlockMeta, 46 sample_block_meta_generator, 47) 48from .stat_measures import Stat 49from .tensor import Tensor 50 51 52def import_callable( 53 node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr], 54 /, 55 **kwargs: Unpack[HashKwargs], 56) -> Callable[..., Any]: 57 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 58 if isinstance(node, CallableFromDepencency): 59 module = importlib.import_module(node.module_name) 60 c = getattr(module, str(node.callable_name)) 61 elif isinstance(node, ArchitectureFromLibraryDescr): 62 module = importlib.import_module(node.import_from) 63 c = getattr(module, str(node.callable)) 64 elif isinstance(node, CallableFromFile): 65 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 66 elif isinstance(node, ArchitectureFromFileDescr): 67 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 68 69 else: 70 assert_never(node) 71 72 if not callable(c): 73 raise ValueError(f"{node} (imported: {c}) is not callable") 74 75 return c 76 77 78def _import_from_file_impl( 79 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 80): 81 local_file = resolve_and_extract(source, **kwargs) 82 module_name = local_file.path.stem 83 importlib_spec = importlib.util.spec_from_file_location( 84 module_name, local_file.path 85 ) 86 if importlib_spec is None: 87 raise ImportError(f"Failed to import {module_name} from {source}.") 88 89 dep = importlib.util.module_from_spec(importlib_spec) 90 importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? 91 return getattr(dep, callable_name) 92 93 94def get_axes_infos( 95 io_descr: Union[ 96 v0_4.InputTensorDescr, 97 v0_4.OutputTensorDescr, 98 v0_5.InputTensorDescr, 99 v0_5.OutputTensorDescr, 100 ], 101) -> List[AxisInfo]: 102 """get a unified, simplified axis representation from spec axes""" 103 return [ 104 ( 105 AxisInfo.create("i") 106 if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x") 107 else AxisInfo.create(a) 108 ) 109 for a in io_descr.axes 110 ] 111 112 113def get_member_id( 114 tensor_description: Union[ 115 v0_4.InputTensorDescr, 116 v0_4.OutputTensorDescr, 117 v0_5.InputTensorDescr, 118 v0_5.OutputTensorDescr, 119 ], 120) -> MemberId: 121 """get the normalized tensor ID, usable as a sample member ID""" 122 123 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 124 return MemberId(tensor_description.name) 125 elif isinstance( 126 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 127 ): 128 return tensor_description.id 129 else: 130 assert_never(tensor_description) 131 132 133def get_member_ids( 134 tensor_descriptions: Sequence[ 135 Union[ 136 v0_4.InputTensorDescr, 137 v0_4.OutputTensorDescr, 138 v0_5.InputTensorDescr, 139 v0_5.OutputTensorDescr, 140 ] 141 ], 142) -> List[MemberId]: 143 """get normalized tensor IDs to be used as sample member IDs""" 144 return [get_member_id(descr) for descr in tensor_descriptions] 145 146 147def get_test_inputs(model: AnyModelDescr) -> Sample: 148 """returns a model's test input sample""" 149 member_ids = get_member_ids(model.inputs) 150 if isinstance(model, v0_4.ModelDescr): 151 arrays = [load_array(tt) for tt in model.test_inputs] 152 else: 153 arrays = [load_array(d.test_tensor) for d in model.inputs] 154 155 axes = [get_axes_infos(t) for t in model.inputs] 156 return Sample( 157 members={ 158 m: Tensor.from_numpy(arr, dims=ax) 159 for m, arr, ax in zip(member_ids, arrays, axes) 160 }, 161 stat={}, 162 id="test-sample", 163 ) 164 165 166def get_test_outputs(model: AnyModelDescr) -> Sample: 167 """returns a model's test output sample""" 168 member_ids = get_member_ids(model.outputs) 169 170 if isinstance(model, v0_4.ModelDescr): 171 arrays = [load_array(tt) for tt in model.test_outputs] 172 else: 173 arrays = [load_array(d.test_tensor) for d in model.outputs] 174 175 axes = [get_axes_infos(t) for t in model.outputs] 176 177 return Sample( 178 members={ 179 m: Tensor.from_numpy(arr, dims=ax) 180 for m, arr, ax in zip(member_ids, arrays, axes) 181 }, 182 stat={}, 183 id="test-sample", 184 ) 185 186 187class IO_SampleBlockMeta(NamedTuple): 188 input: SampleBlockMeta 189 output: SampleBlockMeta 190 191 192def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 193 """returns which halo input tensors need to be divided into blocks with, such that 194 `output_halo` can be cropped from their outputs without introducing gaps.""" 195 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 196 outputs = {t.id: t for t in model.outputs} 197 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 198 199 for t, th in output_halo.items(): 200 axes = {a.id: a for a in outputs[t].axes} 201 202 for a, ah in th.items(): 203 s = axes[a].size 204 if not isinstance(s, v0_5.SizeReference): 205 raise ValueError( 206 f"Unable to map output halo for {t}.{a} to an input axis" 207 ) 208 209 axis = axes[a] 210 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 211 212 total_output_halo = sum(ah) 213 total_input_halo = total_output_halo * axis.scale / ref_axis.scale 214 assert ( 215 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 216 ) 217 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 218 int(total_input_halo // 2), int(total_input_halo // 2) 219 ) 220 221 return input_halo 222 223 224def get_block_transform( 225 model: v0_5.ModelDescr, 226) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 227 """returns how a model's output tensor shapes relates to its input shapes""" 228 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 229 batch_axis_trf = None 230 for ipt in model.inputs: 231 for a in ipt.axes: 232 if a.type == "batch": 233 batch_axis_trf = LinearSampleAxisTransform( 234 axis=a.id, scale=1, offset=0, member=ipt.id 235 ) 236 break 237 if batch_axis_trf is not None: 238 break 239 axis_scales = { 240 t.id: {a.id: a.scale for a in t.axes} 241 for t in chain(model.inputs, model.outputs) 242 } 243 for out in model.outputs: 244 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 245 for a in out.axes: 246 if a.size is None: 247 assert a.type == "batch" 248 if batch_axis_trf is None: 249 raise ValueError( 250 "no batch axis found in any input tensor, but output tensor" 251 + f" '{out.id}' has one." 252 ) 253 s = batch_axis_trf 254 elif isinstance(a.size, int): 255 s = a.size 256 elif isinstance(a.size, v0_5.DataDependentSize): 257 s = -1 258 elif isinstance(a.size, v0_5.SizeReference): 259 s = LinearSampleAxisTransform( 260 axis=a.size.axis_id, 261 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 262 offset=a.size.offset, 263 member=a.size.tensor_id, 264 ) 265 else: 266 assert_never(a.size) 267 268 new_axes[a.id] = s 269 270 ret[out.id] = new_axes 271 272 return ret 273 274 275def get_io_sample_block_metas( 276 model: v0_5.ModelDescr, 277 input_sample_shape: PerMember[PerAxis[int]], 278 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 279 batch_size: int = 1, 280) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 281 """returns an iterable yielding meta data for corresponding input and output samples""" 282 if not isinstance(model, v0_5.ModelDescr): 283 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 284 285 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 286 input_block_shape = { 287 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 288 for t in {tt for tt, _ in block_axis_sizes.inputs} 289 } 290 output_halo = { 291 t.id: { 292 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 293 } 294 for t in model.outputs 295 } 296 input_halo = get_input_halo(model, output_halo) 297 298 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 299 input_sample_shape, input_block_shape, halo=input_halo 300 ) 301 block_transform = get_block_transform(model) 302 return n_input_blocks, ( 303 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 304 for ipt in sample_block_meta_generator( 305 input_blocks, sample_shape=input_sample_shape, sample_id=None 306 ) 307 ) 308 309 310def get_tensor( 311 src: Union[Tensor, xr.DataArray, NDArray[Any], Path], 312 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 313): 314 """helper to cast/load various tensor sources""" 315 316 if isinstance(src, Tensor): 317 return src 318 319 if isinstance(src, xr.DataArray): 320 return Tensor.from_xarray(src) 321 322 if isinstance(src, np.ndarray): 323 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 324 325 if isinstance(src, Path): 326 return load_tensor(src, axes=get_axes_infos(ipt)) 327 328 assert_never(src) 329 330 331def create_sample_for_model( 332 model: AnyModelDescr, 333 *, 334 stat: Optional[Stat] = None, 335 sample_id: SampleId = None, 336 inputs: Optional[ 337 PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]] 338 ] = None, # TODO: make non-optional 339 **kwargs: NDArray[Any], # TODO: deprecate in favor of `inputs` 340) -> Sample: 341 """Create a sample from a single set of input(s) for a specific bioimage.io model 342 343 Args: 344 model: a bioimage.io model description 345 stat: dictionary with sample and dataset statistics (may be updated in-place!) 346 inputs: the input(s) constituting a single sample. 347 """ 348 inputs = {MemberId(k): v for k, v in {**kwargs, **(inputs or {})}.items()} 349 350 model_inputs = {get_member_id(d): d for d in model.inputs} 351 if unknown := {k for k in inputs if k not in model_inputs}: 352 raise ValueError(f"Got unexpected inputs: {unknown}") 353 354 if missing := { 355 k 356 for k, v in model_inputs.items() 357 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 358 }: 359 raise ValueError(f"Missing non-optional model inputs: {missing}") 360 361 return Sample( 362 members={ 363 m: get_tensor(inputs[m], ipt) 364 for m, ipt in model_inputs.items() 365 if m in inputs 366 }, 367 stat={} if stat is None else stat, 368 id=sample_id, 369 ) 370 371 372def load_sample_for_model( 373 *, 374 model: AnyModelDescr, 375 paths: PerMember[Path], 376 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 377 stat: Optional[Stat] = None, 378 sample_id: Optional[SampleId] = None, 379): 380 """load a single sample from `paths` that can be processed by `model`""" 381 382 if axes is None: 383 axes = {} 384 385 # make sure members are keyed by MemberId, not string 386 paths = {MemberId(k): v for k, v in paths.items()} 387 axes = {MemberId(k): v for k, v in axes.items()} 388 389 model_inputs = {get_member_id(d): d for d in model.inputs} 390 391 if unknown := {k for k in paths if k not in model_inputs}: 392 raise ValueError(f"Got unexpected paths for {unknown}") 393 394 if unknown := {k for k in axes if k not in model_inputs}: 395 raise ValueError(f"Got unexpected axes hints for: {unknown}") 396 397 members: Dict[MemberId, Tensor] = {} 398 for m, p in paths.items(): 399 if m not in axes: 400 axes[m] = get_axes_infos(model_inputs[m]) 401 logger.debug( 402 "loading '{}' from {} with default input axes {} ", 403 m, 404 p, 405 axes[m], 406 ) 407 members[m] = load_tensor(p, axes[m]) 408 409 return Sample( 410 members=members, 411 stat={} if stat is None else stat, 412 id=sample_id or tuple(sorted(paths.values())), 413 )
53def import_callable( 54 node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr], 55 /, 56 **kwargs: Unpack[HashKwargs], 57) -> Callable[..., Any]: 58 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 59 if isinstance(node, CallableFromDepencency): 60 module = importlib.import_module(node.module_name) 61 c = getattr(module, str(node.callable_name)) 62 elif isinstance(node, ArchitectureFromLibraryDescr): 63 module = importlib.import_module(node.import_from) 64 c = getattr(module, str(node.callable)) 65 elif isinstance(node, CallableFromFile): 66 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 67 elif isinstance(node, ArchitectureFromFileDescr): 68 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 69 70 else: 71 assert_never(node) 72 73 if not callable(c): 74 raise ValueError(f"{node} (imported: {c}) is not callable") 75 76 return c
import a callable (e.g. a torch.nn.Module) from a spec node describing it
95def get_axes_infos( 96 io_descr: Union[ 97 v0_4.InputTensorDescr, 98 v0_4.OutputTensorDescr, 99 v0_5.InputTensorDescr, 100 v0_5.OutputTensorDescr, 101 ], 102) -> List[AxisInfo]: 103 """get a unified, simplified axis representation from spec axes""" 104 return [ 105 ( 106 AxisInfo.create("i") 107 if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x") 108 else AxisInfo.create(a) 109 ) 110 for a in io_descr.axes 111 ]
get a unified, simplified axis representation from spec axes
114def get_member_id( 115 tensor_description: Union[ 116 v0_4.InputTensorDescr, 117 v0_4.OutputTensorDescr, 118 v0_5.InputTensorDescr, 119 v0_5.OutputTensorDescr, 120 ], 121) -> MemberId: 122 """get the normalized tensor ID, usable as a sample member ID""" 123 124 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 125 return MemberId(tensor_description.name) 126 elif isinstance( 127 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 128 ): 129 return tensor_description.id 130 else: 131 assert_never(tensor_description)
get the normalized tensor ID, usable as a sample member ID
134def get_member_ids( 135 tensor_descriptions: Sequence[ 136 Union[ 137 v0_4.InputTensorDescr, 138 v0_4.OutputTensorDescr, 139 v0_5.InputTensorDescr, 140 v0_5.OutputTensorDescr, 141 ] 142 ], 143) -> List[MemberId]: 144 """get normalized tensor IDs to be used as sample member IDs""" 145 return [get_member_id(descr) for descr in tensor_descriptions]
get normalized tensor IDs to be used as sample member IDs
148def get_test_inputs(model: AnyModelDescr) -> Sample: 149 """returns a model's test input sample""" 150 member_ids = get_member_ids(model.inputs) 151 if isinstance(model, v0_4.ModelDescr): 152 arrays = [load_array(tt) for tt in model.test_inputs] 153 else: 154 arrays = [load_array(d.test_tensor) for d in model.inputs] 155 156 axes = [get_axes_infos(t) for t in model.inputs] 157 return Sample( 158 members={ 159 m: Tensor.from_numpy(arr, dims=ax) 160 for m, arr, ax in zip(member_ids, arrays, axes) 161 }, 162 stat={}, 163 id="test-sample", 164 )
returns a model's test input sample
167def get_test_outputs(model: AnyModelDescr) -> Sample: 168 """returns a model's test output sample""" 169 member_ids = get_member_ids(model.outputs) 170 171 if isinstance(model, v0_4.ModelDescr): 172 arrays = [load_array(tt) for tt in model.test_outputs] 173 else: 174 arrays = [load_array(d.test_tensor) for d in model.outputs] 175 176 axes = [get_axes_infos(t) for t in model.outputs] 177 178 return Sample( 179 members={ 180 m: Tensor.from_numpy(arr, dims=ax) 181 for m, arr, ax in zip(member_ids, arrays, axes) 182 }, 183 stat={}, 184 id="test-sample", 185 )
returns a model's test output sample
IO_SampleBlockMeta(input, output)
193def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 194 """returns which halo input tensors need to be divided into blocks with, such that 195 `output_halo` can be cropped from their outputs without introducing gaps.""" 196 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 197 outputs = {t.id: t for t in model.outputs} 198 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 199 200 for t, th in output_halo.items(): 201 axes = {a.id: a for a in outputs[t].axes} 202 203 for a, ah in th.items(): 204 s = axes[a].size 205 if not isinstance(s, v0_5.SizeReference): 206 raise ValueError( 207 f"Unable to map output halo for {t}.{a} to an input axis" 208 ) 209 210 axis = axes[a] 211 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 212 213 total_output_halo = sum(ah) 214 total_input_halo = total_output_halo * axis.scale / ref_axis.scale 215 assert ( 216 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 217 ) 218 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 219 int(total_input_halo // 2), int(total_input_halo // 2) 220 ) 221 222 return input_halo
returns which halo input tensors need to be divided into blocks with, such that
output_halo
can be cropped from their outputs without introducing gaps.
225def get_block_transform( 226 model: v0_5.ModelDescr, 227) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 228 """returns how a model's output tensor shapes relates to its input shapes""" 229 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 230 batch_axis_trf = None 231 for ipt in model.inputs: 232 for a in ipt.axes: 233 if a.type == "batch": 234 batch_axis_trf = LinearSampleAxisTransform( 235 axis=a.id, scale=1, offset=0, member=ipt.id 236 ) 237 break 238 if batch_axis_trf is not None: 239 break 240 axis_scales = { 241 t.id: {a.id: a.scale for a in t.axes} 242 for t in chain(model.inputs, model.outputs) 243 } 244 for out in model.outputs: 245 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 246 for a in out.axes: 247 if a.size is None: 248 assert a.type == "batch" 249 if batch_axis_trf is None: 250 raise ValueError( 251 "no batch axis found in any input tensor, but output tensor" 252 + f" '{out.id}' has one." 253 ) 254 s = batch_axis_trf 255 elif isinstance(a.size, int): 256 s = a.size 257 elif isinstance(a.size, v0_5.DataDependentSize): 258 s = -1 259 elif isinstance(a.size, v0_5.SizeReference): 260 s = LinearSampleAxisTransform( 261 axis=a.size.axis_id, 262 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 263 offset=a.size.offset, 264 member=a.size.tensor_id, 265 ) 266 else: 267 assert_never(a.size) 268 269 new_axes[a.id] = s 270 271 ret[out.id] = new_axes 272 273 return ret
returns how a model's output tensor shapes relates to its input shapes
276def get_io_sample_block_metas( 277 model: v0_5.ModelDescr, 278 input_sample_shape: PerMember[PerAxis[int]], 279 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 280 batch_size: int = 1, 281) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 282 """returns an iterable yielding meta data for corresponding input and output samples""" 283 if not isinstance(model, v0_5.ModelDescr): 284 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 285 286 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 287 input_block_shape = { 288 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 289 for t in {tt for tt, _ in block_axis_sizes.inputs} 290 } 291 output_halo = { 292 t.id: { 293 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 294 } 295 for t in model.outputs 296 } 297 input_halo = get_input_halo(model, output_halo) 298 299 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 300 input_sample_shape, input_block_shape, halo=input_halo 301 ) 302 block_transform = get_block_transform(model) 303 return n_input_blocks, ( 304 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 305 for ipt in sample_block_meta_generator( 306 input_blocks, sample_shape=input_sample_shape, sample_id=None 307 ) 308 )
returns an iterable yielding meta data for corresponding input and output samples
311def get_tensor( 312 src: Union[Tensor, xr.DataArray, NDArray[Any], Path], 313 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 314): 315 """helper to cast/load various tensor sources""" 316 317 if isinstance(src, Tensor): 318 return src 319 320 if isinstance(src, xr.DataArray): 321 return Tensor.from_xarray(src) 322 323 if isinstance(src, np.ndarray): 324 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 325 326 if isinstance(src, Path): 327 return load_tensor(src, axes=get_axes_infos(ipt)) 328 329 assert_never(src)
helper to cast/load various tensor sources
332def create_sample_for_model( 333 model: AnyModelDescr, 334 *, 335 stat: Optional[Stat] = None, 336 sample_id: SampleId = None, 337 inputs: Optional[ 338 PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]] 339 ] = None, # TODO: make non-optional 340 **kwargs: NDArray[Any], # TODO: deprecate in favor of `inputs` 341) -> Sample: 342 """Create a sample from a single set of input(s) for a specific bioimage.io model 343 344 Args: 345 model: a bioimage.io model description 346 stat: dictionary with sample and dataset statistics (may be updated in-place!) 347 inputs: the input(s) constituting a single sample. 348 """ 349 inputs = {MemberId(k): v for k, v in {**kwargs, **(inputs or {})}.items()} 350 351 model_inputs = {get_member_id(d): d for d in model.inputs} 352 if unknown := {k for k in inputs if k not in model_inputs}: 353 raise ValueError(f"Got unexpected inputs: {unknown}") 354 355 if missing := { 356 k 357 for k, v in model_inputs.items() 358 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 359 }: 360 raise ValueError(f"Missing non-optional model inputs: {missing}") 361 362 return Sample( 363 members={ 364 m: get_tensor(inputs[m], ipt) 365 for m, ipt in model_inputs.items() 366 if m in inputs 367 }, 368 stat={} if stat is None else stat, 369 id=sample_id, 370 )
Create a sample from a single set of input(s) for a specific bioimage.io model
Arguments:
- model: a bioimage.io model description
- stat: dictionary with sample and dataset statistics (may be updated in-place!)
- inputs: the input(s) constituting a single sample.
373def load_sample_for_model( 374 *, 375 model: AnyModelDescr, 376 paths: PerMember[Path], 377 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 378 stat: Optional[Stat] = None, 379 sample_id: Optional[SampleId] = None, 380): 381 """load a single sample from `paths` that can be processed by `model`""" 382 383 if axes is None: 384 axes = {} 385 386 # make sure members are keyed by MemberId, not string 387 paths = {MemberId(k): v for k, v in paths.items()} 388 axes = {MemberId(k): v for k, v in axes.items()} 389 390 model_inputs = {get_member_id(d): d for d in model.inputs} 391 392 if unknown := {k for k in paths if k not in model_inputs}: 393 raise ValueError(f"Got unexpected paths for {unknown}") 394 395 if unknown := {k for k in axes if k not in model_inputs}: 396 raise ValueError(f"Got unexpected axes hints for: {unknown}") 397 398 members: Dict[MemberId, Tensor] = {} 399 for m, p in paths.items(): 400 if m not in axes: 401 axes[m] = get_axes_infos(model_inputs[m]) 402 logger.debug( 403 "loading '{}' from {} with default input axes {} ", 404 m, 405 p, 406 axes[m], 407 ) 408 members[m] = load_tensor(p, axes[m]) 409 410 return Sample( 411 members=members, 412 stat={} if stat is None else stat, 413 id=sample_id or tuple(sorted(paths.values())), 414 )
load a single sample from paths
that can be processed by model