Coverage for bioimageio/core/digest_spec.py: 90%
162 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1from __future__ import annotations
3import importlib.util
4from itertools import chain
5from pathlib import Path
6from typing import (
7 Any,
8 Callable,
9 Dict,
10 Iterable,
11 List,
12 Mapping,
13 NamedTuple,
14 Optional,
15 Sequence,
16 Tuple,
17 Union,
18)
20import numpy as np
21import xarray as xr
22from loguru import logger
23from numpy.typing import NDArray
24from typing_extensions import Unpack, assert_never
26from bioimageio.spec._internal.io import resolve_and_extract
27from bioimageio.spec._internal.io_utils import HashKwargs
28from bioimageio.spec.common import FileSource
29from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
30from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
31from bioimageio.spec.model.v0_5 import (
32 ArchitectureFromFileDescr,
33 ArchitectureFromLibraryDescr,
34 ParameterizedSize_N,
35)
36from bioimageio.spec.utils import load_array
38from .axis import AxisId, AxisInfo, AxisLike, PerAxis
39from .block_meta import split_multiple_shapes_into_blocks
40from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
41from .io import load_tensor
42from .sample import (
43 LinearSampleAxisTransform,
44 Sample,
45 SampleBlockMeta,
46 sample_block_meta_generator,
47)
48from .stat_measures import Stat
49from .tensor import Tensor
52def import_callable(
53 node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr],
54 /,
55 **kwargs: Unpack[HashKwargs],
56) -> Callable[..., Any]:
57 """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
58 if isinstance(node, CallableFromDepencency):
59 module = importlib.import_module(node.module_name)
60 c = getattr(module, str(node.callable_name))
61 elif isinstance(node, ArchitectureFromLibraryDescr):
62 module = importlib.import_module(node.import_from)
63 c = getattr(module, str(node.callable))
64 elif isinstance(node, CallableFromFile):
65 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs)
66 elif isinstance(node, ArchitectureFromFileDescr):
67 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256)
69 else:
70 assert_never(node)
72 if not callable(c):
73 raise ValueError(f"{node} (imported: {c}) is not callable")
75 return c
78def _import_from_file_impl(
79 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
80):
81 local_file = resolve_and_extract(source, **kwargs)
82 module_name = local_file.path.stem
83 importlib_spec = importlib.util.spec_from_file_location(
84 module_name, local_file.path
85 )
86 if importlib_spec is None:
87 raise ImportError(f"Failed to import {module_name} from {source}.")
89 dep = importlib.util.module_from_spec(importlib_spec)
90 importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"?
91 return getattr(dep, callable_name)
94def get_axes_infos(
95 io_descr: Union[
96 v0_4.InputTensorDescr,
97 v0_4.OutputTensorDescr,
98 v0_5.InputTensorDescr,
99 v0_5.OutputTensorDescr,
100 ],
101) -> List[AxisInfo]:
102 """get a unified, simplified axis representation from spec axes"""
103 return [
104 (
105 AxisInfo.create("i")
106 if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x")
107 else AxisInfo.create(a)
108 )
109 for a in io_descr.axes
110 ]
113def get_member_id(
114 tensor_description: Union[
115 v0_4.InputTensorDescr,
116 v0_4.OutputTensorDescr,
117 v0_5.InputTensorDescr,
118 v0_5.OutputTensorDescr,
119 ],
120) -> MemberId:
121 """get the normalized tensor ID, usable as a sample member ID"""
123 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)):
124 return MemberId(tensor_description.name)
125 elif isinstance(
126 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)
127 ):
128 return tensor_description.id
129 else:
130 assert_never(tensor_description)
133def get_member_ids(
134 tensor_descriptions: Sequence[
135 Union[
136 v0_4.InputTensorDescr,
137 v0_4.OutputTensorDescr,
138 v0_5.InputTensorDescr,
139 v0_5.OutputTensorDescr,
140 ]
141 ],
142) -> List[MemberId]:
143 """get normalized tensor IDs to be used as sample member IDs"""
144 return [get_member_id(descr) for descr in tensor_descriptions]
147def get_test_inputs(model: AnyModelDescr) -> Sample:
148 """returns a model's test input sample"""
149 member_ids = get_member_ids(model.inputs)
150 if isinstance(model, v0_4.ModelDescr):
151 arrays = [load_array(tt) for tt in model.test_inputs]
152 else:
153 arrays = [load_array(d.test_tensor) for d in model.inputs]
155 axes = [get_axes_infos(t) for t in model.inputs]
156 return Sample(
157 members={
158 m: Tensor.from_numpy(arr, dims=ax)
159 for m, arr, ax in zip(member_ids, arrays, axes)
160 },
161 stat={},
162 id="test-sample",
163 )
166def get_test_outputs(model: AnyModelDescr) -> Sample:
167 """returns a model's test output sample"""
168 member_ids = get_member_ids(model.outputs)
170 if isinstance(model, v0_4.ModelDescr):
171 arrays = [load_array(tt) for tt in model.test_outputs]
172 else:
173 arrays = [load_array(d.test_tensor) for d in model.outputs]
175 axes = [get_axes_infos(t) for t in model.outputs]
177 return Sample(
178 members={
179 m: Tensor.from_numpy(arr, dims=ax)
180 for m, arr, ax in zip(member_ids, arrays, axes)
181 },
182 stat={},
183 id="test-sample",
184 )
187class IO_SampleBlockMeta(NamedTuple):
188 input: SampleBlockMeta
189 output: SampleBlockMeta
192def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]):
193 """returns which halo input tensors need to be divided into blocks with, such that
194 `output_halo` can be cropped from their outputs without introducing gaps."""
195 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {}
196 outputs = {t.id: t for t in model.outputs}
197 all_tensors = {**{t.id: t for t in model.inputs}, **outputs}
199 for t, th in output_halo.items():
200 axes = {a.id: a for a in outputs[t].axes}
202 for a, ah in th.items():
203 s = axes[a].size
204 if not isinstance(s, v0_5.SizeReference):
205 raise ValueError(
206 f"Unable to map output halo for {t}.{a} to an input axis"
207 )
209 axis = axes[a]
210 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id]
212 total_output_halo = sum(ah)
213 total_input_halo = total_output_halo * axis.scale / ref_axis.scale
214 assert (
215 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0
216 )
217 input_halo.setdefault(s.tensor_id, {})[a] = Halo(
218 int(total_input_halo // 2), int(total_input_halo // 2)
219 )
221 return input_halo
224def get_block_transform(
225 model: v0_5.ModelDescr,
226) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]:
227 """returns how a model's output tensor shapes relates to its input shapes"""
228 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {}
229 batch_axis_trf = None
230 for ipt in model.inputs:
231 for a in ipt.axes:
232 if a.type == "batch":
233 batch_axis_trf = LinearSampleAxisTransform(
234 axis=a.id, scale=1, offset=0, member=ipt.id
235 )
236 break
237 if batch_axis_trf is not None:
238 break
239 axis_scales = {
240 t.id: {a.id: a.scale for a in t.axes}
241 for t in chain(model.inputs, model.outputs)
242 }
243 for out in model.outputs:
244 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {}
245 for a in out.axes:
246 if a.size is None:
247 assert a.type == "batch"
248 if batch_axis_trf is None:
249 raise ValueError(
250 "no batch axis found in any input tensor, but output tensor"
251 + f" '{out.id}' has one."
252 )
253 s = batch_axis_trf
254 elif isinstance(a.size, int):
255 s = a.size
256 elif isinstance(a.size, v0_5.DataDependentSize):
257 s = -1
258 elif isinstance(a.size, v0_5.SizeReference):
259 s = LinearSampleAxisTransform(
260 axis=a.size.axis_id,
261 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale,
262 offset=a.size.offset,
263 member=a.size.tensor_id,
264 )
265 else:
266 assert_never(a.size)
268 new_axes[a.id] = s
270 ret[out.id] = new_axes
272 return ret
275def get_io_sample_block_metas(
276 model: v0_5.ModelDescr,
277 input_sample_shape: PerMember[PerAxis[int]],
278 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N],
279 batch_size: int = 1,
280) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]:
281 """returns an iterable yielding meta data for corresponding input and output samples"""
282 if not isinstance(model, v0_5.ModelDescr):
283 raise TypeError(f"get_block_meta() not implemented for {type(model)}")
285 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size)
286 input_block_shape = {
287 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t}
288 for t in {tt for tt, _ in block_axis_sizes.inputs}
289 }
290 output_halo = {
291 t.id: {
292 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)
293 }
294 for t in model.outputs
295 }
296 input_halo = get_input_halo(model, output_halo)
298 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks(
299 input_sample_shape, input_block_shape, halo=input_halo
300 )
301 block_transform = get_block_transform(model)
302 return n_input_blocks, (
303 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform))
304 for ipt in sample_block_meta_generator(
305 input_blocks, sample_shape=input_sample_shape, sample_id=None
306 )
307 )
310def get_tensor(
311 src: Union[Tensor, xr.DataArray, NDArray[Any], Path],
312 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
313):
314 """helper to cast/load various tensor sources"""
316 if isinstance(src, Tensor):
317 return src
319 if isinstance(src, xr.DataArray):
320 return Tensor.from_xarray(src)
322 if isinstance(src, np.ndarray):
323 return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
325 if isinstance(src, Path):
326 return load_tensor(src, axes=get_axes_infos(ipt))
328 assert_never(src)
331def create_sample_for_model(
332 model: AnyModelDescr,
333 *,
334 stat: Optional[Stat] = None,
335 sample_id: SampleId = None,
336 inputs: Optional[
337 PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]
338 ] = None, # TODO: make non-optional
339 **kwargs: NDArray[Any], # TODO: deprecate in favor of `inputs`
340) -> Sample:
341 """Create a sample from a single set of input(s) for a specific bioimage.io model
343 Args:
344 model: a bioimage.io model description
345 stat: dictionary with sample and dataset statistics (may be updated in-place!)
346 inputs: the input(s) constituting a single sample.
347 """
348 inputs = {MemberId(k): v for k, v in {**kwargs, **(inputs or {})}.items()}
350 model_inputs = {get_member_id(d): d for d in model.inputs}
351 if unknown := {k for k in inputs if k not in model_inputs}:
352 raise ValueError(f"Got unexpected inputs: {unknown}")
354 if missing := {
355 k
356 for k, v in model_inputs.items()
357 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional)
358 }:
359 raise ValueError(f"Missing non-optional model inputs: {missing}")
361 return Sample(
362 members={
363 m: get_tensor(inputs[m], ipt)
364 for m, ipt in model_inputs.items()
365 if m in inputs
366 },
367 stat={} if stat is None else stat,
368 id=sample_id,
369 )
372def load_sample_for_model(
373 *,
374 model: AnyModelDescr,
375 paths: PerMember[Path],
376 axes: Optional[PerMember[Sequence[AxisLike]]] = None,
377 stat: Optional[Stat] = None,
378 sample_id: Optional[SampleId] = None,
379):
380 """load a single sample from `paths` that can be processed by `model`"""
382 if axes is None:
383 axes = {}
385 # make sure members are keyed by MemberId, not string
386 paths = {MemberId(k): v for k, v in paths.items()}
387 axes = {MemberId(k): v for k, v in axes.items()}
389 model_inputs = {get_member_id(d): d for d in model.inputs}
391 if unknown := {k for k in paths if k not in model_inputs}:
392 raise ValueError(f"Got unexpected paths for {unknown}")
394 if unknown := {k for k in axes if k not in model_inputs}:
395 raise ValueError(f"Got unexpected axes hints for: {unknown}")
397 members: Dict[MemberId, Tensor] = {}
398 for m, p in paths.items():
399 if m not in axes:
400 axes[m] = get_axes_infos(model_inputs[m])
401 logger.debug(
402 "loading '{}' from {} with default input axes {} ",
403 m,
404 p,
405 axes[m],
406 )
407 members[m] = load_tensor(p, axes[m])
409 return Sample(
410 members=members,
411 stat={} if stat is None else stat,
412 id=sample_id or tuple(sorted(paths.values())),
413 )