Coverage for bioimageio/core/digest_spec.py: 86%

205 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 09:51 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import hashlib 

5import importlib.util 

6import sys 

7from itertools import chain 

8from pathlib import Path 

9from typing import ( 

10 Any, 

11 Callable, 

12 Dict, 

13 Iterable, 

14 List, 

15 Mapping, 

16 NamedTuple, 

17 Optional, 

18 Sequence, 

19 Tuple, 

20 Union, 

21) 

22 

23import numpy as np 

24import xarray as xr 

25from loguru import logger 

26from numpy.typing import NDArray 

27from typing_extensions import Unpack, assert_never 

28 

29from bioimageio.spec._internal.io import HashKwargs 

30from bioimageio.spec.common import FileDescr, FileSource, ZipPath 

31from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

32from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 

33from bioimageio.spec.model.v0_5 import ( 

34 ArchitectureFromFileDescr, 

35 ArchitectureFromLibraryDescr, 

36 ParameterizedSize_N, 

37) 

38from bioimageio.spec.utils import download, load_array 

39 

40from ._settings import settings 

41from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis 

42from .block_meta import split_multiple_shapes_into_blocks 

43from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 

44from .io import load_tensor 

45from .sample import ( 

46 LinearSampleAxisTransform, 

47 Sample, 

48 SampleBlockMeta, 

49 sample_block_meta_generator, 

50) 

51from .stat_measures import Stat 

52from .tensor import Tensor 

53 

54TensorSource = Union[Tensor, xr.DataArray, NDArray[Any], Path] 

55 

56 

57def import_callable( 

58 node: Union[ 

59 ArchitectureFromFileDescr, 

60 ArchitectureFromLibraryDescr, 

61 CallableFromDepencency, 

62 CallableFromFile, 

63 ], 

64 /, 

65 **kwargs: Unpack[HashKwargs], 

66) -> Callable[..., Any]: 

67 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 

68 if isinstance(node, CallableFromDepencency): 

69 module = importlib.import_module(node.module_name) 

70 c = getattr(module, str(node.callable_name)) 

71 elif isinstance(node, ArchitectureFromLibraryDescr): 

72 module = importlib.import_module(node.import_from) 

73 c = getattr(module, str(node.callable)) 

74 elif isinstance(node, CallableFromFile): 

75 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 

76 elif isinstance(node, ArchitectureFromFileDescr): 

77 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 

78 else: 

79 assert_never(node) 

80 

81 if not callable(c): 

82 raise ValueError(f"{node} (imported: {c}) is not callable") 

83 

84 return c 

85 

86 

87def _import_from_file_impl( 

88 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 

89): 

90 src_descr = FileDescr(source=source, **kwargs) 

91 # ensure sha is valid even if perform_io_checks=False 

92 src_descr.validate_sha256() 

93 assert src_descr.sha256 is not None 

94 

95 local_source = src_descr.download() 

96 

97 source_bytes = local_source.path.read_bytes() 

98 assert isinstance(source_bytes, bytes) 

99 source_sha = hashlib.sha256(source_bytes).hexdigest() 

100 

101 # make sure we have unique module name 

102 module_name = f"{local_source.path.stem}_{source_sha}" 

103 

104 # make sure we have a valid module name 

105 if not module_name.isidentifier(): 

106 module_name = f"custom_module_{source_sha}" 

107 assert module_name.isidentifier(), module_name 

108 

109 module = sys.modules.get(module_name) 

110 if module is None: 

111 try: 

112 if isinstance(local_source.path, Path): 

113 module_path = local_source.path 

114 elif isinstance(local_source.path, ZipPath): 

115 # save extract source to cache 

116 # loading from a file from disk ensure we get readable tracebacks 

117 # if any errors occur 

118 module_path = ( 

119 settings.cache_path / f"{source_sha}-{local_source.path.name}" 

120 ) 

121 _ = module_path.write_bytes(source_bytes) 

122 else: 

123 assert_never(local_source.path) 

124 

125 importlib_spec = importlib.util.spec_from_file_location( 

126 module_name, module_path 

127 ) 

128 

129 if importlib_spec is None: 

130 raise ImportError(f"Failed to import {source}") 

131 

132 module = importlib.util.module_from_spec(importlib_spec) 

133 assert importlib_spec.loader is not None 

134 importlib_spec.loader.exec_module(module) 

135 

136 except Exception as e: 

137 raise ImportError(f"Failed to import {source}") from e 

138 else: 

139 sys.modules[module_name] = module # cache this module 

140 

141 try: 

142 callable_attr = getattr(module, callable_name) 

143 except AttributeError as e: 

144 raise AttributeError( 

145 f"Imported custom module from {source} has no `{callable_name}` attribute." 

146 ) from e 

147 except Exception as e: 

148 raise AttributeError( 

149 f"Failed to access `{callable_name}` attribute from custom module imported from {source} ." 

150 ) from e 

151 

152 else: 

153 return callable_attr 

154 

155 

156def get_axes_infos( 

157 io_descr: Union[ 

158 v0_4.InputTensorDescr, 

159 v0_4.OutputTensorDescr, 

160 v0_5.InputTensorDescr, 

161 v0_5.OutputTensorDescr, 

162 ], 

163) -> List[AxisInfo]: 

164 """get a unified, simplified axis representation from spec axes""" 

165 ret: List[AxisInfo] = [] 

166 for a in io_descr.axes: 

167 if isinstance(a, v0_5.AxisBase): 

168 ret.append(AxisInfo.create(Axis(id=a.id, type=a.type))) 

169 else: 

170 assert a in ("b", "i", "t", "c", "z", "y", "x") 

171 ret.append(AxisInfo.create(a)) 

172 

173 return ret 

174 

175 

176def get_member_id( 

177 tensor_description: Union[ 

178 v0_4.InputTensorDescr, 

179 v0_4.OutputTensorDescr, 

180 v0_5.InputTensorDescr, 

181 v0_5.OutputTensorDescr, 

182 ], 

183) -> MemberId: 

184 """get the normalized tensor ID, usable as a sample member ID""" 

185 

186 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 

187 return MemberId(tensor_description.name) 

188 elif isinstance( 

189 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 

190 ): 

191 return tensor_description.id 

192 else: 

193 assert_never(tensor_description) 

194 

195 

196def get_member_ids( 

197 tensor_descriptions: Sequence[ 

198 Union[ 

199 v0_4.InputTensorDescr, 

200 v0_4.OutputTensorDescr, 

201 v0_5.InputTensorDescr, 

202 v0_5.OutputTensorDescr, 

203 ] 

204 ], 

205) -> List[MemberId]: 

206 """get normalized tensor IDs to be used as sample member IDs""" 

207 return [get_member_id(descr) for descr in tensor_descriptions] 

208 

209 

210def get_test_inputs(model: AnyModelDescr) -> Sample: 

211 """returns a model's test input sample""" 

212 member_ids = get_member_ids(model.inputs) 

213 if isinstance(model, v0_4.ModelDescr): 

214 arrays = [load_array(tt) for tt in model.test_inputs] 

215 else: 

216 arrays = [load_array(d.test_tensor) for d in model.inputs] 

217 

218 axes = [get_axes_infos(t) for t in model.inputs] 

219 return Sample( 

220 members={ 

221 m: Tensor.from_numpy(arr, dims=ax) 

222 for m, arr, ax in zip(member_ids, arrays, axes) 

223 }, 

224 stat={}, 

225 id="test-sample", 

226 ) 

227 

228 

229def get_test_outputs(model: AnyModelDescr) -> Sample: 

230 """returns a model's test output sample""" 

231 member_ids = get_member_ids(model.outputs) 

232 

233 if isinstance(model, v0_4.ModelDescr): 

234 arrays = [load_array(tt) for tt in model.test_outputs] 

235 else: 

236 arrays = [load_array(d.test_tensor) for d in model.outputs] 

237 

238 axes = [get_axes_infos(t) for t in model.outputs] 

239 

240 return Sample( 

241 members={ 

242 m: Tensor.from_numpy(arr, dims=ax) 

243 for m, arr, ax in zip(member_ids, arrays, axes) 

244 }, 

245 stat={}, 

246 id="test-sample", 

247 ) 

248 

249 

250class IO_SampleBlockMeta(NamedTuple): 

251 input: SampleBlockMeta 

252 output: SampleBlockMeta 

253 

254 

255def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 

256 """returns which halo input tensors need to be divided into blocks with, such that 

257 `output_halo` can be cropped from their outputs without introducing gaps.""" 

258 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 

259 outputs = {t.id: t for t in model.outputs} 

260 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 

261 

262 for t, th in output_halo.items(): 

263 axes = {a.id: a for a in outputs[t].axes} 

264 

265 for a, ah in th.items(): 

266 s = axes[a].size 

267 if not isinstance(s, v0_5.SizeReference): 

268 raise ValueError( 

269 f"Unable to map output halo for {t}.{a} to an input axis" 

270 ) 

271 

272 axis = axes[a] 

273 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 

274 

275 total_output_halo = sum(ah) 

276 total_input_halo = total_output_halo * axis.scale / ref_axis.scale 

277 assert ( 

278 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 

279 ) 

280 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 

281 int(total_input_halo // 2), int(total_input_halo // 2) 

282 ) 

283 

284 return input_halo 

285 

286 

287def get_block_transform( 

288 model: v0_5.ModelDescr, 

289) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 

290 """returns how a model's output tensor shapes relates to its input shapes""" 

291 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 

292 batch_axis_trf = None 

293 for ipt in model.inputs: 

294 for a in ipt.axes: 

295 if a.type == "batch": 

296 batch_axis_trf = LinearSampleAxisTransform( 

297 axis=a.id, scale=1, offset=0, member=ipt.id 

298 ) 

299 break 

300 if batch_axis_trf is not None: 

301 break 

302 axis_scales = { 

303 t.id: {a.id: a.scale for a in t.axes} 

304 for t in chain(model.inputs, model.outputs) 

305 } 

306 for out in model.outputs: 

307 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 

308 for a in out.axes: 

309 if a.size is None: 

310 assert a.type == "batch" 

311 if batch_axis_trf is None: 

312 raise ValueError( 

313 "no batch axis found in any input tensor, but output tensor" 

314 + f" '{out.id}' has one." 

315 ) 

316 s = batch_axis_trf 

317 elif isinstance(a.size, int): 

318 s = a.size 

319 elif isinstance(a.size, v0_5.DataDependentSize): 

320 s = -1 

321 elif isinstance(a.size, v0_5.SizeReference): 

322 s = LinearSampleAxisTransform( 

323 axis=a.size.axis_id, 

324 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 

325 offset=a.size.offset, 

326 member=a.size.tensor_id, 

327 ) 

328 else: 

329 assert_never(a.size) 

330 

331 new_axes[a.id] = s 

332 

333 ret[out.id] = new_axes 

334 

335 return ret 

336 

337 

338def get_io_sample_block_metas( 

339 model: v0_5.ModelDescr, 

340 input_sample_shape: PerMember[PerAxis[int]], 

341 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 

342 batch_size: int = 1, 

343) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 

344 """returns an iterable yielding meta data for corresponding input and output samples""" 

345 if not isinstance(model, v0_5.ModelDescr): 

346 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 

347 

348 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 

349 input_block_shape = { 

350 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 

351 for t in {tt for tt, _ in block_axis_sizes.inputs} 

352 } 

353 output_halo = { 

354 t.id: { 

355 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 

356 } 

357 for t in model.outputs 

358 } 

359 input_halo = get_input_halo(model, output_halo) 

360 

361 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 

362 input_sample_shape, input_block_shape, halo=input_halo 

363 ) 

364 block_transform = get_block_transform(model) 

365 return n_input_blocks, ( 

366 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 

367 for ipt in sample_block_meta_generator( 

368 input_blocks, sample_shape=input_sample_shape, sample_id=None 

369 ) 

370 ) 

371 

372 

373def get_tensor( 

374 src: Union[ZipPath, TensorSource], 

375 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 

376): 

377 """helper to cast/load various tensor sources""" 

378 

379 if isinstance(src, Tensor): 

380 return src 

381 

382 if isinstance(src, xr.DataArray): 

383 return Tensor.from_xarray(src) 

384 

385 if isinstance(src, np.ndarray): 

386 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 

387 

388 if isinstance(src, FileDescr): 

389 src = download(src).path 

390 

391 if isinstance(src, (ZipPath, Path, str)): 

392 return load_tensor(src, axes=get_axes_infos(ipt)) 

393 

394 assert_never(src) 

395 

396 

397def create_sample_for_model( 

398 model: AnyModelDescr, 

399 *, 

400 stat: Optional[Stat] = None, 

401 sample_id: SampleId = None, 

402 inputs: Union[PerMember[TensorSource], TensorSource], 

403) -> Sample: 

404 """Create a sample from a single set of input(s) for a specific bioimage.io model 

405 

406 Args: 

407 model: a bioimage.io model description 

408 stat: dictionary with sample and dataset statistics (may be updated in-place!) 

409 inputs: the input(s) constituting a single sample. 

410 """ 

411 

412 model_inputs = {get_member_id(d): d for d in model.inputs} 

413 if isinstance(inputs, collections.abc.Mapping): 

414 inputs = {MemberId(k): v for k, v in inputs.items()} 

415 elif len(model_inputs) == 1: 

416 inputs = {list(model_inputs)[0]: inputs} 

417 else: 

418 raise TypeError( 

419 f"Expected `inputs` to be a mapping with keys {tuple(model_inputs)}" 

420 ) 

421 

422 if unknown := {k for k in inputs if k not in model_inputs}: 

423 raise ValueError(f"Got unexpected inputs: {unknown}") 

424 

425 if missing := { 

426 k 

427 for k, v in model_inputs.items() 

428 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 

429 }: 

430 raise ValueError(f"Missing non-optional model inputs: {missing}") 

431 

432 return Sample( 

433 members={ 

434 m: get_tensor(inputs[m], ipt) 

435 for m, ipt in model_inputs.items() 

436 if m in inputs 

437 }, 

438 stat={} if stat is None else stat, 

439 id=sample_id, 

440 ) 

441 

442 

443def load_sample_for_model( 

444 *, 

445 model: AnyModelDescr, 

446 paths: PerMember[Path], 

447 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 

448 stat: Optional[Stat] = None, 

449 sample_id: Optional[SampleId] = None, 

450): 

451 """load a single sample from `paths` that can be processed by `model`""" 

452 

453 if axes is None: 

454 axes = {} 

455 

456 # make sure members are keyed by MemberId, not string 

457 paths = {MemberId(k): v for k, v in paths.items()} 

458 axes = {MemberId(k): v for k, v in axes.items()} 

459 

460 model_inputs = {get_member_id(d): d for d in model.inputs} 

461 

462 if unknown := {k for k in paths if k not in model_inputs}: 

463 raise ValueError(f"Got unexpected paths for {unknown}") 

464 

465 if unknown := {k for k in axes if k not in model_inputs}: 

466 raise ValueError(f"Got unexpected axes hints for: {unknown}") 

467 

468 members: Dict[MemberId, Tensor] = {} 

469 for m, p in paths.items(): 

470 if m not in axes: 

471 axes[m] = get_axes_infos(model_inputs[m]) 

472 logger.debug( 

473 "loading '{}' from {} with default input axes {} ", 

474 m, 

475 p, 

476 axes[m], 

477 ) 

478 members[m] = load_tensor(p, axes[m]) 

479 

480 return Sample( 

481 members=members, 

482 stat={} if stat is None else stat, 

483 id=sample_id or tuple(sorted(paths.values())), 

484 )