Coverage for bioimageio/core/digest_spec.py: 90%

162 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1from __future__ import annotations 

2 

3import importlib.util 

4from itertools import chain 

5from pathlib import Path 

6from typing import ( 

7 Any, 

8 Callable, 

9 Dict, 

10 Iterable, 

11 List, 

12 Mapping, 

13 NamedTuple, 

14 Optional, 

15 Sequence, 

16 Tuple, 

17 Union, 

18) 

19 

20import numpy as np 

21import xarray as xr 

22from loguru import logger 

23from numpy.typing import NDArray 

24from typing_extensions import Unpack, assert_never 

25 

26from bioimageio.spec._internal.io import resolve_and_extract 

27from bioimageio.spec._internal.io_utils import HashKwargs 

28from bioimageio.spec.common import FileSource 

29from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 

30from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile 

31from bioimageio.spec.model.v0_5 import ( 

32 ArchitectureFromFileDescr, 

33 ArchitectureFromLibraryDescr, 

34 ParameterizedSize_N, 

35) 

36from bioimageio.spec.utils import load_array 

37 

38from .axis import AxisId, AxisInfo, AxisLike, PerAxis 

39from .block_meta import split_multiple_shapes_into_blocks 

40from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks 

41from .io import load_tensor 

42from .sample import ( 

43 LinearSampleAxisTransform, 

44 Sample, 

45 SampleBlockMeta, 

46 sample_block_meta_generator, 

47) 

48from .stat_measures import Stat 

49from .tensor import Tensor 

50 

51 

52def import_callable( 

53 node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr], 

54 /, 

55 **kwargs: Unpack[HashKwargs], 

56) -> Callable[..., Any]: 

57 """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" 

58 if isinstance(node, CallableFromDepencency): 

59 module = importlib.import_module(node.module_name) 

60 c = getattr(module, str(node.callable_name)) 

61 elif isinstance(node, ArchitectureFromLibraryDescr): 

62 module = importlib.import_module(node.import_from) 

63 c = getattr(module, str(node.callable)) 

64 elif isinstance(node, CallableFromFile): 

65 c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) 

66 elif isinstance(node, ArchitectureFromFileDescr): 

67 c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) 

68 

69 else: 

70 assert_never(node) 

71 

72 if not callable(c): 

73 raise ValueError(f"{node} (imported: {c}) is not callable") 

74 

75 return c 

76 

77 

78def _import_from_file_impl( 

79 source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] 

80): 

81 local_file = resolve_and_extract(source, **kwargs) 

82 module_name = local_file.path.stem 

83 importlib_spec = importlib.util.spec_from_file_location( 

84 module_name, local_file.path 

85 ) 

86 if importlib_spec is None: 

87 raise ImportError(f"Failed to import {module_name} from {source}.") 

88 

89 dep = importlib.util.module_from_spec(importlib_spec) 

90 importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? 

91 return getattr(dep, callable_name) 

92 

93 

94def get_axes_infos( 

95 io_descr: Union[ 

96 v0_4.InputTensorDescr, 

97 v0_4.OutputTensorDescr, 

98 v0_5.InputTensorDescr, 

99 v0_5.OutputTensorDescr, 

100 ], 

101) -> List[AxisInfo]: 

102 """get a unified, simplified axis representation from spec axes""" 

103 return [ 

104 ( 

105 AxisInfo.create("i") 

106 if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x") 

107 else AxisInfo.create(a) 

108 ) 

109 for a in io_descr.axes 

110 ] 

111 

112 

113def get_member_id( 

114 tensor_description: Union[ 

115 v0_4.InputTensorDescr, 

116 v0_4.OutputTensorDescr, 

117 v0_5.InputTensorDescr, 

118 v0_5.OutputTensorDescr, 

119 ], 

120) -> MemberId: 

121 """get the normalized tensor ID, usable as a sample member ID""" 

122 

123 if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): 

124 return MemberId(tensor_description.name) 

125 elif isinstance( 

126 tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) 

127 ): 

128 return tensor_description.id 

129 else: 

130 assert_never(tensor_description) 

131 

132 

133def get_member_ids( 

134 tensor_descriptions: Sequence[ 

135 Union[ 

136 v0_4.InputTensorDescr, 

137 v0_4.OutputTensorDescr, 

138 v0_5.InputTensorDescr, 

139 v0_5.OutputTensorDescr, 

140 ] 

141 ], 

142) -> List[MemberId]: 

143 """get normalized tensor IDs to be used as sample member IDs""" 

144 return [get_member_id(descr) for descr in tensor_descriptions] 

145 

146 

147def get_test_inputs(model: AnyModelDescr) -> Sample: 

148 """returns a model's test input sample""" 

149 member_ids = get_member_ids(model.inputs) 

150 if isinstance(model, v0_4.ModelDescr): 

151 arrays = [load_array(tt) for tt in model.test_inputs] 

152 else: 

153 arrays = [load_array(d.test_tensor) for d in model.inputs] 

154 

155 axes = [get_axes_infos(t) for t in model.inputs] 

156 return Sample( 

157 members={ 

158 m: Tensor.from_numpy(arr, dims=ax) 

159 for m, arr, ax in zip(member_ids, arrays, axes) 

160 }, 

161 stat={}, 

162 id="test-sample", 

163 ) 

164 

165 

166def get_test_outputs(model: AnyModelDescr) -> Sample: 

167 """returns a model's test output sample""" 

168 member_ids = get_member_ids(model.outputs) 

169 

170 if isinstance(model, v0_4.ModelDescr): 

171 arrays = [load_array(tt) for tt in model.test_outputs] 

172 else: 

173 arrays = [load_array(d.test_tensor) for d in model.outputs] 

174 

175 axes = [get_axes_infos(t) for t in model.outputs] 

176 

177 return Sample( 

178 members={ 

179 m: Tensor.from_numpy(arr, dims=ax) 

180 for m, arr, ax in zip(member_ids, arrays, axes) 

181 }, 

182 stat={}, 

183 id="test-sample", 

184 ) 

185 

186 

187class IO_SampleBlockMeta(NamedTuple): 

188 input: SampleBlockMeta 

189 output: SampleBlockMeta 

190 

191 

192def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): 

193 """returns which halo input tensors need to be divided into blocks with, such that 

194 `output_halo` can be cropped from their outputs without introducing gaps.""" 

195 input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} 

196 outputs = {t.id: t for t in model.outputs} 

197 all_tensors = {**{t.id: t for t in model.inputs}, **outputs} 

198 

199 for t, th in output_halo.items(): 

200 axes = {a.id: a for a in outputs[t].axes} 

201 

202 for a, ah in th.items(): 

203 s = axes[a].size 

204 if not isinstance(s, v0_5.SizeReference): 

205 raise ValueError( 

206 f"Unable to map output halo for {t}.{a} to an input axis" 

207 ) 

208 

209 axis = axes[a] 

210 ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] 

211 

212 total_output_halo = sum(ah) 

213 total_input_halo = total_output_halo * axis.scale / ref_axis.scale 

214 assert ( 

215 total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 

216 ) 

217 input_halo.setdefault(s.tensor_id, {})[a] = Halo( 

218 int(total_input_halo // 2), int(total_input_halo // 2) 

219 ) 

220 

221 return input_halo 

222 

223 

224def get_block_transform( 

225 model: v0_5.ModelDescr, 

226) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: 

227 """returns how a model's output tensor shapes relates to its input shapes""" 

228 ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} 

229 batch_axis_trf = None 

230 for ipt in model.inputs: 

231 for a in ipt.axes: 

232 if a.type == "batch": 

233 batch_axis_trf = LinearSampleAxisTransform( 

234 axis=a.id, scale=1, offset=0, member=ipt.id 

235 ) 

236 break 

237 if batch_axis_trf is not None: 

238 break 

239 axis_scales = { 

240 t.id: {a.id: a.scale for a in t.axes} 

241 for t in chain(model.inputs, model.outputs) 

242 } 

243 for out in model.outputs: 

244 new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} 

245 for a in out.axes: 

246 if a.size is None: 

247 assert a.type == "batch" 

248 if batch_axis_trf is None: 

249 raise ValueError( 

250 "no batch axis found in any input tensor, but output tensor" 

251 + f" '{out.id}' has one." 

252 ) 

253 s = batch_axis_trf 

254 elif isinstance(a.size, int): 

255 s = a.size 

256 elif isinstance(a.size, v0_5.DataDependentSize): 

257 s = -1 

258 elif isinstance(a.size, v0_5.SizeReference): 

259 s = LinearSampleAxisTransform( 

260 axis=a.size.axis_id, 

261 scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, 

262 offset=a.size.offset, 

263 member=a.size.tensor_id, 

264 ) 

265 else: 

266 assert_never(a.size) 

267 

268 new_axes[a.id] = s 

269 

270 ret[out.id] = new_axes 

271 

272 return ret 

273 

274 

275def get_io_sample_block_metas( 

276 model: v0_5.ModelDescr, 

277 input_sample_shape: PerMember[PerAxis[int]], 

278 ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize_N], 

279 batch_size: int = 1, 

280) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: 

281 """returns an iterable yielding meta data for corresponding input and output samples""" 

282 if not isinstance(model, v0_5.ModelDescr): 

283 raise TypeError(f"get_block_meta() not implemented for {type(model)}") 

284 

285 block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) 

286 input_block_shape = { 

287 t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} 

288 for t in {tt for tt, _ in block_axis_sizes.inputs} 

289 } 

290 output_halo = { 

291 t.id: { 

292 a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) 

293 } 

294 for t in model.outputs 

295 } 

296 input_halo = get_input_halo(model, output_halo) 

297 

298 n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( 

299 input_sample_shape, input_block_shape, halo=input_halo 

300 ) 

301 block_transform = get_block_transform(model) 

302 return n_input_blocks, ( 

303 IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) 

304 for ipt in sample_block_meta_generator( 

305 input_blocks, sample_shape=input_sample_shape, sample_id=None 

306 ) 

307 ) 

308 

309 

310def get_tensor( 

311 src: Union[Tensor, xr.DataArray, NDArray[Any], Path], 

312 ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr], 

313): 

314 """helper to cast/load various tensor sources""" 

315 

316 if isinstance(src, Tensor): 

317 return src 

318 

319 if isinstance(src, xr.DataArray): 

320 return Tensor.from_xarray(src) 

321 

322 if isinstance(src, np.ndarray): 

323 return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) 

324 

325 if isinstance(src, Path): 

326 return load_tensor(src, axes=get_axes_infos(ipt)) 

327 

328 assert_never(src) 

329 

330 

331def create_sample_for_model( 

332 model: AnyModelDescr, 

333 *, 

334 stat: Optional[Stat] = None, 

335 sample_id: SampleId = None, 

336 inputs: Optional[ 

337 PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]] 

338 ] = None, # TODO: make non-optional 

339 **kwargs: NDArray[Any], # TODO: deprecate in favor of `inputs` 

340) -> Sample: 

341 """Create a sample from a single set of input(s) for a specific bioimage.io model 

342 

343 Args: 

344 model: a bioimage.io model description 

345 stat: dictionary with sample and dataset statistics (may be updated in-place!) 

346 inputs: the input(s) constituting a single sample. 

347 """ 

348 inputs = {MemberId(k): v for k, v in {**kwargs, **(inputs or {})}.items()} 

349 

350 model_inputs = {get_member_id(d): d for d in model.inputs} 

351 if unknown := {k for k in inputs if k not in model_inputs}: 

352 raise ValueError(f"Got unexpected inputs: {unknown}") 

353 

354 if missing := { 

355 k 

356 for k, v in model_inputs.items() 

357 if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) 

358 }: 

359 raise ValueError(f"Missing non-optional model inputs: {missing}") 

360 

361 return Sample( 

362 members={ 

363 m: get_tensor(inputs[m], ipt) 

364 for m, ipt in model_inputs.items() 

365 if m in inputs 

366 }, 

367 stat={} if stat is None else stat, 

368 id=sample_id, 

369 ) 

370 

371 

372def load_sample_for_model( 

373 *, 

374 model: AnyModelDescr, 

375 paths: PerMember[Path], 

376 axes: Optional[PerMember[Sequence[AxisLike]]] = None, 

377 stat: Optional[Stat] = None, 

378 sample_id: Optional[SampleId] = None, 

379): 

380 """load a single sample from `paths` that can be processed by `model`""" 

381 

382 if axes is None: 

383 axes = {} 

384 

385 # make sure members are keyed by MemberId, not string 

386 paths = {MemberId(k): v for k, v in paths.items()} 

387 axes = {MemberId(k): v for k, v in axes.items()} 

388 

389 model_inputs = {get_member_id(d): d for d in model.inputs} 

390 

391 if unknown := {k for k in paths if k not in model_inputs}: 

392 raise ValueError(f"Got unexpected paths for {unknown}") 

393 

394 if unknown := {k for k in axes if k not in model_inputs}: 

395 raise ValueError(f"Got unexpected axes hints for: {unknown}") 

396 

397 members: Dict[MemberId, Tensor] = {} 

398 for m, p in paths.items(): 

399 if m not in axes: 

400 axes[m] = get_axes_infos(model_inputs[m]) 

401 logger.debug( 

402 "loading '{}' from {} with default input axes {} ", 

403 m, 

404 p, 

405 axes[m], 

406 ) 

407 members[m] = load_tensor(p, axes[m]) 

408 

409 return Sample( 

410 members=members, 

411 stat={} if stat is None else stat, 

412 id=sample_id or tuple(sorted(paths.values())), 

413 )