Coverage for src/bioimageio/core/digest_spec.py: 83%
239 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-22 16:54 +0000
1from __future__ import annotations
3import collections.abc
4import importlib.util
5import sys
6from itertools import chain
7from pathlib import Path
8from tempfile import TemporaryDirectory
9from typing import (
10 Any,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 NamedTuple,
17 Optional,
18 Sequence,
19 Tuple,
20 Union,
21)
22from zipfile import ZipFile, is_zipfile
24import numpy as np
25import xarray as xr
26from loguru import logger
27from numpy.typing import NDArray
28from typing_extensions import TypeAlias, Unpack, assert_never
30from bioimageio.spec._internal.io import HashKwargs, PermissiveFileSource
31from bioimageio.spec.common import FileDescr, FileSource
32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
34from bioimageio.spec.model.v0_5 import (
35 ArchitectureFromFileDescr,
36 ArchitectureFromLibraryDescr,
37 ParameterizedSize_N,
38)
39from bioimageio.spec.utils import load_array
41from .axis import AxisId, AxisInfo, AxisLike, PerAxis
42from .block_meta import split_multiple_shapes_into_blocks
43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
44from .io import load_tensor
45from .sample import (
46 LinearSampleAxisTransform,
47 Sample,
48 SampleBlockMeta,
49 SampleBlockWithOrigin,
50 sample_block_meta_generator,
51)
52from .stat_measures import Stat
53from .tensor import Tensor
55TensorSource: TypeAlias = Union[
56 Tensor, xr.DataArray, NDArray[Any], PermissiveFileSource
57]
60def import_callable(
61 node: Union[
62 ArchitectureFromFileDescr,
63 ArchitectureFromLibraryDescr,
64 CallableFromDepencency,
65 CallableFromFile,
66 v0_5.CustomProcessingDescr,
67 ],
68 /,
69 **kwargs: Unpack[HashKwargs],
70) -> Callable[..., Any]:
71 """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
72 if isinstance(node, CallableFromDepencency):
73 module = importlib.import_module(node.module_name)
74 c = getattr(module, str(node.callable_name))
75 elif isinstance(node, ArchitectureFromLibraryDescr):
76 module = importlib.import_module(node.import_from)
77 c = getattr(module, str(node.callable))
78 elif isinstance(node, CallableFromFile):
79 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
80 elif isinstance(node, (ArchitectureFromFileDescr, v0_5.CustomProcessingDescr)):
81 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
82 else:
83 assert_never(node)
85 if not callable(c):
86 raise ValueError(f"{node} (imported: {c}) is not callable")
88 return c
91tmp_dirs_in_use: List[TemporaryDirectory[str]] = []
92"""keep global reference to temporary directories created during import to delay cleanup"""
95def _import_from_file_impl(
96 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
97):
98 src_descr = FileDescr(source=source, **kwargs)
99 # ensure sha is valid even if perform_io_checks=False
100 # or the source has changed since last sha computation
101 src_descr.validate_sha256(force_recompute=True)
102 assert src_descr.sha256 is not None
103 source_sha = src_descr.sha256
105 reader = src_descr.get_reader()
106 # make sure we have unique module name
107 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}"
109 # make sure we have a unique and valid module name
110 if not module_name.isidentifier():
111 module_name = f"custom_module_{source_sha}"
112 assert module_name.isidentifier(), module_name
114 source_bytes = reader.read()
116 module = sys.modules.get(module_name)
117 if module is None:
118 try:
119 td_kwargs: Dict[str, Any] = (
120 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {}
121 )
122 if sys.version_info >= (3, 12):
123 td_kwargs["delete"] = False
125 tmp_dir = TemporaryDirectory(**td_kwargs)
126 # keep global ref to tmp_dir to delay cleanup until program exit
127 # TODO: remove for py >= 3.12, when delete=False works
128 tmp_dirs_in_use.append(tmp_dir)
130 module_path = Path(tmp_dir.name) / module_name
131 if reader.original_file_name.endswith(".zip") or is_zipfile(reader):
132 module_path.mkdir()
133 ZipFile(reader).extractall(path=module_path)
134 else:
135 module_path = module_path.with_suffix(".py")
136 _ = module_path.write_bytes(source_bytes)
138 importlib_spec = importlib.util.spec_from_file_location(
139 module_name, str(module_path)
140 )
142 if importlib_spec is None:
143 raise ImportError(f"Failed to import {source}")
145 module = importlib.util.module_from_spec(importlib_spec)
147 sys.modules[module_name] = module # cache this module
149 assert importlib_spec.loader is not None
150 importlib_spec.loader.exec_module(module)
152 except Exception as e:
153 if module_name in sys.modules:
154 del sys.modules[module_name]
156 raise ImportError(f"Failed to import {source}") from e
158 try:
159 callable_attr = getattr(module, callable_name)
160 except AttributeError as e:
161 raise AttributeError(
162 f"Imported custom module from {source} has no `{callable_name}` attribute."
163 ) from e
164 except Exception as e:
165 raise AttributeError(
166 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
167 ) from e
169 else:
170 return callable_attr
173def get_axes_infos(
174 io_descr: Union[
175 v0_4.InputTensorDescr,
176 v0_4.OutputTensorDescr,
177 v0_5.InputTensorDescr,
178 v0_5.OutputTensorDescr,
179 ],
180) -> List[AxisInfo]:
181 """get a unified, simplified axis representation from spec axes"""
182 return [AxisInfo.create(a) for a in io_descr.axes]
185def get_member_id(
186 tensor_description: Union[
187 v0_4.InputTensorDescr,
188 v0_4.OutputTensorDescr,
189 v0_5.InputTensorDescr,
190 v0_5.OutputTensorDescr,
191 ],
192) -> MemberId:
193 """get the normalized tensor ID, usable as a sample member ID"""
195 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
196 return MemberId(tensor_description.name)
197 elif isinstance(
198 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
199 ):
200 return tensor_description.id
201 else:
202 assert_never(tensor_description)
205def get_member_ids(
206 tensor_descriptions: Iterable[
207 Union[
208 v0_4.InputTensorDescr,
209 v0_4.OutputTensorDescr,
210 v0_5.InputTensorDescr,
211 v0_5.OutputTensorDescr,
212 ]
213 ],
214) -> List[MemberId]:
215 """get normalized tensor IDs to be used as sample member IDs"""
216 return [get_member_id(descr) for descr in tensor_descriptions]
219def get_test_input_sample(model: AnyModelDescr) -> Sample:
220 if isinstance(model, v0_4.ModelDescr):
221 info = {
222 MemberId(d.name): (d, t) for d, t in zip(model.inputs, model.test_inputs)
223 }
224 else:
225 info = {d.id: d for d in model.inputs}
227 return _get_test_sample(info)
230get_test_inputs = get_test_input_sample
231"""DEPRECATED: use `get_test_input_sample` instead"""
234def get_test_output_sample(model: AnyModelDescr) -> Sample:
235 """returns a model's test output sample"""
236 if isinstance(model, v0_4.ModelDescr):
237 info = {
238 MemberId(d.name): (d, t) for d, t in zip(model.outputs, model.test_outputs)
239 }
240 else:
241 info = {d.id: d for d in model.outputs}
243 return _get_test_sample(info)
246get_test_outputs = get_test_output_sample
247"""DEPRECATED: use `get_test_output_sample` instead"""
250def _get_test_sample(
251 info: Union[
252 Mapping[MemberId, Union[v0_5.InputTensorDescr, v0_5.OutputTensorDescr]],
253 Mapping[
254 MemberId,
255 Tuple[
256 v0_4.InputTensorDescr,
257 FileSource,
258 ],
259 ],
260 Mapping[
261 MemberId,
262 Tuple[
263 v0_4.OutputTensorDescr,
264 FileSource,
265 ],
266 ],
267 ],
268) -> Sample:
269 arrays: Dict[MemberId, NDArray[Any]] = {}
270 for m, src in info.items():
271 if isinstance(src, tuple):
272 arrays[m] = load_array(src[1])
273 elif isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)):
274 if src.test_tensor is None:
275 raise ValueError(
276 f"Model input '{m}' has no test tensor defined, cannot create test sample."
277 )
278 arrays[m] = load_array(src.test_tensor)
279 else:
280 assert_never(src)
282 axes = {
283 m: get_axes_infos(t[0] if isinstance(t, tuple) else t) for m, t in info.items()
284 }
285 return Sample(
286 members={m: Tensor.from_numpy(arrays[m], dims=axes[m]) for m in info},
287 stat={},
288 id="test-sample",
289 )
292class IO_SampleBlockMeta(NamedTuple):
293 input: SampleBlockMeta
294 output: SampleBlockMeta
297def get_input_halo(
298 model: v0_5.ModelDescr, output_halo: Optional[PerMember[PerAxis[Halo]]] = None
299):
300 """returns which halo input tensors need to be divided into blocks with, such that
301 `output_halo` can be cropped from their outputs without introducing gaps."""
302 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
303 outputs = {t.id: t for t in model.outputs}
304 all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
305 if output_halo is None:
306 output_halo = {
307 t.id: {
308 a.id: Halo(a.halo, a.halo)
309 for a in t.axes
310 if isinstance(a, v0_5.WithHalo)
311 }
312 for t in model.outputs
313 }
315 for t, th in output_halo.items():
316 axes = {a.id: a for a in outputs[t].axes}
318 for a, ah in th.items():
319 s = axes[a].size
320 if not isinstance(s, v0_5.SizeReference):
321 raise ValueError(
322 f"Unable to map output halo for {t}.{a} to an input axis"
323 )
325 axis = axes[a]
326 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
328 input_halo_left = ah.left * axis.scale / ref_axis.scale
329 input_halo_right = ah.right * axis.scale / ref_axis.scale
330 assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int"
331 assert input_halo_right == int(input_halo_right), (
332 f"{input_halo_right} not int"
333 )
335 input_halo.setdefault(s.tensor_id, {})[a] = Halo(
336 int(input_halo_left), int(input_halo_right)
337 )
339 return input_halo
342def get_block_transform(
343 model: v0_5.ModelDescr,
344) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
345 """returns how a model's output tensor shapes relates to its input shapes"""
346 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
347 batch_axis_trf = None
348 for ipt in model.inputs:
349 for a in ipt.axes:
350 if a.type == "batch":
351 batch_axis_trf = LinearSampleAxisTransform(
352 axis=a.id, scale=1, offset=0, member=ipt.id
353 )
354 break
355 if batch_axis_trf is not None:
356 break
357 axis_scales = {
358 t.id: {a.id: a.scale for a in t.axes}
359 for t in chain(model.inputs, model.outputs)
360 }
361 for out in model.outputs:
362 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
363 for a in out.axes:
364 if a.size is None:
365 assert a.type == "batch"
366 if batch_axis_trf is None:
367 raise ValueError(
368 "no batch axis found in any input tensor, but output tensor"
369 + f" '{out.id}' has one."
370 )
371 s = batch_axis_trf
372 elif isinstance(a.size, int):
373 s = a.size
374 elif isinstance(a.size, v0_5.DataDependentSize):
375 s = -1
376 elif isinstance(a.size, v0_5.SizeReference):
377 s = LinearSampleAxisTransform(
378 axis=a.size.axis_id,
379 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
380 offset=a.size.offset,
381 member=a.size.tensor_id,
382 )
383 else:
384 assert_never(a.size)
386 new_axes[a.id] = s
388 # account for postprocessing that changes the nubmer of output channels by
389 # overwriting described output shape by the intermediate output shape
390 c = AxisId("channel")
391 if c not in new_axes:
392 continue
393 for post in out.postprocessing:
394 if post.id == "cellpose_flow_dynamics":
395 new_axes[c] = 3
396 break
397 elif post.id == "stardist_postprocessing":
398 new_axes[c] = post.kwargs.n_rays + 1
399 break
401 ret[out.id] = new_axes
403 return ret
406def get_io_sample_block_metas(
407 model: v0_5.ModelDescr,
408 input_sample_shape: PerMember[PerAxis[int]],
409 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
410 batch_size: int = 1,
411) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
412 """returns an iterable yielding meta data for corresponding input and output samples"""
413 if not isinstance(model, v0_5.ModelDescr):
414 raise TypeError(f"get_block_meta() not implemented for {type(model)}")
416 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
417 input_block_shape = {
418 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
419 for t in {tt for tt, _ in block_axis_sizes.inputs}
420 }
421 output_halo = {
422 t.id: {
423 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
424 }
425 for t in model.outputs
426 }
427 input_halo = get_input_halo(model, output_halo)
429 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
430 input_sample_shape, input_block_shape, halo=input_halo
431 )
432 block_transform = get_block_transform(model)
433 return n_input_blocks, (
434 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
435 for ipt in sample_block_meta_generator(
436 input_blocks, sample_shape=input_sample_shape, sample_id=None
437 )
438 )
441def get_tensor(
442 src: TensorSource,
443 descr: Union[
444 v0_4.InputTensorDescr,
445 v0_5.InputTensorDescr,
446 v0_4.OutputTensorDescr,
447 v0_5.OutputTensorDescr,
448 Sequence[AxisInfo],
449 ],
450):
451 """helper to cast/load various tensor sources"""
453 if isinstance(
454 descr,
455 (
456 v0_4.InputTensorDescr,
457 v0_5.InputTensorDescr,
458 v0_4.OutputTensorDescr,
459 v0_5.OutputTensorDescr,
460 ),
461 ):
462 axes = get_axes_infos(descr)
463 else:
464 axes = descr
466 if isinstance(src, Tensor):
467 return src.transpose(axes=[a.id for a in axes])
468 elif isinstance(src, xr.DataArray):
469 return Tensor.from_xarray(src).transpose(axes=[a.id for a in axes])
470 elif isinstance(src, np.ndarray):
471 return Tensor.from_numpy(src, dims=axes)
472 else:
473 return load_tensor(src, axes=axes)
476def create_sample_for_model(
477 model: AnyModelDescr,
478 *,
479 stat: Optional[Stat] = None,
480 sample_id: SampleId = None,
481 inputs: Union[PerMember[TensorSource], TensorSource],
482) -> Sample:
483 """Create a sample from a single set of input(s) for a specific bioimage.io model
485 Args:
486 model: a bioimage.io model description
487 stat: dictionary with sample and dataset statistics (may be updated in-place!)
488 inputs: the input(s) constituting a single sample.
489 """
491 model_inputs = {get_member_id(d): d for d in model.inputs}
492 if isinstance(inputs, collections.abc.Mapping):
493 inputs = {MemberId(k): v for k, v in inputs.items()}
494 elif len(model_inputs) == 1:
495 inputs = {list(model_inputs)[0]: inputs}
496 else:
497 raise TypeError(
498 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
499 )
501 if unknown := {k for k in inputs if k not in model_inputs}:
502 raise ValueError(f"Got unexpected inputs: {unknown}")
504 if missing := {
505 k
506 for k, v in model_inputs.items()
507 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
508 }:
509 raise ValueError(f"Missing non-optional model inputs: {missing}")
511 return Sample(
512 members={
513 m: get_tensor(inputs[m], ipt)
514 for m, ipt in model_inputs.items()
515 if m in inputs
516 },
517 stat={} if stat is None else stat,
518 id=sample_id,
519 )
522def load_sample_for_model(
523 *,
524 model: AnyModelDescr,
525 paths: PerMember[Path],
526 axes: Optional[PerMember[Sequence[AxisLike]]] = None,
527 stat: Optional[Stat] = None,
528 sample_id: Optional[SampleId] = None,
529):
530 """load a single sample from `paths` that can be processed by `model`"""
532 if axes is None:
533 axes = {}
535 # make sure members are keyed by MemberId, not string
536 paths = {MemberId(k): v for k, v in paths.items()}
537 axes = {MemberId(k): v for k, v in axes.items()}
539 model_inputs = {get_member_id(d): d for d in model.inputs}
541 if unknown := {k for k in paths if k not in model_inputs}:
542 raise ValueError(f"Got unexpected paths for {unknown}")
544 if unknown := {k for k in axes if k not in model_inputs}:
545 raise ValueError(f"Got unexpected axes hints for: {unknown}")
547 members: Dict[MemberId, Tensor] = {}
548 for m, p in paths.items():
549 if m not in axes:
550 axes[m] = get_axes_infos(model_inputs[m])
551 logger.info(
552 "loading '{}' from {} with default input axes {} ",
553 m,
554 p,
555 axes[m],
556 )
557 members[m] = load_tensor(p, axes[m])
559 return Sample(
560 members=members,
561 stat={} if stat is None else stat,
562 id=sample_id or tuple(sorted(paths.values())),
563 )
566def split_sample_into_blocks_for_model(
567 sample: Sample,
568 model: v0_5.ModelDescr,
569 blocksize_parameter: int,
570 batch_size: int = 1,
571) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]:
572 if isinstance(model, v0_4.ModelDescr):
573 raise NotImplementedError(
574 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
575 + f" {model.name}."
576 + " Consider using `predict_sample_with_fixed_blocking` or update the model description to format version 0.5."
577 )
579 ns = {
580 (ipt.id, a.id): blocksize_parameter
581 for ipt in model.inputs
582 for a in ipt.axes
583 if isinstance(a.size, v0_5.ParameterizedSize)
584 }
585 halo = get_input_halo(model)
587 input_block_shape = model.get_tensor_sizes(ns, batch_size=batch_size).inputs
589 return sample.split_into_blocks(
590 block_shapes=input_block_shape,
591 halo=halo,
592 pad_mode={ipt.id: ipt.pad or "symmetric" for ipt in model.inputs},
593 )