Coverage for bioimageio/spec/model/v0_4.py: 90%

595 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-02 14:21 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 ClassVar, 

8 Dict, 

9 List, 

10 Literal, 

11 Optional, 

12 Sequence, 

13 Tuple, 

14 Type, 

15 Union, 

16) 

17 

18import numpy as np 

19from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf 

20from numpy.typing import NDArray 

21from pydantic import ( 

22 AllowInfNan, 

23 Discriminator, 

24 Field, 

25 RootModel, 

26 SerializationInfo, 

27 SerializerFunctionWrapHandler, 

28 StringConstraints, 

29 TypeAdapter, 

30 ValidationInfo, 

31 WrapSerializer, 

32 field_validator, 

33 model_validator, 

34) 

35from typing_extensions import Annotated, Self, assert_never, get_args 

36 

37from .._internal.common_nodes import ( 

38 KwargsNode, 

39 Node, 

40 NodeWithExplicitlySetFields, 

41) 

42from .._internal.constants import SHA256_HINT 

43from .._internal.field_validation import validate_unique_entries 

44from .._internal.field_warning import issue_warning, warn 

45from .._internal.io import ( 

46 BioimageioYamlContent, 

47 WithSuffix, 

48 download, 

49 include_in_package_serializer, 

50) 

51from .._internal.io import FileDescr as FileDescr 

52from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath 

53from .._internal.io_basics import Sha256 as Sha256 

54from .._internal.io_utils import load_array 

55from .._internal.packaging_context import packaging_context_var 

56from .._internal.types import Datetime as Datetime 

57from .._internal.types import Identifier as Identifier 

58from .._internal.types import ImportantFileSource, LowerCaseIdentifier 

59from .._internal.types import LicenseId as LicenseId 

60from .._internal.types import NotEmpty as NotEmpty 

61from .._internal.url import HttpUrl as HttpUrl 

62from .._internal.validated_string_with_inner_node import ValidatedStringWithInnerNode 

63from .._internal.validator_annotations import AfterValidator, RestrictCharacters 

64from .._internal.version_type import Version as Version 

65from .._internal.warning_levels import ALERT, INFO 

66from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS 

67from ..dataset.v0_2 import DatasetDescr as DatasetDescr 

68from ..dataset.v0_2 import LinkedDataset as LinkedDataset 

69from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr 

70from ..generic.v0_2 import Author as Author 

71from ..generic.v0_2 import BadgeDescr as BadgeDescr 

72from ..generic.v0_2 import CiteEntry as CiteEntry 

73from ..generic.v0_2 import Doi as Doi 

74from ..generic.v0_2 import GenericModelDescrBase 

75from ..generic.v0_2 import LinkedResource as LinkedResource 

76from ..generic.v0_2 import Maintainer as Maintainer 

77from ..generic.v0_2 import OrcidId as OrcidId 

78from ..generic.v0_2 import RelativeFilePath as RelativeFilePath 

79from ..generic.v0_2 import ResourceId as ResourceId 

80from ..generic.v0_2 import Uploader as Uploader 

81from ._v0_4_converter import convert_from_older_format 

82 

83 

84class ModelId(ResourceId): 

85 pass 

86 

87 

88AxesStr = Annotated[ 

89 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries) 

90] 

91AxesInCZYX = Annotated[ 

92 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries) 

93] 

94 

95PostprocessingName = Literal[ 

96 "binarize", 

97 "clip", 

98 "scale_linear", 

99 "sigmoid", 

100 "zero_mean_unit_variance", 

101 "scale_range", 

102 "scale_mean_variance", 

103] 

104PreprocessingName = Literal[ 

105 "binarize", 

106 "clip", 

107 "scale_linear", 

108 "sigmoid", 

109 "zero_mean_unit_variance", 

110 "scale_range", 

111] 

112 

113 

114class TensorName(LowerCaseIdentifier): 

115 pass 

116 

117 

118class CallableFromDepencencyNode(Node): 

119 _submodule_adapter: ClassVar[TypeAdapter[Identifier]] = TypeAdapter(Identifier) 

120 

121 module_name: str 

122 """The Python module that implements **callable_name**.""" 

123 

124 @field_validator("module_name", mode="after") 

125 def _check_submodules(cls, module_name: str) -> str: 

126 for submod in module_name.split("."): 

127 _ = cls._submodule_adapter.validate_python(submod) 

128 

129 return module_name 

130 

131 callable_name: Identifier 

132 """The callable Python identifier implemented in module **module_name**.""" 

133 

134 

135class CallableFromDepencency(ValidatedStringWithInnerNode[CallableFromDepencencyNode]): 

136 _inner_node_class = CallableFromDepencencyNode 

137 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

138 Annotated[ 

139 str, 

140 StringConstraints(strip_whitespace=True, pattern=r"^.+\..+$"), 

141 ] 

142 ] 

143 

144 @classmethod 

145 def _get_data(cls, valid_string_data: str): 

146 *mods, callname = valid_string_data.split(".") 

147 return dict(module_name=".".join(mods), callable_name=callname) 

148 

149 @property 

150 def module_name(self): 

151 """The Python module that implements **callable_name**.""" 

152 return self._inner_node.module_name 

153 

154 @property 

155 def callable_name(self): 

156 """The callable Python identifier implemented in module **module_name**.""" 

157 return self._inner_node.callable_name 

158 

159 

160class CallableFromFileNode(Node): 

161 source_file: Annotated[ 

162 Union[RelativeFilePath, HttpUrl], 

163 Field(union_mode="left_to_right"), 

164 include_in_package_serializer, 

165 ] 

166 """The Python source file that implements **callable_name**.""" 

167 callable_name: Identifier 

168 """The callable Python identifier implemented in **source_file**.""" 

169 

170 

171class CallableFromFile(ValidatedStringWithInnerNode[CallableFromFileNode]): 

172 _inner_node_class = CallableFromFileNode 

173 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

174 Annotated[ 

175 str, 

176 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 

177 ] 

178 ] 

179 

180 @classmethod 

181 def _get_data(cls, valid_string_data: str): 

182 *file_parts, callname = valid_string_data.split(":") 

183 return dict(source_file=":".join(file_parts), callable_name=callname) 

184 

185 @property 

186 def source_file(self): 

187 """The Python source file that implements **callable_name**.""" 

188 return self._inner_node.source_file 

189 

190 @property 

191 def callable_name(self): 

192 """The callable Python identifier implemented in **source_file**.""" 

193 return self._inner_node.callable_name 

194 

195 

196CustomCallable = Annotated[ 

197 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right") 

198] 

199 

200 

201class DependenciesNode(Node): 

202 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])] 

203 """Dependency manager""" 

204 

205 file: ImportantFileSource 

206 """Dependency file""" 

207 

208 

209class Dependencies(ValidatedStringWithInnerNode[DependenciesNode]): 

210 _inner_node_class = DependenciesNode 

211 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

212 Annotated[ 

213 str, 

214 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 

215 ] 

216 ] 

217 

218 @classmethod 

219 def _get_data(cls, valid_string_data: str): 

220 manager, *file_parts = valid_string_data.split(":") 

221 return dict(manager=manager, file=":".join(file_parts)) 

222 

223 @property 

224 def manager(self): 

225 """Dependency manager""" 

226 return self._inner_node.manager 

227 

228 @property 

229 def file(self): 

230 """Dependency file""" 

231 return self._inner_node.file 

232 

233 

234WeightsFormat = Literal[ 

235 "keras_hdf5", 

236 "onnx", 

237 "pytorch_state_dict", 

238 "tensorflow_js", 

239 "tensorflow_saved_model_bundle", 

240 "torchscript", 

241] 

242 

243 

244class WeightsDescr(Node): 

245 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

246 onnx: Optional[OnnxWeightsDescr] = None 

247 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

248 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

249 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

250 None 

251 ) 

252 torchscript: Optional[TorchscriptWeightsDescr] = None 

253 

254 @model_validator(mode="after") 

255 def check_one_entry(self) -> Self: 

256 if all( 

257 entry is None 

258 for entry in [ 

259 self.keras_hdf5, 

260 self.onnx, 

261 self.pytorch_state_dict, 

262 self.tensorflow_js, 

263 self.tensorflow_saved_model_bundle, 

264 self.torchscript, 

265 ] 

266 ): 

267 raise ValueError("Missing weights entry") 

268 

269 return self 

270 

271 def __getitem__( 

272 self, 

273 key: WeightsFormat, 

274 ): 

275 if key == "keras_hdf5": 

276 ret = self.keras_hdf5 

277 elif key == "onnx": 

278 ret = self.onnx 

279 elif key == "pytorch_state_dict": 

280 ret = self.pytorch_state_dict 

281 elif key == "tensorflow_js": 

282 ret = self.tensorflow_js 

283 elif key == "tensorflow_saved_model_bundle": 

284 ret = self.tensorflow_saved_model_bundle 

285 elif key == "torchscript": 

286 ret = self.torchscript 

287 else: 

288 raise KeyError(key) 

289 

290 if ret is None: 

291 raise KeyError(key) 

292 

293 return ret 

294 

295 @property 

296 def available_formats(self): 

297 return { 

298 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

299 **({} if self.onnx is None else {"onnx": self.onnx}), 

300 **( 

301 {} 

302 if self.pytorch_state_dict is None 

303 else {"pytorch_state_dict": self.pytorch_state_dict} 

304 ), 

305 **( 

306 {} 

307 if self.tensorflow_js is None 

308 else {"tensorflow_js": self.tensorflow_js} 

309 ), 

310 **( 

311 {} 

312 if self.tensorflow_saved_model_bundle is None 

313 else { 

314 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

315 } 

316 ), 

317 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

318 } 

319 

320 @property 

321 def missing_formats(self): 

322 return { 

323 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

324 } 

325 

326 

327class WeightsEntryDescrBase(FileDescr): 

328 type: ClassVar[WeightsFormat] 

329 weights_format_name: ClassVar[str] # human readable 

330 

331 source: ImportantFileSource 

332 """The weights file.""" 

333 

334 attachments: Annotated[ 

335 Union[AttachmentsDescr, None], 

336 warn(None, "Weights entry depends on additional attachments.", ALERT), 

337 ] = None 

338 """Attachments that are specific to this weights entry.""" 

339 

340 authors: Union[List[Author], None] = None 

341 """Authors 

342 Either the person(s) that have trained this model resulting in the original weights file. 

343 (If this is the initial weights entry, i.e. it does not have a `parent`) 

344 Or the person(s) who have converted the weights to this weights format. 

345 (If this is a child weight, i.e. it has a `parent` field) 

346 """ 

347 

348 dependencies: Annotated[ 

349 Optional[Dependencies], 

350 warn( 

351 None, 

352 "Custom dependencies ({value}) specified. Avoid this whenever possible " 

353 + "to allow execution in a wider range of software environments.", 

354 ), 

355 Field( 

356 examples=[ 

357 "conda:environment.yaml", 

358 "maven:./pom.xml", 

359 "pip:./requirements.txt", 

360 ] 

361 ), 

362 ] = None 

363 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`.""" 

364 

365 parent: Annotated[ 

366 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

367 ] = None 

368 """The source weights these weights were converted from. 

369 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

370 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

371 All weight entries except one (the initial set of weights resulting from training the model), 

372 need to have this field.""" 

373 

374 @model_validator(mode="after") 

375 def check_parent_is_not_self(self) -> Self: 

376 if self.type == self.parent: 

377 raise ValueError("Weights entry can't be it's own parent.") 

378 

379 return self 

380 

381 

382class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

383 type = "keras_hdf5" 

384 weights_format_name: ClassVar[str] = "Keras HDF5" 

385 tensorflow_version: Optional[Version] = None 

386 """TensorFlow version used to create these weights""" 

387 

388 @field_validator("tensorflow_version", mode="after") 

389 @classmethod 

390 def _tfv(cls, value: Any): 

391 if value is None: 

392 issue_warning( 

393 "missing. Please specify the TensorFlow version" 

394 + " these weights were created with.", 

395 value=value, 

396 severity=ALERT, 

397 field="tensorflow_version", 

398 ) 

399 return value 

400 

401 

402class OnnxWeightsDescr(WeightsEntryDescrBase): 

403 type = "onnx" 

404 weights_format_name: ClassVar[str] = "ONNX" 

405 opset_version: Optional[Annotated[int, Ge(7)]] = None 

406 """ONNX opset version""" 

407 

408 @field_validator("opset_version", mode="after") 

409 @classmethod 

410 def _ov(cls, value: Any): 

411 if value is None: 

412 issue_warning( 

413 "Missing ONNX opset version (aka ONNX opset number). " 

414 + "Please specify the ONNX opset version these weights were created" 

415 + " with.", 

416 value=value, 

417 severity=ALERT, 

418 field="opset_version", 

419 ) 

420 return value 

421 

422 

423class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

424 type = "pytorch_state_dict" 

425 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

426 architecture: CustomCallable = Field( 

427 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"] 

428 ) 

429 """callable returning a torch.nn.Module instance. 

430 Local implementation: `<relative path to file>:<identifier of implementation within the file>`. 

431 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`.""" 

432 

433 architecture_sha256: Annotated[ 

434 Optional[Sha256], 

435 Field( 

436 description=( 

437 "The SHA256 of the architecture source file, if the architecture is not" 

438 " defined in a module listed in `dependencies`\n" 

439 ) 

440 + SHA256_HINT, 

441 ), 

442 ] = None 

443 """The SHA256 of the architecture source file, 

444 if the architecture is not defined in a module listed in `dependencies`""" 

445 

446 @model_validator(mode="after") 

447 def check_architecture_sha256(self) -> Self: 

448 if isinstance(self.architecture, CallableFromFile): 

449 if self.architecture_sha256 is None: 

450 raise ValueError( 

451 "Missing required `architecture_sha256` for `architecture` with" 

452 + " source file." 

453 ) 

454 elif self.architecture_sha256 is not None: 

455 raise ValueError( 

456 "Got `architecture_sha256` for architecture that does not have a source" 

457 + " file." 

458 ) 

459 

460 return self 

461 

462 kwargs: Dict[str, Any] = Field(default_factory=dict) 

463 """key word arguments for the `architecture` callable""" 

464 

465 pytorch_version: Optional[Version] = None 

466 """Version of the PyTorch library used. 

467 If `depencencies` is specified it should include pytorch and the verison has to match. 

468 (`dependencies` overrules `pytorch_version`)""" 

469 

470 @field_validator("pytorch_version", mode="after") 

471 @classmethod 

472 def _ptv(cls, value: Any): 

473 if value is None: 

474 issue_warning( 

475 "missing. Please specify the PyTorch version these" 

476 + " PyTorch state dict weights were created with.", 

477 value=value, 

478 severity=ALERT, 

479 field="pytorch_version", 

480 ) 

481 return value 

482 

483 

484class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

485 type = "torchscript" 

486 weights_format_name: ClassVar[str] = "TorchScript" 

487 pytorch_version: Optional[Version] = None 

488 """Version of the PyTorch library used.""" 

489 

490 @field_validator("pytorch_version", mode="after") 

491 @classmethod 

492 def _ptv(cls, value: Any): 

493 if value is None: 

494 issue_warning( 

495 "missing. Please specify the PyTorch version these" 

496 + " Torchscript weights were created with.", 

497 value=value, 

498 severity=ALERT, 

499 field="pytorch_version", 

500 ) 

501 return value 

502 

503 

504class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

505 type = "tensorflow_js" 

506 weights_format_name: ClassVar[str] = "Tensorflow.js" 

507 tensorflow_version: Optional[Version] = None 

508 """Version of the TensorFlow library used.""" 

509 

510 @field_validator("tensorflow_version", mode="after") 

511 @classmethod 

512 def _tfv(cls, value: Any): 

513 if value is None: 

514 issue_warning( 

515 "missing. Please specify the TensorFlow version" 

516 + " these TensorflowJs weights were created with.", 

517 value=value, 

518 severity=ALERT, 

519 field="tensorflow_version", 

520 ) 

521 return value 

522 

523 source: ImportantFileSource 

524 """The multi-file weights. 

525 All required files/folders should be a zip archive.""" 

526 

527 

528class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

529 type = "tensorflow_saved_model_bundle" 

530 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

531 tensorflow_version: Optional[Version] = None 

532 """Version of the TensorFlow library used.""" 

533 

534 @field_validator("tensorflow_version", mode="after") 

535 @classmethod 

536 def _tfv(cls, value: Any): 

537 if value is None: 

538 issue_warning( 

539 "missing. Please specify the TensorFlow version" 

540 + " these Tensorflow saved model bundle weights were created with.", 

541 value=value, 

542 severity=ALERT, 

543 field="tensorflow_version", 

544 ) 

545 return value 

546 

547 

548class ParameterizedInputShape(Node): 

549 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`.""" 

550 

551 min: NotEmpty[List[int]] 

552 """The minimum input shape""" 

553 

554 step: NotEmpty[List[int]] 

555 """The minimum shape change""" 

556 

557 def __len__(self) -> int: 

558 return len(self.min) 

559 

560 @model_validator(mode="after") 

561 def matching_lengths(self) -> Self: 

562 if len(self.min) != len(self.step): 

563 raise ValueError("`min` and `step` required to have the same length") 

564 

565 return self 

566 

567 

568class ImplicitOutputShape(Node): 

569 """Output tensor shape depending on an input tensor shape. 

570 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`""" 

571 

572 reference_tensor: TensorName 

573 """Name of the reference tensor.""" 

574 

575 scale: NotEmpty[List[Optional[float]]] 

576 """output_pix/input_pix for each dimension. 

577 'null' values indicate new dimensions, whose length is defined by 2*`offset`""" 

578 

579 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]] 

580 """Position of origin wrt to input.""" 

581 

582 def __len__(self) -> int: 

583 return len(self.scale) 

584 

585 @model_validator(mode="after") 

586 def matching_lengths(self) -> Self: 

587 if len(self.scale) != len(self.offset): 

588 raise ValueError( 

589 f"scale {self.scale} has to have same length as offset {self.offset}!" 

590 ) 

591 # if we have an expanded dimension, make sure that it's offet is not zero 

592 for sc, off in zip(self.scale, self.offset): 

593 if sc is None and not off: 

594 raise ValueError("`offset` must not be zero if `scale` is none/zero") 

595 

596 return self 

597 

598 

599class TensorDescrBase(Node): 

600 name: TensorName 

601 """Tensor name. No duplicates are allowed.""" 

602 

603 description: str = "" 

604 

605 axes: AxesStr 

606 """Axes identifying characters. Same length and order as the axes in `shape`. 

607 | axis | description | 

608 | --- | --- | 

609 | b | batch (groups multiple samples) | 

610 | i | instance/index/element | 

611 | t | time | 

612 | c | channel | 

613 | z | spatial dimension z | 

614 | y | spatial dimension y | 

615 | x | spatial dimension x | 

616 """ 

617 

618 data_range: Optional[ 

619 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]] 

620 ] = None 

621 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

622 If not specified, the full data range that can be expressed in `data_type` is allowed.""" 

623 

624 

625class ProcessingKwargs(KwargsNode): 

626 """base class for pre-/postprocessing key word arguments""" 

627 

628 

629class ProcessingDescrBase(NodeWithExplicitlySetFields): 

630 """processing base class""" 

631 

632 

633class BinarizeKwargs(ProcessingKwargs): 

634 """key word arguments for `BinarizeDescr`""" 

635 

636 threshold: float 

637 """The fixed threshold""" 

638 

639 

640class BinarizeDescr(ProcessingDescrBase): 

641 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`. 

642 Values above the threshold will be set to one, values below the threshold to zero. 

643 """ 

644 

645 implemented_name: ClassVar[Literal["binarize"]] = "binarize" 

646 if TYPE_CHECKING: 

647 name: Literal["binarize"] = "binarize" 

648 else: 

649 name: Literal["binarize"] 

650 

651 kwargs: BinarizeKwargs 

652 

653 

654class ClipKwargs(ProcessingKwargs): 

655 """key word arguments for `ClipDescr`""" 

656 

657 min: float 

658 """minimum value for clipping""" 

659 max: float 

660 """maximum value for clipping""" 

661 

662 

663class ClipDescr(ProcessingDescrBase): 

664 """Clip tensor values to a range. 

665 

666 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min` 

667 and above `ClipKwargs.max` to `ClipKwargs.max`. 

668 """ 

669 

670 implemented_name: ClassVar[Literal["clip"]] = "clip" 

671 if TYPE_CHECKING: 

672 name: Literal["clip"] = "clip" 

673 else: 

674 name: Literal["clip"] 

675 

676 kwargs: ClipKwargs 

677 

678 

679class ScaleLinearKwargs(ProcessingKwargs): 

680 """key word arguments for `ScaleLinearDescr`""" 

681 

682 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 

683 """The subset of axes to scale jointly. 

684 For example xy to scale the two image axes for 2d data jointly.""" 

685 

686 gain: Union[float, List[float]] = 1.0 

687 """multiplicative factor""" 

688 

689 offset: Union[float, List[float]] = 0.0 

690 """additive term""" 

691 

692 @model_validator(mode="after") 

693 def either_gain_or_offset(self) -> Self: 

694 if ( 

695 self.gain == 1.0 

696 or isinstance(self.gain, list) 

697 and all(g == 1.0 for g in self.gain) 

698 ) and ( 

699 self.offset == 0.0 

700 or isinstance(self.offset, list) 

701 and all(off == 0.0 for off in self.offset) 

702 ): 

703 raise ValueError( 

704 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 

705 + " 0.0." 

706 ) 

707 

708 return self 

709 

710 

711class ScaleLinearDescr(ProcessingDescrBase): 

712 """Fixed linear scaling.""" 

713 

714 implemented_name: ClassVar[Literal["scale_linear"]] = "scale_linear" 

715 if TYPE_CHECKING: 

716 name: Literal["scale_linear"] = "scale_linear" 

717 else: 

718 name: Literal["scale_linear"] 

719 

720 kwargs: ScaleLinearKwargs 

721 

722 

723class SigmoidDescr(ProcessingDescrBase): 

724 """The logistic sigmoid funciton, a.k.a. expit function.""" 

725 

726 implemented_name: ClassVar[Literal["sigmoid"]] = "sigmoid" 

727 if TYPE_CHECKING: 

728 name: Literal["sigmoid"] = "sigmoid" 

729 else: 

730 name: Literal["sigmoid"] 

731 

732 @property 

733 def kwargs(self) -> ProcessingKwargs: 

734 """empty kwargs""" 

735 return ProcessingKwargs() 

736 

737 

738class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

739 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

740 

741 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed" 

742 """Mode for computing mean and variance. 

743 | mode | description | 

744 | ----------- | ------------------------------------ | 

745 | fixed | Fixed values for mean and variance | 

746 | per_dataset | Compute for the entire dataset | 

747 | per_sample | Compute for each sample individually | 

748 """ 

749 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 

750 """The subset of axes to normalize jointly. 

751 For example `xy` to normalize the two image axes for 2d data jointly.""" 

752 

753 mean: Annotated[ 

754 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)]) 

755 ] = None 

756 """The mean value(s) to use for `mode: fixed`. 

757 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`.""" 

758 # todo: check if means match input axes (for mode 'fixed') 

759 

760 std: Annotated[ 

761 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)]) 

762 ] = None 

763 """The standard deviation values to use for `mode: fixed`. Analogous to mean.""" 

764 

765 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

766 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

767 

768 @model_validator(mode="after") 

769 def mean_and_std_match_mode(self) -> Self: 

770 if self.mode == "fixed" and (self.mean is None or self.std is None): 

771 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 

772 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 

773 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 

774 

775 return self 

776 

777 

778class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

779 """Subtract mean and divide by variance.""" 

780 

781 implemented_name: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

782 "zero_mean_unit_variance" 

783 ) 

784 if TYPE_CHECKING: 

785 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

786 else: 

787 name: Literal["zero_mean_unit_variance"] 

788 

789 kwargs: ZeroMeanUnitVarianceKwargs 

790 

791 

792class ScaleRangeKwargs(ProcessingKwargs): 

793 """key word arguments for `ScaleRangeDescr` 

794 

795 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

796 this processing step normalizes data to the [0, 1] intervall. 

797 For other percentiles the normalized values will partially be outside the [0, 1] 

798 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

799 normalized values to a range. 

800 """ 

801 

802 mode: Literal["per_dataset", "per_sample"] 

803 """Mode for computing percentiles. 

804 | mode | description | 

805 | ----------- | ------------------------------------ | 

806 | per_dataset | compute for the entire dataset | 

807 | per_sample | compute for each sample individually | 

808 """ 

809 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 

810 """The subset of axes to normalize jointly. 

811 For example xy to normalize the two image axes for 2d data jointly.""" 

812 

813 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0 

814 """The lower percentile used to determine the value to align with zero.""" 

815 

816 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0 

817 """The upper percentile used to determine the value to align with one. 

818 Has to be bigger than `min_percentile`. 

819 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

820 accepting percentiles specified in the range 0.0 to 1.0.""" 

821 

822 @model_validator(mode="after") 

823 def min_smaller_max(self, info: ValidationInfo) -> Self: 

824 if self.min_percentile >= self.max_percentile: 

825 raise ValueError( 

826 f"min_percentile {self.min_percentile} >= max_percentile" 

827 + f" {self.max_percentile}" 

828 ) 

829 

830 return self 

831 

832 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

833 """Epsilon for numeric stability. 

834 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

835 with `v_lower,v_upper` values at the respective percentiles.""" 

836 

837 reference_tensor: Optional[TensorName] = None 

838 """Tensor name to compute the percentiles from. Default: The tensor itself. 

839 For any tensor in `inputs` only input tensor references are allowed. 

840 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`""" 

841 

842 

843class ScaleRangeDescr(ProcessingDescrBase): 

844 """Scale with percentiles.""" 

845 

846 implemented_name: ClassVar[Literal["scale_range"]] = "scale_range" 

847 if TYPE_CHECKING: 

848 name: Literal["scale_range"] = "scale_range" 

849 else: 

850 name: Literal["scale_range"] 

851 

852 kwargs: ScaleRangeKwargs 

853 

854 

855class ScaleMeanVarianceKwargs(ProcessingKwargs): 

856 """key word arguments for `ScaleMeanVarianceDescr`""" 

857 

858 mode: Literal["per_dataset", "per_sample"] 

859 """Mode for computing mean and variance. 

860 | mode | description | 

861 | ----------- | ------------------------------------ | 

862 | per_dataset | Compute for the entire dataset | 

863 | per_sample | Compute for each sample individually | 

864 """ 

865 

866 reference_tensor: TensorName 

867 """Name of tensor to match.""" 

868 

869 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 

870 """The subset of axes to scale jointly. 

871 For example xy to normalize the two image axes for 2d data jointly. 

872 Default: scale all non-batch axes jointly.""" 

873 

874 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

875 """Epsilon for numeric stability: 

876 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.""" 

877 

878 

879class ScaleMeanVarianceDescr(ProcessingDescrBase): 

880 """Scale the tensor s.t. its mean and variance match a reference tensor.""" 

881 

882 implemented_name: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

883 if TYPE_CHECKING: 

884 name: Literal["scale_mean_variance"] = "scale_mean_variance" 

885 else: 

886 name: Literal["scale_mean_variance"] 

887 

888 kwargs: ScaleMeanVarianceKwargs 

889 

890 

891PreprocessingDescr = Annotated[ 

892 Union[ 

893 BinarizeDescr, 

894 ClipDescr, 

895 ScaleLinearDescr, 

896 SigmoidDescr, 

897 ZeroMeanUnitVarianceDescr, 

898 ScaleRangeDescr, 

899 ], 

900 Discriminator("name"), 

901] 

902PostprocessingDescr = Annotated[ 

903 Union[ 

904 BinarizeDescr, 

905 ClipDescr, 

906 ScaleLinearDescr, 

907 SigmoidDescr, 

908 ZeroMeanUnitVarianceDescr, 

909 ScaleRangeDescr, 

910 ScaleMeanVarianceDescr, 

911 ], 

912 Discriminator("name"), 

913] 

914 

915 

916class InputTensorDescr(TensorDescrBase): 

917 data_type: Literal["float32", "uint8", "uint16"] 

918 """For now an input tensor is expected to be given as `float32`. 

919 The data flow in bioimage.io models is explained 

920 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 

921 

922 shape: Annotated[ 

923 Union[Sequence[int], ParameterizedInputShape], 

924 Field( 

925 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))] 

926 ), 

927 ] 

928 """Specification of input tensor shape.""" 

929 

930 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 

931 """Description of how this input should be preprocessed.""" 

932 

933 @model_validator(mode="after") 

934 def zero_batch_step_and_one_batch_size(self) -> Self: 

935 bidx = self.axes.find("b") 

936 if bidx == -1: 

937 return self 

938 

939 if isinstance(self.shape, ParameterizedInputShape): 

940 step = self.shape.step 

941 shape = self.shape.min 

942 if step[bidx] != 0: 

943 raise ValueError( 

944 "Input shape step has to be zero in the batch dimension (the batch" 

945 + " dimension can always be increased, but `step` should specify how" 

946 + " to increase the minimal shape to find the largest single batch" 

947 + " shape)" 

948 ) 

949 else: 

950 shape = self.shape 

951 

952 if shape[bidx] != 1: 

953 raise ValueError("Input shape has to be 1 in the batch dimension b.") 

954 

955 return self 

956 

957 @model_validator(mode="after") 

958 def validate_preprocessing_kwargs(self) -> Self: 

959 for p in self.preprocessing: 

960 kwargs_axes = p.kwargs.get("axes") 

961 if isinstance(kwargs_axes, str) and any( 

962 a not in self.axes for a in kwargs_axes 

963 ): 

964 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 

965 

966 return self 

967 

968 

969class OutputTensorDescr(TensorDescrBase): 

970 data_type: Literal[ 

971 "float32", 

972 "float64", 

973 "uint8", 

974 "int8", 

975 "uint16", 

976 "int16", 

977 "uint32", 

978 "int32", 

979 "uint64", 

980 "int64", 

981 "bool", 

982 ] 

983 """Data type. 

984 The data flow in bioimage.io models is explained 

985 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 

986 

987 shape: Union[Sequence[int], ImplicitOutputShape] 

988 """Output tensor shape.""" 

989 

990 halo: Optional[Sequence[int]] = None 

991 """The `halo` that should be cropped from the output tensor to avoid boundary effects. 

992 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. 

993 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead.""" 

994 

995 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 

996 """Description of how this output should be postprocessed.""" 

997 

998 @model_validator(mode="after") 

999 def matching_halo_length(self) -> Self: 

1000 if self.halo and len(self.halo) != len(self.shape): 

1001 raise ValueError( 

1002 f"halo {self.halo} has to have same length as shape {self.shape}!" 

1003 ) 

1004 

1005 return self 

1006 

1007 @model_validator(mode="after") 

1008 def validate_postprocessing_kwargs(self) -> Self: 

1009 for p in self.postprocessing: 

1010 kwargs_axes = p.kwargs.get("axes", "") 

1011 if not isinstance(kwargs_axes, str): 

1012 raise ValueError(f"Expected {kwargs_axes} to be a string") 

1013 

1014 if any(a not in self.axes for a in kwargs_axes): 

1015 raise ValueError("`kwargs.axes` needs to be subset of axes") 

1016 

1017 return self 

1018 

1019 

1020KnownRunMode = Literal["deepimagej"] 

1021 

1022 

1023class RunMode(Node): 

1024 name: Annotated[ 

1025 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.") 

1026 ] 

1027 """Run mode name""" 

1028 

1029 kwargs: Dict[str, Any] = Field(default_factory=dict) 

1030 """Run mode specific key word arguments""" 

1031 

1032 

1033class LinkedModel(Node): 

1034 """Reference to a bioimage.io model.""" 

1035 

1036 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])] 

1037 """A valid model `id` from the bioimage.io collection.""" 

1038 

1039 version_number: Optional[int] = None 

1040 """version number (n-th published version, not the semantic version) of linked model""" 

1041 

1042 

1043def package_weights( 

1044 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr] 

1045 handler: SerializerFunctionWrapHandler, 

1046 info: SerializationInfo, 

1047): 

1048 ctxt = packaging_context_var.get() 

1049 if ctxt is not None and ctxt.weights_priority_order is not None: 

1050 for wf in ctxt.weights_priority_order: 

1051 w = getattr(value, wf, None) 

1052 if w is not None: 

1053 break 

1054 else: 

1055 raise ValueError( 

1056 "None of the weight formats in `weights_priority_order`" 

1057 + f" ({ctxt.weights_priority_order}) is present in the given model." 

1058 ) 

1059 

1060 assert isinstance(w, Node), type(w) 

1061 # construct WeightsDescr with new single weight format entry 

1062 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"}) 

1063 value = value.model_construct(None, **{wf: new_w}) 

1064 

1065 return handler( 

1066 value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs 

1067 ) 

1068 

1069 

1070class ModelDescr(GenericModelDescrBase): 

1071 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights. 

1072 

1073 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

1074 """ 

1075 

1076 implemented_format_version: ClassVar[Literal["0.4.10"]] = "0.4.10" 

1077 if TYPE_CHECKING: 

1078 format_version: Literal["0.4.10"] = "0.4.10" 

1079 else: 

1080 format_version: Literal["0.4.10"] 

1081 """Version of the bioimage.io model description specification used. 

1082 When creating a new model always use the latest micro/patch version described here. 

1083 The `format_version` is important for any consumer software to understand how to parse the fields. 

1084 """ 

1085 

1086 implemented_type: ClassVar[Literal["model"]] = "model" 

1087 if TYPE_CHECKING: 

1088 type: Literal["model"] = "model" 

1089 else: 

1090 type: Literal["model"] 

1091 """Specialized resource type 'model'""" 

1092 

1093 id: Optional[ModelId] = None 

1094 """bioimage.io-wide unique resource identifier 

1095 assigned by bioimage.io; version **un**specific.""" 

1096 

1097 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory 

1098 List[Author] 

1099 ] 

1100 """The authors are the creators of the model RDF and the primary points of contact.""" 

1101 

1102 documentation: Annotated[ 

1103 ImportantFileSource, 

1104 Field( 

1105 examples=[ 

1106 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 

1107 "README.md", 

1108 ], 

1109 ), 

1110 ] 

1111 """URL or relative path to a markdown file with additional documentation. 

1112 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

1113 The documentation should include a '[#[#]]# Validation' (sub)section 

1114 with details on how to quantitatively validate the model on unseen data.""" 

1115 

1116 inputs: NotEmpty[List[InputTensorDescr]] 

1117 """Describes the input tensors expected by this model.""" 

1118 

1119 license: Annotated[ 

1120 Union[LicenseId, str], 

1121 warn(LicenseId, "Unknown license id '{value}'."), 

1122 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]), 

1123 ] 

1124 """A [SPDX license identifier](https://spdx.org/licenses/). 

1125 We do notsupport custom license beyond the SPDX license list, if you need that please 

1126 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose 

1127 ) to discuss your intentions with the community.""" 

1128 

1129 name: Annotated[ 

1130 str, 

1131 MinLen(1), 

1132 warn(MinLen(5), "Name shorter than 5 characters.", INFO), 

1133 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

1134 ] 

1135 """A human-readable name of this model. 

1136 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.""" 

1137 

1138 outputs: NotEmpty[List[OutputTensorDescr]] 

1139 """Describes the output tensors.""" 

1140 

1141 @field_validator("inputs", "outputs") 

1142 @classmethod 

1143 def unique_tensor_descr_names( 

1144 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 

1145 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 

1146 unique_names = {str(v.name) for v in value} 

1147 if len(unique_names) != len(value): 

1148 raise ValueError("Duplicate tensor descriptor names") 

1149 

1150 return value 

1151 

1152 @model_validator(mode="after") 

1153 def unique_io_names(self) -> Self: 

1154 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 

1155 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 

1156 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 

1157 

1158 return self 

1159 

1160 @model_validator(mode="after") 

1161 def minimum_shape2valid_output(self) -> Self: 

1162 tensors_by_name: Dict[ 

1163 TensorName, Union[InputTensorDescr, OutputTensorDescr] 

1164 ] = {t.name: t for t in self.inputs + self.outputs} 

1165 

1166 for out in self.outputs: 

1167 if isinstance(out.shape, ImplicitOutputShape): 

1168 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 

1169 ndim_out_ref = len( 

1170 [scale for scale in out.shape.scale if scale is not None] 

1171 ) 

1172 if ndim_ref != ndim_out_ref: 

1173 expanded_dim_note = ( 

1174 " Note that expanded dimensions (`scale`: null) are not" 

1175 + f" counted for {out.name}'sdimensionality here." 

1176 if None in out.shape.scale 

1177 else "" 

1178 ) 

1179 raise ValueError( 

1180 f"Referenced tensor '{out.shape.reference_tensor}' with" 

1181 + f" {ndim_ref} dimensions does not match output tensor" 

1182 + f" '{out.name}' with" 

1183 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 

1184 ) 

1185 

1186 min_out_shape = self._get_min_shape(out, tensors_by_name) 

1187 if out.halo: 

1188 halo = out.halo 

1189 halo_msg = f" for halo {out.halo}" 

1190 else: 

1191 halo = [0] * len(min_out_shape) 

1192 halo_msg = "" 

1193 

1194 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 

1195 raise ValueError( 

1196 f"Minimal shape {min_out_shape} of output {out.name} is too" 

1197 + f" small{halo_msg}." 

1198 ) 

1199 

1200 return self 

1201 

1202 @classmethod 

1203 def _get_min_shape( 

1204 cls, 

1205 t: Union[InputTensorDescr, OutputTensorDescr], 

1206 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]], 

1207 ) -> Sequence[int]: 

1208 """output with subtracted halo has to result in meaningful output even for the minimal input 

1209 see https://github.com/bioimage-io/spec-bioimage-io/issues/392 

1210 """ 

1211 if isinstance(t.shape, collections.abc.Sequence): 

1212 return t.shape 

1213 elif isinstance(t.shape, ParameterizedInputShape): 

1214 return t.shape.min 

1215 elif isinstance(t.shape, ImplicitOutputShape): 

1216 pass 

1217 else: 

1218 assert_never(t.shape) 

1219 

1220 ref_shape = cls._get_min_shape( 

1221 tensors_by_name[t.shape.reference_tensor], tensors_by_name 

1222 ) 

1223 

1224 if None not in t.shape.scale: 

1225 scale: Sequence[float, ...] = t.shape.scale # type: ignore 

1226 else: 

1227 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None] 

1228 new_ref_shape: List[int] = [] 

1229 for idx in range(len(t.shape.scale)): 

1230 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims) 

1231 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx]) 

1232 

1233 ref_shape = new_ref_shape 

1234 assert len(ref_shape) == len(t.shape.scale) 

1235 scale = [0.0 if sc is None else sc for sc in t.shape.scale] 

1236 

1237 offset = t.shape.offset 

1238 assert len(offset) == len(scale) 

1239 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)] 

1240 

1241 @model_validator(mode="after") 

1242 def validate_tensor_references_in_inputs(self) -> Self: 

1243 for t in self.inputs: 

1244 for proc in t.preprocessing: 

1245 if "reference_tensor" not in proc.kwargs: 

1246 continue 

1247 

1248 ref_tensor = proc.kwargs["reference_tensor"] 

1249 if ref_tensor is not None and str(ref_tensor) not in { 

1250 str(t.name) for t in self.inputs 

1251 }: 

1252 raise ValueError(f"'{ref_tensor}' not found in inputs") 

1253 

1254 if ref_tensor == t.name: 

1255 raise ValueError( 

1256 f"invalid self reference for preprocessing of tensor {t.name}" 

1257 ) 

1258 

1259 return self 

1260 

1261 @model_validator(mode="after") 

1262 def validate_tensor_references_in_outputs(self) -> Self: 

1263 for t in self.outputs: 

1264 for proc in t.postprocessing: 

1265 if "reference_tensor" not in proc.kwargs: 

1266 continue 

1267 ref_tensor = proc.kwargs["reference_tensor"] 

1268 if ref_tensor is not None and str(ref_tensor) not in { 

1269 str(t.name) for t in self.inputs 

1270 }: 

1271 raise ValueError(f"{ref_tensor} not found in inputs") 

1272 

1273 return self 

1274 

1275 packaged_by: List[Author] = Field(default_factory=list) 

1276 """The persons that have packaged and uploaded this model. 

1277 Only required if those persons differ from the `authors`.""" 

1278 

1279 parent: Optional[LinkedModel] = None 

1280 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

1281 

1282 @field_validator("parent", mode="before") 

1283 @classmethod 

1284 def ignore_url_parent(cls, parent: Any): 

1285 if isinstance(parent, dict): 

1286 return None 

1287 

1288 else: 

1289 return parent 

1290 

1291 run_mode: Optional[RunMode] = None 

1292 """Custom run mode for this model: for more complex prediction procedures like test time 

1293 data augmentation that currently cannot be expressed in the specification. 

1294 No standard run modes are defined yet.""" 

1295 

1296 sample_inputs: List[ImportantFileSource] = Field(default_factory=list) 

1297 """URLs/relative paths to sample inputs to illustrate possible inputs for the model, 

1298 for example stored as PNG or TIFF images. 

1299 The sample files primarily serve to inform a human user about an example use case""" 

1300 

1301 sample_outputs: List[ImportantFileSource] = Field(default_factory=list) 

1302 """URLs/relative paths to sample outputs corresponding to the `sample_inputs`.""" 

1303 

1304 test_inputs: NotEmpty[ 

1305 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 

1306 ] 

1307 """Test input tensors compatible with the `inputs` description for a **single test case**. 

1308 This means if your model has more than one input, you should provide one URL/relative path for each input. 

1309 Each test input should be a file with an ndarray in 

1310 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1311 The extension must be '.npy'.""" 

1312 

1313 test_outputs: NotEmpty[ 

1314 List[Annotated[ImportantFileSource, WithSuffix(".npy", case_sensitive=True)]] 

1315 ] 

1316 """Analog to `test_inputs`.""" 

1317 

1318 timestamp: Datetime 

1319 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

1320 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).""" 

1321 

1322 training_data: Union[LinkedDataset, DatasetDescr, None] = None 

1323 """The dataset used to train this model""" 

1324 

1325 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

1326 """The weights for this model. 

1327 Weights can be given for different formats, but should otherwise be equivalent. 

1328 The available weight formats determine which consumers can use this model.""" 

1329 

1330 @model_validator(mode="before") 

1331 @classmethod 

1332 def _convert_from_older_format( 

1333 cls, data: BioimageioYamlContent, / 

1334 ) -> BioimageioYamlContent: 

1335 convert_from_older_format(data) 

1336 return data 

1337 

1338 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

1339 data = [load_array(download(ipt).path) for ipt in self.test_inputs] 

1340 assert all(isinstance(d, np.ndarray) for d in data) 

1341 return data 

1342 

1343 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

1344 data = [load_array(download(out).path) for out in self.test_outputs] 

1345 assert all(isinstance(d, np.ndarray) for d in data) 

1346 return data