Coverage for src/bioimageio/spec/_internal/common_nodes.py: 87%

191 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-30 13:10 +0000

1from __future__ import annotations 

2 

3from abc import ABC 

4from io import BytesIO 

5from pathlib import Path 

6from types import MappingProxyType 

7from typing import ( 

8 IO, 

9 TYPE_CHECKING, 

10 Any, 

11 ClassVar, 

12 Dict, 

13 Iterable, 

14 List, 

15 Literal, 

16 Mapping, 

17 Optional, 

18 Protocol, 

19 Tuple, 

20 Union, 

21) 

22from zipfile import ZipFile 

23 

24import pydantic 

25from pydantic import DirectoryPath, PrivateAttr, model_validator 

26from pydantic_core import PydanticUndefined 

27from typing_extensions import Self 

28 

29from ..summary import ( 

30 WARNING_LEVEL_TO_NAME, 

31 ErrorEntry, 

32 ValidationDetail, 

33 ValidationSummary, 

34 WarningEntry, 

35) 

36from .field_warning import issue_warning 

37from .io import ( 

38 BioimageioYamlContent, 

39 BioimageioYamlContentView, 

40 FileDescr, 

41 deepcopy_yaml_value, 

42 extract_file_descrs, 

43 populate_cache, 

44) 

45from .io_basics import BIOIMAGEIO_YAML, FileName 

46from .io_utils import write_content_to_zip 

47from .node import Node 

48from .packaging_context import PackagingContext 

49from .root_url import RootHttpUrl 

50from .type_guards import is_dict 

51from .utils import get_format_version_tuple 

52from .validation_context import ValidationContext, get_validation_context 

53from .warning_levels import ALERT, ERROR, INFO 

54 

55 

56class NodeWithExplicitlySetFields(Node): 

57 _fields_to_set_explicitly: ClassVar[Mapping[str, Any]] 

58 

59 @classmethod 

60 def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: 

61 explict_fields: Dict[str, Any] = {} 

62 for attr in dir(cls): 

63 if attr.startswith("implemented_"): 

64 field_name = attr.replace("implemented_", "") 

65 if field_name not in cls.model_fields: 

66 continue 

67 

68 assert ( 

69 cls.model_fields[field_name].get_default() is PydanticUndefined 

70 ), field_name 

71 default = getattr(cls, attr) 

72 explict_fields[field_name] = default 

73 

74 cls._fields_to_set_explicitly = MappingProxyType(explict_fields) 

75 return super().__pydantic_init_subclass__(**kwargs) 

76 

77 @model_validator(mode="before") 

78 @classmethod 

79 def _set_fields_explicitly( 

80 cls, data: Union[Any, Dict[str, Any]] 

81 ) -> Union[Any, Dict[str, Any]]: 

82 if isinstance(data, dict): 

83 for name, default in cls._fields_to_set_explicitly.items(): 

84 if name not in data: 

85 data[name] = default 

86 

87 return data # pyright: ignore[reportUnknownVariableType] 

88 

89 

90if TYPE_CHECKING: 

91 

92 class _ResourceDescrBaseAbstractFieldsProtocol(Protocol): 

93 """workaround to add "abstract" fields to ResourceDescrBase""" 

94 

95 # TODO: implement as proper abstract fields of ResourceDescrBase 

96 

97 type: Any # should be LiteralString 

98 format_version: Any # should be LiteralString 

99 implemented_type: ClassVar[Any] 

100 implemented_format_version: ClassVar[Any] 

101 

102else: 

103 

104 class _ResourceDescrBaseAbstractFieldsProtocol: 

105 pass 

106 

107 

108class ResourceDescrBase( 

109 NodeWithExplicitlySetFields, ABC, _ResourceDescrBaseAbstractFieldsProtocol 

110): 

111 """base class for all resource descriptions""" 

112 

113 _validation_summary: Optional[ValidationSummary] = None 

114 

115 implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] 

116 

117 # @field_validator("format_version", mode="before", check_fields=False) 

118 # field_validator on "format_version" is not possible, because we want to use 

119 # "format_version" in a descriminated Union higher up 

120 # (PydanticUserError: Cannot use a mode='before' validator in the discriminator 

121 # field 'format_version' of Model 'CollectionDescr') 

122 @model_validator(mode="before") 

123 @classmethod 

124 def _ignore_future_patch(cls, data: Any, /) -> Any: 

125 if ( 

126 cls.implemented_format_version == "unknown" 

127 or not is_dict(data) 

128 or "format_version" not in data 

129 ): 

130 return data 

131 

132 value = data["format_version"] 

133 fv = get_format_version_tuple(value) 

134 if fv is None: 

135 return data 

136 if ( 

137 fv[0] == cls.implemented_format_version_tuple[0] 

138 and fv[1:] > cls.implemented_format_version_tuple[1:] 

139 ): 

140 issue_warning( 

141 "future format_version '{value}' treated as '{implemented}'", 

142 value=value, 

143 msg_context=dict(implemented=cls.implemented_format_version), 

144 severity=ALERT, 

145 ) 

146 data["format_version"] = cls.implemented_format_version 

147 

148 return data 

149 

150 @model_validator(mode="after") 

151 def _set_init_validation_summary(self) -> Self: 

152 context = get_validation_context() 

153 

154 self._validation_summary = ValidationSummary( 

155 name="bioimageio format validation", 

156 source_name=context.source_name, 

157 id=getattr(self, "id", None), 

158 type=self.type, 

159 format_version=self.format_version, 

160 status="failed" if isinstance(self, InvalidDescr) else "valid-format", 

161 metadata_completeness=self._get_metadata_completeness(), 

162 details=( 

163 [] 

164 if isinstance(self, InvalidDescr) 

165 else [ 

166 ValidationDetail( 

167 name=f"Successfully created `{self.__class__.__name__}` instance.", 

168 status="passed", 

169 context=context.summary, 

170 ) 

171 ] 

172 ), 

173 ) 

174 return self 

175 

176 @property 

177 def validation_summary(self) -> ValidationSummary: 

178 assert self._validation_summary is not None, "access only after initialization" 

179 return self._validation_summary 

180 

181 _root: Union[RootHttpUrl, DirectoryPath, ZipFile] = PrivateAttr( 

182 default_factory=lambda: get_validation_context().root 

183 ) 

184 

185 _file_name: Optional[FileName] = PrivateAttr( 

186 default_factory=lambda: get_validation_context().file_name 

187 ) 

188 

189 @property 

190 def root(self) -> Union[RootHttpUrl, DirectoryPath, ZipFile]: 

191 """The URL/Path prefix to resolve any relative paths with.""" 

192 return self._root 

193 

194 @property 

195 def file_name(self) -> Optional[FileName]: 

196 """File name of the bioimageio.yaml file the description was loaded from.""" 

197 return self._file_name 

198 

199 @classmethod 

200 def __pydantic_init_subclass__(cls, **kwargs: Any): 

201 super().__pydantic_init_subclass__(**kwargs) 

202 # set classvar implemented_format_version_tuple 

203 if "format_version" in cls.model_fields: 

204 if "." not in cls.implemented_format_version: 

205 cls.implemented_format_version_tuple = (0, 0, 0) 

206 else: 

207 fv_tuple = get_format_version_tuple(cls.implemented_format_version) 

208 assert fv_tuple is not None, ( 

209 f"failed to cast '{cls.implemented_format_version}' to tuple" 

210 ) 

211 cls.implemented_format_version_tuple = fv_tuple 

212 

213 @classmethod 

214 def load( 

215 cls, 

216 data: BioimageioYamlContentView, 

217 context: Optional[ValidationContext] = None, 

218 ) -> Union[Self, InvalidDescr]: 

219 """factory method to create a resource description object""" 

220 context = context or get_validation_context() 

221 if context.perform_io_checks: 

222 file_descrs = extract_file_descrs({k: v for k, v in data.items()}) 

223 populate_cache(file_descrs) # TODO: add progress bar 

224 

225 with context.replace(log_warnings=context.warning_level <= INFO): 

226 rd, errors, val_warnings = cls._load_impl(deepcopy_yaml_value(data)) 

227 

228 if context.warning_level > INFO: 

229 all_warnings_context = context.replace( 

230 warning_level=INFO, log_warnings=False, raise_errors=False 

231 ) 

232 # raise all validation warnings by reloading 

233 with all_warnings_context: 

234 _, _, val_warnings = cls._load_impl(deepcopy_yaml_value(data)) 

235 

236 format_status = "failed" if errors else "passed" 

237 rd.validation_summary.add_detail( 

238 ValidationDetail( 

239 errors=errors, 

240 name=( 

241 "bioimageio.spec format validation" 

242 f" {rd.type} {cls.implemented_format_version}" 

243 ), 

244 status=format_status, 

245 warnings=val_warnings, 

246 ), 

247 update_status=False, # avoid updating status from 'valid-format' to 'passed', but ... 

248 ) 

249 if format_status == "failed": 

250 # ... update status in case of failure 

251 rd.validation_summary.status = "failed" 

252 

253 return rd 

254 

255 def _get_metadata_completeness(self) -> float: 

256 if isinstance(self, InvalidDescr): 

257 return 0.0 

258 

259 given = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=False) 

260 full = self.model_dump(mode="json", exclude_unset=False, exclude_defaults=False) 

261 

262 def extract_flat_keys(d: Dict[Any, Any], key: str = "") -> Iterable[str]: 

263 for k, v in d.items(): 

264 if is_dict(v): 

265 yield from extract_flat_keys(v, key=f"{key}.{k}" if key else k) 

266 

267 yield f"{key}.{k}" if key else k 

268 

269 given_keys = set(extract_flat_keys(given)) 

270 full_keys = set(extract_flat_keys(full)) 

271 assert len(full_keys) >= len(given_keys) 

272 return len(given_keys) / len(full_keys) if full_keys else 0.0 

273 

274 @classmethod 

275 def _load_impl( 

276 cls, data: BioimageioYamlContent 

277 ) -> Tuple[Union[Self, InvalidDescr], List[ErrorEntry], List[WarningEntry]]: 

278 rd: Union[Self, InvalidDescr, None] = None 

279 val_errors: List[ErrorEntry] = [] 

280 val_warnings: List[WarningEntry] = [] 

281 

282 context = get_validation_context() 

283 try: 

284 rd = cls.model_validate(data) 

285 except pydantic.ValidationError as e: 

286 for ee in e.errors(include_url=False): 

287 if (severity := ee.get("ctx", {}).get("severity", ERROR)) < ERROR: 

288 val_warnings.append( 

289 WarningEntry( 

290 loc=ee["loc"], 

291 msg=ee["msg"], 

292 type=ee["type"], 

293 severity=severity, 

294 ) 

295 ) 

296 elif context.raise_errors: 

297 raise e 

298 else: 

299 val_errors.append( 

300 ErrorEntry(loc=ee["loc"], msg=ee["msg"], type=ee["type"]) 

301 ) 

302 

303 if len(val_errors) == 0: # FIXME is this reduntant? 

304 val_errors.append( 

305 ErrorEntry( 

306 loc=(), 

307 msg=( 

308 f"Encountered {len(val_warnings)} more severe than warning" 

309 " level " 

310 f"'{WARNING_LEVEL_TO_NAME[context.warning_level]}'" 

311 ), 

312 type="severe_warnings", 

313 ) 

314 ) 

315 except Exception as e: 

316 if context.raise_errors: 

317 raise e 

318 

319 try: 

320 msg = str(e) 

321 except Exception: 

322 msg = e.__class__.__name__ + " encountered" 

323 

324 val_errors.append( 

325 ErrorEntry( 

326 loc=(), 

327 msg=msg, 

328 type=type(e).__name__, 

329 with_traceback=True, 

330 ) 

331 ) 

332 

333 if rd is None: 

334 try: 

335 rd = InvalidDescr.model_validate(data) 

336 except Exception as e: 

337 if context.raise_errors: 

338 raise e 

339 resource_type = cls.model_fields["type"].default 

340 format_version = cls.implemented_format_version 

341 rd = InvalidDescr(type=resource_type, format_version=format_version) 

342 if context.raise_errors: 

343 raise ValueError(rd) 

344 

345 return rd, val_errors, val_warnings 

346 

347 def package( 

348 self, dest: Optional[Union[ZipFile, IO[bytes], Path, str]] = None, / 

349 ) -> ZipFile: 

350 """package the described resource as a zip archive 

351 

352 Args: 

353 dest: (path/bytes stream of) destination zipfile 

354 """ 

355 if dest is None: 

356 dest = BytesIO() 

357 

358 if isinstance(dest, ZipFile): 

359 zip = dest 

360 if "r" in zip.mode: 

361 raise ValueError( 

362 f"zip file {dest} opened in '{zip.mode}' mode," 

363 + " but write access is needed for packaging." 

364 ) 

365 else: 

366 zip = ZipFile(dest, mode="w") 

367 

368 if zip.filename is None: 

369 zip.filename = ( 

370 str(getattr(self, "id", getattr(self, "name", "bioimageio"))) + ".zip" 

371 ) 

372 

373 content = self.get_package_content() 

374 write_content_to_zip(content, zip) 

375 return zip 

376 

377 def get_package_content( 

378 self, 

379 ) -> Dict[FileName, Union[FileDescr, BioimageioYamlContent]]: 

380 """Returns package content without creating the package.""" 

381 content: Dict[FileName, FileDescr] = {} 

382 with PackagingContext( 

383 bioimageio_yaml_file_name=BIOIMAGEIO_YAML, 

384 file_sources=content, 

385 ): 

386 rdf_content: BioimageioYamlContent = self.model_dump( 

387 mode="json", exclude_unset=True 

388 ) 

389 

390 _ = rdf_content.pop("rdf_source", None) 

391 

392 return {**content, BIOIMAGEIO_YAML: rdf_content} 

393 

394 

395class InvalidDescr( 

396 ResourceDescrBase, 

397 extra="allow", 

398 title="An invalid resource description", 

399): 

400 """A representation of an invalid resource description""" 

401 

402 implemented_type: ClassVar[Literal["unknown"]] = "unknown" 

403 if TYPE_CHECKING: # see NodeWithExplicitlySetFields 

404 type: Any = "unknown" 

405 else: 

406 type: Any 

407 

408 implemented_format_version: ClassVar[Literal["unknown"]] = "unknown" 

409 if TYPE_CHECKING: # see NodeWithExplicitlySetFields 

410 format_version: Any = "unknown" 

411 else: 

412 format_version: Any 

413 

414 

415class KwargsNode(Node): 

416 def get(self, item: str, default: Any = None) -> Any: 

417 return self[item] if item in self else default 

418 

419 def __getitem__(self, item: str) -> Any: 

420 if item in self.__class__.model_fields: 

421 return getattr(self, item) 

422 else: 

423 raise KeyError(item) 

424 

425 def __contains__(self, item: str) -> int: 

426 return item in self.__class__.model_fields