Coverage for src/bioimageio/core/digest_spec.py: 79%

217 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:02 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import importlib.util 

5import sys 

6from itertools import chain 

7from pathlib import Path 

8from tempfile import TemporaryDirectory 

9from typing import ( 

10 Any, 

11 Callable, 

12 Dict, 

13 Iterable, 

14 List, 

15 Mapping, 

16 NamedTuple, 

17 Optional, 

18 Sequence, 

19 Tuple, 

20 Union, 

21) 

22from zipfile import ZipFile, is_zipfile 

23 

24import numpy as np 

25import xarray as xr 

26from bioimageio.spec._internal.io import HashKwargs 

27from bioimageio.spec.common import FileDescr, FileSource, ZipPath 

28from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

29from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 

30from bioimageio.spec.model.v0_5 import ( 

31 ArchitectureFromFileDescr, 

32 ArchitectureFromLibraryDescr, 

33 ParameterizedSize_N, 

34) 

35from bioimageio.spec.utils import load_array 

36from loguru import logger 

37from numpy.typing import NDArray 

38from typing_extensions import Unpack, assert_never 

39 

40from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 

41from .block_meta import split_multiple_shapes_into_blocks 

42from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 

43from .io import load_tensor 

44from .sample import ( 

45 LinearSampleAxisTransform, 

46 Sample, 

47 SampleBlockMeta, 

48 sample_block_meta_generator, 

49) 

50from .stat_measures import Stat 

51from .tensor import Tensor 

52 

53TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path] 

54 

55 

56def import_callable( 

57 node: Union[ 

58 ArchitectureFromFileDescr, 

59 ArchitectureFromLibraryDescr, 

60 CallableFromDepencency, 

61 CallableFromFile, 

62 ], 

63 /, 

64 **kwargs: Unpack[HashKwargs], 

65) -> Callable[..., Any]: 

66 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 

67 if isinstance(node, CallableFromDepencency): 

68 module = importlib.import_module(node.module_name) 

69 c = getattr(module, str(node.callable_name)) 

70 elif isinstance(node, ArchitectureFromLibraryDescr): 

71 module = importlib.import_module(node.import_from) 

72 c = getattr(module, str(node.callable)) 

73 elif isinstance(node, CallableFromFile): 

74 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 

75 elif isinstance(node, ArchitectureFromFileDescr): 

76 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 

77 else: 

78 assert_never(node) 

79 

80 if not callable(c): 

81 raise ValueError(f"{node} (imported: {c}) is not callable") 

82 

83 return c 

84 

85 

86tmp_dirs_in_use: List[TemporaryDirectory[str]] = [] 

87"""keep global reference to temporary directories created during import to delay cleanup""" 

88 

89 

90def _import_from_file_impl( 

91 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 

92): 

93 src_descr = FileDescr(source=source, **kwargs) 

94 # ensure sha is valid even if perform_io_checks=False 

95 # or the source has changed since last sha computation 

96 src_descr.validate_sha256(force_recompute=True) 

97 assert src_descr.sha256 is not None 

98 source_sha = src_descr.sha256 

99 

100 reader = src_descr.get_reader() 

101 # make sure we have unique module name 

102 module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}" 

103 

104 # make sure we have a unique and valid module name 

105 if not module_name.isidentifier(): 

106 module_name = f"custom_module_{source_sha}" 

107 assert module_name.isidentifier(), module_name 

108 

109 source_bytes = reader.read() 

110 

111 module = sys.modules.get(module_name) 

112 if module is None: 

113 try: 

114 td_kwargs: Dict[str, Any] = ( 

115 dict(ignore_cleanup_errors=True) if sys.version_info >= (3, 10) else {} 

116 ) 

117 if sys.version_info >= (3, 12): 

118 td_kwargs["delete"] = False 

119 

120 tmp_dir = TemporaryDirectory(**td_kwargs) 

121 # keep global ref to tmp_dir to delay cleanup until program exit 

122 # TODO: remove for py >= 3.12, when delete=False works 

123 tmp_dirs_in_use.append(tmp_dir) 

124 

125 module_path = Path(tmp_dir.name) / module_name 

126 if reader.original_file_name.endswith(".zip") or is_zipfile(reader): 

127 module_path.mkdir() 

128 ZipFile(reader).extractall(path=module_path) 

129 else: 

130 module_path = module_path.with_suffix(".py") 

131 _ = module_path.write_bytes(source_bytes) 

132 

133 importlib_spec = importlib.util.spec_from_file_location( 

134 module_name, str(module_path) 

135 ) 

136 

137 if importlib_spec is None: 

138 raise ImportError(f"Failed to import {source}") 

139 

140 module = importlib.util.module_from_spec(importlib_spec) 

141 

142 sys.modules[module_name] = module # cache this module 

143 

144 assert importlib_spec.loader is not None 

145 importlib_spec.loader.exec_module(module) 

146 

147 except Exception as e: 

148 if module_name in sys.modules: 

149 del sys.modules[module_name] 

150 

151 raise ImportError(f"Failed to import {source}") from e 

152 

153 try: 

154 callable_attr = getattr(module, callable_name) 

155 except AttributeError as e: 

156 raise AttributeError( 

157 f"Imported custom module from {source} has no `{callable_name}` attribute." 

158 ) from e 

159 except Exception as e: 

160 raise AttributeError( 

161 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ." 

162 ) from e 

163 

164 else: 

165 return callable_attr 

166 

167 

168def get_axes_infos( 

169 io_descr: Union[ 

170 v0_4.InputTensorDescr, 

171 v0_4.OutputTensorDescr, 

172 v0_5.InputTensorDescr, 

173 v0_5.OutputTensorDescr, 

174 ], 

175) -> List[AxisInfo]: 

176 """get a unified, simplified axis representation from spec axes""" 

177 ret: List[AxisInfo] = [] 

178 for a in io_descr.axes: 

179 if isinstance(a, v0_5.AxisBase): 

180 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 

181 else: 

182 assert a in ("b", "i", "t", "c", "z", "y", "x") 

183 ret.append(AxisInfo.create(a)) 

184 

185 return ret 

186 

187 

188def get_member_id( 

189 tensor_description: Union[ 

190 v0_4.InputTensorDescr, 

191 v0_4.OutputTensorDescr, 

192 v0_5.InputTensorDescr, 

193 v0_5.OutputTensorDescr, 

194 ], 

195) -> MemberId: 

196 """get the normalized tensor ID, usable as a sample member ID""" 

197 

198 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 

199 return MemberId(tensor_description.name) 

200 elif isinstance( 

201 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 

202 ): 

203 return tensor_description.id 

204 else: 

205 assert_never(tensor_description) 

206 

207 

208def get_member_ids( 

209 tensor_descriptions: Sequence[ 

210 Union[ 

211 v0_4.InputTensorDescr, 

212 v0_4.OutputTensorDescr, 

213 v0_5.InputTensorDescr, 

214 v0_5.OutputTensorDescr, 

215 ] 

216 ], 

217) -> List[MemberId]: 

218 """get normalized tensor IDs to be used as sample member IDs""" 

219 return [get_member_id(descr) for descr in tensor_descriptions] 

220 

221 

222def get_test_input_sample(model: AnyModelDescr) -> Sample: 

223 return _get_test_sample( 

224 model.inputs, 

225 model.test_inputs if isinstance(model, v0_4.ModelDescr) else model.inputs, 

226 ) 

227 

228 

229get_test_inputs = get_test_input_sample 

230"""DEPRECATED: use `get_test_input_sample` instead""" 

231 

232 

233def get_test_output_sample(model: AnyModelDescr) -> Sample: 

234 """returns a model's test output sample""" 

235 return _get_test_sample( 

236 model.outputs, 

237 model.test_outputs if isinstance(model, v0_4.ModelDescr) else model.outputs, 

238 ) 

239 

240 

241get_test_outputs = get_test_output_sample 

242"""DEPRECATED: use `get_test_input_sample` instead""" 

243 

244 

245def _get_test_sample( 

246 tensor_descrs: Sequence[ 

247 Union[ 

248 v0_4.InputTensorDescr, 

249 v0_4.OutputTensorDescr, 

250 v0_5.InputTensorDescr, 

251 v0_5.OutputTensorDescr, 

252 ] 

253 ], 

254 test_sources: Sequence[Union[FileSource, v0_5.TensorDescr]], 

255) -> Sample: 

256 """returns a model's input/output test sample""" 

257 member_ids = get_member_ids(tensor_descrs) 

258 arrays: List[NDArray[Any]] = [] 

259 for src in test_sources: 

260 if isinstance(src, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr)): 

261 if src.test_tensor is None: 

262 raise ValueError( 

263 f"Model input '{src.id}' has no test tensor defined, cannot create test sample." 

264 ) 

265 arrays.append(load_array(src.test_tensor)) 

266 else: 

267 arrays.append(load_array(src)) 

268 

269 axes = [get_axes_infos(t) for t in tensor_descrs] 

270 return Sample( 

271 members={ 

272 m: Tensor.from_numpy(arr, dims=ax) 

273 for m, arr, ax in zip(member_ids, arrays, axes) 

274 }, 

275 stat={}, 

276 id="test-sample", 

277 ) 

278 

279 

280class IO_SampleBlockMeta(NamedTuple): 

281 input: SampleBlockMeta 

282 output: SampleBlockMeta 

283 

284 

285def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 

286 """returns which halo input tensors need to be divided into blocks with, such that 

287 `output_halo` can be cropped from their outputs without introducing gaps.""" 

288 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 

289 outputs = {t.id: t for t in model.outputs} 

290 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 

291 

292 for t, th in output_halo.items(): 

293 axes = {a.id: a for a in outputs[t].axes} 

294 

295 for a, ah in th.items(): 

296 s = axes[a].size 

297 if not isinstance(s, v0_5.SizeReference): 

298 raise ValueError( 

299 f"Unable to map output halo for {t}.{a} to an input axis" 

300 ) 

301 

302 axis = axes[a] 

303 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 

304 

305 input_halo_left = ah.left * axis.scale / ref_axis.scale 

306 input_halo_right = ah.right * axis.scale / ref_axis.scale 

307 assert input_halo_left == int(input_halo_left), f"{input_halo_left} not int" 

308 assert input_halo_right == int(input_halo_right), ( 

309 f"{input_halo_right} not int" 

310 ) 

311 

312 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 

313 int(input_halo_left), int(input_halo_right) 

314 ) 

315 

316 return input_halo 

317 

318 

319def get_block_transform( 

320 model: v0_5.ModelDescr, 

321) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 

322 """returns how a model's output tensor shapes relates to its input shapes""" 

323 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 

324 batch_axis_trf = None 

325 for ipt in model.inputs: 

326 for a in ipt.axes: 

327 if a.type == "batch": 

328 batch_axis_trf = LinearSampleAxisTransform( 

329 axis=a.id, scale=1, offset=0, member=ipt.id 

330 ) 

331 break 

332 if batch_axis_trf is not None: 

333 break 

334 axis_scales = { 

335 t.id: {a.id: a.scale for a in t.axes} 

336 for t in chain(model.inputs, model.outputs) 

337 } 

338 for out in model.outputs: 

339 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 

340 for a in out.axes: 

341 if a.size is None: 

342 assert a.type == "batch" 

343 if batch_axis_trf is None: 

344 raise ValueError( 

345 "no batch axis found in any input tensor, but output tensor" 

346 + f" '{out.id}' has one." 

347 ) 

348 s = batch_axis_trf 

349 elif isinstance(a.size, int): 

350 s = a.size 

351 elif isinstance(a.size, v0_5.DataDependentSize): 

352 s = -1 

353 elif isinstance(a.size, v0_5.SizeReference): 

354 s = LinearSampleAxisTransform( 

355 axis=a.size.axis_id, 

356 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 

357 offset=a.size.offset, 

358 member=a.size.tensor_id, 

359 ) 

360 else: 

361 assert_never(a.size) 

362 

363 new_axes[a.id] = s 

364 

365 ret[out.id] = new_axes 

366 

367 return ret 

368 

369 

370def get_io_sample_block_metas( 

371 model: v0_5.ModelDescr, 

372 input_sample_shape: PerMember[PerAxis[int]], 

373 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 

374 batch_size: int = 1, 

375) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 

376 """returns an iterable yielding meta data for corresponding input and output samples""" 

377 if not isinstance(model, v0_5.ModelDescr): 

378 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 

379 

380 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 

381 input_block_shape = { 

382 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 

383 for t in {tt for tt, _ in block_axis_sizes.inputs} 

384 } 

385 output_halo = { 

386 t.id: { 

387 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 

388 } 

389 for t in model.outputs 

390 } 

391 input_halo = get_input_halo(model, output_halo) 

392 

393 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 

394 input_sample_shape, input_block_shape, halo=input_halo 

395 ) 

396 block_transform = get_block_transform(model) 

397 return n_input_blocks, ( 

398 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 

399 for ipt in sample_block_meta_generator( 

400 input_blocks, sample_shape=input_sample_shape, sample_id=None 

401 ) 

402 ) 

403 

404 

405def get_tensor( 

406 src: Union[ZipPath, TensorSource], 

407 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 

408): 

409 """helper to cast/load various tensor sources""" 

410 

411 if isinstance(src, Tensor): 

412 return src 

413 elif isinstance(src, xr.DataArray): 

414 return Tensor.from_xarray(src) 

415 elif isinstance(src, np.ndarray): 

416 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 

417 else: 

418 return load_tensor(src, axes=get_axes_infos(ipt)) 

419 

420 

421def create_sample_for_model( 

422 model: AnyModelDescr, 

423 *, 

424 stat: Optional[Stat] = None, 

425 sample_id: SampleId = None, 

426 inputs: Union[PerMember[TensorSource], TensorSource], 

427) -> Sample: 

428 """Create a sample from a single set of input(s) for a specific bioimage.io model 

429 

430 Args: 

431 model: a bioimage.io model description 

432 stat: dictionary with sample and dataset statistics (may be updated in-place!) 

433 inputs: the input(s) constituting a single sample. 

434 """ 

435 

436 model_inputs = {get_member_id(d): d for d in model.inputs} 

437 if isinstance(inputs, collections.abc.Mapping): 

438 inputs = {MemberId(k): v for k, v in inputs.items()} 

439 elif len(model_inputs) == 1: 

440 inputs = {list(model_inputs)[0]: inputs} 

441 else: 

442 raise TypeError( 

443 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 

444 ) 

445 

446 if unknown := {k for k in inputs if k not in model_inputs}: 

447 raise ValueError(f"Got unexpected inputs: {unknown}") 

448 

449 if missing := { 

450 k 

451 for k, v in model_inputs.items() 

452 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 

453 }: 

454 raise ValueError(f"Missing non-optional model inputs: {missing}") 

455 

456 return Sample( 

457 members={ 

458 m: get_tensor(inputs[m], ipt) 

459 for m, ipt in model_inputs.items() 

460 if m in inputs 

461 }, 

462 stat={} if stat is None else stat, 

463 id=sample_id, 

464 ) 

465 

466 

467def load_sample_for_model( 

468 *, 

469 model: AnyModelDescr, 

470 paths: PerMember[Path], 

471 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 

472 stat: Optional[Stat] = None, 

473 sample_id: Optional[SampleId] = None, 

474): 

475 """load a single sample from `paths` that can be processed by `model`""" 

476 

477 if axes is None: 

478 axes = {} 

479 

480 # make sure members are keyed by MemberId, not string 

481 paths = {MemberId(k): v for k, v in paths.items()} 

482 axes = {MemberId(k): v for k, v in axes.items()} 

483 

484 model_inputs = {get_member_id(d): d for d in model.inputs} 

485 

486 if unknown := {k for k in paths if k not in model_inputs}: 

487 raise ValueError(f"Got unexpected paths for {unknown}") 

488 

489 if unknown := {k for k in axes if k not in model_inputs}: 

490 raise ValueError(f"Got unexpected axes hints for: {unknown}") 

491 

492 members: Dict[MemberId, Tensor] = {} 

493 for m, p in paths.items(): 

494 if m not in axes: 

495 axes[m] = get_axes_infos(model_inputs[m]) 

496 logger.debug( 

497 "loading '{}' from {} with default input axes {} ", 

498 m, 

499 p, 

500 axes[m], 

501 ) 

502 members[m] = load_tensor(p, axes[m]) 

503 

504 return Sample( 

505 members=members, 

506 stat={} if stat is None else stat, 

507 id=sample_id or tuple(sorted(paths.values())), 

508 )