Coverage for bioimageio/core/_resource_tests.py: 75%

169 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1import traceback 

2import warnings 

3from itertools import product 

4from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union 

5 

6import numpy as np 

7from loguru import logger 

8 

9from bioimageio.spec import ( 

10 InvalidDescr, 

11 ResourceDescr, 

12 build_description, 

13 dump_description, 

14 load_description, 

15) 

16from bioimageio.spec._internal.common_nodes import ResourceDescrBase 

17from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource 

18from bioimageio.spec.get_conda_env import get_conda_env 

19from bioimageio.spec.model import v0_4, v0_5 

20from bioimageio.spec.model.v0_5 import WeightsFormat 

21from bioimageio.spec.summary import ( 

22 ErrorEntry, 

23 InstalledPackage, 

24 ValidationDetail, 

25 ValidationSummary, 

26) 

27 

28from ._prediction_pipeline import create_prediction_pipeline 

29from .axis import AxisId, BatchSize 

30from .digest_spec import get_test_inputs, get_test_outputs 

31from .sample import Sample 

32from .utils import VERSION 

33 

34 

35def enable_determinism(mode: Literal["seed_only", "full"]): 

36 """Seed and configure ML frameworks for maximum reproducibility. 

37 May degrade performance. Only recommended for testing reproducibility! 

38 

39 Seed any random generators and (if **mode**=="full") request ML frameworks to use 

40 deterministic algorithms. 

41 Notes: 

42 - **mode** == "full" might degrade performance and throw exceptions. 

43 - Subsequent inference calls might still differ. Call before each function 

44 (sequence) that is expected to be reproducible. 

45 - Degraded performance: Use for testing reproducibility only! 

46 - Recipes: 

47 - [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html) 

48 - [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/) 

49 - [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html) 

50 """ 

51 try: 

52 try: 

53 import numpy.random 

54 except ImportError: 

55 pass 

56 else: 

57 numpy.random.seed(0) 

58 except Exception as e: 

59 logger.debug(str(e)) 

60 

61 try: 

62 try: 

63 import torch 

64 except ImportError: 

65 pass 

66 else: 

67 _ = torch.manual_seed(0) 

68 torch.use_deterministic_algorithms(mode == "full") 

69 except Exception as e: 

70 logger.debug(str(e)) 

71 

72 try: 

73 try: 

74 import keras 

75 except ImportError: 

76 pass 

77 else: 

78 keras.utils.set_random_seed(0) 

79 except Exception as e: 

80 logger.debug(str(e)) 

81 

82 try: 

83 try: 

84 import tensorflow as tf # pyright: ignore[reportMissingImports] 

85 except ImportError: 

86 pass 

87 else: 

88 tf.random.seed(0) 

89 if mode == "full": 

90 tf.config.experimental.enable_op_determinism() 

91 # TODO: find possibility to switch it off again?? 

92 except Exception as e: 

93 logger.debug(str(e)) 

94 

95 

96def test_model( 

97 source: Union[v0_5.ModelDescr, PermissiveFileSource], 

98 weight_format: Optional[WeightsFormat] = None, 

99 devices: Optional[List[str]] = None, 

100 absolute_tolerance: float = 1.5e-4, 

101 relative_tolerance: float = 1e-4, 

102 decimal: Optional[int] = None, 

103 *, 

104 determinism: Literal["seed_only", "full"] = "seed_only", 

105) -> ValidationSummary: 

106 """Test model inference""" 

107 return test_description( 

108 source, 

109 weight_format=weight_format, 

110 devices=devices, 

111 absolute_tolerance=absolute_tolerance, 

112 relative_tolerance=relative_tolerance, 

113 decimal=decimal, 

114 determinism=determinism, 

115 expected_type="model", 

116 ) 

117 

118 

119def test_description( 

120 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

121 *, 

122 format_version: Union[Literal["discover", "latest"], str] = "discover", 

123 weight_format: Optional[WeightsFormat] = None, 

124 devices: Optional[Sequence[str]] = None, 

125 absolute_tolerance: float = 1.5e-4, 

126 relative_tolerance: float = 1e-4, 

127 decimal: Optional[int] = None, 

128 determinism: Literal["seed_only", "full"] = "seed_only", 

129 expected_type: Optional[str] = None, 

130) -> ValidationSummary: 

131 """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models""" 

132 rd = load_description_and_test( 

133 source, 

134 format_version=format_version, 

135 weight_format=weight_format, 

136 devices=devices, 

137 absolute_tolerance=absolute_tolerance, 

138 relative_tolerance=relative_tolerance, 

139 decimal=decimal, 

140 determinism=determinism, 

141 expected_type=expected_type, 

142 ) 

143 return rd.validation_summary 

144 

145 

146def load_description_and_test( 

147 source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], 

148 *, 

149 format_version: Union[Literal["discover", "latest"], str] = "discover", 

150 weight_format: Optional[WeightsFormat] = None, 

151 devices: Optional[Sequence[str]] = None, 

152 absolute_tolerance: float = 1.5e-4, 

153 relative_tolerance: float = 1e-4, 

154 decimal: Optional[int] = None, 

155 determinism: Literal["seed_only", "full"] = "seed_only", 

156 expected_type: Optional[str] = None, 

157) -> Union[ResourceDescr, InvalidDescr]: 

158 """Test RDF dynamically, e.g. model inference of test inputs""" 

159 if ( 

160 isinstance(source, ResourceDescrBase) 

161 and format_version != "discover" 

162 and source.format_version != format_version 

163 ): 

164 warnings.warn( 

165 f"deserializing source to ensure we validate and test using format {format_version}" 

166 ) 

167 source = dump_description(source) 

168 

169 if isinstance(source, ResourceDescrBase): 

170 rd = source 

171 elif isinstance(source, dict): 

172 rd = build_description(source, format_version=format_version) 

173 else: 

174 rd = load_description(source, format_version=format_version) 

175 

176 rd.validation_summary.env.add( 

177 InstalledPackage(name="bioimageio.core", version=VERSION) 

178 ) 

179 

180 if expected_type is not None: 

181 _test_expected_resource_type(rd, expected_type) 

182 

183 if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): 

184 if weight_format is None: 

185 weight_formats: List[WeightsFormat] = [ 

186 w for w, we in rd.weights if we is not None 

187 ] # pyright: ignore[reportAssignmentType] 

188 else: 

189 weight_formats = [weight_format] 

190 

191 if decimal is None: 

192 atol = absolute_tolerance 

193 rtol = relative_tolerance 

194 else: 

195 warnings.warn( 

196 "The argument `decimal` has been deprecated in favour of" 

197 + " `relative_tolerance` and `absolute_tolerance`, with different" 

198 + " validation logic, using `numpy.testing.assert_allclose, see" 

199 + " 'https://numpy.org/doc/stable/reference/generated/" 

200 + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" 

201 + " will cause validation to revert to the old behaviour." 

202 ) 

203 atol = 1.5 * 10 ** (-decimal) 

204 rtol = 0 

205 

206 enable_determinism(determinism) 

207 for w in weight_formats: 

208 _test_model_inference(rd, w, devices, atol, rtol) 

209 if not isinstance(rd, v0_4.ModelDescr): 

210 _test_model_inference_parametrized(rd, w, devices) 

211 

212 # TODO: add execution of jupyter notebooks 

213 # TODO: add more tests 

214 

215 return rd 

216 

217 

218def _test_model_inference( 

219 model: Union[v0_4.ModelDescr, v0_5.ModelDescr], 

220 weight_format: WeightsFormat, 

221 devices: Optional[Sequence[str]], 

222 atol: float, 

223 rtol: float, 

224) -> None: 

225 test_name = f"Reproduce test outputs from test inputs ({weight_format})" 

226 logger.info("starting '{}'", test_name) 

227 error: Optional[str] = None 

228 tb: List[str] = [] 

229 

230 try: 

231 inputs = get_test_inputs(model) 

232 expected = get_test_outputs(model) 

233 

234 with create_prediction_pipeline( 

235 bioimageio_model=model, devices=devices, weight_format=weight_format 

236 ) as prediction_pipeline: 

237 results = prediction_pipeline.predict_sample_without_blocking(inputs) 

238 

239 if len(results.members) != len(expected.members): 

240 error = f"Expected {len(expected.members)} outputs, but got {len(results.members)}" 

241 

242 else: 

243 for m, exp in expected.members.items(): 

244 res = results.members.get(m) 

245 if res is None: 

246 error = "Output tensors for test case may not be None" 

247 break 

248 try: 

249 np.testing.assert_allclose( 

250 res.data, 

251 exp.data, 

252 rtol=rtol, 

253 atol=atol, 

254 ) 

255 except AssertionError as e: 

256 error = f"Output and expected output disagree:\n {e}" 

257 break 

258 except Exception as e: 

259 error = str(e) 

260 tb = traceback.format_tb(e.__traceback__) 

261 

262 model.validation_summary.add_detail( 

263 ValidationDetail( 

264 name=test_name, 

265 loc=("weights", weight_format), 

266 status="passed" if error is None else "failed", 

267 recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), 

268 errors=( 

269 [] 

270 if error is None 

271 else [ 

272 ErrorEntry( 

273 loc=("weights", weight_format), 

274 msg=error, 

275 type="bioimageio.core", 

276 traceback=tb, 

277 ) 

278 ] 

279 ), 

280 ) 

281 ) 

282 

283 

284def _test_model_inference_parametrized( 

285 model: v0_5.ModelDescr, 

286 weight_format: WeightsFormat, 

287 devices: Optional[Sequence[str]], 

288) -> None: 

289 if not any( 

290 isinstance(a.size, v0_5.ParameterizedSize) 

291 for ipt in model.inputs 

292 for a in ipt.axes 

293 ): 

294 # no parameterized sizes => set n=0 

295 ns: Set[v0_5.ParameterizedSize_N] = {0} 

296 else: 

297 ns = {0, 1, 2} 

298 

299 given_batch_sizes = { 

300 a.size 

301 for ipt in model.inputs 

302 for a in ipt.axes 

303 if isinstance(a, v0_5.BatchAxis) 

304 } 

305 if given_batch_sizes: 

306 batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None} 

307 if not batch_sizes: 

308 # only arbitrary batch sizes 

309 batch_sizes = {1, 2} 

310 else: 

311 # no batch axis 

312 batch_sizes = {1} 

313 

314 test_cases: Set[Tuple[v0_5.ParameterizedSize_N, BatchSize]] = { 

315 (n, b) for n, b in product(sorted(ns), sorted(batch_sizes)) 

316 } 

317 logger.info( 

318 "Testing inference with {} different input tensor sizes", len(test_cases) 

319 ) 

320 

321 def generate_test_cases(): 

322 tested: Set[Hashable] = set() 

323 

324 def get_ns(n: int): 

325 return { 

326 (t.id, a.id): n 

327 for t in model.inputs 

328 for a in t.axes 

329 if isinstance(a.size, v0_5.ParameterizedSize) 

330 } 

331 

332 for n, batch_size in sorted(test_cases): 

333 input_target_sizes, expected_output_sizes = model.get_axis_sizes( 

334 get_ns(n), batch_size=batch_size 

335 ) 

336 hashable_target_size = tuple( 

337 (k, input_target_sizes[k]) for k in sorted(input_target_sizes) 

338 ) 

339 if hashable_target_size in tested: 

340 continue 

341 else: 

342 tested.add(hashable_target_size) 

343 

344 resized_test_inputs = Sample( 

345 members={ 

346 t.id: test_inputs.members[t.id].resize_to( 

347 { 

348 aid: s 

349 for (tid, aid), s in input_target_sizes.items() 

350 if tid == t.id 

351 }, 

352 ) 

353 for t in model.inputs 

354 }, 

355 stat=test_inputs.stat, 

356 id=test_inputs.id, 

357 ) 

358 expected_output_shapes = { 

359 t.id: { 

360 aid: s 

361 for (tid, aid), s in expected_output_sizes.items() 

362 if tid == t.id 

363 } 

364 for t in model.outputs 

365 } 

366 yield n, batch_size, resized_test_inputs, expected_output_shapes 

367 

368 try: 

369 test_inputs = get_test_inputs(model) 

370 

371 with create_prediction_pipeline( 

372 bioimageio_model=model, devices=devices, weight_format=weight_format 

373 ) as prediction_pipeline: 

374 for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): 

375 error: Optional[str] = None 

376 result = prediction_pipeline.predict_sample_without_blocking(inputs) 

377 if len(result.members) != len(exptected_output_shape): 

378 error = ( 

379 f"Expected {len(exptected_output_shape)} outputs," 

380 + f" but got {len(result.members)}" 

381 ) 

382 

383 else: 

384 for m, exp in exptected_output_shape.items(): 

385 res = result.members.get(m) 

386 if res is None: 

387 error = "Output tensors may not be None for test case" 

388 break 

389 

390 diff: Dict[AxisId, int] = {} 

391 for a, s in res.sizes.items(): 

392 if isinstance((e_aid := exp[AxisId(a)]), int): 

393 if s != e_aid: 

394 diff[AxisId(a)] = s 

395 elif ( 

396 s < e_aid.min or e_aid.max is not None and s > e_aid.max 

397 ): 

398 diff[AxisId(a)] = s 

399 if diff: 

400 error = ( 

401 f"(n={n}) Expected output shape {exp}," 

402 + f" but got {res.sizes} (diff: {diff})" 

403 ) 

404 break 

405 

406 model.validation_summary.add_detail( 

407 ValidationDetail( 

408 name=f"Run {weight_format} inference for inputs with" 

409 + f" batch_size: {batch_size} and size parameter n: {n}", 

410 loc=("weights", weight_format), 

411 status="passed" if error is None else "failed", 

412 errors=( 

413 [] 

414 if error is None 

415 else [ 

416 ErrorEntry( 

417 loc=("weights", weight_format), 

418 msg=error, 

419 type="bioimageio.core", 

420 ) 

421 ] 

422 ), 

423 ) 

424 ) 

425 except Exception as e: 

426 error = str(e) 

427 tb = traceback.format_tb(e.__traceback__) 

428 model.validation_summary.add_detail( 

429 ValidationDetail( 

430 name=f"Run {weight_format} inference for parametrized inputs", 

431 status="failed", 

432 loc=("weights", weight_format), 

433 errors=[ 

434 ErrorEntry( 

435 loc=("weights", weight_format), 

436 msg=error, 

437 type="bioimageio.core", 

438 traceback=tb, 

439 ) 

440 ], 

441 ) 

442 ) 

443 

444 

445def _test_expected_resource_type( 

446 rd: Union[InvalidDescr, ResourceDescr], expected_type: str 

447): 

448 has_expected_type = rd.type == expected_type 

449 rd.validation_summary.details.append( 

450 ValidationDetail( 

451 name="Has expected resource type", 

452 status="passed" if has_expected_type else "failed", 

453 loc=("type",), 

454 errors=( 

455 [] 

456 if has_expected_type 

457 else [ 

458 ErrorEntry( 

459 loc=("type",), 

460 type="type", 

461 msg=f"expected type {expected_type}, found {rd.type}", 

462 ) 

463 ] 

464 ), 

465 ) 

466 ) 

467 

468 

469# TODO: Implement `debug_model()` 

470# def debug_model( 

471# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], 

472# *, 

473# weight_format: Optional[WeightsFormat] = None, 

474# devices: Optional[List[str]] = None, 

475# ): 

476# """Run the model test and return dict with inputs, results, expected results and intermediates. 

477 

478# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". 

479# """ 

480# inputs_raw: Optional = None 

481# inputs_processed: Optional = None 

482# outputs_raw: Optional = None 

483# outputs: Optional = None 

484# expected: Optional = None 

485# diff: Optional = None 

486 

487# model = load_description( 

488# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] 

489# ) 

490# if not isinstance(model, Model): 

491# raise ValueError(f"Not a bioimageio.model: {model_rdf}") 

492 

493# prediction_pipeline = create_prediction_pipeline( 

494# bioimageio_model=model, devices=devices, weight_format=weight_format 

495# ) 

496# inputs = [ 

497# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) 

498# for in_path, input_spec in zip(model.test_inputs, model.inputs) 

499# ] 

500# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} 

501 

502# # keep track of the non-processed inputs 

503# inputs_raw = [deepcopy(input) for input in inputs] 

504 

505# computed_measures = {} 

506 

507# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) 

508# inputs_processed = list(input_dict.values()) 

509# outputs_raw = prediction_pipeline.predict(*inputs_processed) 

510# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} 

511# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) 

512# outputs = list(output_dict.values()) 

513 

514# if isinstance(outputs, (np.ndarray, xr.DataArray)): 

515# outputs = [outputs] 

516 

517# expected = [ 

518# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) 

519# for out_path, output_spec in zip(model.test_outputs, model.outputs) 

520# ] 

521# if len(outputs) != len(expected): 

522# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" 

523# print(error) 

524# else: 

525# diff = [] 

526# for res, exp in zip(outputs, expected): 

527# diff.append(res - exp) 

528 

529# return { 

530# "inputs": inputs_raw, 

531# "inputs_processed": inputs_processed, 

532# "outputs_raw": outputs_raw, 

533# "outputs": outputs, 

534# "expected": expected, 

535# "diff": diff, 

536# }