Coverage for src / bioimageio / core / digest_spec.py: 85%
230 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-18 12:35 +0000
1from __future__ import annotations
3import collections.abc
4import importlib.util
5import sys
6from itertools import chain
7from pathlib import Path
8from tempfile import TemporaryDirectory
9from typing import (
10 Any,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 NamedTuple,
17 Optional,
18 Sequence,
19 Tuple,
20 Union,
21)
22from zipfile import ZipFile, is_zipfile
24import numpy as np
25import xarray as xr
26from loguru import logger
27from numpy.typing import NDArray
28from typing_extensions import Unpack, assert_never
30from bioimageio.spec._internal.io import HashKwargs, PermissiveFileSource
31from bioimageio.spec.common import FileDescr, FileSource
32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
34from bioimageio.spec.model.v0_5 import (
35 ArchitectureFromFileDescr,
36 ArchitectureFromLibraryDescr,
37 ParameterizedSize_N,
38)
39from bioimageio.spec.utils import load_array
41from .axis import AxisId, AxisInfo, AxisLike, PerAxis
42from .block_meta import split_multiple_shapes_into_blocks
43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
44from .io import load_tensor
45from .sample import (
46 LinearSampleAxisTransform,
47 Sample,
48 SampleBlockMeta,
49 sample_block_meta_generator,
50)
51from .stat_measures import Stat
52from .tensor import Tensor
54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], PermissiveFileSource]
57def import_callable(
58 node: Union[
59 ArchitectureFromFileDescr,
60 ArchitectureFromLibraryDescr,
61 CallableFromDepencency,
62 CallableFromFile,
63 v0_5.CustomProcessingDescr,
64 ],
65 /,
66 **kwargs: Unpack[HashKwargs],
67) -> Callable[..., Any]:
68 """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
69 if isinstance(node, CallableFromDepencency):
70 module = importlib.import_module(node.module_name)
71 c = getattr(module, str(node.callable_name))
72 elif isinstance(node, ArchitectureFromLibraryDescr):
73 module = importlib.import_module(node.import_from)
74 c = getattr(module, str(node.callable))
75 elif isinstance(node, CallableFromFile):
76 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
77 elif isinstance(node, (ArchitectureFromFileDescr, v0_5.CustomProcessingDescr)):
78 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
79 else:
80 assert_never(node)
82 if not callable(c):
83 raise ValueError(f"{node} (imported: {c}) is not callable")
85 return c
88tmp_dirs_in_use: List[TemporaryDirectory[str]] = []
89"""keep global reference to temporary directories created during import to delay cleanup"""
92def _import_from_file_impl(
93 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
94):
95 src_descr = FileDescr(source=source, **kwargs)
96 # ensure sha is valid even if perform_io_checks=False
97 # or the source has changed since last sha computation
98 src_descr.validate_sha256(force_recompute=True)
99 assert src_descr.sha256 is not None
100 source_sha = src_descr.sha256
102 reader = src_descr.get_reader()
103 # make sure we have unique module name
104 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}"
106 # make sure we have a unique and valid module name
107 if not module_name.isidentifier():
108 module_name = f"custom_module_{source_sha}"
109 assert module_name.isidentifier(), module_name
111 source_bytes = reader.read()
113 module = sys.modules.get(module_name)
114 if module is None:
115 try:
116 td_kwargs: Dict[str, Any] = (
117 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
118 )
119 if sys.version_info >= (3, 12):
120 td_kwargs["delete"] = False
122 tmp_dir = TemporaryDirectory(**td_kwargs)
123 # keep global ref to tmp_dir to delay cleanup until program exit
124 # TODO: remove for py >= 3.12, when delete=False works
125 tmp_dirs_in_use.append(tmp_dir)
127 module_path = Path(tmp_dir.name) / module_name
128 if reader.original_file_name.endswith(".zip") or is_zipfile(reader):
129 module_path.mkdir()
130 ZipFile(reader).extractall(path=module_path)
131 else:
132 module_path = module_path.with_suffix(".py")
133 _ = module_path.write_bytes(source_bytes)
135 importlib_spec = importlib.util.spec_from_file_location(
136 module_name, str(module_path)
137 )
139 if importlib_spec is None:
140 raise ImportError(f"Failed to import {source}")
142 module = importlib.util.module_from_spec(importlib_spec)
144 sys.modules[module_name] = module # cache this module
146 assert importlib_spec.loader is not None
147 importlib_spec.loader.exec_module(module)
149 except Exception as e:
150 if module_name in sys.modules:
151 del sys.modules[module_name]
153 raise ImportError(f"Failed to import {source}") from e
155 try:
156 callable_attr = getattr(module, callable_name)
157 except AttributeError as e:
158 raise AttributeError(
159 f"Imported custom module from {source} has no `{callable_name}` attribute."
160 ) from e
161 except Exception as e:
162 raise AttributeError(
163 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
164 ) from e
166 else:
167 return callable_attr
170def get_axes_infos(
171 io_descr: Union[
172 v0_4.InputTensorDescr,
173 v0_4.OutputTensorDescr,
174 v0_5.InputTensorDescr,
175 v0_5.OutputTensorDescr,
176 ],
177) -> List[AxisInfo]:
178 """get a unified, simplified axis representation from spec axes"""
179 return [AxisInfo.create(a) for a in io_descr.axes]
182def get_member_id(
183 tensor_description: Union[
184 v0_4.InputTensorDescr,
185 v0_4.OutputTensorDescr,
186 v0_5.InputTensorDescr,
187 v0_5.OutputTensorDescr,
188 ],
189) -> MemberId:
190 """get the normalized tensor ID, usable as a sample member ID"""
192 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
193 return MemberId(tensor_description.name)
194 elif isinstance(
195 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
196 ):
197 return tensor_description.id
198 else:
199 assert_never(tensor_description)
202def get_member_ids(
203 tensor_descriptions: Iterable[
204 Union[
205 v0_4.InputTensorDescr,
206 v0_4.OutputTensorDescr,
207 v0_5.InputTensorDescr,
208 v0_5.OutputTensorDescr,
209 ]
210 ],
211) -> List[MemberId]:
212 """get normalized tensor IDs to be used as sample member IDs"""
213 return [get_member_id(descr) for descr in tensor_descriptions]
216def get_test_input_sample(model: AnyModelDescr) -> Sample:
217 if isinstance(model, v0_4.ModelDescr):
218 info = {
219 MemberId(d.name): (d, t) for d, t in zip(model.inputs, model.test_inputs)
220 }
221 else:
222 info = {d.id: d for d in model.inputs}
224 return _get_test_sample(info)
227get_test_inputs = get_test_input_sample
228"""DEPRECATED: use `get_test_input_sample` instead"""
231def get_test_output_sample(model: AnyModelDescr) -> Sample:
232 """returns a model's test output sample"""
233 if isinstance(model, v0_4.ModelDescr):
234 info = {
235 MemberId(d.name): (d, t) for d, t in zip(model.outputs, model.test_outputs)
236 }
237 else:
238 info = {d.id: d for d in model.outputs}
240 return _get_test_sample(info)
243get_test_outputs = get_test_output_sample
244"""DEPRECATED: use `get_test_output_sample` instead"""
247def _get_test_sample(
248 info: Union[
249 Mapping[MemberId, Union[v0_5.InputTensorDescr, v0_5.OutputTensorDescr]],
250 Mapping[
251 MemberId,
252 Tuple[
253 v0_4.InputTensorDescr,
254 FileSource,
255 ],
256 ],
257 Mapping[
258 MemberId,
259 Tuple[
260 v0_4.OutputTensorDescr,
261 FileSource,
262 ],
263 ],
264 ],
265) -> Sample:
266 arrays: Dict[MemberId, NDArray[Any]] = {}
267 for m, src in info.items():
268 if isinstance(src, tuple):
269 arrays[m] = load_array(src[1])
270 elif isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)):
271 if src.test_tensor is None:
272 raise ValueError(
273 f"Model input '{m}' has no test tensor defined, cannot create test sample."
274 )
275 arrays[m] = load_array(src.test_tensor)
276 else:
277 assert_never(src)
279 axes = {
280 m: get_axes_infos(t[0] if isinstance(t, tuple) else t) for m, t in info.items()
281 }
282 return Sample(
283 members={m: Tensor.from_numpy(arrays[m], dims=axes[m]) for m in info},
284 stat={},
285 id="test-sample",
286 )
289class IO_SampleBlockMeta(NamedTuple):
290 input: SampleBlockMeta
291 output: SampleBlockMeta
294def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
295 """returns which halo input tensors need to be divided into blocks with, such that
296 `output_halo` can be cropped from their outputs without introducing gaps."""
297 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
298 outputs = {t.id: t for t in model.outputs}
299 all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
301 for t, th in output_halo.items():
302 axes = {a.id: a for a in outputs[t].axes}
304 for a, ah in th.items():
305 s = axes[a].size
306 if not isinstance(s, v0_5.SizeReference):
307 raise ValueError(
308 f"Unable to map output halo for {t}.{a} to an input axis"
309 )
311 axis = axes[a]
312 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
314 input_halo_left = ah.left * axis.scale / ref_axis.scale
315 input_halo_right = ah.right * axis.scale / ref_axis.scale
316 assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int"
317 assert input_halo_right == int(input_halo_right), (
318 f"{input_halo_right} not int"
319 )
321 input_halo.setdefault(s.tensor_id, {})[a] = Halo(
322 int(input_halo_left), int(input_halo_right)
323 )
325 return input_halo
328def get_block_transform(
329 model: v0_5.ModelDescr,
330) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
331 """returns how a model's output tensor shapes relates to its input shapes"""
332 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
333 batch_axis_trf = None
334 for ipt in model.inputs:
335 for a in ipt.axes:
336 if a.type == "batch":
337 batch_axis_trf = LinearSampleAxisTransform(
338 axis=a.id, scale=1, offset=0, member=ipt.id
339 )
340 break
341 if batch_axis_trf is not None:
342 break
343 axis_scales = {
344 t.id: {a.id: a.scale for a in t.axes}
345 for t in chain(model.inputs, model.outputs)
346 }
347 for out in model.outputs:
348 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
349 for a in out.axes:
350 if a.size is None:
351 assert a.type == "batch"
352 if batch_axis_trf is None:
353 raise ValueError(
354 "no batch axis found in any input tensor, but output tensor"
355 + f" '{out.id}' has one."
356 )
357 s = batch_axis_trf
358 elif isinstance(a.size, int):
359 s = a.size
360 elif isinstance(a.size, v0_5.DataDependentSize):
361 s = -1
362 elif isinstance(a.size, v0_5.SizeReference):
363 s = LinearSampleAxisTransform(
364 axis=a.size.axis_id,
365 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
366 offset=a.size.offset,
367 member=a.size.tensor_id,
368 )
369 else:
370 assert_never(a.size)
372 new_axes[a.id] = s
374 # account for postprocessing that changes the nubmer of output channels by
375 # overwriting described output shape by the intermediate output shape
376 c = AxisId("channel")
377 if c not in new_axes:
378 continue
379 for post in out.postprocessing:
380 if post.id == "cellpose_flow_dynamics":
381 new_axes[c] = 3
382 break
383 elif post.id == "stardist_postprocessing":
384 new_axes[c] = post.kwargs.n_rays + 1
385 break
387 ret[out.id] = new_axes
389 return ret
392def get_io_sample_block_metas(
393 model: v0_5.ModelDescr,
394 input_sample_shape: PerMember[PerAxis[int]],
395 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
396 batch_size: int = 1,
397) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
398 """returns an iterable yielding meta data for corresponding input and output samples"""
399 if not isinstance(model, v0_5.ModelDescr):
400 raise TypeError(f"get_block_meta() not implemented for {type(model)}")
402 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
403 input_block_shape = {
404 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
405 for t in {tt for tt, _ in block_axis_sizes.inputs}
406 }
407 output_halo = {
408 t.id: {
409 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
410 }
411 for t in model.outputs
412 }
413 input_halo = get_input_halo(model, output_halo)
415 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
416 input_sample_shape, input_block_shape, halo=input_halo
417 )
418 block_transform = get_block_transform(model)
419 return n_input_blocks, (
420 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
421 for ipt in sample_block_meta_generator(
422 input_blocks, sample_shape=input_sample_shape, sample_id=None
423 )
424 )
427def get_tensor(
428 src: TensorSource,
429 descr: Union[
430 v0_4.InputTensorDescr,
431 v0_5.InputTensorDescr,
432 v0_4.OutputTensorDescr,
433 v0_5.OutputTensorDescr,
434 Sequence[AxisInfo],
435 ],
436):
437 """helper to cast/load various tensor sources"""
439 if isinstance(
440 descr,
441 (
442 v0_4.InputTensorDescr,
443 v0_5.InputTensorDescr,
444 v0_4.OutputTensorDescr,
445 v0_5.OutputTensorDescr,
446 ),
447 ):
448 axes = get_axes_infos(descr)
449 else:
450 axes = descr
452 if isinstance(src, Tensor):
453 return src.transpose(axes=[a.id for a in axes])
454 elif isinstance(src, xr.DataArray):
455 return Tensor.from_xarray(src).transpose(axes=[a.id for a in axes])
456 elif isinstance(src, np.ndarray):
457 return Tensor.from_numpy(src, dims=axes)
458 else:
459 return load_tensor(src, axes=axes)
462def create_sample_for_model(
463 model: AnyModelDescr,
464 *,
465 stat: Optional[Stat] = None,
466 sample_id: SampleId = None,
467 inputs: Union[PerMember[TensorSource], TensorSource],
468) -> Sample:
469 """Create a sample from a single set of input(s) for a specific bioimage.io model
471 Args:
472 model: a bioimage.io model description
473 stat: dictionary with sample and dataset statistics (may be updated in-place!)
474 inputs: the input(s) constituting a single sample.
475 """
477 model_inputs = {get_member_id(d): d for d in model.inputs}
478 if isinstance(inputs, collections.abc.Mapping):
479 inputs = {MemberId(k): v for k, v in inputs.items()}
480 elif len(model_inputs) == 1:
481 inputs = {list(model_inputs)[0]: inputs}
482 else:
483 raise TypeError(
484 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
485 )
487 if unknown := {k for k in inputs if k not in model_inputs}:
488 raise ValueError(f"Got unexpected inputs: {unknown}")
490 if missing := {
491 k
492 for k, v in model_inputs.items()
493 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
494 }:
495 raise ValueError(f"Missing non-optional model inputs: {missing}")
497 return Sample(
498 members={
499 m: get_tensor(inputs[m], ipt)
500 for m, ipt in model_inputs.items()
501 if m in inputs
502 },
503 stat={} if stat is None else stat,
504 id=sample_id,
505 )
508def load_sample_for_model(
509 *,
510 model: AnyModelDescr,
511 paths: PerMember[Path],
512 axes: Optional[PerMember[Sequence[AxisLike]]] = None,
513 stat: Optional[Stat] = None,
514 sample_id: Optional[SampleId] = None,
515):
516 """load a single sample from `paths` that can be processed by `model`"""
518 if axes is None:
519 axes = {}
521 # make sure members are keyed by MemberId, not string
522 paths = {MemberId(k): v for k, v in paths.items()}
523 axes = {MemberId(k): v for k, v in axes.items()}
525 model_inputs = {get_member_id(d): d for d in model.inputs}
527 if unknown := {k for k in paths if k not in model_inputs}:
528 raise ValueError(f"Got unexpected paths for {unknown}")
530 if unknown := {k for k in axes if k not in model_inputs}:
531 raise ValueError(f"Got unexpected axes hints for: {unknown}")
533 members: Dict[MemberId, Tensor] = {}
534 for m, p in paths.items():
535 if m not in axes:
536 axes[m] = get_axes_infos(model_inputs[m])
537 logger.info(
538 "loading '{}' from {} with default input axes {} ",
539 m,
540 p,
541 axes[m],
542 )
543 members[m] = load_tensor(p, axes[m])
545 return Sample(
546 members=members,
547 stat={} if stat is None else stat,
548 id=sample_id or tuple(sorted(paths.values())),
549 )