Coverage for src / bioimageio / spec / model / v0_4.py: 91%

593 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-08 13:52 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Callable, 

8 ClassVar, 

9 Dict, 

10 List, 

11 Literal, 

12 Optional, 

13 Sequence, 

14 Tuple, 

15 Type, 

16 Union, 

17 cast, 

18) 

19 

20import numpy as np 

21from annotated_types import Ge, Interval, MaxLen, MinLen, MultipleOf 

22from numpy.typing import NDArray 

23from pydantic import ( 

24 AllowInfNan, 

25 Discriminator, 

26 Field, 

27 RootModel, 

28 SerializationInfo, 

29 SerializerFunctionWrapHandler, 

30 StringConstraints, 

31 TypeAdapter, 

32 ValidationInfo, 

33 WrapSerializer, 

34 field_validator, 

35 model_validator, 

36) 

37from typing_extensions import Annotated, Self, assert_never, get_args 

38 

39from .._internal.common_nodes import ( 

40 KwargsNode, 

41 Node, 

42 NodeWithExplicitlySetFields, 

43) 

44from .._internal.constants import SHA256_HINT 

45from .._internal.field_validation import validate_unique_entries 

46from .._internal.field_warning import issue_warning, warn 

47from .._internal.io import BioimageioYamlContent, WithSuffix 

48from .._internal.io import FileDescr as FileDescr 

49from .._internal.io_basics import Sha256 as Sha256 

50from .._internal.io_packaging import include_in_package 

51from .._internal.io_utils import load_array 

52from .._internal.packaging_context import packaging_context_var 

53from .._internal.types import Datetime as Datetime 

54from .._internal.types import FileSource_, LowerCaseIdentifier 

55from .._internal.types import Identifier as Identifier 

56from .._internal.types import LicenseId as LicenseId 

57from .._internal.types import NotEmpty as NotEmpty 

58from .._internal.url import HttpUrl as HttpUrl 

59from .._internal.validated_string_with_inner_node import ValidatedStringWithInnerNode 

60from .._internal.validator_annotations import AfterValidator, RestrictCharacters 

61from .._internal.version_type import Version as Version 

62from .._internal.warning_levels import ALERT, INFO 

63from ..dataset.v0_2 import VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS 

64from ..dataset.v0_2 import DatasetDescr as DatasetDescr 

65from ..dataset.v0_2 import LinkedDataset as LinkedDataset 

66from ..generic.v0_2 import AttachmentsDescr as AttachmentsDescr 

67from ..generic.v0_2 import Author as Author 

68from ..generic.v0_2 import BadgeDescr as BadgeDescr 

69from ..generic.v0_2 import CiteEntry as CiteEntry 

70from ..generic.v0_2 import Doi as Doi 

71from ..generic.v0_2 import GenericModelDescrBase 

72from ..generic.v0_2 import LinkedResource as LinkedResource 

73from ..generic.v0_2 import Maintainer as Maintainer 

74from ..generic.v0_2 import OrcidId as OrcidId 

75from ..generic.v0_2 import RelativeFilePath as RelativeFilePath 

76from ..generic.v0_2 import ResourceId as ResourceId 

77from ..generic.v0_2 import Uploader as Uploader 

78from ._v0_4_converter import convert_from_older_format 

79 

80 

81class ModelId(ResourceId): 

82 pass 

83 

84 

85AxesStr = Annotated[ 

86 str, RestrictCharacters("bitczyx"), AfterValidator(validate_unique_entries) 

87] 

88AxesInCZYX = Annotated[ 

89 str, RestrictCharacters("czyx"), AfterValidator(validate_unique_entries) 

90] 

91 

92PostprocessingName = Literal[ 

93 "binarize", 

94 "clip", 

95 "scale_linear", 

96 "sigmoid", 

97 "zero_mean_unit_variance", 

98 "scale_range", 

99 "scale_mean_variance", 

100] 

101PreprocessingName = Literal[ 

102 "binarize", 

103 "clip", 

104 "scale_linear", 

105 "sigmoid", 

106 "zero_mean_unit_variance", 

107 "scale_range", 

108] 

109 

110 

111class TensorName(LowerCaseIdentifier): 

112 pass 

113 

114 

115class CallableFromDepencencyNode(Node): 

116 _submodule_adapter: ClassVar[TypeAdapter[Identifier]] = TypeAdapter(Identifier) 

117 

118 module_name: str 

119 """The Python module that implements **callable_name**.""" 

120 

121 @field_validator("module_name", mode="after") 

122 def _check_submodules(cls, module_name: str) -> str: 

123 for submod in module_name.split("."): 

124 _ = cls._submodule_adapter.validate_python(submod) 

125 

126 return module_name 

127 

128 callable_name: Identifier 

129 """The callable Python identifier implemented in module **module_name**.""" 

130 

131 

132class CallableFromDepencency(ValidatedStringWithInnerNode[CallableFromDepencencyNode]): 

133 _inner_node_class = CallableFromDepencencyNode 

134 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

135 Annotated[ 

136 str, 

137 StringConstraints(strip_whitespace=True, pattern=r"^.+\..+$"), 

138 ] 

139 ] 

140 

141 @classmethod 

142 def _get_data(cls, valid_string_data: str): 

143 *mods, callname = valid_string_data.split(".") 

144 return dict(module_name=".".join(mods), callable_name=callname) 

145 

146 @property 

147 def module_name(self): 

148 """The Python module that implements **callable_name**.""" 

149 return self._inner_node.module_name 

150 

151 @property 

152 def callable_name(self): 

153 """The callable Python identifier implemented in module **module_name**.""" 

154 return self._inner_node.callable_name 

155 

156 

157class CallableFromFileNode(Node): 

158 source_file: Annotated[ 

159 Union[RelativeFilePath, HttpUrl], 

160 Field(union_mode="left_to_right"), 

161 include_in_package, 

162 ] 

163 """The Python source file that implements **callable_name**.""" 

164 callable_name: Identifier 

165 """The callable Python identifier implemented in **source_file**.""" 

166 

167 

168class CallableFromFile(ValidatedStringWithInnerNode[CallableFromFileNode]): 

169 _inner_node_class = CallableFromFileNode 

170 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

171 Annotated[ 

172 str, 

173 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 

174 ] 

175 ] 

176 

177 @classmethod 

178 def _get_data(cls, valid_string_data: str): 

179 *file_parts, callname = valid_string_data.split(":") 

180 return dict(source_file=":".join(file_parts), callable_name=callname) 

181 

182 @property 

183 def source_file(self): 

184 """The Python source file that implements **callable_name**.""" 

185 return self._inner_node.source_file 

186 

187 @property 

188 def callable_name(self): 

189 """The callable Python identifier implemented in **source_file**.""" 

190 return self._inner_node.callable_name 

191 

192 

193CustomCallable = Annotated[ 

194 Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right") 

195] 

196 

197 

198class DependenciesNode(Node): 

199 manager: Annotated[NotEmpty[str], Field(examples=["conda", "maven", "pip"])] 

200 """Dependency manager""" 

201 

202 file: FileSource_ 

203 """Dependency file""" 

204 

205 

206class Dependencies(ValidatedStringWithInnerNode[DependenciesNode]): 

207 _inner_node_class = DependenciesNode 

208 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

209 Annotated[ 

210 str, 

211 StringConstraints(strip_whitespace=True, pattern=r"^.+:.+$"), 

212 ] 

213 ] 

214 

215 @classmethod 

216 def _get_data(cls, valid_string_data: str): 

217 manager, *file_parts = valid_string_data.split(":") 

218 return dict(manager=manager, file=":".join(file_parts)) 

219 

220 @property 

221 def manager(self): 

222 """Dependency manager""" 

223 return self._inner_node.manager 

224 

225 @property 

226 def file(self): 

227 """Dependency file""" 

228 return self._inner_node.file 

229 

230 

231WeightsFormat = Literal[ 

232 "keras_hdf5", 

233 "onnx", 

234 "pytorch_state_dict", 

235 "tensorflow_js", 

236 "tensorflow_saved_model_bundle", 

237 "torchscript", 

238] 

239 

240 

241class WeightsEntryDescrBase(FileDescr): 

242 type: ClassVar[WeightsFormat] 

243 weights_format_name: ClassVar[str] # human readable 

244 

245 source: FileSource_ 

246 """The weights file.""" 

247 

248 attachments: Annotated[ 

249 Union[AttachmentsDescr, None], 

250 warn(None, "Weights entry depends on additional attachments.", ALERT), 

251 ] = None 

252 """Attachments that are specific to this weights entry.""" 

253 

254 authors: Union[List[Author], None] = None 

255 """Authors 

256 Either the person(s) that have trained this model resulting in the original weights file. 

257 (If this is the initial weights entry, i.e. it does not have a `parent`) 

258 Or the person(s) who have converted the weights to this weights format. 

259 (If this is a child weight, i.e. it has a `parent` field) 

260 """ 

261 

262 dependencies: Annotated[ 

263 Optional[Dependencies], 

264 warn( 

265 None, 

266 "Custom dependencies ({value}) specified. Avoid this whenever possible " 

267 + "to allow execution in a wider range of software environments.", 

268 ), 

269 Field( 

270 examples=[ 

271 "conda:environment.yaml", 

272 "maven:./pom.xml", 

273 "pip:./requirements.txt", 

274 ] 

275 ), 

276 ] = None 

277 """Dependency manager and dependency file, specified as `<dependency manager>:<relative file path>`.""" 

278 

279 parent: Annotated[ 

280 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

281 ] = None 

282 """The source weights these weights were converted from. 

283 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

284 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

285 All weight entries except one (the initial set of weights resulting from training the model), 

286 need to have this field.""" 

287 

288 @model_validator(mode="after") 

289 def check_parent_is_not_self(self) -> Self: 

290 if self.type == self.parent: 

291 raise ValueError("Weights entry can't be it's own parent.") 

292 

293 return self 

294 

295 

296class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

297 type: ClassVar[WeightsFormat] = "keras_hdf5" 

298 weights_format_name: ClassVar[str] = "Keras HDF5" 

299 tensorflow_version: Optional[Version] = None 

300 """TensorFlow version used to create these weights""" 

301 

302 @field_validator("tensorflow_version", mode="after") 

303 @classmethod 

304 def _tfv(cls, value: Any): 

305 if value is None: 

306 issue_warning( 

307 "missing. Please specify the TensorFlow version" 

308 + " these weights were created with.", 

309 value=value, 

310 severity=ALERT, 

311 field="tensorflow_version", 

312 ) 

313 return value 

314 

315 

316class OnnxWeightsDescr(WeightsEntryDescrBase): 

317 type: ClassVar[WeightsFormat] = "onnx" 

318 weights_format_name: ClassVar[str] = "ONNX" 

319 opset_version: Optional[Annotated[int, Ge(7)]] = None 

320 """ONNX opset version""" 

321 

322 @field_validator("opset_version", mode="after") 

323 @classmethod 

324 def _ov(cls, value: Any): 

325 if value is None: 

326 issue_warning( 

327 "Missing ONNX opset version (aka ONNX opset number). " 

328 + "Please specify the ONNX opset version these weights were created" 

329 + " with.", 

330 value=value, 

331 severity=ALERT, 

332 field="opset_version", 

333 ) 

334 return value 

335 

336 

337class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

338 type: ClassVar[WeightsFormat] = "pytorch_state_dict" 

339 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

340 architecture: CustomCallable = Field( 

341 examples=["my_function.py:MyNetworkClass", "my_module.submodule.get_my_model"] 

342 ) 

343 """callable returning a torch.nn.Module instance. 

344 Local implementation: `<relative path to file>:<identifier of implementation within the file>`. 

345 Implementation in a dependency: `<dependency-package>.<[dependency-module]>.<identifier>`.""" 

346 

347 architecture_sha256: Annotated[ 

348 Optional[Sha256], 

349 Field( 

350 description=( 

351 "The SHA256 of the architecture source file, if the architecture is not" 

352 " defined in a module listed in `dependencies`\n" 

353 ) 

354 + SHA256_HINT, 

355 ), 

356 ] = None 

357 """The SHA256 of the architecture source file, 

358 if the architecture is not defined in a module listed in `dependencies`""" 

359 

360 @model_validator(mode="after") 

361 def check_architecture_sha256(self) -> Self: 

362 if isinstance(self.architecture, CallableFromFile): 

363 if self.architecture_sha256 is None: 

364 raise ValueError( 

365 "Missing required `architecture_sha256` for `architecture` with" 

366 + " source file." 

367 ) 

368 elif self.architecture_sha256 is not None: 

369 raise ValueError( 

370 "Got `architecture_sha256` for architecture that does not have a source" 

371 + " file." 

372 ) 

373 

374 return self 

375 

376 kwargs: Dict[str, Any] = Field( 

377 default_factory=cast(Callable[[], Dict[str, Any]], dict) 

378 ) 

379 """key word arguments for the `architecture` callable""" 

380 

381 pytorch_version: Optional[Version] = None 

382 """Version of the PyTorch library used. 

383 If `depencencies` is specified it should include pytorch and the verison has to match. 

384 (`dependencies` overrules `pytorch_version`)""" 

385 

386 @field_validator("pytorch_version", mode="after") 

387 @classmethod 

388 def _ptv(cls, value: Any): 

389 if value is None: 

390 issue_warning( 

391 "missing. Please specify the PyTorch version these" 

392 + " PyTorch state dict weights were created with.", 

393 value=value, 

394 severity=ALERT, 

395 field="pytorch_version", 

396 ) 

397 return value 

398 

399 

400class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

401 type: ClassVar[WeightsFormat] = "torchscript" 

402 weights_format_name: ClassVar[str] = "TorchScript" 

403 pytorch_version: Optional[Version] = None 

404 """Version of the PyTorch library used.""" 

405 

406 @field_validator("pytorch_version", mode="after") 

407 @classmethod 

408 def _ptv(cls, value: Any): 

409 if value is None: 

410 issue_warning( 

411 "missing. Please specify the PyTorch version these" 

412 + " Torchscript weights were created with.", 

413 value=value, 

414 severity=ALERT, 

415 field="pytorch_version", 

416 ) 

417 return value 

418 

419 

420class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

421 type: ClassVar[WeightsFormat] = "tensorflow_js" 

422 weights_format_name: ClassVar[str] = "Tensorflow.js" 

423 tensorflow_version: Optional[Version] = None 

424 """Version of the TensorFlow library used.""" 

425 

426 @field_validator("tensorflow_version", mode="after") 

427 @classmethod 

428 def _tfv(cls, value: Any): 

429 if value is None: 

430 issue_warning( 

431 "missing. Please specify the TensorFlow version" 

432 + " these TensorflowJs weights were created with.", 

433 value=value, 

434 severity=ALERT, 

435 field="tensorflow_version", 

436 ) 

437 return value 

438 

439 source: FileSource_ 

440 """The multi-file weights. 

441 All required files/folders should be a zip archive.""" 

442 

443 

444class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

445 type: ClassVar[WeightsFormat] = "tensorflow_saved_model_bundle" 

446 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

447 tensorflow_version: Optional[Version] = None 

448 """Version of the TensorFlow library used.""" 

449 

450 @field_validator("tensorflow_version", mode="after") 

451 @classmethod 

452 def _tfv(cls, value: Any): 

453 if value is None: 

454 issue_warning( 

455 "missing. Please specify the TensorFlow version" 

456 + " these Tensorflow saved model bundle weights were created with.", 

457 value=value, 

458 severity=ALERT, 

459 field="tensorflow_version", 

460 ) 

461 return value 

462 

463 

464class WeightsDescr(Node): 

465 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

466 onnx: Optional[OnnxWeightsDescr] = None 

467 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

468 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

469 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

470 None 

471 ) 

472 torchscript: Optional[TorchscriptWeightsDescr] = None 

473 

474 @model_validator(mode="after") 

475 def check_one_entry(self) -> Self: 

476 if all( 

477 entry is None 

478 for entry in [ 

479 self.keras_hdf5, 

480 self.onnx, 

481 self.pytorch_state_dict, 

482 self.tensorflow_js, 

483 self.tensorflow_saved_model_bundle, 

484 self.torchscript, 

485 ] 

486 ): 

487 raise ValueError("Missing weights entry") 

488 

489 return self 

490 

491 def __getitem__( 

492 self, 

493 key: WeightsFormat, 

494 ): 

495 if key == "keras_hdf5": 

496 ret = self.keras_hdf5 

497 elif key == "onnx": 

498 ret = self.onnx 

499 elif key == "pytorch_state_dict": 

500 ret = self.pytorch_state_dict 

501 elif key == "tensorflow_js": 

502 ret = self.tensorflow_js 

503 elif key == "tensorflow_saved_model_bundle": 

504 ret = self.tensorflow_saved_model_bundle 

505 elif key == "torchscript": 

506 ret = self.torchscript 

507 else: 

508 raise KeyError(key) 

509 

510 if ret is None: 

511 raise KeyError(key) 

512 

513 return ret 

514 

515 @property 

516 def available_formats(self): 

517 return { 

518 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

519 **({} if self.onnx is None else {"onnx": self.onnx}), 

520 **( 

521 {} 

522 if self.pytorch_state_dict is None 

523 else {"pytorch_state_dict": self.pytorch_state_dict} 

524 ), 

525 **( 

526 {} 

527 if self.tensorflow_js is None 

528 else {"tensorflow_js": self.tensorflow_js} 

529 ), 

530 **( 

531 {} 

532 if self.tensorflow_saved_model_bundle is None 

533 else { 

534 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

535 } 

536 ), 

537 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

538 } 

539 

540 @property 

541 def missing_formats(self): 

542 return { 

543 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

544 } 

545 

546 

547class ParameterizedInputShape(Node): 

548 """A sequence of valid shapes given by `shape_k = min + k * step for k in {0, 1, ...}`.""" 

549 

550 min: NotEmpty[List[int]] 

551 """The minimum input shape""" 

552 

553 step: NotEmpty[List[int]] 

554 """The minimum shape change""" 

555 

556 def __len__(self) -> int: 

557 return len(self.min) 

558 

559 @model_validator(mode="after") 

560 def matching_lengths(self) -> Self: 

561 if len(self.min) != len(self.step): 

562 raise ValueError("`min` and `step` required to have the same length") 

563 

564 return self 

565 

566 

567class ImplicitOutputShape(Node): 

568 """Output tensor shape depending on an input tensor shape. 

569 `shape(output_tensor) = shape(input_tensor) * scale + 2 * offset`""" 

570 

571 reference_tensor: TensorName 

572 """Name of the reference tensor.""" 

573 

574 scale: NotEmpty[List[Optional[float]]] 

575 """output_pix/input_pix for each dimension. 

576 'null' values indicate new dimensions, whose length is defined by 2*`offset`""" 

577 

578 offset: NotEmpty[List[Union[int, Annotated[float, MultipleOf(0.5)]]]] 

579 """Position of origin wrt to input.""" 

580 

581 def __len__(self) -> int: 

582 return len(self.scale) 

583 

584 @model_validator(mode="after") 

585 def matching_lengths(self) -> Self: 

586 if len(self.scale) != len(self.offset): 

587 raise ValueError( 

588 f"scale {self.scale} has to have same length as offset {self.offset}!" 

589 ) 

590 # if we have an expanded dimension, make sure that it's offet is not zero 

591 for sc, off in zip(self.scale, self.offset): 

592 if sc is None and not off: 

593 raise ValueError("`offset` must not be zero if `scale` is none/zero") 

594 

595 return self 

596 

597 

598class TensorDescrBase(Node): 

599 name: TensorName 

600 """Tensor name. No duplicates are allowed.""" 

601 

602 description: str = "" 

603 

604 axes: AxesStr 

605 """Axes identifying characters. Same length and order as the axes in `shape`. 

606 | axis | description | 

607 | --- | --- | 

608 | b | batch (groups multiple samples) | 

609 | i | instance/index/element | 

610 | t | time | 

611 | c | channel | 

612 | z | spatial dimension z | 

613 | y | spatial dimension y | 

614 | x | spatial dimension x | 

615 """ 

616 

617 data_range: Optional[ 

618 Tuple[Annotated[float, AllowInfNan(True)], Annotated[float, AllowInfNan(True)]] 

619 ] = None 

620 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

621 If not specified, the full data range that can be expressed in `data_type` is allowed.""" 

622 

623 

624class BinarizeKwargs(KwargsNode): 

625 """key word arguments for `BinarizeDescr`""" 

626 

627 threshold: float 

628 """The fixed threshold""" 

629 

630 

631class BinarizeDescr(NodeWithExplicitlySetFields): 

632 """BinarizeDescr the tensor with a fixed `BinarizeKwargs.threshold`. 

633 Values above the threshold will be set to one, values below the threshold to zero. 

634 """ 

635 

636 implemented_name: ClassVar[Literal["binarize"]] = "binarize" 

637 if TYPE_CHECKING: 

638 name: Literal["binarize"] = "binarize" 

639 else: 

640 name: Literal["binarize"] 

641 

642 kwargs: BinarizeKwargs 

643 

644 

645class ClipKwargs(KwargsNode): 

646 """key word arguments for `ClipDescr`""" 

647 

648 min: float 

649 """minimum value for clipping""" 

650 max: float 

651 """maximum value for clipping""" 

652 

653 

654class ClipDescr(NodeWithExplicitlySetFields): 

655 """Clip tensor values to a range. 

656 

657 Set tensor values below `ClipKwargs.min` to `ClipKwargs.min` 

658 and above `ClipKwargs.max` to `ClipKwargs.max`. 

659 """ 

660 

661 implemented_name: ClassVar[Literal["clip"]] = "clip" 

662 if TYPE_CHECKING: 

663 name: Literal["clip"] = "clip" 

664 else: 

665 name: Literal["clip"] 

666 

667 kwargs: ClipKwargs 

668 

669 

670class ScaleLinearKwargs(KwargsNode): 

671 """key word arguments for `ScaleLinearDescr`""" 

672 

673 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 

674 """The subset of axes to scale jointly. 

675 For example xy to scale the two image axes for 2d data jointly.""" 

676 

677 gain: Union[float, List[float]] = 1.0 

678 """multiplicative factor""" 

679 

680 offset: Union[float, List[float]] = 0.0 

681 """additive term""" 

682 

683 @model_validator(mode="after") 

684 def either_gain_or_offset(self) -> Self: 

685 if ( 

686 self.gain == 1.0 

687 or isinstance(self.gain, list) 

688 and all(g == 1.0 for g in self.gain) 

689 ) and ( 

690 self.offset == 0.0 

691 or isinstance(self.offset, list) 

692 and all(off == 0.0 for off in self.offset) 

693 ): 

694 raise ValueError( 

695 "Redunt linear scaling not allowd. Set `gain` != 1.0 and/or `offset` !=" 

696 + " 0.0." 

697 ) 

698 

699 return self 

700 

701 

702class ScaleLinearDescr(NodeWithExplicitlySetFields): 

703 """Fixed linear scaling.""" 

704 

705 implemented_name: ClassVar[Literal["scale_linear"]] = "scale_linear" 

706 if TYPE_CHECKING: 

707 name: Literal["scale_linear"] = "scale_linear" 

708 else: 

709 name: Literal["scale_linear"] 

710 

711 kwargs: ScaleLinearKwargs 

712 

713 

714class SigmoidDescr(NodeWithExplicitlySetFields): 

715 """The logistic sigmoid funciton, a.k.a. expit function.""" 

716 

717 implemented_name: ClassVar[Literal["sigmoid"]] = "sigmoid" 

718 if TYPE_CHECKING: 

719 name: Literal["sigmoid"] = "sigmoid" 

720 else: 

721 name: Literal["sigmoid"] 

722 

723 @property 

724 def kwargs(self) -> KwargsNode: 

725 """empty kwargs""" 

726 return KwargsNode() 

727 

728 

729class ZeroMeanUnitVarianceKwargs(KwargsNode): 

730 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

731 

732 mode: Literal["fixed", "per_dataset", "per_sample"] = "fixed" 

733 """Mode for computing mean and variance. 

734 | mode | description | 

735 | ----------- | ------------------------------------ | 

736 | fixed | Fixed values for mean and variance | 

737 | per_dataset | Compute for the entire dataset | 

738 | per_sample | Compute for each sample individually | 

739 """ 

740 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 

741 """The subset of axes to normalize jointly. 

742 For example `xy` to normalize the two image axes for 2d data jointly.""" 

743 

744 mean: Annotated[ 

745 Union[float, NotEmpty[List[float]], None], Field(examples=[(1.1, 2.2, 3.3)]) 

746 ] = None 

747 """The mean value(s) to use for `mode: fixed`. 

748 For example `[1.1, 2.2, 3.3]` in the case of a 3 channel image with `axes: xy`.""" 

749 # todo: check if means match input axes (for mode 'fixed') 

750 

751 std: Annotated[ 

752 Union[float, NotEmpty[List[float]], None], Field(examples=[(0.1, 0.2, 0.3)]) 

753 ] = None 

754 """The standard deviation values to use for `mode: fixed`. Analogous to mean.""" 

755 

756 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

757 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

758 

759 @model_validator(mode="after") 

760 def mean_and_std_match_mode(self) -> Self: 

761 if self.mode == "fixed" and (self.mean is None or self.std is None): 

762 raise ValueError("`mean` and `std` are required for `mode: fixed`.") 

763 elif self.mode != "fixed" and (self.mean is not None or self.std is not None): 

764 raise ValueError(f"`mean` and `std` not allowed for `mode: {self.mode}`") 

765 

766 return self 

767 

768 

769class ZeroMeanUnitVarianceDescr(NodeWithExplicitlySetFields): 

770 """Subtract mean and divide by variance.""" 

771 

772 implemented_name: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

773 "zero_mean_unit_variance" 

774 ) 

775 if TYPE_CHECKING: 

776 name: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

777 else: 

778 name: Literal["zero_mean_unit_variance"] 

779 

780 kwargs: ZeroMeanUnitVarianceKwargs 

781 

782 

783class ScaleRangeKwargs(KwargsNode): 

784 """key word arguments for `ScaleRangeDescr` 

785 

786 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

787 this processing step normalizes data to the [0, 1] intervall. 

788 For other percentiles the normalized values will partially be outside the [0, 1] 

789 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

790 normalized values to a range. 

791 """ 

792 

793 mode: Literal["per_dataset", "per_sample"] 

794 """Mode for computing percentiles. 

795 | mode | description | 

796 | ----------- | ------------------------------------ | 

797 | per_dataset | compute for the entire dataset | 

798 | per_sample | compute for each sample individually | 

799 """ 

800 axes: Annotated[AxesInCZYX, Field(examples=["xy"])] 

801 """The subset of axes to normalize jointly. 

802 For example xy to normalize the two image axes for 2d data jointly.""" 

803 

804 min_percentile: Annotated[Union[int, float], Interval(ge=0, lt=100)] = 0.0 

805 """The lower percentile used to determine the value to align with zero.""" 

806 

807 max_percentile: Annotated[Union[int, float], Interval(gt=1, le=100)] = 100.0 

808 """The upper percentile used to determine the value to align with one. 

809 Has to be bigger than `min_percentile`. 

810 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

811 accepting percentiles specified in the range 0.0 to 1.0.""" 

812 

813 @model_validator(mode="after") 

814 def min_smaller_max(self, info: ValidationInfo) -> Self: 

815 if self.min_percentile >= self.max_percentile: 

816 raise ValueError( 

817 f"min_percentile {self.min_percentile} >= max_percentile" 

818 + f" {self.max_percentile}" 

819 ) 

820 

821 return self 

822 

823 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

824 """Epsilon for numeric stability. 

825 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

826 with `v_lower,v_upper` values at the respective percentiles.""" 

827 

828 reference_tensor: Optional[TensorName] = None 

829 """Tensor name to compute the percentiles from. Default: The tensor itself. 

830 For any tensor in `inputs` only input tensor references are allowed. 

831 For a tensor in `outputs` only input tensor refereences are allowed if `mode: per_dataset`""" 

832 

833 

834class ScaleRangeDescr(NodeWithExplicitlySetFields): 

835 """Scale with percentiles.""" 

836 

837 implemented_name: ClassVar[Literal["scale_range"]] = "scale_range" 

838 if TYPE_CHECKING: 

839 name: Literal["scale_range"] = "scale_range" 

840 else: 

841 name: Literal["scale_range"] 

842 

843 kwargs: ScaleRangeKwargs 

844 

845 

846class ScaleMeanVarianceKwargs(KwargsNode): 

847 """key word arguments for `ScaleMeanVarianceDescr`""" 

848 

849 mode: Literal["per_dataset", "per_sample"] 

850 """Mode for computing mean and variance. 

851 | mode | description | 

852 | ----------- | ------------------------------------ | 

853 | per_dataset | Compute for the entire dataset | 

854 | per_sample | Compute for each sample individually | 

855 """ 

856 

857 reference_tensor: TensorName 

858 """Name of tensor to match.""" 

859 

860 axes: Annotated[Optional[AxesInCZYX], Field(examples=["xy"])] = None 

861 """The subset of axes to scale jointly. 

862 For example xy to normalize the two image axes for 2d data jointly. 

863 Default: scale all non-batch axes jointly.""" 

864 

865 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

866 """Epsilon for numeric stability: 

867 "`out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.""" 

868 

869 

870class ScaleMeanVarianceDescr(NodeWithExplicitlySetFields): 

871 """Scale the tensor s.t. its mean and variance match a reference tensor.""" 

872 

873 implemented_name: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

874 if TYPE_CHECKING: 

875 name: Literal["scale_mean_variance"] = "scale_mean_variance" 

876 else: 

877 name: Literal["scale_mean_variance"] 

878 

879 kwargs: ScaleMeanVarianceKwargs 

880 

881 

882PreprocessingDescr = Annotated[ 

883 Union[ 

884 BinarizeDescr, 

885 ClipDescr, 

886 ScaleLinearDescr, 

887 SigmoidDescr, 

888 ZeroMeanUnitVarianceDescr, 

889 ScaleRangeDescr, 

890 ], 

891 Discriminator("name"), 

892] 

893PostprocessingDescr = Annotated[ 

894 Union[ 

895 BinarizeDescr, 

896 ClipDescr, 

897 ScaleLinearDescr, 

898 SigmoidDescr, 

899 ZeroMeanUnitVarianceDescr, 

900 ScaleRangeDescr, 

901 ScaleMeanVarianceDescr, 

902 ], 

903 Discriminator("name"), 

904] 

905 

906 

907class InputTensorDescr(TensorDescrBase): 

908 data_type: Literal["float32", "uint8", "uint16"] 

909 """For now an input tensor is expected to be given as `float32`. 

910 The data flow in bioimage.io models is explained 

911 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 

912 

913 shape: Annotated[ 

914 Union[Sequence[int], ParameterizedInputShape], 

915 Field( 

916 examples=[(1, 512, 512, 1), dict(min=(1, 64, 64, 1), step=(0, 32, 32, 0))] 

917 ), 

918 ] 

919 """Specification of input tensor shape.""" 

920 

921 preprocessing: List[PreprocessingDescr] = Field( 

922 default_factory=cast( # TODO: (py>3.8) use list[PreprocessingDesr] 

923 Callable[[], List[PreprocessingDescr]], list 

924 ) 

925 ) 

926 """Description of how this input should be preprocessed.""" 

927 

928 @model_validator(mode="after") 

929 def zero_batch_step_and_one_batch_size(self) -> Self: 

930 bidx = self.axes.find("b") 

931 if bidx == -1: 

932 return self 

933 

934 if isinstance(self.shape, ParameterizedInputShape): 

935 step = self.shape.step 

936 shape = self.shape.min 

937 if step[bidx] != 0: 

938 raise ValueError( 

939 "Input shape step has to be zero in the batch dimension (the batch" 

940 + " dimension can always be increased, but `step` should specify how" 

941 + " to increase the minimal shape to find the largest single batch" 

942 + " shape)" 

943 ) 

944 else: 

945 shape = self.shape 

946 

947 if shape[bidx] != 1: 

948 raise ValueError("Input shape has to be 1 in the batch dimension b.") 

949 

950 return self 

951 

952 @model_validator(mode="after") 

953 def validate_preprocessing_kwargs(self) -> Self: 

954 for p in self.preprocessing: 

955 kwargs_axes = p.kwargs.get("axes") 

956 if isinstance(kwargs_axes, str) and any( 

957 a not in self.axes for a in kwargs_axes 

958 ): 

959 raise ValueError("`kwargs.axes` needs to be subset of `axes`") 

960 

961 return self 

962 

963 

964class OutputTensorDescr(TensorDescrBase): 

965 data_type: Literal[ 

966 "float32", 

967 "float64", 

968 "uint8", 

969 "int8", 

970 "uint16", 

971 "int16", 

972 "uint32", 

973 "int32", 

974 "uint64", 

975 "int64", 

976 "bool", 

977 ] 

978 """Data type. 

979 The data flow in bioimage.io models is explained 

980 [in this diagram.](https://docs.google.com/drawings/d/1FTw8-Rn6a6nXdkZ_SkMumtcjvur9mtIhRqLwnKqZNHM/edit).""" 

981 

982 shape: Union[Sequence[int], ImplicitOutputShape] 

983 """Output tensor shape.""" 

984 

985 halo: Optional[Sequence[int]] = None 

986 """The `halo` that should be cropped from the output tensor to avoid boundary effects. 

987 The `halo` is to be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. 

988 To document a `halo` that is already cropped by the model `shape.offset` has to be used instead.""" 

989 

990 postprocessing: List[PostprocessingDescr] = Field( 

991 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

992 ) 

993 """Description of how this output should be postprocessed.""" 

994 

995 @model_validator(mode="after") 

996 def matching_halo_length(self) -> Self: 

997 if self.halo and len(self.halo) != len(self.shape): 

998 raise ValueError( 

999 f"halo {self.halo} has to have same length as shape {self.shape}!" 

1000 ) 

1001 

1002 return self 

1003 

1004 @model_validator(mode="after") 

1005 def validate_postprocessing_kwargs(self) -> Self: 

1006 for p in self.postprocessing: 

1007 kwargs_axes = p.kwargs.get("axes", "") 

1008 if not isinstance(kwargs_axes, str): 

1009 raise ValueError(f"Expected {kwargs_axes} to be a string") 

1010 

1011 if any(a not in self.axes for a in kwargs_axes): 

1012 raise ValueError("`kwargs.axes` needs to be subset of axes") 

1013 

1014 return self 

1015 

1016 

1017KnownRunMode = Literal["deepimagej"] 

1018 

1019 

1020class RunMode(Node): 

1021 name: Annotated[ 

1022 Union[KnownRunMode, str], warn(KnownRunMode, "Unknown run mode '{value}'.") 

1023 ] 

1024 """Run mode name""" 

1025 

1026 kwargs: Dict[str, Any] = Field( 

1027 default_factory=cast(Callable[[], Dict[str, Any]], dict) 

1028 ) 

1029 """Run mode specific key word arguments""" 

1030 

1031 

1032class LinkedModel(Node): 

1033 """Reference to a bioimage.io model.""" 

1034 

1035 id: Annotated[ModelId, Field(examples=["affable-shark", "ambitious-sloth"])] 

1036 """A valid model `id` from the bioimage.io collection.""" 

1037 

1038 version_number: Optional[int] = None 

1039 """version number (n-th published version, not the semantic version) of linked model""" 

1040 

1041 

1042def package_weights( 

1043 value: Node, # Union[v0_4.WeightsDescr, v0_5.WeightsDescr] 

1044 handler: SerializerFunctionWrapHandler, 

1045 info: SerializationInfo, 

1046): 

1047 ctxt = packaging_context_var.get() 

1048 if ctxt is not None and ctxt.weights_priority_order is not None: 

1049 for wf in ctxt.weights_priority_order: 

1050 w = getattr(value, wf, None) 

1051 if w is not None: 

1052 break 

1053 else: 

1054 raise ValueError( 

1055 "None of the weight formats in `weights_priority_order`" 

1056 + f" ({ctxt.weights_priority_order}) is present in the given model." 

1057 ) 

1058 

1059 assert isinstance(w, Node), type(w) 

1060 # construct WeightsDescr with new single weight format entry 

1061 new_w = w.model_construct(**{k: v for k, v in w if k != "parent"}) 

1062 value = value.model_construct(None, **{wf: new_w}) 

1063 

1064 return handler( 

1065 value, 

1066 info, # pyright: ignore[reportArgumentType] # taken from pydantic docs 

1067 ) 

1068 

1069 

1070class ModelDescr(GenericModelDescrBase): 

1071 """Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights. 

1072 

1073 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

1074 """ 

1075 

1076 implemented_format_version: ClassVar[Literal["0.4.10"]] = "0.4.10" 

1077 if TYPE_CHECKING: 

1078 format_version: Literal["0.4.10"] = "0.4.10" 

1079 else: 

1080 format_version: Literal["0.4.10"] 

1081 """Version of the bioimage.io model description specification used. 

1082 When creating a new model always use the latest micro/patch version described here. 

1083 The `format_version` is important for any consumer software to understand how to parse the fields. 

1084 """ 

1085 

1086 implemented_type: ClassVar[Literal["model"]] = "model" 

1087 if TYPE_CHECKING: 

1088 type: Literal["model"] = "model" 

1089 else: 

1090 type: Literal["model"] 

1091 """Specialized resource type 'model'""" 

1092 

1093 id: Optional[ModelId] = None 

1094 """bioimage.io-wide unique resource identifier 

1095 assigned by bioimage.io; version **un**specific.""" 

1096 

1097 authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory 

1098 List[Author] 

1099 ] 

1100 """The authors are the creators of the model RDF and the primary points of contact.""" 

1101 

1102 documentation: Annotated[ 

1103 FileSource_, 

1104 Field( 

1105 examples=[ 

1106 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 

1107 "README.md", 

1108 ], 

1109 ), 

1110 ] 

1111 """URL or relative path to a markdown file with additional documentation. 

1112 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

1113 The documentation should include a '[#[#]]# Validation' (sub)section 

1114 with details on how to quantitatively validate the model on unseen data.""" 

1115 

1116 inputs: NotEmpty[List[InputTensorDescr]] 

1117 """Describes the input tensors expected by this model.""" 

1118 

1119 license: Annotated[ 

1120 Union[LicenseId, str], 

1121 warn(LicenseId, "Unknown license id '{value}'."), 

1122 Field(examples=["CC0-1.0", "MIT", "BSD-2-Clause"]), 

1123 ] 

1124 """A [SPDX license identifier](https://spdx.org/licenses/). 

1125 We do notsupport custom license beyond the SPDX license list, if you need that please 

1126 [open a GitHub issue](https://github.com/bioimage-io/spec-bioimage-io/issues/new/choose 

1127 ) to discuss your intentions with the community.""" 

1128 

1129 name: Annotated[ 

1130 str, 

1131 MinLen(1), 

1132 warn(MinLen(5), "Name shorter than 5 characters.", INFO), 

1133 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

1134 ] 

1135 """A human-readable name of this model. 

1136 It should be no longer than 64 characters and only contain letter, number, underscore, minus or space characters.""" 

1137 

1138 outputs: NotEmpty[List[OutputTensorDescr]] 

1139 """Describes the output tensors.""" 

1140 

1141 @field_validator("inputs", "outputs") 

1142 @classmethod 

1143 def unique_tensor_descr_names( 

1144 cls, value: Sequence[Union[InputTensorDescr, OutputTensorDescr]] 

1145 ) -> Sequence[Union[InputTensorDescr, OutputTensorDescr]]: 

1146 unique_names = {str(v.name) for v in value} 

1147 if len(unique_names) != len(value): 

1148 raise ValueError("Duplicate tensor descriptor names") 

1149 

1150 return value 

1151 

1152 @model_validator(mode="after") 

1153 def unique_io_names(self) -> Self: 

1154 unique_names = {str(ss.name) for s in (self.inputs, self.outputs) for ss in s} 

1155 if len(unique_names) != (len(self.inputs) + len(self.outputs)): 

1156 raise ValueError("Duplicate tensor descriptor names across inputs/outputs") 

1157 

1158 return self 

1159 

1160 @model_validator(mode="after") 

1161 def minimum_shape2valid_output(self) -> Self: 

1162 tensors_by_name: Dict[ 

1163 TensorName, Union[InputTensorDescr, OutputTensorDescr] 

1164 ] = {t.name: t for t in self.inputs + self.outputs} 

1165 

1166 for out in self.outputs: 

1167 if isinstance(out.shape, ImplicitOutputShape): 

1168 ndim_ref = len(tensors_by_name[out.shape.reference_tensor].shape) 

1169 ndim_out_ref = len( 

1170 [scale for scale in out.shape.scale if scale is not None] 

1171 ) 

1172 if ndim_ref != ndim_out_ref: 

1173 expanded_dim_note = ( 

1174 " Note that expanded dimensions (`scale`: null) are not" 

1175 + f" counted for {out.name}'sdimensionality here." 

1176 if None in out.shape.scale 

1177 else "" 

1178 ) 

1179 raise ValueError( 

1180 f"Referenced tensor '{out.shape.reference_tensor}' with" 

1181 + f" {ndim_ref} dimensions does not match output tensor" 

1182 + f" '{out.name}' with" 

1183 + f" {ndim_out_ref} dimensions.{expanded_dim_note}" 

1184 ) 

1185 

1186 min_out_shape = self._get_min_shape(out, tensors_by_name) 

1187 if out.halo: 

1188 halo = out.halo 

1189 halo_msg = f" for halo {out.halo}" 

1190 else: 

1191 halo = [0] * len(min_out_shape) 

1192 halo_msg = "" 

1193 

1194 if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]): 

1195 raise ValueError( 

1196 f"Minimal shape {min_out_shape} of output {out.name} is too" 

1197 + f" small{halo_msg}." 

1198 ) 

1199 

1200 return self 

1201 

1202 @classmethod 

1203 def _get_min_shape( 

1204 cls, 

1205 t: Union[InputTensorDescr, OutputTensorDescr], 

1206 tensors_by_name: Dict[TensorName, Union[InputTensorDescr, OutputTensorDescr]], 

1207 ) -> Sequence[int]: 

1208 """output with subtracted halo has to result in meaningful output even for the minimal input 

1209 see https://github.com/bioimage-io/spec-bioimage-io/issues/392 

1210 """ 

1211 if isinstance(t.shape, collections.abc.Sequence): 

1212 return t.shape 

1213 elif isinstance(t.shape, ParameterizedInputShape): 

1214 return t.shape.min 

1215 elif isinstance(t.shape, ImplicitOutputShape): 

1216 pass 

1217 else: 

1218 assert_never(t.shape) 

1219 

1220 ref_shape = cls._get_min_shape( 

1221 tensors_by_name[t.shape.reference_tensor], tensors_by_name 

1222 ) 

1223 

1224 if None not in t.shape.scale: 

1225 scale: Sequence[float, ...] = t.shape.scale # type: ignore 

1226 else: 

1227 expanded_dims = [idx for idx, sc in enumerate(t.shape.scale) if sc is None] 

1228 new_ref_shape: List[int] = [] 

1229 for idx in range(len(t.shape.scale)): 

1230 ref_idx = idx - sum(int(exp < idx) for exp in expanded_dims) 

1231 new_ref_shape.append(1 if idx in expanded_dims else ref_shape[ref_idx]) 

1232 

1233 ref_shape = new_ref_shape 

1234 assert len(ref_shape) == len(t.shape.scale) 

1235 scale = [0.0 if sc is None else sc for sc in t.shape.scale] 

1236 

1237 offset = t.shape.offset 

1238 assert len(offset) == len(scale) 

1239 return [int(rs * s + 2 * off) for rs, s, off in zip(ref_shape, scale, offset)] 

1240 

1241 @model_validator(mode="after") 

1242 def validate_tensor_references_in_inputs(self) -> Self: 

1243 for t in self.inputs: 

1244 for proc in t.preprocessing: 

1245 if "reference_tensor" not in proc.kwargs: 

1246 continue 

1247 

1248 ref_tensor = proc.kwargs["reference_tensor"] 

1249 if ref_tensor is not None and str(ref_tensor) not in { 

1250 str(t.name) for t in self.inputs 

1251 }: 

1252 raise ValueError(f"'{ref_tensor}' not found in inputs") 

1253 

1254 if ref_tensor == t.name: 

1255 raise ValueError( 

1256 f"invalid self reference for preprocessing of tensor {t.name}" 

1257 ) 

1258 

1259 return self 

1260 

1261 @model_validator(mode="after") 

1262 def validate_tensor_references_in_outputs(self) -> Self: 

1263 for t in self.outputs: 

1264 for proc in t.postprocessing: 

1265 if "reference_tensor" not in proc.kwargs: 

1266 continue 

1267 ref_tensor = proc.kwargs["reference_tensor"] 

1268 if ref_tensor is not None and str(ref_tensor) not in { 

1269 str(t.name) for t in self.inputs 

1270 }: 

1271 raise ValueError(f"{ref_tensor} not found in inputs") 

1272 

1273 return self 

1274 

1275 packaged_by: List[Author] = Field( 

1276 default_factory=cast(Callable[[], List[Author]], list) 

1277 ) 

1278 """The persons that have packaged and uploaded this model. 

1279 Only required if those persons differ from the `authors`.""" 

1280 

1281 parent: Optional[LinkedModel] = None 

1282 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

1283 

1284 @field_validator("parent", mode="before") 

1285 @classmethod 

1286 def ignore_url_parent(cls, parent: Any): 

1287 if isinstance(parent, dict): 

1288 return None 

1289 

1290 else: 

1291 return parent 

1292 

1293 run_mode: Optional[RunMode] = None 

1294 """Custom run mode for this model: for more complex prediction procedures like test time 

1295 data augmentation that currently cannot be expressed in the specification. 

1296 No standard run modes are defined yet.""" 

1297 

1298 sample_inputs: List[FileSource_] = Field( 

1299 default_factory=cast(Callable[[], List[FileSource_]], list) 

1300 ) 

1301 """URLs/relative paths to sample inputs to illustrate possible inputs for the model, 

1302 for example stored as PNG or TIFF images. 

1303 The sample files primarily serve to inform a human user about an example use case""" 

1304 

1305 sample_outputs: List[FileSource_] = Field( 

1306 default_factory=cast(Callable[[], List[FileSource_]], list) 

1307 ) 

1308 """URLs/relative paths to sample outputs corresponding to the `sample_inputs`.""" 

1309 

1310 test_inputs: NotEmpty[ 

1311 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]] 

1312 ] 

1313 """Test input tensors compatible with the `inputs` description for a **single test case**. 

1314 This means if your model has more than one input, you should provide one URL/relative path for each input. 

1315 Each test input should be a file with an ndarray in 

1316 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1317 The extension must be '.npy'.""" 

1318 

1319 test_outputs: NotEmpty[ 

1320 List[Annotated[FileSource_, WithSuffix(".npy", case_sensitive=True)]] 

1321 ] 

1322 """Analog to `test_inputs`.""" 

1323 

1324 timestamp: Datetime 

1325 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

1326 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat).""" 

1327 

1328 training_data: Union[LinkedDataset, DatasetDescr, None] = None 

1329 """The dataset used to train this model""" 

1330 

1331 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

1332 """The weights for this model. 

1333 Weights can be given for different formats, but should otherwise be equivalent. 

1334 The available weight formats determine which consumers can use this model.""" 

1335 

1336 @model_validator(mode="before") 

1337 @classmethod 

1338 def _convert_from_older_format( 

1339 cls, data: BioimageioYamlContent, / 

1340 ) -> BioimageioYamlContent: 

1341 convert_from_older_format(data) 

1342 return data 

1343 

1344 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

1345 data = [load_array(ipt) for ipt in self.test_inputs] 

1346 assert all(isinstance(d, np.ndarray) for d in data) 

1347 return data 

1348 

1349 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

1350 data = [load_array(out) for out in self.test_outputs] 

1351 assert all(isinstance(d, np.ndarray) for d in data) 

1352 return data