Coverage for bioimageio/core/digest_spec.py: 86%
205 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 09:51 +0000
1from __future__ import annotations
3import collections.abc
4import hashlib
5import importlib.util
6import sys
7from itertools import chain
8from pathlib import Path
9from typing import (
10 Any,
11 Callable,
12 Dict,
13 Iterable,
14 List,
15 Mapping,
16 NamedTuple,
17 Optional,
18 Sequence,
19 Tuple,
20 Union,
21)
23import numpy as np
24import xarray as xr
25from loguru import logger
26from numpy.typing import NDArray
27from typing_extensions import Unpack, assert_never
29from bioimageio.spec._internal.io import HashKwargs
30from bioimageio.spec.common import FileDescr, FileSource, ZipPath
31from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
32from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
33from bioimageio.spec.model.v0_5 import (
34 ArchitectureFromFileDescr,
35 ArchitectureFromLibraryDescr,
36 ParameterizedSize_N,
37)
38from bioimageio.spec.utils import download, load_array
40from ._settings import settings
41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
42from .block_meta import split_multiple_shapes_into_blocks
43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
44from .io import load_tensor
45from .sample import (
46 LinearSampleAxisTransform,
47 Sample,
48 SampleBlockMeta,
49 sample_block_meta_generator,
50)
51from .stat_measures import Stat
52from .tensor import Tensor
54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path]
57def import_callable(
58 node: Union[
59 ArchitectureFromFileDescr,
60 ArchitectureFromLibraryDescr,
61 CallableFromDepencency,
62 CallableFromFile,
63 ],
64 /,
65 **kwargs: Unpack[HashKwargs],
66) -> Callable[..., Any]:
67 """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
68 if isinstance(node, CallableFromDepencency):
69 module = importlib.import_module(node.module_name)
70 c = getattr(module, str(node.callable_name))
71 elif isinstance(node, ArchitectureFromLibraryDescr):
72 module = importlib.import_module(node.import_from)
73 c = getattr(module, str(node.callable))
74 elif isinstance(node, CallableFromFile):
75 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
76 elif isinstance(node, ArchitectureFromFileDescr):
77 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
78 else:
79 assert_never(node)
81 if not callable(c):
82 raise ValueError(f"{node} (imported: {c}) is not callable")
84 return c
87def _import_from_file_impl(
88 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
89):
90 src_descr = FileDescr(source=source, **kwargs)
91 # ensure sha is valid even if perform_io_checks=False
92 src_descr.validate_sha256()
93 assert src_descr.sha256 is not None
95 local_source = src_descr.download()
97 source_bytes = local_source.path.read_bytes()
98 assert isinstance(source_bytes, bytes)
99 source_sha = hashlib.sha256(source_bytes).hexdigest()
101 # make sure we have unique module name
102 module_name = f"{local_source.path.stem}_{source_sha}"
104 # make sure we have a valid module name
105 if not module_name.isidentifier():
106 module_name = f"custom_module_{source_sha}"
107 assert module_name.isidentifier(), module_name
109 module = sys.modules.get(module_name)
110 if module is None:
111 try:
112 if isinstance(local_source.path, Path):
113 module_path = local_source.path
114 elif isinstance(local_source.path, ZipPath):
115 # save extract source to cache
116 # loading from a file from disk ensure we get readable tracebacks
117 # if any errors occur
118 module_path = (
119 settings.cache_path / f"{source_sha}-{local_source.path.name}"
120 )
121 _ = module_path.write_bytes(source_bytes)
122 else:
123 assert_never(local_source.path)
125 importlib_spec = importlib.util.spec_from_file_location(
126 module_name, module_path
127 )
129 if importlib_spec is None:
130 raise ImportError(f"Failed to import {source}")
132 module = importlib.util.module_from_spec(importlib_spec)
133 assert importlib_spec.loader is not None
134 importlib_spec.loader.exec_module(module)
136 except Exception as e:
137 raise ImportError(f"Failed to import {source}") from e
138 else:
139 sys.modules[module_name] = module # cache this module
141 try:
142 callable_attr = getattr(module, callable_name)
143 except AttributeError as e:
144 raise AttributeError(
145 f"Imported custom module from {source} has no `{callable_name}` attribute."
146 ) from e
147 except Exception as e:
148 raise AttributeError(
149 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
150 ) from e
152 else:
153 return callable_attr
156def get_axes_infos(
157 io_descr: Union[
158 v0_4.InputTensorDescr,
159 v0_4.OutputTensorDescr,
160 v0_5.InputTensorDescr,
161 v0_5.OutputTensorDescr,
162 ],
163) -> List[AxisInfo]:
164 """get a unified, simplified axis representation from spec axes"""
165 ret: List[AxisInfo] = []
166 for a in io_descr.axes:
167 if isinstance(a, v0_5.AxisBase):
168 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type)))
169 else:
170 assert a in ("b", "i", "t", "c", "z", "y", "x")
171 ret.append(AxisInfo.create(a))
173 return ret
176def get_member_id(
177 tensor_description: Union[
178 v0_4.InputTensorDescr,
179 v0_4.OutputTensorDescr,
180 v0_5.InputTensorDescr,
181 v0_5.OutputTensorDescr,
182 ],
183) -> MemberId:
184 """get the normalized tensor ID, usable as a sample member ID"""
186 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
187 return MemberId(tensor_description.name)
188 elif isinstance(
189 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
190 ):
191 return tensor_description.id
192 else:
193 assert_never(tensor_description)
196def get_member_ids(
197 tensor_descriptions: Sequence[
198 Union[
199 v0_4.InputTensorDescr,
200 v0_4.OutputTensorDescr,
201 v0_5.InputTensorDescr,
202 v0_5.OutputTensorDescr,
203 ]
204 ],
205) -> List[MemberId]:
206 """get normalized tensor IDs to be used as sample member IDs"""
207 return [get_member_id(descr) for descr in tensor_descriptions]
210def get_test_inputs(model: AnyModelDescr) -> Sample:
211 """returns a model's test input sample"""
212 member_ids = get_member_ids(model.inputs)
213 if isinstance(model, v0_4.ModelDescr):
214 arrays = [load_array(tt) for tt in model.test_inputs]
215 else:
216 arrays = [load_array(d.test_tensor) for d in model.inputs]
218 axes = [get_axes_infos(t) for t in model.inputs]
219 return Sample(
220 members={
221 m: Tensor.from_numpy(arr, dims=ax)
222 for m, arr, ax in zip(member_ids, arrays, axes)
223 },
224 stat={},
225 id="test-sample",
226 )
229def get_test_outputs(model: AnyModelDescr) -> Sample:
230 """returns a model's test output sample"""
231 member_ids = get_member_ids(model.outputs)
233 if isinstance(model, v0_4.ModelDescr):
234 arrays = [load_array(tt) for tt in model.test_outputs]
235 else:
236 arrays = [load_array(d.test_tensor) for d in model.outputs]
238 axes = [get_axes_infos(t) for t in model.outputs]
240 return Sample(
241 members={
242 m: Tensor.from_numpy(arr, dims=ax)
243 for m, arr, ax in zip(member_ids, arrays, axes)
244 },
245 stat={},
246 id="test-sample",
247 )
250class IO_SampleBlockMeta(NamedTuple):
251 input: SampleBlockMeta
252 output: SampleBlockMeta
255def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
256 """returns which halo input tensors need to be divided into blocks with, such that
257 `output_halo` can be cropped from their outputs without introducing gaps."""
258 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
259 outputs = {t.id: t for t in model.outputs}
260 all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
262 for t, th in output_halo.items():
263 axes = {a.id: a for a in outputs[t].axes}
265 for a, ah in th.items():
266 s = axes[a].size
267 if not isinstance(s, v0_5.SizeReference):
268 raise ValueError(
269 f"Unable to map output halo for {t}.{a} to an input axis"
270 )
272 axis = axes[a]
273 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
275 total_output_halo = sum(ah)
276 total_input_halo = total_output_halo * axis.scale / ref_axis.scale
277 assert (
278 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
279 )
280 input_halo.setdefault(s.tensor_id, {})[a] = Halo(
281 int(total_input_halo // 2), int(total_input_halo // 2)
282 )
284 return input_halo
287def get_block_transform(
288 model: v0_5.ModelDescr,
289) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
290 """returns how a model's output tensor shapes relates to its input shapes"""
291 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
292 batch_axis_trf = None
293 for ipt in model.inputs:
294 for a in ipt.axes:
295 if a.type == "batch":
296 batch_axis_trf = LinearSampleAxisTransform(
297 axis=a.id, scale=1, offset=0, member=ipt.id
298 )
299 break
300 if batch_axis_trf is not None:
301 break
302 axis_scales = {
303 t.id: {a.id: a.scale for a in t.axes}
304 for t in chain(model.inputs, model.outputs)
305 }
306 for out in model.outputs:
307 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
308 for a in out.axes:
309 if a.size is None:
310 assert a.type == "batch"
311 if batch_axis_trf is None:
312 raise ValueError(
313 "no batch axis found in any input tensor, but output tensor"
314 + f" '{out.id}' has one."
315 )
316 s = batch_axis_trf
317 elif isinstance(a.size, int):
318 s = a.size
319 elif isinstance(a.size, v0_5.DataDependentSize):
320 s = -1
321 elif isinstance(a.size, v0_5.SizeReference):
322 s = LinearSampleAxisTransform(
323 axis=a.size.axis_id,
324 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
325 offset=a.size.offset,
326 member=a.size.tensor_id,
327 )
328 else:
329 assert_never(a.size)
331 new_axes[a.id] = s
333 ret[out.id] = new_axes
335 return ret
338def get_io_sample_block_metas(
339 model: v0_5.ModelDescr,
340 input_sample_shape: PerMember[PerAxis[int]],
341 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
342 batch_size: int = 1,
343) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
344 """returns an iterable yielding meta data for corresponding input and output samples"""
345 if not isinstance(model, v0_5.ModelDescr):
346 raise TypeError(f"get_block_meta() not implemented for {type(model)}")
348 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
349 input_block_shape = {
350 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
351 for t in {tt for tt, _ in block_axis_sizes.inputs}
352 }
353 output_halo = {
354 t.id: {
355 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
356 }
357 for t in model.outputs
358 }
359 input_halo = get_input_halo(model, output_halo)
361 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
362 input_sample_shape, input_block_shape, halo=input_halo
363 )
364 block_transform = get_block_transform(model)
365 return n_input_blocks, (
366 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
367 for ipt in sample_block_meta_generator(
368 input_blocks, sample_shape=input_sample_shape, sample_id=None
369 )
370 )
373def get_tensor(
374 src: Union[ZipPath, TensorSource],
375 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
376):
377 """helper to cast/load various tensor sources"""
379 if isinstance(src, Tensor):
380 return src
382 if isinstance(src, xr.DataArray):
383 return Tensor.from_xarray(src)
385 if isinstance(src, np.ndarray):
386 return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
388 if isinstance(src, FileDescr):
389 src = download(src).path
391 if isinstance(src, (ZipPath, Path, str)):
392 return load_tensor(src, axes=get_axes_infos(ipt))
394 assert_never(src)
397def create_sample_for_model(
398 model: AnyModelDescr,
399 *,
400 stat: Optional[Stat] = None,
401 sample_id: SampleId = None,
402 inputs: Union[PerMember[TensorSource], TensorSource],
403) -> Sample:
404 """Create a sample from a single set of input(s) for a specific bioimage.io model
406 Args:
407 model: a bioimage.io model description
408 stat: dictionary with sample and dataset statistics (may be updated in-place!)
409 inputs: the input(s) constituting a single sample.
410 """
412 model_inputs = {get_member_id(d): d for d in model.inputs}
413 if isinstance(inputs, collections.abc.Mapping):
414 inputs = {MemberId(k): v for k, v in inputs.items()}
415 elif len(model_inputs) == 1:
416 inputs = {list(model_inputs)[0]: inputs}
417 else:
418 raise TypeError(
419 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}"
420 )
422 if unknown := {k for k in inputs if k not in model_inputs}:
423 raise ValueError(f"Got unexpected inputs: {unknown}")
425 if missing := {
426 k
427 for k, v in model_inputs.items()
428 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
429 }:
430 raise ValueError(f"Missing non-optional model inputs: {missing}")
432 return Sample(
433 members={
434 m: get_tensor(inputs[m], ipt)
435 for m, ipt in model_inputs.items()
436 if m in inputs
437 },
438 stat={} if stat is None else stat,
439 id=sample_id,
440 )
443def load_sample_for_model(
444 *,
445 model: AnyModelDescr,
446 paths: PerMember[Path],
447 axes: Optional[PerMember[Sequence[AxisLike]]] = None,
448 stat: Optional[Stat] = None,
449 sample_id: Optional[SampleId] = None,
450):
451 """load a single sample from `paths` that can be processed by `model`"""
453 if axes is None:
454 axes = {}
456 # make sure members are keyed by MemberId, not string
457 paths = {MemberId(k): v for k, v in paths.items()}
458 axes = {MemberId(k): v for k, v in axes.items()}
460 model_inputs = {get_member_id(d): d for d in model.inputs}
462 if unknown := {k for k in paths if k not in model_inputs}:
463 raise ValueError(f"Got unexpected paths for {unknown}")
465 if unknown := {k for k in axes if k not in model_inputs}:
466 raise ValueError(f"Got unexpected axes hints for: {unknown}")
468 members: Dict[MemberId, Tensor] = {}
469 for m, p in paths.items():
470 if m not in axes:
471 axes[m] = get_axes_infos(model_inputs[m])
472 logger.debug(
473 "loading '{}' from {} with default input axes {} ",
474 m,
475 p,
476 axes[m],
477 )
478 members[m] = load_tensor(p, axes[m])
480 return Sample(
481 members=members,
482 stat={} if stat is None else stat,
483 id=sample_id or tuple(sorted(paths.values())),
484 )