Coverage for bioimageio/spec/model/v0_4.py: 91%

595 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-11 07:34 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Callable, 

8 ClassVar, 

9 Dict, 

10 List, 

11 Literal, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Type, 

16 Union, 

17 cast, 

18) 

19 

20import numpy as np 

21from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf 

22from numpy.typing import NDArray 

23from pydantic import ( 

24 AllowInfNan, 

25 Discriminator, 

26 Field, 

27 RootModel, 

28 SerializationInfo, 

29 SerializerFunctionWrapHandler, 

30 StringConstraints, 

31 TypeAdapter, 

32 ValidationInfo, 

33 WrapSerializer, 

34 field_validator, 

35 model_validator, 

36) 

37from typing_extensions import Annotated, Self, assert_never, get_args 

38 

39from .._internal.common_nodes import ( 

40 KwargsNode, 

41 Node, 

42 NodeWithExplicitlySetFields, 

43) 

44from .._internal.constants import SHA256_HINT 

45from .._internal.field_validation import validate_unique_entries 

46from .._internal.field_warning import issue_warning, warn 

47from .._internal.io import BioimageioYamlContent, WithSuffix 

48from .._internal.io import FileDescr as FileDescr 

49from .._internal.io_basics import Sha256 as Sha256 

50from .._internal.io_packaging import include_in_package 

51from .._internal.io_utils import load_array 

52from .._internal.packaging_context import packaging_context_var 

53from .._internal.types import Datetime as Datetime 

54from .._internal.types import FileSource_, LowerCaseIdentifier 

55from .._internal.types import Identifier as Identifier 

56from .._internal.types import LicenseId as LicenseId 

57from .._internal.types import NotEmpty as NotEmpty 

58from .._internal.url import HttpUrl as HttpUrl 

59from .._internal.validated_string_with_inner_node import ValidatedStringWithInnerNode 

60from .._internal.validator_annotations import AfterValidator, RestrictCharacters 

61from .._internal.version_type import Version as Version 

62from .._internal.warning_levels import ALERT, INFO 

63from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS 

64from ..dataset.v0_2 import DatasetDescr as DatasetDescr 

65from ..dataset.v0_2 import LinkedDataset as LinkedDataset 

66from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr 

67from ..generic.v0_2 import Author as Author 

68from ..generic.v0_2 import BadgeDescr as BadgeDescr 

69from ..generic.v0_2 import CiteEntry as CiteEntry 

70from ..generic.v0_2 import Doi as Doi 

71from ..generic.v0_2 import GenericModelDescrBase 

72from ..generic.v0_2 import LinkedResource as LinkedResource 

73from ..generic.v0_2 import Maintainer as Maintainer 

74from ..generic.v0_2 import OrcidId as OrcidId 

75from ..generic.v0_2 import RelativeFilePath as RelativeFilePath 

76from ..generic.v0_2 import ResourceId as ResourceId 

77from ..generic.v0_2 import Uploader as Uploader 

78from ._v0_4_converter import convert_from_older_format 

79 

80 

81class ModelId(ResourceId): 

82 pass 

83 

84 

85AxesStr = Annotated[ 

86 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries) 

87] 

88AxesInCZYX = Annotated[ 

89 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries) 

90] 

91 

92PostprocessingName = Literal[ 

93 "binarize", 

94 "clip", 

95 "scale_linear", 

96 "sigmoid", 

97 "zero_mean_unit_variance", 

98 "scale_range", 

99 "scale_mean_variance", 

100] 

101PreprocessingName = Literal[ 

102 "binarize", 

103 "clip", 

104 "scale_linear", 

105 "sigmoid", 

106 "zero_mean_unit_variance", 

107 "scale_range", 

108] 

109 

110 

111class TensorName(LowerCaseIdentifier): 

112 pass 

113 

114 

115class CallableFromDepencencyNode(Node): 

116 _submodule_adapter: ClassVar[TypeAdapter[Identifier]] = TypeAdapter(Identifier) 

117 

118 module_name: str 

119 """The Python module that implements **callable_name**.""" 

120 

121 @field_validator("module_name", mode="after") 

122 def _check_submodules(cls, module_name: str) -> str: 

123 for submod in module_name.split("."): 

124 _ = cls._submodule_adapter.validate_python(submod) 

125 

126 return module_name 

127 

128 callable_name: Identifier 

129 """The callable Python identifier implemented in module **module_name**.""" 

130 

131 

132class CallableFromDepencency(ValidatedStringWithInnerNode[CallableFromDepencencyNode]): 

133 _inner_node_class = CallableFromDepencencyNode 

134 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

135 Annotated[ 

136 str, 

137 StringConstraints(strip_whitespace=True, pattern=r"^.+\..+$"), 

138 ] 

139 ] 

140 

141 @classmethod 

142 def _get_data(cls, valid_string_data: str): 

143 *mods, callname = valid_string_data.split(".") 

144 return dict(module_name=".".join(mods), callable_name=callname) 

145 

146 @property 

147 def module_name(self): 

148 """The Python module that implements **callable_name**.""" 

149 return self._inner_node.module_name 

150 

151 @property 

152 def callable_name(self): 

153 """The callable Python identifier implemented in module **module_name**.""" 

154 return self._inner_node.callable_name 

155 

156 

157class CallableFromFileNode(Node): 

158 source_file: Annotated[ 

159 Union[RelativeFilePath, HttpUrl], 

160 Field(union_mode="left_to_right"), 

161 include_in_package, 

162 ] 

163 """The Python source file that implements **callable_name**.""" 

164 callable_name: Identifier 

165 """The callable Python identifier implemented in **source_file**.""" 

166 

167 

168class CallableFromFile(ValidatedStringWithInnerNode[CallableFromFileNode]): 

169 _inner_node_class = CallableFromFileNode 

170 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

171 Annotated[ 

172 str, 

173 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 

174 ] 

175 ] 

176 

177 @classmethod 

178 def _get_data(cls, valid_string_data: str): 

179 *file_parts, callname = valid_string_data.split(":") 

180 return dict(source_file=":".join(file_parts), callable_name=callname) 

181 

182 @property 

183 def source_file(self): 

184 """The Python source file that implements **callable_name**.""" 

185 return self._inner_node.source_file 

186 

187 @property 

188 def callable_name(self): 

189 """The callable Python identifier implemented in **source_file**.""" 

190 return self._inner_node.callable_name 

191 

192 

193CustomCallable = Annotated[ 

194 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right") 

195] 

196 

197 

198class DependenciesNode(Node): 

199 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])] 

200 """Dependency manager""" 

201 

202 file: FileSource_ 

203 """Dependency file""" 

204 

205 

206class Dependencies(ValidatedStringWithInnerNode[DependenciesNode]): 

207 _inner_node_class = DependenciesNode 

208 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

209 Annotated[ 

210 str, 

211 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 

212 ] 

213 ] 

214 

215 @classmethod 

216 def _get_data(cls, valid_string_data: str): 

217 manager, *file_parts = valid_string_data.split(":") 

218 return dict(manager=manager, file=":".join(file_parts)) 

219 

220 @property 

221 def manager(self): 

222 """Dependency manager""" 

223 return self._inner_node.manager 

224 

225 @property 

226 def file(self): 

227 """Dependency file""" 

228 return self._inner_node.file 

229 

230 

231WeightsFormat = Literal[ 

232 "keras_hdf5", 

233 "onnx", 

234 "pytorch_state_dict", 

235 "tensorflow_js", 

236 "tensorflow_saved_model_bundle", 

237 "torchscript", 

238] 

239 

240 

241class WeightsEntryDescrBase(FileDescr): 

242 type: ClassVar[WeightsFormat] 

243 weights_format_name: ClassVar[str] # human readable 

244 

245 source: FileSource_ 

246 """The weights file.""" 

247 

248 attachments: Annotated[ 

249 Union[AttachmentsDescr, None], 

250 warn(None, "Weights entry depends on additional attachments.", ALERT), 

251 ] = None 

252 """Attachments that are specific to this weights entry.""" 

253 

254 authors: Union[List[Author], None] = None 

255 """Authors 

256 Either the person(s) that have trained this model resulting in the original weights file. 

257 (If this is the initial weights entry, i.e. it does not have a `parent`) 

258 Or the person(s) who have converted the weights to this weights format. 

259 (If this is a child weight, i.e. it has a `parent` field) 

260 """ 

261 

262 dependencies: Annotated[ 

263 Optional[Dependencies], 

264 warn( 

265 None, 

266 "Custom dependencies ({value}) specified. Avoid this whenever possible " 

267 + "to allow execution in a wider range of software environments.", 

268 ), 

269 Field( 

270 examples=[ 

271 "conda:environment.yaml", 

272 "maven:./pom.xml", 

273 "pip:./requirements.txt", 

274 ] 

275 ), 

276 ] = None 

277 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`.""" 

278 

279 parent: Annotated[ 

280 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

281 ] = None 

282 """The source weights these weights were converted from. 

283 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

284 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

285 All weight entries except one (the initial set of weights resulting from training the model), 

286 need to have this field.""" 

287 

288 @model_validator(mode="after") 

289 def check_parent_is_not_self(self) -> Self: 

290 if self.type == self.parent: 

291 raise ValueError("Weights entry can't be it's own parent.") 

292 

293 return self 

294 

295 

296class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

297 type = "keras_hdf5" 

298 weights_format_name: ClassVar[str] = "Keras HDF5" 

299 tensorflow_version: Optional[Version] = None 

300 """TensorFlow version used to create these weights""" 

301 

302 @field_validator("tensorflow_version", mode="after") 

303 @classmethod 

304 def _tfv(cls, value: Any): 

305 if value is None: 

306 issue_warning( 

307 "missing. Please specify the TensorFlow version" 

308 + " these weights were created with.", 

309 value=value, 

310 severity=ALERT, 

311 field="tensorflow_version", 

312 ) 

313 return value 

314 

315 

316class OnnxWeightsDescr(WeightsEntryDescrBase): 

317 type = "onnx" 

318 weights_format_name: ClassVar[str] = "ONNX" 

319 opset_version: Optional[Annotated[int, Ge(7)]] = None 

320 """ONNX opset version""" 

321 

322 @field_validator("opset_version", mode="after") 

323 @classmethod 

324 def _ov(cls, value: Any): 

325 if value is None: 

326 issue_warning( 

327 "Missing ONNX opset version (aka ONNX opset number). " 

328 + "Please specify the ONNX opset version these weights were created" 

329 + " with.", 

330 value=value, 

331 severity=ALERT, 

332 field="opset_version", 

333 ) 

334 return value 

335 

336 

337class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

338 type = "pytorch_state_dict" 

339 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

340 architecture: CustomCallable = Field( 

341 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"] 

342 ) 

343 """callable returning a torch.nn.Module instance. 

344 Local implementation: `<relative path to file>:<identifier of implementation within the file>`. 

345 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`.""" 

346 

347 architecture_sha256: Annotated[ 

348 Optional[Sha256], 

349 Field( 

350 description=( 

351 "The SHA256 of the architecture source file, if the architecture is not" 

352 " defined in a module listed in `dependencies`\n" 

353 ) 

354 + SHA256_HINT, 

355 ), 

356 ] = None 

357 """The SHA256 of the architecture source file, 

358 if the architecture is not defined in a module listed in `dependencies`""" 

359 

360 @model_validator(mode="after") 

361 def check_architecture_sha256(self) -> Self: 

362 if isinstance(self.architecture, CallableFromFile): 

363 if self.architecture_sha256 is None: 

364 raise ValueError( 

365 "Missing required `architecture_sha256` for `architecture` with" 

366 + " source file." 

367 ) 

368 elif self.architecture_sha256 is not None: 

369 raise ValueError( 

370 "Got `architecture_sha256` for architecture that does not have a source" 

371 + " file." 

372 ) 

373 

374 return self 

375 

376 kwargs: Dict[str, Any] = Field( 

377 default_factory=cast(Callable[[], Dict[str, Any]], dict) 

378 ) 

379 """key word arguments for the `architecture` callable""" 

380 

381 pytorch_version: Optional[Version] = None 

382 """Version of the PyTorch library used. 

383 If `depencencies` is specified it should include pytorch and the verison has to match. 

384 (`dependencies` overrules `pytorch_version`)""" 

385 

386 @field_validator("pytorch_version", mode="after") 

387 @classmethod 

388 def _ptv(cls, value: Any): 

389 if value is None: 

390 issue_warning( 

391 "missing. Please specify the PyTorch version these" 

392 + " PyTorch state dict weights were created with.", 

393 value=value, 

394 severity=ALERT, 

395 field="pytorch_version", 

396 ) 

397 return value 

398 

399 

400class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

401 type = "torchscript" 

402 weights_format_name: ClassVar[str] = "TorchScript" 

403 pytorch_version: Optional[Version] = None 

404 """Version of the PyTorch library used.""" 

405 

406 @field_validator("pytorch_version", mode="after") 

407 @classmethod 

408 def _ptv(cls, value: Any): 

409 if value is None: 

410 issue_warning( 

411 "missing. Please specify the PyTorch version these" 

412 + " Torchscript weights were created with.", 

413 value=value, 

414 severity=ALERT, 

415 field="pytorch_version", 

416 ) 

417 return value 

418 

419 

420class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

421 type = "tensorflow_js" 

422 weights_format_name: ClassVar[str] = "Tensorflow.js" 

423 tensorflow_version: Optional[Version] = None 

424 """Version of the TensorFlow library used.""" 

425 

426 @field_validator("tensorflow_version", mode="after") 

427 @classmethod 

428 def _tfv(cls, value: Any): 

429 if value is None: 

430 issue_warning( 

431 "missing. Please specify the TensorFlow version" 

432 + " these TensorflowJs weights were created with.", 

433 value=value, 

434 severity=ALERT, 

435 field="tensorflow_version", 

436 ) 

437 return value 

438 

439 source: FileSource_ 

440 """The multi-file weights. 

441 All required files/folders should be a zip archive.""" 

442 

443 

444class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

445 type = "tensorflow_saved_model_bundle" 

446 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

447 tensorflow_version: Optional[Version] = None 

448 """Version of the TensorFlow library used.""" 

449 

450 @field_validator("tensorflow_version", mode="after") 

451 @classmethod 

452 def _tfv(cls, value: Any): 

453 if value is None: 

454 issue_warning( 

455 "missing. Please specify the TensorFlow version" 

456 + " these Tensorflow saved model bundle weights were created with.", 

457 value=value, 

458 severity=ALERT, 

459 field="tensorflow_version", 

460 ) 

461 return value 

462 

463 

464class WeightsDescr(Node): 

465 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

466 onnx: Optional[OnnxWeightsDescr] = None 

467 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

468 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

469 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

470 None 

471 ) 

472 torchscript: Optional[TorchscriptWeightsDescr] = None 

473 

474 @model_validator(mode="after") 

475 def check_one_entry(self) -> Self: 

476 if all( 

477 entry is None 

478 for entry in [ 

479 self.keras_hdf5, 

480 self.onnx, 

481 self.pytorch_state_dict, 

482 self.tensorflow_js, 

483 self.tensorflow_saved_model_bundle, 

484 self.torchscript, 

485 ] 

486 ): 

487 raise ValueError("Missing weights entry") 

488 

489 return self 

490 

491 def __getitem__( 

492 self, 

493 key: WeightsFormat, 

494 ): 

495 if key == "keras_hdf5": 

496 ret = self.keras_hdf5 

497 elif key == "onnx": 

498 ret = self.onnx 

499 elif key == "pytorch_state_dict": 

500 ret = self.pytorch_state_dict 

501 elif key == "tensorflow_js": 

502 ret = self.tensorflow_js 

503 elif key == "tensorflow_saved_model_bundle": 

504 ret = self.tensorflow_saved_model_bundle 

505 elif key == "torchscript": 

506 ret = self.torchscript 

507 else: 

508 raise KeyError(key) 

509 

510 if ret is None: 

511 raise KeyError(key) 

512 

513 return ret 

514 

515 @property 

516 def available_formats(self): 

517 return { 

518 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

519 **({} if self.onnx is None else {"onnx": self.onnx}), 

520 **( 

521 {} 

522 if self.pytorch_state_dict is None 

523 else {"pytorch_state_dict": self.pytorch_state_dict} 

524 ), 

525 **( 

526 {} 

527 if self.tensorflow_js is None 

528 else {"tensorflow_js": self.tensorflow_js} 

529 ), 

530 **( 

531 {} 

532 if self.tensorflow_saved_model_bundle is None 

533 else { 

534 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

535 } 

536 ), 

537 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

538 } 

539 

540 @property 

541 def missing_formats(self): 

542 return { 

543 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

544 } 

545 

546 

547class ParameterizedInputShape(Node): 

548 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`.""" 

549 

550 min: NotEmpty[List[int]] 

551 """The minimum input shape""" 

552 

553 step: NotEmpty[List[int]] 

554 """The minimum shape change""" 

555 

556 def __len__(self) -> int: 

557 return len(self.min) 

558 

559 @model_validator(mode="after") 

560 def matching_lengths(self) -> Self: 

561 if len(self.min) != len(self.step): 

562 raise ValueError("`min` and `step` required to have the same length") 

563 

564 return self 

565 

566 

567class ImplicitOutputShape(Node): 

568 """Output tensor shape depending on an input tensor shape. 

569 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`""" 

570 

571 reference_tensor: TensorName 

572 """Name of the reference tensor.""" 

573 

574 scale: NotEmpty[List[Optional[float]]] 

575 """output_pix/input_pix for each dimension. 

576 'null' values indicate new dimensions, whose length is defined by 2*`offset`""" 

577 

578 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]] 

579 """Position of origin wrt to input.""" 

580 

581 def __len__(self) -> int: 

582 return len(self.scale) 

583 

584 @model_validator(mode="after") 

585 def matching_lengths(self) -> Self: 

586 if len(self.scale) != len(self.offset): 

587 raise ValueError( 

588 f"scale {self.scale} has to have same length as offset {self.offset}!" 

589 ) 

590 # if we have an expanded dimension, make sure that it's offet is not zero 

591 for sc, off in zip(self.scale, self.offset): 

592 if sc is None and not off: 

593 raise ValueError("`offset` must not be zero if `scale` is none/zero") 

594 

595 return self 

596 

597 

598class TensorDescrBase(Node): 

599 name: TensorName 

600 """Tensor name. No duplicates are allowed.""" 

601 

602 description: str = "" 

603 

604 axes: AxesStr 

605 """Axes identifying characters. Same length and order as the axes in `shape`. 

606 | axis | description | 

607 | --- | --- | 

608 | b | batch (groups multiple samples) | 

609 | i | instance/index/element | 

610 | t | time | 

611 | c | channel | 

612 | z | spatial dimension z | 

613 | y | spatial dimension y | 

614 | x | spatial dimension x | 

615 """ 

616 

617 data_range: Optional[ 

618 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]] 

619 ] = None 

620 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

621 If not specified, the full data range that can be expressed in `data_type` is allowed.""" 

622 

623 

624class ProcessingKwargs(KwargsNode): 

625 """base class for pre-/postprocessing key word arguments""" 

626 

627 

628class ProcessingDescrBase(NodeWithExplicitlySetFields): 

629 """processing base class""" 

630 

631 

632class BinarizeKwargs(ProcessingKwargs): 

633 """key word arguments for `BinarizeDescr`""" 

634 

635 threshold: float 

636 """The fixed threshold""" 

637 

638 

639class BinarizeDescr(ProcessingDescrBase): 

640 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`. 

641 Values above the threshold will be set to one, values below the threshold to zero. 

642 """ 

643 

644 implemented_name: ClassVar[Literal["binarize"]] = "binarize" 

645 if TYPE_CHECKING: 

646 name: Literal["binarize"] = "binarize" 

647 else: 

648 name: Literal["binarize"] 

649 

650 kwargs: BinarizeKwargs 

651 

652 

653class ClipKwargs(ProcessingKwargs): 

654 """key word arguments for `ClipDescr`""" 

655 

656 min: float 

657 """minimum value for clipping""" 

658 max: float 

659 """maximum value for clipping""" 

660 

661 

662class ClipDescr(ProcessingDescrBase): 

663 """Clip tensor values to a range. 

664 

665 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min` 

666 and above `ClipKwargs.max` to `ClipKwargs.max`. 

667 """ 

668 

669 implemented_name: ClassVar[Literal["clip"]] = "clip" 

670 if TYPE_CHECKING: 

671 name: Literal["clip"] = "clip" 

672 else: 

673 name: Literal["clip"] 

674 

675 kwargs: ClipKwargs 

676 

677 

678class ScaleLinearKwargs(ProcessingKwargs): 

679 """key word arguments for `ScaleLinearDescr`""" 

680 

681 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 

682 """The subset of axes to scale jointly. 

683 For example xy to scale the two image axes for 2d data jointly.""" 

684 

685 gain: Union[float, List[float]] = 1.0 

686 """multiplicative factor""" 

687 

688 offset: Union[float, List[float]] = 0.0 

689 """additive term""" 

690 

691 @model_validator(mode="after") 

692 def either_gain_or_offset(self) -> Self: 

693 if ( 

694 self.gain == 1.0 

695 or isinstance(self.gain, list) 

696 and all(g == 1.0 for g in self.gain) 

697 ) and ( 

698 self.offset == 0.0 

699 or isinstance(self.offset, list) 

700 and all(off == 0.0 for off in self.offset) 

701 ): 

702 raise ValueError( 

703 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 

704 + " 0.0." 

705 ) 

706 

707 return self 

708 

709 

710class ScaleLinearDescr(ProcessingDescrBase): 

711 """Fixed linear scaling.""" 

712 

713 implemented_name: ClassVar[Literal["scale_linear"]] = "scale_linear" 

714 if TYPE_CHECKING: 

715 name: Literal["scale_linear"] = "scale_linear" 

716 else: 

717 name: Literal["scale_linear"] 

718 

719 kwargs: ScaleLinearKwargs 

720 

721 

722class SigmoidDescr(ProcessingDescrBase): 

723 """The logistic sigmoid funciton, a.k.a. expit function.""" 

724 

725 implemented_name: ClassVar[Literal["sigmoid"]] = "sigmoid" 

726 if TYPE_CHECKING: 

727 name: Literal["sigmoid"] = "sigmoid" 

728 else: 

729 name: Literal["sigmoid"] 

730 

731 @property 

732 def kwargs(self) -> ProcessingKwargs: 

733 """empty kwargs""" 

734 return ProcessingKwargs() 

735 

736 

737class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

738 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

739 

740 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed" 

741 """Mode for computing mean and variance. 

742 | mode | description | 

743 | ----------- | ------------------------------------ | 

744 | fixed | Fixed values for mean and variance | 

745 | per_dataset | Compute for the entire dataset | 

746 | per_sample | Compute for each sample individually | 

747 """ 

748 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 

749 """The subset of axes to normalize jointly. 

750 For example `xy` to normalize the two image axes for 2d data jointly.""" 

751 

752 mean: Annotated[ 

753 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)]) 

754 ] = None 

755 """The mean value(s) to use for `mode: fixed`. 

756 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`.""" 

757 # todo: check if means match input axes (for mode 'fixed') 

758 

759 std: Annotated[ 

760 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)]) 

761 ] = None 

762 """The standard deviation values to use for `mode: fixed`. Analogous to mean.""" 

763 

764 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

765 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

766 

767 @model_validator(mode="after") 

768 def mean_and_std_match_mode(self) -> Self: 

769 if self.mode == "fixed" and (self.mean is None or self.std is None): 

770 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 

771 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 

772 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 

773 

774 return self 

775 

776 

777class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

778 """Subtract mean and divide by variance.""" 

779 

780 implemented_name: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

781 "zero_mean_unit_variance" 

782 ) 

783 if TYPE_CHECKING: 

784 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

785 else: 

786 name: Literal["zero_mean_unit_variance"] 

787 

788 kwargs: ZeroMeanUnitVarianceKwargs 

789 

790 

791class ScaleRangeKwargs(ProcessingKwargs): 

792 """key word arguments for `ScaleRangeDescr` 

793 

794 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

795 this processing step normalizes data to the [0, 1] intervall. 

796 For other percentiles the normalized values will partially be outside the [0, 1] 

797 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

798 normalized values to a range. 

799 """ 

800 

801 mode: Literal["per_dataset", "per_sample"] 

802 """Mode for computing percentiles. 

803 | mode | description | 

804 | ----------- | ------------------------------------ | 

805 | per_dataset | compute for the entire dataset | 

806 | per_sample | compute for each sample individually | 

807 """ 

808 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 

809 """The subset of axes to normalize jointly. 

810 For example xy to normalize the two image axes for 2d data jointly.""" 

811 

812 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0 

813 """The lower percentile used to determine the value to align with zero.""" 

814 

815 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0 

816 """The upper percentile used to determine the value to align with one. 

817 Has to be bigger than `min_percentile`. 

818 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

819 accepting percentiles specified in the range 0.0 to 1.0.""" 

820 

821 @model_validator(mode="after") 

822 def min_smaller_max(self, info: ValidationInfo) -> Self: 

823 if self.min_percentile >= self.max_percentile: 

824 raise ValueError( 

825 f"min_percentile {self.min_percentile} >= max_percentile" 

826 + f" {self.max_percentile}" 

827 ) 

828 

829 return self 

830 

831 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

832 """Epsilon for numeric stability. 

833 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

834 with `v_lower,v_upper` values at the respective percentiles.""" 

835 

836 reference_tensor: Optional[TensorName] = None 

837 """Tensor name to compute the percentiles from. Default: The tensor itself. 

838 For any tensor in `inputs` only input tensor references are allowed. 

839 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`""" 

840 

841 

842class ScaleRangeDescr(ProcessingDescrBase): 

843 """Scale with percentiles.""" 

844 

845 implemented_name: ClassVar[Literal["scale_range"]] = "scale_range" 

846 if TYPE_CHECKING: 

847 name: Literal["scale_range"] = "scale_range" 

848 else: 

849 name: Literal["scale_range"] 

850 

851 kwargs: ScaleRangeKwargs 

852 

853 

854class ScaleMeanVarianceKwargs(ProcessingKwargs): 

855 """key word arguments for `ScaleMeanVarianceDescr`""" 

856 

857 mode: Literal["per_dataset", "per_sample"] 

858 """Mode for computing mean and variance. 

859 | mode | description | 

860 | ----------- | ------------------------------------ | 

861 | per_dataset | Compute for the entire dataset | 

862 | per_sample | Compute for each sample individually | 

863 """ 

864 

865 reference_tensor: TensorName 

866 """Name of tensor to match.""" 

867 

868 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 

869 """The subset of axes to scale jointly. 

870 For example xy to normalize the two image axes for 2d data jointly. 

871 Default: scale all non-batch axes jointly.""" 

872 

873 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

874 """Epsilon for numeric stability: 

875 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.""" 

876 

877 

878class ScaleMeanVarianceDescr(ProcessingDescrBase): 

879 """Scale the tensor s.t. its mean and variance match a reference tensor.""" 

880 

881 implemented_name: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

882 if TYPE_CHECKING: 

883 name: Literal["scale_mean_variance"] = "scale_mean_variance" 

884 else: 

885 name: Literal["scale_mean_variance"] 

886 

887 kwargs: ScaleMeanVarianceKwargs 

888 

889 

890PreprocessingDescr = Annotated[ 

891 Union[ 

892 BinarizeDescr, 

893 ClipDescr, 

894 ScaleLinearDescr, 

895 SigmoidDescr, 

896 ZeroMeanUnitVarianceDescr, 

897 ScaleRangeDescr, 

898 ], 

899 Discriminator("name"), 

900] 

901PostprocessingDescr = Annotated[ 

902 Union[ 

903 BinarizeDescr, 

904 ClipDescr, 

905 ScaleLinearDescr, 

906 SigmoidDescr, 

907 ZeroMeanUnitVarianceDescr, 

908 ScaleRangeDescr, 

909 ScaleMeanVarianceDescr, 

910 ], 

911 Discriminator("name"), 

912] 

913 

914 

915class InputTensorDescr(TensorDescrBase): 

916 data_type: Literal["float32", "uint8", "uint16"] 

917 """For now an input tensor is expected to be given as `float32`. 

918 The data flow in bioimage.io models is explained 

919 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 

920 

921 shape: Annotated[ 

922 Union[Sequence[int], ParameterizedInputShape], 

923 Field( 

924 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))] 

925 ), 

926 ] 

927 """Specification of input tensor shape.""" 

928 

929 preprocessing: List[PreprocessingDescr] = Field( 

930 default_factory=cast( # TODO: (py>3.8) use list[PreprocessingDesr] 

931 Callable[[], List[PreprocessingDescr]], list 

932 ) 

933 ) 

934 """Description of how this input should be preprocessed.""" 

935 

936 @model_validator(mode="after") 

937 def zero_batch_step_and_one_batch_size(self) -> Self: 

938 bidx = self.axes.find("b") 

939 if bidx == -1: 

940 return self 

941 

942 if isinstance(self.shape, ParameterizedInputShape): 

943 step = self.shape.step 

944 shape = self.shape.min 

945 if step[bidx] != 0: 

946 raise ValueError( 

947 "Input shape step has to be zero in the batch dimension (the batch" 

948 + " dimension can always be increased, but `step` should specify how" 

949 + " to increase the minimal shape to find the largest single batch" 

950 + " shape)" 

951 ) 

952 else: 

953 shape = self.shape 

954 

955 if shape[bidx] != 1: 

956 raise ValueError("Input shape has to be 1 in the batch dimension b.") 

957 

958 return self 

959 

960 @model_validator(mode="after") 

961 def validate_preprocessing_kwargs(self) -> Self: 

962 for p in self.preprocessing: 

963 kwargs_axes = p.kwargs.get("axes") 

964 if isinstance(kwargs_axes, str) and any( 

965 a not in self.axes for a in kwargs_axes 

966 ): 

967 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 

968 

969 return self 

970 

971 

972class OutputTensorDescr(TensorDescrBase): 

973 data_type: Literal[ 

974 "float32", 

975 "float64", 

976 "uint8", 

977 "int8", 

978 "uint16", 

979 "int16", 

980 "uint32", 

981 "int32", 

982 "uint64", 

983 "int64", 

984 "bool", 

985 ] 

986 """Data type. 

987 The data flow in bioimage.io models is explained 

988 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 

989 

990 shape: Union[Sequence[int], ImplicitOutputShape] 

991 """Output tensor shape.""" 

992 

993 halo: Optional[Sequence[int]] = None 

994 """The `halo` that should be cropped from the output tensor to avoid boundary effects. 

995 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. 

996 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead.""" 

997 

998 postprocessing: List[PostprocessingDescr] = Field( 

999 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

1000 ) 

1001 """Description of how this output should be postprocessed.""" 

1002 

1003 @model_validator(mode="after") 

1004 def matching_halo_length(self) -> Self: 

1005 if self.halo and len(self.halo) != len(self.shape): 

1006 raise ValueError( 

1007 f"halo {self.halo} has to have same length as shape {self.shape}!" 

1008 ) 

1009 

1010 return self 

1011 

1012 @model_validator(mode="after") 

1013 def validate_postprocessing_kwargs(self) -> Self: 

1014 for p in self.postprocessing: 

1015 kwargs_axes = p.kwargs.get("axes", "") 

1016 if not isinstance(kwargs_axes, str): 

1017 raise ValueError(f"Expected {kwargs_axes} to be a string") 

1018 

1019 if any(a not in self.axes for a in kwargs_axes): 

1020 raise ValueError("`kwargs.axes` needs to be subset of axes") 

1021 

1022 return self 

1023 

1024 

1025KnownRunMode = Literal["deepimagej"] 

1026 

1027 

1028class RunMode(Node): 

1029 name: Annotated[ 

1030 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.") 

1031 ] 

1032 """Run mode name""" 

1033 

1034 kwargs: Dict[str, Any] = Field( 

1035 default_factory=cast(Callable[[], Dict[str, Any]], dict) 

1036 ) 

1037 """Run mode specific key word arguments""" 

1038 

1039 

1040class LinkedModel(Node): 

1041 """Reference to a bioimage.io model.""" 

1042 

1043 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])] 

1044 """A valid model `id` from the bioimage.io collection.""" 

1045 

1046 version_number: Optional[int] = None 

1047 """version number (n-th published version, not the semantic version) of linked model""" 

1048 

1049 

1050def package_weights( 

1051 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr] 

1052 handler: SerializerFunctionWrapHandler, 

1053 info: SerializationInfo, 

1054): 

1055 ctxt = packaging_context_var.get() 

1056 if ctxt is not None and ctxt.weights_priority_order is not None: 

1057 for wf in ctxt.weights_priority_order: 

1058 w = getattr(value, wf, None) 

1059 if w is not None: 

1060 break 

1061 else: 

1062 raise ValueError( 

1063 "None of the weight formats in `weights_priority_order`" 

1064 + f" ({ctxt.weights_priority_order}) is present in the given model." 

1065 ) 

1066 

1067 assert isinstance(w, Node), type(w) 

1068 # construct WeightsDescr with new single weight format entry 

1069 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"}) 

1070 value = value.model_construct(None, **{wf: new_w}) 

1071 

1072 return handler( 

1073 value, 

1074 info, # pyright: ignore[reportArgumentType] # taken from pydantic docs 

1075 ) 

1076 

1077 

1078class ModelDescr(GenericModelDescrBase): 

1079 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights. 

1080 

1081 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

1082 """ 

1083 

1084 implemented_format_version: ClassVar[Literal["0.4.10"]] = "0.4.10" 

1085 if TYPE_CHECKING: 

1086 format_version: Literal["0.4.10"] = "0.4.10" 

1087 else: 

1088 format_version: Literal["0.4.10"] 

1089 """Version of the bioimage.io model description specification used. 

1090 When creating a new model always use the latest micro/patch version described here. 

1091 The `format_version` is important for any consumer software to understand how to parse the fields. 

1092 """ 

1093 

1094 implemented_type: ClassVar[Literal["model"]] = "model" 

1095 if TYPE_CHECKING: 

1096 type: Literal["model"] = "model" 

1097 else: 

1098 type: Literal["model"] 

1099 """Specialized resource type 'model'""" 

1100 

1101 id: Optional[ModelId] = None 

1102 """bioimage.io-wide unique resource identifier 

1103 assigned by bioimage.io; version **un**specific.""" 

1104 

1105 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory 

1106 List[Author] 

1107 ] 

1108 """The authors are the creators of the model RDF and the primary points of contact.""" 

1109 

1110 documentation: Annotated[ 

1111 FileSource_, 

1112 Field( 

1113 examples=[ 

1114 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 

1115 "README.md", 

1116 ], 

1117 ), 

1118 ] 

1119 """URL or relative path to a markdown file with additional documentation. 

1120 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

1121 The documentation should include a '[#[#]]# Validation' (sub)section 

1122 with details on how to quantitatively validate the model on unseen data.""" 

1123 

1124 inputs: NotEmpty[List[InputTensorDescr]] 

1125 """Describes the input tensors expected by this model.""" 

1126 

1127 license: Annotated[ 

1128 Union[LicenseId, str], 

1129 warn(LicenseId, "Unknown license id '{value}'."), 

1130 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]), 

1131 ] 

1132 """A [SPDX license identifier](https://spdx.org/licenses/). 

1133 We do notsupport custom license beyond the SPDX license list, if you need that please 

1134 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose 

1135 ) to discuss your intentions with the community.""" 

1136 

1137 name: Annotated[ 

1138 str, 

1139 MinLen(1), 

1140 warn(MinLen(5), "Name shorter than 5 characters.", INFO), 

1141 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

1142 ] 

1143 """A human-readable name of this model. 

1144 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.""" 

1145 

1146 outputs: NotEmpty[List[OutputTensorDescr]] 

1147 """Describes the output tensors.""" 

1148 

1149 @field_validator("inputs", "outputs") 

1150 @classmethod 

1151 def unique_tensor_descr_names( 

1152 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 

1153 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 

1154 unique_names = {str(v.name) for v in value} 

1155 if len(unique_names) != len(value): 

1156 raise ValueError("Duplicate tensor descriptor names") 

1157 

1158 return value 

1159 

1160 @model_validator(mode="after") 

1161 def unique_io_names(self) -> Self: 

1162 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 

1163 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 

1164 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 

1165 

1166 return self 

1167 

1168 @model_validator(mode="after") 

1169 def minimum_shape2valid_output(self) -> Self: 

1170 tensors_by_name: Dict[ 

1171 TensorName, Union[InputTensorDescr, OutputTensorDescr] 

1172 ] = {t.name: t for t in self.inputs + self.outputs} 

1173 

1174 for out in self.outputs: 

1175 if isinstance(out.shape, ImplicitOutputShape): 

1176 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 

1177 ndim_out_ref = len( 

1178 [scale for scale in out.shape.scale if scale is not None] 

1179 ) 

1180 if ndim_ref != ndim_out_ref: 

1181 expanded_dim_note = ( 

1182 " Note that expanded dimensions (`scale`: null) are not" 

1183 + f" counted for {out.name}'sdimensionality here." 

1184 if None in out.shape.scale 

1185 else "" 

1186 ) 

1187 raise ValueError( 

1188 f"Referenced tensor '{out.shape.reference_tensor}' with" 

1189 + f" {ndim_ref} dimensions does not match output tensor" 

1190 + f" '{out.name}' with" 

1191 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 

1192 ) 

1193 

1194 min_out_shape = self._get_min_shape(out, tensors_by_name) 

1195 if out.halo: 

1196 halo = out.halo 

1197 halo_msg = f" for halo {out.halo}" 

1198 else: 

1199 halo = [0] * len(min_out_shape) 

1200 halo_msg = "" 

1201 

1202 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 

1203 raise ValueError( 

1204 f"Minimal shape {min_out_shape} of output {out.name} is too" 

1205 + f" small{halo_msg}." 

1206 ) 

1207 

1208 return self 

1209 

1210 @classmethod 

1211 def _get_min_shape( 

1212 cls, 

1213 t: Union[InputTensorDescr, OutputTensorDescr], 

1214 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]], 

1215 ) -> Sequence[int]: 

1216 """output with subtracted halo has to result in meaningful output even for the minimal input 

1217 see https://github.com/bioimage-io/spec-bioimage-io/issues/392 

1218 """ 

1219 if isinstance(t.shape, collections.abc.Sequence): 

1220 return t.shape 

1221 elif isinstance(t.shape, ParameterizedInputShape): 

1222 return t.shape.min 

1223 elif isinstance(t.shape, ImplicitOutputShape): 

1224 pass 

1225 else: 

1226 assert_never(t.shape) 

1227 

1228 ref_shape = cls._get_min_shape( 

1229 tensors_by_name[t.shape.reference_tensor], tensors_by_name 

1230 ) 

1231 

1232 if None not in t.shape.scale: 

1233 scale: Sequence[float, ...] = t.shape.scale # type: ignore 

1234 else: 

1235 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None] 

1236 new_ref_shape: List[int] = [] 

1237 for idx in range(len(t.shape.scale)): 

1238 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims) 

1239 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx]) 

1240 

1241 ref_shape = new_ref_shape 

1242 assert len(ref_shape) == len(t.shape.scale) 

1243 scale = [0.0 if sc is None else sc for sc in t.shape.scale] 

1244 

1245 offset = t.shape.offset 

1246 assert len(offset) == len(scale) 

1247 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)] 

1248 

1249 @model_validator(mode="after") 

1250 def validate_tensor_references_in_inputs(self) -> Self: 

1251 for t in self.inputs: 

1252 for proc in t.preprocessing: 

1253 if "reference_tensor" not in proc.kwargs: 

1254 continue 

1255 

1256 ref_tensor = proc.kwargs["reference_tensor"] 

1257 if ref_tensor is not None and str(ref_tensor) not in { 

1258 str(t.name) for t in self.inputs 

1259 }: 

1260 raise ValueError(f"'{ref_tensor}' not found in inputs") 

1261 

1262 if ref_tensor == t.name: 

1263 raise ValueError( 

1264 f"invalid self reference for preprocessing of tensor {t.name}" 

1265 ) 

1266 

1267 return self 

1268 

1269 @model_validator(mode="after") 

1270 def validate_tensor_references_in_outputs(self) -> Self: 

1271 for t in self.outputs: 

1272 for proc in t.postprocessing: 

1273 if "reference_tensor" not in proc.kwargs: 

1274 continue 

1275 ref_tensor = proc.kwargs["reference_tensor"] 

1276 if ref_tensor is not None and str(ref_tensor) not in { 

1277 str(t.name) for t in self.inputs 

1278 }: 

1279 raise ValueError(f"{ref_tensor} not found in inputs") 

1280 

1281 return self 

1282 

1283 packaged_by: List[Author] = Field( 

1284 default_factory=cast(Callable[[], List[Author]], list) 

1285 ) 

1286 """The persons that have packaged and uploaded this model. 

1287 Only required if those persons differ from the `authors`.""" 

1288 

1289 parent: Optional[LinkedModel] = None 

1290 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

1291 

1292 @field_validator("parent", mode="before") 

1293 @classmethod 

1294 def ignore_url_parent(cls, parent: Any): 

1295 if isinstance(parent, dict): 

1296 return None 

1297 

1298 else: 

1299 return parent 

1300 

1301 run_mode: Optional[RunMode] = None 

1302 """Custom run mode for this model: for more complex prediction procedures like test time 

1303 data augmentation that currently cannot be expressed in the specification. 

1304 No standard run modes are defined yet.""" 

1305 

1306 sample_inputs: List[FileSource_] = Field( 

1307 default_factory=cast(Callable[[], List[FileSource_]], list) 

1308 ) 

1309 """URLs/relative paths to sample inputs to illustrate possible inputs for the model, 

1310 for example stored as PNG or TIFF images. 

1311 The sample files primarily serve to inform a human user about an example use case""" 

1312 

1313 sample_outputs: List[FileSource_] = Field( 

1314 default_factory=cast(Callable[[], List[FileSource_]], list) 

1315 ) 

1316 """URLs/relative paths to sample outputs corresponding to the `sample_inputs`.""" 

1317 

1318 test_inputs: NotEmpty[ 

1319 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]] 

1320 ] 

1321 """Test input tensors compatible with the `inputs` description for a **single test case**. 

1322 This means if your model has more than one input, you should provide one URL/relative path for each input. 

1323 Each test input should be a file with an ndarray in 

1324 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1325 The extension must be '.npy'.""" 

1326 

1327 test_outputs: NotEmpty[ 

1328 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]] 

1329 ] 

1330 """Analog to `test_inputs`.""" 

1331 

1332 timestamp: Datetime 

1333 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

1334 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).""" 

1335 

1336 training_data: Union[LinkedDataset, DatasetDescr, None] = None 

1337 """The dataset used to train this model""" 

1338 

1339 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

1340 """The weights for this model. 

1341 Weights can be given for different formats, but should otherwise be equivalent. 

1342 The available weight formats determine which consumers can use this model.""" 

1343 

1344 @model_validator(mode="before") 

1345 @classmethod 

1346 def _convert_from_older_format( 

1347 cls, data: BioimageioYamlContent, / 

1348 ) -> BioimageioYamlContent: 

1349 convert_from_older_format(data) 

1350 return data 

1351 

1352 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

1353 data = [load_array(ipt) for ipt in self.test_inputs] 

1354 assert all(isinstance(d, np.ndarray) for d in data) 

1355 return data 

1356 

1357 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

1358 data = [load_array(out) for out in self.test_outputs] 

1359 assert all(isinstance(d, np.ndarray) for d in data) 

1360 return data