Coverage for src / bioimageio / core / digest_spec.py: 85%
220 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 18:38 +0000
1from __future__ import annotations
3import collections.abc
4import importlib.util
5import sys
6from itertools import chain
7from pathlib import Path
8from tempfile import TemporaryDirectory
9from typing import (
10 Any,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 NamedTuple,
17 Optional,
18 Sequence,
19 Tuple,
20 Union,
21)
22from zipfile import ZipFile, is_zipfile
24import numpy as np
25import xarray as xr
26from loguru import logger
27from numpy.typing import NDArray
28from typing_extensions import Unpack, assert_never
30from bioimageio.spec._internal.io import HashKwargs, PermissiveFileSource
31from bioimageio.spec.common import FileDescr, FileSource
32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
34from bioimageio.spec.model.v0_5 import (
35 ArchitectureFromFileDescr,
36 ArchitectureFromLibraryDescr,
37 ParameterizedSize_N,
38)
39from bioimageio.spec.utils import load_array
41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
42from .block_meta import split_multiple_shapes_into_blocks
43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
44from .io import load_tensor
45from .sample import (
46 LinearSampleAxisTransform,
47 Sample,
48 SampleBlockMeta,
49 sample_block_meta_generator,
50)
51from .stat_measures import Stat
52from .tensor import Tensor
54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], PermissiveFileSource]
57def import_callable(
58 node: Union[
59 ArchitectureFromFileDescr,
60 ArchitectureFromLibraryDescr,
61 CallableFromDepencency,
62 CallableFromFile,
63 ],
64 /,
65 **kwargs: Unpack[HashKwargs],
66) -> Callable[..., Any]:
67 """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
68 if isinstance(node, CallableFromDepencency):
69 module = importlib.import_module(node.module_name)
70 c = getattr(module, str(node.callable_name))
71 elif isinstance(node, ArchitectureFromLibraryDescr):
72 module = importlib.import_module(node.import_from)
73 c = getattr(module, str(node.callable))
74 elif isinstance(node, CallableFromFile):
75 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
76 elif isinstance(node, ArchitectureFromFileDescr):
77 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
78 else:
79 assert_never(node)
81 if not callable(c):
82 raise ValueError(f"{node} (imported: {c}) is not callable")
84 return c
87tmp_dirs_in_use: List[TemporaryDirectory[str]] = []
88"""keep global reference to temporary directories created during import to delay cleanup"""
91def _import_from_file_impl(
92 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
93):
94 src_descr = FileDescr(source=source, **kwargs)
95 # ensure sha is valid even if perform_io_checks=False
96 # or the source has changed since last sha computation
97 src_descr.validate_sha256(force_recompute=True)
98 assert src_descr.sha256 is not None
99 source_sha = src_descr.sha256
101 reader = src_descr.get_reader()
102 # make sure we have unique module name
103 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}"
105 # make sure we have a unique and valid module name
106 if not module_name.isidentifier():
107 module_name = f"custom_module_{source_sha}"
108 assert module_name.isidentifier(), module_name
110 source_bytes = reader.read()
112 module = sys.modules.get(module_name)
113 if module is None:
114 try:
115 td_kwargs: Dict[str, Any] = (
116 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
117 )
118 if sys.version_info >= (3, 12):
119 td_kwargs["delete"] = False
121 tmp_dir = TemporaryDirectory(**td_kwargs)
122 # keep global ref to tmp_dir to delay cleanup until program exit
123 # TODO: remove for py >= 3.12, when delete=False works
124 tmp_dirs_in_use.append(tmp_dir)
126 module_path = Path(tmp_dir.name) / module_name
127 if reader.original_file_name.endswith(".zip") or is_zipfile(reader):
128 module_path.mkdir()
129 ZipFile(reader).extractall(path=module_path)
130 else:
131 module_path = module_path.with_suffix(".py")
132 _ = module_path.write_bytes(source_bytes)
134 importlib_spec = importlib.util.spec_from_file_location(
135 module_name, str(module_path)
136 )
138 if importlib_spec is None:
139 raise ImportError(f"Failed to import {source}")
141 module = importlib.util.module_from_spec(importlib_spec)
143 sys.modules[module_name] = module # cache this module
145 assert importlib_spec.loader is not None
146 importlib_spec.loader.exec_module(module)
148 except Exception as e:
149 if module_name in sys.modules:
150 del sys.modules[module_name]
152 raise ImportError(f"Failed to import {source}") from e
154 try:
155 callable_attr = getattr(module, callable_name)
156 except AttributeError as e:
157 raise AttributeError(
158 f"Imported custom module from {source} has no `{callable_name}` attribute."
159 ) from e
160 except Exception as e:
161 raise AttributeError(
162 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
163 ) from e
165 else:
166 return callable_attr
169def get_axes_infos(
170 io_descr: Union[
171 v0_4.InputTensorDescr,
172 v0_4.OutputTensorDescr,
173 v0_5.InputTensorDescr,
174 v0_5.OutputTensorDescr,
175 ],
176) -> List[AxisInfo]:
177 """get a unified, simplified axis representation from spec axes"""
178 ret: List[AxisInfo] = []
179 for a in io_descr.axes:
180 if isinstance(a, v0_5.AxisBase):
181 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
182 else:
183 assert a in ("b", "i", "t", "c", "z", "y", "x")
184 ret.append(AxisInfo.create(a))
186 return ret
189def get_member_id(
190 tensor_description: Union[
191 v0_4.InputTensorDescr,
192 v0_4.OutputTensorDescr,
193 v0_5.InputTensorDescr,
194 v0_5.OutputTensorDescr,
195 ],
196) -> MemberId:
197 """get the normalized tensor ID, usable as a sample member ID"""
199 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
200 return MemberId(tensor_description.name)
201 elif isinstance(
202 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
203 ):
204 return tensor_description.id
205 else:
206 assert_never(tensor_description)
209def get_member_ids(
210 tensor_descriptions: Sequence[
211 Union[
212 v0_4.InputTensorDescr,
213 v0_4.OutputTensorDescr,
214 v0_5.InputTensorDescr,
215 v0_5.OutputTensorDescr,
216 ]
217 ],
218) -> List[MemberId]:
219 """get normalized tensor IDs to be used as sample member IDs"""
220 return [get_member_id(descr) for descr in tensor_descriptions]
223def get_test_input_sample(model: AnyModelDescr) -> Sample:
224 return _get_test_sample(
225 model.inputs,
226 model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs,
227 )
230get_test_inputs = get_test_input_sample
231"""DEPRECATED: use `get_test_input_sample` instead"""
234def get_test_output_sample(model: AnyModelDescr) -> Sample:
235 """returns a model's test output sample"""
236 return _get_test_sample(
237 model.outputs,
238 model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs,
239 )
242get_test_outputs = get_test_output_sample
243"""DEPRECATED: use `get_test_input_sample` instead"""
246def _get_test_sample(
247 tensor_descrs: Sequence[
248 Union[
249 v0_4.InputTensorDescr,
250 v0_4.OutputTensorDescr,
251 v0_5.InputTensorDescr,
252 v0_5.OutputTensorDescr,
253 ]
254 ],
255 test_sources: Sequence[Union[FileSource, v0_5.TensorDescr]],
256) -> Sample:
257 """returns a model's input/output test sample"""
258 member_ids = get_member_ids(tensor_descrs)
259 arrays: List[NDArray[Any]] = []
260 for src in test_sources:
261 if isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)):
262 if src.test_tensor is None:
263 raise ValueError(
264 f"Model input '{src.id}' has no test tensor defined, cannot create test sample."
265 )
266 arrays.append(load_array(src.test_tensor))
267 else:
268 arrays.append(load_array(src))
270 axes = [get_axes_infos(t) for t in tensor_descrs]
271 return Sample(
272 members={
273 m: Tensor.from_numpy(arr, dims=ax)
274 for m, arr, ax in zip(member_ids, arrays, axes)
275 },
276 stat={},
277 id="test-sample",
278 )
281class IO_SampleBlockMeta(NamedTuple):
282 input: SampleBlockMeta
283 output: SampleBlockMeta
286def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
287 """returns which halo input tensors need to be divided into blocks with, such that
288 `output_halo` can be cropped from their outputs without introducing gaps."""
289 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
290 outputs = {t.id: t for t in model.outputs}
291 all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
293 for t, th in output_halo.items():
294 axes = {a.id: a for a in outputs[t].axes}
296 for a, ah in th.items():
297 s = axes[a].size
298 if not isinstance(s, v0_5.SizeReference):
299 raise ValueError(
300 f"Unable to map output halo for {t}.{a} to an input axis"
301 )
303 axis = axes[a]
304 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
306 input_halo_left = ah.left * axis.scale / ref_axis.scale
307 input_halo_right = ah.right * axis.scale / ref_axis.scale
308 assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int"
309 assert input_halo_right == int(input_halo_right), (
310 f"{input_halo_right} not int"
311 )
313 input_halo.setdefault(s.tensor_id, {})[a] = Halo(
314 int(input_halo_left), int(input_halo_right)
315 )
317 return input_halo
320def get_block_transform(
321 model: v0_5.ModelDescr,
322) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
323 """returns how a model's output tensor shapes relates to its input shapes"""
324 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
325 batch_axis_trf = None
326 for ipt in model.inputs:
327 for a in ipt.axes:
328 if a.type == "batch":
329 batch_axis_trf = LinearSampleAxisTransform(
330 axis=a.id, scale=1, offset=0, member=ipt.id
331 )
332 break
333 if batch_axis_trf is not None:
334 break
335 axis_scales = {
336 t.id: {a.id: a.scale for a in t.axes}
337 for t in chain(model.inputs, model.outputs)
338 }
339 for out in model.outputs:
340 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
341 for a in out.axes:
342 if a.size is None:
343 assert a.type == "batch"
344 if batch_axis_trf is None:
345 raise ValueError(
346 "no batch axis found in any input tensor, but output tensor"
347 + f" '{out.id}' has one."
348 )
349 s = batch_axis_trf
350 elif isinstance(a.size, int):
351 s = a.size
352 elif isinstance(a.size, v0_5.DataDependentSize):
353 s = -1
354 elif isinstance(a.size, v0_5.SizeReference):
355 s = LinearSampleAxisTransform(
356 axis=a.size.axis_id,
357 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
358 offset=a.size.offset,
359 member=a.size.tensor_id,
360 )
361 else:
362 assert_never(a.size)
364 new_axes[a.id] = s
366 ret[out.id] = new_axes
368 return ret
371def get_io_sample_block_metas(
372 model: v0_5.ModelDescr,
373 input_sample_shape: PerMember[PerAxis[int]],
374 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
375 batch_size: int = 1,
376) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
377 """returns an iterable yielding meta data for corresponding input and output samples"""
378 if not isinstance(model, v0_5.ModelDescr):
379 raise TypeError(f"get_block_meta() not implemented for {type(model)}")
381 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
382 input_block_shape = {
383 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
384 for t in {tt for tt, _ in block_axis_sizes.inputs}
385 }
386 output_halo = {
387 t.id: {
388 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
389 }
390 for t in model.outputs
391 }
392 input_halo = get_input_halo(model, output_halo)
394 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
395 input_sample_shape, input_block_shape, halo=input_halo
396 )
397 block_transform = get_block_transform(model)
398 return n_input_blocks, (
399 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
400 for ipt in sample_block_meta_generator(
401 input_blocks, sample_shape=input_sample_shape, sample_id=None
402 )
403 )
406def get_tensor(
407 src: TensorSource,
408 descr: Union[
409 v0_4.InputTensorDescr,
410 v0_5.InputTensorDescr,
411 v0_4.OutputTensorDescr,
412 v0_5.OutputTensorDescr,
413 Sequence[AxisInfo],
414 ],
415):
416 """helper to cast/load various tensor sources"""
418 if isinstance(
419 descr,
420 (
421 v0_4.InputTensorDescr,
422 v0_5.InputTensorDescr,
423 v0_4.OutputTensorDescr,
424 v0_5.OutputTensorDescr,
425 ),
426 ):
427 axes = get_axes_infos(descr)
428 else:
429 axes = descr
431 if isinstance(src, Tensor):
432 return src.transpose(axes=[a.id for a in axes])
433 elif isinstance(src, xr.DataArray):
434 return Tensor.from_xarray(src).transpose(axes=[a.id for a in axes])
435 elif isinstance(src, np.ndarray):
436 return Tensor.from_numpy(src, dims=axes)
437 else:
438 return load_tensor(src, axes=axes)
441def create_sample_for_model(
442 model: AnyModelDescr,
443 *,
444 stat: Optional[Stat] = None,
445 sample_id: SampleId = None,
446 inputs: Union[PerMember[TensorSource], TensorSource],
447) -> Sample:
448 """Create a sample from a single set of input(s) for a specific bioimage.io model
450 Args:
451 model: a bioimage.io model description
452 stat: dictionary with sample and dataset statistics (may be updated in-place!)
453 inputs: the input(s) constituting a single sample.
454 """
456 model_inputs = {get_member_id(d): d for d in model.inputs}
457 if isinstance(inputs, collections.abc.Mapping):
458 inputs = {MemberId(k): v for k, v in inputs.items()}
459 elif len(model_inputs) == 1:
460 inputs = {list(model_inputs)[0]: inputs}
461 else:
462 raise TypeError(
463 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
464 )
466 if unknown := {k for k in inputs if k not in model_inputs}:
467 raise ValueError(f"Got unexpected inputs: {unknown}")
469 if missing := {
470 k
471 for k, v in model_inputs.items()
472 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
473 }:
474 raise ValueError(f"Missing non-optional model inputs: {missing}")
476 return Sample(
477 members={
478 m: get_tensor(inputs[m], ipt)
479 for m, ipt in model_inputs.items()
480 if m in inputs
481 },
482 stat={} if stat is None else stat,
483 id=sample_id,
484 )
487def load_sample_for_model(
488 *,
489 model: AnyModelDescr,
490 paths: PerMember[Path],
491 axes: Optional[PerMember[Sequence[AxisLike]]] = None,
492 stat: Optional[Stat] = None,
493 sample_id: Optional[SampleId] = None,
494):
495 """load a single sample from `paths` that can be processed by `model`"""
497 if axes is None:
498 axes = {}
500 # make sure members are keyed by MemberId, not string
501 paths = {MemberId(k): v for k, v in paths.items()}
502 axes = {MemberId(k): v for k, v in axes.items()}
504 model_inputs = {get_member_id(d): d for d in model.inputs}
506 if unknown := {k for k in paths if k not in model_inputs}:
507 raise ValueError(f"Got unexpected paths for {unknown}")
509 if unknown := {k for k in axes if k not in model_inputs}:
510 raise ValueError(f"Got unexpected axes hints for: {unknown}")
512 members: Dict[MemberId, Tensor] = {}
513 for m, p in paths.items():
514 if m not in axes:
515 axes[m] = get_axes_infos(model_inputs[m])
516 logger.debug(
517 "loading '{}' from {} with default input axes {} ",
518 m,
519 p,
520 axes[m],
521 )
522 members[m] = load_tensor(p, axes[m])
524 return Sample(
525 members=members,
526 stat={} if stat is None else stat,
527 id=sample_id or tuple(sorted(paths.values())),
528 )