Coverage for bioimageio/spec/model/v0_5.py: 75%

1315 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 09:20 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from abc import ABC 

8from copy import deepcopy 

9from itertools import chain 

10from math import ceil 

11from pathlib import Path, PurePosixPath 

12from tempfile import mkdtemp 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 ClassVar, 

18 Dict, 

19 Generic, 

20 List, 

21 Literal, 

22 Mapping, 

23 NamedTuple, 

24 Optional, 

25 Sequence, 

26 Set, 

27 Tuple, 

28 Type, 

29 TypeVar, 

30 Union, 

31 cast, 

32) 

33 

34import numpy as np 

35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

37from loguru import logger 

38from numpy.typing import NDArray 

39from pydantic import ( 

40 AfterValidator, 

41 Discriminator, 

42 Field, 

43 RootModel, 

44 SerializationInfo, 

45 SerializerFunctionWrapHandler, 

46 Tag, 

47 ValidationInfo, 

48 WrapSerializer, 

49 field_validator, 

50 model_serializer, 

51 model_validator, 

52) 

53from typing_extensions import Annotated, Self, assert_never, get_args 

54 

55from .._internal.common_nodes import ( 

56 InvalidDescr, 

57 Node, 

58 NodeWithExplicitlySetFields, 

59) 

60from .._internal.constants import DTYPE_LIMITS 

61from .._internal.field_warning import issue_warning, warn 

62from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

63from .._internal.io import FileDescr as FileDescr 

64from .._internal.io import ( 

65 FileSource, 

66 WithSuffix, 

67 YamlValue, 

68 get_reader, 

69 wo_special_file_name, 

70) 

71from .._internal.io_basics import Sha256 as Sha256 

72from .._internal.io_packaging import ( 

73 FileDescr_, 

74 FileSource_, 

75 package_file_descr_serializer, 

76) 

77from .._internal.io_utils import load_array 

78from .._internal.node_converter import Converter 

79from .._internal.types import ( 

80 AbsoluteTolerance, 

81 LowerCaseIdentifier, 

82 LowerCaseIdentifierAnno, 

83 MismatchedElementsPerMillion, 

84 RelativeTolerance, 

85) 

86from .._internal.types import Datetime as Datetime 

87from .._internal.types import Identifier as Identifier 

88from .._internal.types import NotEmpty as NotEmpty 

89from .._internal.types import SiUnit as SiUnit 

90from .._internal.url import HttpUrl as HttpUrl 

91from .._internal.validation_context import get_validation_context 

92from .._internal.validator_annotations import RestrictCharacters 

93from .._internal.version_type import Version as Version 

94from .._internal.warning_levels import INFO 

95from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

96from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

97from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

98from ..dataset.v0_3 import DatasetId as DatasetId 

99from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

100from ..dataset.v0_3 import Uploader as Uploader 

101from ..generic.v0_3 import ( 

102 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

103) 

104from ..generic.v0_3 import Author as Author 

105from ..generic.v0_3 import BadgeDescr as BadgeDescr 

106from ..generic.v0_3 import CiteEntry as CiteEntry 

107from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

108from ..generic.v0_3 import Doi as Doi 

109from ..generic.v0_3 import ( 

110 FileSource_documentation, 

111 GenericModelDescrBase, 

112 LinkedResourceBase, 

113 _author_conv, # pyright: ignore[reportPrivateUsage] 

114 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

115) 

116from ..generic.v0_3 import LicenseId as LicenseId 

117from ..generic.v0_3 import LinkedResource as LinkedResource 

118from ..generic.v0_3 import Maintainer as Maintainer 

119from ..generic.v0_3 import OrcidId as OrcidId 

120from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

121from ..generic.v0_3 import ResourceId as ResourceId 

122from .v0_4 import Author as _Author_v0_4 

123from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

124from .v0_4 import CallableFromDepencency as CallableFromDepencency 

125from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

126from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

127from .v0_4 import ClipDescr as _ClipDescr_v0_4 

128from .v0_4 import ClipKwargs as ClipKwargs 

129from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

130from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

131from .v0_4 import KnownRunMode as KnownRunMode 

132from .v0_4 import ModelDescr as _ModelDescr_v0_4 

133from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

134from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

135from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

136from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

137from .v0_4 import ProcessingKwargs as ProcessingKwargs 

138from .v0_4 import RunMode as RunMode 

139from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

140from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

141from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

142from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

143from .v0_4 import TensorName as _TensorName_v0_4 

144from .v0_4 import WeightsFormat as WeightsFormat 

145from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

146from .v0_4 import package_weights 

147 

148SpaceUnit = Literal[ 

149 "attometer", 

150 "angstrom", 

151 "centimeter", 

152 "decimeter", 

153 "exameter", 

154 "femtometer", 

155 "foot", 

156 "gigameter", 

157 "hectometer", 

158 "inch", 

159 "kilometer", 

160 "megameter", 

161 "meter", 

162 "micrometer", 

163 "mile", 

164 "millimeter", 

165 "nanometer", 

166 "parsec", 

167 "petameter", 

168 "picometer", 

169 "terameter", 

170 "yard", 

171 "yoctometer", 

172 "yottameter", 

173 "zeptometer", 

174 "zettameter", 

175] 

176"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

177 

178TimeUnit = Literal[ 

179 "attosecond", 

180 "centisecond", 

181 "day", 

182 "decisecond", 

183 "exasecond", 

184 "femtosecond", 

185 "gigasecond", 

186 "hectosecond", 

187 "hour", 

188 "kilosecond", 

189 "megasecond", 

190 "microsecond", 

191 "millisecond", 

192 "minute", 

193 "nanosecond", 

194 "petasecond", 

195 "picosecond", 

196 "second", 

197 "terasecond", 

198 "yoctosecond", 

199 "yottasecond", 

200 "zeptosecond", 

201 "zettasecond", 

202] 

203"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

204 

205AxisType = Literal["batch", "channel", "index", "time", "space"] 

206 

207_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

208 "b": "batch", 

209 "t": "time", 

210 "i": "index", 

211 "c": "channel", 

212 "x": "space", 

213 "y": "space", 

214 "z": "space", 

215} 

216 

217_AXIS_ID_MAP = { 

218 "b": "batch", 

219 "t": "time", 

220 "i": "index", 

221 "c": "channel", 

222} 

223 

224 

225class TensorId(LowerCaseIdentifier): 

226 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

227 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

228 ] 

229 

230 

231def _normalize_axis_id(a: str): 

232 a = str(a) 

233 normalized = _AXIS_ID_MAP.get(a, a) 

234 if a != normalized: 

235 logger.opt(depth=3).warning( 

236 "Normalized axis id from '{}' to '{}'.", a, normalized 

237 ) 

238 return normalized 

239 

240 

241class AxisId(LowerCaseIdentifier): 

242 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

243 Annotated[ 

244 LowerCaseIdentifierAnno, 

245 MaxLen(16), 

246 AfterValidator(_normalize_axis_id), 

247 ] 

248 ] 

249 

250 

251def _is_batch(a: str) -> bool: 

252 return str(a) == "batch" 

253 

254 

255def _is_not_batch(a: str) -> bool: 

256 return not _is_batch(a) 

257 

258 

259NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

260 

261PostprocessingId = Literal[ 

262 "binarize", 

263 "clip", 

264 "ensure_dtype", 

265 "fixed_zero_mean_unit_variance", 

266 "scale_linear", 

267 "scale_mean_variance", 

268 "scale_range", 

269 "sigmoid", 

270 "zero_mean_unit_variance", 

271] 

272PreprocessingId = Literal[ 

273 "binarize", 

274 "clip", 

275 "ensure_dtype", 

276 "scale_linear", 

277 "sigmoid", 

278 "zero_mean_unit_variance", 

279 "scale_range", 

280] 

281 

282 

283SAME_AS_TYPE = "<same as type>" 

284 

285 

286ParameterizedSize_N = int 

287""" 

288Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

289""" 

290 

291 

292class ParameterizedSize(Node): 

293 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

294 

295 - **min** and **step** are given by the model description. 

296 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

297 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

298 This allows to adjust the axis size more generically. 

299 """ 

300 

301 N: ClassVar[Type[int]] = ParameterizedSize_N 

302 """Positive integer to parameterize this axis""" 

303 

304 min: Annotated[int, Gt(0)] 

305 step: Annotated[int, Gt(0)] 

306 

307 def validate_size(self, size: int) -> int: 

308 if size < self.min: 

309 raise ValueError(f"size {size} < {self.min}") 

310 if (size - self.min) % self.step != 0: 

311 raise ValueError( 

312 f"axis of size {size} is not parameterized by `min + n*step` =" 

313 + f" `{self.min} + n*{self.step}`" 

314 ) 

315 

316 return size 

317 

318 def get_size(self, n: ParameterizedSize_N) -> int: 

319 return self.min + self.step * n 

320 

321 def get_n(self, s: int) -> ParameterizedSize_N: 

322 """return smallest n parameterizing a size greater or equal than `s`""" 

323 return ceil((s - self.min) / self.step) 

324 

325 

326class DataDependentSize(Node): 

327 min: Annotated[int, Gt(0)] = 1 

328 max: Annotated[Optional[int], Gt(1)] = None 

329 

330 @model_validator(mode="after") 

331 def _validate_max_gt_min(self): 

332 if self.max is not None and self.min >= self.max: 

333 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

334 

335 return self 

336 

337 def validate_size(self, size: int) -> int: 

338 if size < self.min: 

339 raise ValueError(f"size {size} < {self.min}") 

340 

341 if self.max is not None and size > self.max: 

342 raise ValueError(f"size {size} > {self.max}") 

343 

344 return size 

345 

346 

347class SizeReference(Node): 

348 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

349 

350 `axis.size = reference.size * reference.scale / axis.scale + offset` 

351 

352 Note: 

353 1. The axis and the referenced axis need to have the same unit (or no unit). 

354 2. Batch axes may not be referenced. 

355 3. Fractions are rounded down. 

356 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

357 `concatenable` as well with the same block order. 

358 

359 Example: 

360 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

361 Let's assume that we want to express the image height h in relation to its width w 

362 instead of only accepting input images of exactly 100*49 pixels 

363 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

364 

365 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

366 >>> h = SpaceInputAxis( 

367 ... id=AxisId("h"), 

368 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

369 ... unit="millimeter", 

370 ... scale=4, 

371 ... ) 

372 >>> print(h.size.get_size(h, w)) 

373 49 

374 

375 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

376 """ 

377 

378 tensor_id: TensorId 

379 """tensor id of the reference axis""" 

380 

381 axis_id: AxisId 

382 """axis id of the reference axis""" 

383 

384 offset: int = 0 

385 

386 def get_size( 

387 self, 

388 axis: Union[ 

389 ChannelAxis, 

390 IndexInputAxis, 

391 IndexOutputAxis, 

392 TimeInputAxis, 

393 SpaceInputAxis, 

394 TimeOutputAxis, 

395 TimeOutputAxisWithHalo, 

396 SpaceOutputAxis, 

397 SpaceOutputAxisWithHalo, 

398 ], 

399 ref_axis: Union[ 

400 ChannelAxis, 

401 IndexInputAxis, 

402 IndexOutputAxis, 

403 TimeInputAxis, 

404 SpaceInputAxis, 

405 TimeOutputAxis, 

406 TimeOutputAxisWithHalo, 

407 SpaceOutputAxis, 

408 SpaceOutputAxisWithHalo, 

409 ], 

410 n: ParameterizedSize_N = 0, 

411 ref_size: Optional[int] = None, 

412 ): 

413 """Compute the concrete size for a given axis and its reference axis. 

414 

415 Args: 

416 axis: The axis this `SizeReference` is the size of. 

417 ref_axis: The reference axis to compute the size from. 

418 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

419 and no fixed **ref_size** is given, 

420 **n** is used to compute the size of the parameterized **ref_axis**. 

421 ref_size: Overwrite the reference size instead of deriving it from 

422 **ref_axis** 

423 (**ref_axis.scale** is still used; any given **n** is ignored). 

424 """ 

425 assert ( 

426 axis.size == self 

427 ), "Given `axis.size` is not defined by this `SizeReference`" 

428 

429 assert ( 

430 ref_axis.id == self.axis_id 

431 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

432 

433 assert axis.unit == ref_axis.unit, ( 

434 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

435 f" but {axis.unit}!={ref_axis.unit}" 

436 ) 

437 if ref_size is None: 

438 if isinstance(ref_axis.size, (int, float)): 

439 ref_size = ref_axis.size 

440 elif isinstance(ref_axis.size, ParameterizedSize): 

441 ref_size = ref_axis.size.get_size(n) 

442 elif isinstance(ref_axis.size, DataDependentSize): 

443 raise ValueError( 

444 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

445 ) 

446 elif isinstance(ref_axis.size, SizeReference): 

447 raise ValueError( 

448 "Reference axis referenced in `SizeReference` may not be sized by a" 

449 + " `SizeReference` itself." 

450 ) 

451 else: 

452 assert_never(ref_axis.size) 

453 

454 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

455 

456 @staticmethod 

457 def _get_unit( 

458 axis: Union[ 

459 ChannelAxis, 

460 IndexInputAxis, 

461 IndexOutputAxis, 

462 TimeInputAxis, 

463 SpaceInputAxis, 

464 TimeOutputAxis, 

465 TimeOutputAxisWithHalo, 

466 SpaceOutputAxis, 

467 SpaceOutputAxisWithHalo, 

468 ], 

469 ): 

470 return axis.unit 

471 

472 

473class AxisBase(NodeWithExplicitlySetFields): 

474 id: AxisId 

475 """An axis id unique across all axes of one tensor.""" 

476 

477 description: Annotated[str, MaxLen(128)] = "" 

478 

479 

480class WithHalo(Node): 

481 halo: Annotated[int, Ge(1)] 

482 """The halo should be cropped from the output tensor to avoid boundary effects. 

483 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

484 To document a halo that is already cropped by the model use `size.offset` instead.""" 

485 

486 size: Annotated[ 

487 SizeReference, 

488 Field( 

489 examples=[ 

490 10, 

491 SizeReference( 

492 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

493 ).model_dump(mode="json"), 

494 ] 

495 ), 

496 ] 

497 """reference to another axis with an optional offset (see `SizeReference`)""" 

498 

499 

500BATCH_AXIS_ID = AxisId("batch") 

501 

502 

503class BatchAxis(AxisBase): 

504 implemented_type: ClassVar[Literal["batch"]] = "batch" 

505 if TYPE_CHECKING: 

506 type: Literal["batch"] = "batch" 

507 else: 

508 type: Literal["batch"] 

509 

510 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

511 size: Optional[Literal[1]] = None 

512 """The batch size may be fixed to 1, 

513 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

514 

515 @property 

516 def scale(self): 

517 return 1.0 

518 

519 @property 

520 def concatenable(self): 

521 return True 

522 

523 @property 

524 def unit(self): 

525 return None 

526 

527 

528class ChannelAxis(AxisBase): 

529 implemented_type: ClassVar[Literal["channel"]] = "channel" 

530 if TYPE_CHECKING: 

531 type: Literal["channel"] = "channel" 

532 else: 

533 type: Literal["channel"] 

534 

535 id: NonBatchAxisId = AxisId("channel") 

536 channel_names: NotEmpty[List[Identifier]] 

537 

538 @property 

539 def size(self) -> int: 

540 return len(self.channel_names) 

541 

542 @property 

543 def concatenable(self): 

544 return False 

545 

546 @property 

547 def scale(self) -> float: 

548 return 1.0 

549 

550 @property 

551 def unit(self): 

552 return None 

553 

554 

555class IndexAxisBase(AxisBase): 

556 implemented_type: ClassVar[Literal["index"]] = "index" 

557 if TYPE_CHECKING: 

558 type: Literal["index"] = "index" 

559 else: 

560 type: Literal["index"] 

561 

562 id: NonBatchAxisId = AxisId("index") 

563 

564 @property 

565 def scale(self) -> float: 

566 return 1.0 

567 

568 @property 

569 def unit(self): 

570 return None 

571 

572 

573class _WithInputAxisSize(Node): 

574 size: Annotated[ 

575 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

576 Field( 

577 examples=[ 

578 10, 

579 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

580 SizeReference( 

581 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

582 ).model_dump(mode="json"), 

583 ] 

584 ), 

585 ] 

586 """The size/length of this axis can be specified as 

587 - fixed integer 

588 - parameterized series of valid sizes (`ParameterizedSize`) 

589 - reference to another axis with an optional offset (`SizeReference`) 

590 """ 

591 

592 

593class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

594 concatenable: bool = False 

595 """If a model has a `concatenable` input axis, it can be processed blockwise, 

596 splitting a longer sample axis into blocks matching its input tensor description. 

597 Output axes are concatenable if they have a `SizeReference` to a concatenable 

598 input axis. 

599 """ 

600 

601 

602class IndexOutputAxis(IndexAxisBase): 

603 size: Annotated[ 

604 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

605 Field( 

606 examples=[ 

607 10, 

608 SizeReference( 

609 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

610 ).model_dump(mode="json"), 

611 ] 

612 ), 

613 ] 

614 """The size/length of this axis can be specified as 

615 - fixed integer 

616 - reference to another axis with an optional offset (`SizeReference`) 

617 - data dependent size using `DataDependentSize` (size is only known after model inference) 

618 """ 

619 

620 

621class TimeAxisBase(AxisBase): 

622 implemented_type: ClassVar[Literal["time"]] = "time" 

623 if TYPE_CHECKING: 

624 type: Literal["time"] = "time" 

625 else: 

626 type: Literal["time"] 

627 

628 id: NonBatchAxisId = AxisId("time") 

629 unit: Optional[TimeUnit] = None 

630 scale: Annotated[float, Gt(0)] = 1.0 

631 

632 

633class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

634 concatenable: bool = False 

635 """If a model has a `concatenable` input axis, it can be processed blockwise, 

636 splitting a longer sample axis into blocks matching its input tensor description. 

637 Output axes are concatenable if they have a `SizeReference` to a concatenable 

638 input axis. 

639 """ 

640 

641 

642class SpaceAxisBase(AxisBase): 

643 implemented_type: ClassVar[Literal["space"]] = "space" 

644 if TYPE_CHECKING: 

645 type: Literal["space"] = "space" 

646 else: 

647 type: Literal["space"] 

648 

649 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

650 unit: Optional[SpaceUnit] = None 

651 scale: Annotated[float, Gt(0)] = 1.0 

652 

653 

654class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

655 concatenable: bool = False 

656 """If a model has a `concatenable` input axis, it can be processed blockwise, 

657 splitting a longer sample axis into blocks matching its input tensor description. 

658 Output axes are concatenable if they have a `SizeReference` to a concatenable 

659 input axis. 

660 """ 

661 

662 

663INPUT_AXIS_TYPES = ( 

664 BatchAxis, 

665 ChannelAxis, 

666 IndexInputAxis, 

667 TimeInputAxis, 

668 SpaceInputAxis, 

669) 

670"""intended for isinstance comparisons in py<3.10""" 

671 

672_InputAxisUnion = Union[ 

673 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

674] 

675InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

676 

677 

678class _WithOutputAxisSize(Node): 

679 size: Annotated[ 

680 Union[Annotated[int, Gt(0)], SizeReference], 

681 Field( 

682 examples=[ 

683 10, 

684 SizeReference( 

685 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

686 ).model_dump(mode="json"), 

687 ] 

688 ), 

689 ] 

690 """The size/length of this axis can be specified as 

691 - fixed integer 

692 - reference to another axis with an optional offset (see `SizeReference`) 

693 """ 

694 

695 

696class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

697 pass 

698 

699 

700class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

701 pass 

702 

703 

704def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

705 if isinstance(v, dict): 

706 return "with_halo" if "halo" in v else "wo_halo" 

707 else: 

708 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

709 

710 

711_TimeOutputAxisUnion = Annotated[ 

712 Union[ 

713 Annotated[TimeOutputAxis, Tag("wo_halo")], 

714 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

715 ], 

716 Discriminator(_get_halo_axis_discriminator_value), 

717] 

718 

719 

720class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

721 pass 

722 

723 

724class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

725 pass 

726 

727 

728_SpaceOutputAxisUnion = Annotated[ 

729 Union[ 

730 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

731 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

732 ], 

733 Discriminator(_get_halo_axis_discriminator_value), 

734] 

735 

736 

737_OutputAxisUnion = Union[ 

738 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

739] 

740OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

741 

742OUTPUT_AXIS_TYPES = ( 

743 BatchAxis, 

744 ChannelAxis, 

745 IndexOutputAxis, 

746 TimeOutputAxis, 

747 TimeOutputAxisWithHalo, 

748 SpaceOutputAxis, 

749 SpaceOutputAxisWithHalo, 

750) 

751"""intended for isinstance comparisons in py<3.10""" 

752 

753 

754AnyAxis = Union[InputAxis, OutputAxis] 

755 

756ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

757"""intended for isinstance comparisons in py<3.10""" 

758 

759TVs = Union[ 

760 NotEmpty[List[int]], 

761 NotEmpty[List[float]], 

762 NotEmpty[List[bool]], 

763 NotEmpty[List[str]], 

764] 

765 

766 

767NominalOrOrdinalDType = Literal[ 

768 "float32", 

769 "float64", 

770 "uint8", 

771 "int8", 

772 "uint16", 

773 "int16", 

774 "uint32", 

775 "int32", 

776 "uint64", 

777 "int64", 

778 "bool", 

779] 

780 

781 

782class NominalOrOrdinalDataDescr(Node): 

783 values: TVs 

784 """A fixed set of nominal or an ascending sequence of ordinal values. 

785 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

786 String `values` are interpreted as labels for tensor values 0, ..., N. 

787 Note: as YAML 1.2 does not natively support a "set" datatype, 

788 nominal values should be given as a sequence (aka list/array) as well. 

789 """ 

790 

791 type: Annotated[ 

792 NominalOrOrdinalDType, 

793 Field( 

794 examples=[ 

795 "float32", 

796 "uint8", 

797 "uint16", 

798 "int64", 

799 "bool", 

800 ], 

801 ), 

802 ] = "uint8" 

803 

804 @model_validator(mode="after") 

805 def _validate_values_match_type( 

806 self, 

807 ) -> Self: 

808 incompatible: List[Any] = [] 

809 for v in self.values: 

810 if self.type == "bool": 

811 if not isinstance(v, bool): 

812 incompatible.append(v) 

813 elif self.type in DTYPE_LIMITS: 

814 if ( 

815 isinstance(v, (int, float)) 

816 and ( 

817 v < DTYPE_LIMITS[self.type].min 

818 or v > DTYPE_LIMITS[self.type].max 

819 ) 

820 or (isinstance(v, str) and "uint" not in self.type) 

821 or (isinstance(v, float) and "int" in self.type) 

822 ): 

823 incompatible.append(v) 

824 else: 

825 incompatible.append(v) 

826 

827 if len(incompatible) == 5: 

828 incompatible.append("...") 

829 break 

830 

831 if incompatible: 

832 raise ValueError( 

833 f"data type '{self.type}' incompatible with values {incompatible}" 

834 ) 

835 

836 return self 

837 

838 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

839 

840 @property 

841 def range(self): 

842 if isinstance(self.values[0], str): 

843 return 0, len(self.values) - 1 

844 else: 

845 return min(self.values), max(self.values) 

846 

847 

848IntervalOrRatioDType = Literal[ 

849 "float32", 

850 "float64", 

851 "uint8", 

852 "int8", 

853 "uint16", 

854 "int16", 

855 "uint32", 

856 "int32", 

857 "uint64", 

858 "int64", 

859] 

860 

861 

862class IntervalOrRatioDataDescr(Node): 

863 type: Annotated[ # todo: rename to dtype 

864 IntervalOrRatioDType, 

865 Field( 

866 examples=["float32", "float64", "uint8", "uint16"], 

867 ), 

868 ] = "float32" 

869 range: Tuple[Optional[float], Optional[float]] = ( 

870 None, 

871 None, 

872 ) 

873 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

874 `None` corresponds to min/max of what can be expressed by **type**.""" 

875 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

876 scale: float = 1.0 

877 """Scale for data on an interval (or ratio) scale.""" 

878 offset: Optional[float] = None 

879 """Offset for data on a ratio scale.""" 

880 

881 

882TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

883 

884 

885class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 

886 """processing base class""" 

887 

888 

889class BinarizeKwargs(ProcessingKwargs): 

890 """key word arguments for `BinarizeDescr`""" 

891 

892 threshold: float 

893 """The fixed threshold""" 

894 

895 

896class BinarizeAlongAxisKwargs(ProcessingKwargs): 

897 """key word arguments for `BinarizeDescr`""" 

898 

899 threshold: NotEmpty[List[float]] 

900 """The fixed threshold values along `axis`""" 

901 

902 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

903 """The `threshold` axis""" 

904 

905 

906class BinarizeDescr(ProcessingDescrBase): 

907 """Binarize the tensor with a fixed threshold. 

908 

909 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 

910 will be set to one, values below the threshold to zero. 

911 

912 Examples: 

913 - in YAML 

914 ```yaml 

915 postprocessing: 

916 - id: binarize 

917 kwargs: 

918 axis: 'channel' 

919 threshold: [0.25, 0.5, 0.75] 

920 ``` 

921 - in Python: 

922 >>> postprocessing = [BinarizeDescr( 

923 ... kwargs=BinarizeAlongAxisKwargs( 

924 ... axis=AxisId('channel'), 

925 ... threshold=[0.25, 0.5, 0.75], 

926 ... ) 

927 ... )] 

928 """ 

929 

930 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

931 if TYPE_CHECKING: 

932 id: Literal["binarize"] = "binarize" 

933 else: 

934 id: Literal["binarize"] 

935 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

936 

937 

938class ClipDescr(ProcessingDescrBase): 

939 """Set tensor values below min to min and above max to max. 

940 

941 See `ScaleRangeDescr` for examples. 

942 """ 

943 

944 implemented_id: ClassVar[Literal["clip"]] = "clip" 

945 if TYPE_CHECKING: 

946 id: Literal["clip"] = "clip" 

947 else: 

948 id: Literal["clip"] 

949 

950 kwargs: ClipKwargs 

951 

952 

953class EnsureDtypeKwargs(ProcessingKwargs): 

954 """key word arguments for `EnsureDtypeDescr`""" 

955 

956 dtype: Literal[ 

957 "float32", 

958 "float64", 

959 "uint8", 

960 "int8", 

961 "uint16", 

962 "int16", 

963 "uint32", 

964 "int32", 

965 "uint64", 

966 "int64", 

967 "bool", 

968 ] 

969 

970 

971class EnsureDtypeDescr(ProcessingDescrBase): 

972 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

973 

974 This can for example be used to ensure the inner neural network model gets a 

975 different input tensor data type than the fully described bioimage.io model does. 

976 

977 Examples: 

978 The described bioimage.io model (incl. preprocessing) accepts any 

979 float32-compatible tensor, normalizes it with percentiles and clipping and then 

980 casts it to uint8, which is what the neural network in this example expects. 

981 - in YAML 

982 ```yaml 

983 inputs: 

984 - data: 

985 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

986 preprocessing: 

987 - id: scale_range 

988 kwargs: 

989 axes: ['y', 'x'] 

990 max_percentile: 99.8 

991 min_percentile: 5.0 

992 - id: clip 

993 kwargs: 

994 min: 0.0 

995 max: 1.0 

996 - id: ensure_dtype # the neural network of the model requires uint8 

997 kwargs: 

998 dtype: uint8 

999 ``` 

1000 - in Python: 

1001 >>> preprocessing = [ 

1002 ... ScaleRangeDescr( 

1003 ... kwargs=ScaleRangeKwargs( 

1004 ... axes= (AxisId('y'), AxisId('x')), 

1005 ... max_percentile= 99.8, 

1006 ... min_percentile= 5.0, 

1007 ... ) 

1008 ... ), 

1009 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

1010 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

1011 ... ] 

1012 """ 

1013 

1014 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1015 if TYPE_CHECKING: 

1016 id: Literal["ensure_dtype"] = "ensure_dtype" 

1017 else: 

1018 id: Literal["ensure_dtype"] 

1019 

1020 kwargs: EnsureDtypeKwargs 

1021 

1022 

1023class ScaleLinearKwargs(ProcessingKwargs): 

1024 """Key word arguments for `ScaleLinearDescr`""" 

1025 

1026 gain: float = 1.0 

1027 """multiplicative factor""" 

1028 

1029 offset: float = 0.0 

1030 """additive term""" 

1031 

1032 @model_validator(mode="after") 

1033 def _validate(self) -> Self: 

1034 if self.gain == 1.0 and self.offset == 0.0: 

1035 raise ValueError( 

1036 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1037 + " != 0.0." 

1038 ) 

1039 

1040 return self 

1041 

1042 

1043class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 

1044 """Key word arguments for `ScaleLinearDescr`""" 

1045 

1046 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1047 """The axis of gain and offset values.""" 

1048 

1049 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1050 """multiplicative factor""" 

1051 

1052 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1053 """additive term""" 

1054 

1055 @model_validator(mode="after") 

1056 def _validate(self) -> Self: 

1057 

1058 if isinstance(self.gain, list): 

1059 if isinstance(self.offset, list): 

1060 if len(self.gain) != len(self.offset): 

1061 raise ValueError( 

1062 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1063 ) 

1064 else: 

1065 self.offset = [float(self.offset)] * len(self.gain) 

1066 elif isinstance(self.offset, list): 

1067 self.gain = [float(self.gain)] * len(self.offset) 

1068 else: 

1069 raise ValueError( 

1070 "Do not specify an `axis` for scalar gain and offset values." 

1071 ) 

1072 

1073 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1074 raise ValueError( 

1075 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1076 + " != 0.0." 

1077 ) 

1078 

1079 return self 

1080 

1081 

1082class ScaleLinearDescr(ProcessingDescrBase): 

1083 """Fixed linear scaling. 

1084 

1085 Examples: 

1086 1. Scale with scalar gain and offset 

1087 - in YAML 

1088 ```yaml 

1089 preprocessing: 

1090 - id: scale_linear 

1091 kwargs: 

1092 gain: 2.0 

1093 offset: 3.0 

1094 ``` 

1095 - in Python: 

1096 >>> preprocessing = [ 

1097 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1098 ... ] 

1099 

1100 2. Independent scaling along an axis 

1101 - in YAML 

1102 ```yaml 

1103 preprocessing: 

1104 - id: scale_linear 

1105 kwargs: 

1106 axis: 'channel' 

1107 gain: [1.0, 2.0, 3.0] 

1108 ``` 

1109 - in Python: 

1110 >>> preprocessing = [ 

1111 ... ScaleLinearDescr( 

1112 ... kwargs=ScaleLinearAlongAxisKwargs( 

1113 ... axis=AxisId("channel"), 

1114 ... gain=[1.0, 2.0, 3.0], 

1115 ... ) 

1116 ... ) 

1117 ... ] 

1118 

1119 """ 

1120 

1121 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1122 if TYPE_CHECKING: 

1123 id: Literal["scale_linear"] = "scale_linear" 

1124 else: 

1125 id: Literal["scale_linear"] 

1126 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1127 

1128 

1129class SigmoidDescr(ProcessingDescrBase): 

1130 """The logistic sigmoid funciton, a.k.a. expit function. 

1131 

1132 Examples: 

1133 - in YAML 

1134 ```yaml 

1135 postprocessing: 

1136 - id: sigmoid 

1137 ``` 

1138 - in Python: 

1139 >>> postprocessing = [SigmoidDescr()] 

1140 """ 

1141 

1142 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1143 if TYPE_CHECKING: 

1144 id: Literal["sigmoid"] = "sigmoid" 

1145 else: 

1146 id: Literal["sigmoid"] 

1147 

1148 @property 

1149 def kwargs(self) -> ProcessingKwargs: 

1150 """empty kwargs""" 

1151 return ProcessingKwargs() 

1152 

1153 

1154class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1155 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1156 

1157 mean: float 

1158 """The mean value to normalize with.""" 

1159 

1160 std: Annotated[float, Ge(1e-6)] 

1161 """The standard deviation value to normalize with.""" 

1162 

1163 

1164class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 

1165 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1166 

1167 mean: NotEmpty[List[float]] 

1168 """The mean value(s) to normalize with.""" 

1169 

1170 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1171 """The standard deviation value(s) to normalize with. 

1172 Size must match `mean` values.""" 

1173 

1174 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1175 """The axis of the mean/std values to normalize each entry along that dimension 

1176 separately.""" 

1177 

1178 @model_validator(mode="after") 

1179 def _mean_and_std_match(self) -> Self: 

1180 if len(self.mean) != len(self.std): 

1181 raise ValueError( 

1182 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1183 + " must match." 

1184 ) 

1185 

1186 return self 

1187 

1188 

1189class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1190 """Subtract a given mean and divide by the standard deviation. 

1191 

1192 Normalize with fixed, precomputed values for 

1193 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1194 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1195 axes. 

1196 

1197 Examples: 

1198 1. scalar value for whole tensor 

1199 - in YAML 

1200 ```yaml 

1201 preprocessing: 

1202 - id: fixed_zero_mean_unit_variance 

1203 kwargs: 

1204 mean: 103.5 

1205 std: 13.7 

1206 ``` 

1207 - in Python 

1208 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1209 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1210 ... )] 

1211 

1212 2. independently along an axis 

1213 - in YAML 

1214 ```yaml 

1215 preprocessing: 

1216 - id: fixed_zero_mean_unit_variance 

1217 kwargs: 

1218 axis: channel 

1219 mean: [101.5, 102.5, 103.5] 

1220 std: [11.7, 12.7, 13.7] 

1221 ``` 

1222 - in Python 

1223 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1224 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1225 ... axis=AxisId("channel"), 

1226 ... mean=[101.5, 102.5, 103.5], 

1227 ... std=[11.7, 12.7, 13.7], 

1228 ... ) 

1229 ... )] 

1230 """ 

1231 

1232 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1233 "fixed_zero_mean_unit_variance" 

1234 ) 

1235 if TYPE_CHECKING: 

1236 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1237 else: 

1238 id: Literal["fixed_zero_mean_unit_variance"] 

1239 

1240 kwargs: Union[ 

1241 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1242 ] 

1243 

1244 

1245class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1246 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

1247 

1248 axes: Annotated[ 

1249 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1250 ] = None 

1251 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1252 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1253 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1254 To normalize each sample independently leave out the 'batch' axis. 

1255 Default: Scale all axes jointly.""" 

1256 

1257 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1258 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1259 

1260 

1261class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1262 """Subtract mean and divide by variance. 

1263 

1264 Examples: 

1265 Subtract tensor mean and variance 

1266 - in YAML 

1267 ```yaml 

1268 preprocessing: 

1269 - id: zero_mean_unit_variance 

1270 ``` 

1271 - in Python 

1272 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1273 """ 

1274 

1275 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1276 "zero_mean_unit_variance" 

1277 ) 

1278 if TYPE_CHECKING: 

1279 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1280 else: 

1281 id: Literal["zero_mean_unit_variance"] 

1282 

1283 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1284 default_factory=ZeroMeanUnitVarianceKwargs 

1285 ) 

1286 

1287 

1288class ScaleRangeKwargs(ProcessingKwargs): 

1289 """key word arguments for `ScaleRangeDescr` 

1290 

1291 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1292 this processing step normalizes data to the [0, 1] intervall. 

1293 For other percentiles the normalized values will partially be outside the [0, 1] 

1294 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1295 normalized values to a range. 

1296 """ 

1297 

1298 axes: Annotated[ 

1299 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1300 ] = None 

1301 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1302 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1303 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1304 To normalize samples independently, leave out the "batch" axis. 

1305 Default: Scale all axes jointly.""" 

1306 

1307 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1308 """The lower percentile used to determine the value to align with zero.""" 

1309 

1310 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1311 """The upper percentile used to determine the value to align with one. 

1312 Has to be bigger than `min_percentile`. 

1313 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1314 accepting percentiles specified in the range 0.0 to 1.0.""" 

1315 

1316 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1317 """Epsilon for numeric stability. 

1318 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1319 with `v_lower,v_upper` values at the respective percentiles.""" 

1320 

1321 reference_tensor: Optional[TensorId] = None 

1322 """Tensor ID to compute the percentiles from. Default: The tensor itself. 

1323 For any tensor in `inputs` only input tensor references are allowed.""" 

1324 

1325 @field_validator("max_percentile", mode="after") 

1326 @classmethod 

1327 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1328 if (min_p := info.data["min_percentile"]) >= value: 

1329 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1330 

1331 return value 

1332 

1333 

1334class ScaleRangeDescr(ProcessingDescrBase): 

1335 """Scale with percentiles. 

1336 

1337 Examples: 

1338 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1339 - in YAML 

1340 ```yaml 

1341 preprocessing: 

1342 - id: scale_range 

1343 kwargs: 

1344 axes: ['y', 'x'] 

1345 max_percentile: 99.8 

1346 min_percentile: 5.0 

1347 ``` 

1348 - in Python 

1349 >>> preprocessing = [ 

1350 ... ScaleRangeDescr( 

1351 ... kwargs=ScaleRangeKwargs( 

1352 ... axes= (AxisId('y'), AxisId('x')), 

1353 ... max_percentile= 99.8, 

1354 ... min_percentile= 5.0, 

1355 ... ) 

1356 ... ), 

1357 ... ClipDescr( 

1358 ... kwargs=ClipKwargs( 

1359 ... min=0.0, 

1360 ... max=1.0, 

1361 ... ) 

1362 ... ), 

1363 ... ] 

1364 

1365 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1366 - in YAML 

1367 ```yaml 

1368 preprocessing: 

1369 - id: scale_range 

1370 kwargs: 

1371 axes: ['y', 'x'] 

1372 max_percentile: 99.8 

1373 min_percentile: 5.0 

1374 - id: scale_range 

1375 - id: clip 

1376 kwargs: 

1377 min: 0.0 

1378 max: 1.0 

1379 ``` 

1380 - in Python 

1381 >>> preprocessing = [ScaleRangeDescr( 

1382 ... kwargs=ScaleRangeKwargs( 

1383 ... axes= (AxisId('y'), AxisId('x')), 

1384 ... max_percentile= 99.8, 

1385 ... min_percentile= 5.0, 

1386 ... ) 

1387 ... )] 

1388 

1389 """ 

1390 

1391 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1392 if TYPE_CHECKING: 

1393 id: Literal["scale_range"] = "scale_range" 

1394 else: 

1395 id: Literal["scale_range"] 

1396 kwargs: ScaleRangeKwargs 

1397 

1398 

1399class ScaleMeanVarianceKwargs(ProcessingKwargs): 

1400 """key word arguments for `ScaleMeanVarianceKwargs`""" 

1401 

1402 reference_tensor: TensorId 

1403 """Name of tensor to match.""" 

1404 

1405 axes: Annotated[ 

1406 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1407 ] = None 

1408 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1409 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1410 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1411 To normalize samples independently, leave out the 'batch' axis. 

1412 Default: Scale all axes jointly.""" 

1413 

1414 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1415 """Epsilon for numeric stability: 

1416 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1417 

1418 

1419class ScaleMeanVarianceDescr(ProcessingDescrBase): 

1420 """Scale a tensor's data distribution to match another tensor's mean/std. 

1421 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1422 """ 

1423 

1424 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1425 if TYPE_CHECKING: 

1426 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1427 else: 

1428 id: Literal["scale_mean_variance"] 

1429 kwargs: ScaleMeanVarianceKwargs 

1430 

1431 

1432PreprocessingDescr = Annotated[ 

1433 Union[ 

1434 BinarizeDescr, 

1435 ClipDescr, 

1436 EnsureDtypeDescr, 

1437 ScaleLinearDescr, 

1438 SigmoidDescr, 

1439 FixedZeroMeanUnitVarianceDescr, 

1440 ZeroMeanUnitVarianceDescr, 

1441 ScaleRangeDescr, 

1442 ], 

1443 Discriminator("id"), 

1444] 

1445PostprocessingDescr = Annotated[ 

1446 Union[ 

1447 BinarizeDescr, 

1448 ClipDescr, 

1449 EnsureDtypeDescr, 

1450 ScaleLinearDescr, 

1451 SigmoidDescr, 

1452 FixedZeroMeanUnitVarianceDescr, 

1453 ZeroMeanUnitVarianceDescr, 

1454 ScaleRangeDescr, 

1455 ScaleMeanVarianceDescr, 

1456 ], 

1457 Discriminator("id"), 

1458] 

1459 

1460IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1461 

1462 

1463class TensorDescrBase(Node, Generic[IO_AxisT]): 

1464 id: TensorId 

1465 """Tensor id. No duplicates are allowed.""" 

1466 

1467 description: Annotated[str, MaxLen(128)] = "" 

1468 """free text description""" 

1469 

1470 axes: NotEmpty[Sequence[IO_AxisT]] 

1471 """tensor axes""" 

1472 

1473 @property 

1474 def shape(self): 

1475 return tuple(a.size for a in self.axes) 

1476 

1477 @field_validator("axes", mode="after", check_fields=False) 

1478 @classmethod 

1479 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1480 batch_axes = [a for a in axes if a.type == "batch"] 

1481 if len(batch_axes) > 1: 

1482 raise ValueError( 

1483 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1484 ) 

1485 

1486 seen_ids: Set[AxisId] = set() 

1487 duplicate_axes_ids: Set[AxisId] = set() 

1488 for a in axes: 

1489 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1490 

1491 if duplicate_axes_ids: 

1492 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1493 

1494 return axes 

1495 

1496 test_tensor: FileDescr_ 

1497 """An example tensor to use for testing. 

1498 Using the model with the test input tensors is expected to yield the test output tensors. 

1499 Each test tensor has be a an ndarray in the 

1500 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1501 The file extension must be '.npy'.""" 

1502 

1503 sample_tensor: Optional[FileDescr_] = None 

1504 """A sample tensor to illustrate a possible input/output for the model, 

1505 The sample image primarily serves to inform a human user about an example use case 

1506 and is typically stored as .hdf5, .png or .tiff. 

1507 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1508 (numpy's `.npy` format is not supported). 

1509 The image dimensionality has to match the number of axes specified in this tensor description. 

1510 """ 

1511 

1512 @model_validator(mode="after") 

1513 def _validate_sample_tensor(self) -> Self: 

1514 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1515 return self 

1516 

1517 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1518 tensor: NDArray[Any] = imread( 

1519 reader.read(), 

1520 extension=PurePosixPath(reader.original_file_name).suffix, 

1521 ) 

1522 n_dims = len(tensor.squeeze().shape) 

1523 n_dims_min = n_dims_max = len(self.axes) 

1524 

1525 for a in self.axes: 

1526 if isinstance(a, BatchAxis): 

1527 n_dims_min -= 1 

1528 elif isinstance(a.size, int): 

1529 if a.size == 1: 

1530 n_dims_min -= 1 

1531 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1532 if a.size.min == 1: 

1533 n_dims_min -= 1 

1534 elif isinstance(a.size, SizeReference): 

1535 if a.size.offset < 2: 

1536 # size reference may result in singleton axis 

1537 n_dims_min -= 1 

1538 else: 

1539 assert_never(a.size) 

1540 

1541 n_dims_min = max(0, n_dims_min) 

1542 if n_dims < n_dims_min or n_dims > n_dims_max: 

1543 raise ValueError( 

1544 f"Expected sample tensor to have {n_dims_min} to" 

1545 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1546 ) 

1547 

1548 return self 

1549 

1550 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1551 IntervalOrRatioDataDescr() 

1552 ) 

1553 """Description of the tensor's data values, optionally per channel. 

1554 If specified per channel, the data `type` needs to match across channels.""" 

1555 

1556 @property 

1557 def dtype( 

1558 self, 

1559 ) -> Literal[ 

1560 "float32", 

1561 "float64", 

1562 "uint8", 

1563 "int8", 

1564 "uint16", 

1565 "int16", 

1566 "uint32", 

1567 "int32", 

1568 "uint64", 

1569 "int64", 

1570 "bool", 

1571 ]: 

1572 """dtype as specified under `data.type` or `data[i].type`""" 

1573 if isinstance(self.data, collections.abc.Sequence): 

1574 return self.data[0].type 

1575 else: 

1576 return self.data.type 

1577 

1578 @field_validator("data", mode="after") 

1579 @classmethod 

1580 def _check_data_type_across_channels( 

1581 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1582 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1583 if not isinstance(value, list): 

1584 return value 

1585 

1586 dtypes = {t.type for t in value} 

1587 if len(dtypes) > 1: 

1588 raise ValueError( 

1589 "Tensor data descriptions per channel need to agree in their data" 

1590 + f" `type`, but found {dtypes}." 

1591 ) 

1592 

1593 return value 

1594 

1595 @model_validator(mode="after") 

1596 def _check_data_matches_channelaxis(self) -> Self: 

1597 if not isinstance(self.data, (list, tuple)): 

1598 return self 

1599 

1600 for a in self.axes: 

1601 if isinstance(a, ChannelAxis): 

1602 size = a.size 

1603 assert isinstance(size, int) 

1604 break 

1605 else: 

1606 return self 

1607 

1608 if len(self.data) != size: 

1609 raise ValueError( 

1610 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1611 + f" '{a.id}' axis has size {size}." 

1612 ) 

1613 

1614 return self 

1615 

1616 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1617 if len(array.shape) != len(self.axes): 

1618 raise ValueError( 

1619 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1620 + f" incompatible with {len(self.axes)} axes." 

1621 ) 

1622 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1623 

1624 

1625class InputTensorDescr(TensorDescrBase[InputAxis]): 

1626 id: TensorId = TensorId("input") 

1627 """Input tensor id. 

1628 No duplicates are allowed across all inputs and outputs.""" 

1629 

1630 optional: bool = False 

1631 """indicates that this tensor may be `None`""" 

1632 

1633 preprocessing: List[PreprocessingDescr] = Field( 

1634 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 

1635 ) 

1636 

1637 """Description of how this input should be preprocessed. 

1638 

1639 notes: 

1640 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1641 to ensure an input tensor's data type matches the input tensor's data description. 

1642 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1643 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1644 changing the data type. 

1645 """ 

1646 

1647 @model_validator(mode="after") 

1648 def _validate_preprocessing_kwargs(self) -> Self: 

1649 axes_ids = [a.id for a in self.axes] 

1650 for p in self.preprocessing: 

1651 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1652 if kwargs_axes is None: 

1653 continue 

1654 

1655 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1656 raise ValueError( 

1657 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1658 ) 

1659 

1660 if any(a not in axes_ids for a in kwargs_axes): 

1661 raise ValueError( 

1662 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1663 ) 

1664 

1665 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1666 dtype = self.data.type 

1667 else: 

1668 dtype = self.data[0].type 

1669 

1670 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1671 if not self.preprocessing or not isinstance( 

1672 self.preprocessing[0], EnsureDtypeDescr 

1673 ): 

1674 self.preprocessing.insert( 

1675 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1676 ) 

1677 

1678 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1679 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1680 self.preprocessing.append( 

1681 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1682 ) 

1683 

1684 return self 

1685 

1686 

1687def convert_axes( 

1688 axes: str, 

1689 *, 

1690 shape: Union[ 

1691 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1692 ], 

1693 tensor_type: Literal["input", "output"], 

1694 halo: Optional[Sequence[int]], 

1695 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1696): 

1697 ret: List[AnyAxis] = [] 

1698 for i, a in enumerate(axes): 

1699 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1700 if axis_type == "batch": 

1701 ret.append(BatchAxis()) 

1702 continue 

1703 

1704 scale = 1.0 

1705 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1706 if shape.step[i] == 0: 

1707 size = shape.min[i] 

1708 else: 

1709 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1710 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1711 ref_t = str(shape.reference_tensor) 

1712 if ref_t.count(".") == 1: 

1713 t_id, orig_a_id = ref_t.split(".") 

1714 else: 

1715 t_id = ref_t 

1716 orig_a_id = a 

1717 

1718 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1719 if not (orig_scale := shape.scale[i]): 

1720 # old way to insert a new axis dimension 

1721 size = int(2 * shape.offset[i]) 

1722 else: 

1723 scale = 1 / orig_scale 

1724 if axis_type in ("channel", "index"): 

1725 # these axes no longer have a scale 

1726 offset_from_scale = orig_scale * size_refs.get( 

1727 _TensorName_v0_4(t_id), {} 

1728 ).get(orig_a_id, 0) 

1729 else: 

1730 offset_from_scale = 0 

1731 size = SizeReference( 

1732 tensor_id=TensorId(t_id), 

1733 axis_id=AxisId(a_id), 

1734 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1735 ) 

1736 else: 

1737 size = shape[i] 

1738 

1739 if axis_type == "time": 

1740 if tensor_type == "input": 

1741 ret.append(TimeInputAxis(size=size, scale=scale)) 

1742 else: 

1743 assert not isinstance(size, ParameterizedSize) 

1744 if halo is None: 

1745 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1746 else: 

1747 assert not isinstance(size, int) 

1748 ret.append( 

1749 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1750 ) 

1751 

1752 elif axis_type == "index": 

1753 if tensor_type == "input": 

1754 ret.append(IndexInputAxis(size=size)) 

1755 else: 

1756 if isinstance(size, ParameterizedSize): 

1757 size = DataDependentSize(min=size.min) 

1758 

1759 ret.append(IndexOutputAxis(size=size)) 

1760 elif axis_type == "channel": 

1761 assert not isinstance(size, ParameterizedSize) 

1762 if isinstance(size, SizeReference): 

1763 warnings.warn( 

1764 "Conversion of channel size from an implicit output shape may be" 

1765 + " wrong" 

1766 ) 

1767 ret.append( 

1768 ChannelAxis( 

1769 channel_names=[ 

1770 Identifier(f"channel{i}") for i in range(size.offset) 

1771 ] 

1772 ) 

1773 ) 

1774 else: 

1775 ret.append( 

1776 ChannelAxis( 

1777 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1778 ) 

1779 ) 

1780 elif axis_type == "space": 

1781 if tensor_type == "input": 

1782 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

1783 else: 

1784 assert not isinstance(size, ParameterizedSize) 

1785 if halo is None or halo[i] == 0: 

1786 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

1787 elif isinstance(size, int): 

1788 raise NotImplementedError( 

1789 f"output axis with halo and fixed size (here {size}) not allowed" 

1790 ) 

1791 else: 

1792 ret.append( 

1793 SpaceOutputAxisWithHalo( 

1794 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

1795 ) 

1796 ) 

1797 

1798 return ret 

1799 

1800 

1801def _axes_letters_to_ids( 

1802 axes: Optional[str], 

1803) -> Optional[List[AxisId]]: 

1804 if axes is None: 

1805 return None 

1806 

1807 return [AxisId(a) for a in axes] 

1808 

1809 

1810def _get_complement_v04_axis( 

1811 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

1812) -> Optional[AxisId]: 

1813 if axes is None: 

1814 return None 

1815 

1816 non_complement_axes = set(axes) | {"b"} 

1817 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

1818 if len(complement_axes) > 1: 

1819 raise ValueError( 

1820 f"Expected none or a single complement axis, but axes '{axes}' " 

1821 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

1822 ) 

1823 

1824 return None if not complement_axes else AxisId(complement_axes[0]) 

1825 

1826 

1827def _convert_proc( 

1828 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

1829 tensor_axes: Sequence[str], 

1830) -> Union[PreprocessingDescr, PostprocessingDescr]: 

1831 if isinstance(p, _BinarizeDescr_v0_4): 

1832 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

1833 elif isinstance(p, _ClipDescr_v0_4): 

1834 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

1835 elif isinstance(p, _SigmoidDescr_v0_4): 

1836 return SigmoidDescr() 

1837 elif isinstance(p, _ScaleLinearDescr_v0_4): 

1838 axes = _axes_letters_to_ids(p.kwargs.axes) 

1839 if p.kwargs.axes is None: 

1840 axis = None 

1841 else: 

1842 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1843 

1844 if axis is None: 

1845 assert not isinstance(p.kwargs.gain, list) 

1846 assert not isinstance(p.kwargs.offset, list) 

1847 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

1848 else: 

1849 kwargs = ScaleLinearAlongAxisKwargs( 

1850 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

1851 ) 

1852 return ScaleLinearDescr(kwargs=kwargs) 

1853 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

1854 return ScaleMeanVarianceDescr( 

1855 kwargs=ScaleMeanVarianceKwargs( 

1856 axes=_axes_letters_to_ids(p.kwargs.axes), 

1857 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

1858 eps=p.kwargs.eps, 

1859 ) 

1860 ) 

1861 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

1862 if p.kwargs.mode == "fixed": 

1863 mean = p.kwargs.mean 

1864 std = p.kwargs.std 

1865 assert mean is not None 

1866 assert std is not None 

1867 

1868 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1869 

1870 if axis is None: 

1871 return FixedZeroMeanUnitVarianceDescr( 

1872 kwargs=FixedZeroMeanUnitVarianceKwargs( 

1873 mean=mean, std=std # pyright: ignore[reportArgumentType] 

1874 ) 

1875 ) 

1876 else: 

1877 if not isinstance(mean, list): 

1878 mean = [float(mean)] 

1879 if not isinstance(std, list): 

1880 std = [float(std)] 

1881 

1882 return FixedZeroMeanUnitVarianceDescr( 

1883 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1884 axis=axis, mean=mean, std=std 

1885 ) 

1886 ) 

1887 

1888 else: 

1889 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

1890 if p.kwargs.mode == "per_dataset": 

1891 axes = [AxisId("batch")] + axes 

1892 if not axes: 

1893 axes = None 

1894 return ZeroMeanUnitVarianceDescr( 

1895 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

1896 ) 

1897 

1898 elif isinstance(p, _ScaleRangeDescr_v0_4): 

1899 return ScaleRangeDescr( 

1900 kwargs=ScaleRangeKwargs( 

1901 axes=_axes_letters_to_ids(p.kwargs.axes), 

1902 min_percentile=p.kwargs.min_percentile, 

1903 max_percentile=p.kwargs.max_percentile, 

1904 eps=p.kwargs.eps, 

1905 ) 

1906 ) 

1907 else: 

1908 assert_never(p) 

1909 

1910 

1911class _InputTensorConv( 

1912 Converter[ 

1913 _InputTensorDescr_v0_4, 

1914 InputTensorDescr, 

1915 FileSource_, 

1916 Optional[FileSource_], 

1917 Mapping[_TensorName_v0_4, Mapping[str, int]], 

1918 ] 

1919): 

1920 def _convert( 

1921 self, 

1922 src: _InputTensorDescr_v0_4, 

1923 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

1924 test_tensor: FileSource_, 

1925 sample_tensor: Optional[FileSource_], 

1926 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1927 ) -> "InputTensorDescr | dict[str, Any]": 

1928 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

1929 src.axes, 

1930 shape=src.shape, 

1931 tensor_type="input", 

1932 halo=None, 

1933 size_refs=size_refs, 

1934 ) 

1935 prep: List[PreprocessingDescr] = [] 

1936 for p in src.preprocessing: 

1937 cp = _convert_proc(p, src.axes) 

1938 assert not isinstance(cp, ScaleMeanVarianceDescr) 

1939 prep.append(cp) 

1940 

1941 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

1942 

1943 return tgt( 

1944 axes=axes, 

1945 id=TensorId(str(src.name)), 

1946 test_tensor=FileDescr(source=test_tensor), 

1947 sample_tensor=( 

1948 None if sample_tensor is None else FileDescr(source=sample_tensor) 

1949 ), 

1950 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

1951 preprocessing=prep, 

1952 ) 

1953 

1954 

1955_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

1956 

1957 

1958class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

1959 id: TensorId = TensorId("output") 

1960 """Output tensor id. 

1961 No duplicates are allowed across all inputs and outputs.""" 

1962 

1963 postprocessing: List[PostprocessingDescr] = Field( 

1964 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

1965 ) 

1966 """Description of how this output should be postprocessed. 

1967 

1968 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

1969 If not given this is added to cast to this tensor's `data.type`. 

1970 """ 

1971 

1972 @model_validator(mode="after") 

1973 def _validate_postprocessing_kwargs(self) -> Self: 

1974 axes_ids = [a.id for a in self.axes] 

1975 for p in self.postprocessing: 

1976 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1977 if kwargs_axes is None: 

1978 continue 

1979 

1980 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1981 raise ValueError( 

1982 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

1983 ) 

1984 

1985 if any(a not in axes_ids for a in kwargs_axes): 

1986 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

1987 

1988 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1989 dtype = self.data.type 

1990 else: 

1991 dtype = self.data[0].type 

1992 

1993 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1994 if not self.postprocessing or not isinstance( 

1995 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

1996 ): 

1997 self.postprocessing.append( 

1998 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1999 ) 

2000 return self 

2001 

2002 

2003class _OutputTensorConv( 

2004 Converter[ 

2005 _OutputTensorDescr_v0_4, 

2006 OutputTensorDescr, 

2007 FileSource_, 

2008 Optional[FileSource_], 

2009 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2010 ] 

2011): 

2012 def _convert( 

2013 self, 

2014 src: _OutputTensorDescr_v0_4, 

2015 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

2016 test_tensor: FileSource_, 

2017 sample_tensor: Optional[FileSource_], 

2018 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2019 ) -> "OutputTensorDescr | dict[str, Any]": 

2020 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2021 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2022 src.axes, 

2023 shape=src.shape, 

2024 tensor_type="output", 

2025 halo=src.halo, 

2026 size_refs=size_refs, 

2027 ) 

2028 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2029 if data_descr["type"] == "bool": 

2030 data_descr["values"] = [False, True] 

2031 

2032 return tgt( 

2033 axes=axes, 

2034 id=TensorId(str(src.name)), 

2035 test_tensor=FileDescr(source=test_tensor), 

2036 sample_tensor=( 

2037 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2038 ), 

2039 data=data_descr, # pyright: ignore[reportArgumentType] 

2040 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2041 ) 

2042 

2043 

2044_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2045 

2046 

2047TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2048 

2049 

2050def validate_tensors( 

2051 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 

2052 tensor_origin: Literal[ 

2053 "test_tensor" 

2054 ], # for more precise error messages, e.g. 'test_tensor' 

2055): 

2056 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 

2057 

2058 def e_msg(d: TensorDescr): 

2059 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2060 

2061 for descr, array in tensors.values(): 

2062 try: 

2063 axis_sizes = descr.get_axis_sizes_for_array(array) 

2064 except ValueError as e: 

2065 raise ValueError(f"{e_msg(descr)} {e}") 

2066 else: 

2067 all_tensor_axes[descr.id] = { 

2068 a.id: (a, axis_sizes[a.id]) for a in descr.axes 

2069 } 

2070 

2071 for descr, array in tensors.values(): 

2072 if descr.dtype in ("float32", "float64"): 

2073 invalid_test_tensor_dtype = array.dtype.name not in ( 

2074 "float32", 

2075 "float64", 

2076 "uint8", 

2077 "int8", 

2078 "uint16", 

2079 "int16", 

2080 "uint32", 

2081 "int32", 

2082 "uint64", 

2083 "int64", 

2084 ) 

2085 else: 

2086 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2087 

2088 if invalid_test_tensor_dtype: 

2089 raise ValueError( 

2090 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 

2091 + f" match described dtype '{descr.dtype}'" 

2092 ) 

2093 

2094 if array.min() > -1e-4 and array.max() < 1e-4: 

2095 raise ValueError( 

2096 "Output values are too small for reliable testing." 

2097 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2098 ) 

2099 

2100 for a in descr.axes: 

2101 actual_size = all_tensor_axes[descr.id][a.id][1] 

2102 if a.size is None: 

2103 continue 

2104 

2105 if isinstance(a.size, int): 

2106 if actual_size != a.size: 

2107 raise ValueError( 

2108 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 

2109 + f"has incompatible size {actual_size}, expected {a.size}" 

2110 ) 

2111 elif isinstance(a.size, ParameterizedSize): 

2112 _ = a.size.validate_size(actual_size) 

2113 elif isinstance(a.size, DataDependentSize): 

2114 _ = a.size.validate_size(actual_size) 

2115 elif isinstance(a.size, SizeReference): 

2116 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2117 if ref_tensor_axes is None: 

2118 raise ValueError( 

2119 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2120 + f" reference '{a.size.tensor_id}'" 

2121 ) 

2122 

2123 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2124 if ref_axis is None or ref_size is None: 

2125 raise ValueError( 

2126 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2127 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 

2128 ) 

2129 

2130 if a.unit != ref_axis.unit: 

2131 raise ValueError( 

2132 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2133 + " axis and reference axis to have the same `unit`, but" 

2134 + f" {a.unit}!={ref_axis.unit}" 

2135 ) 

2136 

2137 if actual_size != ( 

2138 expected_size := ( 

2139 ref_size * ref_axis.scale / a.scale + a.size.offset 

2140 ) 

2141 ): 

2142 raise ValueError( 

2143 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2144 + f" {actual_size} invalid for referenced size {ref_size};" 

2145 + f" expected {expected_size}" 

2146 ) 

2147 else: 

2148 assert_never(a.size) 

2149 

2150 

2151FileDescr_dependencies = Annotated[ 

2152 FileDescr_, 

2153 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2154 Field(examples=[dict(source="environment.yaml")]), 

2155] 

2156 

2157 

2158class _ArchitectureCallableDescr(Node): 

2159 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2160 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2161 

2162 kwargs: Dict[str, YamlValue] = Field( 

2163 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2164 ) 

2165 """key word arguments for the `callable`""" 

2166 

2167 

2168class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2169 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2170 """Architecture source file""" 

2171 

2172 @model_serializer(mode="wrap", when_used="unless-none") 

2173 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2174 return package_file_descr_serializer(self, nxt, info) 

2175 

2176 

2177class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2178 import_from: str 

2179 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2180 

2181 

2182class _ArchFileConv( 

2183 Converter[ 

2184 _CallableFromFile_v0_4, 

2185 ArchitectureFromFileDescr, 

2186 Optional[Sha256], 

2187 Dict[str, Any], 

2188 ] 

2189): 

2190 def _convert( 

2191 self, 

2192 src: _CallableFromFile_v0_4, 

2193 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2194 sha256: Optional[Sha256], 

2195 kwargs: Dict[str, Any], 

2196 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2197 if src.startswith("http") and src.count(":") == 2: 

2198 http, source, callable_ = src.split(":") 

2199 source = ":".join((http, source)) 

2200 elif not src.startswith("http") and src.count(":") == 1: 

2201 source, callable_ = src.split(":") 

2202 else: 

2203 source = str(src) 

2204 callable_ = str(src) 

2205 return tgt( 

2206 callable=Identifier(callable_), 

2207 source=cast(FileSource_, source), 

2208 sha256=sha256, 

2209 kwargs=kwargs, 

2210 ) 

2211 

2212 

2213_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2214 

2215 

2216class _ArchLibConv( 

2217 Converter[ 

2218 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2219 ] 

2220): 

2221 def _convert( 

2222 self, 

2223 src: _CallableFromDepencency_v0_4, 

2224 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2225 kwargs: Dict[str, Any], 

2226 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2227 *mods, callable_ = src.split(".") 

2228 import_from = ".".join(mods) 

2229 return tgt( 

2230 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2231 ) 

2232 

2233 

2234_arch_lib_conv = _ArchLibConv( 

2235 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2236) 

2237 

2238 

2239class WeightsEntryDescrBase(FileDescr): 

2240 type: ClassVar[WeightsFormat] 

2241 weights_format_name: ClassVar[str] # human readable 

2242 

2243 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2244 """Source of the weights file.""" 

2245 

2246 authors: Optional[List[Author]] = None 

2247 """Authors 

2248 Either the person(s) that have trained this model resulting in the original weights file. 

2249 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2250 Or the person(s) who have converted the weights to this weights format. 

2251 (If this is a child weight, i.e. it has a `parent` field) 

2252 """ 

2253 

2254 parent: Annotated[ 

2255 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2256 ] = None 

2257 """The source weights these weights were converted from. 

2258 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2259 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2260 All weight entries except one (the initial set of weights resulting from training the model), 

2261 need to have this field.""" 

2262 

2263 comment: str = "" 

2264 """A comment about this weights entry, for example how these weights were created.""" 

2265 

2266 @model_validator(mode="after") 

2267 def _validate(self) -> Self: 

2268 if self.type == self.parent: 

2269 raise ValueError("Weights entry can't be it's own parent.") 

2270 

2271 return self 

2272 

2273 @model_serializer(mode="wrap", when_used="unless-none") 

2274 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2275 return package_file_descr_serializer(self, nxt, info) 

2276 

2277 

2278class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2279 type = "keras_hdf5" 

2280 weights_format_name: ClassVar[str] = "Keras HDF5" 

2281 tensorflow_version: Version 

2282 """TensorFlow version used to create these weights.""" 

2283 

2284 

2285class OnnxWeightsDescr(WeightsEntryDescrBase): 

2286 type = "onnx" 

2287 weights_format_name: ClassVar[str] = "ONNX" 

2288 opset_version: Annotated[int, Ge(7)] 

2289 """ONNX opset version""" 

2290 

2291 

2292class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2293 type = "pytorch_state_dict" 

2294 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2295 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 

2296 pytorch_version: Version 

2297 """Version of the PyTorch library used. 

2298 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2299 """ 

2300 dependencies: Optional[FileDescr_dependencies] = None 

2301 """Custom depencies beyond pytorch described in a Conda environment file. 

2302 Allows to specify custom dependencies, see conda docs: 

2303 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2304 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2305 

2306 The conda environment file should include pytorch and any version pinning has to be compatible with 

2307 **pytorch_version**. 

2308 """ 

2309 

2310 

2311class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2312 type = "tensorflow_js" 

2313 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2314 tensorflow_version: Version 

2315 """Version of the TensorFlow library used.""" 

2316 

2317 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2318 """The multi-file weights. 

2319 All required files/folders should be a zip archive.""" 

2320 

2321 

2322class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2323 type = "tensorflow_saved_model_bundle" 

2324 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2325 tensorflow_version: Version 

2326 """Version of the TensorFlow library used.""" 

2327 

2328 dependencies: Optional[FileDescr_dependencies] = None 

2329 """Custom dependencies beyond tensorflow. 

2330 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 

2331 

2332 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2333 """The multi-file weights. 

2334 All required files/folders should be a zip archive.""" 

2335 

2336 

2337class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2338 type = "torchscript" 

2339 weights_format_name: ClassVar[str] = "TorchScript" 

2340 pytorch_version: Version 

2341 """Version of the PyTorch library used.""" 

2342 

2343 

2344class WeightsDescr(Node): 

2345 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2346 onnx: Optional[OnnxWeightsDescr] = None 

2347 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2348 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2349 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2350 None 

2351 ) 

2352 torchscript: Optional[TorchscriptWeightsDescr] = None 

2353 

2354 @model_validator(mode="after") 

2355 def check_entries(self) -> Self: 

2356 entries = {wtype for wtype, entry in self if entry is not None} 

2357 

2358 if not entries: 

2359 raise ValueError("Missing weights entry") 

2360 

2361 entries_wo_parent = { 

2362 wtype 

2363 for wtype, entry in self 

2364 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2365 } 

2366 if len(entries_wo_parent) != 1: 

2367 issue_warning( 

2368 "Exactly one weights entry may not specify the `parent` field (got" 

2369 + " {value}). That entry is considered the original set of model weights." 

2370 + " Other weight formats are created through conversion of the orignal or" 

2371 + " already converted weights. They have to reference the weights format" 

2372 + " they were converted from as their `parent`.", 

2373 value=len(entries_wo_parent), 

2374 field="weights", 

2375 ) 

2376 

2377 for wtype, entry in self: 

2378 if entry is None: 

2379 continue 

2380 

2381 assert hasattr(entry, "type") 

2382 assert hasattr(entry, "parent") 

2383 assert wtype == entry.type 

2384 if ( 

2385 entry.parent is not None and entry.parent not in entries 

2386 ): # self reference checked for `parent` field 

2387 raise ValueError( 

2388 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2389 + f" formats: {entries}" 

2390 ) 

2391 

2392 return self 

2393 

2394 def __getitem__( 

2395 self, 

2396 key: Literal[ 

2397 "keras_hdf5", 

2398 "onnx", 

2399 "pytorch_state_dict", 

2400 "tensorflow_js", 

2401 "tensorflow_saved_model_bundle", 

2402 "torchscript", 

2403 ], 

2404 ): 

2405 if key == "keras_hdf5": 

2406 ret = self.keras_hdf5 

2407 elif key == "onnx": 

2408 ret = self.onnx 

2409 elif key == "pytorch_state_dict": 

2410 ret = self.pytorch_state_dict 

2411 elif key == "tensorflow_js": 

2412 ret = self.tensorflow_js 

2413 elif key == "tensorflow_saved_model_bundle": 

2414 ret = self.tensorflow_saved_model_bundle 

2415 elif key == "torchscript": 

2416 ret = self.torchscript 

2417 else: 

2418 raise KeyError(key) 

2419 

2420 if ret is None: 

2421 raise KeyError(key) 

2422 

2423 return ret 

2424 

2425 @property 

2426 def available_formats(self): 

2427 return { 

2428 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2429 **({} if self.onnx is None else {"onnx": self.onnx}), 

2430 **( 

2431 {} 

2432 if self.pytorch_state_dict is None 

2433 else {"pytorch_state_dict": self.pytorch_state_dict} 

2434 ), 

2435 **( 

2436 {} 

2437 if self.tensorflow_js is None 

2438 else {"tensorflow_js": self.tensorflow_js} 

2439 ), 

2440 **( 

2441 {} 

2442 if self.tensorflow_saved_model_bundle is None 

2443 else { 

2444 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2445 } 

2446 ), 

2447 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2448 } 

2449 

2450 @property 

2451 def missing_formats(self): 

2452 return { 

2453 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2454 } 

2455 

2456 

2457class ModelId(ResourceId): 

2458 pass 

2459 

2460 

2461class LinkedModel(LinkedResourceBase): 

2462 """Reference to a bioimage.io model.""" 

2463 

2464 id: ModelId 

2465 """A valid model `id` from the bioimage.io collection.""" 

2466 

2467 

2468class _DataDepSize(NamedTuple): 

2469 min: int 

2470 max: Optional[int] 

2471 

2472 

2473class _AxisSizes(NamedTuple): 

2474 """the lenghts of all axes of model inputs and outputs""" 

2475 

2476 inputs: Dict[Tuple[TensorId, AxisId], int] 

2477 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2478 

2479 

2480class _TensorSizes(NamedTuple): 

2481 """_AxisSizes as nested dicts""" 

2482 

2483 inputs: Dict[TensorId, Dict[AxisId, int]] 

2484 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2485 

2486 

2487class ReproducibilityTolerance(Node, extra="allow"): 

2488 """Describes what small numerical differences -- if any -- may be tolerated 

2489 in the generated output when executing in different environments. 

2490 

2491 A tensor element *output* is considered mismatched to the **test_tensor** if 

2492 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

2493 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

2494 

2495 Motivation: 

2496 For testing we can request the respective deep learning frameworks to be as 

2497 reproducible as possible by setting seeds and chosing deterministic algorithms, 

2498 but differences in operating systems, available hardware and installed drivers 

2499 may still lead to numerical differences. 

2500 """ 

2501 

2502 relative_tolerance: RelativeTolerance = 1e-3 

2503 """Maximum relative tolerance of reproduced test tensor.""" 

2504 

2505 absolute_tolerance: AbsoluteTolerance = 1e-4 

2506 """Maximum absolute tolerance of reproduced test tensor.""" 

2507 

2508 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 

2509 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

2510 

2511 output_ids: Sequence[TensorId] = () 

2512 """Limits the output tensor IDs these reproducibility details apply to.""" 

2513 

2514 weights_formats: Sequence[WeightsFormat] = () 

2515 """Limits the weights formats these details apply to.""" 

2516 

2517 

2518class BioimageioConfig(Node, extra="allow"): 

2519 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

2520 """Tolerances to allow when reproducing the model's test outputs 

2521 from the model's test inputs. 

2522 Only the first entry matching tensor id and weights format is considered. 

2523 """ 

2524 

2525 

2526class Config(Node, extra="allow"): 

2527 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig) 

2528 

2529 

2530class ModelDescr(GenericModelDescrBase): 

2531 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

2532 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

2533 """ 

2534 

2535 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 

2536 if TYPE_CHECKING: 

2537 format_version: Literal["0.5.4"] = "0.5.4" 

2538 else: 

2539 format_version: Literal["0.5.4"] 

2540 """Version of the bioimage.io model description specification used. 

2541 When creating a new model always use the latest micro/patch version described here. 

2542 The `format_version` is important for any consumer software to understand how to parse the fields. 

2543 """ 

2544 

2545 implemented_type: ClassVar[Literal["model"]] = "model" 

2546 if TYPE_CHECKING: 

2547 type: Literal["model"] = "model" 

2548 else: 

2549 type: Literal["model"] 

2550 """Specialized resource type 'model'""" 

2551 

2552 id: Optional[ModelId] = None 

2553 """bioimage.io-wide unique resource identifier 

2554 assigned by bioimage.io; version **un**specific.""" 

2555 

2556 authors: NotEmpty[List[Author]] 

2557 """The authors are the creators of the model RDF and the primary points of contact.""" 

2558 

2559 documentation: FileSource_documentation 

2560 """URL or relative path to a markdown file with additional documentation. 

2561 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

2562 The documentation should include a '#[#] Validation' (sub)section 

2563 with details on how to quantitatively validate the model on unseen data.""" 

2564 

2565 @field_validator("documentation", mode="after") 

2566 @classmethod 

2567 def _validate_documentation( 

2568 cls, value: FileSource_documentation 

2569 ) -> FileSource_documentation: 

2570 if not get_validation_context().perform_io_checks: 

2571 return value 

2572 

2573 doc_reader = get_reader(value) 

2574 doc_content = doc_reader.read().decode(encoding="utf-8") 

2575 if not re.search("#.*[vV]alidation", doc_content): 

2576 issue_warning( 

2577 "No '# Validation' (sub)section found in {value}.", 

2578 value=value, 

2579 field="documentation", 

2580 ) 

2581 

2582 return value 

2583 

2584 inputs: NotEmpty[Sequence[InputTensorDescr]] 

2585 """Describes the input tensors expected by this model.""" 

2586 

2587 @field_validator("inputs", mode="after") 

2588 @classmethod 

2589 def _validate_input_axes( 

2590 cls, inputs: Sequence[InputTensorDescr] 

2591 ) -> Sequence[InputTensorDescr]: 

2592 input_size_refs = cls._get_axes_with_independent_size(inputs) 

2593 

2594 for i, ipt in enumerate(inputs): 

2595 valid_independent_refs: Dict[ 

2596 Tuple[TensorId, AxisId], 

2597 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2598 ] = { 

2599 **{ 

2600 (ipt.id, a.id): (ipt, a, a.size) 

2601 for a in ipt.axes 

2602 if not isinstance(a, BatchAxis) 

2603 and isinstance(a.size, (int, ParameterizedSize)) 

2604 }, 

2605 **input_size_refs, 

2606 } 

2607 for a, ax in enumerate(ipt.axes): 

2608 cls._validate_axis( 

2609 "inputs", 

2610 i=i, 

2611 tensor_id=ipt.id, 

2612 a=a, 

2613 axis=ax, 

2614 valid_independent_refs=valid_independent_refs, 

2615 ) 

2616 return inputs 

2617 

2618 @staticmethod 

2619 def _validate_axis( 

2620 field_name: str, 

2621 i: int, 

2622 tensor_id: TensorId, 

2623 a: int, 

2624 axis: AnyAxis, 

2625 valid_independent_refs: Dict[ 

2626 Tuple[TensorId, AxisId], 

2627 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2628 ], 

2629 ): 

2630 if isinstance(axis, BatchAxis) or isinstance( 

2631 axis.size, (int, ParameterizedSize, DataDependentSize) 

2632 ): 

2633 return 

2634 elif not isinstance(axis.size, SizeReference): 

2635 assert_never(axis.size) 

2636 

2637 # validate axis.size SizeReference 

2638 ref = (axis.size.tensor_id, axis.size.axis_id) 

2639 if ref not in valid_independent_refs: 

2640 raise ValueError( 

2641 "Invalid tensor axis reference at" 

2642 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

2643 ) 

2644 if ref == (tensor_id, axis.id): 

2645 raise ValueError( 

2646 "Self-referencing not allowed for" 

2647 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

2648 ) 

2649 if axis.type == "channel": 

2650 if valid_independent_refs[ref][1].type != "channel": 

2651 raise ValueError( 

2652 "A channel axis' size may only reference another fixed size" 

2653 + " channel axis." 

2654 ) 

2655 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

2656 ref_size = valid_independent_refs[ref][2] 

2657 assert isinstance(ref_size, int), ( 

2658 "channel axis ref (another channel axis) has to specify fixed" 

2659 + " size" 

2660 ) 

2661 generated_channel_names = [ 

2662 Identifier(axis.channel_names.format(i=i)) 

2663 for i in range(1, ref_size + 1) 

2664 ] 

2665 axis.channel_names = generated_channel_names 

2666 

2667 if (ax_unit := getattr(axis, "unit", None)) != ( 

2668 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

2669 ): 

2670 raise ValueError( 

2671 "The units of an axis and its reference axis need to match, but" 

2672 + f" '{ax_unit}' != '{ref_unit}'." 

2673 ) 

2674 ref_axis = valid_independent_refs[ref][1] 

2675 if isinstance(ref_axis, BatchAxis): 

2676 raise ValueError( 

2677 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

2678 + " (a batch axis is not allowed as reference)." 

2679 ) 

2680 

2681 if isinstance(axis, WithHalo): 

2682 min_size = axis.size.get_size(axis, ref_axis, n=0) 

2683 if (min_size - 2 * axis.halo) < 1: 

2684 raise ValueError( 

2685 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

2686 + f" {axis.halo}." 

2687 ) 

2688 

2689 input_halo = axis.halo * axis.scale / ref_axis.scale 

2690 if input_halo != int(input_halo) or input_halo % 2 == 1: 

2691 raise ValueError( 

2692 f"input_halo {input_halo} (output_halo {axis.halo} *" 

2693 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 

2694 + f" {tensor_id}.{axis.id}." 

2695 ) 

2696 

2697 @model_validator(mode="after") 

2698 def _validate_test_tensors(self) -> Self: 

2699 if not get_validation_context().perform_io_checks: 

2700 return self 

2701 

2702 test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs] 

2703 test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs] 

2704 

2705 tensors = { 

2706 descr.id: (descr, array) 

2707 for descr, array in zip( 

2708 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 

2709 ) 

2710 } 

2711 validate_tensors(tensors, tensor_origin="test_tensor") 

2712 

2713 output_arrays = { 

2714 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 

2715 } 

2716 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

2717 if not rep_tol.absolute_tolerance: 

2718 continue 

2719 

2720 if rep_tol.output_ids: 

2721 out_arrays = { 

2722 oid: a 

2723 for oid, a in output_arrays.items() 

2724 if oid in rep_tol.output_ids 

2725 } 

2726 else: 

2727 out_arrays = output_arrays 

2728 

2729 for out_id, array in out_arrays.items(): 

2730 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

2731 raise ValueError( 

2732 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

2733 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

2734 + f" (1% of the maximum value of the test tensor '{out_id}')" 

2735 ) 

2736 

2737 return self 

2738 

2739 @model_validator(mode="after") 

2740 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

2741 ipt_refs = {t.id for t in self.inputs} 

2742 out_refs = {t.id for t in self.outputs} 

2743 for ipt in self.inputs: 

2744 for p in ipt.preprocessing: 

2745 ref = p.kwargs.get("reference_tensor") 

2746 if ref is None: 

2747 continue 

2748 if ref not in ipt_refs: 

2749 raise ValueError( 

2750 f"`reference_tensor` '{ref}' not found. Valid input tensor" 

2751 + f" references are: {ipt_refs}." 

2752 ) 

2753 

2754 for out in self.outputs: 

2755 for p in out.postprocessing: 

2756 ref = p.kwargs.get("reference_tensor") 

2757 if ref is None: 

2758 continue 

2759 

2760 if ref not in ipt_refs and ref not in out_refs: 

2761 raise ValueError( 

2762 f"`reference_tensor` '{ref}' not found. Valid tensor references" 

2763 + f" are: {ipt_refs | out_refs}." 

2764 ) 

2765 

2766 return self 

2767 

2768 # TODO: use validate funcs in validate_test_tensors 

2769 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 

2770 

2771 name: Annotated[ 

2772 Annotated[ 

2773 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 

2774 ], 

2775 MinLen(5), 

2776 MaxLen(128), 

2777 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

2778 ] 

2779 """A human-readable name of this model. 

2780 It should be no longer than 64 characters 

2781 and may only contain letter, number, underscore, minus, parentheses and spaces. 

2782 We recommend to chose a name that refers to the model's task and image modality. 

2783 """ 

2784 

2785 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

2786 """Describes the output tensors.""" 

2787 

2788 @field_validator("outputs", mode="after") 

2789 @classmethod 

2790 def _validate_tensor_ids( 

2791 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

2792 ) -> Sequence[OutputTensorDescr]: 

2793 tensor_ids = [ 

2794 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

2795 ] 

2796 duplicate_tensor_ids: List[str] = [] 

2797 seen: Set[str] = set() 

2798 for t in tensor_ids: 

2799 if t in seen: 

2800 duplicate_tensor_ids.append(t) 

2801 

2802 seen.add(t) 

2803 

2804 if duplicate_tensor_ids: 

2805 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

2806 

2807 return outputs 

2808 

2809 @staticmethod 

2810 def _get_axes_with_parameterized_size( 

2811 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2812 ): 

2813 return { 

2814 f"{t.id}.{a.id}": (t, a, a.size) 

2815 for t in io 

2816 for a in t.axes 

2817 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

2818 } 

2819 

2820 @staticmethod 

2821 def _get_axes_with_independent_size( 

2822 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2823 ): 

2824 return { 

2825 (t.id, a.id): (t, a, a.size) 

2826 for t in io 

2827 for a in t.axes 

2828 if not isinstance(a, BatchAxis) 

2829 and isinstance(a.size, (int, ParameterizedSize)) 

2830 } 

2831 

2832 @field_validator("outputs", mode="after") 

2833 @classmethod 

2834 def _validate_output_axes( 

2835 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

2836 ) -> List[OutputTensorDescr]: 

2837 input_size_refs = cls._get_axes_with_independent_size( 

2838 info.data.get("inputs", []) 

2839 ) 

2840 output_size_refs = cls._get_axes_with_independent_size(outputs) 

2841 

2842 for i, out in enumerate(outputs): 

2843 valid_independent_refs: Dict[ 

2844 Tuple[TensorId, AxisId], 

2845 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2846 ] = { 

2847 **{ 

2848 (out.id, a.id): (out, a, a.size) 

2849 for a in out.axes 

2850 if not isinstance(a, BatchAxis) 

2851 and isinstance(a.size, (int, ParameterizedSize)) 

2852 }, 

2853 **input_size_refs, 

2854 **output_size_refs, 

2855 } 

2856 for a, ax in enumerate(out.axes): 

2857 cls._validate_axis( 

2858 "outputs", 

2859 i, 

2860 out.id, 

2861 a, 

2862 ax, 

2863 valid_independent_refs=valid_independent_refs, 

2864 ) 

2865 

2866 return outputs 

2867 

2868 packaged_by: List[Author] = Field( 

2869 default_factory=cast(Callable[[], List[Author]], list) 

2870 ) 

2871 """The persons that have packaged and uploaded this model. 

2872 Only required if those persons differ from the `authors`.""" 

2873 

2874 parent: Optional[LinkedModel] = None 

2875 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

2876 

2877 @model_validator(mode="after") 

2878 def _validate_parent_is_not_self(self) -> Self: 

2879 if self.parent is not None and self.parent.id == self.id: 

2880 raise ValueError("A model description may not reference itself as parent.") 

2881 

2882 return self 

2883 

2884 run_mode: Annotated[ 

2885 Optional[RunMode], 

2886 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

2887 ] = None 

2888 """Custom run mode for this model: for more complex prediction procedures like test time 

2889 data augmentation that currently cannot be expressed in the specification. 

2890 No standard run modes are defined yet.""" 

2891 

2892 timestamp: Datetime = Field(default_factory=Datetime.now) 

2893 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

2894 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

2895 (In Python a datetime object is valid, too).""" 

2896 

2897 training_data: Annotated[ 

2898 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

2899 Field(union_mode="left_to_right"), 

2900 ] = None 

2901 """The dataset used to train this model""" 

2902 

2903 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

2904 """The weights for this model. 

2905 Weights can be given for different formats, but should otherwise be equivalent. 

2906 The available weight formats determine which consumers can use this model.""" 

2907 

2908 config: Config = Field(default_factory=Config) 

2909 

2910 @model_validator(mode="after") 

2911 def _add_default_cover(self) -> Self: 

2912 if not get_validation_context().perform_io_checks or self.covers: 

2913 return self 

2914 

2915 try: 

2916 generated_covers = generate_covers( 

2917 [(t, load_array(t.test_tensor)) for t in self.inputs], 

2918 [(t, load_array(t.test_tensor)) for t in self.outputs], 

2919 ) 

2920 except Exception as e: 

2921 issue_warning( 

2922 "Failed to generate cover image(s): {e}", 

2923 value=self.covers, 

2924 msg_context=dict(e=e), 

2925 field="covers", 

2926 ) 

2927 else: 

2928 self.covers.extend(generated_covers) 

2929 

2930 return self 

2931 

2932 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

2933 data = [load_array(ipt.test_tensor) for ipt in self.inputs] 

2934 assert all(isinstance(d, np.ndarray) for d in data) 

2935 return data 

2936 

2937 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

2938 data = [load_array(out.test_tensor) for out in self.outputs] 

2939 assert all(isinstance(d, np.ndarray) for d in data) 

2940 return data 

2941 

2942 @staticmethod 

2943 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

2944 batch_size = 1 

2945 tensor_with_batchsize: Optional[TensorId] = None 

2946 for tid in tensor_sizes: 

2947 for aid, s in tensor_sizes[tid].items(): 

2948 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

2949 continue 

2950 

2951 if batch_size != 1: 

2952 assert tensor_with_batchsize is not None 

2953 raise ValueError( 

2954 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

2955 ) 

2956 

2957 batch_size = s 

2958 tensor_with_batchsize = tid 

2959 

2960 return batch_size 

2961 

2962 def get_output_tensor_sizes( 

2963 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

2964 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

2965 """Returns the tensor output sizes for given **input_sizes**. 

2966 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

2967 Otherwise it might be larger than the actual (valid) output""" 

2968 batch_size = self.get_batch_size(input_sizes) 

2969 ns = self.get_ns(input_sizes) 

2970 

2971 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

2972 return tensor_sizes.outputs 

2973 

2974 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

2975 """get parameter `n` for each parameterized axis 

2976 such that the valid input size is >= the given input size""" 

2977 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

2978 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

2979 for tid in input_sizes: 

2980 for aid, s in input_sizes[tid].items(): 

2981 size_descr = axes[tid][aid].size 

2982 if isinstance(size_descr, ParameterizedSize): 

2983 ret[(tid, aid)] = size_descr.get_n(s) 

2984 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

2985 pass 

2986 else: 

2987 assert_never(size_descr) 

2988 

2989 return ret 

2990 

2991 def get_tensor_sizes( 

2992 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

2993 ) -> _TensorSizes: 

2994 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

2995 return _TensorSizes( 

2996 { 

2997 t: { 

2998 aa: axis_sizes.inputs[(tt, aa)] 

2999 for tt, aa in axis_sizes.inputs 

3000 if tt == t 

3001 } 

3002 for t in {tt for tt, _ in axis_sizes.inputs} 

3003 }, 

3004 { 

3005 t: { 

3006 aa: axis_sizes.outputs[(tt, aa)] 

3007 for tt, aa in axis_sizes.outputs 

3008 if tt == t 

3009 } 

3010 for t in {tt for tt, _ in axis_sizes.outputs} 

3011 }, 

3012 ) 

3013 

3014 def get_axis_sizes( 

3015 self, 

3016 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3017 batch_size: Optional[int] = None, 

3018 *, 

3019 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

3020 ) -> _AxisSizes: 

3021 """Determine input and output block shape for scale factors **ns** 

3022 of parameterized input sizes. 

3023 

3024 Args: 

3025 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

3026 that is parameterized as `size = min + n * step`. 

3027 batch_size: The desired size of the batch dimension. 

3028 If given **batch_size** overwrites any batch size present in 

3029 **max_input_shape**. Default 1. 

3030 max_input_shape: Limits the derived block shapes. 

3031 Each axis for which the input size, parameterized by `n`, is larger 

3032 than **max_input_shape** is set to the minimal value `n_min` for which 

3033 this is still true. 

3034 Use this for small input samples or large values of **ns**. 

3035 Or simply whenever you know the full input shape. 

3036 

3037 Returns: 

3038 Resolved axis sizes for model inputs and outputs. 

3039 """ 

3040 max_input_shape = max_input_shape or {} 

3041 if batch_size is None: 

3042 for (_t_id, a_id), s in max_input_shape.items(): 

3043 if a_id == BATCH_AXIS_ID: 

3044 batch_size = s 

3045 break 

3046 else: 

3047 batch_size = 1 

3048 

3049 all_axes = { 

3050 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

3051 } 

3052 

3053 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

3054 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

3055 

3056 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

3057 if isinstance(a, BatchAxis): 

3058 if (t_descr.id, a.id) in ns: 

3059 logger.warning( 

3060 "Ignoring unexpected size increment factor (n) for batch axis" 

3061 + " of tensor '{}'.", 

3062 t_descr.id, 

3063 ) 

3064 return batch_size 

3065 elif isinstance(a.size, int): 

3066 if (t_descr.id, a.id) in ns: 

3067 logger.warning( 

3068 "Ignoring unexpected size increment factor (n) for fixed size" 

3069 + " axis '{}' of tensor '{}'.", 

3070 a.id, 

3071 t_descr.id, 

3072 ) 

3073 return a.size 

3074 elif isinstance(a.size, ParameterizedSize): 

3075 if (t_descr.id, a.id) not in ns: 

3076 raise ValueError( 

3077 "Size increment factor (n) missing for parametrized axis" 

3078 + f" '{a.id}' of tensor '{t_descr.id}'." 

3079 ) 

3080 n = ns[(t_descr.id, a.id)] 

3081 s_max = max_input_shape.get((t_descr.id, a.id)) 

3082 if s_max is not None: 

3083 n = min(n, a.size.get_n(s_max)) 

3084 

3085 return a.size.get_size(n) 

3086 

3087 elif isinstance(a.size, SizeReference): 

3088 if (t_descr.id, a.id) in ns: 

3089 logger.warning( 

3090 "Ignoring unexpected size increment factor (n) for axis '{}'" 

3091 + " of tensor '{}' with size reference.", 

3092 a.id, 

3093 t_descr.id, 

3094 ) 

3095 assert not isinstance(a, BatchAxis) 

3096 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

3097 assert not isinstance(ref_axis, BatchAxis) 

3098 ref_key = (a.size.tensor_id, a.size.axis_id) 

3099 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

3100 assert ref_size is not None, ref_key 

3101 assert not isinstance(ref_size, _DataDepSize), ref_key 

3102 return a.size.get_size( 

3103 axis=a, 

3104 ref_axis=ref_axis, 

3105 ref_size=ref_size, 

3106 ) 

3107 elif isinstance(a.size, DataDependentSize): 

3108 if (t_descr.id, a.id) in ns: 

3109 logger.warning( 

3110 "Ignoring unexpected increment factor (n) for data dependent" 

3111 + " size axis '{}' of tensor '{}'.", 

3112 a.id, 

3113 t_descr.id, 

3114 ) 

3115 return _DataDepSize(a.size.min, a.size.max) 

3116 else: 

3117 assert_never(a.size) 

3118 

3119 # first resolve all , but the `SizeReference` input sizes 

3120 for t_descr in self.inputs: 

3121 for a in t_descr.axes: 

3122 if not isinstance(a.size, SizeReference): 

3123 s = get_axis_size(a) 

3124 assert not isinstance(s, _DataDepSize) 

3125 inputs[t_descr.id, a.id] = s 

3126 

3127 # resolve all other input axis sizes 

3128 for t_descr in self.inputs: 

3129 for a in t_descr.axes: 

3130 if isinstance(a.size, SizeReference): 

3131 s = get_axis_size(a) 

3132 assert not isinstance(s, _DataDepSize) 

3133 inputs[t_descr.id, a.id] = s 

3134 

3135 # resolve all output axis sizes 

3136 for t_descr in self.outputs: 

3137 for a in t_descr.axes: 

3138 assert not isinstance(a.size, ParameterizedSize) 

3139 s = get_axis_size(a) 

3140 outputs[t_descr.id, a.id] = s 

3141 

3142 return _AxisSizes(inputs=inputs, outputs=outputs) 

3143 

3144 @model_validator(mode="before") 

3145 @classmethod 

3146 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

3147 cls.convert_from_old_format_wo_validation(data) 

3148 return data 

3149 

3150 @classmethod 

3151 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

3152 """Convert metadata following an older format version to this classes' format 

3153 without validating the result. 

3154 """ 

3155 if ( 

3156 data.get("type") == "model" 

3157 and isinstance(fv := data.get("format_version"), str) 

3158 and fv.count(".") == 2 

3159 ): 

3160 fv_parts = fv.split(".") 

3161 if any(not p.isdigit() for p in fv_parts): 

3162 return 

3163 

3164 fv_tuple = tuple(map(int, fv_parts)) 

3165 

3166 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

3167 if fv_tuple[:2] in ((0, 3), (0, 4)): 

3168 m04 = _ModelDescr_v0_4.load(data) 

3169 if isinstance(m04, InvalidDescr): 

3170 try: 

3171 updated = _model_conv.convert_as_dict( 

3172 m04 # pyright: ignore[reportArgumentType] 

3173 ) 

3174 except Exception as e: 

3175 logger.error( 

3176 "Failed to convert from invalid model 0.4 description." 

3177 + f"\nerror: {e}" 

3178 + "\nProceeding with model 0.5 validation without conversion." 

3179 ) 

3180 updated = None 

3181 else: 

3182 updated = _model_conv.convert_as_dict(m04) 

3183 

3184 if updated is not None: 

3185 data.clear() 

3186 data.update(updated) 

3187 

3188 elif fv_tuple[:2] == (0, 5): 

3189 # bump patch version 

3190 data["format_version"] = cls.implemented_format_version 

3191 

3192 

3193class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

3194 def _convert( 

3195 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

3196 ) -> "ModelDescr | dict[str, Any]": 

3197 name = "".join( 

3198 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

3199 for c in src.name 

3200 ) 

3201 

3202 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

3203 conv = ( 

3204 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

3205 ) 

3206 return None if auths is None else [conv(a) for a in auths] 

3207 

3208 if TYPE_CHECKING: 

3209 arch_file_conv = _arch_file_conv.convert 

3210 arch_lib_conv = _arch_lib_conv.convert 

3211 else: 

3212 arch_file_conv = _arch_file_conv.convert_as_dict 

3213 arch_lib_conv = _arch_lib_conv.convert_as_dict 

3214 

3215 input_size_refs = { 

3216 ipt.name: { 

3217 a: s 

3218 for a, s in zip( 

3219 ipt.axes, 

3220 ( 

3221 ipt.shape.min 

3222 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

3223 else ipt.shape 

3224 ), 

3225 ) 

3226 } 

3227 for ipt in src.inputs 

3228 if ipt.shape 

3229 } 

3230 output_size_refs = { 

3231 **{ 

3232 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

3233 for out in src.outputs 

3234 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

3235 }, 

3236 **input_size_refs, 

3237 } 

3238 

3239 return tgt( 

3240 attachments=( 

3241 [] 

3242 if src.attachments is None 

3243 else [FileDescr(source=f) for f in src.attachments.files] 

3244 ), 

3245 authors=[ 

3246 _author_conv.convert_as_dict(a) for a in src.authors 

3247 ], # pyright: ignore[reportArgumentType] 

3248 cite=[ 

3249 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 

3250 ], # pyright: ignore[reportArgumentType] 

3251 config=src.config, # pyright: ignore[reportArgumentType] 

3252 covers=src.covers, 

3253 description=src.description, 

3254 documentation=src.documentation, 

3255 format_version="0.5.4", 

3256 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3257 icon=src.icon, 

3258 id=None if src.id is None else ModelId(src.id), 

3259 id_emoji=src.id_emoji, 

3260 license=src.license, # type: ignore 

3261 links=src.links, 

3262 maintainers=[ 

3263 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 

3264 ], # pyright: ignore[reportArgumentType] 

3265 name=name, 

3266 tags=src.tags, 

3267 type=src.type, 

3268 uploader=src.uploader, 

3269 version=src.version, 

3270 inputs=[ # pyright: ignore[reportArgumentType] 

3271 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

3272 for ipt, tt, st, in zip( 

3273 src.inputs, 

3274 src.test_inputs, 

3275 src.sample_inputs or [None] * len(src.test_inputs), 

3276 ) 

3277 ], 

3278 outputs=[ # pyright: ignore[reportArgumentType] 

3279 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

3280 for out, tt, st, in zip( 

3281 src.outputs, 

3282 src.test_outputs, 

3283 src.sample_outputs or [None] * len(src.test_outputs), 

3284 ) 

3285 ], 

3286 parent=( 

3287 None 

3288 if src.parent is None 

3289 else LinkedModel( 

3290 id=ModelId( 

3291 str(src.parent.id) 

3292 + ( 

3293 "" 

3294 if src.parent.version_number is None 

3295 else f"/{src.parent.version_number}" 

3296 ) 

3297 ) 

3298 ) 

3299 ), 

3300 training_data=( 

3301 None 

3302 if src.training_data is None 

3303 else ( 

3304 LinkedDataset( 

3305 id=DatasetId( 

3306 str(src.training_data.id) 

3307 + ( 

3308 "" 

3309 if src.training_data.version_number is None 

3310 else f"/{src.training_data.version_number}" 

3311 ) 

3312 ) 

3313 ) 

3314 if isinstance(src.training_data, LinkedDataset02) 

3315 else src.training_data 

3316 ) 

3317 ), 

3318 packaged_by=[ 

3319 _author_conv.convert_as_dict(a) for a in src.packaged_by 

3320 ], # pyright: ignore[reportArgumentType] 

3321 run_mode=src.run_mode, 

3322 timestamp=src.timestamp, 

3323 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

3324 keras_hdf5=(w := src.weights.keras_hdf5) 

3325 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

3326 authors=conv_authors(w.authors), 

3327 source=w.source, 

3328 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3329 parent=w.parent, 

3330 ), 

3331 onnx=(w := src.weights.onnx) 

3332 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

3333 source=w.source, 

3334 authors=conv_authors(w.authors), 

3335 parent=w.parent, 

3336 opset_version=w.opset_version or 15, 

3337 ), 

3338 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

3339 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

3340 source=w.source, 

3341 authors=conv_authors(w.authors), 

3342 parent=w.parent, 

3343 architecture=( 

3344 arch_file_conv( 

3345 w.architecture, 

3346 w.architecture_sha256, 

3347 w.kwargs, 

3348 ) 

3349 if isinstance(w.architecture, _CallableFromFile_v0_4) 

3350 else arch_lib_conv(w.architecture, w.kwargs) 

3351 ), 

3352 pytorch_version=w.pytorch_version or Version("1.10"), 

3353 dependencies=( 

3354 None 

3355 if w.dependencies is None 

3356 else (FileDescr if TYPE_CHECKING else dict)( 

3357 source=cast( 

3358 FileSource, 

3359 str(deps := w.dependencies)[ 

3360 ( 

3361 len("conda:") 

3362 if str(deps).startswith("conda:") 

3363 else 0 

3364 ) : 

3365 ], 

3366 ) 

3367 ) 

3368 ), 

3369 ), 

3370 tensorflow_js=(w := src.weights.tensorflow_js) 

3371 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

3372 source=w.source, 

3373 authors=conv_authors(w.authors), 

3374 parent=w.parent, 

3375 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3376 ), 

3377 tensorflow_saved_model_bundle=( 

3378 w := src.weights.tensorflow_saved_model_bundle 

3379 ) 

3380 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

3381 authors=conv_authors(w.authors), 

3382 parent=w.parent, 

3383 source=w.source, 

3384 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3385 dependencies=( 

3386 None 

3387 if w.dependencies is None 

3388 else (FileDescr if TYPE_CHECKING else dict)( 

3389 source=cast( 

3390 FileSource, 

3391 ( 

3392 str(w.dependencies)[len("conda:") :] 

3393 if str(w.dependencies).startswith("conda:") 

3394 else str(w.dependencies) 

3395 ), 

3396 ) 

3397 ) 

3398 ), 

3399 ), 

3400 torchscript=(w := src.weights.torchscript) 

3401 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

3402 source=w.source, 

3403 authors=conv_authors(w.authors), 

3404 parent=w.parent, 

3405 pytorch_version=w.pytorch_version or Version("1.10"), 

3406 ), 

3407 ), 

3408 ) 

3409 

3410 

3411_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

3412 

3413 

3414# create better cover images for 3d data and non-image outputs 

3415def generate_covers( 

3416 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

3417 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

3418) -> List[Path]: 

3419 def squeeze( 

3420 data: NDArray[Any], axes: Sequence[AnyAxis] 

3421 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

3422 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

3423 if data.ndim != len(axes): 

3424 raise ValueError( 

3425 f"tensor shape {data.shape} does not match described axes" 

3426 + f" {[a.id for a in axes]}" 

3427 ) 

3428 

3429 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

3430 return data.squeeze(), axes 

3431 

3432 def normalize( 

3433 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

3434 ) -> NDArray[np.float32]: 

3435 data = data.astype("float32") 

3436 data -= data.min(axis=axis, keepdims=True) 

3437 data /= data.max(axis=axis, keepdims=True) + eps 

3438 return data 

3439 

3440 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

3441 original_shape = data.shape 

3442 data, axes = squeeze(data, axes) 

3443 

3444 # take slice fom any batch or index axis if needed 

3445 # and convert the first channel axis and take a slice from any additional channel axes 

3446 slices: Tuple[slice, ...] = () 

3447 ndim = data.ndim 

3448 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

3449 has_c_axis = False 

3450 for i, a in enumerate(axes): 

3451 s = data.shape[i] 

3452 assert s > 1 

3453 if ( 

3454 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

3455 and ndim > ndim_need 

3456 ): 

3457 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3458 ndim -= 1 

3459 elif isinstance(a, ChannelAxis): 

3460 if has_c_axis: 

3461 # second channel axis 

3462 data = data[slices + (slice(0, 1),)] 

3463 ndim -= 1 

3464 else: 

3465 has_c_axis = True 

3466 if s == 2: 

3467 # visualize two channels with cyan and magenta 

3468 data = np.concatenate( 

3469 [ 

3470 data[slices + (slice(1, 2),)], 

3471 data[slices + (slice(0, 1),)], 

3472 ( 

3473 data[slices + (slice(0, 1),)] 

3474 + data[slices + (slice(1, 2),)] 

3475 ) 

3476 / 2, # TODO: take maximum instead? 

3477 ], 

3478 axis=i, 

3479 ) 

3480 elif data.shape[i] == 3: 

3481 pass # visualize 3 channels as RGB 

3482 else: 

3483 # visualize first 3 channels as RGB 

3484 data = data[slices + (slice(3),)] 

3485 

3486 assert data.shape[i] == 3 

3487 

3488 slices += (slice(None),) 

3489 

3490 data, axes = squeeze(data, axes) 

3491 assert len(axes) == ndim 

3492 # take slice from z axis if needed 

3493 slices = () 

3494 if ndim > ndim_need: 

3495 for i, a in enumerate(axes): 

3496 s = data.shape[i] 

3497 if a.id == AxisId("z"): 

3498 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3499 data, axes = squeeze(data, axes) 

3500 ndim -= 1 

3501 break 

3502 

3503 slices += (slice(None),) 

3504 

3505 # take slice from any space or time axis 

3506 slices = () 

3507 

3508 for i, a in enumerate(axes): 

3509 if ndim <= ndim_need: 

3510 break 

3511 

3512 s = data.shape[i] 

3513 assert s > 1 

3514 if isinstance( 

3515 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

3516 ): 

3517 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3518 ndim -= 1 

3519 

3520 slices += (slice(None),) 

3521 

3522 del slices 

3523 data, axes = squeeze(data, axes) 

3524 assert len(axes) == ndim 

3525 

3526 if (has_c_axis and ndim != 3) or ndim != 2: 

3527 raise ValueError( 

3528 f"Failed to construct cover image from shape {original_shape}" 

3529 ) 

3530 

3531 if not has_c_axis: 

3532 assert ndim == 2 

3533 data = np.repeat(data[:, :, None], 3, axis=2) 

3534 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

3535 ndim += 1 

3536 

3537 assert ndim == 3 

3538 

3539 # transpose axis order such that longest axis comes first... 

3540 axis_order: List[int] = list(np.argsort(list(data.shape))) 

3541 axis_order.reverse() 

3542 # ... and channel axis is last 

3543 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

3544 axis_order.append(axis_order.pop(c)) 

3545 axes = [axes[ao] for ao in axis_order] 

3546 data = data.transpose(axis_order) 

3547 

3548 # h, w = data.shape[:2] 

3549 # if h / w in (1.0 or 2.0): 

3550 # pass 

3551 # elif h / w < 2: 

3552 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

3553 

3554 norm_along = ( 

3555 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

3556 ) 

3557 # normalize the data and map to 8 bit 

3558 data = normalize(data, norm_along) 

3559 data = (data * 255).astype("uint8") 

3560 

3561 return data 

3562 

3563 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

3564 assert im0.dtype == im1.dtype == np.uint8 

3565 assert im0.shape == im1.shape 

3566 assert im0.ndim == 3 

3567 N, M, C = im0.shape 

3568 assert C == 3 

3569 out = np.ones((N, M, C), dtype="uint8") 

3570 for c in range(C): 

3571 outc = np.tril(im0[..., c]) 

3572 mask = outc == 0 

3573 outc[mask] = np.triu(im1[..., c])[mask] 

3574 out[..., c] = outc 

3575 

3576 return out 

3577 

3578 ipt_descr, ipt = inputs[0] 

3579 out_descr, out = outputs[0] 

3580 

3581 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

3582 out_img = to_2d_image(out, out_descr.axes) 

3583 

3584 cover_folder = Path(mkdtemp()) 

3585 if ipt_img.shape == out_img.shape: 

3586 covers = [cover_folder / "cover.png"] 

3587 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

3588 else: 

3589 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

3590 imwrite(covers[0], ipt_img) 

3591 imwrite(covers[1], out_img) 

3592 

3593 return covers