Coverage for bioimageio/spec/_internal/common_nodes.py: 88%

171 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-02 14:21 +0000

1from __future__ import annotations 

2 

3from abc import ABC 

4from copy import deepcopy 

5from io import BytesIO 

6from pathlib import Path 

7from types import MappingProxyType 

8from typing import ( 

9 IO, 

10 TYPE_CHECKING, 

11 Any, 

12 ClassVar, 

13 Dict, 

14 List, 

15 Literal, 

16 Mapping, 

17 Optional, 

18 Protocol, 

19 Tuple, 

20 Union, 

21) 

22from zipfile import ZipFile 

23 

24import pydantic 

25from pydantic import DirectoryPath, PrivateAttr, model_validator 

26from pydantic_core import PydanticUndefined 

27from typing_extensions import Self 

28 

29from bioimageio.spec._internal.type_guards import is_dict 

30 

31from ..summary import ( 

32 WARNING_LEVEL_TO_NAME, 

33 ErrorEntry, 

34 ValidationDetail, 

35 ValidationSummary, 

36 WarningEntry, 

37) 

38from .field_warning import issue_warning 

39from .io import BioimageioYamlContent 

40from .io_basics import BIOIMAGEIO_YAML, AbsoluteFilePath, FileName, ZipPath 

41from .io_utils import write_content_to_zip 

42from .node import Node 

43from .packaging_context import PackagingContext 

44from .root_url import RootHttpUrl 

45from .url import HttpUrl 

46from .utils import get_format_version_tuple 

47from .validation_context import ValidationContext, get_validation_context 

48from .warning_levels import ALERT, ERROR, INFO 

49 

50 

51class NodeWithExplicitlySetFields(Node): 

52 _fields_to_set_explicitly: ClassVar[Mapping[str, Any]] 

53 

54 @classmethod 

55 def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: 

56 explict_fields: Dict[str, Any] = {} 

57 for attr in dir(cls): 

58 if attr.startswith("implemented_"): 

59 field_name = attr.replace("implemented_", "") 

60 if field_name not in cls.model_fields: 

61 continue 

62 

63 assert ( 

64 cls.model_fields[field_name].get_default() is PydanticUndefined 

65 ), field_name 

66 default = getattr(cls, attr) 

67 explict_fields[field_name] = default 

68 

69 cls._fields_to_set_explicitly = MappingProxyType(explict_fields) 

70 return super().__pydantic_init_subclass__(**kwargs) 

71 

72 @model_validator(mode="before") 

73 @classmethod 

74 def _set_fields_explicitly( 

75 cls, data: Union[Any, Dict[str, Any]] 

76 ) -> Union[Any, Dict[str, Any]]: 

77 if isinstance(data, dict): 

78 for name, default in cls._fields_to_set_explicitly.items(): 

79 if name not in data: 

80 data[name] = default 

81 

82 return data # pyright: ignore[reportUnknownVariableType] 

83 

84 

85if TYPE_CHECKING: 

86 

87 class _ResourceDescrBaseAbstractFieldsProtocol(Protocol): 

88 """workaround to add "abstract" fields to ResourceDescrBase""" 

89 

90 # TODO: implement as proper abstract fields of ResourceDescrBase 

91 

92 type: Any # should be LiteralString 

93 format_version: Any # should be LiteralString 

94 implemented_type: ClassVar[Any] 

95 implemented_format_version: ClassVar[Any] 

96 

97else: 

98 

99 class _ResourceDescrBaseAbstractFieldsProtocol: 

100 pass 

101 

102 

103class ResourceDescrBase( 

104 NodeWithExplicitlySetFields, ABC, _ResourceDescrBaseAbstractFieldsProtocol 

105): 

106 """base class for all resource descriptions""" 

107 

108 _validation_summary: Optional[ValidationSummary] = None 

109 

110 implemented_format_version_tuple: ClassVar[Tuple[int, int, int]] 

111 

112 # @field_validator("format_version", mode="before", check_fields=False) 

113 # field_validator on "format_version" is not possible, because we want to use 

114 # "format_version" in a descriminated Union higher up 

115 # (PydanticUserError: Cannot use a mode='before' validator in the discriminator 

116 # field 'format_version' of Model 'CollectionDescr') 

117 @model_validator(mode="before") 

118 @classmethod 

119 def _ignore_future_patch(cls, data: Any, /) -> Any: 

120 if ( 

121 cls.implemented_format_version == "unknown" 

122 or not is_dict(data) 

123 or "format_version" not in data 

124 ): 

125 return data 

126 

127 value = data["format_version"] 

128 fv = get_format_version_tuple(value) 

129 if fv is None: 

130 return data 

131 if ( 

132 fv[0] == cls.implemented_format_version_tuple[0] 

133 and fv[1:] > cls.implemented_format_version_tuple[1:] 

134 ): 

135 issue_warning( 

136 "future format_version '{value}' treated as '{implemented}'", 

137 value=value, 

138 msg_context=dict(implemented=cls.implemented_format_version), 

139 severity=ALERT, 

140 ) 

141 data["format_version"] = cls.implemented_format_version 

142 

143 return data 

144 

145 @model_validator(mode="after") 

146 def _set_init_validation_summary(self) -> Self: 

147 context = get_validation_context() 

148 detail_name = ( 

149 "Created" if isinstance(self, InvalidDescr) else "Successfully created" 

150 ) + f" `{self.__class__.__name__}` instance." 

151 self._validation_summary = ValidationSummary( 

152 name="bioimageio format validation", 

153 source_name=context.source_name, 

154 id=getattr(self, "id", None), 

155 type=self.type, 

156 format_version=self.format_version, 

157 status="failed" if isinstance(self, InvalidDescr) else "valid-format", 

158 details=[ 

159 ValidationDetail( 

160 name=detail_name, 

161 status="failed" if isinstance(self, InvalidDescr) else "passed", 

162 context=context.summary, 

163 ) 

164 ], 

165 ) 

166 return self 

167 

168 @property 

169 def validation_summary(self) -> ValidationSummary: 

170 assert self._validation_summary is not None, "access only after initialization" 

171 return self._validation_summary 

172 

173 _root: Union[RootHttpUrl, DirectoryPath, ZipFile] = PrivateAttr( 

174 default_factory=lambda: get_validation_context().root 

175 ) 

176 

177 _file_name: Optional[FileName] = PrivateAttr( 

178 default_factory=lambda: get_validation_context().file_name 

179 ) 

180 

181 @property 

182 def root(self) -> Union[RootHttpUrl, DirectoryPath, ZipFile]: 

183 """The URL/Path prefix to resolve any relative paths with.""" 

184 return self._root 

185 

186 @property 

187 def file_name(self) -> Optional[FileName]: 

188 """File name of the bioimageio.yaml file the description was loaded from.""" 

189 return self._file_name 

190 

191 @classmethod 

192 def __pydantic_init_subclass__(cls, **kwargs: Any): 

193 super().__pydantic_init_subclass__(**kwargs) 

194 # set classvar implemented_format_version_tuple 

195 if "format_version" in cls.model_fields: 

196 if "." not in cls.implemented_format_version: 

197 cls.implemented_format_version_tuple = (0, 0, 0) 

198 else: 

199 fv_tuple = get_format_version_tuple(cls.implemented_format_version) 

200 assert ( 

201 fv_tuple is not None 

202 ), f"failed to cast '{cls.implemented_format_version}' to tuple" 

203 cls.implemented_format_version_tuple = fv_tuple 

204 

205 @classmethod 

206 def load( 

207 cls, data: BioimageioYamlContent, context: Optional[ValidationContext] = None 

208 ) -> Union[Self, InvalidDescr]: 

209 """factory method to create a resource description object""" 

210 context = context or get_validation_context() 

211 assert isinstance(data, dict) 

212 with context: 

213 rd, errors, val_warnings = cls._load_impl(deepcopy(data)) 

214 

215 if context.warning_level > INFO: 

216 all_warnings_context = context.replace( 

217 warning_level=INFO, log_warnings=False 

218 ) 

219 # raise all validation warnings by reloading 

220 with all_warnings_context: 

221 _, _, val_warnings = cls._load_impl(deepcopy(data)) 

222 

223 rd.validation_summary.add_detail( 

224 ValidationDetail( 

225 errors=errors, 

226 name=( 

227 "bioimageio.spec format validation" 

228 f" {rd.type} {cls.implemented_format_version}" 

229 ), 

230 status="failed" if errors else "passed", 

231 warnings=val_warnings, 

232 context=context.summary, # context for format validation detail is identical 

233 ) 

234 ) 

235 

236 return rd 

237 

238 @classmethod 

239 def _load_impl( 

240 cls, data: BioimageioYamlContent 

241 ) -> Tuple[Union[Self, InvalidDescr], List[ErrorEntry], List[WarningEntry]]: 

242 rd: Union[Self, InvalidDescr, None] = None 

243 val_errors: List[ErrorEntry] = [] 

244 val_warnings: List[WarningEntry] = [] 

245 

246 context = get_validation_context() 

247 try: 

248 rd = cls.model_validate(data) 

249 except pydantic.ValidationError as e: 

250 for ee in e.errors(include_url=False): 

251 if (severity := ee.get("ctx", {}).get("severity", ERROR)) < ERROR: 

252 val_warnings.append( 

253 WarningEntry( 

254 loc=ee["loc"], 

255 msg=ee["msg"], 

256 type=ee["type"], 

257 severity=severity, 

258 ) 

259 ) 

260 elif context.raise_errors: 

261 raise e 

262 else: 

263 val_errors.append( 

264 ErrorEntry(loc=ee["loc"], msg=ee["msg"], type=ee["type"]) 

265 ) 

266 

267 if len(val_errors) == 0: # FIXME is this reduntant? 

268 val_errors.append( 

269 ErrorEntry( 

270 loc=(), 

271 msg=( 

272 f"Encountered {len(val_warnings)} more severe than warning" 

273 " level " 

274 f"'{WARNING_LEVEL_TO_NAME[context.warning_level]}'" 

275 ), 

276 type="severe_warnings", 

277 ) 

278 ) 

279 except Exception as e: 

280 if context.raise_errors: 

281 raise e 

282 

283 val_errors.append( 

284 ErrorEntry( 

285 loc=(), 

286 msg=str(e), 

287 type=type(e).__name__, 

288 with_traceback=True, 

289 ) 

290 ) 

291 

292 if rd is None: 

293 try: 

294 rd = InvalidDescr.model_validate(data) 

295 except Exception as e: 

296 if context.raise_errors: 

297 raise e 

298 resource_type = cls.model_fields["type"].default 

299 format_version = cls.implemented_format_version 

300 rd = InvalidDescr(type=resource_type, format_version=format_version) 

301 if context.raise_errors: 

302 raise ValueError(rd) 

303 

304 return rd, val_errors, val_warnings 

305 

306 def package( 

307 self, dest: Optional[Union[ZipFile, IO[bytes], Path, str]] = None, / 

308 ) -> ZipFile: 

309 """package the described resource as a zip archive 

310 

311 Args: 

312 dest: (path/bytes stream of) destination zipfile 

313 """ 

314 if dest is None: 

315 dest = BytesIO() 

316 

317 if isinstance(dest, ZipFile): 

318 zip = dest 

319 if "r" in zip.mode: 

320 raise ValueError( 

321 f"zip file {dest} opened in '{zip.mode}' mode," 

322 + " but write access is needed for packaging." 

323 ) 

324 else: 

325 zip = ZipFile(dest, mode="w") 

326 

327 if zip.filename is None: 

328 zip.filename = ( 

329 str(getattr(self, "id", getattr(self, "name", "bioimageio"))) + ".zip" 

330 ) 

331 

332 content = self.get_package_content() 

333 write_content_to_zip(content, zip) 

334 return zip 

335 

336 def get_package_content( 

337 self, 

338 ) -> Dict[ 

339 FileName, Union[HttpUrl, AbsoluteFilePath, BioimageioYamlContent, ZipPath] 

340 ]: 

341 """Returns package content without creating the package.""" 

342 content: Dict[FileName, Union[HttpUrl, AbsoluteFilePath, ZipPath]] = {} 

343 with PackagingContext( 

344 bioimageio_yaml_file_name=BIOIMAGEIO_YAML, 

345 file_sources=content, 

346 ): 

347 rdf_content: BioimageioYamlContent = self.model_dump( 

348 mode="json", exclude_unset=True 

349 ) 

350 

351 _ = rdf_content.pop("rdf_source", None) 

352 

353 return {**content, BIOIMAGEIO_YAML: rdf_content} 

354 

355 

356class InvalidDescr( 

357 ResourceDescrBase, 

358 extra="allow", 

359 title="An invalid resource description", 

360): 

361 """A representation of an invalid resource description""" 

362 

363 implemented_type: ClassVar[Literal["unknown"]] = "unknown" 

364 if TYPE_CHECKING: # see NodeWithExplicitlySetFields 

365 type: Any = "unknown" 

366 else: 

367 type: Any 

368 

369 implemented_format_version: ClassVar[Literal["unknown"]] = "unknown" 

370 if TYPE_CHECKING: # see NodeWithExplicitlySetFields 

371 format_version: Any = "unknown" 

372 else: 

373 format_version: Any 

374 

375 

376class KwargsNode(Node): 

377 def get(self, item: str, default: Any = None) -> Any: 

378 return self[item] if item in self else default 

379 

380 def __getitem__(self, item: str) -> Any: 

381 if item in self.model_fields: 

382 return getattr(self, item) 

383 else: 

384 raise KeyError(item) 

385 

386 def __contains__(self, item: str) -> int: 

387 return item in self.model_fields