Coverage for src / bioimageio / core / digest_spec.py: 85%

230 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-18 12:35 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import importlib.util 

5import sys 

6from itertools import chain 

7from pathlib import Path 

8from tempfile import TemporaryDirectory 

9from typing import ( 

10 Any, 

11 Callable, 

12 Dict, 

13 Iterable, 

14 List, 

15 Mapping, 

16 NamedTuple, 

17 Optional, 

18 Sequence, 

19 Tuple, 

20 Union, 

21) 

22from zipfile import ZipFile, is_zipfile 

23 

24import numpy as np 

25import xarray as xr 

26from loguru import logger 

27from numpy.typing import NDArray 

28from typing_extensions import Unpack, assert_never 

29 

30from bioimageio.spec._internal.io import HashKwargs, PermissiveFileSource 

31from bioimageio.spec.common import FileDescr, FileSource 

32from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

33from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 

34from bioimageio.spec.model.v0_5 import ( 

35 ArchitectureFromFileDescr, 

36 ArchitectureFromLibraryDescr, 

37 ParameterizedSize_N, 

38) 

39from bioimageio.spec.utils import load_array 

40 

41from .axis import AxisId, AxisInfo, AxisLike, PerAxis 

42from .block_meta import split_multiple_shapes_into_blocks 

43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 

44from .io import load_tensor 

45from .sample import ( 

46 LinearSampleAxisTransform, 

47 Sample, 

48 SampleBlockMeta, 

49 sample_block_meta_generator, 

50) 

51from .stat_measures import Stat 

52from .tensor import Tensor 

53 

54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], PermissiveFileSource] 

55 

56 

57def import_callable( 

58 node: Union[ 

59 ArchitectureFromFileDescr, 

60 ArchitectureFromLibraryDescr, 

61 CallableFromDepencency, 

62 CallableFromFile, 

63 v0_5.CustomProcessingDescr, 

64 ], 

65 /, 

66 **kwargs: Unpack[HashKwargs], 

67) -> Callable[..., Any]: 

68 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 

69 if isinstance(node, CallableFromDepencency): 

70 module = importlib.import_module(node.module_name) 

71 c = getattr(module, str(node.callable_name)) 

72 elif isinstance(node, ArchitectureFromLibraryDescr): 

73 module = importlib.import_module(node.import_from) 

74 c = getattr(module, str(node.callable)) 

75 elif isinstance(node, CallableFromFile): 

76 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 

77 elif isinstance(node, (ArchitectureFromFileDescr, v0_5.CustomProcessingDescr)): 

78 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 

79 else: 

80 assert_never(node) 

81 

82 if not callable(c): 

83 raise ValueError(f"{node} (imported: {c}) is not callable") 

84 

85 return c 

86 

87 

88tmp_dirs_in_use: List[TemporaryDirectory[str]] = [] 

89"""keep global reference to temporary directories created during import to delay cleanup""" 

90 

91 

92def _import_from_file_impl( 

93 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 

94): 

95 src_descr = FileDescr(source=source, **kwargs) 

96 # ensure sha is valid even if perform_io_checks=False 

97 # or the source has changed since last sha computation 

98 src_descr.validate_sha256(force_recompute=True) 

99 assert src_descr.sha256 is not None 

100 source_sha = src_descr.sha256 

101 

102 reader = src_descr.get_reader() 

103 # make sure we have unique module name 

104 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}" 

105 

106 # make sure we have a unique and valid module name 

107 if not module_name.isidentifier(): 

108 module_name = f"custom_module_{source_sha}" 

109 assert module_name.isidentifier(), module_name 

110 

111 source_bytes = reader.read() 

112 

113 module = sys.modules.get(module_name) 

114 if module is None: 

115 try: 

116 td_kwargs: Dict[str, Any] = ( 

117 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

118 ) 

119 if sys.version_info >= (3, 12): 

120 td_kwargs["delete"] = False 

121 

122 tmp_dir = TemporaryDirectory(**td_kwargs) 

123 # keep global ref to tmp_dir to delay cleanup until program exit 

124 # TODO: remove for py >= 3.12, when delete=False works 

125 tmp_dirs_in_use.append(tmp_dir) 

126 

127 module_path = Path(tmp_dir.name) / module_name 

128 if reader.original_file_name.endswith(".zip") or is_zipfile(reader): 

129 module_path.mkdir() 

130 ZipFile(reader).extractall(path=module_path) 

131 else: 

132 module_path = module_path.with_suffix(".py") 

133 _ = module_path.write_bytes(source_bytes) 

134 

135 importlib_spec = importlib.util.spec_from_file_location( 

136 module_name, str(module_path) 

137 ) 

138 

139 if importlib_spec is None: 

140 raise ImportError(f"Failed to import {source}") 

141 

142 module = importlib.util.module_from_spec(importlib_spec) 

143 

144 sys.modules[module_name] = module # cache this module 

145 

146 assert importlib_spec.loader is not None 

147 importlib_spec.loader.exec_module(module) 

148 

149 except Exception as e: 

150 if module_name in sys.modules: 

151 del sys.modules[module_name] 

152 

153 raise ImportError(f"Failed to import {source}") from e 

154 

155 try: 

156 callable_attr = getattr(module, callable_name) 

157 except AttributeError as e: 

158 raise AttributeError( 

159 f"Imported custom module from {source} has no `{callable_name}` attribute." 

160 ) from e 

161 except Exception as e: 

162 raise AttributeError( 

163 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ." 

164 ) from e 

165 

166 else: 

167 return callable_attr 

168 

169 

170def get_axes_infos( 

171 io_descr: Union[ 

172 v0_4.InputTensorDescr, 

173 v0_4.OutputTensorDescr, 

174 v0_5.InputTensorDescr, 

175 v0_5.OutputTensorDescr, 

176 ], 

177) -> List[AxisInfo]: 

178 """get a unified, simplified axis representation from spec axes""" 

179 return [AxisInfo.create(a) for a in io_descr.axes] 

180 

181 

182def get_member_id( 

183 tensor_description: Union[ 

184 v0_4.InputTensorDescr, 

185 v0_4.OutputTensorDescr, 

186 v0_5.InputTensorDescr, 

187 v0_5.OutputTensorDescr, 

188 ], 

189) -> MemberId: 

190 """get the normalized tensor ID, usable as a sample member ID""" 

191 

192 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 

193 return MemberId(tensor_description.name) 

194 elif isinstance( 

195 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 

196 ): 

197 return tensor_description.id 

198 else: 

199 assert_never(tensor_description) 

200 

201 

202def get_member_ids( 

203 tensor_descriptions: Iterable[ 

204 Union[ 

205 v0_4.InputTensorDescr, 

206 v0_4.OutputTensorDescr, 

207 v0_5.InputTensorDescr, 

208 v0_5.OutputTensorDescr, 

209 ] 

210 ], 

211) -> List[MemberId]: 

212 """get normalized tensor IDs to be used as sample member IDs""" 

213 return [get_member_id(descr) for descr in tensor_descriptions] 

214 

215 

216def get_test_input_sample(model: AnyModelDescr) -> Sample: 

217 if isinstance(model, v0_4.ModelDescr): 

218 info = { 

219 MemberId(d.name): (d, t) for d, t in zip(model.inputs, model.test_inputs) 

220 } 

221 else: 

222 info = {d.id: d for d in model.inputs} 

223 

224 return _get_test_sample(info) 

225 

226 

227get_test_inputs = get_test_input_sample 

228"""DEPRECATED: use `get_test_input_sample` instead""" 

229 

230 

231def get_test_output_sample(model: AnyModelDescr) -> Sample: 

232 """returns a model's test output sample""" 

233 if isinstance(model, v0_4.ModelDescr): 

234 info = { 

235 MemberId(d.name): (d, t) for d, t in zip(model.outputs, model.test_outputs) 

236 } 

237 else: 

238 info = {d.id: d for d in model.outputs} 

239 

240 return _get_test_sample(info) 

241 

242 

243get_test_outputs = get_test_output_sample 

244"""DEPRECATED: use `get_test_output_sample` instead""" 

245 

246 

247def _get_test_sample( 

248 info: Union[ 

249 Mapping[MemberId, Union[v0_5.InputTensorDescr, v0_5.OutputTensorDescr]], 

250 Mapping[ 

251 MemberId, 

252 Tuple[ 

253 v0_4.InputTensorDescr, 

254 FileSource, 

255 ], 

256 ], 

257 Mapping[ 

258 MemberId, 

259 Tuple[ 

260 v0_4.OutputTensorDescr, 

261 FileSource, 

262 ], 

263 ], 

264 ], 

265) -> Sample: 

266 arrays: Dict[MemberId, NDArray[Any]] = {} 

267 for m, src in info.items(): 

268 if isinstance(src, tuple): 

269 arrays[m] = load_array(src[1]) 

270 elif isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)): 

271 if src.test_tensor is None: 

272 raise ValueError( 

273 f"Model input '{m}' has no test tensor defined, cannot create test sample." 

274 ) 

275 arrays[m] = load_array(src.test_tensor) 

276 else: 

277 assert_never(src) 

278 

279 axes = { 

280 m: get_axes_infos(t[0] if isinstance(t, tuple) else t) for m, t in info.items() 

281 } 

282 return Sample( 

283 members={m: Tensor.from_numpy(arrays[m], dims=axes[m]) for m in info}, 

284 stat={}, 

285 id="test-sample", 

286 ) 

287 

288 

289class IO_SampleBlockMeta(NamedTuple): 

290 input: SampleBlockMeta 

291 output: SampleBlockMeta 

292 

293 

294def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 

295 """returns which halo input tensors need to be divided into blocks with, such that 

296 `output_halo` can be cropped from their outputs without introducing gaps.""" 

297 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 

298 outputs = {t.id: t for t in model.outputs} 

299 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 

300 

301 for t, th in output_halo.items(): 

302 axes = {a.id: a for a in outputs[t].axes} 

303 

304 for a, ah in th.items(): 

305 s = axes[a].size 

306 if not isinstance(s, v0_5.SizeReference): 

307 raise ValueError( 

308 f"Unable to map output halo for {t}.{a} to an input axis" 

309 ) 

310 

311 axis = axes[a] 

312 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 

313 

314 input_halo_left = ah.left * axis.scale / ref_axis.scale 

315 input_halo_right = ah.right * axis.scale / ref_axis.scale 

316 assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int" 

317 assert input_halo_right == int(input_halo_right), ( 

318 f"{input_halo_right} not int" 

319 ) 

320 

321 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 

322 int(input_halo_left), int(input_halo_right) 

323 ) 

324 

325 return input_halo 

326 

327 

328def get_block_transform( 

329 model: v0_5.ModelDescr, 

330) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 

331 """returns how a model's output tensor shapes relates to its input shapes""" 

332 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 

333 batch_axis_trf = None 

334 for ipt in model.inputs: 

335 for a in ipt.axes: 

336 if a.type == "batch": 

337 batch_axis_trf = LinearSampleAxisTransform( 

338 axis=a.id, scale=1, offset=0, member=ipt.id 

339 ) 

340 break 

341 if batch_axis_trf is not None: 

342 break 

343 axis_scales = { 

344 t.id: {a.id: a.scale for a in t.axes} 

345 for t in chain(model.inputs, model.outputs) 

346 } 

347 for out in model.outputs: 

348 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 

349 for a in out.axes: 

350 if a.size is None: 

351 assert a.type == "batch" 

352 if batch_axis_trf is None: 

353 raise ValueError( 

354 "no batch axis found in any input tensor, but output tensor" 

355 + f" '{out.id}' has one." 

356 ) 

357 s = batch_axis_trf 

358 elif isinstance(a.size, int): 

359 s = a.size 

360 elif isinstance(a.size, v0_5.DataDependentSize): 

361 s = -1 

362 elif isinstance(a.size, v0_5.SizeReference): 

363 s = LinearSampleAxisTransform( 

364 axis=a.size.axis_id, 

365 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 

366 offset=a.size.offset, 

367 member=a.size.tensor_id, 

368 ) 

369 else: 

370 assert_never(a.size) 

371 

372 new_axes[a.id] = s 

373 

374 # account for postprocessing that changes the nubmer of output channels by 

375 # overwriting described output shape by the intermediate output shape 

376 c = AxisId("channel") 

377 if c not in new_axes: 

378 continue 

379 for post in out.postprocessing: 

380 if post.id == "cellpose_flow_dynamics": 

381 new_axes[c] = 3 

382 break 

383 elif post.id == "stardist_postprocessing": 

384 new_axes[c] = post.kwargs.n_rays + 1 

385 break 

386 

387 ret[out.id] = new_axes 

388 

389 return ret 

390 

391 

392def get_io_sample_block_metas( 

393 model: v0_5.ModelDescr, 

394 input_sample_shape: PerMember[PerAxis[int]], 

395 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 

396 batch_size: int = 1, 

397) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 

398 """returns an iterable yielding meta data for corresponding input and output samples""" 

399 if not isinstance(model, v0_5.ModelDescr): 

400 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 

401 

402 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 

403 input_block_shape = { 

404 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 

405 for t in {tt for tt, _ in block_axis_sizes.inputs} 

406 } 

407 output_halo = { 

408 t.id: { 

409 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 

410 } 

411 for t in model.outputs 

412 } 

413 input_halo = get_input_halo(model, output_halo) 

414 

415 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 

416 input_sample_shape, input_block_shape, halo=input_halo 

417 ) 

418 block_transform = get_block_transform(model) 

419 return n_input_blocks, ( 

420 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 

421 for ipt in sample_block_meta_generator( 

422 input_blocks, sample_shape=input_sample_shape, sample_id=None 

423 ) 

424 ) 

425 

426 

427def get_tensor( 

428 src: TensorSource, 

429 descr: Union[ 

430 v0_4.InputTensorDescr, 

431 v0_5.InputTensorDescr, 

432 v0_4.OutputTensorDescr, 

433 v0_5.OutputTensorDescr, 

434 Sequence[AxisInfo], 

435 ], 

436): 

437 """helper to cast/load various tensor sources""" 

438 

439 if isinstance( 

440 descr, 

441 ( 

442 v0_4.InputTensorDescr, 

443 v0_5.InputTensorDescr, 

444 v0_4.OutputTensorDescr, 

445 v0_5.OutputTensorDescr, 

446 ), 

447 ): 

448 axes = get_axes_infos(descr) 

449 else: 

450 axes = descr 

451 

452 if isinstance(src, Tensor): 

453 return src.transpose(axes=[a.id for a in axes]) 

454 elif isinstance(src, xr.DataArray): 

455 return Tensor.from_xarray(src).transpose(axes=[a.id for a in axes]) 

456 elif isinstance(src, np.ndarray): 

457 return Tensor.from_numpy(src, dims=axes) 

458 else: 

459 return load_tensor(src, axes=axes) 

460 

461 

462def create_sample_for_model( 

463 model: AnyModelDescr, 

464 *, 

465 stat: Optional[Stat] = None, 

466 sample_id: SampleId = None, 

467 inputs: Union[PerMember[TensorSource], TensorSource], 

468) -> Sample: 

469 """Create a sample from a single set of input(s) for a specific bioimage.io model 

470 

471 Args: 

472 model: a bioimage.io model description 

473 stat: dictionary with sample and dataset statistics (may be updated in-place!) 

474 inputs: the input(s) constituting a single sample. 

475 """ 

476 

477 model_inputs = {get_member_id(d): d for d in model.inputs} 

478 if isinstance(inputs, collections.abc.Mapping): 

479 inputs = {MemberId(k): v for k, v in inputs.items()} 

480 elif len(model_inputs) == 1: 

481 inputs = {list(model_inputs)[0]: inputs} 

482 else: 

483 raise TypeError( 

484 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 

485 ) 

486 

487 if unknown := {k for k in inputs if k not in model_inputs}: 

488 raise ValueError(f"Got unexpected inputs: {unknown}") 

489 

490 if missing := { 

491 k 

492 for k, v in model_inputs.items() 

493 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 

494 }: 

495 raise ValueError(f"Missing non-optional model inputs: {missing}") 

496 

497 return Sample( 

498 members={ 

499 m: get_tensor(inputs[m], ipt) 

500 for m, ipt in model_inputs.items() 

501 if m in inputs 

502 }, 

503 stat={} if stat is None else stat, 

504 id=sample_id, 

505 ) 

506 

507 

508def load_sample_for_model( 

509 *, 

510 model: AnyModelDescr, 

511 paths: PerMember[Path], 

512 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 

513 stat: Optional[Stat] = None, 

514 sample_id: Optional[SampleId] = None, 

515): 

516 """load a single sample from `paths` that can be processed by `model`""" 

517 

518 if axes is None: 

519 axes = {} 

520 

521 # make sure members are keyed by MemberId, not string 

522 paths = {MemberId(k): v for k, v in paths.items()} 

523 axes = {MemberId(k): v for k, v in axes.items()} 

524 

525 model_inputs = {get_member_id(d): d for d in model.inputs} 

526 

527 if unknown := {k for k in paths if k not in model_inputs}: 

528 raise ValueError(f"Got unexpected paths for {unknown}") 

529 

530 if unknown := {k for k in axes if k not in model_inputs}: 

531 raise ValueError(f"Got unexpected axes hints for: {unknown}") 

532 

533 members: Dict[MemberId, Tensor] = {} 

534 for m, p in paths.items(): 

535 if m not in axes: 

536 axes[m] = get_axes_infos(model_inputs[m]) 

537 logger.info( 

538 "loading '{}' from {} with default input axes {} ", 

539 m, 

540 p, 

541 axes[m], 

542 ) 

543 members[m] = load_tensor(p, axes[m]) 

544 

545 return Sample( 

546 members=members, 

547 stat={} if stat is None else stat, 

548 id=sample_id or tuple(sorted(paths.values())), 

549 )