Coverage for bioimageio/core/digest_spec.py: 85%
204 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 15:20 +0000
1from __future__ import annotations
3import collections.abc
4import importlib.util
5import sys
6from itertools import chain
7from pathlib import Path
8from tempfile import TemporaryDirectory
9from typing import (
10 Any,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 NamedTuple,
17 Optional,
18 Sequence,
19 Tuple,
20 Union,
21)
22from zipfile import ZipFile, is_zipfile
24import numpy as np
25import xarray as xr
26from loguru import logger
27from numpy.typing import NDArray
28from typing_extensions import Unpack, assert_never
30from bioimageio.spec._internal.io import HashKwargs
31from bioimageio.spec.common import FileDescr, FileSource, ZipPath
32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
34from bioimageio.spec.model.v0_5 import (
35 ArchitectureFromFileDescr,
36 ArchitectureFromLibraryDescr,
37 ParameterizedSize_N,
38)
39from bioimageio.spec.utils import load_array
41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
42from .block_meta import split_multiple_shapes_into_blocks
43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
44from .io import load_tensor
45from .sample import (
46 LinearSampleAxisTransform,
47 Sample,
48 SampleBlockMeta,
49 sample_block_meta_generator,
50)
51from .stat_measures import Stat
52from .tensor import Tensor
54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path]
57def import_callable(
58 node: Union[
59 ArchitectureFromFileDescr,
60 ArchitectureFromLibraryDescr,
61 CallableFromDepencency,
62 CallableFromFile,
63 ],
64 /,
65 **kwargs: Unpack[HashKwargs],
66) -> Callable[..., Any]:
67 """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
68 if isinstance(node, CallableFromDepencency):
69 module = importlib.import_module(node.module_name)
70 c = getattr(module, str(node.callable_name))
71 elif isinstance(node, ArchitectureFromLibraryDescr):
72 module = importlib.import_module(node.import_from)
73 c = getattr(module, str(node.callable))
74 elif isinstance(node, CallableFromFile):
75 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
76 elif isinstance(node, ArchitectureFromFileDescr):
77 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
78 else:
79 assert_never(node)
81 if not callable(c):
82 raise ValueError(f"{node} (imported: {c}) is not callable")
84 return c
87def _import_from_file_impl(
88 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
89):
90 src_descr = FileDescr(source=source, **kwargs)
91 # ensure sha is valid even if perform_io_checks=False
92 # or the source has changed since last sha computation
93 src_descr.validate_sha256(force_recompute=True)
94 assert src_descr.sha256 is not None
95 source_sha = src_descr.sha256
97 reader = src_descr.get_reader()
98 # make sure we have unique module name
99 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}"
101 # make sure we have a unique and valid module name
102 if not module_name.isidentifier():
103 module_name = f"custom_module_{source_sha}"
104 assert module_name.isidentifier(), module_name
106 source_bytes = reader.read()
108 module = sys.modules.get(module_name)
109 if module is None:
110 try:
111 tmp_dir = TemporaryDirectory(ignore_cleanup_errors=True)
112 module_path = Path(tmp_dir.name) / module_name
113 if reader.original_file_name.endswith(".zip") or is_zipfile(reader):
114 module_path.mkdir()
115 ZipFile(reader).extractall(path=module_path)
116 else:
117 module_path = module_path.with_suffix(".py")
118 _ = module_path.write_bytes(source_bytes)
120 importlib_spec = importlib.util.spec_from_file_location(
121 module_name, str(module_path)
122 )
124 if importlib_spec is None:
125 raise ImportError(f"Failed to import {source}")
127 module = importlib.util.module_from_spec(importlib_spec)
129 sys.modules[module_name] = module # cache this module
131 assert importlib_spec.loader is not None
132 importlib_spec.loader.exec_module(module)
134 except Exception as e:
135 del sys.modules[module_name]
136 raise ImportError(f"Failed to import {source}") from e
138 try:
139 callable_attr = getattr(module, callable_name)
140 except AttributeError as e:
141 raise AttributeError(
142 f"Imported custom module from {source} has no `{callable_name}` attribute."
143 ) from e
144 except Exception as e:
145 raise AttributeError(
146 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
147 ) from e
149 else:
150 return callable_attr
153def get_axes_infos(
154 io_descr: Union[
155 v0_4.InputTensorDescr,
156 v0_4.OutputTensorDescr,
157 v0_5.InputTensorDescr,
158 v0_5.OutputTensorDescr,
159 ],
160) -> List[AxisInfo]:
161 """get a unified, simplified axis representation from spec axes"""
162 ret: List[AxisInfo] = []
163 for a in io_descr.axes:
164 if isinstance(a, v0_5.AxisBase):
165 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
166 else:
167 assert a in ("b", "i", "t", "c", "z", "y", "x")
168 ret.append(AxisInfo.create(a))
170 return ret
173def get_member_id(
174 tensor_description: Union[
175 v0_4.InputTensorDescr,
176 v0_4.OutputTensorDescr,
177 v0_5.InputTensorDescr,
178 v0_5.OutputTensorDescr,
179 ],
180) -> MemberId:
181 """get the normalized tensor ID, usable as a sample member ID"""
183 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
184 return MemberId(tensor_description.name)
185 elif isinstance(
186 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
187 ):
188 return tensor_description.id
189 else:
190 assert_never(tensor_description)
193def get_member_ids(
194 tensor_descriptions: Sequence[
195 Union[
196 v0_4.InputTensorDescr,
197 v0_4.OutputTensorDescr,
198 v0_5.InputTensorDescr,
199 v0_5.OutputTensorDescr,
200 ]
201 ],
202) -> List[MemberId]:
203 """get normalized tensor IDs to be used as sample member IDs"""
204 return [get_member_id(descr) for descr in tensor_descriptions]
207def get_test_inputs(model: AnyModelDescr) -> Sample:
208 """returns a model's test input sample"""
209 member_ids = get_member_ids(model.inputs)
210 if isinstance(model, v0_4.ModelDescr):
211 arrays = [load_array(tt) for tt in model.test_inputs]
212 else:
213 arrays = [load_array(d.test_tensor) for d in model.inputs]
215 axes = [get_axes_infos(t) for t in model.inputs]
216 return Sample(
217 members={
218 m: Tensor.from_numpy(arr, dims=ax)
219 for m, arr, ax in zip(member_ids, arrays, axes)
220 },
221 stat={},
222 id="test-sample",
223 )
226def get_test_outputs(model: AnyModelDescr) -> Sample:
227 """returns a model's test output sample"""
228 member_ids = get_member_ids(model.outputs)
230 if isinstance(model, v0_4.ModelDescr):
231 arrays = [load_array(tt) for tt in model.test_outputs]
232 else:
233 arrays = [load_array(d.test_tensor) for d in model.outputs]
235 axes = [get_axes_infos(t) for t in model.outputs]
237 return Sample(
238 members={
239 m: Tensor.from_numpy(arr, dims=ax)
240 for m, arr, ax in zip(member_ids, arrays, axes)
241 },
242 stat={},
243 id="test-sample",
244 )
247class IO_SampleBlockMeta(NamedTuple):
248 input: SampleBlockMeta
249 output: SampleBlockMeta
252def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
253 """returns which halo input tensors need to be divided into blocks with, such that
254 `output_halo` can be cropped from their outputs without introducing gaps."""
255 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
256 outputs = {t.id: t for t in model.outputs}
257 all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
259 for t, th in output_halo.items():
260 axes = {a.id: a for a in outputs[t].axes}
262 for a, ah in th.items():
263 s = axes[a].size
264 if not isinstance(s, v0_5.SizeReference):
265 raise ValueError(
266 f"Unable to map output halo for {t}.{a} to an input axis"
267 )
269 axis = axes[a]
270 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
272 total_output_halo = sum(ah)
273 total_input_halo = total_output_halo * axis.scale / ref_axis.scale
274 assert (
275 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
276 )
277 input_halo.setdefault(s.tensor_id, {})[a] = Halo(
278 int(total_input_halo // 2), int(total_input_halo // 2)
279 )
281 return input_halo
284def get_block_transform(
285 model: v0_5.ModelDescr,
286) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
287 """returns how a model's output tensor shapes relates to its input shapes"""
288 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
289 batch_axis_trf = None
290 for ipt in model.inputs:
291 for a in ipt.axes:
292 if a.type == "batch":
293 batch_axis_trf = LinearSampleAxisTransform(
294 axis=a.id, scale=1, offset=0, member=ipt.id
295 )
296 break
297 if batch_axis_trf is not None:
298 break
299 axis_scales = {
300 t.id: {a.id: a.scale for a in t.axes}
301 for t in chain(model.inputs, model.outputs)
302 }
303 for out in model.outputs:
304 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
305 for a in out.axes:
306 if a.size is None:
307 assert a.type == "batch"
308 if batch_axis_trf is None:
309 raise ValueError(
310 "no batch axis found in any input tensor, but output tensor"
311 + f" '{out.id}' has one."
312 )
313 s = batch_axis_trf
314 elif isinstance(a.size, int):
315 s = a.size
316 elif isinstance(a.size, v0_5.DataDependentSize):
317 s = -1
318 elif isinstance(a.size, v0_5.SizeReference):
319 s = LinearSampleAxisTransform(
320 axis=a.size.axis_id,
321 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
322 offset=a.size.offset,
323 member=a.size.tensor_id,
324 )
325 else:
326 assert_never(a.size)
328 new_axes[a.id] = s
330 ret[out.id] = new_axes
332 return ret
335def get_io_sample_block_metas(
336 model: v0_5.ModelDescr,
337 input_sample_shape: PerMember[PerAxis[int]],
338 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
339 batch_size: int = 1,
340) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
341 """returns an iterable yielding meta data for corresponding input and output samples"""
342 if not isinstance(model, v0_5.ModelDescr):
343 raise TypeError(f"get_block_meta() not implemented for {type(model)}")
345 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
346 input_block_shape = {
347 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
348 for t in {tt for tt, _ in block_axis_sizes.inputs}
349 }
350 output_halo = {
351 t.id: {
352 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
353 }
354 for t in model.outputs
355 }
356 input_halo = get_input_halo(model, output_halo)
358 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
359 input_sample_shape, input_block_shape, halo=input_halo
360 )
361 block_transform = get_block_transform(model)
362 return n_input_blocks, (
363 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
364 for ipt in sample_block_meta_generator(
365 input_blocks, sample_shape=input_sample_shape, sample_id=None
366 )
367 )
370def get_tensor(
371 src: Union[ZipPath, TensorSource],
372 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
373):
374 """helper to cast/load various tensor sources"""
376 if isinstance(src, Tensor):
377 return src
378 elif isinstance(src, xr.DataArray):
379 return Tensor.from_xarray(src)
380 elif isinstance(src, np.ndarray):
381 return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
382 else:
383 return load_tensor(src, axes=get_axes_infos(ipt))
386def create_sample_for_model(
387 model: AnyModelDescr,
388 *,
389 stat: Optional[Stat] = None,
390 sample_id: SampleId = None,
391 inputs: Union[PerMember[TensorSource], TensorSource],
392) -> Sample:
393 """Create a sample from a single set of input(s) for a specific bioimage.io model
395 Args:
396 model: a bioimage.io model description
397 stat: dictionary with sample and dataset statistics (may be updated in-place!)
398 inputs: the input(s) constituting a single sample.
399 """
401 model_inputs = {get_member_id(d): d for d in model.inputs}
402 if isinstance(inputs, collections.abc.Mapping):
403 inputs = {MemberId(k): v for k, v in inputs.items()}
404 elif len(model_inputs) == 1:
405 inputs = {list(model_inputs)[0]: inputs}
406 else:
407 raise TypeError(
408 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
409 )
411 if unknown := {k for k in inputs if k not in model_inputs}:
412 raise ValueError(f"Got unexpected inputs: {unknown}")
414 if missing := {
415 k
416 for k, v in model_inputs.items()
417 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
418 }:
419 raise ValueError(f"Missing non-optional model inputs: {missing}")
421 return Sample(
422 members={
423 m: get_tensor(inputs[m], ipt)
424 for m, ipt in model_inputs.items()
425 if m in inputs
426 },
427 stat={} if stat is None else stat,
428 id=sample_id,
429 )
432def load_sample_for_model(
433 *,
434 model: AnyModelDescr,
435 paths: PerMember[Path],
436 axes: Optional[PerMember[Sequence[AxisLike]]] = None,
437 stat: Optional[Stat] = None,
438 sample_id: Optional[SampleId] = None,
439):
440 """load a single sample from `paths` that can be processed by `model`"""
442 if axes is None:
443 axes = {}
445 # make sure members are keyed by MemberId, not string
446 paths = {MemberId(k): v for k, v in paths.items()}
447 axes = {MemberId(k): v for k, v in axes.items()}
449 model_inputs = {get_member_id(d): d for d in model.inputs}
451 if unknown := {k for k in paths if k not in model_inputs}:
452 raise ValueError(f"Got unexpected paths for {unknown}")
454 if unknown := {k for k in axes if k not in model_inputs}:
455 raise ValueError(f"Got unexpected axes hints for: {unknown}")
457 members: Dict[MemberId, Tensor] = {}
458 for m, p in paths.items():
459 if m not in axes:
460 axes[m] = get_axes_infos(model_inputs[m])
461 logger.debug(
462 "loading '{}' from {} with default input axes {} ",
463 m,
464 p,
465 axes[m],
466 )
467 members[m] = load_tensor(p, axes[m])
469 return Sample(
470 members=members,
471 stat={} if stat is None else stat,
472 id=sample_id or tuple(sorted(paths.values())),
473 )