Coverage for src/bioimageio/core/digest_spec.py: 79%
217 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:02 +0000
1from __future__ import annotations
3import collections.abc
4import importlib.util
5import sys
6from itertools import chain
7from pathlib import Path
8from tempfile import TemporaryDirectory
9from typing import (
10 Any,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 NamedTuple,
17 Optional,
18 Sequence,
19 Tuple,
20 Union,
21)
22from zipfile import ZipFile, is_zipfile
24import numpy as np
25import xarray as xr
26from bioimageio.spec._internal.io import HashKwargs
27from bioimageio.spec.common import FileDescr, FileSource, ZipPath
28from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
29from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
30from bioimageio.spec.model.v0_5 import (
31 ArchitectureFromFileDescr,
32 ArchitectureFromLibraryDescr,
33 ParameterizedSize_N,
34)
35from bioimageio.spec.utils import load_array
36from loguru import logger
37from numpy.typing import NDArray
38from typing_extensions import Unpack, assert_never
40from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
41from .block_meta import split_multiple_shapes_into_blocks
42from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
43from .io import load_tensor
44from .sample import (
45 LinearSampleAxisTransform,
46 Sample,
47 SampleBlockMeta,
48 sample_block_meta_generator,
49)
50from .stat_measures import Stat
51from .tensor import Tensor
53TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path]
56def import_callable(
57 node: Union[
58 ArchitectureFromFileDescr,
59 ArchitectureFromLibraryDescr,
60 CallableFromDepencency,
61 CallableFromFile,
62 ],
63 /,
64 **kwargs: Unpack[HashKwargs],
65) -> Callable[..., Any]:
66 """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
67 if isinstance(node, CallableFromDepencency):
68 module = importlib.import_module(node.module_name)
69 c = getattr(module, str(node.callable_name))
70 elif isinstance(node, ArchitectureFromLibraryDescr):
71 module = importlib.import_module(node.import_from)
72 c = getattr(module, str(node.callable))
73 elif isinstance(node, CallableFromFile):
74 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
75 elif isinstance(node, ArchitectureFromFileDescr):
76 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
77 else:
78 assert_never(node)
80 if not callable(c):
81 raise ValueError(f"{node} (imported: {c}) is not callable")
83 return c
86tmp_dirs_in_use: List[TemporaryDirectory[str]] = []
87"""keep global reference to temporary directories created during import to delay cleanup"""
90def _import_from_file_impl(
91 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
92):
93 src_descr = FileDescr(source=source, **kwargs)
94 # ensure sha is valid even if perform_io_checks=False
95 # or the source has changed since last sha computation
96 src_descr.validate_sha256(force_recompute=True)
97 assert src_descr.sha256 is not None
98 source_sha = src_descr.sha256
100 reader = src_descr.get_reader()
101 # make sure we have unique module name
102 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}"
104 # make sure we have a unique and valid module name
105 if not module_name.isidentifier():
106 module_name = f"custom_module_{source_sha}"
107 assert module_name.isidentifier(), module_name
109 source_bytes = reader.read()
111 module = sys.modules.get(module_name)
112 if module is None:
113 try:
114 td_kwargs: Dict[str, Any] = (
115 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
116 )
117 if sys.version_info >= (3, 12):
118 td_kwargs["delete"] = False
120 tmp_dir = TemporaryDirectory(**td_kwargs)
121 # keep global ref to tmp_dir to delay cleanup until program exit
122 # TODO: remove for py >= 3.12, when delete=False works
123 tmp_dirs_in_use.append(tmp_dir)
125 module_path = Path(tmp_dir.name) / module_name
126 if reader.original_file_name.endswith(".zip") or is_zipfile(reader):
127 module_path.mkdir()
128 ZipFile(reader).extractall(path=module_path)
129 else:
130 module_path = module_path.with_suffix(".py")
131 _ = module_path.write_bytes(source_bytes)
133 importlib_spec = importlib.util.spec_from_file_location(
134 module_name, str(module_path)
135 )
137 if importlib_spec is None:
138 raise ImportError(f"Failed to import {source}")
140 module = importlib.util.module_from_spec(importlib_spec)
142 sys.modules[module_name] = module # cache this module
144 assert importlib_spec.loader is not None
145 importlib_spec.loader.exec_module(module)
147 except Exception as e:
148 if module_name in sys.modules:
149 del sys.modules[module_name]
151 raise ImportError(f"Failed to import {source}") from e
153 try:
154 callable_attr = getattr(module, callable_name)
155 except AttributeError as e:
156 raise AttributeError(
157 f"Imported custom module from {source} has no `{callable_name}` attribute."
158 ) from e
159 except Exception as e:
160 raise AttributeError(
161 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
162 ) from e
164 else:
165 return callable_attr
168def get_axes_infos(
169 io_descr: Union[
170 v0_4.InputTensorDescr,
171 v0_4.OutputTensorDescr,
172 v0_5.InputTensorDescr,
173 v0_5.OutputTensorDescr,
174 ],
175) -> List[AxisInfo]:
176 """get a unified, simplified axis representation from spec axes"""
177 ret: List[AxisInfo] = []
178 for a in io_descr.axes:
179 if isinstance(a, v0_5.AxisBase):
180 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
181 else:
182 assert a in ("b", "i", "t", "c", "z", "y", "x")
183 ret.append(AxisInfo.create(a))
185 return ret
188def get_member_id(
189 tensor_description: Union[
190 v0_4.InputTensorDescr,
191 v0_4.OutputTensorDescr,
192 v0_5.InputTensorDescr,
193 v0_5.OutputTensorDescr,
194 ],
195) -> MemberId:
196 """get the normalized tensor ID, usable as a sample member ID"""
198 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
199 return MemberId(tensor_description.name)
200 elif isinstance(
201 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
202 ):
203 return tensor_description.id
204 else:
205 assert_never(tensor_description)
208def get_member_ids(
209 tensor_descriptions: Sequence[
210 Union[
211 v0_4.InputTensorDescr,
212 v0_4.OutputTensorDescr,
213 v0_5.InputTensorDescr,
214 v0_5.OutputTensorDescr,
215 ]
216 ],
217) -> List[MemberId]:
218 """get normalized tensor IDs to be used as sample member IDs"""
219 return [get_member_id(descr) for descr in tensor_descriptions]
222def get_test_input_sample(model: AnyModelDescr) -> Sample:
223 return _get_test_sample(
224 model.inputs,
225 model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs,
226 )
229get_test_inputs = get_test_input_sample
230"""DEPRECATED: use `get_test_input_sample` instead"""
233def get_test_output_sample(model: AnyModelDescr) -> Sample:
234 """returns a model's test output sample"""
235 return _get_test_sample(
236 model.outputs,
237 model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs,
238 )
241get_test_outputs = get_test_output_sample
242"""DEPRECATED: use `get_test_input_sample` instead"""
245def _get_test_sample(
246 tensor_descrs: Sequence[
247 Union[
248 v0_4.InputTensorDescr,
249 v0_4.OutputTensorDescr,
250 v0_5.InputTensorDescr,
251 v0_5.OutputTensorDescr,
252 ]
253 ],
254 test_sources: Sequence[Union[FileSource, v0_5.TensorDescr]],
255) -> Sample:
256 """returns a model's input/output test sample"""
257 member_ids = get_member_ids(tensor_descrs)
258 arrays: List[NDArray[Any]] = []
259 for src in test_sources:
260 if isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)):
261 if src.test_tensor is None:
262 raise ValueError(
263 f"Model input '{src.id}' has no test tensor defined, cannot create test sample."
264 )
265 arrays.append(load_array(src.test_tensor))
266 else:
267 arrays.append(load_array(src))
269 axes = [get_axes_infos(t) for t in tensor_descrs]
270 return Sample(
271 members={
272 m: Tensor.from_numpy(arr, dims=ax)
273 for m, arr, ax in zip(member_ids, arrays, axes)
274 },
275 stat={},
276 id="test-sample",
277 )
280class IO_SampleBlockMeta(NamedTuple):
281 input: SampleBlockMeta
282 output: SampleBlockMeta
285def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
286 """returns which halo input tensors need to be divided into blocks with, such that
287 `output_halo` can be cropped from their outputs without introducing gaps."""
288 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
289 outputs = {t.id: t for t in model.outputs}
290 all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
292 for t, th in output_halo.items():
293 axes = {a.id: a for a in outputs[t].axes}
295 for a, ah in th.items():
296 s = axes[a].size
297 if not isinstance(s, v0_5.SizeReference):
298 raise ValueError(
299 f"Unable to map output halo for {t}.{a} to an input axis"
300 )
302 axis = axes[a]
303 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
305 input_halo_left = ah.left * axis.scale / ref_axis.scale
306 input_halo_right = ah.right * axis.scale / ref_axis.scale
307 assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int"
308 assert input_halo_right == int(input_halo_right), (
309 f"{input_halo_right} not int"
310 )
312 input_halo.setdefault(s.tensor_id, {})[a] = Halo(
313 int(input_halo_left), int(input_halo_right)
314 )
316 return input_halo
319def get_block_transform(
320 model: v0_5.ModelDescr,
321) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
322 """returns how a model's output tensor shapes relates to its input shapes"""
323 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
324 batch_axis_trf = None
325 for ipt in model.inputs:
326 for a in ipt.axes:
327 if a.type == "batch":
328 batch_axis_trf = LinearSampleAxisTransform(
329 axis=a.id, scale=1, offset=0, member=ipt.id
330 )
331 break
332 if batch_axis_trf is not None:
333 break
334 axis_scales = {
335 t.id: {a.id: a.scale for a in t.axes}
336 for t in chain(model.inputs, model.outputs)
337 }
338 for out in model.outputs:
339 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
340 for a in out.axes:
341 if a.size is None:
342 assert a.type == "batch"
343 if batch_axis_trf is None:
344 raise ValueError(
345 "no batch axis found in any input tensor, but output tensor"
346 + f" '{out.id}' has one."
347 )
348 s = batch_axis_trf
349 elif isinstance(a.size, int):
350 s = a.size
351 elif isinstance(a.size, v0_5.DataDependentSize):
352 s = -1
353 elif isinstance(a.size, v0_5.SizeReference):
354 s = LinearSampleAxisTransform(
355 axis=a.size.axis_id,
356 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
357 offset=a.size.offset,
358 member=a.size.tensor_id,
359 )
360 else:
361 assert_never(a.size)
363 new_axes[a.id] = s
365 ret[out.id] = new_axes
367 return ret
370def get_io_sample_block_metas(
371 model: v0_5.ModelDescr,
372 input_sample_shape: PerMember[PerAxis[int]],
373 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
374 batch_size: int = 1,
375) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
376 """returns an iterable yielding meta data for corresponding input and output samples"""
377 if not isinstance(model, v0_5.ModelDescr):
378 raise TypeError(f"get_block_meta() not implemented for {type(model)}")
380 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
381 input_block_shape = {
382 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
383 for t in {tt for tt, _ in block_axis_sizes.inputs}
384 }
385 output_halo = {
386 t.id: {
387 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
388 }
389 for t in model.outputs
390 }
391 input_halo = get_input_halo(model, output_halo)
393 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
394 input_sample_shape, input_block_shape, halo=input_halo
395 )
396 block_transform = get_block_transform(model)
397 return n_input_blocks, (
398 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
399 for ipt in sample_block_meta_generator(
400 input_blocks, sample_shape=input_sample_shape, sample_id=None
401 )
402 )
405def get_tensor(
406 src: Union[ZipPath, TensorSource],
407 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
408):
409 """helper to cast/load various tensor sources"""
411 if isinstance(src, Tensor):
412 return src
413 elif isinstance(src, xr.DataArray):
414 return Tensor.from_xarray(src)
415 elif isinstance(src, np.ndarray):
416 return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
417 else:
418 return load_tensor(src, axes=get_axes_infos(ipt))
421def create_sample_for_model(
422 model: AnyModelDescr,
423 *,
424 stat: Optional[Stat] = None,
425 sample_id: SampleId = None,
426 inputs: Union[PerMember[TensorSource], TensorSource],
427) -> Sample:
428 """Create a sample from a single set of input(s) for a specific bioimage.io model
430 Args:
431 model: a bioimage.io model description
432 stat: dictionary with sample and dataset statistics (may be updated in-place!)
433 inputs: the input(s) constituting a single sample.
434 """
436 model_inputs = {get_member_id(d): d for d in model.inputs}
437 if isinstance(inputs, collections.abc.Mapping):
438 inputs = {MemberId(k): v for k, v in inputs.items()}
439 elif len(model_inputs) == 1:
440 inputs = {list(model_inputs)[0]: inputs}
441 else:
442 raise TypeError(
443 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
444 )
446 if unknown := {k for k in inputs if k not in model_inputs}:
447 raise ValueError(f"Got unexpected inputs: {unknown}")
449 if missing := {
450 k
451 for k, v in model_inputs.items()
452 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
453 }:
454 raise ValueError(f"Missing non-optional model inputs: {missing}")
456 return Sample(
457 members={
458 m: get_tensor(inputs[m], ipt)
459 for m, ipt in model_inputs.items()
460 if m in inputs
461 },
462 stat={} if stat is None else stat,
463 id=sample_id,
464 )
467def load_sample_for_model(
468 *,
469 model: AnyModelDescr,
470 paths: PerMember[Path],
471 axes: Optional[PerMember[Sequence[AxisLike]]] = None,
472 stat: Optional[Stat] = None,
473 sample_id: Optional[SampleId] = None,
474):
475 """load a single sample from `paths` that can be processed by `model`"""
477 if axes is None:
478 axes = {}
480 # make sure members are keyed by MemberId, not string
481 paths = {MemberId(k): v for k, v in paths.items()}
482 axes = {MemberId(k): v for k, v in axes.items()}
484 model_inputs = {get_member_id(d): d for d in model.inputs}
486 if unknown := {k for k in paths if k not in model_inputs}:
487 raise ValueError(f"Got unexpected paths for {unknown}")
489 if unknown := {k for k in axes if k not in model_inputs}:
490 raise ValueError(f"Got unexpected axes hints for: {unknown}")
492 members: Dict[MemberId, Tensor] = {}
493 for m, p in paths.items():
494 if m not in axes:
495 axes[m] = get_axes_infos(model_inputs[m])
496 logger.debug(
497 "loading '{}' from {} with default input axes {} ",
498 m,
499 p,
500 axes[m],
501 )
502 members[m] = load_tensor(p, axes[m])
504 return Sample(
505 members=members,
506 stat={} if stat is None else stat,
507 id=sample_id or tuple(sorted(paths.values())),
508 )