Coverage for src / bioimageio / core / digest_spec.py: 85%

220 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 18:38 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import importlib.util 

5import sys 

6from itertools import chain 

7from pathlib import Path 

8from tempfile import TemporaryDirectory 

9from typing import ( 

10 Any, 

11 Callable, 

12 Dict, 

13 Iterable, 

14 List, 

15 Mapping, 

16 NamedTuple, 

17 Optional, 

18 Sequence, 

19 Tuple, 

20 Union, 

21) 

22from zipfile import ZipFile, is_zipfile 

23 

24import numpy as np 

25import xarray as xr 

26from loguru import logger 

27from numpy.typing import NDArray 

28from typing_extensions import Unpack, assert_never 

29 

30from bioimageio.spec._internal.io import HashKwargs, PermissiveFileSource 

31from bioimageio.spec.common import FileDescr, FileSource 

32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 

34from bioimageio.spec.model.v0_5 import ( 

35 ArchitectureFromFileDescr, 

36 ArchitectureFromLibraryDescr, 

37 ParameterizedSize_N, 

38) 

39from bioimageio.spec.utils import load_array 

40 

41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 

42from .block_meta import split_multiple_shapes_into_blocks 

43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 

44from .io import load_tensor 

45from .sample import ( 

46 LinearSampleAxisTransform, 

47 Sample, 

48 SampleBlockMeta, 

49 sample_block_meta_generator, 

50) 

51from .stat_measures import Stat 

52from .tensor import Tensor 

53 

54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], PermissiveFileSource] 

55 

56 

57def import_callable( 

58 node: Union[ 

59 ArchitectureFromFileDescr, 

60 ArchitectureFromLibraryDescr, 

61 CallableFromDepencency, 

62 CallableFromFile, 

63 ], 

64 /, 

65 **kwargs: Unpack[HashKwargs], 

66) -> Callable[..., Any]: 

67 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 

68 if isinstance(node, CallableFromDepencency): 

69 module = importlib.import_module(node.module_name) 

70 c = getattr(module, str(node.callable_name)) 

71 elif isinstance(node, ArchitectureFromLibraryDescr): 

72 module = importlib.import_module(node.import_from) 

73 c = getattr(module, str(node.callable)) 

74 elif isinstance(node, CallableFromFile): 

75 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 

76 elif isinstance(node, ArchitectureFromFileDescr): 

77 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 

78 else: 

79 assert_never(node) 

80 

81 if not callable(c): 

82 raise ValueError(f"{node} (imported: {c}) is not callable") 

83 

84 return c 

85 

86 

87tmp_dirs_in_use: List[TemporaryDirectory[str]] = [] 

88"""keep global reference to temporary directories created during import to delay cleanup""" 

89 

90 

91def _import_from_file_impl( 

92 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 

93): 

94 src_descr = FileDescr(source=source, **kwargs) 

95 # ensure sha is valid even if perform_io_checks=False 

96 # or the source has changed since last sha computation 

97 src_descr.validate_sha256(force_recompute=True) 

98 assert src_descr.sha256 is not None 

99 source_sha = src_descr.sha256 

100 

101 reader = src_descr.get_reader() 

102 # make sure we have unique module name 

103 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}" 

104 

105 # make sure we have a unique and valid module name 

106 if not module_name.isidentifier(): 

107 module_name = f"custom_module_{source_sha}" 

108 assert module_name.isidentifier(), module_name 

109 

110 source_bytes = reader.read() 

111 

112 module = sys.modules.get(module_name) 

113 if module is None: 

114 try: 

115 td_kwargs: Dict[str, Any] = ( 

116 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

117 ) 

118 if sys.version_info >= (3, 12): 

119 td_kwargs["delete"] = False 

120 

121 tmp_dir = TemporaryDirectory(**td_kwargs) 

122 # keep global ref to tmp_dir to delay cleanup until program exit 

123 # TODO: remove for py >= 3.12, when delete=False works 

124 tmp_dirs_in_use.append(tmp_dir) 

125 

126 module_path = Path(tmp_dir.name) / module_name 

127 if reader.original_file_name.endswith(".zip") or is_zipfile(reader): 

128 module_path.mkdir() 

129 ZipFile(reader).extractall(path=module_path) 

130 else: 

131 module_path = module_path.with_suffix(".py") 

132 _ = module_path.write_bytes(source_bytes) 

133 

134 importlib_spec = importlib.util.spec_from_file_location( 

135 module_name, str(module_path) 

136 ) 

137 

138 if importlib_spec is None: 

139 raise ImportError(f"Failed to import {source}") 

140 

141 module = importlib.util.module_from_spec(importlib_spec) 

142 

143 sys.modules[module_name] = module # cache this module 

144 

145 assert importlib_spec.loader is not None 

146 importlib_spec.loader.exec_module(module) 

147 

148 except Exception as e: 

149 if module_name in sys.modules: 

150 del sys.modules[module_name] 

151 

152 raise ImportError(f"Failed to import {source}") from e 

153 

154 try: 

155 callable_attr = getattr(module, callable_name) 

156 except AttributeError as e: 

157 raise AttributeError( 

158 f"Imported custom module from {source} has no `{callable_name}` attribute." 

159 ) from e 

160 except Exception as e: 

161 raise AttributeError( 

162 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ." 

163 ) from e 

164 

165 else: 

166 return callable_attr 

167 

168 

169def get_axes_infos( 

170 io_descr: Union[ 

171 v0_4.InputTensorDescr, 

172 v0_4.OutputTensorDescr, 

173 v0_5.InputTensorDescr, 

174 v0_5.OutputTensorDescr, 

175 ], 

176) -> List[AxisInfo]: 

177 """get a unified, simplified axis representation from spec axes""" 

178 ret: List[AxisInfo] = [] 

179 for a in io_descr.axes: 

180 if isinstance(a, v0_5.AxisBase): 

181 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 

182 else: 

183 assert a in ("b", "i", "t", "c", "z", "y", "x") 

184 ret.append(AxisInfo.create(a)) 

185 

186 return ret 

187 

188 

189def get_member_id( 

190 tensor_description: Union[ 

191 v0_4.InputTensorDescr, 

192 v0_4.OutputTensorDescr, 

193 v0_5.InputTensorDescr, 

194 v0_5.OutputTensorDescr, 

195 ], 

196) -> MemberId: 

197 """get the normalized tensor ID, usable as a sample member ID""" 

198 

199 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 

200 return MemberId(tensor_description.name) 

201 elif isinstance( 

202 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 

203 ): 

204 return tensor_description.id 

205 else: 

206 assert_never(tensor_description) 

207 

208 

209def get_member_ids( 

210 tensor_descriptions: Sequence[ 

211 Union[ 

212 v0_4.InputTensorDescr, 

213 v0_4.OutputTensorDescr, 

214 v0_5.InputTensorDescr, 

215 v0_5.OutputTensorDescr, 

216 ] 

217 ], 

218) -> List[MemberId]: 

219 """get normalized tensor IDs to be used as sample member IDs""" 

220 return [get_member_id(descr) for descr in tensor_descriptions] 

221 

222 

223def get_test_input_sample(model: AnyModelDescr) -> Sample: 

224 return _get_test_sample( 

225 model.inputs, 

226 model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs, 

227 ) 

228 

229 

230get_test_inputs = get_test_input_sample 

231"""DEPRECATED: use `get_test_input_sample` instead""" 

232 

233 

234def get_test_output_sample(model: AnyModelDescr) -> Sample: 

235 """returns a model's test output sample""" 

236 return _get_test_sample( 

237 model.outputs, 

238 model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs, 

239 ) 

240 

241 

242get_test_outputs = get_test_output_sample 

243"""DEPRECATED: use `get_test_input_sample` instead""" 

244 

245 

246def _get_test_sample( 

247 tensor_descrs: Sequence[ 

248 Union[ 

249 v0_4.InputTensorDescr, 

250 v0_4.OutputTensorDescr, 

251 v0_5.InputTensorDescr, 

252 v0_5.OutputTensorDescr, 

253 ] 

254 ], 

255 test_sources: Sequence[Union[FileSource, v0_5.TensorDescr]], 

256) -> Sample: 

257 """returns a model's input/output test sample""" 

258 member_ids = get_member_ids(tensor_descrs) 

259 arrays: List[NDArray[Any]] = [] 

260 for src in test_sources: 

261 if isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)): 

262 if src.test_tensor is None: 

263 raise ValueError( 

264 f"Model input '{src.id}' has no test tensor defined, cannot create test sample." 

265 ) 

266 arrays.append(load_array(src.test_tensor)) 

267 else: 

268 arrays.append(load_array(src)) 

269 

270 axes = [get_axes_infos(t) for t in tensor_descrs] 

271 return Sample( 

272 members={ 

273 m: Tensor.from_numpy(arr, dims=ax) 

274 for m, arr, ax in zip(member_ids, arrays, axes) 

275 }, 

276 stat={}, 

277 id="test-sample", 

278 ) 

279 

280 

281class IO_SampleBlockMeta(NamedTuple): 

282 input: SampleBlockMeta 

283 output: SampleBlockMeta 

284 

285 

286def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 

287 """returns which halo input tensors need to be divided into blocks with, such that 

288 `output_halo` can be cropped from their outputs without introducing gaps.""" 

289 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 

290 outputs = {t.id: t for t in model.outputs} 

291 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 

292 

293 for t, th in output_halo.items(): 

294 axes = {a.id: a for a in outputs[t].axes} 

295 

296 for a, ah in th.items(): 

297 s = axes[a].size 

298 if not isinstance(s, v0_5.SizeReference): 

299 raise ValueError( 

300 f"Unable to map output halo for {t}.{a} to an input axis" 

301 ) 

302 

303 axis = axes[a] 

304 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 

305 

306 input_halo_left = ah.left * axis.scale / ref_axis.scale 

307 input_halo_right = ah.right * axis.scale / ref_axis.scale 

308 assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int" 

309 assert input_halo_right == int(input_halo_right), ( 

310 f"{input_halo_right} not int" 

311 ) 

312 

313 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 

314 int(input_halo_left), int(input_halo_right) 

315 ) 

316 

317 return input_halo 

318 

319 

320def get_block_transform( 

321 model: v0_5.ModelDescr, 

322) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 

323 """returns how a model's output tensor shapes relates to its input shapes""" 

324 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 

325 batch_axis_trf = None 

326 for ipt in model.inputs: 

327 for a in ipt.axes: 

328 if a.type == "batch": 

329 batch_axis_trf = LinearSampleAxisTransform( 

330 axis=a.id, scale=1, offset=0, member=ipt.id 

331 ) 

332 break 

333 if batch_axis_trf is not None: 

334 break 

335 axis_scales = { 

336 t.id: {a.id: a.scale for a in t.axes} 

337 for t in chain(model.inputs, model.outputs) 

338 } 

339 for out in model.outputs: 

340 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 

341 for a in out.axes: 

342 if a.size is None: 

343 assert a.type == "batch" 

344 if batch_axis_trf is None: 

345 raise ValueError( 

346 "no batch axis found in any input tensor, but output tensor" 

347 + f" '{out.id}' has one." 

348 ) 

349 s = batch_axis_trf 

350 elif isinstance(a.size, int): 

351 s = a.size 

352 elif isinstance(a.size, v0_5.DataDependentSize): 

353 s = -1 

354 elif isinstance(a.size, v0_5.SizeReference): 

355 s = LinearSampleAxisTransform( 

356 axis=a.size.axis_id, 

357 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 

358 offset=a.size.offset, 

359 member=a.size.tensor_id, 

360 ) 

361 else: 

362 assert_never(a.size) 

363 

364 new_axes[a.id] = s 

365 

366 ret[out.id] = new_axes 

367 

368 return ret 

369 

370 

371def get_io_sample_block_metas( 

372 model: v0_5.ModelDescr, 

373 input_sample_shape: PerMember[PerAxis[int]], 

374 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 

375 batch_size: int = 1, 

376) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 

377 """returns an iterable yielding meta data for corresponding input and output samples""" 

378 if not isinstance(model, v0_5.ModelDescr): 

379 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 

380 

381 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 

382 input_block_shape = { 

383 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 

384 for t in {tt for tt, _ in block_axis_sizes.inputs} 

385 } 

386 output_halo = { 

387 t.id: { 

388 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 

389 } 

390 for t in model.outputs 

391 } 

392 input_halo = get_input_halo(model, output_halo) 

393 

394 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 

395 input_sample_shape, input_block_shape, halo=input_halo 

396 ) 

397 block_transform = get_block_transform(model) 

398 return n_input_blocks, ( 

399 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 

400 for ipt in sample_block_meta_generator( 

401 input_blocks, sample_shape=input_sample_shape, sample_id=None 

402 ) 

403 ) 

404 

405 

406def get_tensor( 

407 src: TensorSource, 

408 descr: Union[ 

409 v0_4.InputTensorDescr, 

410 v0_5.InputTensorDescr, 

411 v0_4.OutputTensorDescr, 

412 v0_5.OutputTensorDescr, 

413 Sequence[AxisInfo], 

414 ], 

415): 

416 """helper to cast/load various tensor sources""" 

417 

418 if isinstance( 

419 descr, 

420 ( 

421 v0_4.InputTensorDescr, 

422 v0_5.InputTensorDescr, 

423 v0_4.OutputTensorDescr, 

424 v0_5.OutputTensorDescr, 

425 ), 

426 ): 

427 axes = get_axes_infos(descr) 

428 else: 

429 axes = descr 

430 

431 if isinstance(src, Tensor): 

432 return src.transpose(axes=[a.id for a in axes]) 

433 elif isinstance(src, xr.DataArray): 

434 return Tensor.from_xarray(src).transpose(axes=[a.id for a in axes]) 

435 elif isinstance(src, np.ndarray): 

436 return Tensor.from_numpy(src, dims=axes) 

437 else: 

438 return load_tensor(src, axes=axes) 

439 

440 

441def create_sample_for_model( 

442 model: AnyModelDescr, 

443 *, 

444 stat: Optional[Stat] = None, 

445 sample_id: SampleId = None, 

446 inputs: Union[PerMember[TensorSource], TensorSource], 

447) -> Sample: 

448 """Create a sample from a single set of input(s) for a specific bioimage.io model 

449 

450 Args: 

451 model: a bioimage.io model description 

452 stat: dictionary with sample and dataset statistics (may be updated in-place!) 

453 inputs: the input(s) constituting a single sample. 

454 """ 

455 

456 model_inputs = {get_member_id(d): d for d in model.inputs} 

457 if isinstance(inputs, collections.abc.Mapping): 

458 inputs = {MemberId(k): v for k, v in inputs.items()} 

459 elif len(model_inputs) == 1: 

460 inputs = {list(model_inputs)[0]: inputs} 

461 else: 

462 raise TypeError( 

463 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 

464 ) 

465 

466 if unknown := {k for k in inputs if k not in model_inputs}: 

467 raise ValueError(f"Got unexpected inputs: {unknown}") 

468 

469 if missing := { 

470 k 

471 for k, v in model_inputs.items() 

472 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 

473 }: 

474 raise ValueError(f"Missing non-optional model inputs: {missing}") 

475 

476 return Sample( 

477 members={ 

478 m: get_tensor(inputs[m], ipt) 

479 for m, ipt in model_inputs.items() 

480 if m in inputs 

481 }, 

482 stat={} if stat is None else stat, 

483 id=sample_id, 

484 ) 

485 

486 

487def load_sample_for_model( 

488 *, 

489 model: AnyModelDescr, 

490 paths: PerMember[Path], 

491 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 

492 stat: Optional[Stat] = None, 

493 sample_id: Optional[SampleId] = None, 

494): 

495 """load a single sample from `paths` that can be processed by `model`""" 

496 

497 if axes is None: 

498 axes = {} 

499 

500 # make sure members are keyed by MemberId, not string 

501 paths = {MemberId(k): v for k, v in paths.items()} 

502 axes = {MemberId(k): v for k, v in axes.items()} 

503 

504 model_inputs = {get_member_id(d): d for d in model.inputs} 

505 

506 if unknown := {k for k in paths if k not in model_inputs}: 

507 raise ValueError(f"Got unexpected paths for {unknown}") 

508 

509 if unknown := {k for k in axes if k not in model_inputs}: 

510 raise ValueError(f"Got unexpected axes hints for: {unknown}") 

511 

512 members: Dict[MemberId, Tensor] = {} 

513 for m, p in paths.items(): 

514 if m not in axes: 

515 axes[m] = get_axes_infos(model_inputs[m]) 

516 logger.debug( 

517 "loading '{}' from {} with default input axes {} ", 

518 m, 

519 p, 

520 axes[m], 

521 ) 

522 members[m] = load_tensor(p, axes[m]) 

523 

524 return Sample( 

525 members=members, 

526 stat={} if stat is None else stat, 

527 id=sample_id or tuple(sorted(paths.values())), 

528 )