Coverage for bioimageio/spec/model/v0_5.py: 75%

1351 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-12 17:44 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from abc import ABC 

8from copy import deepcopy 

9from itertools import chain 

10from math import ceil 

11from pathlib import Path, PurePosixPath 

12from tempfile import mkdtemp 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 ClassVar, 

18 Dict, 

19 Generic, 

20 List, 

21 Literal, 

22 Mapping, 

23 NamedTuple, 

24 Optional, 

25 Sequence, 

26 Set, 

27 Tuple, 

28 Type, 

29 TypeVar, 

30 Union, 

31 cast, 

32) 

33 

34import numpy as np 

35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

37from loguru import logger 

38from numpy.typing import NDArray 

39from pydantic import ( 

40 AfterValidator, 

41 Discriminator, 

42 Field, 

43 RootModel, 

44 SerializationInfo, 

45 SerializerFunctionWrapHandler, 

46 StrictInt, 

47 Tag, 

48 ValidationInfo, 

49 WrapSerializer, 

50 field_validator, 

51 model_serializer, 

52 model_validator, 

53) 

54from typing_extensions import Annotated, Self, assert_never, get_args 

55 

56from .._internal.common_nodes import ( 

57 InvalidDescr, 

58 Node, 

59 NodeWithExplicitlySetFields, 

60) 

61from .._internal.constants import DTYPE_LIMITS 

62from .._internal.field_warning import issue_warning, warn 

63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

64from .._internal.io import FileDescr as FileDescr 

65from .._internal.io import ( 

66 FileSource, 

67 WithSuffix, 

68 YamlValue, 

69 get_reader, 

70 wo_special_file_name, 

71) 

72from .._internal.io_basics import Sha256 as Sha256 

73from .._internal.io_packaging import ( 

74 FileDescr_, 

75 FileSource_, 

76 package_file_descr_serializer, 

77) 

78from .._internal.io_utils import load_array 

79from .._internal.node_converter import Converter 

80from .._internal.type_guards import is_dict, is_sequence 

81from .._internal.types import ( 

82 FAIR, 

83 AbsoluteTolerance, 

84 LowerCaseIdentifier, 

85 LowerCaseIdentifierAnno, 

86 MismatchedElementsPerMillion, 

87 RelativeTolerance, 

88) 

89from .._internal.types import Datetime as Datetime 

90from .._internal.types import Identifier as Identifier 

91from .._internal.types import NotEmpty as NotEmpty 

92from .._internal.types import SiUnit as SiUnit 

93from .._internal.url import HttpUrl as HttpUrl 

94from .._internal.validation_context import get_validation_context 

95from .._internal.validator_annotations import RestrictCharacters 

96from .._internal.version_type import Version as Version 

97from .._internal.warning_levels import INFO 

98from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

99from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

100from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

101from ..dataset.v0_3 import DatasetId as DatasetId 

102from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

103from ..dataset.v0_3 import Uploader as Uploader 

104from ..generic.v0_3 import ( 

105 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

106) 

107from ..generic.v0_3 import Author as Author 

108from ..generic.v0_3 import BadgeDescr as BadgeDescr 

109from ..generic.v0_3 import CiteEntry as CiteEntry 

110from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

111from ..generic.v0_3 import Doi as Doi 

112from ..generic.v0_3 import ( 

113 FileSource_documentation, 

114 GenericModelDescrBase, 

115 LinkedResourceBase, 

116 _author_conv, # pyright: ignore[reportPrivateUsage] 

117 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

118) 

119from ..generic.v0_3 import LicenseId as LicenseId 

120from ..generic.v0_3 import LinkedResource as LinkedResource 

121from ..generic.v0_3 import Maintainer as Maintainer 

122from ..generic.v0_3 import OrcidId as OrcidId 

123from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

124from ..generic.v0_3 import ResourceId as ResourceId 

125from .v0_4 import Author as _Author_v0_4 

126from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

127from .v0_4 import CallableFromDepencency as CallableFromDepencency 

128from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

129from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

130from .v0_4 import ClipDescr as _ClipDescr_v0_4 

131from .v0_4 import ClipKwargs as ClipKwargs 

132from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

133from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

134from .v0_4 import KnownRunMode as KnownRunMode 

135from .v0_4 import ModelDescr as _ModelDescr_v0_4 

136from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

137from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

138from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

139from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

140from .v0_4 import ProcessingKwargs as ProcessingKwargs 

141from .v0_4 import RunMode as RunMode 

142from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

143from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

144from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

145from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

146from .v0_4 import TensorName as _TensorName_v0_4 

147from .v0_4 import WeightsFormat as WeightsFormat 

148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

149from .v0_4 import package_weights 

150 

151SpaceUnit = Literal[ 

152 "attometer", 

153 "angstrom", 

154 "centimeter", 

155 "decimeter", 

156 "exameter", 

157 "femtometer", 

158 "foot", 

159 "gigameter", 

160 "hectometer", 

161 "inch", 

162 "kilometer", 

163 "megameter", 

164 "meter", 

165 "micrometer", 

166 "mile", 

167 "millimeter", 

168 "nanometer", 

169 "parsec", 

170 "petameter", 

171 "picometer", 

172 "terameter", 

173 "yard", 

174 "yoctometer", 

175 "yottameter", 

176 "zeptometer", 

177 "zettameter", 

178] 

179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

180 

181TimeUnit = Literal[ 

182 "attosecond", 

183 "centisecond", 

184 "day", 

185 "decisecond", 

186 "exasecond", 

187 "femtosecond", 

188 "gigasecond", 

189 "hectosecond", 

190 "hour", 

191 "kilosecond", 

192 "megasecond", 

193 "microsecond", 

194 "millisecond", 

195 "minute", 

196 "nanosecond", 

197 "petasecond", 

198 "picosecond", 

199 "second", 

200 "terasecond", 

201 "yoctosecond", 

202 "yottasecond", 

203 "zeptosecond", 

204 "zettasecond", 

205] 

206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

207 

208AxisType = Literal["batch", "channel", "index", "time", "space"] 

209 

210_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

211 "b": "batch", 

212 "t": "time", 

213 "i": "index", 

214 "c": "channel", 

215 "x": "space", 

216 "y": "space", 

217 "z": "space", 

218} 

219 

220_AXIS_ID_MAP = { 

221 "b": "batch", 

222 "t": "time", 

223 "i": "index", 

224 "c": "channel", 

225} 

226 

227 

228class TensorId(LowerCaseIdentifier): 

229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

230 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

231 ] 

232 

233 

234def _normalize_axis_id(a: str): 

235 a = str(a) 

236 normalized = _AXIS_ID_MAP.get(a, a) 

237 if a != normalized: 

238 logger.opt(depth=3).warning( 

239 "Normalized axis id from '{}' to '{}'.", a, normalized 

240 ) 

241 return normalized 

242 

243 

244class AxisId(LowerCaseIdentifier): 

245 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

246 Annotated[ 

247 LowerCaseIdentifierAnno, 

248 MaxLen(16), 

249 AfterValidator(_normalize_axis_id), 

250 ] 

251 ] 

252 

253 

254def _is_batch(a: str) -> bool: 

255 return str(a) == "batch" 

256 

257 

258def _is_not_batch(a: str) -> bool: 

259 return not _is_batch(a) 

260 

261 

262NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

263 

264PreprocessingId = Literal[ 

265 "binarize", 

266 "clip", 

267 "ensure_dtype", 

268 "fixed_zero_mean_unit_variance", 

269 "scale_linear", 

270 "scale_range", 

271 "sigmoid", 

272 "softmax", 

273] 

274PostprocessingId = Literal[ 

275 "binarize", 

276 "clip", 

277 "ensure_dtype", 

278 "fixed_zero_mean_unit_variance", 

279 "scale_linear", 

280 "scale_mean_variance", 

281 "scale_range", 

282 "sigmoid", 

283 "softmax", 

284 "zero_mean_unit_variance", 

285] 

286 

287 

288SAME_AS_TYPE = "<same as type>" 

289 

290 

291ParameterizedSize_N = int 

292""" 

293Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

294""" 

295 

296 

297class ParameterizedSize(Node): 

298 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

299 

300 - **min** and **step** are given by the model description. 

301 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

302 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

303 This allows to adjust the axis size more generically. 

304 """ 

305 

306 N: ClassVar[Type[int]] = ParameterizedSize_N 

307 """Positive integer to parameterize this axis""" 

308 

309 min: Annotated[int, Gt(0)] 

310 step: Annotated[int, Gt(0)] 

311 

312 def validate_size(self, size: int) -> int: 

313 if size < self.min: 

314 raise ValueError(f"size {size} < {self.min}") 

315 if (size - self.min) % self.step != 0: 

316 raise ValueError( 

317 f"axis of size {size} is not parameterized by `min + n*step` =" 

318 + f" `{self.min} + n*{self.step}`" 

319 ) 

320 

321 return size 

322 

323 def get_size(self, n: ParameterizedSize_N) -> int: 

324 return self.min + self.step * n 

325 

326 def get_n(self, s: int) -> ParameterizedSize_N: 

327 """return smallest n parameterizing a size greater or equal than `s`""" 

328 return ceil((s - self.min) / self.step) 

329 

330 

331class DataDependentSize(Node): 

332 min: Annotated[int, Gt(0)] = 1 

333 max: Annotated[Optional[int], Gt(1)] = None 

334 

335 @model_validator(mode="after") 

336 def _validate_max_gt_min(self): 

337 if self.max is not None and self.min >= self.max: 

338 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

339 

340 return self 

341 

342 def validate_size(self, size: int) -> int: 

343 if size < self.min: 

344 raise ValueError(f"size {size} < {self.min}") 

345 

346 if self.max is not None and size > self.max: 

347 raise ValueError(f"size {size} > {self.max}") 

348 

349 return size 

350 

351 

352class SizeReference(Node): 

353 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

354 

355 `axis.size = reference.size * reference.scale / axis.scale + offset` 

356 

357 Note: 

358 1. The axis and the referenced axis need to have the same unit (or no unit). 

359 2. Batch axes may not be referenced. 

360 3. Fractions are rounded down. 

361 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

362 `concatenable` as well with the same block order. 

363 

364 Example: 

365 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

366 Let's assume that we want to express the image height h in relation to its width w 

367 instead of only accepting input images of exactly 100*49 pixels 

368 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

369 

370 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

371 >>> h = SpaceInputAxis( 

372 ... id=AxisId("h"), 

373 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

374 ... unit="millimeter", 

375 ... scale=4, 

376 ... ) 

377 >>> print(h.size.get_size(h, w)) 

378 49 

379 

380 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

381 """ 

382 

383 tensor_id: TensorId 

384 """tensor id of the reference axis""" 

385 

386 axis_id: AxisId 

387 """axis id of the reference axis""" 

388 

389 offset: StrictInt = 0 

390 

391 def get_size( 

392 self, 

393 axis: Union[ 

394 ChannelAxis, 

395 IndexInputAxis, 

396 IndexOutputAxis, 

397 TimeInputAxis, 

398 SpaceInputAxis, 

399 TimeOutputAxis, 

400 TimeOutputAxisWithHalo, 

401 SpaceOutputAxis, 

402 SpaceOutputAxisWithHalo, 

403 ], 

404 ref_axis: Union[ 

405 ChannelAxis, 

406 IndexInputAxis, 

407 IndexOutputAxis, 

408 TimeInputAxis, 

409 SpaceInputAxis, 

410 TimeOutputAxis, 

411 TimeOutputAxisWithHalo, 

412 SpaceOutputAxis, 

413 SpaceOutputAxisWithHalo, 

414 ], 

415 n: ParameterizedSize_N = 0, 

416 ref_size: Optional[int] = None, 

417 ): 

418 """Compute the concrete size for a given axis and its reference axis. 

419 

420 Args: 

421 axis: The axis this `SizeReference` is the size of. 

422 ref_axis: The reference axis to compute the size from. 

423 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

424 and no fixed **ref_size** is given, 

425 **n** is used to compute the size of the parameterized **ref_axis**. 

426 ref_size: Overwrite the reference size instead of deriving it from 

427 **ref_axis** 

428 (**ref_axis.scale** is still used; any given **n** is ignored). 

429 """ 

430 assert ( 

431 axis.size == self 

432 ), "Given `axis.size` is not defined by this `SizeReference`" 

433 

434 assert ( 

435 ref_axis.id == self.axis_id 

436 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

437 

438 assert axis.unit == ref_axis.unit, ( 

439 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

440 f" but {axis.unit}!={ref_axis.unit}" 

441 ) 

442 if ref_size is None: 

443 if isinstance(ref_axis.size, (int, float)): 

444 ref_size = ref_axis.size 

445 elif isinstance(ref_axis.size, ParameterizedSize): 

446 ref_size = ref_axis.size.get_size(n) 

447 elif isinstance(ref_axis.size, DataDependentSize): 

448 raise ValueError( 

449 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

450 ) 

451 elif isinstance(ref_axis.size, SizeReference): 

452 raise ValueError( 

453 "Reference axis referenced in `SizeReference` may not be sized by a" 

454 + " `SizeReference` itself." 

455 ) 

456 else: 

457 assert_never(ref_axis.size) 

458 

459 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

460 

461 @staticmethod 

462 def _get_unit( 

463 axis: Union[ 

464 ChannelAxis, 

465 IndexInputAxis, 

466 IndexOutputAxis, 

467 TimeInputAxis, 

468 SpaceInputAxis, 

469 TimeOutputAxis, 

470 TimeOutputAxisWithHalo, 

471 SpaceOutputAxis, 

472 SpaceOutputAxisWithHalo, 

473 ], 

474 ): 

475 return axis.unit 

476 

477 

478class AxisBase(NodeWithExplicitlySetFields): 

479 id: AxisId 

480 """An axis id unique across all axes of one tensor.""" 

481 

482 description: Annotated[str, MaxLen(128)] = "" 

483 """A short description of this axis beyond its type and id.""" 

484 

485 

486class WithHalo(Node): 

487 halo: Annotated[int, Ge(1)] 

488 """The halo should be cropped from the output tensor to avoid boundary effects. 

489 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

490 To document a halo that is already cropped by the model use `size.offset` instead.""" 

491 

492 size: Annotated[ 

493 SizeReference, 

494 Field( 

495 examples=[ 

496 10, 

497 SizeReference( 

498 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

499 ).model_dump(mode="json"), 

500 ] 

501 ), 

502 ] 

503 """reference to another axis with an optional offset (see `SizeReference`)""" 

504 

505 

506BATCH_AXIS_ID = AxisId("batch") 

507 

508 

509class BatchAxis(AxisBase): 

510 implemented_type: ClassVar[Literal["batch"]] = "batch" 

511 if TYPE_CHECKING: 

512 type: Literal["batch"] = "batch" 

513 else: 

514 type: Literal["batch"] 

515 

516 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

517 size: Optional[Literal[1]] = None 

518 """The batch size may be fixed to 1, 

519 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

520 

521 @property 

522 def scale(self): 

523 return 1.0 

524 

525 @property 

526 def concatenable(self): 

527 return True 

528 

529 @property 

530 def unit(self): 

531 return None 

532 

533 

534class ChannelAxis(AxisBase): 

535 implemented_type: ClassVar[Literal["channel"]] = "channel" 

536 if TYPE_CHECKING: 

537 type: Literal["channel"] = "channel" 

538 else: 

539 type: Literal["channel"] 

540 

541 id: NonBatchAxisId = AxisId("channel") 

542 

543 channel_names: NotEmpty[List[Identifier]] 

544 

545 @property 

546 def size(self) -> int: 

547 return len(self.channel_names) 

548 

549 @property 

550 def concatenable(self): 

551 return False 

552 

553 @property 

554 def scale(self) -> float: 

555 return 1.0 

556 

557 @property 

558 def unit(self): 

559 return None 

560 

561 

562class IndexAxisBase(AxisBase): 

563 implemented_type: ClassVar[Literal["index"]] = "index" 

564 if TYPE_CHECKING: 

565 type: Literal["index"] = "index" 

566 else: 

567 type: Literal["index"] 

568 

569 id: NonBatchAxisId = AxisId("index") 

570 

571 @property 

572 def scale(self) -> float: 

573 return 1.0 

574 

575 @property 

576 def unit(self): 

577 return None 

578 

579 

580class _WithInputAxisSize(Node): 

581 size: Annotated[ 

582 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

583 Field( 

584 examples=[ 

585 10, 

586 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

587 SizeReference( 

588 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

589 ).model_dump(mode="json"), 

590 ] 

591 ), 

592 ] 

593 """The size/length of this axis can be specified as 

594 - fixed integer 

595 - parameterized series of valid sizes (`ParameterizedSize`) 

596 - reference to another axis with an optional offset (`SizeReference`) 

597 """ 

598 

599 

600class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

601 concatenable: bool = False 

602 """If a model has a `concatenable` input axis, it can be processed blockwise, 

603 splitting a longer sample axis into blocks matching its input tensor description. 

604 Output axes are concatenable if they have a `SizeReference` to a concatenable 

605 input axis. 

606 """ 

607 

608 

609class IndexOutputAxis(IndexAxisBase): 

610 size: Annotated[ 

611 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

612 Field( 

613 examples=[ 

614 10, 

615 SizeReference( 

616 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

617 ).model_dump(mode="json"), 

618 ] 

619 ), 

620 ] 

621 """The size/length of this axis can be specified as 

622 - fixed integer 

623 - reference to another axis with an optional offset (`SizeReference`) 

624 - data dependent size using `DataDependentSize` (size is only known after model inference) 

625 """ 

626 

627 

628class TimeAxisBase(AxisBase): 

629 implemented_type: ClassVar[Literal["time"]] = "time" 

630 if TYPE_CHECKING: 

631 type: Literal["time"] = "time" 

632 else: 

633 type: Literal["time"] 

634 

635 id: NonBatchAxisId = AxisId("time") 

636 unit: Optional[TimeUnit] = None 

637 scale: Annotated[float, Gt(0)] = 1.0 

638 

639 

640class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

641 concatenable: bool = False 

642 """If a model has a `concatenable` input axis, it can be processed blockwise, 

643 splitting a longer sample axis into blocks matching its input tensor description. 

644 Output axes are concatenable if they have a `SizeReference` to a concatenable 

645 input axis. 

646 """ 

647 

648 

649class SpaceAxisBase(AxisBase): 

650 implemented_type: ClassVar[Literal["space"]] = "space" 

651 if TYPE_CHECKING: 

652 type: Literal["space"] = "space" 

653 else: 

654 type: Literal["space"] 

655 

656 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

657 unit: Optional[SpaceUnit] = None 

658 scale: Annotated[float, Gt(0)] = 1.0 

659 

660 

661class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

662 concatenable: bool = False 

663 """If a model has a `concatenable` input axis, it can be processed blockwise, 

664 splitting a longer sample axis into blocks matching its input tensor description. 

665 Output axes are concatenable if they have a `SizeReference` to a concatenable 

666 input axis. 

667 """ 

668 

669 

670INPUT_AXIS_TYPES = ( 

671 BatchAxis, 

672 ChannelAxis, 

673 IndexInputAxis, 

674 TimeInputAxis, 

675 SpaceInputAxis, 

676) 

677"""intended for isinstance comparisons in py<3.10""" 

678 

679_InputAxisUnion = Union[ 

680 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

681] 

682InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

683 

684 

685class _WithOutputAxisSize(Node): 

686 size: Annotated[ 

687 Union[Annotated[int, Gt(0)], SizeReference], 

688 Field( 

689 examples=[ 

690 10, 

691 SizeReference( 

692 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

693 ).model_dump(mode="json"), 

694 ] 

695 ), 

696 ] 

697 """The size/length of this axis can be specified as 

698 - fixed integer 

699 - reference to another axis with an optional offset (see `SizeReference`) 

700 """ 

701 

702 

703class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

704 pass 

705 

706 

707class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

708 pass 

709 

710 

711def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

712 if isinstance(v, dict): 

713 return "with_halo" if "halo" in v else "wo_halo" 

714 else: 

715 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

716 

717 

718_TimeOutputAxisUnion = Annotated[ 

719 Union[ 

720 Annotated[TimeOutputAxis, Tag("wo_halo")], 

721 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

722 ], 

723 Discriminator(_get_halo_axis_discriminator_value), 

724] 

725 

726 

727class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

728 pass 

729 

730 

731class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

732 pass 

733 

734 

735_SpaceOutputAxisUnion = Annotated[ 

736 Union[ 

737 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

738 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

739 ], 

740 Discriminator(_get_halo_axis_discriminator_value), 

741] 

742 

743 

744_OutputAxisUnion = Union[ 

745 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

746] 

747OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

748 

749OUTPUT_AXIS_TYPES = ( 

750 BatchAxis, 

751 ChannelAxis, 

752 IndexOutputAxis, 

753 TimeOutputAxis, 

754 TimeOutputAxisWithHalo, 

755 SpaceOutputAxis, 

756 SpaceOutputAxisWithHalo, 

757) 

758"""intended for isinstance comparisons in py<3.10""" 

759 

760 

761AnyAxis = Union[InputAxis, OutputAxis] 

762 

763ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

764"""intended for isinstance comparisons in py<3.10""" 

765 

766TVs = Union[ 

767 NotEmpty[List[int]], 

768 NotEmpty[List[float]], 

769 NotEmpty[List[bool]], 

770 NotEmpty[List[str]], 

771] 

772 

773 

774NominalOrOrdinalDType = Literal[ 

775 "float32", 

776 "float64", 

777 "uint8", 

778 "int8", 

779 "uint16", 

780 "int16", 

781 "uint32", 

782 "int32", 

783 "uint64", 

784 "int64", 

785 "bool", 

786] 

787 

788 

789class NominalOrOrdinalDataDescr(Node): 

790 values: TVs 

791 """A fixed set of nominal or an ascending sequence of ordinal values. 

792 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

793 String `values` are interpreted as labels for tensor values 0, ..., N. 

794 Note: as YAML 1.2 does not natively support a "set" datatype, 

795 nominal values should be given as a sequence (aka list/array) as well. 

796 """ 

797 

798 type: Annotated[ 

799 NominalOrOrdinalDType, 

800 Field( 

801 examples=[ 

802 "float32", 

803 "uint8", 

804 "uint16", 

805 "int64", 

806 "bool", 

807 ], 

808 ), 

809 ] = "uint8" 

810 

811 @model_validator(mode="after") 

812 def _validate_values_match_type( 

813 self, 

814 ) -> Self: 

815 incompatible: List[Any] = [] 

816 for v in self.values: 

817 if self.type == "bool": 

818 if not isinstance(v, bool): 

819 incompatible.append(v) 

820 elif self.type in DTYPE_LIMITS: 

821 if ( 

822 isinstance(v, (int, float)) 

823 and ( 

824 v < DTYPE_LIMITS[self.type].min 

825 or v > DTYPE_LIMITS[self.type].max 

826 ) 

827 or (isinstance(v, str) and "uint" not in self.type) 

828 or (isinstance(v, float) and "int" in self.type) 

829 ): 

830 incompatible.append(v) 

831 else: 

832 incompatible.append(v) 

833 

834 if len(incompatible) == 5: 

835 incompatible.append("...") 

836 break 

837 

838 if incompatible: 

839 raise ValueError( 

840 f"data type '{self.type}' incompatible with values {incompatible}" 

841 ) 

842 

843 return self 

844 

845 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

846 

847 @property 

848 def range(self): 

849 if isinstance(self.values[0], str): 

850 return 0, len(self.values) - 1 

851 else: 

852 return min(self.values), max(self.values) 

853 

854 

855IntervalOrRatioDType = Literal[ 

856 "float32", 

857 "float64", 

858 "uint8", 

859 "int8", 

860 "uint16", 

861 "int16", 

862 "uint32", 

863 "int32", 

864 "uint64", 

865 "int64", 

866] 

867 

868 

869class IntervalOrRatioDataDescr(Node): 

870 type: Annotated[ # TODO: rename to dtype 

871 IntervalOrRatioDType, 

872 Field( 

873 examples=["float32", "float64", "uint8", "uint16"], 

874 ), 

875 ] = "float32" 

876 range: Tuple[Optional[float], Optional[float]] = ( 

877 None, 

878 None, 

879 ) 

880 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

881 `None` corresponds to min/max of what can be expressed by **type**.""" 

882 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

883 scale: float = 1.0 

884 """Scale for data on an interval (or ratio) scale.""" 

885 offset: Optional[float] = None 

886 """Offset for data on a ratio scale.""" 

887 

888 @model_validator(mode="before") 

889 def _replace_inf(cls, data: Any): 

890 if is_dict(data): 

891 if "range" in data and is_sequence(data["range"]): 

892 forbidden = ( 

893 "inf", 

894 "-inf", 

895 ".inf", 

896 "-.inf", 

897 float("inf"), 

898 float("-inf"), 

899 ) 

900 if any(v in forbidden for v in data["range"]): 

901 issue_warning("replaced 'inf' value", value=data["range"]) 

902 

903 data["range"] = tuple( 

904 (None if v in forbidden else v) for v in data["range"] 

905 ) 

906 

907 return data 

908 

909 

910TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

911 

912 

913class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 

914 """processing base class""" 

915 

916 

917class BinarizeKwargs(ProcessingKwargs): 

918 """key word arguments for `BinarizeDescr`""" 

919 

920 threshold: float 

921 """The fixed threshold""" 

922 

923 

924class BinarizeAlongAxisKwargs(ProcessingKwargs): 

925 """key word arguments for `BinarizeDescr`""" 

926 

927 threshold: NotEmpty[List[float]] 

928 """The fixed threshold values along `axis`""" 

929 

930 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

931 """The `threshold` axis""" 

932 

933 

934class BinarizeDescr(ProcessingDescrBase): 

935 """Binarize the tensor with a fixed threshold. 

936 

937 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 

938 will be set to one, values below the threshold to zero. 

939 

940 Examples: 

941 - in YAML 

942 ```yaml 

943 postprocessing: 

944 - id: binarize 

945 kwargs: 

946 axis: 'channel' 

947 threshold: [0.25, 0.5, 0.75] 

948 ``` 

949 - in Python: 

950 >>> postprocessing = [BinarizeDescr( 

951 ... kwargs=BinarizeAlongAxisKwargs( 

952 ... axis=AxisId('channel'), 

953 ... threshold=[0.25, 0.5, 0.75], 

954 ... ) 

955 ... )] 

956 """ 

957 

958 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

959 if TYPE_CHECKING: 

960 id: Literal["binarize"] = "binarize" 

961 else: 

962 id: Literal["binarize"] 

963 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

964 

965 

966class ClipDescr(ProcessingDescrBase): 

967 """Set tensor values below min to min and above max to max. 

968 

969 See `ScaleRangeDescr` for examples. 

970 """ 

971 

972 implemented_id: ClassVar[Literal["clip"]] = "clip" 

973 if TYPE_CHECKING: 

974 id: Literal["clip"] = "clip" 

975 else: 

976 id: Literal["clip"] 

977 

978 kwargs: ClipKwargs 

979 

980 

981class EnsureDtypeKwargs(ProcessingKwargs): 

982 """key word arguments for `EnsureDtypeDescr`""" 

983 

984 dtype: Literal[ 

985 "float32", 

986 "float64", 

987 "uint8", 

988 "int8", 

989 "uint16", 

990 "int16", 

991 "uint32", 

992 "int32", 

993 "uint64", 

994 "int64", 

995 "bool", 

996 ] 

997 

998 

999class EnsureDtypeDescr(ProcessingDescrBase): 

1000 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

1001 

1002 This can for example be used to ensure the inner neural network model gets a 

1003 different input tensor data type than the fully described bioimage.io model does. 

1004 

1005 Examples: 

1006 The described bioimage.io model (incl. preprocessing) accepts any 

1007 float32-compatible tensor, normalizes it with percentiles and clipping and then 

1008 casts it to uint8, which is what the neural network in this example expects. 

1009 - in YAML 

1010 ```yaml 

1011 inputs: 

1012 - data: 

1013 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

1014 preprocessing: 

1015 - id: scale_range 

1016 kwargs: 

1017 axes: ['y', 'x'] 

1018 max_percentile: 99.8 

1019 min_percentile: 5.0 

1020 - id: clip 

1021 kwargs: 

1022 min: 0.0 

1023 max: 1.0 

1024 - id: ensure_dtype # the neural network of the model requires uint8 

1025 kwargs: 

1026 dtype: uint8 

1027 ``` 

1028 - in Python: 

1029 >>> preprocessing = [ 

1030 ... ScaleRangeDescr( 

1031 ... kwargs=ScaleRangeKwargs( 

1032 ... axes= (AxisId('y'), AxisId('x')), 

1033 ... max_percentile= 99.8, 

1034 ... min_percentile= 5.0, 

1035 ... ) 

1036 ... ), 

1037 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

1038 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

1039 ... ] 

1040 """ 

1041 

1042 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1043 if TYPE_CHECKING: 

1044 id: Literal["ensure_dtype"] = "ensure_dtype" 

1045 else: 

1046 id: Literal["ensure_dtype"] 

1047 

1048 kwargs: EnsureDtypeKwargs 

1049 

1050 

1051class ScaleLinearKwargs(ProcessingKwargs): 

1052 """Key word arguments for `ScaleLinearDescr`""" 

1053 

1054 gain: float = 1.0 

1055 """multiplicative factor""" 

1056 

1057 offset: float = 0.0 

1058 """additive term""" 

1059 

1060 @model_validator(mode="after") 

1061 def _validate(self) -> Self: 

1062 if self.gain == 1.0 and self.offset == 0.0: 

1063 raise ValueError( 

1064 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1065 + " != 0.0." 

1066 ) 

1067 

1068 return self 

1069 

1070 

1071class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 

1072 """Key word arguments for `ScaleLinearDescr`""" 

1073 

1074 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1075 """The axis of gain and offset values.""" 

1076 

1077 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1078 """multiplicative factor""" 

1079 

1080 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1081 """additive term""" 

1082 

1083 @model_validator(mode="after") 

1084 def _validate(self) -> Self: 

1085 

1086 if isinstance(self.gain, list): 

1087 if isinstance(self.offset, list): 

1088 if len(self.gain) != len(self.offset): 

1089 raise ValueError( 

1090 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1091 ) 

1092 else: 

1093 self.offset = [float(self.offset)] * len(self.gain) 

1094 elif isinstance(self.offset, list): 

1095 self.gain = [float(self.gain)] * len(self.offset) 

1096 else: 

1097 raise ValueError( 

1098 "Do not specify an `axis` for scalar gain and offset values." 

1099 ) 

1100 

1101 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1102 raise ValueError( 

1103 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1104 + " != 0.0." 

1105 ) 

1106 

1107 return self 

1108 

1109 

1110class ScaleLinearDescr(ProcessingDescrBase): 

1111 """Fixed linear scaling. 

1112 

1113 Examples: 

1114 1. Scale with scalar gain and offset 

1115 - in YAML 

1116 ```yaml 

1117 preprocessing: 

1118 - id: scale_linear 

1119 kwargs: 

1120 gain: 2.0 

1121 offset: 3.0 

1122 ``` 

1123 - in Python: 

1124 >>> preprocessing = [ 

1125 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1126 ... ] 

1127 

1128 2. Independent scaling along an axis 

1129 - in YAML 

1130 ```yaml 

1131 preprocessing: 

1132 - id: scale_linear 

1133 kwargs: 

1134 axis: 'channel' 

1135 gain: [1.0, 2.0, 3.0] 

1136 ``` 

1137 - in Python: 

1138 >>> preprocessing = [ 

1139 ... ScaleLinearDescr( 

1140 ... kwargs=ScaleLinearAlongAxisKwargs( 

1141 ... axis=AxisId("channel"), 

1142 ... gain=[1.0, 2.0, 3.0], 

1143 ... ) 

1144 ... ) 

1145 ... ] 

1146 

1147 """ 

1148 

1149 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1150 if TYPE_CHECKING: 

1151 id: Literal["scale_linear"] = "scale_linear" 

1152 else: 

1153 id: Literal["scale_linear"] 

1154 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1155 

1156 

1157class SigmoidDescr(ProcessingDescrBase): 

1158 """The logistic sigmoid function, a.k.a. expit function. 

1159 

1160 Examples: 

1161 - in YAML 

1162 ```yaml 

1163 postprocessing: 

1164 - id: sigmoid 

1165 ``` 

1166 - in Python: 

1167 >>> postprocessing = [SigmoidDescr()] 

1168 """ 

1169 

1170 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1171 if TYPE_CHECKING: 

1172 id: Literal["sigmoid"] = "sigmoid" 

1173 else: 

1174 id: Literal["sigmoid"] 

1175 

1176 @property 

1177 def kwargs(self) -> ProcessingKwargs: 

1178 """empty kwargs""" 

1179 return ProcessingKwargs() 

1180 

1181 

1182class SoftmaxKwargs(ProcessingKwargs): 

1183 """key word arguments for `SoftmaxDescr`""" 

1184 

1185 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 

1186 """The axis to apply the softmax function along. 

1187 Note: 

1188 Defaults to 'channel' axis 

1189 (which may not exist, in which case 

1190 a different axis id has to be specified). 

1191 """ 

1192 

1193 

1194class SoftmaxDescr(ProcessingDescrBase): 

1195 """The softmax function. 

1196 

1197 Examples: 

1198 - in YAML 

1199 ```yaml 

1200 postprocessing: 

1201 - id: softmax 

1202 kwargs: 

1203 axis: channel 

1204 ``` 

1205 - in Python: 

1206 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 

1207 """ 

1208 

1209 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 

1210 if TYPE_CHECKING: 

1211 id: Literal["softmax"] = "softmax" 

1212 else: 

1213 id: Literal["softmax"] 

1214 

1215 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 

1216 

1217 

1218class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1219 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1220 

1221 mean: float 

1222 """The mean value to normalize with.""" 

1223 

1224 std: Annotated[float, Ge(1e-6)] 

1225 """The standard deviation value to normalize with.""" 

1226 

1227 

1228class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 

1229 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1230 

1231 mean: NotEmpty[List[float]] 

1232 """The mean value(s) to normalize with.""" 

1233 

1234 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1235 """The standard deviation value(s) to normalize with. 

1236 Size must match `mean` values.""" 

1237 

1238 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1239 """The axis of the mean/std values to normalize each entry along that dimension 

1240 separately.""" 

1241 

1242 @model_validator(mode="after") 

1243 def _mean_and_std_match(self) -> Self: 

1244 if len(self.mean) != len(self.std): 

1245 raise ValueError( 

1246 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1247 + " must match." 

1248 ) 

1249 

1250 return self 

1251 

1252 

1253class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1254 """Subtract a given mean and divide by the standard deviation. 

1255 

1256 Normalize with fixed, precomputed values for 

1257 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1258 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1259 axes. 

1260 

1261 Examples: 

1262 1. scalar value for whole tensor 

1263 - in YAML 

1264 ```yaml 

1265 preprocessing: 

1266 - id: fixed_zero_mean_unit_variance 

1267 kwargs: 

1268 mean: 103.5 

1269 std: 13.7 

1270 ``` 

1271 - in Python 

1272 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1273 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1274 ... )] 

1275 

1276 2. independently along an axis 

1277 - in YAML 

1278 ```yaml 

1279 preprocessing: 

1280 - id: fixed_zero_mean_unit_variance 

1281 kwargs: 

1282 axis: channel 

1283 mean: [101.5, 102.5, 103.5] 

1284 std: [11.7, 12.7, 13.7] 

1285 ``` 

1286 - in Python 

1287 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1288 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1289 ... axis=AxisId("channel"), 

1290 ... mean=[101.5, 102.5, 103.5], 

1291 ... std=[11.7, 12.7, 13.7], 

1292 ... ) 

1293 ... )] 

1294 """ 

1295 

1296 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1297 "fixed_zero_mean_unit_variance" 

1298 ) 

1299 if TYPE_CHECKING: 

1300 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1301 else: 

1302 id: Literal["fixed_zero_mean_unit_variance"] 

1303 

1304 kwargs: Union[ 

1305 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1306 ] 

1307 

1308 

1309class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1310 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

1311 

1312 axes: Annotated[ 

1313 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1314 ] = None 

1315 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1316 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1317 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1318 To normalize each sample independently leave out the 'batch' axis. 

1319 Default: Scale all axes jointly.""" 

1320 

1321 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1322 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1323 

1324 

1325class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1326 """Subtract mean and divide by variance. 

1327 

1328 Examples: 

1329 Subtract tensor mean and variance 

1330 - in YAML 

1331 ```yaml 

1332 preprocessing: 

1333 - id: zero_mean_unit_variance 

1334 ``` 

1335 - in Python 

1336 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1337 """ 

1338 

1339 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1340 "zero_mean_unit_variance" 

1341 ) 

1342 if TYPE_CHECKING: 

1343 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1344 else: 

1345 id: Literal["zero_mean_unit_variance"] 

1346 

1347 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1348 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 

1349 ) 

1350 

1351 

1352class ScaleRangeKwargs(ProcessingKwargs): 

1353 """key word arguments for `ScaleRangeDescr` 

1354 

1355 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1356 this processing step normalizes data to the [0, 1] intervall. 

1357 For other percentiles the normalized values will partially be outside the [0, 1] 

1358 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1359 normalized values to a range. 

1360 """ 

1361 

1362 axes: Annotated[ 

1363 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1364 ] = None 

1365 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1366 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1367 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1368 To normalize samples independently, leave out the "batch" axis. 

1369 Default: Scale all axes jointly.""" 

1370 

1371 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1372 """The lower percentile used to determine the value to align with zero.""" 

1373 

1374 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1375 """The upper percentile used to determine the value to align with one. 

1376 Has to be bigger than `min_percentile`. 

1377 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1378 accepting percentiles specified in the range 0.0 to 1.0.""" 

1379 

1380 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1381 """Epsilon for numeric stability. 

1382 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1383 with `v_lower,v_upper` values at the respective percentiles.""" 

1384 

1385 reference_tensor: Optional[TensorId] = None 

1386 """Tensor ID to compute the percentiles from. Default: The tensor itself. 

1387 For any tensor in `inputs` only input tensor references are allowed.""" 

1388 

1389 @field_validator("max_percentile", mode="after") 

1390 @classmethod 

1391 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1392 if (min_p := info.data["min_percentile"]) >= value: 

1393 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1394 

1395 return value 

1396 

1397 

1398class ScaleRangeDescr(ProcessingDescrBase): 

1399 """Scale with percentiles. 

1400 

1401 Examples: 

1402 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1403 - in YAML 

1404 ```yaml 

1405 preprocessing: 

1406 - id: scale_range 

1407 kwargs: 

1408 axes: ['y', 'x'] 

1409 max_percentile: 99.8 

1410 min_percentile: 5.0 

1411 ``` 

1412 - in Python 

1413 >>> preprocessing = [ 

1414 ... ScaleRangeDescr( 

1415 ... kwargs=ScaleRangeKwargs( 

1416 ... axes= (AxisId('y'), AxisId('x')), 

1417 ... max_percentile= 99.8, 

1418 ... min_percentile= 5.0, 

1419 ... ) 

1420 ... ), 

1421 ... ClipDescr( 

1422 ... kwargs=ClipKwargs( 

1423 ... min=0.0, 

1424 ... max=1.0, 

1425 ... ) 

1426 ... ), 

1427 ... ] 

1428 

1429 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1430 - in YAML 

1431 ```yaml 

1432 preprocessing: 

1433 - id: scale_range 

1434 kwargs: 

1435 axes: ['y', 'x'] 

1436 max_percentile: 99.8 

1437 min_percentile: 5.0 

1438 - id: scale_range 

1439 - id: clip 

1440 kwargs: 

1441 min: 0.0 

1442 max: 1.0 

1443 ``` 

1444 - in Python 

1445 >>> preprocessing = [ScaleRangeDescr( 

1446 ... kwargs=ScaleRangeKwargs( 

1447 ... axes= (AxisId('y'), AxisId('x')), 

1448 ... max_percentile= 99.8, 

1449 ... min_percentile= 5.0, 

1450 ... ) 

1451 ... )] 

1452 

1453 """ 

1454 

1455 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1456 if TYPE_CHECKING: 

1457 id: Literal["scale_range"] = "scale_range" 

1458 else: 

1459 id: Literal["scale_range"] 

1460 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 

1461 

1462 

1463class ScaleMeanVarianceKwargs(ProcessingKwargs): 

1464 """key word arguments for `ScaleMeanVarianceKwargs`""" 

1465 

1466 reference_tensor: TensorId 

1467 """Name of tensor to match.""" 

1468 

1469 axes: Annotated[ 

1470 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1471 ] = None 

1472 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1473 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1474 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1475 To normalize samples independently, leave out the 'batch' axis. 

1476 Default: Scale all axes jointly.""" 

1477 

1478 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1479 """Epsilon for numeric stability: 

1480 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1481 

1482 

1483class ScaleMeanVarianceDescr(ProcessingDescrBase): 

1484 """Scale a tensor's data distribution to match another tensor's mean/std. 

1485 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1486 """ 

1487 

1488 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1489 if TYPE_CHECKING: 

1490 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1491 else: 

1492 id: Literal["scale_mean_variance"] 

1493 kwargs: ScaleMeanVarianceKwargs 

1494 

1495 

1496PreprocessingDescr = Annotated[ 

1497 Union[ 

1498 BinarizeDescr, 

1499 ClipDescr, 

1500 EnsureDtypeDescr, 

1501 FixedZeroMeanUnitVarianceDescr, 

1502 ScaleLinearDescr, 

1503 ScaleRangeDescr, 

1504 SigmoidDescr, 

1505 SoftmaxDescr, 

1506 ZeroMeanUnitVarianceDescr, 

1507 ], 

1508 Discriminator("id"), 

1509] 

1510PostprocessingDescr = Annotated[ 

1511 Union[ 

1512 BinarizeDescr, 

1513 ClipDescr, 

1514 EnsureDtypeDescr, 

1515 FixedZeroMeanUnitVarianceDescr, 

1516 ScaleLinearDescr, 

1517 ScaleMeanVarianceDescr, 

1518 ScaleRangeDescr, 

1519 SigmoidDescr, 

1520 SoftmaxDescr, 

1521 ZeroMeanUnitVarianceDescr, 

1522 ], 

1523 Discriminator("id"), 

1524] 

1525 

1526IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1527 

1528 

1529class TensorDescrBase(Node, Generic[IO_AxisT]): 

1530 id: TensorId 

1531 """Tensor id. No duplicates are allowed.""" 

1532 

1533 description: Annotated[str, MaxLen(128)] = "" 

1534 """free text description""" 

1535 

1536 axes: NotEmpty[Sequence[IO_AxisT]] 

1537 """tensor axes""" 

1538 

1539 @property 

1540 def shape(self): 

1541 return tuple(a.size for a in self.axes) 

1542 

1543 @field_validator("axes", mode="after", check_fields=False) 

1544 @classmethod 

1545 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1546 batch_axes = [a for a in axes if a.type == "batch"] 

1547 if len(batch_axes) > 1: 

1548 raise ValueError( 

1549 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1550 ) 

1551 

1552 seen_ids: Set[AxisId] = set() 

1553 duplicate_axes_ids: Set[AxisId] = set() 

1554 for a in axes: 

1555 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1556 

1557 if duplicate_axes_ids: 

1558 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1559 

1560 return axes 

1561 

1562 test_tensor: FAIR[Optional[FileDescr_]] = None 

1563 """An example tensor to use for testing. 

1564 Using the model with the test input tensors is expected to yield the test output tensors. 

1565 Each test tensor has be a an ndarray in the 

1566 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1567 The file extension must be '.npy'.""" 

1568 

1569 sample_tensor: FAIR[Optional[FileDescr_]] = None 

1570 """A sample tensor to illustrate a possible input/output for the model, 

1571 The sample image primarily serves to inform a human user about an example use case 

1572 and is typically stored as .hdf5, .png or .tiff. 

1573 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1574 (numpy's `.npy` format is not supported). 

1575 The image dimensionality has to match the number of axes specified in this tensor description. 

1576 """ 

1577 

1578 @model_validator(mode="after") 

1579 def _validate_sample_tensor(self) -> Self: 

1580 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1581 return self 

1582 

1583 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1584 tensor: NDArray[Any] = imread( 

1585 reader.read(), 

1586 extension=PurePosixPath(reader.original_file_name).suffix, 

1587 ) 

1588 n_dims = len(tensor.squeeze().shape) 

1589 n_dims_min = n_dims_max = len(self.axes) 

1590 

1591 for a in self.axes: 

1592 if isinstance(a, BatchAxis): 

1593 n_dims_min -= 1 

1594 elif isinstance(a.size, int): 

1595 if a.size == 1: 

1596 n_dims_min -= 1 

1597 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1598 if a.size.min == 1: 

1599 n_dims_min -= 1 

1600 elif isinstance(a.size, SizeReference): 

1601 if a.size.offset < 2: 

1602 # size reference may result in singleton axis 

1603 n_dims_min -= 1 

1604 else: 

1605 assert_never(a.size) 

1606 

1607 n_dims_min = max(0, n_dims_min) 

1608 if n_dims < n_dims_min or n_dims > n_dims_max: 

1609 raise ValueError( 

1610 f"Expected sample tensor to have {n_dims_min} to" 

1611 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1612 ) 

1613 

1614 return self 

1615 

1616 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1617 IntervalOrRatioDataDescr() 

1618 ) 

1619 """Description of the tensor's data values, optionally per channel. 

1620 If specified per channel, the data `type` needs to match across channels.""" 

1621 

1622 @property 

1623 def dtype( 

1624 self, 

1625 ) -> Literal[ 

1626 "float32", 

1627 "float64", 

1628 "uint8", 

1629 "int8", 

1630 "uint16", 

1631 "int16", 

1632 "uint32", 

1633 "int32", 

1634 "uint64", 

1635 "int64", 

1636 "bool", 

1637 ]: 

1638 """dtype as specified under `data.type` or `data[i].type`""" 

1639 if isinstance(self.data, collections.abc.Sequence): 

1640 return self.data[0].type 

1641 else: 

1642 return self.data.type 

1643 

1644 @field_validator("data", mode="after") 

1645 @classmethod 

1646 def _check_data_type_across_channels( 

1647 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1648 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1649 if not isinstance(value, list): 

1650 return value 

1651 

1652 dtypes = {t.type for t in value} 

1653 if len(dtypes) > 1: 

1654 raise ValueError( 

1655 "Tensor data descriptions per channel need to agree in their data" 

1656 + f" `type`, but found {dtypes}." 

1657 ) 

1658 

1659 return value 

1660 

1661 @model_validator(mode="after") 

1662 def _check_data_matches_channelaxis(self) -> Self: 

1663 if not isinstance(self.data, (list, tuple)): 

1664 return self 

1665 

1666 for a in self.axes: 

1667 if isinstance(a, ChannelAxis): 

1668 size = a.size 

1669 assert isinstance(size, int) 

1670 break 

1671 else: 

1672 return self 

1673 

1674 if len(self.data) != size: 

1675 raise ValueError( 

1676 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1677 + f" '{a.id}' axis has size {size}." 

1678 ) 

1679 

1680 return self 

1681 

1682 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1683 if len(array.shape) != len(self.axes): 

1684 raise ValueError( 

1685 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1686 + f" incompatible with {len(self.axes)} axes." 

1687 ) 

1688 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1689 

1690 

1691class InputTensorDescr(TensorDescrBase[InputAxis]): 

1692 id: TensorId = TensorId("input") 

1693 """Input tensor id. 

1694 No duplicates are allowed across all inputs and outputs.""" 

1695 

1696 optional: bool = False 

1697 """indicates that this tensor may be `None`""" 

1698 

1699 preprocessing: List[PreprocessingDescr] = Field( 

1700 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 

1701 ) 

1702 

1703 """Description of how this input should be preprocessed. 

1704 

1705 notes: 

1706 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1707 to ensure an input tensor's data type matches the input tensor's data description. 

1708 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1709 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1710 changing the data type. 

1711 """ 

1712 

1713 @model_validator(mode="after") 

1714 def _validate_preprocessing_kwargs(self) -> Self: 

1715 axes_ids = [a.id for a in self.axes] 

1716 for p in self.preprocessing: 

1717 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1718 if kwargs_axes is None: 

1719 continue 

1720 

1721 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1722 raise ValueError( 

1723 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1724 ) 

1725 

1726 if any(a not in axes_ids for a in kwargs_axes): 

1727 raise ValueError( 

1728 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1729 ) 

1730 

1731 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1732 dtype = self.data.type 

1733 else: 

1734 dtype = self.data[0].type 

1735 

1736 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1737 if not self.preprocessing or not isinstance( 

1738 self.preprocessing[0], EnsureDtypeDescr 

1739 ): 

1740 self.preprocessing.insert( 

1741 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1742 ) 

1743 

1744 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1745 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1746 self.preprocessing.append( 

1747 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1748 ) 

1749 

1750 return self 

1751 

1752 

1753def convert_axes( 

1754 axes: str, 

1755 *, 

1756 shape: Union[ 

1757 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1758 ], 

1759 tensor_type: Literal["input", "output"], 

1760 halo: Optional[Sequence[int]], 

1761 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1762): 

1763 ret: List[AnyAxis] = [] 

1764 for i, a in enumerate(axes): 

1765 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1766 if axis_type == "batch": 

1767 ret.append(BatchAxis()) 

1768 continue 

1769 

1770 scale = 1.0 

1771 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1772 if shape.step[i] == 0: 

1773 size = shape.min[i] 

1774 else: 

1775 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1776 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1777 ref_t = str(shape.reference_tensor) 

1778 if ref_t.count(".") == 1: 

1779 t_id, orig_a_id = ref_t.split(".") 

1780 else: 

1781 t_id = ref_t 

1782 orig_a_id = a 

1783 

1784 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1785 if not (orig_scale := shape.scale[i]): 

1786 # old way to insert a new axis dimension 

1787 size = int(2 * shape.offset[i]) 

1788 else: 

1789 scale = 1 / orig_scale 

1790 if axis_type in ("channel", "index"): 

1791 # these axes no longer have a scale 

1792 offset_from_scale = orig_scale * size_refs.get( 

1793 _TensorName_v0_4(t_id), {} 

1794 ).get(orig_a_id, 0) 

1795 else: 

1796 offset_from_scale = 0 

1797 size = SizeReference( 

1798 tensor_id=TensorId(t_id), 

1799 axis_id=AxisId(a_id), 

1800 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1801 ) 

1802 else: 

1803 size = shape[i] 

1804 

1805 if axis_type == "time": 

1806 if tensor_type == "input": 

1807 ret.append(TimeInputAxis(size=size, scale=scale)) 

1808 else: 

1809 assert not isinstance(size, ParameterizedSize) 

1810 if halo is None: 

1811 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1812 else: 

1813 assert not isinstance(size, int) 

1814 ret.append( 

1815 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1816 ) 

1817 

1818 elif axis_type == "index": 

1819 if tensor_type == "input": 

1820 ret.append(IndexInputAxis(size=size)) 

1821 else: 

1822 if isinstance(size, ParameterizedSize): 

1823 size = DataDependentSize(min=size.min) 

1824 

1825 ret.append(IndexOutputAxis(size=size)) 

1826 elif axis_type == "channel": 

1827 assert not isinstance(size, ParameterizedSize) 

1828 if isinstance(size, SizeReference): 

1829 warnings.warn( 

1830 "Conversion of channel size from an implicit output shape may be" 

1831 + " wrong" 

1832 ) 

1833 ret.append( 

1834 ChannelAxis( 

1835 channel_names=[ 

1836 Identifier(f"channel{i}") for i in range(size.offset) 

1837 ] 

1838 ) 

1839 ) 

1840 else: 

1841 ret.append( 

1842 ChannelAxis( 

1843 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1844 ) 

1845 ) 

1846 elif axis_type == "space": 

1847 if tensor_type == "input": 

1848 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

1849 else: 

1850 assert not isinstance(size, ParameterizedSize) 

1851 if halo is None or halo[i] == 0: 

1852 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

1853 elif isinstance(size, int): 

1854 raise NotImplementedError( 

1855 f"output axis with halo and fixed size (here {size}) not allowed" 

1856 ) 

1857 else: 

1858 ret.append( 

1859 SpaceOutputAxisWithHalo( 

1860 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

1861 ) 

1862 ) 

1863 

1864 return ret 

1865 

1866 

1867def _axes_letters_to_ids( 

1868 axes: Optional[str], 

1869) -> Optional[List[AxisId]]: 

1870 if axes is None: 

1871 return None 

1872 

1873 return [AxisId(a) for a in axes] 

1874 

1875 

1876def _get_complement_v04_axis( 

1877 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

1878) -> Optional[AxisId]: 

1879 if axes is None: 

1880 return None 

1881 

1882 non_complement_axes = set(axes) | {"b"} 

1883 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

1884 if len(complement_axes) > 1: 

1885 raise ValueError( 

1886 f"Expected none or a single complement axis, but axes '{axes}' " 

1887 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

1888 ) 

1889 

1890 return None if not complement_axes else AxisId(complement_axes[0]) 

1891 

1892 

1893def _convert_proc( 

1894 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

1895 tensor_axes: Sequence[str], 

1896) -> Union[PreprocessingDescr, PostprocessingDescr]: 

1897 if isinstance(p, _BinarizeDescr_v0_4): 

1898 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

1899 elif isinstance(p, _ClipDescr_v0_4): 

1900 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

1901 elif isinstance(p, _SigmoidDescr_v0_4): 

1902 return SigmoidDescr() 

1903 elif isinstance(p, _ScaleLinearDescr_v0_4): 

1904 axes = _axes_letters_to_ids(p.kwargs.axes) 

1905 if p.kwargs.axes is None: 

1906 axis = None 

1907 else: 

1908 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1909 

1910 if axis is None: 

1911 assert not isinstance(p.kwargs.gain, list) 

1912 assert not isinstance(p.kwargs.offset, list) 

1913 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

1914 else: 

1915 kwargs = ScaleLinearAlongAxisKwargs( 

1916 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

1917 ) 

1918 return ScaleLinearDescr(kwargs=kwargs) 

1919 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

1920 return ScaleMeanVarianceDescr( 

1921 kwargs=ScaleMeanVarianceKwargs( 

1922 axes=_axes_letters_to_ids(p.kwargs.axes), 

1923 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

1924 eps=p.kwargs.eps, 

1925 ) 

1926 ) 

1927 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

1928 if p.kwargs.mode == "fixed": 

1929 mean = p.kwargs.mean 

1930 std = p.kwargs.std 

1931 assert mean is not None 

1932 assert std is not None 

1933 

1934 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1935 

1936 if axis is None: 

1937 return FixedZeroMeanUnitVarianceDescr( 

1938 kwargs=FixedZeroMeanUnitVarianceKwargs( 

1939 mean=mean, std=std # pyright: ignore[reportArgumentType] 

1940 ) 

1941 ) 

1942 else: 

1943 if not isinstance(mean, list): 

1944 mean = [float(mean)] 

1945 if not isinstance(std, list): 

1946 std = [float(std)] 

1947 

1948 return FixedZeroMeanUnitVarianceDescr( 

1949 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1950 axis=axis, mean=mean, std=std 

1951 ) 

1952 ) 

1953 

1954 else: 

1955 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

1956 if p.kwargs.mode == "per_dataset": 

1957 axes = [AxisId("batch")] + axes 

1958 if not axes: 

1959 axes = None 

1960 return ZeroMeanUnitVarianceDescr( 

1961 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

1962 ) 

1963 

1964 elif isinstance(p, _ScaleRangeDescr_v0_4): 

1965 return ScaleRangeDescr( 

1966 kwargs=ScaleRangeKwargs( 

1967 axes=_axes_letters_to_ids(p.kwargs.axes), 

1968 min_percentile=p.kwargs.min_percentile, 

1969 max_percentile=p.kwargs.max_percentile, 

1970 eps=p.kwargs.eps, 

1971 ) 

1972 ) 

1973 else: 

1974 assert_never(p) 

1975 

1976 

1977class _InputTensorConv( 

1978 Converter[ 

1979 _InputTensorDescr_v0_4, 

1980 InputTensorDescr, 

1981 FileSource_, 

1982 Optional[FileSource_], 

1983 Mapping[_TensorName_v0_4, Mapping[str, int]], 

1984 ] 

1985): 

1986 def _convert( 

1987 self, 

1988 src: _InputTensorDescr_v0_4, 

1989 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

1990 test_tensor: FileSource_, 

1991 sample_tensor: Optional[FileSource_], 

1992 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1993 ) -> "InputTensorDescr | dict[str, Any]": 

1994 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

1995 src.axes, 

1996 shape=src.shape, 

1997 tensor_type="input", 

1998 halo=None, 

1999 size_refs=size_refs, 

2000 ) 

2001 prep: List[PreprocessingDescr] = [] 

2002 for p in src.preprocessing: 

2003 cp = _convert_proc(p, src.axes) 

2004 assert not isinstance(cp, ScaleMeanVarianceDescr) 

2005 prep.append(cp) 

2006 

2007 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

2008 

2009 return tgt( 

2010 axes=axes, 

2011 id=TensorId(str(src.name)), 

2012 test_tensor=FileDescr(source=test_tensor), 

2013 sample_tensor=( 

2014 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2015 ), 

2016 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

2017 preprocessing=prep, 

2018 ) 

2019 

2020 

2021_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

2022 

2023 

2024class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

2025 id: TensorId = TensorId("output") 

2026 """Output tensor id. 

2027 No duplicates are allowed across all inputs and outputs.""" 

2028 

2029 postprocessing: List[PostprocessingDescr] = Field( 

2030 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

2031 ) 

2032 """Description of how this output should be postprocessed. 

2033 

2034 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

2035 If not given this is added to cast to this tensor's `data.type`. 

2036 """ 

2037 

2038 @model_validator(mode="after") 

2039 def _validate_postprocessing_kwargs(self) -> Self: 

2040 axes_ids = [a.id for a in self.axes] 

2041 for p in self.postprocessing: 

2042 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

2043 if kwargs_axes is None: 

2044 continue 

2045 

2046 if not isinstance(kwargs_axes, collections.abc.Sequence): 

2047 raise ValueError( 

2048 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

2049 ) 

2050 

2051 if any(a not in axes_ids for a in kwargs_axes): 

2052 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

2053 

2054 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

2055 dtype = self.data.type 

2056 else: 

2057 dtype = self.data[0].type 

2058 

2059 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

2060 if not self.postprocessing or not isinstance( 

2061 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

2062 ): 

2063 self.postprocessing.append( 

2064 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2065 ) 

2066 return self 

2067 

2068 

2069class _OutputTensorConv( 

2070 Converter[ 

2071 _OutputTensorDescr_v0_4, 

2072 OutputTensorDescr, 

2073 FileSource_, 

2074 Optional[FileSource_], 

2075 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2076 ] 

2077): 

2078 def _convert( 

2079 self, 

2080 src: _OutputTensorDescr_v0_4, 

2081 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

2082 test_tensor: FileSource_, 

2083 sample_tensor: Optional[FileSource_], 

2084 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2085 ) -> "OutputTensorDescr | dict[str, Any]": 

2086 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2087 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2088 src.axes, 

2089 shape=src.shape, 

2090 tensor_type="output", 

2091 halo=src.halo, 

2092 size_refs=size_refs, 

2093 ) 

2094 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2095 if data_descr["type"] == "bool": 

2096 data_descr["values"] = [False, True] 

2097 

2098 return tgt( 

2099 axes=axes, 

2100 id=TensorId(str(src.name)), 

2101 test_tensor=FileDescr(source=test_tensor), 

2102 sample_tensor=( 

2103 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2104 ), 

2105 data=data_descr, # pyright: ignore[reportArgumentType] 

2106 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2107 ) 

2108 

2109 

2110_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2111 

2112 

2113TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2114 

2115 

2116def validate_tensors( 

2117 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 

2118 tensor_origin: Literal[ 

2119 "test_tensor" 

2120 ], # for more precise error messages, e.g. 'test_tensor' 

2121): 

2122 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 

2123 

2124 def e_msg(d: TensorDescr): 

2125 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2126 

2127 for descr, array in tensors.values(): 

2128 if array is None: 

2129 axis_sizes = {a.id: None for a in descr.axes} 

2130 else: 

2131 try: 

2132 axis_sizes = descr.get_axis_sizes_for_array(array) 

2133 except ValueError as e: 

2134 raise ValueError(f"{e_msg(descr)} {e}") 

2135 

2136 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 

2137 

2138 for descr, array in tensors.values(): 

2139 if array is None: 

2140 continue 

2141 

2142 if descr.dtype in ("float32", "float64"): 

2143 invalid_test_tensor_dtype = array.dtype.name not in ( 

2144 "float32", 

2145 "float64", 

2146 "uint8", 

2147 "int8", 

2148 "uint16", 

2149 "int16", 

2150 "uint32", 

2151 "int32", 

2152 "uint64", 

2153 "int64", 

2154 ) 

2155 else: 

2156 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2157 

2158 if invalid_test_tensor_dtype: 

2159 raise ValueError( 

2160 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 

2161 + f" match described dtype '{descr.dtype}'" 

2162 ) 

2163 

2164 if array.min() > -1e-4 and array.max() < 1e-4: 

2165 raise ValueError( 

2166 "Output values are too small for reliable testing." 

2167 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2168 ) 

2169 

2170 for a in descr.axes: 

2171 actual_size = all_tensor_axes[descr.id][a.id][1] 

2172 if actual_size is None: 

2173 continue 

2174 

2175 if a.size is None: 

2176 continue 

2177 

2178 if isinstance(a.size, int): 

2179 if actual_size != a.size: 

2180 raise ValueError( 

2181 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 

2182 + f"has incompatible size {actual_size}, expected {a.size}" 

2183 ) 

2184 elif isinstance(a.size, ParameterizedSize): 

2185 _ = a.size.validate_size(actual_size) 

2186 elif isinstance(a.size, DataDependentSize): 

2187 _ = a.size.validate_size(actual_size) 

2188 elif isinstance(a.size, SizeReference): 

2189 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2190 if ref_tensor_axes is None: 

2191 raise ValueError( 

2192 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2193 + f" reference '{a.size.tensor_id}'" 

2194 ) 

2195 

2196 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2197 if ref_axis is None or ref_size is None: 

2198 raise ValueError( 

2199 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2200 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 

2201 ) 

2202 

2203 if a.unit != ref_axis.unit: 

2204 raise ValueError( 

2205 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2206 + " axis and reference axis to have the same `unit`, but" 

2207 + f" {a.unit}!={ref_axis.unit}" 

2208 ) 

2209 

2210 if actual_size != ( 

2211 expected_size := ( 

2212 ref_size * ref_axis.scale / a.scale + a.size.offset 

2213 ) 

2214 ): 

2215 raise ValueError( 

2216 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2217 + f" {actual_size} invalid for referenced size {ref_size};" 

2218 + f" expected {expected_size}" 

2219 ) 

2220 else: 

2221 assert_never(a.size) 

2222 

2223 

2224FileDescr_dependencies = Annotated[ 

2225 FileDescr_, 

2226 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2227 Field(examples=[dict(source="environment.yaml")]), 

2228] 

2229 

2230 

2231class _ArchitectureCallableDescr(Node): 

2232 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2233 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2234 

2235 kwargs: Dict[str, YamlValue] = Field( 

2236 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2237 ) 

2238 """key word arguments for the `callable`""" 

2239 

2240 

2241class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2242 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2243 """Architecture source file""" 

2244 

2245 @model_serializer(mode="wrap", when_used="unless-none") 

2246 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2247 return package_file_descr_serializer(self, nxt, info) 

2248 

2249 

2250class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2251 import_from: str 

2252 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2253 

2254 

2255class _ArchFileConv( 

2256 Converter[ 

2257 _CallableFromFile_v0_4, 

2258 ArchitectureFromFileDescr, 

2259 Optional[Sha256], 

2260 Dict[str, Any], 

2261 ] 

2262): 

2263 def _convert( 

2264 self, 

2265 src: _CallableFromFile_v0_4, 

2266 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2267 sha256: Optional[Sha256], 

2268 kwargs: Dict[str, Any], 

2269 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2270 if src.startswith("http") and src.count(":") == 2: 

2271 http, source, callable_ = src.split(":") 

2272 source = ":".join((http, source)) 

2273 elif not src.startswith("http") and src.count(":") == 1: 

2274 source, callable_ = src.split(":") 

2275 else: 

2276 source = str(src) 

2277 callable_ = str(src) 

2278 return tgt( 

2279 callable=Identifier(callable_), 

2280 source=cast(FileSource_, source), 

2281 sha256=sha256, 

2282 kwargs=kwargs, 

2283 ) 

2284 

2285 

2286_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2287 

2288 

2289class _ArchLibConv( 

2290 Converter[ 

2291 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2292 ] 

2293): 

2294 def _convert( 

2295 self, 

2296 src: _CallableFromDepencency_v0_4, 

2297 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2298 kwargs: Dict[str, Any], 

2299 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2300 *mods, callable_ = src.split(".") 

2301 import_from = ".".join(mods) 

2302 return tgt( 

2303 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2304 ) 

2305 

2306 

2307_arch_lib_conv = _ArchLibConv( 

2308 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2309) 

2310 

2311 

2312class WeightsEntryDescrBase(FileDescr): 

2313 type: ClassVar[WeightsFormat] 

2314 weights_format_name: ClassVar[str] # human readable 

2315 

2316 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2317 """Source of the weights file.""" 

2318 

2319 authors: Optional[List[Author]] = None 

2320 """Authors 

2321 Either the person(s) that have trained this model resulting in the original weights file. 

2322 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2323 Or the person(s) who have converted the weights to this weights format. 

2324 (If this is a child weight, i.e. it has a `parent` field) 

2325 """ 

2326 

2327 parent: Annotated[ 

2328 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2329 ] = None 

2330 """The source weights these weights were converted from. 

2331 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2332 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2333 All weight entries except one (the initial set of weights resulting from training the model), 

2334 need to have this field.""" 

2335 

2336 comment: str = "" 

2337 """A comment about this weights entry, for example how these weights were created.""" 

2338 

2339 @model_validator(mode="after") 

2340 def _validate(self) -> Self: 

2341 if self.type == self.parent: 

2342 raise ValueError("Weights entry can't be it's own parent.") 

2343 

2344 return self 

2345 

2346 @model_serializer(mode="wrap", when_used="unless-none") 

2347 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2348 return package_file_descr_serializer(self, nxt, info) 

2349 

2350 

2351class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2352 type = "keras_hdf5" 

2353 weights_format_name: ClassVar[str] = "Keras HDF5" 

2354 tensorflow_version: Version 

2355 """TensorFlow version used to create these weights.""" 

2356 

2357 

2358class OnnxWeightsDescr(WeightsEntryDescrBase): 

2359 type = "onnx" 

2360 weights_format_name: ClassVar[str] = "ONNX" 

2361 opset_version: Annotated[int, Ge(7)] 

2362 """ONNX opset version""" 

2363 

2364 

2365class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2366 type = "pytorch_state_dict" 

2367 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2368 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 

2369 pytorch_version: Version 

2370 """Version of the PyTorch library used. 

2371 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2372 """ 

2373 dependencies: Optional[FileDescr_dependencies] = None 

2374 """Custom depencies beyond pytorch described in a Conda environment file. 

2375 Allows to specify custom dependencies, see conda docs: 

2376 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2377 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2378 

2379 The conda environment file should include pytorch and any version pinning has to be compatible with 

2380 **pytorch_version**. 

2381 """ 

2382 

2383 

2384class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2385 type = "tensorflow_js" 

2386 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2387 tensorflow_version: Version 

2388 """Version of the TensorFlow library used.""" 

2389 

2390 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2391 """The multi-file weights. 

2392 All required files/folders should be a zip archive.""" 

2393 

2394 

2395class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2396 type = "tensorflow_saved_model_bundle" 

2397 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2398 tensorflow_version: Version 

2399 """Version of the TensorFlow library used.""" 

2400 

2401 dependencies: Optional[FileDescr_dependencies] = None 

2402 """Custom dependencies beyond tensorflow. 

2403 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 

2404 

2405 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2406 """The multi-file weights. 

2407 All required files/folders should be a zip archive.""" 

2408 

2409 

2410class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2411 type = "torchscript" 

2412 weights_format_name: ClassVar[str] = "TorchScript" 

2413 pytorch_version: Version 

2414 """Version of the PyTorch library used.""" 

2415 

2416 

2417class WeightsDescr(Node): 

2418 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2419 onnx: Optional[OnnxWeightsDescr] = None 

2420 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2421 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2422 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2423 None 

2424 ) 

2425 torchscript: Optional[TorchscriptWeightsDescr] = None 

2426 

2427 @model_validator(mode="after") 

2428 def check_entries(self) -> Self: 

2429 entries = {wtype for wtype, entry in self if entry is not None} 

2430 

2431 if not entries: 

2432 raise ValueError("Missing weights entry") 

2433 

2434 entries_wo_parent = { 

2435 wtype 

2436 for wtype, entry in self 

2437 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2438 } 

2439 if len(entries_wo_parent) != 1: 

2440 issue_warning( 

2441 "Exactly one weights entry may not specify the `parent` field (got" 

2442 + " {value}). That entry is considered the original set of model weights." 

2443 + " Other weight formats are created through conversion of the orignal or" 

2444 + " already converted weights. They have to reference the weights format" 

2445 + " they were converted from as their `parent`.", 

2446 value=len(entries_wo_parent), 

2447 field="weights", 

2448 ) 

2449 

2450 for wtype, entry in self: 

2451 if entry is None: 

2452 continue 

2453 

2454 assert hasattr(entry, "type") 

2455 assert hasattr(entry, "parent") 

2456 assert wtype == entry.type 

2457 if ( 

2458 entry.parent is not None and entry.parent not in entries 

2459 ): # self reference checked for `parent` field 

2460 raise ValueError( 

2461 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2462 + f" formats: {entries}" 

2463 ) 

2464 

2465 return self 

2466 

2467 def __getitem__( 

2468 self, 

2469 key: Literal[ 

2470 "keras_hdf5", 

2471 "onnx", 

2472 "pytorch_state_dict", 

2473 "tensorflow_js", 

2474 "tensorflow_saved_model_bundle", 

2475 "torchscript", 

2476 ], 

2477 ): 

2478 if key == "keras_hdf5": 

2479 ret = self.keras_hdf5 

2480 elif key == "onnx": 

2481 ret = self.onnx 

2482 elif key == "pytorch_state_dict": 

2483 ret = self.pytorch_state_dict 

2484 elif key == "tensorflow_js": 

2485 ret = self.tensorflow_js 

2486 elif key == "tensorflow_saved_model_bundle": 

2487 ret = self.tensorflow_saved_model_bundle 

2488 elif key == "torchscript": 

2489 ret = self.torchscript 

2490 else: 

2491 raise KeyError(key) 

2492 

2493 if ret is None: 

2494 raise KeyError(key) 

2495 

2496 return ret 

2497 

2498 @property 

2499 def available_formats(self): 

2500 return { 

2501 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2502 **({} if self.onnx is None else {"onnx": self.onnx}), 

2503 **( 

2504 {} 

2505 if self.pytorch_state_dict is None 

2506 else {"pytorch_state_dict": self.pytorch_state_dict} 

2507 ), 

2508 **( 

2509 {} 

2510 if self.tensorflow_js is None 

2511 else {"tensorflow_js": self.tensorflow_js} 

2512 ), 

2513 **( 

2514 {} 

2515 if self.tensorflow_saved_model_bundle is None 

2516 else { 

2517 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2518 } 

2519 ), 

2520 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2521 } 

2522 

2523 @property 

2524 def missing_formats(self): 

2525 return { 

2526 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2527 } 

2528 

2529 

2530class ModelId(ResourceId): 

2531 pass 

2532 

2533 

2534class LinkedModel(LinkedResourceBase): 

2535 """Reference to a bioimage.io model.""" 

2536 

2537 id: ModelId 

2538 """A valid model `id` from the bioimage.io collection.""" 

2539 

2540 

2541class _DataDepSize(NamedTuple): 

2542 min: StrictInt 

2543 max: Optional[StrictInt] 

2544 

2545 

2546class _AxisSizes(NamedTuple): 

2547 """the lenghts of all axes of model inputs and outputs""" 

2548 

2549 inputs: Dict[Tuple[TensorId, AxisId], int] 

2550 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2551 

2552 

2553class _TensorSizes(NamedTuple): 

2554 """_AxisSizes as nested dicts""" 

2555 

2556 inputs: Dict[TensorId, Dict[AxisId, int]] 

2557 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2558 

2559 

2560class ReproducibilityTolerance(Node, extra="allow"): 

2561 """Describes what small numerical differences -- if any -- may be tolerated 

2562 in the generated output when executing in different environments. 

2563 

2564 A tensor element *output* is considered mismatched to the **test_tensor** if 

2565 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

2566 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

2567 

2568 Motivation: 

2569 For testing we can request the respective deep learning frameworks to be as 

2570 reproducible as possible by setting seeds and chosing deterministic algorithms, 

2571 but differences in operating systems, available hardware and installed drivers 

2572 may still lead to numerical differences. 

2573 """ 

2574 

2575 relative_tolerance: RelativeTolerance = 1e-3 

2576 """Maximum relative tolerance of reproduced test tensor.""" 

2577 

2578 absolute_tolerance: AbsoluteTolerance = 1e-4 

2579 """Maximum absolute tolerance of reproduced test tensor.""" 

2580 

2581 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 

2582 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

2583 

2584 output_ids: Sequence[TensorId] = () 

2585 """Limits the output tensor IDs these reproducibility details apply to.""" 

2586 

2587 weights_formats: Sequence[WeightsFormat] = () 

2588 """Limits the weights formats these details apply to.""" 

2589 

2590 

2591class BioimageioConfig(Node, extra="allow"): 

2592 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

2593 """Tolerances to allow when reproducing the model's test outputs 

2594 from the model's test inputs. 

2595 Only the first entry matching tensor id and weights format is considered. 

2596 """ 

2597 

2598 

2599class Config(Node, extra="allow"): 

2600 bioimageio: BioimageioConfig = Field( 

2601 default_factory=BioimageioConfig.model_construct 

2602 ) 

2603 

2604 

2605class ModelDescr(GenericModelDescrBase): 

2606 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

2607 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

2608 """ 

2609 

2610 implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5" 

2611 if TYPE_CHECKING: 

2612 format_version: Literal["0.5.5"] = "0.5.5" 

2613 else: 

2614 format_version: Literal["0.5.5"] 

2615 """Version of the bioimage.io model description specification used. 

2616 When creating a new model always use the latest micro/patch version described here. 

2617 The `format_version` is important for any consumer software to understand how to parse the fields. 

2618 """ 

2619 

2620 implemented_type: ClassVar[Literal["model"]] = "model" 

2621 if TYPE_CHECKING: 

2622 type: Literal["model"] = "model" 

2623 else: 

2624 type: Literal["model"] 

2625 """Specialized resource type 'model'""" 

2626 

2627 id: Optional[ModelId] = None 

2628 """bioimage.io-wide unique resource identifier 

2629 assigned by bioimage.io; version **un**specific.""" 

2630 

2631 authors: FAIR[List[Author]] = Field( 

2632 default_factory=cast(Callable[[], List[Author]], list) 

2633 ) 

2634 """The authors are the creators of the model RDF and the primary points of contact.""" 

2635 

2636 documentation: FAIR[Optional[FileSource_documentation]] = None 

2637 """URL or relative path to a markdown file with additional documentation. 

2638 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

2639 The documentation should include a '#[#] Validation' (sub)section 

2640 with details on how to quantitatively validate the model on unseen data.""" 

2641 

2642 @field_validator("documentation", mode="after") 

2643 @classmethod 

2644 def _validate_documentation( 

2645 cls, value: Optional[FileSource_documentation] 

2646 ) -> Optional[FileSource_documentation]: 

2647 if not get_validation_context().perform_io_checks or value is None: 

2648 return value 

2649 

2650 doc_reader = get_reader(value) 

2651 doc_content = doc_reader.read().decode(encoding="utf-8") 

2652 if not re.search("#.*[vV]alidation", doc_content): 

2653 issue_warning( 

2654 "No '# Validation' (sub)section found in {value}.", 

2655 value=value, 

2656 field="documentation", 

2657 ) 

2658 

2659 return value 

2660 

2661 inputs: NotEmpty[Sequence[InputTensorDescr]] 

2662 """Describes the input tensors expected by this model.""" 

2663 

2664 @field_validator("inputs", mode="after") 

2665 @classmethod 

2666 def _validate_input_axes( 

2667 cls, inputs: Sequence[InputTensorDescr] 

2668 ) -> Sequence[InputTensorDescr]: 

2669 input_size_refs = cls._get_axes_with_independent_size(inputs) 

2670 

2671 for i, ipt in enumerate(inputs): 

2672 valid_independent_refs: Dict[ 

2673 Tuple[TensorId, AxisId], 

2674 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2675 ] = { 

2676 **{ 

2677 (ipt.id, a.id): (ipt, a, a.size) 

2678 for a in ipt.axes 

2679 if not isinstance(a, BatchAxis) 

2680 and isinstance(a.size, (int, ParameterizedSize)) 

2681 }, 

2682 **input_size_refs, 

2683 } 

2684 for a, ax in enumerate(ipt.axes): 

2685 cls._validate_axis( 

2686 "inputs", 

2687 i=i, 

2688 tensor_id=ipt.id, 

2689 a=a, 

2690 axis=ax, 

2691 valid_independent_refs=valid_independent_refs, 

2692 ) 

2693 return inputs 

2694 

2695 @staticmethod 

2696 def _validate_axis( 

2697 field_name: str, 

2698 i: int, 

2699 tensor_id: TensorId, 

2700 a: int, 

2701 axis: AnyAxis, 

2702 valid_independent_refs: Dict[ 

2703 Tuple[TensorId, AxisId], 

2704 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2705 ], 

2706 ): 

2707 if isinstance(axis, BatchAxis) or isinstance( 

2708 axis.size, (int, ParameterizedSize, DataDependentSize) 

2709 ): 

2710 return 

2711 elif not isinstance(axis.size, SizeReference): 

2712 assert_never(axis.size) 

2713 

2714 # validate axis.size SizeReference 

2715 ref = (axis.size.tensor_id, axis.size.axis_id) 

2716 if ref not in valid_independent_refs: 

2717 raise ValueError( 

2718 "Invalid tensor axis reference at" 

2719 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

2720 ) 

2721 if ref == (tensor_id, axis.id): 

2722 raise ValueError( 

2723 "Self-referencing not allowed for" 

2724 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

2725 ) 

2726 if axis.type == "channel": 

2727 if valid_independent_refs[ref][1].type != "channel": 

2728 raise ValueError( 

2729 "A channel axis' size may only reference another fixed size" 

2730 + " channel axis." 

2731 ) 

2732 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

2733 ref_size = valid_independent_refs[ref][2] 

2734 assert isinstance(ref_size, int), ( 

2735 "channel axis ref (another channel axis) has to specify fixed" 

2736 + " size" 

2737 ) 

2738 generated_channel_names = [ 

2739 Identifier(axis.channel_names.format(i=i)) 

2740 for i in range(1, ref_size + 1) 

2741 ] 

2742 axis.channel_names = generated_channel_names 

2743 

2744 if (ax_unit := getattr(axis, "unit", None)) != ( 

2745 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

2746 ): 

2747 raise ValueError( 

2748 "The units of an axis and its reference axis need to match, but" 

2749 + f" '{ax_unit}' != '{ref_unit}'." 

2750 ) 

2751 ref_axis = valid_independent_refs[ref][1] 

2752 if isinstance(ref_axis, BatchAxis): 

2753 raise ValueError( 

2754 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

2755 + " (a batch axis is not allowed as reference)." 

2756 ) 

2757 

2758 if isinstance(axis, WithHalo): 

2759 min_size = axis.size.get_size(axis, ref_axis, n=0) 

2760 if (min_size - 2 * axis.halo) < 1: 

2761 raise ValueError( 

2762 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

2763 + f" {axis.halo}." 

2764 ) 

2765 

2766 input_halo = axis.halo * axis.scale / ref_axis.scale 

2767 if input_halo != int(input_halo) or input_halo % 2 == 1: 

2768 raise ValueError( 

2769 f"input_halo {input_halo} (output_halo {axis.halo} *" 

2770 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 

2771 + f" {tensor_id}.{axis.id}." 

2772 ) 

2773 

2774 @model_validator(mode="after") 

2775 def _validate_test_tensors(self) -> Self: 

2776 if not get_validation_context().perform_io_checks: 

2777 return self 

2778 

2779 test_output_arrays = [ 

2780 None if descr.test_tensor is None else load_array(descr.test_tensor) 

2781 for descr in self.outputs 

2782 ] 

2783 test_input_arrays = [ 

2784 None if descr.test_tensor is None else load_array(descr.test_tensor) 

2785 for descr in self.inputs 

2786 ] 

2787 

2788 tensors = { 

2789 descr.id: (descr, array) 

2790 for descr, array in zip( 

2791 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 

2792 ) 

2793 } 

2794 validate_tensors(tensors, tensor_origin="test_tensor") 

2795 

2796 output_arrays = { 

2797 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 

2798 } 

2799 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

2800 if not rep_tol.absolute_tolerance: 

2801 continue 

2802 

2803 if rep_tol.output_ids: 

2804 out_arrays = { 

2805 oid: a 

2806 for oid, a in output_arrays.items() 

2807 if oid in rep_tol.output_ids 

2808 } 

2809 else: 

2810 out_arrays = output_arrays 

2811 

2812 for out_id, array in out_arrays.items(): 

2813 if array is None: 

2814 continue 

2815 

2816 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

2817 raise ValueError( 

2818 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

2819 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

2820 + f" (1% of the maximum value of the test tensor '{out_id}')" 

2821 ) 

2822 

2823 return self 

2824 

2825 @model_validator(mode="after") 

2826 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

2827 ipt_refs = {t.id for t in self.inputs} 

2828 out_refs = {t.id for t in self.outputs} 

2829 for ipt in self.inputs: 

2830 for p in ipt.preprocessing: 

2831 ref = p.kwargs.get("reference_tensor") 

2832 if ref is None: 

2833 continue 

2834 if ref not in ipt_refs: 

2835 raise ValueError( 

2836 f"`reference_tensor` '{ref}' not found. Valid input tensor" 

2837 + f" references are: {ipt_refs}." 

2838 ) 

2839 

2840 for out in self.outputs: 

2841 for p in out.postprocessing: 

2842 ref = p.kwargs.get("reference_tensor") 

2843 if ref is None: 

2844 continue 

2845 

2846 if ref not in ipt_refs and ref not in out_refs: 

2847 raise ValueError( 

2848 f"`reference_tensor` '{ref}' not found. Valid tensor references" 

2849 + f" are: {ipt_refs | out_refs}." 

2850 ) 

2851 

2852 return self 

2853 

2854 # TODO: use validate funcs in validate_test_tensors 

2855 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 

2856 

2857 name: Annotated[ 

2858 str, 

2859 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 

2860 MinLen(5), 

2861 MaxLen(128), 

2862 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

2863 ] 

2864 """A human-readable name of this model. 

2865 It should be no longer than 64 characters 

2866 and may only contain letter, number, underscore, minus, parentheses and spaces. 

2867 We recommend to chose a name that refers to the model's task and image modality. 

2868 """ 

2869 

2870 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

2871 """Describes the output tensors.""" 

2872 

2873 @field_validator("outputs", mode="after") 

2874 @classmethod 

2875 def _validate_tensor_ids( 

2876 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

2877 ) -> Sequence[OutputTensorDescr]: 

2878 tensor_ids = [ 

2879 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

2880 ] 

2881 duplicate_tensor_ids: List[str] = [] 

2882 seen: Set[str] = set() 

2883 for t in tensor_ids: 

2884 if t in seen: 

2885 duplicate_tensor_ids.append(t) 

2886 

2887 seen.add(t) 

2888 

2889 if duplicate_tensor_ids: 

2890 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

2891 

2892 return outputs 

2893 

2894 @staticmethod 

2895 def _get_axes_with_parameterized_size( 

2896 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2897 ): 

2898 return { 

2899 f"{t.id}.{a.id}": (t, a, a.size) 

2900 for t in io 

2901 for a in t.axes 

2902 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

2903 } 

2904 

2905 @staticmethod 

2906 def _get_axes_with_independent_size( 

2907 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2908 ): 

2909 return { 

2910 (t.id, a.id): (t, a, a.size) 

2911 for t in io 

2912 for a in t.axes 

2913 if not isinstance(a, BatchAxis) 

2914 and isinstance(a.size, (int, ParameterizedSize)) 

2915 } 

2916 

2917 @field_validator("outputs", mode="after") 

2918 @classmethod 

2919 def _validate_output_axes( 

2920 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

2921 ) -> List[OutputTensorDescr]: 

2922 input_size_refs = cls._get_axes_with_independent_size( 

2923 info.data.get("inputs", []) 

2924 ) 

2925 output_size_refs = cls._get_axes_with_independent_size(outputs) 

2926 

2927 for i, out in enumerate(outputs): 

2928 valid_independent_refs: Dict[ 

2929 Tuple[TensorId, AxisId], 

2930 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2931 ] = { 

2932 **{ 

2933 (out.id, a.id): (out, a, a.size) 

2934 for a in out.axes 

2935 if not isinstance(a, BatchAxis) 

2936 and isinstance(a.size, (int, ParameterizedSize)) 

2937 }, 

2938 **input_size_refs, 

2939 **output_size_refs, 

2940 } 

2941 for a, ax in enumerate(out.axes): 

2942 cls._validate_axis( 

2943 "outputs", 

2944 i, 

2945 out.id, 

2946 a, 

2947 ax, 

2948 valid_independent_refs=valid_independent_refs, 

2949 ) 

2950 

2951 return outputs 

2952 

2953 packaged_by: List[Author] = Field( 

2954 default_factory=cast(Callable[[], List[Author]], list) 

2955 ) 

2956 """The persons that have packaged and uploaded this model. 

2957 Only required if those persons differ from the `authors`.""" 

2958 

2959 parent: Optional[LinkedModel] = None 

2960 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

2961 

2962 @model_validator(mode="after") 

2963 def _validate_parent_is_not_self(self) -> Self: 

2964 if self.parent is not None and self.parent.id == self.id: 

2965 raise ValueError("A model description may not reference itself as parent.") 

2966 

2967 return self 

2968 

2969 run_mode: Annotated[ 

2970 Optional[RunMode], 

2971 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

2972 ] = None 

2973 """Custom run mode for this model: for more complex prediction procedures like test time 

2974 data augmentation that currently cannot be expressed in the specification. 

2975 No standard run modes are defined yet.""" 

2976 

2977 timestamp: Datetime = Field(default_factory=Datetime.now) 

2978 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

2979 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

2980 (In Python a datetime object is valid, too).""" 

2981 

2982 training_data: Annotated[ 

2983 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

2984 Field(union_mode="left_to_right"), 

2985 ] = None 

2986 """The dataset used to train this model""" 

2987 

2988 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

2989 """The weights for this model. 

2990 Weights can be given for different formats, but should otherwise be equivalent. 

2991 The available weight formats determine which consumers can use this model.""" 

2992 

2993 config: Config = Field(default_factory=Config.model_construct) 

2994 

2995 @model_validator(mode="after") 

2996 def _add_default_cover(self) -> Self: 

2997 if not get_validation_context().perform_io_checks or self.covers: 

2998 return self 

2999 

3000 try: 

3001 generated_covers = generate_covers( 

3002 [ 

3003 (t, load_array(t.test_tensor)) 

3004 for t in self.inputs 

3005 if t.test_tensor is not None 

3006 ], 

3007 [ 

3008 (t, load_array(t.test_tensor)) 

3009 for t in self.outputs 

3010 if t.test_tensor is not None 

3011 ], 

3012 ) 

3013 except Exception as e: 

3014 issue_warning( 

3015 "Failed to generate cover image(s): {e}", 

3016 value=self.covers, 

3017 msg_context=dict(e=e), 

3018 field="covers", 

3019 ) 

3020 else: 

3021 self.covers.extend(generated_covers) 

3022 

3023 return self 

3024 

3025 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

3026 return self._get_test_arrays(self.inputs) 

3027 

3028 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

3029 return self._get_test_arrays(self.outputs) 

3030 

3031 @staticmethod 

3032 def _get_test_arrays( 

3033 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3034 ): 

3035 ts: List[FileDescr] = [] 

3036 for d in io_descr: 

3037 if d.test_tensor is None: 

3038 raise ValueError( 

3039 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 

3040 ) 

3041 ts.append(d.test_tensor) 

3042 

3043 data = [load_array(t) for t in ts] 

3044 assert all(isinstance(d, np.ndarray) for d in data) 

3045 return data 

3046 

3047 @staticmethod 

3048 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

3049 batch_size = 1 

3050 tensor_with_batchsize: Optional[TensorId] = None 

3051 for tid in tensor_sizes: 

3052 for aid, s in tensor_sizes[tid].items(): 

3053 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

3054 continue 

3055 

3056 if batch_size != 1: 

3057 assert tensor_with_batchsize is not None 

3058 raise ValueError( 

3059 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

3060 ) 

3061 

3062 batch_size = s 

3063 tensor_with_batchsize = tid 

3064 

3065 return batch_size 

3066 

3067 def get_output_tensor_sizes( 

3068 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

3069 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

3070 """Returns the tensor output sizes for given **input_sizes**. 

3071 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

3072 Otherwise it might be larger than the actual (valid) output""" 

3073 batch_size = self.get_batch_size(input_sizes) 

3074 ns = self.get_ns(input_sizes) 

3075 

3076 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

3077 return tensor_sizes.outputs 

3078 

3079 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

3080 """get parameter `n` for each parameterized axis 

3081 such that the valid input size is >= the given input size""" 

3082 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

3083 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

3084 for tid in input_sizes: 

3085 for aid, s in input_sizes[tid].items(): 

3086 size_descr = axes[tid][aid].size 

3087 if isinstance(size_descr, ParameterizedSize): 

3088 ret[(tid, aid)] = size_descr.get_n(s) 

3089 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

3090 pass 

3091 else: 

3092 assert_never(size_descr) 

3093 

3094 return ret 

3095 

3096 def get_tensor_sizes( 

3097 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

3098 ) -> _TensorSizes: 

3099 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

3100 return _TensorSizes( 

3101 { 

3102 t: { 

3103 aa: axis_sizes.inputs[(tt, aa)] 

3104 for tt, aa in axis_sizes.inputs 

3105 if tt == t 

3106 } 

3107 for t in {tt for tt, _ in axis_sizes.inputs} 

3108 }, 

3109 { 

3110 t: { 

3111 aa: axis_sizes.outputs[(tt, aa)] 

3112 for tt, aa in axis_sizes.outputs 

3113 if tt == t 

3114 } 

3115 for t in {tt for tt, _ in axis_sizes.outputs} 

3116 }, 

3117 ) 

3118 

3119 def get_axis_sizes( 

3120 self, 

3121 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3122 batch_size: Optional[int] = None, 

3123 *, 

3124 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

3125 ) -> _AxisSizes: 

3126 """Determine input and output block shape for scale factors **ns** 

3127 of parameterized input sizes. 

3128 

3129 Args: 

3130 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

3131 that is parameterized as `size = min + n * step`. 

3132 batch_size: The desired size of the batch dimension. 

3133 If given **batch_size** overwrites any batch size present in 

3134 **max_input_shape**. Default 1. 

3135 max_input_shape: Limits the derived block shapes. 

3136 Each axis for which the input size, parameterized by `n`, is larger 

3137 than **max_input_shape** is set to the minimal value `n_min` for which 

3138 this is still true. 

3139 Use this for small input samples or large values of **ns**. 

3140 Or simply whenever you know the full input shape. 

3141 

3142 Returns: 

3143 Resolved axis sizes for model inputs and outputs. 

3144 """ 

3145 max_input_shape = max_input_shape or {} 

3146 if batch_size is None: 

3147 for (_t_id, a_id), s in max_input_shape.items(): 

3148 if a_id == BATCH_AXIS_ID: 

3149 batch_size = s 

3150 break 

3151 else: 

3152 batch_size = 1 

3153 

3154 all_axes = { 

3155 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

3156 } 

3157 

3158 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

3159 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

3160 

3161 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

3162 if isinstance(a, BatchAxis): 

3163 if (t_descr.id, a.id) in ns: 

3164 logger.warning( 

3165 "Ignoring unexpected size increment factor (n) for batch axis" 

3166 + " of tensor '{}'.", 

3167 t_descr.id, 

3168 ) 

3169 return batch_size 

3170 elif isinstance(a.size, int): 

3171 if (t_descr.id, a.id) in ns: 

3172 logger.warning( 

3173 "Ignoring unexpected size increment factor (n) for fixed size" 

3174 + " axis '{}' of tensor '{}'.", 

3175 a.id, 

3176 t_descr.id, 

3177 ) 

3178 return a.size 

3179 elif isinstance(a.size, ParameterizedSize): 

3180 if (t_descr.id, a.id) not in ns: 

3181 raise ValueError( 

3182 "Size increment factor (n) missing for parametrized axis" 

3183 + f" '{a.id}' of tensor '{t_descr.id}'." 

3184 ) 

3185 n = ns[(t_descr.id, a.id)] 

3186 s_max = max_input_shape.get((t_descr.id, a.id)) 

3187 if s_max is not None: 

3188 n = min(n, a.size.get_n(s_max)) 

3189 

3190 return a.size.get_size(n) 

3191 

3192 elif isinstance(a.size, SizeReference): 

3193 if (t_descr.id, a.id) in ns: 

3194 logger.warning( 

3195 "Ignoring unexpected size increment factor (n) for axis '{}'" 

3196 + " of tensor '{}' with size reference.", 

3197 a.id, 

3198 t_descr.id, 

3199 ) 

3200 assert not isinstance(a, BatchAxis) 

3201 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

3202 assert not isinstance(ref_axis, BatchAxis) 

3203 ref_key = (a.size.tensor_id, a.size.axis_id) 

3204 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

3205 assert ref_size is not None, ref_key 

3206 assert not isinstance(ref_size, _DataDepSize), ref_key 

3207 return a.size.get_size( 

3208 axis=a, 

3209 ref_axis=ref_axis, 

3210 ref_size=ref_size, 

3211 ) 

3212 elif isinstance(a.size, DataDependentSize): 

3213 if (t_descr.id, a.id) in ns: 

3214 logger.warning( 

3215 "Ignoring unexpected increment factor (n) for data dependent" 

3216 + " size axis '{}' of tensor '{}'.", 

3217 a.id, 

3218 t_descr.id, 

3219 ) 

3220 return _DataDepSize(a.size.min, a.size.max) 

3221 else: 

3222 assert_never(a.size) 

3223 

3224 # first resolve all , but the `SizeReference` input sizes 

3225 for t_descr in self.inputs: 

3226 for a in t_descr.axes: 

3227 if not isinstance(a.size, SizeReference): 

3228 s = get_axis_size(a) 

3229 assert not isinstance(s, _DataDepSize) 

3230 inputs[t_descr.id, a.id] = s 

3231 

3232 # resolve all other input axis sizes 

3233 for t_descr in self.inputs: 

3234 for a in t_descr.axes: 

3235 if isinstance(a.size, SizeReference): 

3236 s = get_axis_size(a) 

3237 assert not isinstance(s, _DataDepSize) 

3238 inputs[t_descr.id, a.id] = s 

3239 

3240 # resolve all output axis sizes 

3241 for t_descr in self.outputs: 

3242 for a in t_descr.axes: 

3243 assert not isinstance(a.size, ParameterizedSize) 

3244 s = get_axis_size(a) 

3245 outputs[t_descr.id, a.id] = s 

3246 

3247 return _AxisSizes(inputs=inputs, outputs=outputs) 

3248 

3249 @model_validator(mode="before") 

3250 @classmethod 

3251 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

3252 cls.convert_from_old_format_wo_validation(data) 

3253 return data 

3254 

3255 @classmethod 

3256 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

3257 """Convert metadata following an older format version to this classes' format 

3258 without validating the result. 

3259 """ 

3260 if ( 

3261 data.get("type") == "model" 

3262 and isinstance(fv := data.get("format_version"), str) 

3263 and fv.count(".") == 2 

3264 ): 

3265 fv_parts = fv.split(".") 

3266 if any(not p.isdigit() for p in fv_parts): 

3267 return 

3268 

3269 fv_tuple = tuple(map(int, fv_parts)) 

3270 

3271 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

3272 if fv_tuple[:2] in ((0, 3), (0, 4)): 

3273 m04 = _ModelDescr_v0_4.load(data) 

3274 if isinstance(m04, InvalidDescr): 

3275 try: 

3276 updated = _model_conv.convert_as_dict( 

3277 m04 # pyright: ignore[reportArgumentType] 

3278 ) 

3279 except Exception as e: 

3280 logger.error( 

3281 "Failed to convert from invalid model 0.4 description." 

3282 + f"\nerror: {e}" 

3283 + "\nProceeding with model 0.5 validation without conversion." 

3284 ) 

3285 updated = None 

3286 else: 

3287 updated = _model_conv.convert_as_dict(m04) 

3288 

3289 if updated is not None: 

3290 data.clear() 

3291 data.update(updated) 

3292 

3293 elif fv_tuple[:2] == (0, 5): 

3294 # bump patch version 

3295 data["format_version"] = cls.implemented_format_version 

3296 

3297 

3298class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

3299 def _convert( 

3300 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

3301 ) -> "ModelDescr | dict[str, Any]": 

3302 name = "".join( 

3303 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

3304 for c in src.name 

3305 ) 

3306 

3307 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

3308 conv = ( 

3309 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

3310 ) 

3311 return None if auths is None else [conv(a) for a in auths] 

3312 

3313 if TYPE_CHECKING: 

3314 arch_file_conv = _arch_file_conv.convert 

3315 arch_lib_conv = _arch_lib_conv.convert 

3316 else: 

3317 arch_file_conv = _arch_file_conv.convert_as_dict 

3318 arch_lib_conv = _arch_lib_conv.convert_as_dict 

3319 

3320 input_size_refs = { 

3321 ipt.name: { 

3322 a: s 

3323 for a, s in zip( 

3324 ipt.axes, 

3325 ( 

3326 ipt.shape.min 

3327 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

3328 else ipt.shape 

3329 ), 

3330 ) 

3331 } 

3332 for ipt in src.inputs 

3333 if ipt.shape 

3334 } 

3335 output_size_refs = { 

3336 **{ 

3337 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

3338 for out in src.outputs 

3339 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

3340 }, 

3341 **input_size_refs, 

3342 } 

3343 

3344 return tgt( 

3345 attachments=( 

3346 [] 

3347 if src.attachments is None 

3348 else [FileDescr(source=f) for f in src.attachments.files] 

3349 ), 

3350 authors=[ 

3351 _author_conv.convert_as_dict(a) for a in src.authors 

3352 ], # pyright: ignore[reportArgumentType] 

3353 cite=[ 

3354 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 

3355 ], # pyright: ignore[reportArgumentType] 

3356 config=src.config, # pyright: ignore[reportArgumentType] 

3357 covers=src.covers, 

3358 description=src.description, 

3359 documentation=src.documentation, 

3360 format_version="0.5.5", 

3361 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3362 icon=src.icon, 

3363 id=None if src.id is None else ModelId(src.id), 

3364 id_emoji=src.id_emoji, 

3365 license=src.license, # type: ignore 

3366 links=src.links, 

3367 maintainers=[ 

3368 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 

3369 ], # pyright: ignore[reportArgumentType] 

3370 name=name, 

3371 tags=src.tags, 

3372 type=src.type, 

3373 uploader=src.uploader, 

3374 version=src.version, 

3375 inputs=[ # pyright: ignore[reportArgumentType] 

3376 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

3377 for ipt, tt, st, in zip( 

3378 src.inputs, 

3379 src.test_inputs, 

3380 src.sample_inputs or [None] * len(src.test_inputs), 

3381 ) 

3382 ], 

3383 outputs=[ # pyright: ignore[reportArgumentType] 

3384 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

3385 for out, tt, st, in zip( 

3386 src.outputs, 

3387 src.test_outputs, 

3388 src.sample_outputs or [None] * len(src.test_outputs), 

3389 ) 

3390 ], 

3391 parent=( 

3392 None 

3393 if src.parent is None 

3394 else LinkedModel( 

3395 id=ModelId( 

3396 str(src.parent.id) 

3397 + ( 

3398 "" 

3399 if src.parent.version_number is None 

3400 else f"/{src.parent.version_number}" 

3401 ) 

3402 ) 

3403 ) 

3404 ), 

3405 training_data=( 

3406 None 

3407 if src.training_data is None 

3408 else ( 

3409 LinkedDataset( 

3410 id=DatasetId( 

3411 str(src.training_data.id) 

3412 + ( 

3413 "" 

3414 if src.training_data.version_number is None 

3415 else f"/{src.training_data.version_number}" 

3416 ) 

3417 ) 

3418 ) 

3419 if isinstance(src.training_data, LinkedDataset02) 

3420 else src.training_data 

3421 ) 

3422 ), 

3423 packaged_by=[ 

3424 _author_conv.convert_as_dict(a) for a in src.packaged_by 

3425 ], # pyright: ignore[reportArgumentType] 

3426 run_mode=src.run_mode, 

3427 timestamp=src.timestamp, 

3428 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

3429 keras_hdf5=(w := src.weights.keras_hdf5) 

3430 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

3431 authors=conv_authors(w.authors), 

3432 source=w.source, 

3433 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3434 parent=w.parent, 

3435 ), 

3436 onnx=(w := src.weights.onnx) 

3437 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

3438 source=w.source, 

3439 authors=conv_authors(w.authors), 

3440 parent=w.parent, 

3441 opset_version=w.opset_version or 15, 

3442 ), 

3443 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

3444 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

3445 source=w.source, 

3446 authors=conv_authors(w.authors), 

3447 parent=w.parent, 

3448 architecture=( 

3449 arch_file_conv( 

3450 w.architecture, 

3451 w.architecture_sha256, 

3452 w.kwargs, 

3453 ) 

3454 if isinstance(w.architecture, _CallableFromFile_v0_4) 

3455 else arch_lib_conv(w.architecture, w.kwargs) 

3456 ), 

3457 pytorch_version=w.pytorch_version or Version("1.10"), 

3458 dependencies=( 

3459 None 

3460 if w.dependencies is None 

3461 else (FileDescr if TYPE_CHECKING else dict)( 

3462 source=cast( 

3463 FileSource, 

3464 str(deps := w.dependencies)[ 

3465 ( 

3466 len("conda:") 

3467 if str(deps).startswith("conda:") 

3468 else 0 

3469 ) : 

3470 ], 

3471 ) 

3472 ) 

3473 ), 

3474 ), 

3475 tensorflow_js=(w := src.weights.tensorflow_js) 

3476 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

3477 source=w.source, 

3478 authors=conv_authors(w.authors), 

3479 parent=w.parent, 

3480 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3481 ), 

3482 tensorflow_saved_model_bundle=( 

3483 w := src.weights.tensorflow_saved_model_bundle 

3484 ) 

3485 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

3486 authors=conv_authors(w.authors), 

3487 parent=w.parent, 

3488 source=w.source, 

3489 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3490 dependencies=( 

3491 None 

3492 if w.dependencies is None 

3493 else (FileDescr if TYPE_CHECKING else dict)( 

3494 source=cast( 

3495 FileSource, 

3496 ( 

3497 str(w.dependencies)[len("conda:") :] 

3498 if str(w.dependencies).startswith("conda:") 

3499 else str(w.dependencies) 

3500 ), 

3501 ) 

3502 ) 

3503 ), 

3504 ), 

3505 torchscript=(w := src.weights.torchscript) 

3506 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

3507 source=w.source, 

3508 authors=conv_authors(w.authors), 

3509 parent=w.parent, 

3510 pytorch_version=w.pytorch_version or Version("1.10"), 

3511 ), 

3512 ), 

3513 ) 

3514 

3515 

3516_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

3517 

3518 

3519# create better cover images for 3d data and non-image outputs 

3520def generate_covers( 

3521 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

3522 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

3523) -> List[Path]: 

3524 def squeeze( 

3525 data: NDArray[Any], axes: Sequence[AnyAxis] 

3526 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

3527 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

3528 if data.ndim != len(axes): 

3529 raise ValueError( 

3530 f"tensor shape {data.shape} does not match described axes" 

3531 + f" {[a.id for a in axes]}" 

3532 ) 

3533 

3534 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

3535 return data.squeeze(), axes 

3536 

3537 def normalize( 

3538 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

3539 ) -> NDArray[np.float32]: 

3540 data = data.astype("float32") 

3541 data -= data.min(axis=axis, keepdims=True) 

3542 data /= data.max(axis=axis, keepdims=True) + eps 

3543 return data 

3544 

3545 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

3546 original_shape = data.shape 

3547 data, axes = squeeze(data, axes) 

3548 

3549 # take slice fom any batch or index axis if needed 

3550 # and convert the first channel axis and take a slice from any additional channel axes 

3551 slices: Tuple[slice, ...] = () 

3552 ndim = data.ndim 

3553 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

3554 has_c_axis = False 

3555 for i, a in enumerate(axes): 

3556 s = data.shape[i] 

3557 assert s > 1 

3558 if ( 

3559 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

3560 and ndim > ndim_need 

3561 ): 

3562 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3563 ndim -= 1 

3564 elif isinstance(a, ChannelAxis): 

3565 if has_c_axis: 

3566 # second channel axis 

3567 data = data[slices + (slice(0, 1),)] 

3568 ndim -= 1 

3569 else: 

3570 has_c_axis = True 

3571 if s == 2: 

3572 # visualize two channels with cyan and magenta 

3573 data = np.concatenate( 

3574 [ 

3575 data[slices + (slice(1, 2),)], 

3576 data[slices + (slice(0, 1),)], 

3577 ( 

3578 data[slices + (slice(0, 1),)] 

3579 + data[slices + (slice(1, 2),)] 

3580 ) 

3581 / 2, # TODO: take maximum instead? 

3582 ], 

3583 axis=i, 

3584 ) 

3585 elif data.shape[i] == 3: 

3586 pass # visualize 3 channels as RGB 

3587 else: 

3588 # visualize first 3 channels as RGB 

3589 data = data[slices + (slice(3),)] 

3590 

3591 assert data.shape[i] == 3 

3592 

3593 slices += (slice(None),) 

3594 

3595 data, axes = squeeze(data, axes) 

3596 assert len(axes) == ndim 

3597 # take slice from z axis if needed 

3598 slices = () 

3599 if ndim > ndim_need: 

3600 for i, a in enumerate(axes): 

3601 s = data.shape[i] 

3602 if a.id == AxisId("z"): 

3603 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3604 data, axes = squeeze(data, axes) 

3605 ndim -= 1 

3606 break 

3607 

3608 slices += (slice(None),) 

3609 

3610 # take slice from any space or time axis 

3611 slices = () 

3612 

3613 for i, a in enumerate(axes): 

3614 if ndim <= ndim_need: 

3615 break 

3616 

3617 s = data.shape[i] 

3618 assert s > 1 

3619 if isinstance( 

3620 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

3621 ): 

3622 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3623 ndim -= 1 

3624 

3625 slices += (slice(None),) 

3626 

3627 del slices 

3628 data, axes = squeeze(data, axes) 

3629 assert len(axes) == ndim 

3630 

3631 if (has_c_axis and ndim != 3) or ndim != 2: 

3632 raise ValueError( 

3633 f"Failed to construct cover image from shape {original_shape}" 

3634 ) 

3635 

3636 if not has_c_axis: 

3637 assert ndim == 2 

3638 data = np.repeat(data[:, :, None], 3, axis=2) 

3639 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

3640 ndim += 1 

3641 

3642 assert ndim == 3 

3643 

3644 # transpose axis order such that longest axis comes first... 

3645 axis_order: List[int] = list(np.argsort(list(data.shape))) 

3646 axis_order.reverse() 

3647 # ... and channel axis is last 

3648 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

3649 axis_order.append(axis_order.pop(c)) 

3650 axes = [axes[ao] for ao in axis_order] 

3651 data = data.transpose(axis_order) 

3652 

3653 # h, w = data.shape[:2] 

3654 # if h / w in (1.0 or 2.0): 

3655 # pass 

3656 # elif h / w < 2: 

3657 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

3658 

3659 norm_along = ( 

3660 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

3661 ) 

3662 # normalize the data and map to 8 bit 

3663 data = normalize(data, norm_along) 

3664 data = (data * 255).astype("uint8") 

3665 

3666 return data 

3667 

3668 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

3669 assert im0.dtype == im1.dtype == np.uint8 

3670 assert im0.shape == im1.shape 

3671 assert im0.ndim == 3 

3672 N, M, C = im0.shape 

3673 assert C == 3 

3674 out = np.ones((N, M, C), dtype="uint8") 

3675 for c in range(C): 

3676 outc = np.tril(im0[..., c]) 

3677 mask = outc == 0 

3678 outc[mask] = np.triu(im1[..., c])[mask] 

3679 out[..., c] = outc 

3680 

3681 return out 

3682 

3683 if not inputs: 

3684 raise ValueError("Missing test input tensor for cover generation.") 

3685 

3686 if not outputs: 

3687 raise ValueError("Missing test output tensor for cover generation.") 

3688 

3689 ipt_descr, ipt = inputs[0] 

3690 out_descr, out = outputs[0] 

3691 

3692 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

3693 out_img = to_2d_image(out, out_descr.axes) 

3694 

3695 cover_folder = Path(mkdtemp()) 

3696 if ipt_img.shape == out_img.shape: 

3697 covers = [cover_folder / "cover.png"] 

3698 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

3699 else: 

3700 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

3701 imwrite(covers[0], ipt_img) 

3702 imwrite(covers[1], out_img) 

3703 

3704 return covers