Coverage for src / bioimageio / spec / _internal / common_nodes.py: 87%

205 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 08:15 +0000

1from __future__ import annotations 

2 

3from abc import ABC 

4from inspect import signature 

5from io import BytesIO 

6from pathlib import Path 

7from types import MappingProxyType 

8from typing import ( 

9 IO, 

10 TYPE_CHECKING, 

11 Any, 

12 ClassVar, 

13 Dict, 

14 Iterable, 

15 List, 

16 Literal, 

17 Mapping, 

18 Optional, 

19 Protocol, 

20 Tuple, 

21 TypeVar, 

22 Union, 

23) 

24from zipfile import ZipFile 

25 

26import pydantic 

27from pydantic import DirectoryPath, PrivateAttr, model_validator 

28from pydantic_core import PydanticUndefined 

29from typing_extensions import Callable, ParamSpec, Self 

30 

31from ..summary import ( 

32 WARNING_LEVEL_TO_NAME, 

33 ErrorEntry, 

34 ValidationDetail, 

35 ValidationSummary, 

36 WarningEntry, 

37) 

38from .field_warning import issue_warning 

39from .io import ( 

40 BioimageioYamlContent, 

41 FileDescr, 

42 IncompleteDescr, 

43 IncompleteDescrView, 

44 deepcopy_incomplete_descr, 

45 extract_file_descrs, 

46 populate_cache, 

47) 

48from .io_basics import BIOIMAGEIO_YAML, FileName 

49from .io_utils import write_content_to_zip 

50from .node import Node 

51from .packaging_context import PackagingContext 

52from .root_url import RootHttpUrl 

53from .type_guards import is_dict 

54from .utils import get_format_version_tuple 

55from .validation_context import ValidationContext, get_validation_context 

56from .warning_levels import ALERT, ERROR, INFO 

57 

58 

59class NodeWithExplicitlySetFields(Node): 

60 _fields_to_set_explicitly: ClassVar[Mapping[str, Any]] 

61 

62 @classmethod 

63 def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: 

64 explict_fields: Dict[str, Any] = {} 

65 for attr in dir(cls): 

66 if attr.startswith("implemented_"): 

67 field_name = attr.replace("implemented_", "") 

68 if field_name not in cls.model_fields: 

69 continue 

70 

71 assert ( 

72 cls.model_fields[field_name].get_default() is PydanticUndefined 

73 ), field_name 

74 default = getattr(cls, attr) 

75 explict_fields[field_name] = default 

76 

77 cls._fields_to_set_explicitly = MappingProxyType(explict_fields) 

78 return super().__pydantic_init_subclass__(**kwargs) 

79 

80 @model_validator(mode="before") 

81 @classmethod 

82 def _set_fields_explicitly( 

83 cls, data: Union[Any, Dict[str, Any]] 

84 ) -> Union[Any, Dict[str, Any]]: 

85 if isinstance(data, dict): 

86 for name, default in cls._fields_to_set_explicitly.items(): 

87 if name not in data: 

88 data[name] = default 

89 

90 return data # pyright: ignore[reportUnknownVariableType] 

91 

92 

93if TYPE_CHECKING: 

94 

95 class _ResourceDescrBaseAbstractFieldsProtocol(Protocol): 

96 """workaround to add "abstract" fields to ResourceDescrBase""" 

97 

98 # TODO: implement as proper abstract fields of ResourceDescrBase 

99 

100 type: Any # should be LiteralString 

101 format_version: Any # should be LiteralString 

102 implemented_type: ClassVar[Any] 

103 implemented_format_version: ClassVar[Any] 

104 

105else: 

106 

107 class _ResourceDescrBaseAbstractFieldsProtocol: 

108 pass 

109 

110 

111P = ParamSpec("P") 

112T = TypeVar("T") 

113 

114 

115class ResourceDescrBase( 

116 NodeWithExplicitlySetFields, ABC, _ResourceDescrBaseAbstractFieldsProtocol 

117): 

118 """base class for all resource descriptions""" 

119 

120 _validation_summary: Optional[ValidationSummary] = None 

121 

122 implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] 

123 

124 # @field_validator("format_version", mode="before", check_fields=False) 

125 # field_validator on "format_version" is not possible, because we want to use 

126 # "format_version" in a descriminated Union higher up 

127 # (PydanticUserError: Cannot use a mode='before' validator in the discriminator 

128 # field 'format_version' of Model 'CollectionDescr') 

129 @model_validator(mode="before") 

130 @classmethod 

131 def _ignore_future_patch(cls, data: Any, /) -> Any: 

132 if ( 

133 cls.implemented_format_version == "unknown" 

134 or not is_dict(data) 

135 or "format_version" not in data 

136 ): 

137 return data 

138 

139 value = data["format_version"] 

140 fv = get_format_version_tuple(value) 

141 if fv is None: 

142 return data 

143 if ( 

144 fv[0] == cls.implemented_format_version_tuple[0] 

145 and fv[1:] > cls.implemented_format_version_tuple[1:] 

146 ): 

147 issue_warning( 

148 "future format_version '{value}' treated as '{implemented}'", 

149 value=value, 

150 msg_context=dict(implemented=cls.implemented_format_version), 

151 severity=ALERT, 

152 ) 

153 data["format_version"] = cls.implemented_format_version 

154 

155 return data 

156 

157 @model_validator(mode="after") 

158 def _set_init_validation_summary(self) -> Self: 

159 context = get_validation_context() 

160 

161 self._validation_summary = ValidationSummary( 

162 name="bioimageio format validation", 

163 source_name=context.source_name, 

164 id=getattr(self, "id", None), 

165 version=getattr(self, "version", None), 

166 type=self.type, 

167 format_version=self.format_version, 

168 status="failed" if isinstance(self, InvalidDescr) else "valid-format", 

169 metadata_completeness=self._get_metadata_completeness(), 

170 details=( 

171 [] 

172 if isinstance(self, InvalidDescr) 

173 else [ 

174 ValidationDetail( 

175 name=f"Successfully created `{self.__class__.__name__}` instance.", 

176 status="passed", 

177 context=context.summary, 

178 ) 

179 ] 

180 ), 

181 ) 

182 return self 

183 

184 @property 

185 def validation_summary(self) -> ValidationSummary: 

186 assert self._validation_summary is not None, "access only after initialization" 

187 return self._validation_summary 

188 

189 _root: Union[RootHttpUrl, DirectoryPath, ZipFile] = PrivateAttr( 

190 default_factory=lambda: get_validation_context().root 

191 ) 

192 

193 _file_name: Optional[FileName] = PrivateAttr( 

194 default_factory=lambda: get_validation_context().file_name 

195 ) 

196 

197 @property 

198 def root(self) -> Union[RootHttpUrl, DirectoryPath, ZipFile]: 

199 """The URL/Path prefix to resolve any relative paths with.""" 

200 return self._root 

201 

202 @property 

203 def file_name(self) -> Optional[FileName]: 

204 """File name of the bioimageio.yaml file the description was loaded from.""" 

205 return self._file_name 

206 

207 @classmethod 

208 def __pydantic_init_subclass__(cls, **kwargs: Any): 

209 super().__pydantic_init_subclass__(**kwargs) 

210 # set classvar implemented_format_version_tuple 

211 if "format_version" in cls.model_fields: 

212 if "." not in cls.implemented_format_version: 

213 cls.implemented_format_version_tuple = (0, 0, 0) 

214 else: 

215 fv_tuple = get_format_version_tuple(cls.implemented_format_version) 

216 assert fv_tuple is not None, ( 

217 f"failed to cast '{cls.implemented_format_version}' to tuple" 

218 ) 

219 cls.implemented_format_version_tuple = fv_tuple 

220 

221 @classmethod 

222 def load_from_kwargs( 

223 cls: Callable[P, T], 

224 context: Optional[ValidationContext] = None, 

225 *args: P.args, 

226 **kwargs: P.kwargs, 

227 ) -> Union[T, InvalidDescr]: 

228 sig = signature(cls) 

229 bound = sig.bind_partial(*args, **kwargs) 

230 return cls.load(dict(bound.arguments), context=context) # pyright: ignore[reportFunctionMemberAccess] 

231 

232 @classmethod 

233 def load( 

234 cls, 

235 data: IncompleteDescrView, 

236 context: Optional[ValidationContext] = None, 

237 ) -> Union[Self, InvalidDescr]: 

238 """factory method to create a resource description object""" 

239 

240 context = context or get_validation_context() 

241 if context.perform_io_checks: 

242 file_descrs = extract_file_descrs(data) 

243 populate_cache(file_descrs) # TODO: add progress bar 

244 

245 with context.replace(log_warnings=context.warning_level <= INFO): 

246 rd, errors, val_warnings = cls._load_impl(deepcopy_incomplete_descr(data)) 

247 

248 if context.warning_level > INFO: 

249 all_warnings_context = context.replace( 

250 warning_level=INFO, log_warnings=False, raise_errors=False 

251 ) 

252 # raise all validation warnings by reloading 

253 with all_warnings_context: 

254 _, _, val_warnings = cls._load_impl(deepcopy_incomplete_descr(data)) 

255 

256 format_status = "failed" if errors else "passed" 

257 rd.validation_summary.add_detail( 

258 ValidationDetail( 

259 errors=errors, 

260 name=( 

261 "bioimageio.spec format validation" 

262 f" {rd.type} {cls.implemented_format_version}" 

263 ), 

264 status=format_status, 

265 warnings=val_warnings, 

266 ), 

267 update_status=False, # this special validation detail needs manual format updating below 

268 ) 

269 assert format_status != "failed" or isinstance(rd, InvalidDescr) 

270 

271 return rd 

272 

273 def _get_metadata_completeness(self) -> float: 

274 if isinstance(self, InvalidDescr): 

275 return 0.0 

276 

277 given = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=False) 

278 full = self.model_dump(mode="json", exclude_unset=False, exclude_defaults=False) 

279 

280 def extract_flat_keys(d: Dict[Any, Any], key: str = "") -> Iterable[str]: 

281 for k, v in d.items(): 

282 if is_dict(v): 

283 yield from extract_flat_keys(v, key=f"{key}.{k}" if key else k) 

284 

285 yield f"{key}.{k}" if key else k 

286 

287 given_keys = set(extract_flat_keys(given)) 

288 full_keys = set(extract_flat_keys(full)) 

289 assert len(full_keys) >= len(given_keys) 

290 return len(given_keys) / len(full_keys) if full_keys else 0.0 

291 

292 @classmethod 

293 def _load_impl( 

294 cls, data: IncompleteDescr 

295 ) -> Tuple[Union[Self, InvalidDescr], List[ErrorEntry], List[WarningEntry]]: 

296 rd: Union[Self, InvalidDescr, None] = None 

297 val_errors: List[ErrorEntry] = [] 

298 val_warnings: List[WarningEntry] = [] 

299 

300 context = get_validation_context() 

301 try: 

302 rd = cls.model_validate(data) 

303 except pydantic.ValidationError as e: 

304 for ee in e.errors(include_url=False): 

305 if (severity := ee.get("ctx", {}).get("severity", ERROR)) < ERROR: 

306 val_warnings.append( 

307 WarningEntry( 

308 loc=ee["loc"], 

309 msg=ee["msg"], 

310 type=ee["type"], 

311 severity=severity, 

312 ) 

313 ) 

314 elif context.raise_errors: 

315 raise e 

316 else: 

317 val_errors.append( 

318 ErrorEntry(loc=ee["loc"], msg=ee["msg"], type=ee["type"]) 

319 ) 

320 

321 if len(val_errors) == 0: # FIXME is this reduntant? 

322 val_errors.append( 

323 ErrorEntry( 

324 loc=(), 

325 msg=( 

326 f"Encountered {len(val_warnings)} more severe than warning" 

327 " level " 

328 f"'{WARNING_LEVEL_TO_NAME[context.warning_level]}'" 

329 ), 

330 type="severe_warnings", 

331 ) 

332 ) 

333 except Exception as e: 

334 if context.raise_errors: 

335 raise e 

336 

337 try: 

338 msg = str(e) 

339 except Exception: 

340 msg = e.__class__.__name__ + " encountered" 

341 

342 val_errors.append( 

343 ErrorEntry( 

344 loc=(), 

345 msg=msg, 

346 type=type(e).__name__, 

347 with_traceback=True, 

348 ) 

349 ) 

350 

351 if rd is None: 

352 try: 

353 rd = InvalidDescr.model_validate(data) 

354 except Exception as e: 

355 if context.raise_errors: 

356 raise e 

357 resource_type = cls.model_fields["type"].default 

358 format_version = cls.implemented_format_version 

359 rd = InvalidDescr(type=resource_type, format_version=format_version) 

360 if context.raise_errors: 

361 raise ValueError(rd) 

362 

363 return rd, val_errors, val_warnings 

364 

365 def package( 

366 self, dest: Optional[Union[ZipFile, IO[bytes], Path, str]] = None, / 

367 ) -> ZipFile: 

368 """package the described resource as a zip archive 

369 

370 Args: 

371 dest: (path/bytes stream of) destination zipfile 

372 """ 

373 if dest is None: 

374 dest = BytesIO() 

375 

376 if isinstance(dest, ZipFile): 

377 zip = dest 

378 if "r" in zip.mode: 

379 raise ValueError( 

380 f"zip file {dest} opened in '{zip.mode}' mode," 

381 + " but write access is needed for packaging." 

382 ) 

383 else: 

384 zip = ZipFile(dest, mode="w") 

385 

386 if zip.filename is None: 

387 zip.filename = ( 

388 str(getattr(self, "id", getattr(self, "name", "bioimageio"))) + ".zip" 

389 ) 

390 

391 content = self.get_package_content() 

392 write_content_to_zip(content, zip) 

393 return zip 

394 

395 def get_package_content( 

396 self, 

397 ) -> Dict[FileName, Union[FileDescr, BioimageioYamlContent]]: 

398 """Returns package content without creating the package.""" 

399 content: Dict[FileName, FileDescr] = {} 

400 with PackagingContext( 

401 bioimageio_yaml_file_name=BIOIMAGEIO_YAML, 

402 file_sources=content, 

403 ): 

404 rdf_content: BioimageioYamlContent = self.model_dump( 

405 mode="json", exclude_unset=True 

406 ) 

407 

408 _ = rdf_content.pop("rdf_source", None) 

409 

410 return {**content, BIOIMAGEIO_YAML: rdf_content} 

411 

412 

413class InvalidDescr( 

414 ResourceDescrBase, 

415 extra="allow", 

416 title="An invalid resource description", 

417): 

418 """A representation of an invalid resource description""" 

419 

420 implemented_type: ClassVar[Literal["unknown"]] = "unknown" 

421 if TYPE_CHECKING: # see NodeWithExplicitlySetFields 

422 type: Any = "unknown" 

423 else: 

424 type: Any 

425 

426 implemented_format_version: ClassVar[Literal["unknown"]] = "unknown" 

427 if TYPE_CHECKING: # see NodeWithExplicitlySetFields 

428 format_version: Any = "unknown" 

429 else: 

430 format_version: Any 

431 

432 def get_reason(self) -> Optional[str]: 

433 """Get the reason why the description is invalid, if available.""" 

434 reasons: List[str] = [] 

435 if self.validation_summary and self.validation_summary.details: 

436 for detail in self.validation_summary.details: 

437 if detail.status == "failed" and detail.errors: 

438 reasons.extend( 

439 f"{loc}: {msg}" 

440 for loc, msg in ( 

441 (error.loc, error.msg.replace("\n", " ")) 

442 for error in detail.errors 

443 ) 

444 ) 

445 

446 return "\n- ".join(reasons) if reasons else None 

447 

448 

449class KwargsNode(Node): 

450 def get(self, item: str, default: Any = None) -> Any: 

451 return self[item] if item in self else default 

452 

453 def __getitem__(self, item: str) -> Any: 

454 if item in self.__class__.model_fields: 

455 return getattr(self, item) 

456 else: 

457 raise KeyError(item) 

458 

459 def __contains__(self, item: str) -> bool: 

460 return item in self.__class__.model_fields