Coverage for bioimageio/spec/model/v0_4.py: 90%

595 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 09:20 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Callable, 

8 ClassVar, 

9 Dict, 

10 List, 

11 Literal, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Type, 

16 Union, 

17 cast, 

18) 

19 

20import numpy as np 

21from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf 

22from numpy.typing import NDArray 

23from pydantic import ( 

24 AllowInfNan, 

25 Discriminator, 

26 Field, 

27 RootModel, 

28 SerializationInfo, 

29 SerializerFunctionWrapHandler, 

30 StringConstraints, 

31 TypeAdapter, 

32 ValidationInfo, 

33 WrapSerializer, 

34 field_validator, 

35 model_validator, 

36) 

37from typing_extensions import Annotated, Self, assert_never, get_args 

38 

39from .._internal.common_nodes import ( 

40 KwargsNode, 

41 Node, 

42 NodeWithExplicitlySetFields, 

43) 

44from .._internal.constants import SHA256_HINT 

45from .._internal.field_validation import validate_unique_entries 

46from .._internal.field_warning import issue_warning, warn 

47from .._internal.io import BioimageioYamlContent, WithSuffix 

48from .._internal.io import FileDescr as FileDescr 

49from .._internal.io_basics import Sha256 as Sha256 

50from .._internal.io_packaging import include_in_package 

51from .._internal.io_utils import load_array 

52from .._internal.packaging_context import packaging_context_var 

53from .._internal.types import Datetime as Datetime 

54from .._internal.types import FileSource_, LowerCaseIdentifier 

55from .._internal.types import Identifier as Identifier 

56from .._internal.types import LicenseId as LicenseId 

57from .._internal.types import NotEmpty as NotEmpty 

58from .._internal.url import HttpUrl as HttpUrl 

59from .._internal.validated_string_with_inner_node import ValidatedStringWithInnerNode 

60from .._internal.validator_annotations import AfterValidator, RestrictCharacters 

61from .._internal.version_type import Version as Version 

62from .._internal.warning_levels import ALERT, INFO 

63from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS 

64from ..dataset.v0_2 import DatasetDescr as DatasetDescr 

65from ..dataset.v0_2 import LinkedDataset as LinkedDataset 

66from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr 

67from ..generic.v0_2 import Author as Author 

68from ..generic.v0_2 import BadgeDescr as BadgeDescr 

69from ..generic.v0_2 import CiteEntry as CiteEntry 

70from ..generic.v0_2 import Doi as Doi 

71from ..generic.v0_2 import GenericModelDescrBase 

72from ..generic.v0_2 import LinkedResource as LinkedResource 

73from ..generic.v0_2 import Maintainer as Maintainer 

74from ..generic.v0_2 import OrcidId as OrcidId 

75from ..generic.v0_2 import RelativeFilePath as RelativeFilePath 

76from ..generic.v0_2 import ResourceId as ResourceId 

77from ..generic.v0_2 import Uploader as Uploader 

78from ._v0_4_converter import convert_from_older_format 

79 

80 

81class ModelId(ResourceId): 

82 pass 

83 

84 

85AxesStr = Annotated[ 

86 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries) 

87] 

88AxesInCZYX = Annotated[ 

89 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries) 

90] 

91 

92PostprocessingName = Literal[ 

93 "binarize", 

94 "clip", 

95 "scale_linear", 

96 "sigmoid", 

97 "zero_mean_unit_variance", 

98 "scale_range", 

99 "scale_mean_variance", 

100] 

101PreprocessingName = Literal[ 

102 "binarize", 

103 "clip", 

104 "scale_linear", 

105 "sigmoid", 

106 "zero_mean_unit_variance", 

107 "scale_range", 

108] 

109 

110 

111class TensorName(LowerCaseIdentifier): 

112 pass 

113 

114 

115class CallableFromDepencencyNode(Node): 

116 _submodule_adapter: ClassVar[TypeAdapter[Identifier]] = TypeAdapter(Identifier) 

117 

118 module_name: str 

119 """The Python module that implements **callable_name**.""" 

120 

121 @field_validator("module_name", mode="after") 

122 def _check_submodules(cls, module_name: str) -> str: 

123 for submod in module_name.split("."): 

124 _ = cls._submodule_adapter.validate_python(submod) 

125 

126 return module_name 

127 

128 callable_name: Identifier 

129 """The callable Python identifier implemented in module **module_name**.""" 

130 

131 

132class CallableFromDepencency(ValidatedStringWithInnerNode[CallableFromDepencencyNode]): 

133 _inner_node_class = CallableFromDepencencyNode 

134 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

135 Annotated[ 

136 str, 

137 StringConstraints(strip_whitespace=True, pattern=r"^.+\..+$"), 

138 ] 

139 ] 

140 

141 @classmethod 

142 def _get_data(cls, valid_string_data: str): 

143 *mods, callname = valid_string_data.split(".") 

144 return dict(module_name=".".join(mods), callable_name=callname) 

145 

146 @property 

147 def module_name(self): 

148 """The Python module that implements **callable_name**.""" 

149 return self._inner_node.module_name 

150 

151 @property 

152 def callable_name(self): 

153 """The callable Python identifier implemented in module **module_name**.""" 

154 return self._inner_node.callable_name 

155 

156 

157class CallableFromFileNode(Node): 

158 source_file: Annotated[ 

159 Union[RelativeFilePath, HttpUrl], 

160 Field(union_mode="left_to_right"), 

161 include_in_package, 

162 ] 

163 """The Python source file that implements **callable_name**.""" 

164 callable_name: Identifier 

165 """The callable Python identifier implemented in **source_file**.""" 

166 

167 

168class CallableFromFile(ValidatedStringWithInnerNode[CallableFromFileNode]): 

169 _inner_node_class = CallableFromFileNode 

170 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

171 Annotated[ 

172 str, 

173 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 

174 ] 

175 ] 

176 

177 @classmethod 

178 def _get_data(cls, valid_string_data: str): 

179 *file_parts, callname = valid_string_data.split(":") 

180 return dict(source_file=":".join(file_parts), callable_name=callname) 

181 

182 @property 

183 def source_file(self): 

184 """The Python source file that implements **callable_name**.""" 

185 return self._inner_node.source_file 

186 

187 @property 

188 def callable_name(self): 

189 """The callable Python identifier implemented in **source_file**.""" 

190 return self._inner_node.callable_name 

191 

192 

193CustomCallable = Annotated[ 

194 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right") 

195] 

196 

197 

198class DependenciesNode(Node): 

199 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])] 

200 """Dependency manager""" 

201 

202 file: FileSource_ 

203 """Dependency file""" 

204 

205 

206class Dependencies(ValidatedStringWithInnerNode[DependenciesNode]): 

207 _inner_node_class = DependenciesNode 

208 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

209 Annotated[ 

210 str, 

211 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 

212 ] 

213 ] 

214 

215 @classmethod 

216 def _get_data(cls, valid_string_data: str): 

217 manager, *file_parts = valid_string_data.split(":") 

218 return dict(manager=manager, file=":".join(file_parts)) 

219 

220 @property 

221 def manager(self): 

222 """Dependency manager""" 

223 return self._inner_node.manager 

224 

225 @property 

226 def file(self): 

227 """Dependency file""" 

228 return self._inner_node.file 

229 

230 

231WeightsFormat = Literal[ 

232 "keras_hdf5", 

233 "onnx", 

234 "pytorch_state_dict", 

235 "tensorflow_js", 

236 "tensorflow_saved_model_bundle", 

237 "torchscript", 

238] 

239 

240 

241class WeightsEntryDescrBase(FileDescr): 

242 type: ClassVar[WeightsFormat] 

243 weights_format_name: ClassVar[str] # human readable 

244 

245 source: FileSource_ 

246 """The weights file.""" 

247 

248 attachments: Annotated[ 

249 Union[AttachmentsDescr, None], 

250 warn(None, "Weights entry depends on additional attachments.", ALERT), 

251 ] = None 

252 """Attachments that are specific to this weights entry.""" 

253 

254 authors: Union[List[Author], None] = None 

255 """Authors 

256 Either the person(s) that have trained this model resulting in the original weights file. 

257 (If this is the initial weights entry, i.e. it does not have a `parent`) 

258 Or the person(s) who have converted the weights to this weights format. 

259 (If this is a child weight, i.e. it has a `parent` field) 

260 """ 

261 

262 dependencies: Annotated[ 

263 Optional[Dependencies], 

264 warn( 

265 None, 

266 "Custom dependencies ({value}) specified. Avoid this whenever possible " 

267 + "to allow execution in a wider range of software environments.", 

268 ), 

269 Field( 

270 examples=[ 

271 "conda:environment.yaml", 

272 "maven:./pom.xml", 

273 "pip:./requirements.txt", 

274 ] 

275 ), 

276 ] = None 

277 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`.""" 

278 

279 parent: Annotated[ 

280 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

281 ] = None 

282 """The source weights these weights were converted from. 

283 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

284 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

285 All weight entries except one (the initial set of weights resulting from training the model), 

286 need to have this field.""" 

287 

288 @model_validator(mode="after") 

289 def check_parent_is_not_self(self) -> Self: 

290 if self.type == self.parent: 

291 raise ValueError("Weights entry can't be it's own parent.") 

292 

293 return self 

294 

295 

296class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

297 type = "keras_hdf5" 

298 weights_format_name: ClassVar[str] = "Keras HDF5" 

299 tensorflow_version: Optional[Version] = None 

300 """TensorFlow version used to create these weights""" 

301 

302 @field_validator("tensorflow_version", mode="after") 

303 @classmethod 

304 def _tfv(cls, value: Any): 

305 if value is None: 

306 issue_warning( 

307 "missing. Please specify the TensorFlow version" 

308 + " these weights were created with.", 

309 value=value, 

310 severity=ALERT, 

311 field="tensorflow_version", 

312 ) 

313 return value 

314 

315 

316class OnnxWeightsDescr(WeightsEntryDescrBase): 

317 type = "onnx" 

318 weights_format_name: ClassVar[str] = "ONNX" 

319 opset_version: Optional[Annotated[int, Ge(7)]] = None 

320 """ONNX opset version""" 

321 

322 @field_validator("opset_version", mode="after") 

323 @classmethod 

324 def _ov(cls, value: Any): 

325 if value is None: 

326 issue_warning( 

327 "Missing ONNX opset version (aka ONNX opset number). " 

328 + "Please specify the ONNX opset version these weights were created" 

329 + " with.", 

330 value=value, 

331 severity=ALERT, 

332 field="opset_version", 

333 ) 

334 return value 

335 

336 

337class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

338 type = "pytorch_state_dict" 

339 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

340 architecture: CustomCallable = Field( 

341 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"] 

342 ) 

343 """callable returning a torch.nn.Module instance. 

344 Local implementation: `<relative path to file>:<identifier of implementation within the file>`. 

345 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`.""" 

346 

347 architecture_sha256: Annotated[ 

348 Optional[Sha256], 

349 Field( 

350 description=( 

351 "The SHA256 of the architecture source file, if the architecture is not" 

352 " defined in a module listed in `dependencies`\n" 

353 ) 

354 + SHA256_HINT, 

355 ), 

356 ] = None 

357 """The SHA256 of the architecture source file, 

358 if the architecture is not defined in a module listed in `dependencies`""" 

359 

360 @model_validator(mode="after") 

361 def check_architecture_sha256(self) -> Self: 

362 if isinstance(self.architecture, CallableFromFile): 

363 if self.architecture_sha256 is None: 

364 raise ValueError( 

365 "Missing required `architecture_sha256` for `architecture` with" 

366 + " source file." 

367 ) 

368 elif self.architecture_sha256 is not None: 

369 raise ValueError( 

370 "Got `architecture_sha256` for architecture that does not have a source" 

371 + " file." 

372 ) 

373 

374 return self 

375 

376 kwargs: Dict[str, Any] = Field( 

377 default_factory=cast(Callable[[], Dict[str, Any]], dict) 

378 ) 

379 """key word arguments for the `architecture` callable""" 

380 

381 pytorch_version: Optional[Version] = None 

382 """Version of the PyTorch library used. 

383 If `depencencies` is specified it should include pytorch and the verison has to match. 

384 (`dependencies` overrules `pytorch_version`)""" 

385 

386 @field_validator("pytorch_version", mode="after") 

387 @classmethod 

388 def _ptv(cls, value: Any): 

389 if value is None: 

390 issue_warning( 

391 "missing. Please specify the PyTorch version these" 

392 + " PyTorch state dict weights were created with.", 

393 value=value, 

394 severity=ALERT, 

395 field="pytorch_version", 

396 ) 

397 return value 

398 

399 

400class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

401 type = "torchscript" 

402 weights_format_name: ClassVar[str] = "TorchScript" 

403 pytorch_version: Optional[Version] = None 

404 """Version of the PyTorch library used.""" 

405 

406 @field_validator("pytorch_version", mode="after") 

407 @classmethod 

408 def _ptv(cls, value: Any): 

409 if value is None: 

410 issue_warning( 

411 "missing. Please specify the PyTorch version these" 

412 + " Torchscript weights were created with.", 

413 value=value, 

414 severity=ALERT, 

415 field="pytorch_version", 

416 ) 

417 return value 

418 

419 

420class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

421 type = "tensorflow_js" 

422 weights_format_name: ClassVar[str] = "Tensorflow.js" 

423 tensorflow_version: Optional[Version] = None 

424 """Version of the TensorFlow library used.""" 

425 

426 @field_validator("tensorflow_version", mode="after") 

427 @classmethod 

428 def _tfv(cls, value: Any): 

429 if value is None: 

430 issue_warning( 

431 "missing. Please specify the TensorFlow version" 

432 + " these TensorflowJs weights were created with.", 

433 value=value, 

434 severity=ALERT, 

435 field="tensorflow_version", 

436 ) 

437 return value 

438 

439 source: FileSource_ 

440 """The multi-file weights. 

441 All required files/folders should be a zip archive.""" 

442 

443 

444class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

445 type = "tensorflow_saved_model_bundle" 

446 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

447 tensorflow_version: Optional[Version] = None 

448 """Version of the TensorFlow library used.""" 

449 

450 @field_validator("tensorflow_version", mode="after") 

451 @classmethod 

452 def _tfv(cls, value: Any): 

453 if value is None: 

454 issue_warning( 

455 "missing. Please specify the TensorFlow version" 

456 + " these Tensorflow saved model bundle weights were created with.", 

457 value=value, 

458 severity=ALERT, 

459 field="tensorflow_version", 

460 ) 

461 return value 

462 

463 

464class WeightsDescr(Node): 

465 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

466 onnx: Optional[OnnxWeightsDescr] = None 

467 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

468 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

469 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

470 None 

471 ) 

472 torchscript: Optional[TorchscriptWeightsDescr] = None 

473 

474 @model_validator(mode="after") 

475 def check_one_entry(self) -> Self: 

476 if all( 

477 entry is None 

478 for entry in [ 

479 self.keras_hdf5, 

480 self.onnx, 

481 self.pytorch_state_dict, 

482 self.tensorflow_js, 

483 self.tensorflow_saved_model_bundle, 

484 self.torchscript, 

485 ] 

486 ): 

487 raise ValueError("Missing weights entry") 

488 

489 return self 

490 

491 def __getitem__( 

492 self, 

493 key: WeightsFormat, 

494 ): 

495 if key == "keras_hdf5": 

496 ret = self.keras_hdf5 

497 elif key == "onnx": 

498 ret = self.onnx 

499 elif key == "pytorch_state_dict": 

500 ret = self.pytorch_state_dict 

501 elif key == "tensorflow_js": 

502 ret = self.tensorflow_js 

503 elif key == "tensorflow_saved_model_bundle": 

504 ret = self.tensorflow_saved_model_bundle 

505 elif key == "torchscript": 

506 ret = self.torchscript 

507 else: 

508 raise KeyError(key) 

509 

510 if ret is None: 

511 raise KeyError(key) 

512 

513 return ret 

514 

515 @property 

516 def available_formats(self): 

517 return { 

518 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

519 **({} if self.onnx is None else {"onnx": self.onnx}), 

520 **( 

521 {} 

522 if self.pytorch_state_dict is None 

523 else {"pytorch_state_dict": self.pytorch_state_dict} 

524 ), 

525 **( 

526 {} 

527 if self.tensorflow_js is None 

528 else {"tensorflow_js": self.tensorflow_js} 

529 ), 

530 **( 

531 {} 

532 if self.tensorflow_saved_model_bundle is None 

533 else { 

534 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

535 } 

536 ), 

537 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

538 } 

539 

540 @property 

541 def missing_formats(self): 

542 return { 

543 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

544 } 

545 

546 

547class ParameterizedInputShape(Node): 

548 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`.""" 

549 

550 min: NotEmpty[List[int]] 

551 """The minimum input shape""" 

552 

553 step: NotEmpty[List[int]] 

554 """The minimum shape change""" 

555 

556 def __len__(self) -> int: 

557 return len(self.min) 

558 

559 @model_validator(mode="after") 

560 def matching_lengths(self) -> Self: 

561 if len(self.min) != len(self.step): 

562 raise ValueError("`min` and `step` required to have the same length") 

563 

564 return self 

565 

566 

567class ImplicitOutputShape(Node): 

568 """Output tensor shape depending on an input tensor shape. 

569 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`""" 

570 

571 reference_tensor: TensorName 

572 """Name of the reference tensor.""" 

573 

574 scale: NotEmpty[List[Optional[float]]] 

575 """output_pix/input_pix for each dimension. 

576 'null' values indicate new dimensions, whose length is defined by 2*`offset`""" 

577 

578 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]] 

579 """Position of origin wrt to input.""" 

580 

581 def __len__(self) -> int: 

582 return len(self.scale) 

583 

584 @model_validator(mode="after") 

585 def matching_lengths(self) -> Self: 

586 if len(self.scale) != len(self.offset): 

587 raise ValueError( 

588 f"scale {self.scale} has to have same length as offset {self.offset}!" 

589 ) 

590 # if we have an expanded dimension, make sure that it's offet is not zero 

591 for sc, off in zip(self.scale, self.offset): 

592 if sc is None and not off: 

593 raise ValueError("`offset` must not be zero if `scale` is none/zero") 

594 

595 return self 

596 

597 

598class TensorDescrBase(Node): 

599 name: TensorName 

600 """Tensor name. No duplicates are allowed.""" 

601 

602 description: str = "" 

603 

604 axes: AxesStr 

605 """Axes identifying characters. Same length and order as the axes in `shape`. 

606 | axis | description | 

607 | --- | --- | 

608 | b | batch (groups multiple samples) | 

609 | i | instance/index/element | 

610 | t | time | 

611 | c | channel | 

612 | z | spatial dimension z | 

613 | y | spatial dimension y | 

614 | x | spatial dimension x | 

615 """ 

616 

617 data_range: Optional[ 

618 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]] 

619 ] = None 

620 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

621 If not specified, the full data range that can be expressed in `data_type` is allowed.""" 

622 

623 

624class ProcessingKwargs(KwargsNode): 

625 """base class for pre-/postprocessing key word arguments""" 

626 

627 

628class ProcessingDescrBase(NodeWithExplicitlySetFields): 

629 """processing base class""" 

630 

631 

632class BinarizeKwargs(ProcessingKwargs): 

633 """key word arguments for `BinarizeDescr`""" 

634 

635 threshold: float 

636 """The fixed threshold""" 

637 

638 

639class BinarizeDescr(ProcessingDescrBase): 

640 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`. 

641 Values above the threshold will be set to one, values below the threshold to zero. 

642 """ 

643 

644 implemented_name: ClassVar[Literal["binarize"]] = "binarize" 

645 if TYPE_CHECKING: 

646 name: Literal["binarize"] = "binarize" 

647 else: 

648 name: Literal["binarize"] 

649 

650 kwargs: BinarizeKwargs 

651 

652 

653class ClipKwargs(ProcessingKwargs): 

654 """key word arguments for `ClipDescr`""" 

655 

656 min: float 

657 """minimum value for clipping""" 

658 max: float 

659 """maximum value for clipping""" 

660 

661 

662class ClipDescr(ProcessingDescrBase): 

663 """Clip tensor values to a range. 

664 

665 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min` 

666 and above `ClipKwargs.max` to `ClipKwargs.max`. 

667 """ 

668 

669 implemented_name: ClassVar[Literal["clip"]] = "clip" 

670 if TYPE_CHECKING: 

671 name: Literal["clip"] = "clip" 

672 else: 

673 name: Literal["clip"] 

674 

675 kwargs: ClipKwargs 

676 

677 

678class ScaleLinearKwargs(ProcessingKwargs): 

679 """key word arguments for `ScaleLinearDescr`""" 

680 

681 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 

682 """The subset of axes to scale jointly. 

683 For example xy to scale the two image axes for 2d data jointly.""" 

684 

685 gain: Union[float, List[float]] = 1.0 

686 """multiplicative factor""" 

687 

688 offset: Union[float, List[float]] = 0.0 

689 """additive term""" 

690 

691 @model_validator(mode="after") 

692 def either_gain_or_offset(self) -> Self: 

693 if ( 

694 self.gain == 1.0 

695 or isinstance(self.gain, list) 

696 and all(g == 1.0 for g in self.gain) 

697 ) and ( 

698 self.offset == 0.0 

699 or isinstance(self.offset, list) 

700 and all(off == 0.0 for off in self.offset) 

701 ): 

702 raise ValueError( 

703 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 

704 + " 0.0." 

705 ) 

706 

707 return self 

708 

709 

710class ScaleLinearDescr(ProcessingDescrBase): 

711 """Fixed linear scaling.""" 

712 

713 implemented_name: ClassVar[Literal["scale_linear"]] = "scale_linear" 

714 if TYPE_CHECKING: 

715 name: Literal["scale_linear"] = "scale_linear" 

716 else: 

717 name: Literal["scale_linear"] 

718 

719 kwargs: ScaleLinearKwargs 

720 

721 

722class SigmoidDescr(ProcessingDescrBase): 

723 """The logistic sigmoid funciton, a.k.a. expit function.""" 

724 

725 implemented_name: ClassVar[Literal["sigmoid"]] = "sigmoid" 

726 if TYPE_CHECKING: 

727 name: Literal["sigmoid"] = "sigmoid" 

728 else: 

729 name: Literal["sigmoid"] 

730 

731 @property 

732 def kwargs(self) -> ProcessingKwargs: 

733 """empty kwargs""" 

734 return ProcessingKwargs() 

735 

736 

737class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

738 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

739 

740 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed" 

741 """Mode for computing mean and variance. 

742 | mode | description | 

743 | ----------- | ------------------------------------ | 

744 | fixed | Fixed values for mean and variance | 

745 | per_dataset | Compute for the entire dataset | 

746 | per_sample | Compute for each sample individually | 

747 """ 

748 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 

749 """The subset of axes to normalize jointly. 

750 For example `xy` to normalize the two image axes for 2d data jointly.""" 

751 

752 mean: Annotated[ 

753 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)]) 

754 ] = None 

755 """The mean value(s) to use for `mode: fixed`. 

756 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`.""" 

757 # todo: check if means match input axes (for mode 'fixed') 

758 

759 std: Annotated[ 

760 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)]) 

761 ] = None 

762 """The standard deviation values to use for `mode: fixed`. Analogous to mean.""" 

763 

764 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

765 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

766 

767 @model_validator(mode="after") 

768 def mean_and_std_match_mode(self) -> Self: 

769 if self.mode == "fixed" and (self.mean is None or self.std is None): 

770 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 

771 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 

772 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 

773 

774 return self 

775 

776 

777class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

778 """Subtract mean and divide by variance.""" 

779 

780 implemented_name: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

781 "zero_mean_unit_variance" 

782 ) 

783 if TYPE_CHECKING: 

784 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

785 else: 

786 name: Literal["zero_mean_unit_variance"] 

787 

788 kwargs: ZeroMeanUnitVarianceKwargs 

789 

790 

791class ScaleRangeKwargs(ProcessingKwargs): 

792 """key word arguments for `ScaleRangeDescr` 

793 

794 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

795 this processing step normalizes data to the [0, 1] intervall. 

796 For other percentiles the normalized values will partially be outside the [0, 1] 

797 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

798 normalized values to a range. 

799 """ 

800 

801 mode: Literal["per_dataset", "per_sample"] 

802 """Mode for computing percentiles. 

803 | mode | description | 

804 | ----------- | ------------------------------------ | 

805 | per_dataset | compute for the entire dataset | 

806 | per_sample | compute for each sample individually | 

807 """ 

808 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 

809 """The subset of axes to normalize jointly. 

810 For example xy to normalize the two image axes for 2d data jointly.""" 

811 

812 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0 

813 """The lower percentile used to determine the value to align with zero.""" 

814 

815 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0 

816 """The upper percentile used to determine the value to align with one. 

817 Has to be bigger than `min_percentile`. 

818 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

819 accepting percentiles specified in the range 0.0 to 1.0.""" 

820 

821 @model_validator(mode="after") 

822 def min_smaller_max(self, info: ValidationInfo) -> Self: 

823 if self.min_percentile >= self.max_percentile: 

824 raise ValueError( 

825 f"min_percentile {self.min_percentile} >= max_percentile" 

826 + f" {self.max_percentile}" 

827 ) 

828 

829 return self 

830 

831 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

832 """Epsilon for numeric stability. 

833 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

834 with `v_lower,v_upper` values at the respective percentiles.""" 

835 

836 reference_tensor: Optional[TensorName] = None 

837 """Tensor name to compute the percentiles from. Default: The tensor itself. 

838 For any tensor in `inputs` only input tensor references are allowed. 

839 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`""" 

840 

841 

842class ScaleRangeDescr(ProcessingDescrBase): 

843 """Scale with percentiles.""" 

844 

845 implemented_name: ClassVar[Literal["scale_range"]] = "scale_range" 

846 if TYPE_CHECKING: 

847 name: Literal["scale_range"] = "scale_range" 

848 else: 

849 name: Literal["scale_range"] 

850 

851 kwargs: ScaleRangeKwargs 

852 

853 

854class ScaleMeanVarianceKwargs(ProcessingKwargs): 

855 """key word arguments for `ScaleMeanVarianceDescr`""" 

856 

857 mode: Literal["per_dataset", "per_sample"] 

858 """Mode for computing mean and variance. 

859 | mode | description | 

860 | ----------- | ------------------------------------ | 

861 | per_dataset | Compute for the entire dataset | 

862 | per_sample | Compute for each sample individually | 

863 """ 

864 

865 reference_tensor: TensorName 

866 """Name of tensor to match.""" 

867 

868 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 

869 """The subset of axes to scale jointly. 

870 For example xy to normalize the two image axes for 2d data jointly. 

871 Default: scale all non-batch axes jointly.""" 

872 

873 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

874 """Epsilon for numeric stability: 

875 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.""" 

876 

877 

878class ScaleMeanVarianceDescr(ProcessingDescrBase): 

879 """Scale the tensor s.t. its mean and variance match a reference tensor.""" 

880 

881 implemented_name: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

882 if TYPE_CHECKING: 

883 name: Literal["scale_mean_variance"] = "scale_mean_variance" 

884 else: 

885 name: Literal["scale_mean_variance"] 

886 

887 kwargs: ScaleMeanVarianceKwargs 

888 

889 

890PreprocessingDescr = Annotated[ 

891 Union[ 

892 BinarizeDescr, 

893 ClipDescr, 

894 ScaleLinearDescr, 

895 SigmoidDescr, 

896 ZeroMeanUnitVarianceDescr, 

897 ScaleRangeDescr, 

898 ], 

899 Discriminator("name"), 

900] 

901PostprocessingDescr = Annotated[ 

902 Union[ 

903 BinarizeDescr, 

904 ClipDescr, 

905 ScaleLinearDescr, 

906 SigmoidDescr, 

907 ZeroMeanUnitVarianceDescr, 

908 ScaleRangeDescr, 

909 ScaleMeanVarianceDescr, 

910 ], 

911 Discriminator("name"), 

912] 

913 

914 

915class InputTensorDescr(TensorDescrBase): 

916 data_type: Literal["float32", "uint8", "uint16"] 

917 """For now an input tensor is expected to be given as `float32`. 

918 The data flow in bioimage.io models is explained 

919 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 

920 

921 shape: Annotated[ 

922 Union[Sequence[int], ParameterizedInputShape], 

923 Field( 

924 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))] 

925 ), 

926 ] 

927 """Specification of input tensor shape.""" 

928 

929 preprocessing: List[PreprocessingDescr] = Field( 

930 default_factory=cast( # TODO: (py>3.8) use list[PreprocessingDesr] 

931 Callable[[], List[PreprocessingDescr]], list 

932 ) 

933 ) 

934 """Description of how this input should be preprocessed.""" 

935 

936 @model_validator(mode="after") 

937 def zero_batch_step_and_one_batch_size(self) -> Self: 

938 bidx = self.axes.find("b") 

939 if bidx == -1: 

940 return self 

941 

942 if isinstance(self.shape, ParameterizedInputShape): 

943 step = self.shape.step 

944 shape = self.shape.min 

945 if step[bidx] != 0: 

946 raise ValueError( 

947 "Input shape step has to be zero in the batch dimension (the batch" 

948 + " dimension can always be increased, but `step` should specify how" 

949 + " to increase the minimal shape to find the largest single batch" 

950 + " shape)" 

951 ) 

952 else: 

953 shape = self.shape 

954 

955 if shape[bidx] != 1: 

956 raise ValueError("Input shape has to be 1 in the batch dimension b.") 

957 

958 return self 

959 

960 @model_validator(mode="after") 

961 def validate_preprocessing_kwargs(self) -> Self: 

962 for p in self.preprocessing: 

963 kwargs_axes = p.kwargs.get("axes") 

964 if isinstance(kwargs_axes, str) and any( 

965 a not in self.axes for a in kwargs_axes 

966 ): 

967 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 

968 

969 return self 

970 

971 

972class OutputTensorDescr(TensorDescrBase): 

973 data_type: Literal[ 

974 "float32", 

975 "float64", 

976 "uint8", 

977 "int8", 

978 "uint16", 

979 "int16", 

980 "uint32", 

981 "int32", 

982 "uint64", 

983 "int64", 

984 "bool", 

985 ] 

986 """Data type. 

987 The data flow in bioimage.io models is explained 

988 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 

989 

990 shape: Union[Sequence[int], ImplicitOutputShape] 

991 """Output tensor shape.""" 

992 

993 halo: Optional[Sequence[int]] = None 

994 """The `halo` that should be cropped from the output tensor to avoid boundary effects. 

995 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. 

996 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead.""" 

997 

998 postprocessing: List[PostprocessingDescr] = Field( 

999 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

1000 ) 

1001 """Description of how this output should be postprocessed.""" 

1002 

1003 @model_validator(mode="after") 

1004 def matching_halo_length(self) -> Self: 

1005 if self.halo and len(self.halo) != len(self.shape): 

1006 raise ValueError( 

1007 f"halo {self.halo} has to have same length as shape {self.shape}!" 

1008 ) 

1009 

1010 return self 

1011 

1012 @model_validator(mode="after") 

1013 def validate_postprocessing_kwargs(self) -> Self: 

1014 for p in self.postprocessing: 

1015 kwargs_axes = p.kwargs.get("axes", "") 

1016 if not isinstance(kwargs_axes, str): 

1017 raise ValueError(f"Expected {kwargs_axes} to be a string") 

1018 

1019 if any(a not in self.axes for a in kwargs_axes): 

1020 raise ValueError("`kwargs.axes` needs to be subset of axes") 

1021 

1022 return self 

1023 

1024 

1025KnownRunMode = Literal["deepimagej"] 

1026 

1027 

1028class RunMode(Node): 

1029 name: Annotated[ 

1030 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.") 

1031 ] 

1032 """Run mode name""" 

1033 

1034 kwargs: Dict[str, Any] = Field( 

1035 default_factory=cast(Callable[[], Dict[str, Any]], dict) 

1036 ) 

1037 """Run mode specific key word arguments""" 

1038 

1039 

1040class LinkedModel(Node): 

1041 """Reference to a bioimage.io model.""" 

1042 

1043 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])] 

1044 """A valid model `id` from the bioimage.io collection.""" 

1045 

1046 version_number: Optional[int] = None 

1047 """version number (n-th published version, not the semantic version) of linked model""" 

1048 

1049 

1050def package_weights( 

1051 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr] 

1052 handler: SerializerFunctionWrapHandler, 

1053 info: SerializationInfo, 

1054): 

1055 ctxt = packaging_context_var.get() 

1056 if ctxt is not None and ctxt.weights_priority_order is not None: 

1057 for wf in ctxt.weights_priority_order: 

1058 w = getattr(value, wf, None) 

1059 if w is not None: 

1060 break 

1061 else: 

1062 raise ValueError( 

1063 "None of the weight formats in `weights_priority_order`" 

1064 + f" ({ctxt.weights_priority_order}) is present in the given model." 

1065 ) 

1066 

1067 assert isinstance(w, Node), type(w) 

1068 # construct WeightsDescr with new single weight format entry 

1069 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"}) 

1070 value = value.model_construct(None, **{wf: new_w}) 

1071 

1072 return handler( 

1073 value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs 

1074 ) 

1075 

1076 

1077class ModelDescr(GenericModelDescrBase): 

1078 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights. 

1079 

1080 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

1081 """ 

1082 

1083 implemented_format_version: ClassVar[Literal["0.4.10"]] = "0.4.10" 

1084 if TYPE_CHECKING: 

1085 format_version: Literal["0.4.10"] = "0.4.10" 

1086 else: 

1087 format_version: Literal["0.4.10"] 

1088 """Version of the bioimage.io model description specification used. 

1089 When creating a new model always use the latest micro/patch version described here. 

1090 The `format_version` is important for any consumer software to understand how to parse the fields. 

1091 """ 

1092 

1093 implemented_type: ClassVar[Literal["model"]] = "model" 

1094 if TYPE_CHECKING: 

1095 type: Literal["model"] = "model" 

1096 else: 

1097 type: Literal["model"] 

1098 """Specialized resource type 'model'""" 

1099 

1100 id: Optional[ModelId] = None 

1101 """bioimage.io-wide unique resource identifier 

1102 assigned by bioimage.io; version **un**specific.""" 

1103 

1104 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory 

1105 List[Author] 

1106 ] 

1107 """The authors are the creators of the model RDF and the primary points of contact.""" 

1108 

1109 documentation: Annotated[ 

1110 FileSource_, 

1111 Field( 

1112 examples=[ 

1113 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 

1114 "README.md", 

1115 ], 

1116 ), 

1117 ] 

1118 """URL or relative path to a markdown file with additional documentation. 

1119 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

1120 The documentation should include a '[#[#]]# Validation' (sub)section 

1121 with details on how to quantitatively validate the model on unseen data.""" 

1122 

1123 inputs: NotEmpty[List[InputTensorDescr]] 

1124 """Describes the input tensors expected by this model.""" 

1125 

1126 license: Annotated[ 

1127 Union[LicenseId, str], 

1128 warn(LicenseId, "Unknown license id '{value}'."), 

1129 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]), 

1130 ] 

1131 """A [SPDX license identifier](https://spdx.org/licenses/). 

1132 We do notsupport custom license beyond the SPDX license list, if you need that please 

1133 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose 

1134 ) to discuss your intentions with the community.""" 

1135 

1136 name: Annotated[ 

1137 str, 

1138 MinLen(1), 

1139 warn(MinLen(5), "Name shorter than 5 characters.", INFO), 

1140 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

1141 ] 

1142 """A human-readable name of this model. 

1143 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.""" 

1144 

1145 outputs: NotEmpty[List[OutputTensorDescr]] 

1146 """Describes the output tensors.""" 

1147 

1148 @field_validator("inputs", "outputs") 

1149 @classmethod 

1150 def unique_tensor_descr_names( 

1151 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 

1152 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 

1153 unique_names = {str(v.name) for v in value} 

1154 if len(unique_names) != len(value): 

1155 raise ValueError("Duplicate tensor descriptor names") 

1156 

1157 return value 

1158 

1159 @model_validator(mode="after") 

1160 def unique_io_names(self) -> Self: 

1161 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 

1162 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 

1163 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 

1164 

1165 return self 

1166 

1167 @model_validator(mode="after") 

1168 def minimum_shape2valid_output(self) -> Self: 

1169 tensors_by_name: Dict[ 

1170 TensorName, Union[InputTensorDescr, OutputTensorDescr] 

1171 ] = {t.name: t for t in self.inputs + self.outputs} 

1172 

1173 for out in self.outputs: 

1174 if isinstance(out.shape, ImplicitOutputShape): 

1175 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 

1176 ndim_out_ref = len( 

1177 [scale for scale in out.shape.scale if scale is not None] 

1178 ) 

1179 if ndim_ref != ndim_out_ref: 

1180 expanded_dim_note = ( 

1181 " Note that expanded dimensions (`scale`: null) are not" 

1182 + f" counted for {out.name}'sdimensionality here." 

1183 if None in out.shape.scale 

1184 else "" 

1185 ) 

1186 raise ValueError( 

1187 f"Referenced tensor '{out.shape.reference_tensor}' with" 

1188 + f" {ndim_ref} dimensions does not match output tensor" 

1189 + f" '{out.name}' with" 

1190 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 

1191 ) 

1192 

1193 min_out_shape = self._get_min_shape(out, tensors_by_name) 

1194 if out.halo: 

1195 halo = out.halo 

1196 halo_msg = f" for halo {out.halo}" 

1197 else: 

1198 halo = [0] * len(min_out_shape) 

1199 halo_msg = "" 

1200 

1201 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 

1202 raise ValueError( 

1203 f"Minimal shape {min_out_shape} of output {out.name} is too" 

1204 + f" small{halo_msg}." 

1205 ) 

1206 

1207 return self 

1208 

1209 @classmethod 

1210 def _get_min_shape( 

1211 cls, 

1212 t: Union[InputTensorDescr, OutputTensorDescr], 

1213 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]], 

1214 ) -> Sequence[int]: 

1215 """output with subtracted halo has to result in meaningful output even for the minimal input 

1216 see https://github.com/bioimage-io/spec-bioimage-io/issues/392 

1217 """ 

1218 if isinstance(t.shape, collections.abc.Sequence): 

1219 return t.shape 

1220 elif isinstance(t.shape, ParameterizedInputShape): 

1221 return t.shape.min 

1222 elif isinstance(t.shape, ImplicitOutputShape): 

1223 pass 

1224 else: 

1225 assert_never(t.shape) 

1226 

1227 ref_shape = cls._get_min_shape( 

1228 tensors_by_name[t.shape.reference_tensor], tensors_by_name 

1229 ) 

1230 

1231 if None not in t.shape.scale: 

1232 scale: Sequence[float, ...] = t.shape.scale # type: ignore 

1233 else: 

1234 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None] 

1235 new_ref_shape: List[int] = [] 

1236 for idx in range(len(t.shape.scale)): 

1237 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims) 

1238 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx]) 

1239 

1240 ref_shape = new_ref_shape 

1241 assert len(ref_shape) == len(t.shape.scale) 

1242 scale = [0.0 if sc is None else sc for sc in t.shape.scale] 

1243 

1244 offset = t.shape.offset 

1245 assert len(offset) == len(scale) 

1246 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)] 

1247 

1248 @model_validator(mode="after") 

1249 def validate_tensor_references_in_inputs(self) -> Self: 

1250 for t in self.inputs: 

1251 for proc in t.preprocessing: 

1252 if "reference_tensor" not in proc.kwargs: 

1253 continue 

1254 

1255 ref_tensor = proc.kwargs["reference_tensor"] 

1256 if ref_tensor is not None and str(ref_tensor) not in { 

1257 str(t.name) for t in self.inputs 

1258 }: 

1259 raise ValueError(f"'{ref_tensor}' not found in inputs") 

1260 

1261 if ref_tensor == t.name: 

1262 raise ValueError( 

1263 f"invalid self reference for preprocessing of tensor {t.name}" 

1264 ) 

1265 

1266 return self 

1267 

1268 @model_validator(mode="after") 

1269 def validate_tensor_references_in_outputs(self) -> Self: 

1270 for t in self.outputs: 

1271 for proc in t.postprocessing: 

1272 if "reference_tensor" not in proc.kwargs: 

1273 continue 

1274 ref_tensor = proc.kwargs["reference_tensor"] 

1275 if ref_tensor is not None and str(ref_tensor) not in { 

1276 str(t.name) for t in self.inputs 

1277 }: 

1278 raise ValueError(f"{ref_tensor} not found in inputs") 

1279 

1280 return self 

1281 

1282 packaged_by: List[Author] = Field( 

1283 default_factory=cast(Callable[[], List[Author]], list) 

1284 ) 

1285 """The persons that have packaged and uploaded this model. 

1286 Only required if those persons differ from the `authors`.""" 

1287 

1288 parent: Optional[LinkedModel] = None 

1289 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

1290 

1291 @field_validator("parent", mode="before") 

1292 @classmethod 

1293 def ignore_url_parent(cls, parent: Any): 

1294 if isinstance(parent, dict): 

1295 return None 

1296 

1297 else: 

1298 return parent 

1299 

1300 run_mode: Optional[RunMode] = None 

1301 """Custom run mode for this model: for more complex prediction procedures like test time 

1302 data augmentation that currently cannot be expressed in the specification. 

1303 No standard run modes are defined yet.""" 

1304 

1305 sample_inputs: List[FileSource_] = Field( 

1306 default_factory=cast(Callable[[], List[FileSource_]], list) 

1307 ) 

1308 """URLs/relative paths to sample inputs to illustrate possible inputs for the model, 

1309 for example stored as PNG or TIFF images. 

1310 The sample files primarily serve to inform a human user about an example use case""" 

1311 

1312 sample_outputs: List[FileSource_] = Field( 

1313 default_factory=cast(Callable[[], List[FileSource_]], list) 

1314 ) 

1315 """URLs/relative paths to sample outputs corresponding to the `sample_inputs`.""" 

1316 

1317 test_inputs: NotEmpty[ 

1318 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]] 

1319 ] 

1320 """Test input tensors compatible with the `inputs` description for a **single test case**. 

1321 This means if your model has more than one input, you should provide one URL/relative path for each input. 

1322 Each test input should be a file with an ndarray in 

1323 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1324 The extension must be '.npy'.""" 

1325 

1326 test_outputs: NotEmpty[ 

1327 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]] 

1328 ] 

1329 """Analog to `test_inputs`.""" 

1330 

1331 timestamp: Datetime 

1332 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

1333 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).""" 

1334 

1335 training_data: Union[LinkedDataset, DatasetDescr, None] = None 

1336 """The dataset used to train this model""" 

1337 

1338 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

1339 """The weights for this model. 

1340 Weights can be given for different formats, but should otherwise be equivalent. 

1341 The available weight formats determine which consumers can use this model.""" 

1342 

1343 @model_validator(mode="before") 

1344 @classmethod 

1345 def _convert_from_older_format( 

1346 cls, data: BioimageioYamlContent, / 

1347 ) -> BioimageioYamlContent: 

1348 convert_from_older_format(data) 

1349 return data 

1350 

1351 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

1352 data = [load_array(ipt) for ipt in self.test_inputs] 

1353 assert all(isinstance(d, np.ndarray) for d in data) 

1354 return data 

1355 

1356 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

1357 data = [load_array(out) for out in self.test_outputs] 

1358 assert all(isinstance(d, np.ndarray) for d in data) 

1359 return data