Coverage for src/bioimageio/core/digest_spec.py: 83%

239 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import importlib.util 

5import sys 

6from itertools import chain 

7from pathlib import Path 

8from tempfile import TemporaryDirectory 

9from typing import ( 

10 Any, 

11 Callable, 

12 Dict, 

13 Iterable, 

14 List, 

15 Mapping, 

16 NamedTuple, 

17 Optional, 

18 Sequence, 

19 Tuple, 

20 Union, 

21) 

22from zipfile import ZipFile, is_zipfile 

23 

24import numpy as np 

25import xarray as xr 

26from loguru import logger 

27from numpy.typing import NDArray 

28from typing_extensions import TypeAlias, Unpack, assert_never 

29 

30from bioimageio.spec._internal.io import HashKwargs, PermissiveFileSource 

31from bioimageio.spec.common import FileDescr, FileSource 

32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 

34from bioimageio.spec.model.v0_5 import ( 

35 ArchitectureFromFileDescr, 

36 ArchitectureFromLibraryDescr, 

37 ParameterizedSize_N, 

38) 

39from bioimageio.spec.utils import load_array 

40 

41from .axis import AxisId, AxisInfo, AxisLike, PerAxis 

42from .block_meta import split_multiple_shapes_into_blocks 

43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 

44from .io import load_tensor 

45from .sample import ( 

46 LinearSampleAxisTransform, 

47 Sample, 

48 SampleBlockMeta, 

49 SampleBlockWithOrigin, 

50 sample_block_meta_generator, 

51) 

52from .stat_measures import Stat 

53from .tensor import Tensor 

54 

55TensorSource: TypeAlias = Union[ 

56 Tensor, xr.DataArray, NDArray[Any], PermissiveFileSource 

57] 

58 

59 

60def import_callable( 

61 node: Union[ 

62 ArchitectureFromFileDescr, 

63 ArchitectureFromLibraryDescr, 

64 CallableFromDepencency, 

65 CallableFromFile, 

66 v0_5.CustomProcessingDescr, 

67 ], 

68 /, 

69 **kwargs: Unpack[HashKwargs], 

70) -> Callable[..., Any]: 

71 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 

72 if isinstance(node, CallableFromDepencency): 

73 module = importlib.import_module(node.module_name) 

74 c = getattr(module, str(node.callable_name)) 

75 elif isinstance(node, ArchitectureFromLibraryDescr): 

76 module = importlib.import_module(node.import_from) 

77 c = getattr(module, str(node.callable)) 

78 elif isinstance(node, CallableFromFile): 

79 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 

80 elif isinstance(node, (ArchitectureFromFileDescr, v0_5.CustomProcessingDescr)): 

81 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 

82 else: 

83 assert_never(node) 

84 

85 if not callable(c): 

86 raise ValueError(f"{node} (imported: {c}) is not callable") 

87 

88 return c 

89 

90 

91tmp_dirs_in_use: List[TemporaryDirectory[str]] = [] 

92"""keep global reference to temporary directories created during import to delay cleanup""" 

93 

94 

95def _import_from_file_impl( 

96 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 

97): 

98 src_descr = FileDescr(source=source, **kwargs) 

99 # ensure sha is valid even if perform_io_checks=False 

100 # or the source has changed since last sha computation 

101 src_descr.validate_sha256(force_recompute=True) 

102 assert src_descr.sha256 is not None 

103 source_sha = src_descr.sha256 

104 

105 reader = src_descr.get_reader() 

106 # make sure we have unique module name 

107 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}" 

108 

109 # make sure we have a unique and valid module name 

110 if not module_name.isidentifier(): 

111 module_name = f"custom_module_{source_sha}" 

112 assert module_name.isidentifier(), module_name 

113 

114 source_bytes = reader.read() 

115 

116 module = sys.modules.get(module_name) 

117 if module is None: 

118 try: 

119 td_kwargs: Dict[str, Any] = ( 

120 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

121 ) 

122 if sys.version_info >= (3, 12): 

123 td_kwargs["delete"] = False 

124 

125 tmp_dir = TemporaryDirectory(**td_kwargs) 

126 # keep global ref to tmp_dir to delay cleanup until program exit 

127 # TODO: remove for py >= 3.12, when delete=False works 

128 tmp_dirs_in_use.append(tmp_dir) 

129 

130 module_path = Path(tmp_dir.name) / module_name 

131 if reader.original_file_name.endswith(".zip") or is_zipfile(reader): 

132 module_path.mkdir() 

133 ZipFile(reader).extractall(path=module_path) 

134 else: 

135 module_path = module_path.with_suffix(".py") 

136 _ = module_path.write_bytes(source_bytes) 

137 

138 importlib_spec = importlib.util.spec_from_file_location( 

139 module_name, str(module_path) 

140 ) 

141 

142 if importlib_spec is None: 

143 raise ImportError(f"Failed to import {source}") 

144 

145 module = importlib.util.module_from_spec(importlib_spec) 

146 

147 sys.modules[module_name] = module # cache this module 

148 

149 assert importlib_spec.loader is not None 

150 importlib_spec.loader.exec_module(module) 

151 

152 except Exception as e: 

153 if module_name in sys.modules: 

154 del sys.modules[module_name] 

155 

156 raise ImportError(f"Failed to import {source}") from e 

157 

158 try: 

159 callable_attr = getattr(module, callable_name) 

160 except AttributeError as e: 

161 raise AttributeError( 

162 f"Imported custom module from {source} has no `{callable_name}` attribute." 

163 ) from e 

164 except Exception as e: 

165 raise AttributeError( 

166 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ." 

167 ) from e 

168 

169 else: 

170 return callable_attr 

171 

172 

173def get_axes_infos( 

174 io_descr: Union[ 

175 v0_4.InputTensorDescr, 

176 v0_4.OutputTensorDescr, 

177 v0_5.InputTensorDescr, 

178 v0_5.OutputTensorDescr, 

179 ], 

180) -> List[AxisInfo]: 

181 """get a unified, simplified axis representation from spec axes""" 

182 return [AxisInfo.create(a) for a in io_descr.axes] 

183 

184 

185def get_member_id( 

186 tensor_description: Union[ 

187 v0_4.InputTensorDescr, 

188 v0_4.OutputTensorDescr, 

189 v0_5.InputTensorDescr, 

190 v0_5.OutputTensorDescr, 

191 ], 

192) -> MemberId: 

193 """get the normalized tensor ID, usable as a sample member ID""" 

194 

195 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 

196 return MemberId(tensor_description.name) 

197 elif isinstance( 

198 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 

199 ): 

200 return tensor_description.id 

201 else: 

202 assert_never(tensor_description) 

203 

204 

205def get_member_ids( 

206 tensor_descriptions: Iterable[ 

207 Union[ 

208 v0_4.InputTensorDescr, 

209 v0_4.OutputTensorDescr, 

210 v0_5.InputTensorDescr, 

211 v0_5.OutputTensorDescr, 

212 ] 

213 ], 

214) -> List[MemberId]: 

215 """get normalized tensor IDs to be used as sample member IDs""" 

216 return [get_member_id(descr) for descr in tensor_descriptions] 

217 

218 

219def get_test_input_sample(model: AnyModelDescr) -> Sample: 

220 if isinstance(model, v0_4.ModelDescr): 

221 info = { 

222 MemberId(d.name): (d, t) for d, t in zip(model.inputs, model.test_inputs) 

223 } 

224 else: 

225 info = {d.id: d for d in model.inputs} 

226 

227 return _get_test_sample(info) 

228 

229 

230get_test_inputs = get_test_input_sample 

231"""DEPRECATED: use `get_test_input_sample` instead""" 

232 

233 

234def get_test_output_sample(model: AnyModelDescr) -> Sample: 

235 """returns a model's test output sample""" 

236 if isinstance(model, v0_4.ModelDescr): 

237 info = { 

238 MemberId(d.name): (d, t) for d, t in zip(model.outputs, model.test_outputs) 

239 } 

240 else: 

241 info = {d.id: d for d in model.outputs} 

242 

243 return _get_test_sample(info) 

244 

245 

246get_test_outputs = get_test_output_sample 

247"""DEPRECATED: use `get_test_output_sample` instead""" 

248 

249 

250def _get_test_sample( 

251 info: Union[ 

252 Mapping[MemberId, Union[v0_5.InputTensorDescr, v0_5.OutputTensorDescr]], 

253 Mapping[ 

254 MemberId, 

255 Tuple[ 

256 v0_4.InputTensorDescr, 

257 FileSource, 

258 ], 

259 ], 

260 Mapping[ 

261 MemberId, 

262 Tuple[ 

263 v0_4.OutputTensorDescr, 

264 FileSource, 

265 ], 

266 ], 

267 ], 

268) -> Sample: 

269 arrays: Dict[MemberId, NDArray[Any]] = {} 

270 for m, src in info.items(): 

271 if isinstance(src, tuple): 

272 arrays[m] = load_array(src[1]) 

273 elif isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)): 

274 if src.test_tensor is None: 

275 raise ValueError( 

276 f"Model input '{m}' has no test tensor defined, cannot create test sample." 

277 ) 

278 arrays[m] = load_array(src.test_tensor) 

279 else: 

280 assert_never(src) 

281 

282 axes = { 

283 m: get_axes_infos(t[0] if isinstance(t, tuple) else t) for m, t in info.items() 

284 } 

285 return Sample( 

286 members={m: Tensor.from_numpy(arrays[m], dims=axes[m]) for m in info}, 

287 stat={}, 

288 id="test-sample", 

289 ) 

290 

291 

292class IO_SampleBlockMeta(NamedTuple): 

293 input: SampleBlockMeta 

294 output: SampleBlockMeta 

295 

296 

297def get_input_halo( 

298 model: v0_5.ModelDescr, output_halo: Optional[PerMember[PerAxis[Halo]]] = None 

299): 

300 """returns which halo input tensors need to be divided into blocks with, such that 

301 `output_halo` can be cropped from their outputs without introducing gaps.""" 

302 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 

303 outputs = {t.id: t for t in model.outputs} 

304 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 

305 if output_halo is None: 

306 output_halo = { 

307 t.id: { 

308 a.id: Halo(a.halo, a.halo) 

309 for a in t.axes 

310 if isinstance(a, v0_5.WithHalo) 

311 } 

312 for t in model.outputs 

313 } 

314 

315 for t, th in output_halo.items(): 

316 axes = {a.id: a for a in outputs[t].axes} 

317 

318 for a, ah in th.items(): 

319 s = axes[a].size 

320 if not isinstance(s, v0_5.SizeReference): 

321 raise ValueError( 

322 f"Unable to map output halo for {t}.{a} to an input axis" 

323 ) 

324 

325 axis = axes[a] 

326 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 

327 

328 input_halo_left = ah.left * axis.scale / ref_axis.scale 

329 input_halo_right = ah.right * axis.scale / ref_axis.scale 

330 assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int" 

331 assert input_halo_right == int(input_halo_right), ( 

332 f"{input_halo_right} not int" 

333 ) 

334 

335 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 

336 int(input_halo_left), int(input_halo_right) 

337 ) 

338 

339 return input_halo 

340 

341 

342def get_block_transform( 

343 model: v0_5.ModelDescr, 

344) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 

345 """returns how a model's output tensor shapes relates to its input shapes""" 

346 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 

347 batch_axis_trf = None 

348 for ipt in model.inputs: 

349 for a in ipt.axes: 

350 if a.type == "batch": 

351 batch_axis_trf = LinearSampleAxisTransform( 

352 axis=a.id, scale=1, offset=0, member=ipt.id 

353 ) 

354 break 

355 if batch_axis_trf is not None: 

356 break 

357 axis_scales = { 

358 t.id: {a.id: a.scale for a in t.axes} 

359 for t in chain(model.inputs, model.outputs) 

360 } 

361 for out in model.outputs: 

362 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 

363 for a in out.axes: 

364 if a.size is None: 

365 assert a.type == "batch" 

366 if batch_axis_trf is None: 

367 raise ValueError( 

368 "no batch axis found in any input tensor, but output tensor" 

369 + f" '{out.id}' has one." 

370 ) 

371 s = batch_axis_trf 

372 elif isinstance(a.size, int): 

373 s = a.size 

374 elif isinstance(a.size, v0_5.DataDependentSize): 

375 s = -1 

376 elif isinstance(a.size, v0_5.SizeReference): 

377 s = LinearSampleAxisTransform( 

378 axis=a.size.axis_id, 

379 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 

380 offset=a.size.offset, 

381 member=a.size.tensor_id, 

382 ) 

383 else: 

384 assert_never(a.size) 

385 

386 new_axes[a.id] = s 

387 

388 # account for postprocessing that changes the nubmer of output channels by 

389 # overwriting described output shape by the intermediate output shape 

390 c = AxisId("channel") 

391 if c not in new_axes: 

392 continue 

393 for post in out.postprocessing: 

394 if post.id == "cellpose_flow_dynamics": 

395 new_axes[c] = 3 

396 break 

397 elif post.id == "stardist_postprocessing": 

398 new_axes[c] = post.kwargs.n_rays + 1 

399 break 

400 

401 ret[out.id] = new_axes 

402 

403 return ret 

404 

405 

406def get_io_sample_block_metas( 

407 model: v0_5.ModelDescr, 

408 input_sample_shape: PerMember[PerAxis[int]], 

409 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 

410 batch_size: int = 1, 

411) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 

412 """returns an iterable yielding meta data for corresponding input and output samples""" 

413 if not isinstance(model, v0_5.ModelDescr): 

414 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 

415 

416 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 

417 input_block_shape = { 

418 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 

419 for t in {tt for tt, _ in block_axis_sizes.inputs} 

420 } 

421 output_halo = { 

422 t.id: { 

423 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 

424 } 

425 for t in model.outputs 

426 } 

427 input_halo = get_input_halo(model, output_halo) 

428 

429 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 

430 input_sample_shape, input_block_shape, halo=input_halo 

431 ) 

432 block_transform = get_block_transform(model) 

433 return n_input_blocks, ( 

434 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 

435 for ipt in sample_block_meta_generator( 

436 input_blocks, sample_shape=input_sample_shape, sample_id=None 

437 ) 

438 ) 

439 

440 

441def get_tensor( 

442 src: TensorSource, 

443 descr: Union[ 

444 v0_4.InputTensorDescr, 

445 v0_5.InputTensorDescr, 

446 v0_4.OutputTensorDescr, 

447 v0_5.OutputTensorDescr, 

448 Sequence[AxisInfo], 

449 ], 

450): 

451 """helper to cast/load various tensor sources""" 

452 

453 if isinstance( 

454 descr, 

455 ( 

456 v0_4.InputTensorDescr, 

457 v0_5.InputTensorDescr, 

458 v0_4.OutputTensorDescr, 

459 v0_5.OutputTensorDescr, 

460 ), 

461 ): 

462 axes = get_axes_infos(descr) 

463 else: 

464 axes = descr 

465 

466 if isinstance(src, Tensor): 

467 return src.transpose(axes=[a.id for a in axes]) 

468 elif isinstance(src, xr.DataArray): 

469 return Tensor.from_xarray(src).transpose(axes=[a.id for a in axes]) 

470 elif isinstance(src, np.ndarray): 

471 return Tensor.from_numpy(src, dims=axes) 

472 else: 

473 return load_tensor(src, axes=axes) 

474 

475 

476def create_sample_for_model( 

477 model: AnyModelDescr, 

478 *, 

479 stat: Optional[Stat] = None, 

480 sample_id: SampleId = None, 

481 inputs: Union[PerMember[TensorSource], TensorSource], 

482) -> Sample: 

483 """Create a sample from a single set of input(s) for a specific bioimage.io model 

484 

485 Args: 

486 model: a bioimage.io model description 

487 stat: dictionary with sample and dataset statistics (may be updated in-place!) 

488 inputs: the input(s) constituting a single sample. 

489 """ 

490 

491 model_inputs = {get_member_id(d): d for d in model.inputs} 

492 if isinstance(inputs, collections.abc.Mapping): 

493 inputs = {MemberId(k): v for k, v in inputs.items()} 

494 elif len(model_inputs) == 1: 

495 inputs = {list(model_inputs)[0]: inputs} 

496 else: 

497 raise TypeError( 

498 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 

499 ) 

500 

501 if unknown := {k for k in inputs if k not in model_inputs}: 

502 raise ValueError(f"Got unexpected inputs: {unknown}") 

503 

504 if missing := { 

505 k 

506 for k, v in model_inputs.items() 

507 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 

508 }: 

509 raise ValueError(f"Missing non-optional model inputs: {missing}") 

510 

511 return Sample( 

512 members={ 

513 m: get_tensor(inputs[m], ipt) 

514 for m, ipt in model_inputs.items() 

515 if m in inputs 

516 }, 

517 stat={} if stat is None else stat, 

518 id=sample_id, 

519 ) 

520 

521 

522def load_sample_for_model( 

523 *, 

524 model: AnyModelDescr, 

525 paths: PerMember[Path], 

526 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 

527 stat: Optional[Stat] = None, 

528 sample_id: Optional[SampleId] = None, 

529): 

530 """load a single sample from `paths` that can be processed by `model`""" 

531 

532 if axes is None: 

533 axes = {} 

534 

535 # make sure members are keyed by MemberId, not string 

536 paths = {MemberId(k): v for k, v in paths.items()} 

537 axes = {MemberId(k): v for k, v in axes.items()} 

538 

539 model_inputs = {get_member_id(d): d for d in model.inputs} 

540 

541 if unknown := {k for k in paths if k not in model_inputs}: 

542 raise ValueError(f"Got unexpected paths for {unknown}") 

543 

544 if unknown := {k for k in axes if k not in model_inputs}: 

545 raise ValueError(f"Got unexpected axes hints for: {unknown}") 

546 

547 members: Dict[MemberId, Tensor] = {} 

548 for m, p in paths.items(): 

549 if m not in axes: 

550 axes[m] = get_axes_infos(model_inputs[m]) 

551 logger.info( 

552 "loading '{}' from {} with default input axes {} ", 

553 m, 

554 p, 

555 axes[m], 

556 ) 

557 members[m] = load_tensor(p, axes[m]) 

558 

559 return Sample( 

560 members=members, 

561 stat={} if stat is None else stat, 

562 id=sample_id or tuple(sorted(paths.values())), 

563 ) 

564 

565 

566def split_sample_into_blocks_for_model( 

567 sample: Sample, 

568 model: v0_5.ModelDescr, 

569 blocksize_parameter: int, 

570 batch_size: int = 1, 

571) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: 

572 if isinstance(model, v0_4.ModelDescr): 

573 raise NotImplementedError( 

574 "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr" 

575 + f" {model.name}." 

576 + " Consider using `predict_sample_with_fixed_blocking` or update the model description to format version 0.5." 

577 ) 

578 

579 ns = { 

580 (ipt.id, a.id): blocksize_parameter 

581 for ipt in model.inputs 

582 for a in ipt.axes 

583 if isinstance(a.size, v0_5.ParameterizedSize) 

584 } 

585 halo = get_input_halo(model) 

586 

587 input_block_shape = model.get_tensor_sizes(ns, batch_size=batch_size).inputs 

588 

589 return sample.split_into_blocks( 

590 block_shapes=input_block_shape, 

591 halo=halo, 

592 pad_mode={ipt.id: ipt.pad or "symmetric" for ipt in model.inputs}, 

593 )