Coverage for src / bioimageio / spec / _internal / common_nodes.py: 87%

206 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-27 14:45 +0000

1from __future__ import annotations 

2 

3from abc import ABC 

4from inspect import signature 

5from io import BytesIO 

6from pathlib import Path 

7from types import MappingProxyType 

8from typing import ( 

9 IO, 

10 TYPE_CHECKING, 

11 Any, 

12 ClassVar, 

13 Dict, 

14 Iterable, 

15 List, 

16 Literal, 

17 Mapping, 

18 Optional, 

19 Protocol, 

20 Tuple, 

21 TypeVar, 

22 Union, 

23) 

24from zipfile import ZipFile 

25 

26import pydantic 

27from pydantic import DirectoryPath, PrivateAttr, model_validator 

28from pydantic_core import PydanticUndefined 

29from typing_extensions import Callable, ParamSpec, Self 

30 

31from ..summary import ( 

32 WARNING_LEVEL_TO_NAME, 

33 ErrorEntry, 

34 ValidationDetail, 

35 ValidationSummary, 

36 WarningEntry, 

37) 

38from .field_warning import issue_warning 

39from .io import ( 

40 BioimageioYamlContent, 

41 FileDescr, 

42 IncompleteDescr, 

43 IncompleteDescrView, 

44 deepcopy_incomplete_descr, 

45 extract_file_descrs, 

46 populate_cache, 

47) 

48from .io_basics import BIOIMAGEIO_YAML, FileName 

49from .io_utils import write_content_to_zip 

50from .node import Node 

51from .packaging_context import PackagingContext 

52from .root_url import RootHttpUrl 

53from .type_guards import is_dict 

54from .utils import get_format_version_tuple 

55from .validation_context import ValidationContext, get_validation_context 

56from .warning_levels import ALERT, ERROR, INFO 

57 

58 

59class NodeWithExplicitlySetFields(Node): 

60 _fields_to_set_explicitly: ClassVar[Mapping[str, Any]] 

61 

62 @classmethod 

63 def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: 

64 explict_fields: Dict[str, Any] = {} 

65 for attr in dir(cls): 

66 if attr.startswith("implemented_"): 

67 field_name = attr.replace("implemented_", "") 

68 if field_name not in cls.model_fields: 

69 continue 

70 

71 assert ( 

72 cls.model_fields[field_name].get_default() is PydanticUndefined 

73 ), field_name 

74 default = getattr(cls, attr) 

75 explict_fields[field_name] = default 

76 

77 cls._fields_to_set_explicitly = MappingProxyType(explict_fields) 

78 return super().__pydantic_init_subclass__(**kwargs) 

79 

80 @model_validator(mode="before") 

81 @classmethod 

82 def _set_fields_explicitly( 

83 cls, data: Union[Any, Dict[str, Any]] 

84 ) -> Union[Any, Dict[str, Any]]: 

85 if isinstance(data, dict): 

86 for name, default in cls._fields_to_set_explicitly.items(): 

87 if name not in data: 

88 data[name] = default 

89 

90 return data # pyright: ignore[reportUnknownVariableType] 

91 

92 

93if TYPE_CHECKING: 

94 

95 class _ResourceDescrBaseAbstractFieldsProtocol(Protocol): 

96 """workaround to add "abstract" fields to ResourceDescrBase""" 

97 

98 # TODO: implement as proper abstract fields of ResourceDescrBase 

99 

100 type: Any # should be LiteralString 

101 format_version: Any # should be LiteralString 

102 implemented_type: ClassVar[Any] 

103 implemented_format_version: ClassVar[Any] 

104 

105else: 

106 

107 class _ResourceDescrBaseAbstractFieldsProtocol: 

108 pass 

109 

110 

111P = ParamSpec("P") 

112T = TypeVar("T") 

113 

114 

115class ResourceDescrBase( 

116 NodeWithExplicitlySetFields, ABC, _ResourceDescrBaseAbstractFieldsProtocol 

117): 

118 """base class for all resource descriptions""" 

119 

120 _validation_summary: Optional[ValidationSummary] = None 

121 

122 implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] 

123 

124 # @field_validator("format_version", mode="before", check_fields=False) 

125 # field_validator on "format_version" is not possible, because we want to use 

126 # "format_version" in a descriminated Union higher up 

127 # (PydanticUserError: Cannot use a mode='before' validator in the discriminator 

128 # field 'format_version' of Model 'CollectionDescr') 

129 @model_validator(mode="before") 

130 @classmethod 

131 def _ignore_future_patch(cls, data: Any, /) -> Any: 

132 if ( 

133 cls.implemented_format_version == "unknown" 

134 or not is_dict(data) 

135 or "format_version" not in data 

136 ): 

137 return data 

138 

139 value = data["format_version"] 

140 fv = get_format_version_tuple(value) 

141 if fv is None: 

142 return data 

143 if ( 

144 fv[0] == cls.implemented_format_version_tuple[0] 

145 and fv[1:] > cls.implemented_format_version_tuple[1:] 

146 ): 

147 issue_warning( 

148 "future format_version '{value}' treated as '{implemented}'", 

149 value=value, 

150 msg_context=dict(implemented=cls.implemented_format_version), 

151 severity=ALERT, 

152 ) 

153 data["format_version"] = cls.implemented_format_version 

154 

155 return data 

156 

157 @model_validator(mode="after") 

158 def _set_init_validation_summary(self) -> Self: 

159 context = get_validation_context() 

160 

161 self._validation_summary = ValidationSummary( 

162 name="bioimageio format validation", 

163 source_name=context.source_name, 

164 id=getattr(self, "id", None), 

165 version=getattr(self, "version", None), 

166 type=self.type, 

167 format_version=self.format_version, 

168 status="failed" if isinstance(self, InvalidDescr) else "valid-format", 

169 metadata_completeness=self._get_metadata_completeness(), 

170 details=( 

171 [] 

172 if isinstance(self, InvalidDescr) 

173 else [ 

174 ValidationDetail( 

175 name=f"Successfully created `{self.__class__.__name__}` instance.", 

176 status="passed", 

177 context=context.summary, 

178 ) 

179 ] 

180 ), 

181 ) 

182 return self 

183 

184 @property 

185 def validation_summary(self) -> ValidationSummary: 

186 assert self._validation_summary is not None, "access only after initialization" 

187 return self._validation_summary 

188 

189 _root: Union[RootHttpUrl, DirectoryPath, ZipFile] = PrivateAttr( 

190 default_factory=lambda: get_validation_context().root 

191 ) 

192 

193 _file_name: Optional[FileName] = PrivateAttr( 

194 default_factory=lambda: get_validation_context().file_name 

195 ) 

196 

197 @property 

198 def root(self) -> Union[RootHttpUrl, DirectoryPath, ZipFile]: 

199 """The URL/Path prefix to resolve any relative paths with.""" 

200 return self._root 

201 

202 @property 

203 def file_name(self) -> Optional[FileName]: 

204 """File name of the bioimageio.yaml file the description was loaded from.""" 

205 return self._file_name 

206 

207 @classmethod 

208 def __pydantic_init_subclass__(cls, **kwargs: Any): 

209 super().__pydantic_init_subclass__(**kwargs) 

210 # set classvar implemented_format_version_tuple 

211 if "format_version" in cls.model_fields: 

212 if "." not in cls.implemented_format_version: 

213 cls.implemented_format_version_tuple = (0, 0, 0) 

214 else: 

215 fv_tuple = get_format_version_tuple(cls.implemented_format_version) 

216 assert fv_tuple is not None, ( 

217 f"failed to cast '{cls.implemented_format_version}' to tuple" 

218 ) 

219 cls.implemented_format_version_tuple = fv_tuple 

220 

221 @classmethod 

222 def load_from_kwargs( 

223 cls: Callable[P, T], 

224 context: Optional[ValidationContext] = None, 

225 *args: P.args, 

226 **kwargs: P.kwargs, 

227 ) -> Union[T, InvalidDescr]: 

228 sig = signature(cls) 

229 bound = sig.bind_partial(*args, **kwargs) 

230 return cls.load(dict(bound.arguments), context=context) # pyright: ignore[reportFunctionMemberAccess] 

231 

232 @classmethod 

233 def load( 

234 cls, 

235 data: IncompleteDescrView, 

236 context: Optional[ValidationContext] = None, 

237 ) -> Union[Self, InvalidDescr]: 

238 """factory method to create a resource description object""" 

239 

240 context = context or get_validation_context() 

241 if context.perform_io_checks: 

242 file_descrs = extract_file_descrs(data) 

243 populate_cache(file_descrs) # TODO: add progress bar 

244 

245 with context.replace(log_warnings=context.warning_level <= INFO): 

246 rd, errors, val_warnings = cls._load_impl(deepcopy_incomplete_descr(data)) 

247 

248 if context.warning_level > INFO: 

249 all_warnings_context = context.replace( 

250 warning_level=INFO, log_warnings=False, raise_errors=False 

251 ) 

252 # raise all validation warnings by reloading 

253 with all_warnings_context: 

254 _, _, val_warnings = cls._load_impl(deepcopy_incomplete_descr(data)) 

255 

256 format_status = "failed" if errors else "passed" 

257 rd.validation_summary.add_detail( 

258 ValidationDetail( 

259 errors=errors, 

260 name=( 

261 "bioimageio.spec format validation" 

262 f" {rd.type} {cls.implemented_format_version}" 

263 ), 

264 status=format_status, 

265 warnings=val_warnings, 

266 ), 

267 update_status=False, # avoid updating status from 'valid-format' to 'passed', but ... 

268 ) 

269 if format_status == "failed": 

270 # ... update status in case of failure 

271 rd.validation_summary.status = "failed" 

272 

273 return rd 

274 

275 def _get_metadata_completeness(self) -> float: 

276 if isinstance(self, InvalidDescr): 

277 return 0.0 

278 

279 given = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=False) 

280 full = self.model_dump(mode="json", exclude_unset=False, exclude_defaults=False) 

281 

282 def extract_flat_keys(d: Dict[Any, Any], key: str = "") -> Iterable[str]: 

283 for k, v in d.items(): 

284 if is_dict(v): 

285 yield from extract_flat_keys(v, key=f"{key}.{k}" if key else k) 

286 

287 yield f"{key}.{k}" if key else k 

288 

289 given_keys = set(extract_flat_keys(given)) 

290 full_keys = set(extract_flat_keys(full)) 

291 assert len(full_keys) >= len(given_keys) 

292 return len(given_keys) / len(full_keys) if full_keys else 0.0 

293 

294 @classmethod 

295 def _load_impl( 

296 cls, data: IncompleteDescr 

297 ) -> Tuple[Union[Self, InvalidDescr], List[ErrorEntry], List[WarningEntry]]: 

298 rd: Union[Self, InvalidDescr, None] = None 

299 val_errors: List[ErrorEntry] = [] 

300 val_warnings: List[WarningEntry] = [] 

301 

302 context = get_validation_context() 

303 try: 

304 rd = cls.model_validate(data) 

305 except pydantic.ValidationError as e: 

306 for ee in e.errors(include_url=False): 

307 if (severity := ee.get("ctx", {}).get("severity", ERROR)) < ERROR: 

308 val_warnings.append( 

309 WarningEntry( 

310 loc=ee["loc"], 

311 msg=ee["msg"], 

312 type=ee["type"], 

313 severity=severity, 

314 ) 

315 ) 

316 elif context.raise_errors: 

317 raise e 

318 else: 

319 val_errors.append( 

320 ErrorEntry(loc=ee["loc"], msg=ee["msg"], type=ee["type"]) 

321 ) 

322 

323 if len(val_errors) == 0: # FIXME is this reduntant? 

324 val_errors.append( 

325 ErrorEntry( 

326 loc=(), 

327 msg=( 

328 f"Encountered {len(val_warnings)} more severe than warning" 

329 " level " 

330 f"'{WARNING_LEVEL_TO_NAME[context.warning_level]}'" 

331 ), 

332 type="severe_warnings", 

333 ) 

334 ) 

335 except Exception as e: 

336 if context.raise_errors: 

337 raise e 

338 

339 try: 

340 msg = str(e) 

341 except Exception: 

342 msg = e.__class__.__name__ + " encountered" 

343 

344 val_errors.append( 

345 ErrorEntry( 

346 loc=(), 

347 msg=msg, 

348 type=type(e).__name__, 

349 with_traceback=True, 

350 ) 

351 ) 

352 

353 if rd is None: 

354 try: 

355 rd = InvalidDescr.model_validate(data) 

356 except Exception as e: 

357 if context.raise_errors: 

358 raise e 

359 resource_type = cls.model_fields["type"].default 

360 format_version = cls.implemented_format_version 

361 rd = InvalidDescr(type=resource_type, format_version=format_version) 

362 if context.raise_errors: 

363 raise ValueError(rd) 

364 

365 return rd, val_errors, val_warnings 

366 

367 def package( 

368 self, dest: Optional[Union[ZipFile, IO[bytes], Path, str]] = None, / 

369 ) -> ZipFile: 

370 """package the described resource as a zip archive 

371 

372 Args: 

373 dest: (path/bytes stream of) destination zipfile 

374 """ 

375 if dest is None: 

376 dest = BytesIO() 

377 

378 if isinstance(dest, ZipFile): 

379 zip = dest 

380 if "r" in zip.mode: 

381 raise ValueError( 

382 f"zip file {dest} opened in '{zip.mode}' mode," 

383 + " but write access is needed for packaging." 

384 ) 

385 else: 

386 zip = ZipFile(dest, mode="w") 

387 

388 if zip.filename is None: 

389 zip.filename = ( 

390 str(getattr(self, "id", getattr(self, "name", "bioimageio"))) + ".zip" 

391 ) 

392 

393 content = self.get_package_content() 

394 write_content_to_zip(content, zip) 

395 return zip 

396 

397 def get_package_content( 

398 self, 

399 ) -> Dict[FileName, Union[FileDescr, BioimageioYamlContent]]: 

400 """Returns package content without creating the package.""" 

401 content: Dict[FileName, FileDescr] = {} 

402 with PackagingContext( 

403 bioimageio_yaml_file_name=BIOIMAGEIO_YAML, 

404 file_sources=content, 

405 ): 

406 rdf_content: BioimageioYamlContent = self.model_dump( 

407 mode="json", exclude_unset=True 

408 ) 

409 

410 _ = rdf_content.pop("rdf_source", None) 

411 

412 return {**content, BIOIMAGEIO_YAML: rdf_content} 

413 

414 

415class InvalidDescr( 

416 ResourceDescrBase, 

417 extra="allow", 

418 title="An invalid resource description", 

419): 

420 """A representation of an invalid resource description""" 

421 

422 implemented_type: ClassVar[Literal["unknown"]] = "unknown" 

423 if TYPE_CHECKING: # see NodeWithExplicitlySetFields 

424 type: Any = "unknown" 

425 else: 

426 type: Any 

427 

428 implemented_format_version: ClassVar[Literal["unknown"]] = "unknown" 

429 if TYPE_CHECKING: # see NodeWithExplicitlySetFields 

430 format_version: Any = "unknown" 

431 else: 

432 format_version: Any 

433 

434 def get_reason(self) -> Optional[str]: 

435 """Get the reason why the description is invalid, if available.""" 

436 reasons: List[str] = [] 

437 if self.validation_summary and self.validation_summary.details: 

438 for detail in self.validation_summary.details: 

439 if detail.status == "failed" and detail.errors: 

440 reasons.extend( 

441 f"{loc}: {msg}" 

442 for loc, msg in ( 

443 (error.loc, error.msg.replace("\n", " ")) 

444 for error in detail.errors 

445 ) 

446 ) 

447 

448 return "\n- ".join(reasons) if reasons else None 

449 

450 

451class KwargsNode(Node): 

452 def get(self, item: str, default: Any = None) -> Any: 

453 return self[item] if item in self else default 

454 

455 def __getitem__(self, item: str) -> Any: 

456 if item in self.__class__.model_fields: 

457 return getattr(self, item) 

458 else: 

459 raise KeyError(item) 

460 

461 def __contains__(self, item: str) -> bool: 

462 return item in self.__class__.model_fields