Coverage for bioimageio/spec/model/v0_5.py: 75%

1355 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-11 07:34 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from abc import ABC 

8from copy import deepcopy 

9from itertools import chain 

10from math import ceil 

11from pathlib import Path, PurePosixPath 

12from tempfile import mkdtemp 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 ClassVar, 

18 Dict, 

19 Generic, 

20 List, 

21 Literal, 

22 Mapping, 

23 NamedTuple, 

24 Optional, 

25 Sequence, 

26 Set, 

27 Tuple, 

28 Type, 

29 TypeVar, 

30 Union, 

31 cast, 

32) 

33 

34import numpy as np 

35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

37from loguru import logger 

38from numpy.typing import NDArray 

39from pydantic import ( 

40 AfterValidator, 

41 Discriminator, 

42 Field, 

43 RootModel, 

44 SerializationInfo, 

45 SerializerFunctionWrapHandler, 

46 StrictInt, 

47 Tag, 

48 ValidationInfo, 

49 WrapSerializer, 

50 field_validator, 

51 model_serializer, 

52 model_validator, 

53) 

54from typing_extensions import Annotated, Self, assert_never, get_args 

55 

56from .._internal.common_nodes import ( 

57 InvalidDescr, 

58 Node, 

59 NodeWithExplicitlySetFields, 

60) 

61from .._internal.constants import DTYPE_LIMITS 

62from .._internal.field_warning import issue_warning, warn 

63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

64from .._internal.io import FileDescr as FileDescr 

65from .._internal.io import ( 

66 FileSource, 

67 WithSuffix, 

68 YamlValue, 

69 get_reader, 

70 wo_special_file_name, 

71) 

72from .._internal.io_basics import Sha256 as Sha256 

73from .._internal.io_packaging import ( 

74 FileDescr_, 

75 FileSource_, 

76 package_file_descr_serializer, 

77) 

78from .._internal.io_utils import load_array 

79from .._internal.node_converter import Converter 

80from .._internal.type_guards import is_dict, is_sequence 

81from .._internal.types import ( 

82 FAIR, 

83 AbsoluteTolerance, 

84 LowerCaseIdentifier, 

85 LowerCaseIdentifierAnno, 

86 MismatchedElementsPerMillion, 

87 RelativeTolerance, 

88) 

89from .._internal.types import Datetime as Datetime 

90from .._internal.types import Identifier as Identifier 

91from .._internal.types import NotEmpty as NotEmpty 

92from .._internal.types import SiUnit as SiUnit 

93from .._internal.url import HttpUrl as HttpUrl 

94from .._internal.validation_context import get_validation_context 

95from .._internal.validator_annotations import RestrictCharacters 

96from .._internal.version_type import Version as Version 

97from .._internal.warning_levels import INFO 

98from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

99from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

100from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

101from ..dataset.v0_3 import DatasetId as DatasetId 

102from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

103from ..dataset.v0_3 import Uploader as Uploader 

104from ..generic.v0_3 import ( 

105 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

106) 

107from ..generic.v0_3 import Author as Author 

108from ..generic.v0_3 import BadgeDescr as BadgeDescr 

109from ..generic.v0_3 import CiteEntry as CiteEntry 

110from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

111from ..generic.v0_3 import Doi as Doi 

112from ..generic.v0_3 import ( 

113 FileSource_documentation, 

114 GenericModelDescrBase, 

115 LinkedResourceBase, 

116 _author_conv, # pyright: ignore[reportPrivateUsage] 

117 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

118) 

119from ..generic.v0_3 import LicenseId as LicenseId 

120from ..generic.v0_3 import LinkedResource as LinkedResource 

121from ..generic.v0_3 import Maintainer as Maintainer 

122from ..generic.v0_3 import OrcidId as OrcidId 

123from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

124from ..generic.v0_3 import ResourceId as ResourceId 

125from .v0_4 import Author as _Author_v0_4 

126from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

127from .v0_4 import CallableFromDepencency as CallableFromDepencency 

128from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

129from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

130from .v0_4 import ClipDescr as _ClipDescr_v0_4 

131from .v0_4 import ClipKwargs as ClipKwargs 

132from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

133from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

134from .v0_4 import KnownRunMode as KnownRunMode 

135from .v0_4 import ModelDescr as _ModelDescr_v0_4 

136from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

137from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

138from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

139from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

140from .v0_4 import ProcessingKwargs as ProcessingKwargs 

141from .v0_4 import RunMode as RunMode 

142from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

143from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

144from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

145from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

146from .v0_4 import TensorName as _TensorName_v0_4 

147from .v0_4 import WeightsFormat as WeightsFormat 

148from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

149from .v0_4 import package_weights 

150 

151SpaceUnit = Literal[ 

152 "attometer", 

153 "angstrom", 

154 "centimeter", 

155 "decimeter", 

156 "exameter", 

157 "femtometer", 

158 "foot", 

159 "gigameter", 

160 "hectometer", 

161 "inch", 

162 "kilometer", 

163 "megameter", 

164 "meter", 

165 "micrometer", 

166 "mile", 

167 "millimeter", 

168 "nanometer", 

169 "parsec", 

170 "petameter", 

171 "picometer", 

172 "terameter", 

173 "yard", 

174 "yoctometer", 

175 "yottameter", 

176 "zeptometer", 

177 "zettameter", 

178] 

179"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

180 

181TimeUnit = Literal[ 

182 "attosecond", 

183 "centisecond", 

184 "day", 

185 "decisecond", 

186 "exasecond", 

187 "femtosecond", 

188 "gigasecond", 

189 "hectosecond", 

190 "hour", 

191 "kilosecond", 

192 "megasecond", 

193 "microsecond", 

194 "millisecond", 

195 "minute", 

196 "nanosecond", 

197 "petasecond", 

198 "picosecond", 

199 "second", 

200 "terasecond", 

201 "yoctosecond", 

202 "yottasecond", 

203 "zeptosecond", 

204 "zettasecond", 

205] 

206"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

207 

208AxisType = Literal["batch", "channel", "index", "time", "space"] 

209 

210_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

211 "b": "batch", 

212 "t": "time", 

213 "i": "index", 

214 "c": "channel", 

215 "x": "space", 

216 "y": "space", 

217 "z": "space", 

218} 

219 

220_AXIS_ID_MAP = { 

221 "b": "batch", 

222 "t": "time", 

223 "i": "index", 

224 "c": "channel", 

225} 

226 

227 

228class TensorId(LowerCaseIdentifier): 

229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

230 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

231 ] 

232 

233 

234def _normalize_axis_id(a: str): 

235 a = str(a) 

236 normalized = _AXIS_ID_MAP.get(a, a) 

237 if a != normalized: 

238 logger.opt(depth=3).warning( 

239 "Normalized axis id from '{}' to '{}'.", a, normalized 

240 ) 

241 return normalized 

242 

243 

244class AxisId(LowerCaseIdentifier): 

245 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

246 Annotated[ 

247 LowerCaseIdentifierAnno, 

248 MaxLen(16), 

249 AfterValidator(_normalize_axis_id), 

250 ] 

251 ] 

252 

253 

254def _is_batch(a: str) -> bool: 

255 return str(a) == "batch" 

256 

257 

258def _is_not_batch(a: str) -> bool: 

259 return not _is_batch(a) 

260 

261 

262NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

263 

264PreprocessingId = Literal[ 

265 "binarize", 

266 "clip", 

267 "ensure_dtype", 

268 "fixed_zero_mean_unit_variance", 

269 "scale_linear", 

270 "scale_range", 

271 "sigmoid", 

272 "softmax", 

273] 

274PostprocessingId = Literal[ 

275 "binarize", 

276 "clip", 

277 "ensure_dtype", 

278 "fixed_zero_mean_unit_variance", 

279 "scale_linear", 

280 "scale_mean_variance", 

281 "scale_range", 

282 "sigmoid", 

283 "softmax", 

284 "zero_mean_unit_variance", 

285] 

286 

287 

288SAME_AS_TYPE = "<same as type>" 

289 

290 

291ParameterizedSize_N = int 

292""" 

293Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

294""" 

295 

296 

297class ParameterizedSize(Node): 

298 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

299 

300 - **min** and **step** are given by the model description. 

301 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

302 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

303 This allows to adjust the axis size more generically. 

304 """ 

305 

306 N: ClassVar[Type[int]] = ParameterizedSize_N 

307 """Positive integer to parameterize this axis""" 

308 

309 min: Annotated[int, Gt(0)] 

310 step: Annotated[int, Gt(0)] 

311 

312 def validate_size(self, size: int) -> int: 

313 if size < self.min: 

314 raise ValueError(f"size {size} < {self.min}") 

315 if (size - self.min) % self.step != 0: 

316 raise ValueError( 

317 f"axis of size {size} is not parameterized by `min + n*step` =" 

318 + f" `{self.min} + n*{self.step}`" 

319 ) 

320 

321 return size 

322 

323 def get_size(self, n: ParameterizedSize_N) -> int: 

324 return self.min + self.step * n 

325 

326 def get_n(self, s: int) -> ParameterizedSize_N: 

327 """return smallest n parameterizing a size greater or equal than `s`""" 

328 return ceil((s - self.min) / self.step) 

329 

330 

331class DataDependentSize(Node): 

332 min: Annotated[int, Gt(0)] = 1 

333 max: Annotated[Optional[int], Gt(1)] = None 

334 

335 @model_validator(mode="after") 

336 def _validate_max_gt_min(self): 

337 if self.max is not None and self.min >= self.max: 

338 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

339 

340 return self 

341 

342 def validate_size(self, size: int) -> int: 

343 if size < self.min: 

344 raise ValueError(f"size {size} < {self.min}") 

345 

346 if self.max is not None and size > self.max: 

347 raise ValueError(f"size {size} > {self.max}") 

348 

349 return size 

350 

351 

352class SizeReference(Node): 

353 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

354 

355 `axis.size = reference.size * reference.scale / axis.scale + offset` 

356 

357 Note: 

358 1. The axis and the referenced axis need to have the same unit (or no unit). 

359 2. Batch axes may not be referenced. 

360 3. Fractions are rounded down. 

361 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

362 `concatenable` as well with the same block order. 

363 

364 Example: 

365 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

366 Let's assume that we want to express the image height h in relation to its width w 

367 instead of only accepting input images of exactly 100*49 pixels 

368 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

369 

370 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

371 >>> h = SpaceInputAxis( 

372 ... id=AxisId("h"), 

373 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

374 ... unit="millimeter", 

375 ... scale=4, 

376 ... ) 

377 >>> print(h.size.get_size(h, w)) 

378 49 

379 

380 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

381 """ 

382 

383 tensor_id: TensorId 

384 """tensor id of the reference axis""" 

385 

386 axis_id: AxisId 

387 """axis id of the reference axis""" 

388 

389 offset: StrictInt = 0 

390 

391 def get_size( 

392 self, 

393 axis: Union[ 

394 ChannelAxis, 

395 IndexInputAxis, 

396 IndexOutputAxis, 

397 TimeInputAxis, 

398 SpaceInputAxis, 

399 TimeOutputAxis, 

400 TimeOutputAxisWithHalo, 

401 SpaceOutputAxis, 

402 SpaceOutputAxisWithHalo, 

403 ], 

404 ref_axis: Union[ 

405 ChannelAxis, 

406 IndexInputAxis, 

407 IndexOutputAxis, 

408 TimeInputAxis, 

409 SpaceInputAxis, 

410 TimeOutputAxis, 

411 TimeOutputAxisWithHalo, 

412 SpaceOutputAxis, 

413 SpaceOutputAxisWithHalo, 

414 ], 

415 n: ParameterizedSize_N = 0, 

416 ref_size: Optional[int] = None, 

417 ): 

418 """Compute the concrete size for a given axis and its reference axis. 

419 

420 Args: 

421 axis: The axis this `SizeReference` is the size of. 

422 ref_axis: The reference axis to compute the size from. 

423 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

424 and no fixed **ref_size** is given, 

425 **n** is used to compute the size of the parameterized **ref_axis**. 

426 ref_size: Overwrite the reference size instead of deriving it from 

427 **ref_axis** 

428 (**ref_axis.scale** is still used; any given **n** is ignored). 

429 """ 

430 assert axis.size == self, ( 

431 "Given `axis.size` is not defined by this `SizeReference`" 

432 ) 

433 

434 assert ref_axis.id == self.axis_id, ( 

435 f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

436 ) 

437 

438 assert axis.unit == ref_axis.unit, ( 

439 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

440 f" but {axis.unit}!={ref_axis.unit}" 

441 ) 

442 if ref_size is None: 

443 if isinstance(ref_axis.size, (int, float)): 

444 ref_size = ref_axis.size 

445 elif isinstance(ref_axis.size, ParameterizedSize): 

446 ref_size = ref_axis.size.get_size(n) 

447 elif isinstance(ref_axis.size, DataDependentSize): 

448 raise ValueError( 

449 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

450 ) 

451 elif isinstance(ref_axis.size, SizeReference): 

452 raise ValueError( 

453 "Reference axis referenced in `SizeReference` may not be sized by a" 

454 + " `SizeReference` itself." 

455 ) 

456 else: 

457 assert_never(ref_axis.size) 

458 

459 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

460 

461 @staticmethod 

462 def _get_unit( 

463 axis: Union[ 

464 ChannelAxis, 

465 IndexInputAxis, 

466 IndexOutputAxis, 

467 TimeInputAxis, 

468 SpaceInputAxis, 

469 TimeOutputAxis, 

470 TimeOutputAxisWithHalo, 

471 SpaceOutputAxis, 

472 SpaceOutputAxisWithHalo, 

473 ], 

474 ): 

475 return axis.unit 

476 

477 

478class AxisBase(NodeWithExplicitlySetFields): 

479 id: AxisId 

480 """An axis id unique across all axes of one tensor.""" 

481 

482 description: Annotated[str, MaxLen(128)] = "" 

483 """A short description of this axis beyond its type and id.""" 

484 

485 

486class WithHalo(Node): 

487 halo: Annotated[int, Ge(1)] 

488 """The halo should be cropped from the output tensor to avoid boundary effects. 

489 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

490 To document a halo that is already cropped by the model use `size.offset` instead.""" 

491 

492 size: Annotated[ 

493 SizeReference, 

494 Field( 

495 examples=[ 

496 10, 

497 SizeReference( 

498 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

499 ).model_dump(mode="json"), 

500 ] 

501 ), 

502 ] 

503 """reference to another axis with an optional offset (see `SizeReference`)""" 

504 

505 

506BATCH_AXIS_ID = AxisId("batch") 

507 

508 

509class BatchAxis(AxisBase): 

510 implemented_type: ClassVar[Literal["batch"]] = "batch" 

511 if TYPE_CHECKING: 

512 type: Literal["batch"] = "batch" 

513 else: 

514 type: Literal["batch"] 

515 

516 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

517 size: Optional[Literal[1]] = None 

518 """The batch size may be fixed to 1, 

519 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

520 

521 @property 

522 def scale(self): 

523 return 1.0 

524 

525 @property 

526 def concatenable(self): 

527 return True 

528 

529 @property 

530 def unit(self): 

531 return None 

532 

533 

534class ChannelAxis(AxisBase): 

535 implemented_type: ClassVar[Literal["channel"]] = "channel" 

536 if TYPE_CHECKING: 

537 type: Literal["channel"] = "channel" 

538 else: 

539 type: Literal["channel"] 

540 

541 id: NonBatchAxisId = AxisId("channel") 

542 

543 channel_names: NotEmpty[List[Identifier]] 

544 

545 @property 

546 def size(self) -> int: 

547 return len(self.channel_names) 

548 

549 @property 

550 def concatenable(self): 

551 return False 

552 

553 @property 

554 def scale(self) -> float: 

555 return 1.0 

556 

557 @property 

558 def unit(self): 

559 return None 

560 

561 

562class IndexAxisBase(AxisBase): 

563 implemented_type: ClassVar[Literal["index"]] = "index" 

564 if TYPE_CHECKING: 

565 type: Literal["index"] = "index" 

566 else: 

567 type: Literal["index"] 

568 

569 id: NonBatchAxisId = AxisId("index") 

570 

571 @property 

572 def scale(self) -> float: 

573 return 1.0 

574 

575 @property 

576 def unit(self): 

577 return None 

578 

579 

580class _WithInputAxisSize(Node): 

581 size: Annotated[ 

582 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

583 Field( 

584 examples=[ 

585 10, 

586 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

587 SizeReference( 

588 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

589 ).model_dump(mode="json"), 

590 ] 

591 ), 

592 ] 

593 """The size/length of this axis can be specified as 

594 - fixed integer 

595 - parameterized series of valid sizes (`ParameterizedSize`) 

596 - reference to another axis with an optional offset (`SizeReference`) 

597 """ 

598 

599 

600class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

601 concatenable: bool = False 

602 """If a model has a `concatenable` input axis, it can be processed blockwise, 

603 splitting a longer sample axis into blocks matching its input tensor description. 

604 Output axes are concatenable if they have a `SizeReference` to a concatenable 

605 input axis. 

606 """ 

607 

608 

609class IndexOutputAxis(IndexAxisBase): 

610 size: Annotated[ 

611 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

612 Field( 

613 examples=[ 

614 10, 

615 SizeReference( 

616 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

617 ).model_dump(mode="json"), 

618 ] 

619 ), 

620 ] 

621 """The size/length of this axis can be specified as 

622 - fixed integer 

623 - reference to another axis with an optional offset (`SizeReference`) 

624 - data dependent size using `DataDependentSize` (size is only known after model inference) 

625 """ 

626 

627 

628class TimeAxisBase(AxisBase): 

629 implemented_type: ClassVar[Literal["time"]] = "time" 

630 if TYPE_CHECKING: 

631 type: Literal["time"] = "time" 

632 else: 

633 type: Literal["time"] 

634 

635 id: NonBatchAxisId = AxisId("time") 

636 unit: Optional[TimeUnit] = None 

637 scale: Annotated[float, Gt(0)] = 1.0 

638 

639 

640class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

641 concatenable: bool = False 

642 """If a model has a `concatenable` input axis, it can be processed blockwise, 

643 splitting a longer sample axis into blocks matching its input tensor description. 

644 Output axes are concatenable if they have a `SizeReference` to a concatenable 

645 input axis. 

646 """ 

647 

648 

649class SpaceAxisBase(AxisBase): 

650 implemented_type: ClassVar[Literal["space"]] = "space" 

651 if TYPE_CHECKING: 

652 type: Literal["space"] = "space" 

653 else: 

654 type: Literal["space"] 

655 

656 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

657 unit: Optional[SpaceUnit] = None 

658 scale: Annotated[float, Gt(0)] = 1.0 

659 

660 

661class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

662 concatenable: bool = False 

663 """If a model has a `concatenable` input axis, it can be processed blockwise, 

664 splitting a longer sample axis into blocks matching its input tensor description. 

665 Output axes are concatenable if they have a `SizeReference` to a concatenable 

666 input axis. 

667 """ 

668 

669 

670INPUT_AXIS_TYPES = ( 

671 BatchAxis, 

672 ChannelAxis, 

673 IndexInputAxis, 

674 TimeInputAxis, 

675 SpaceInputAxis, 

676) 

677"""intended for isinstance comparisons in py<3.10""" 

678 

679_InputAxisUnion = Union[ 

680 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

681] 

682InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

683 

684 

685class _WithOutputAxisSize(Node): 

686 size: Annotated[ 

687 Union[Annotated[int, Gt(0)], SizeReference], 

688 Field( 

689 examples=[ 

690 10, 

691 SizeReference( 

692 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

693 ).model_dump(mode="json"), 

694 ] 

695 ), 

696 ] 

697 """The size/length of this axis can be specified as 

698 - fixed integer 

699 - reference to another axis with an optional offset (see `SizeReference`) 

700 """ 

701 

702 

703class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

704 pass 

705 

706 

707class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

708 pass 

709 

710 

711def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

712 if isinstance(v, dict): 

713 return "with_halo" if "halo" in v else "wo_halo" 

714 else: 

715 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

716 

717 

718_TimeOutputAxisUnion = Annotated[ 

719 Union[ 

720 Annotated[TimeOutputAxis, Tag("wo_halo")], 

721 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

722 ], 

723 Discriminator(_get_halo_axis_discriminator_value), 

724] 

725 

726 

727class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

728 pass 

729 

730 

731class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

732 pass 

733 

734 

735_SpaceOutputAxisUnion = Annotated[ 

736 Union[ 

737 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

738 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

739 ], 

740 Discriminator(_get_halo_axis_discriminator_value), 

741] 

742 

743 

744_OutputAxisUnion = Union[ 

745 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

746] 

747OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

748 

749OUTPUT_AXIS_TYPES = ( 

750 BatchAxis, 

751 ChannelAxis, 

752 IndexOutputAxis, 

753 TimeOutputAxis, 

754 TimeOutputAxisWithHalo, 

755 SpaceOutputAxis, 

756 SpaceOutputAxisWithHalo, 

757) 

758"""intended for isinstance comparisons in py<3.10""" 

759 

760 

761AnyAxis = Union[InputAxis, OutputAxis] 

762 

763ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

764"""intended for isinstance comparisons in py<3.10""" 

765 

766TVs = Union[ 

767 NotEmpty[List[int]], 

768 NotEmpty[List[float]], 

769 NotEmpty[List[bool]], 

770 NotEmpty[List[str]], 

771] 

772 

773 

774NominalOrOrdinalDType = Literal[ 

775 "float32", 

776 "float64", 

777 "uint8", 

778 "int8", 

779 "uint16", 

780 "int16", 

781 "uint32", 

782 "int32", 

783 "uint64", 

784 "int64", 

785 "bool", 

786] 

787 

788 

789class NominalOrOrdinalDataDescr(Node): 

790 values: TVs 

791 """A fixed set of nominal or an ascending sequence of ordinal values. 

792 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

793 String `values` are interpreted as labels for tensor values 0, ..., N. 

794 Note: as YAML 1.2 does not natively support a "set" datatype, 

795 nominal values should be given as a sequence (aka list/array) as well. 

796 """ 

797 

798 type: Annotated[ 

799 NominalOrOrdinalDType, 

800 Field( 

801 examples=[ 

802 "float32", 

803 "uint8", 

804 "uint16", 

805 "int64", 

806 "bool", 

807 ], 

808 ), 

809 ] = "uint8" 

810 

811 @model_validator(mode="after") 

812 def _validate_values_match_type( 

813 self, 

814 ) -> Self: 

815 incompatible: List[Any] = [] 

816 for v in self.values: 

817 if self.type == "bool": 

818 if not isinstance(v, bool): 

819 incompatible.append(v) 

820 elif self.type in DTYPE_LIMITS: 

821 if ( 

822 isinstance(v, (int, float)) 

823 and ( 

824 v < DTYPE_LIMITS[self.type].min 

825 or v > DTYPE_LIMITS[self.type].max 

826 ) 

827 or (isinstance(v, str) and "uint" not in self.type) 

828 or (isinstance(v, float) and "int" in self.type) 

829 ): 

830 incompatible.append(v) 

831 else: 

832 incompatible.append(v) 

833 

834 if len(incompatible) == 5: 

835 incompatible.append("...") 

836 break 

837 

838 if incompatible: 

839 raise ValueError( 

840 f"data type '{self.type}' incompatible with values {incompatible}" 

841 ) 

842 

843 return self 

844 

845 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

846 

847 @property 

848 def range(self): 

849 if isinstance(self.values[0], str): 

850 return 0, len(self.values) - 1 

851 else: 

852 return min(self.values), max(self.values) 

853 

854 

855IntervalOrRatioDType = Literal[ 

856 "float32", 

857 "float64", 

858 "uint8", 

859 "int8", 

860 "uint16", 

861 "int16", 

862 "uint32", 

863 "int32", 

864 "uint64", 

865 "int64", 

866] 

867 

868 

869class IntervalOrRatioDataDescr(Node): 

870 type: Annotated[ # TODO: rename to dtype 

871 IntervalOrRatioDType, 

872 Field( 

873 examples=["float32", "float64", "uint8", "uint16"], 

874 ), 

875 ] = "float32" 

876 range: Tuple[Optional[float], Optional[float]] = ( 

877 None, 

878 None, 

879 ) 

880 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

881 `None` corresponds to min/max of what can be expressed by **type**.""" 

882 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

883 scale: float = 1.0 

884 """Scale for data on an interval (or ratio) scale.""" 

885 offset: Optional[float] = None 

886 """Offset for data on a ratio scale.""" 

887 

888 @model_validator(mode="before") 

889 def _replace_inf(cls, data: Any): 

890 if is_dict(data): 

891 if "range" in data and is_sequence(data["range"]): 

892 forbidden = ( 

893 "inf", 

894 "-inf", 

895 ".inf", 

896 "-.inf", 

897 float("inf"), 

898 float("-inf"), 

899 ) 

900 if any(v in forbidden for v in data["range"]): 

901 issue_warning("replaced 'inf' value", value=data["range"]) 

902 

903 data["range"] = tuple( 

904 (None if v in forbidden else v) for v in data["range"] 

905 ) 

906 

907 return data 

908 

909 

910TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

911 

912 

913class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 

914 """processing base class""" 

915 

916 

917class BinarizeKwargs(ProcessingKwargs): 

918 """key word arguments for `BinarizeDescr`""" 

919 

920 threshold: float 

921 """The fixed threshold""" 

922 

923 

924class BinarizeAlongAxisKwargs(ProcessingKwargs): 

925 """key word arguments for `BinarizeDescr`""" 

926 

927 threshold: NotEmpty[List[float]] 

928 """The fixed threshold values along `axis`""" 

929 

930 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

931 """The `threshold` axis""" 

932 

933 

934class BinarizeDescr(ProcessingDescrBase): 

935 """Binarize the tensor with a fixed threshold. 

936 

937 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 

938 will be set to one, values below the threshold to zero. 

939 

940 Examples: 

941 - in YAML 

942 ```yaml 

943 postprocessing: 

944 - id: binarize 

945 kwargs: 

946 axis: 'channel' 

947 threshold: [0.25, 0.5, 0.75] 

948 ``` 

949 - in Python: 

950 >>> postprocessing = [BinarizeDescr( 

951 ... kwargs=BinarizeAlongAxisKwargs( 

952 ... axis=AxisId('channel'), 

953 ... threshold=[0.25, 0.5, 0.75], 

954 ... ) 

955 ... )] 

956 """ 

957 

958 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

959 if TYPE_CHECKING: 

960 id: Literal["binarize"] = "binarize" 

961 else: 

962 id: Literal["binarize"] 

963 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

964 

965 

966class ClipDescr(ProcessingDescrBase): 

967 """Set tensor values below min to min and above max to max. 

968 

969 See `ScaleRangeDescr` for examples. 

970 """ 

971 

972 implemented_id: ClassVar[Literal["clip"]] = "clip" 

973 if TYPE_CHECKING: 

974 id: Literal["clip"] = "clip" 

975 else: 

976 id: Literal["clip"] 

977 

978 kwargs: ClipKwargs 

979 

980 

981class EnsureDtypeKwargs(ProcessingKwargs): 

982 """key word arguments for `EnsureDtypeDescr`""" 

983 

984 dtype: Literal[ 

985 "float32", 

986 "float64", 

987 "uint8", 

988 "int8", 

989 "uint16", 

990 "int16", 

991 "uint32", 

992 "int32", 

993 "uint64", 

994 "int64", 

995 "bool", 

996 ] 

997 

998 

999class EnsureDtypeDescr(ProcessingDescrBase): 

1000 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

1001 

1002 This can for example be used to ensure the inner neural network model gets a 

1003 different input tensor data type than the fully described bioimage.io model does. 

1004 

1005 Examples: 

1006 The described bioimage.io model (incl. preprocessing) accepts any 

1007 float32-compatible tensor, normalizes it with percentiles and clipping and then 

1008 casts it to uint8, which is what the neural network in this example expects. 

1009 - in YAML 

1010 ```yaml 

1011 inputs: 

1012 - data: 

1013 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

1014 preprocessing: 

1015 - id: scale_range 

1016 kwargs: 

1017 axes: ['y', 'x'] 

1018 max_percentile: 99.8 

1019 min_percentile: 5.0 

1020 - id: clip 

1021 kwargs: 

1022 min: 0.0 

1023 max: 1.0 

1024 - id: ensure_dtype # the neural network of the model requires uint8 

1025 kwargs: 

1026 dtype: uint8 

1027 ``` 

1028 - in Python: 

1029 >>> preprocessing = [ 

1030 ... ScaleRangeDescr( 

1031 ... kwargs=ScaleRangeKwargs( 

1032 ... axes= (AxisId('y'), AxisId('x')), 

1033 ... max_percentile= 99.8, 

1034 ... min_percentile= 5.0, 

1035 ... ) 

1036 ... ), 

1037 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

1038 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

1039 ... ] 

1040 """ 

1041 

1042 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1043 if TYPE_CHECKING: 

1044 id: Literal["ensure_dtype"] = "ensure_dtype" 

1045 else: 

1046 id: Literal["ensure_dtype"] 

1047 

1048 kwargs: EnsureDtypeKwargs 

1049 

1050 

1051class ScaleLinearKwargs(ProcessingKwargs): 

1052 """Key word arguments for `ScaleLinearDescr`""" 

1053 

1054 gain: float = 1.0 

1055 """multiplicative factor""" 

1056 

1057 offset: float = 0.0 

1058 """additive term""" 

1059 

1060 @model_validator(mode="after") 

1061 def _validate(self) -> Self: 

1062 if self.gain == 1.0 and self.offset == 0.0: 

1063 raise ValueError( 

1064 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1065 + " != 0.0." 

1066 ) 

1067 

1068 return self 

1069 

1070 

1071class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 

1072 """Key word arguments for `ScaleLinearDescr`""" 

1073 

1074 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1075 """The axis of gain and offset values.""" 

1076 

1077 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1078 """multiplicative factor""" 

1079 

1080 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1081 """additive term""" 

1082 

1083 @model_validator(mode="after") 

1084 def _validate(self) -> Self: 

1085 if isinstance(self.gain, list): 

1086 if isinstance(self.offset, list): 

1087 if len(self.gain) != len(self.offset): 

1088 raise ValueError( 

1089 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1090 ) 

1091 else: 

1092 self.offset = [float(self.offset)] * len(self.gain) 

1093 elif isinstance(self.offset, list): 

1094 self.gain = [float(self.gain)] * len(self.offset) 

1095 else: 

1096 raise ValueError( 

1097 "Do not specify an `axis` for scalar gain and offset values." 

1098 ) 

1099 

1100 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1101 raise ValueError( 

1102 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1103 + " != 0.0." 

1104 ) 

1105 

1106 return self 

1107 

1108 

1109class ScaleLinearDescr(ProcessingDescrBase): 

1110 """Fixed linear scaling. 

1111 

1112 Examples: 

1113 1. Scale with scalar gain and offset 

1114 - in YAML 

1115 ```yaml 

1116 preprocessing: 

1117 - id: scale_linear 

1118 kwargs: 

1119 gain: 2.0 

1120 offset: 3.0 

1121 ``` 

1122 - in Python: 

1123 >>> preprocessing = [ 

1124 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1125 ... ] 

1126 

1127 2. Independent scaling along an axis 

1128 - in YAML 

1129 ```yaml 

1130 preprocessing: 

1131 - id: scale_linear 

1132 kwargs: 

1133 axis: 'channel' 

1134 gain: [1.0, 2.0, 3.0] 

1135 ``` 

1136 - in Python: 

1137 >>> preprocessing = [ 

1138 ... ScaleLinearDescr( 

1139 ... kwargs=ScaleLinearAlongAxisKwargs( 

1140 ... axis=AxisId("channel"), 

1141 ... gain=[1.0, 2.0, 3.0], 

1142 ... ) 

1143 ... ) 

1144 ... ] 

1145 

1146 """ 

1147 

1148 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1149 if TYPE_CHECKING: 

1150 id: Literal["scale_linear"] = "scale_linear" 

1151 else: 

1152 id: Literal["scale_linear"] 

1153 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1154 

1155 

1156class SigmoidDescr(ProcessingDescrBase): 

1157 """The logistic sigmoid function, a.k.a. expit function. 

1158 

1159 Examples: 

1160 - in YAML 

1161 ```yaml 

1162 postprocessing: 

1163 - id: sigmoid 

1164 ``` 

1165 - in Python: 

1166 >>> postprocessing = [SigmoidDescr()] 

1167 """ 

1168 

1169 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1170 if TYPE_CHECKING: 

1171 id: Literal["sigmoid"] = "sigmoid" 

1172 else: 

1173 id: Literal["sigmoid"] 

1174 

1175 @property 

1176 def kwargs(self) -> ProcessingKwargs: 

1177 """empty kwargs""" 

1178 return ProcessingKwargs() 

1179 

1180 

1181class SoftmaxKwargs(ProcessingKwargs): 

1182 """key word arguments for `SoftmaxDescr`""" 

1183 

1184 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] = AxisId("channel") 

1185 """The axis to apply the softmax function along. 

1186 Note: 

1187 Defaults to 'channel' axis 

1188 (which may not exist, in which case 

1189 a different axis id has to be specified). 

1190 """ 

1191 

1192 

1193class SoftmaxDescr(ProcessingDescrBase): 

1194 """The softmax function. 

1195 

1196 Examples: 

1197 - in YAML 

1198 ```yaml 

1199 postprocessing: 

1200 - id: softmax 

1201 kwargs: 

1202 axis: channel 

1203 ``` 

1204 - in Python: 

1205 >>> postprocessing = [SoftmaxDescr(kwargs=SoftmaxKwargs(axis=AxisId("channel")))] 

1206 """ 

1207 

1208 implemented_id: ClassVar[Literal["softmax"]] = "softmax" 

1209 if TYPE_CHECKING: 

1210 id: Literal["softmax"] = "softmax" 

1211 else: 

1212 id: Literal["softmax"] 

1213 

1214 kwargs: SoftmaxKwargs = Field(default_factory=SoftmaxKwargs.model_construct) 

1215 

1216 

1217class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1218 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1219 

1220 mean: float 

1221 """The mean value to normalize with.""" 

1222 

1223 std: Annotated[float, Ge(1e-6)] 

1224 """The standard deviation value to normalize with.""" 

1225 

1226 

1227class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 

1228 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1229 

1230 mean: NotEmpty[List[float]] 

1231 """The mean value(s) to normalize with.""" 

1232 

1233 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1234 """The standard deviation value(s) to normalize with. 

1235 Size must match `mean` values.""" 

1236 

1237 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1238 """The axis of the mean/std values to normalize each entry along that dimension 

1239 separately.""" 

1240 

1241 @model_validator(mode="after") 

1242 def _mean_and_std_match(self) -> Self: 

1243 if len(self.mean) != len(self.std): 

1244 raise ValueError( 

1245 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1246 + " must match." 

1247 ) 

1248 

1249 return self 

1250 

1251 

1252class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1253 """Subtract a given mean and divide by the standard deviation. 

1254 

1255 Normalize with fixed, precomputed values for 

1256 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1257 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1258 axes. 

1259 

1260 Examples: 

1261 1. scalar value for whole tensor 

1262 - in YAML 

1263 ```yaml 

1264 preprocessing: 

1265 - id: fixed_zero_mean_unit_variance 

1266 kwargs: 

1267 mean: 103.5 

1268 std: 13.7 

1269 ``` 

1270 - in Python 

1271 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1272 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1273 ... )] 

1274 

1275 2. independently along an axis 

1276 - in YAML 

1277 ```yaml 

1278 preprocessing: 

1279 - id: fixed_zero_mean_unit_variance 

1280 kwargs: 

1281 axis: channel 

1282 mean: [101.5, 102.5, 103.5] 

1283 std: [11.7, 12.7, 13.7] 

1284 ``` 

1285 - in Python 

1286 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1287 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1288 ... axis=AxisId("channel"), 

1289 ... mean=[101.5, 102.5, 103.5], 

1290 ... std=[11.7, 12.7, 13.7], 

1291 ... ) 

1292 ... )] 

1293 """ 

1294 

1295 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1296 "fixed_zero_mean_unit_variance" 

1297 ) 

1298 if TYPE_CHECKING: 

1299 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1300 else: 

1301 id: Literal["fixed_zero_mean_unit_variance"] 

1302 

1303 kwargs: Union[ 

1304 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1305 ] 

1306 

1307 

1308class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1309 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

1310 

1311 axes: Annotated[ 

1312 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1313 ] = None 

1314 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1315 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1316 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1317 To normalize each sample independently leave out the 'batch' axis. 

1318 Default: Scale all axes jointly.""" 

1319 

1320 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1321 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1322 

1323 

1324class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1325 """Subtract mean and divide by variance. 

1326 

1327 Examples: 

1328 Subtract tensor mean and variance 

1329 - in YAML 

1330 ```yaml 

1331 preprocessing: 

1332 - id: zero_mean_unit_variance 

1333 ``` 

1334 - in Python 

1335 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1336 """ 

1337 

1338 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1339 "zero_mean_unit_variance" 

1340 ) 

1341 if TYPE_CHECKING: 

1342 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1343 else: 

1344 id: Literal["zero_mean_unit_variance"] 

1345 

1346 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1347 default_factory=ZeroMeanUnitVarianceKwargs.model_construct 

1348 ) 

1349 

1350 

1351class ScaleRangeKwargs(ProcessingKwargs): 

1352 """key word arguments for `ScaleRangeDescr` 

1353 

1354 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1355 this processing step normalizes data to the [0, 1] intervall. 

1356 For other percentiles the normalized values will partially be outside the [0, 1] 

1357 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1358 normalized values to a range. 

1359 """ 

1360 

1361 axes: Annotated[ 

1362 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1363 ] = None 

1364 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1365 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1366 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1367 To normalize samples independently, leave out the "batch" axis. 

1368 Default: Scale all axes jointly.""" 

1369 

1370 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1371 """The lower percentile used to determine the value to align with zero.""" 

1372 

1373 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1374 """The upper percentile used to determine the value to align with one. 

1375 Has to be bigger than `min_percentile`. 

1376 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1377 accepting percentiles specified in the range 0.0 to 1.0.""" 

1378 

1379 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1380 """Epsilon for numeric stability. 

1381 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1382 with `v_lower,v_upper` values at the respective percentiles.""" 

1383 

1384 reference_tensor: Optional[TensorId] = None 

1385 """Tensor ID to compute the percentiles from. Default: The tensor itself. 

1386 For any tensor in `inputs` only input tensor references are allowed.""" 

1387 

1388 @field_validator("max_percentile", mode="after") 

1389 @classmethod 

1390 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1391 if (min_p := info.data["min_percentile"]) >= value: 

1392 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1393 

1394 return value 

1395 

1396 

1397class ScaleRangeDescr(ProcessingDescrBase): 

1398 """Scale with percentiles. 

1399 

1400 Examples: 

1401 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1402 - in YAML 

1403 ```yaml 

1404 preprocessing: 

1405 - id: scale_range 

1406 kwargs: 

1407 axes: ['y', 'x'] 

1408 max_percentile: 99.8 

1409 min_percentile: 5.0 

1410 ``` 

1411 - in Python 

1412 >>> preprocessing = [ 

1413 ... ScaleRangeDescr( 

1414 ... kwargs=ScaleRangeKwargs( 

1415 ... axes= (AxisId('y'), AxisId('x')), 

1416 ... max_percentile= 99.8, 

1417 ... min_percentile= 5.0, 

1418 ... ) 

1419 ... ), 

1420 ... ClipDescr( 

1421 ... kwargs=ClipKwargs( 

1422 ... min=0.0, 

1423 ... max=1.0, 

1424 ... ) 

1425 ... ), 

1426 ... ] 

1427 

1428 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1429 - in YAML 

1430 ```yaml 

1431 preprocessing: 

1432 - id: scale_range 

1433 kwargs: 

1434 axes: ['y', 'x'] 

1435 max_percentile: 99.8 

1436 min_percentile: 5.0 

1437 - id: scale_range 

1438 - id: clip 

1439 kwargs: 

1440 min: 0.0 

1441 max: 1.0 

1442 ``` 

1443 - in Python 

1444 >>> preprocessing = [ScaleRangeDescr( 

1445 ... kwargs=ScaleRangeKwargs( 

1446 ... axes= (AxisId('y'), AxisId('x')), 

1447 ... max_percentile= 99.8, 

1448 ... min_percentile= 5.0, 

1449 ... ) 

1450 ... )] 

1451 

1452 """ 

1453 

1454 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1455 if TYPE_CHECKING: 

1456 id: Literal["scale_range"] = "scale_range" 

1457 else: 

1458 id: Literal["scale_range"] 

1459 kwargs: ScaleRangeKwargs = Field(default_factory=ScaleRangeKwargs.model_construct) 

1460 

1461 

1462class ScaleMeanVarianceKwargs(ProcessingKwargs): 

1463 """key word arguments for `ScaleMeanVarianceKwargs`""" 

1464 

1465 reference_tensor: TensorId 

1466 """Name of tensor to match.""" 

1467 

1468 axes: Annotated[ 

1469 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1470 ] = None 

1471 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1472 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1473 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1474 To normalize samples independently, leave out the 'batch' axis. 

1475 Default: Scale all axes jointly.""" 

1476 

1477 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1478 """Epsilon for numeric stability: 

1479 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1480 

1481 

1482class ScaleMeanVarianceDescr(ProcessingDescrBase): 

1483 """Scale a tensor's data distribution to match another tensor's mean/std. 

1484 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1485 """ 

1486 

1487 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1488 if TYPE_CHECKING: 

1489 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1490 else: 

1491 id: Literal["scale_mean_variance"] 

1492 kwargs: ScaleMeanVarianceKwargs 

1493 

1494 

1495PreprocessingDescr = Annotated[ 

1496 Union[ 

1497 BinarizeDescr, 

1498 ClipDescr, 

1499 EnsureDtypeDescr, 

1500 FixedZeroMeanUnitVarianceDescr, 

1501 ScaleLinearDescr, 

1502 ScaleRangeDescr, 

1503 SigmoidDescr, 

1504 SoftmaxDescr, 

1505 ZeroMeanUnitVarianceDescr, 

1506 ], 

1507 Discriminator("id"), 

1508] 

1509PostprocessingDescr = Annotated[ 

1510 Union[ 

1511 BinarizeDescr, 

1512 ClipDescr, 

1513 EnsureDtypeDescr, 

1514 FixedZeroMeanUnitVarianceDescr, 

1515 ScaleLinearDescr, 

1516 ScaleMeanVarianceDescr, 

1517 ScaleRangeDescr, 

1518 SigmoidDescr, 

1519 SoftmaxDescr, 

1520 ZeroMeanUnitVarianceDescr, 

1521 ], 

1522 Discriminator("id"), 

1523] 

1524 

1525IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1526 

1527 

1528class TensorDescrBase(Node, Generic[IO_AxisT]): 

1529 id: TensorId 

1530 """Tensor id. No duplicates are allowed.""" 

1531 

1532 description: Annotated[str, MaxLen(128)] = "" 

1533 """free text description""" 

1534 

1535 axes: NotEmpty[Sequence[IO_AxisT]] 

1536 """tensor axes""" 

1537 

1538 @property 

1539 def shape(self): 

1540 return tuple(a.size for a in self.axes) 

1541 

1542 @field_validator("axes", mode="after", check_fields=False) 

1543 @classmethod 

1544 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1545 batch_axes = [a for a in axes if a.type == "batch"] 

1546 if len(batch_axes) > 1: 

1547 raise ValueError( 

1548 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1549 ) 

1550 

1551 seen_ids: Set[AxisId] = set() 

1552 duplicate_axes_ids: Set[AxisId] = set() 

1553 for a in axes: 

1554 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1555 

1556 if duplicate_axes_ids: 

1557 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1558 

1559 return axes 

1560 

1561 test_tensor: FAIR[Optional[FileDescr_]] = None 

1562 """An example tensor to use for testing. 

1563 Using the model with the test input tensors is expected to yield the test output tensors. 

1564 Each test tensor has be a an ndarray in the 

1565 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1566 The file extension must be '.npy'.""" 

1567 

1568 sample_tensor: FAIR[Optional[FileDescr_]] = None 

1569 """A sample tensor to illustrate a possible input/output for the model, 

1570 The sample image primarily serves to inform a human user about an example use case 

1571 and is typically stored as .hdf5, .png or .tiff. 

1572 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1573 (numpy's `.npy` format is not supported). 

1574 The image dimensionality has to match the number of axes specified in this tensor description. 

1575 """ 

1576 

1577 @model_validator(mode="after") 

1578 def _validate_sample_tensor(self) -> Self: 

1579 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1580 return self 

1581 

1582 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1583 tensor: NDArray[Any] = imread( 

1584 reader.read(), 

1585 extension=PurePosixPath(reader.original_file_name).suffix, 

1586 ) 

1587 n_dims = len(tensor.squeeze().shape) 

1588 n_dims_min = n_dims_max = len(self.axes) 

1589 

1590 for a in self.axes: 

1591 if isinstance(a, BatchAxis): 

1592 n_dims_min -= 1 

1593 elif isinstance(a.size, int): 

1594 if a.size == 1: 

1595 n_dims_min -= 1 

1596 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1597 if a.size.min == 1: 

1598 n_dims_min -= 1 

1599 elif isinstance(a.size, SizeReference): 

1600 if a.size.offset < 2: 

1601 # size reference may result in singleton axis 

1602 n_dims_min -= 1 

1603 else: 

1604 assert_never(a.size) 

1605 

1606 n_dims_min = max(0, n_dims_min) 

1607 if n_dims < n_dims_min or n_dims > n_dims_max: 

1608 raise ValueError( 

1609 f"Expected sample tensor to have {n_dims_min} to" 

1610 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1611 ) 

1612 

1613 return self 

1614 

1615 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1616 IntervalOrRatioDataDescr() 

1617 ) 

1618 """Description of the tensor's data values, optionally per channel. 

1619 If specified per channel, the data `type` needs to match across channels.""" 

1620 

1621 @property 

1622 def dtype( 

1623 self, 

1624 ) -> Literal[ 

1625 "float32", 

1626 "float64", 

1627 "uint8", 

1628 "int8", 

1629 "uint16", 

1630 "int16", 

1631 "uint32", 

1632 "int32", 

1633 "uint64", 

1634 "int64", 

1635 "bool", 

1636 ]: 

1637 """dtype as specified under `data.type` or `data[i].type`""" 

1638 if isinstance(self.data, collections.abc.Sequence): 

1639 return self.data[0].type 

1640 else: 

1641 return self.data.type 

1642 

1643 @field_validator("data", mode="after") 

1644 @classmethod 

1645 def _check_data_type_across_channels( 

1646 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1647 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1648 if not isinstance(value, list): 

1649 return value 

1650 

1651 dtypes = {t.type for t in value} 

1652 if len(dtypes) > 1: 

1653 raise ValueError( 

1654 "Tensor data descriptions per channel need to agree in their data" 

1655 + f" `type`, but found {dtypes}." 

1656 ) 

1657 

1658 return value 

1659 

1660 @model_validator(mode="after") 

1661 def _check_data_matches_channelaxis(self) -> Self: 

1662 if not isinstance(self.data, (list, tuple)): 

1663 return self 

1664 

1665 for a in self.axes: 

1666 if isinstance(a, ChannelAxis): 

1667 size = a.size 

1668 assert isinstance(size, int) 

1669 break 

1670 else: 

1671 return self 

1672 

1673 if len(self.data) != size: 

1674 raise ValueError( 

1675 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1676 + f" '{a.id}' axis has size {size}." 

1677 ) 

1678 

1679 return self 

1680 

1681 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1682 if len(array.shape) != len(self.axes): 

1683 raise ValueError( 

1684 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1685 + f" incompatible with {len(self.axes)} axes." 

1686 ) 

1687 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1688 

1689 

1690class InputTensorDescr(TensorDescrBase[InputAxis]): 

1691 id: TensorId = TensorId("input") 

1692 """Input tensor id. 

1693 No duplicates are allowed across all inputs and outputs.""" 

1694 

1695 optional: bool = False 

1696 """indicates that this tensor may be `None`""" 

1697 

1698 preprocessing: List[PreprocessingDescr] = Field( 

1699 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 

1700 ) 

1701 

1702 """Description of how this input should be preprocessed. 

1703 

1704 notes: 

1705 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1706 to ensure an input tensor's data type matches the input tensor's data description. 

1707 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1708 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1709 changing the data type. 

1710 """ 

1711 

1712 @model_validator(mode="after") 

1713 def _validate_preprocessing_kwargs(self) -> Self: 

1714 axes_ids = [a.id for a in self.axes] 

1715 for p in self.preprocessing: 

1716 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1717 if kwargs_axes is None: 

1718 continue 

1719 

1720 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1721 raise ValueError( 

1722 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1723 ) 

1724 

1725 if any(a not in axes_ids for a in kwargs_axes): 

1726 raise ValueError( 

1727 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1728 ) 

1729 

1730 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1731 dtype = self.data.type 

1732 else: 

1733 dtype = self.data[0].type 

1734 

1735 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1736 if not self.preprocessing or not isinstance( 

1737 self.preprocessing[0], EnsureDtypeDescr 

1738 ): 

1739 self.preprocessing.insert( 

1740 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1741 ) 

1742 

1743 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1744 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1745 self.preprocessing.append( 

1746 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1747 ) 

1748 

1749 return self 

1750 

1751 

1752def convert_axes( 

1753 axes: str, 

1754 *, 

1755 shape: Union[ 

1756 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1757 ], 

1758 tensor_type: Literal["input", "output"], 

1759 halo: Optional[Sequence[int]], 

1760 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1761): 

1762 ret: List[AnyAxis] = [] 

1763 for i, a in enumerate(axes): 

1764 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1765 if axis_type == "batch": 

1766 ret.append(BatchAxis()) 

1767 continue 

1768 

1769 scale = 1.0 

1770 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1771 if shape.step[i] == 0: 

1772 size = shape.min[i] 

1773 else: 

1774 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1775 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1776 ref_t = str(shape.reference_tensor) 

1777 if ref_t.count(".") == 1: 

1778 t_id, orig_a_id = ref_t.split(".") 

1779 else: 

1780 t_id = ref_t 

1781 orig_a_id = a 

1782 

1783 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1784 if not (orig_scale := shape.scale[i]): 

1785 # old way to insert a new axis dimension 

1786 size = int(2 * shape.offset[i]) 

1787 else: 

1788 scale = 1 / orig_scale 

1789 if axis_type in ("channel", "index"): 

1790 # these axes no longer have a scale 

1791 offset_from_scale = orig_scale * size_refs.get( 

1792 _TensorName_v0_4(t_id), {} 

1793 ).get(orig_a_id, 0) 

1794 else: 

1795 offset_from_scale = 0 

1796 size = SizeReference( 

1797 tensor_id=TensorId(t_id), 

1798 axis_id=AxisId(a_id), 

1799 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1800 ) 

1801 else: 

1802 size = shape[i] 

1803 

1804 if axis_type == "time": 

1805 if tensor_type == "input": 

1806 ret.append(TimeInputAxis(size=size, scale=scale)) 

1807 else: 

1808 assert not isinstance(size, ParameterizedSize) 

1809 if halo is None: 

1810 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1811 else: 

1812 assert not isinstance(size, int) 

1813 ret.append( 

1814 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1815 ) 

1816 

1817 elif axis_type == "index": 

1818 if tensor_type == "input": 

1819 ret.append(IndexInputAxis(size=size)) 

1820 else: 

1821 if isinstance(size, ParameterizedSize): 

1822 size = DataDependentSize(min=size.min) 

1823 

1824 ret.append(IndexOutputAxis(size=size)) 

1825 elif axis_type == "channel": 

1826 assert not isinstance(size, ParameterizedSize) 

1827 if isinstance(size, SizeReference): 

1828 warnings.warn( 

1829 "Conversion of channel size from an implicit output shape may be" 

1830 + " wrong" 

1831 ) 

1832 ret.append( 

1833 ChannelAxis( 

1834 channel_names=[ 

1835 Identifier(f"channel{i}") for i in range(size.offset) 

1836 ] 

1837 ) 

1838 ) 

1839 else: 

1840 ret.append( 

1841 ChannelAxis( 

1842 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1843 ) 

1844 ) 

1845 elif axis_type == "space": 

1846 if tensor_type == "input": 

1847 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

1848 else: 

1849 assert not isinstance(size, ParameterizedSize) 

1850 if halo is None or halo[i] == 0: 

1851 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

1852 elif isinstance(size, int): 

1853 raise NotImplementedError( 

1854 f"output axis with halo and fixed size (here {size}) not allowed" 

1855 ) 

1856 else: 

1857 ret.append( 

1858 SpaceOutputAxisWithHalo( 

1859 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

1860 ) 

1861 ) 

1862 

1863 return ret 

1864 

1865 

1866def _axes_letters_to_ids( 

1867 axes: Optional[str], 

1868) -> Optional[List[AxisId]]: 

1869 if axes is None: 

1870 return None 

1871 

1872 return [AxisId(a) for a in axes] 

1873 

1874 

1875def _get_complement_v04_axis( 

1876 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

1877) -> Optional[AxisId]: 

1878 if axes is None: 

1879 return None 

1880 

1881 non_complement_axes = set(axes) | {"b"} 

1882 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

1883 if len(complement_axes) > 1: 

1884 raise ValueError( 

1885 f"Expected none or a single complement axis, but axes '{axes}' " 

1886 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

1887 ) 

1888 

1889 return None if not complement_axes else AxisId(complement_axes[0]) 

1890 

1891 

1892def _convert_proc( 

1893 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

1894 tensor_axes: Sequence[str], 

1895) -> Union[PreprocessingDescr, PostprocessingDescr]: 

1896 if isinstance(p, _BinarizeDescr_v0_4): 

1897 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

1898 elif isinstance(p, _ClipDescr_v0_4): 

1899 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

1900 elif isinstance(p, _SigmoidDescr_v0_4): 

1901 return SigmoidDescr() 

1902 elif isinstance(p, _ScaleLinearDescr_v0_4): 

1903 axes = _axes_letters_to_ids(p.kwargs.axes) 

1904 if p.kwargs.axes is None: 

1905 axis = None 

1906 else: 

1907 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1908 

1909 if axis is None: 

1910 assert not isinstance(p.kwargs.gain, list) 

1911 assert not isinstance(p.kwargs.offset, list) 

1912 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

1913 else: 

1914 kwargs = ScaleLinearAlongAxisKwargs( 

1915 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

1916 ) 

1917 return ScaleLinearDescr(kwargs=kwargs) 

1918 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

1919 return ScaleMeanVarianceDescr( 

1920 kwargs=ScaleMeanVarianceKwargs( 

1921 axes=_axes_letters_to_ids(p.kwargs.axes), 

1922 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

1923 eps=p.kwargs.eps, 

1924 ) 

1925 ) 

1926 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

1927 if p.kwargs.mode == "fixed": 

1928 mean = p.kwargs.mean 

1929 std = p.kwargs.std 

1930 assert mean is not None 

1931 assert std is not None 

1932 

1933 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1934 

1935 if axis is None: 

1936 if isinstance(mean, list): 

1937 raise ValueError("Expected single float value for mean, not <list>") 

1938 if isinstance(std, list): 

1939 raise ValueError("Expected single float value for std, not <list>") 

1940 return FixedZeroMeanUnitVarianceDescr( 

1941 kwargs=FixedZeroMeanUnitVarianceKwargs.model_construct( 

1942 mean=mean, 

1943 std=std, 

1944 ) 

1945 ) 

1946 else: 

1947 if not isinstance(mean, list): 

1948 mean = [float(mean)] 

1949 if not isinstance(std, list): 

1950 std = [float(std)] 

1951 

1952 return FixedZeroMeanUnitVarianceDescr( 

1953 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1954 axis=axis, mean=mean, std=std 

1955 ) 

1956 ) 

1957 

1958 else: 

1959 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

1960 if p.kwargs.mode == "per_dataset": 

1961 axes = [AxisId("batch")] + axes 

1962 if not axes: 

1963 axes = None 

1964 return ZeroMeanUnitVarianceDescr( 

1965 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

1966 ) 

1967 

1968 elif isinstance(p, _ScaleRangeDescr_v0_4): 

1969 return ScaleRangeDescr( 

1970 kwargs=ScaleRangeKwargs( 

1971 axes=_axes_letters_to_ids(p.kwargs.axes), 

1972 min_percentile=p.kwargs.min_percentile, 

1973 max_percentile=p.kwargs.max_percentile, 

1974 eps=p.kwargs.eps, 

1975 ) 

1976 ) 

1977 else: 

1978 assert_never(p) 

1979 

1980 

1981class _InputTensorConv( 

1982 Converter[ 

1983 _InputTensorDescr_v0_4, 

1984 InputTensorDescr, 

1985 FileSource_, 

1986 Optional[FileSource_], 

1987 Mapping[_TensorName_v0_4, Mapping[str, int]], 

1988 ] 

1989): 

1990 def _convert( 

1991 self, 

1992 src: _InputTensorDescr_v0_4, 

1993 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

1994 test_tensor: FileSource_, 

1995 sample_tensor: Optional[FileSource_], 

1996 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1997 ) -> "InputTensorDescr | dict[str, Any]": 

1998 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

1999 src.axes, 

2000 shape=src.shape, 

2001 tensor_type="input", 

2002 halo=None, 

2003 size_refs=size_refs, 

2004 ) 

2005 prep: List[PreprocessingDescr] = [] 

2006 for p in src.preprocessing: 

2007 cp = _convert_proc(p, src.axes) 

2008 assert not isinstance(cp, ScaleMeanVarianceDescr) 

2009 prep.append(cp) 

2010 

2011 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

2012 

2013 return tgt( 

2014 axes=axes, 

2015 id=TensorId(str(src.name)), 

2016 test_tensor=FileDescr(source=test_tensor), 

2017 sample_tensor=( 

2018 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2019 ), 

2020 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

2021 preprocessing=prep, 

2022 ) 

2023 

2024 

2025_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

2026 

2027 

2028class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

2029 id: TensorId = TensorId("output") 

2030 """Output tensor id. 

2031 No duplicates are allowed across all inputs and outputs.""" 

2032 

2033 postprocessing: List[PostprocessingDescr] = Field( 

2034 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

2035 ) 

2036 """Description of how this output should be postprocessed. 

2037 

2038 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

2039 If not given this is added to cast to this tensor's `data.type`. 

2040 """ 

2041 

2042 @model_validator(mode="after") 

2043 def _validate_postprocessing_kwargs(self) -> Self: 

2044 axes_ids = [a.id for a in self.axes] 

2045 for p in self.postprocessing: 

2046 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

2047 if kwargs_axes is None: 

2048 continue 

2049 

2050 if not isinstance(kwargs_axes, collections.abc.Sequence): 

2051 raise ValueError( 

2052 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

2053 ) 

2054 

2055 if any(a not in axes_ids for a in kwargs_axes): 

2056 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

2057 

2058 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

2059 dtype = self.data.type 

2060 else: 

2061 dtype = self.data[0].type 

2062 

2063 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

2064 if not self.postprocessing or not isinstance( 

2065 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

2066 ): 

2067 self.postprocessing.append( 

2068 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2069 ) 

2070 return self 

2071 

2072 

2073class _OutputTensorConv( 

2074 Converter[ 

2075 _OutputTensorDescr_v0_4, 

2076 OutputTensorDescr, 

2077 FileSource_, 

2078 Optional[FileSource_], 

2079 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2080 ] 

2081): 

2082 def _convert( 

2083 self, 

2084 src: _OutputTensorDescr_v0_4, 

2085 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

2086 test_tensor: FileSource_, 

2087 sample_tensor: Optional[FileSource_], 

2088 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2089 ) -> "OutputTensorDescr | dict[str, Any]": 

2090 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2091 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2092 src.axes, 

2093 shape=src.shape, 

2094 tensor_type="output", 

2095 halo=src.halo, 

2096 size_refs=size_refs, 

2097 ) 

2098 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2099 if data_descr["type"] == "bool": 

2100 data_descr["values"] = [False, True] 

2101 

2102 return tgt( 

2103 axes=axes, 

2104 id=TensorId(str(src.name)), 

2105 test_tensor=FileDescr(source=test_tensor), 

2106 sample_tensor=( 

2107 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2108 ), 

2109 data=data_descr, # pyright: ignore[reportArgumentType] 

2110 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2111 ) 

2112 

2113 

2114_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2115 

2116 

2117TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2118 

2119 

2120def validate_tensors( 

2121 tensors: Mapping[TensorId, Tuple[TensorDescr, Optional[NDArray[Any]]]], 

2122 tensor_origin: Literal[ 

2123 "test_tensor" 

2124 ], # for more precise error messages, e.g. 'test_tensor' 

2125): 

2126 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, Optional[int]]]] = {} 

2127 

2128 def e_msg(d: TensorDescr): 

2129 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2130 

2131 for descr, array in tensors.values(): 

2132 if array is None: 

2133 axis_sizes = {a.id: None for a in descr.axes} 

2134 else: 

2135 try: 

2136 axis_sizes = descr.get_axis_sizes_for_array(array) 

2137 except ValueError as e: 

2138 raise ValueError(f"{e_msg(descr)} {e}") 

2139 

2140 all_tensor_axes[descr.id] = {a.id: (a, axis_sizes[a.id]) for a in descr.axes} 

2141 

2142 for descr, array in tensors.values(): 

2143 if array is None: 

2144 continue 

2145 

2146 if descr.dtype in ("float32", "float64"): 

2147 invalid_test_tensor_dtype = array.dtype.name not in ( 

2148 "float32", 

2149 "float64", 

2150 "uint8", 

2151 "int8", 

2152 "uint16", 

2153 "int16", 

2154 "uint32", 

2155 "int32", 

2156 "uint64", 

2157 "int64", 

2158 ) 

2159 else: 

2160 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2161 

2162 if invalid_test_tensor_dtype: 

2163 raise ValueError( 

2164 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 

2165 + f" match described dtype '{descr.dtype}'" 

2166 ) 

2167 

2168 if array.min() > -1e-4 and array.max() < 1e-4: 

2169 raise ValueError( 

2170 "Output values are too small for reliable testing." 

2171 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2172 ) 

2173 

2174 for a in descr.axes: 

2175 actual_size = all_tensor_axes[descr.id][a.id][1] 

2176 if actual_size is None: 

2177 continue 

2178 

2179 if a.size is None: 

2180 continue 

2181 

2182 if isinstance(a.size, int): 

2183 if actual_size != a.size: 

2184 raise ValueError( 

2185 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 

2186 + f"has incompatible size {actual_size}, expected {a.size}" 

2187 ) 

2188 elif isinstance(a.size, ParameterizedSize): 

2189 _ = a.size.validate_size(actual_size) 

2190 elif isinstance(a.size, DataDependentSize): 

2191 _ = a.size.validate_size(actual_size) 

2192 elif isinstance(a.size, SizeReference): 

2193 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2194 if ref_tensor_axes is None: 

2195 raise ValueError( 

2196 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2197 + f" reference '{a.size.tensor_id}'" 

2198 ) 

2199 

2200 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2201 if ref_axis is None or ref_size is None: 

2202 raise ValueError( 

2203 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2204 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 

2205 ) 

2206 

2207 if a.unit != ref_axis.unit: 

2208 raise ValueError( 

2209 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2210 + " axis and reference axis to have the same `unit`, but" 

2211 + f" {a.unit}!={ref_axis.unit}" 

2212 ) 

2213 

2214 if actual_size != ( 

2215 expected_size := ( 

2216 ref_size * ref_axis.scale / a.scale + a.size.offset 

2217 ) 

2218 ): 

2219 raise ValueError( 

2220 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2221 + f" {actual_size} invalid for referenced size {ref_size};" 

2222 + f" expected {expected_size}" 

2223 ) 

2224 else: 

2225 assert_never(a.size) 

2226 

2227 

2228FileDescr_dependencies = Annotated[ 

2229 FileDescr_, 

2230 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2231 Field(examples=[dict(source="environment.yaml")]), 

2232] 

2233 

2234 

2235class _ArchitectureCallableDescr(Node): 

2236 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2237 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2238 

2239 kwargs: Dict[str, YamlValue] = Field( 

2240 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2241 ) 

2242 """key word arguments for the `callable`""" 

2243 

2244 

2245class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2246 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2247 """Architecture source file""" 

2248 

2249 @model_serializer(mode="wrap", when_used="unless-none") 

2250 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2251 return package_file_descr_serializer(self, nxt, info) 

2252 

2253 

2254class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2255 import_from: str 

2256 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2257 

2258 

2259class _ArchFileConv( 

2260 Converter[ 

2261 _CallableFromFile_v0_4, 

2262 ArchitectureFromFileDescr, 

2263 Optional[Sha256], 

2264 Dict[str, Any], 

2265 ] 

2266): 

2267 def _convert( 

2268 self, 

2269 src: _CallableFromFile_v0_4, 

2270 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2271 sha256: Optional[Sha256], 

2272 kwargs: Dict[str, Any], 

2273 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2274 if src.startswith("http") and src.count(":") == 2: 

2275 http, source, callable_ = src.split(":") 

2276 source = ":".join((http, source)) 

2277 elif not src.startswith("http") and src.count(":") == 1: 

2278 source, callable_ = src.split(":") 

2279 else: 

2280 source = str(src) 

2281 callable_ = str(src) 

2282 return tgt( 

2283 callable=Identifier(callable_), 

2284 source=cast(FileSource_, source), 

2285 sha256=sha256, 

2286 kwargs=kwargs, 

2287 ) 

2288 

2289 

2290_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2291 

2292 

2293class _ArchLibConv( 

2294 Converter[ 

2295 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2296 ] 

2297): 

2298 def _convert( 

2299 self, 

2300 src: _CallableFromDepencency_v0_4, 

2301 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2302 kwargs: Dict[str, Any], 

2303 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2304 *mods, callable_ = src.split(".") 

2305 import_from = ".".join(mods) 

2306 return tgt( 

2307 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2308 ) 

2309 

2310 

2311_arch_lib_conv = _ArchLibConv( 

2312 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2313) 

2314 

2315 

2316class WeightsEntryDescrBase(FileDescr): 

2317 type: ClassVar[WeightsFormat] 

2318 weights_format_name: ClassVar[str] # human readable 

2319 

2320 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2321 """Source of the weights file.""" 

2322 

2323 authors: Optional[List[Author]] = None 

2324 """Authors 

2325 Either the person(s) that have trained this model resulting in the original weights file. 

2326 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2327 Or the person(s) who have converted the weights to this weights format. 

2328 (If this is a child weight, i.e. it has a `parent` field) 

2329 """ 

2330 

2331 parent: Annotated[ 

2332 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2333 ] = None 

2334 """The source weights these weights were converted from. 

2335 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2336 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2337 All weight entries except one (the initial set of weights resulting from training the model), 

2338 need to have this field.""" 

2339 

2340 comment: str = "" 

2341 """A comment about this weights entry, for example how these weights were created.""" 

2342 

2343 @model_validator(mode="after") 

2344 def _validate(self) -> Self: 

2345 if self.type == self.parent: 

2346 raise ValueError("Weights entry can't be it's own parent.") 

2347 

2348 return self 

2349 

2350 @model_serializer(mode="wrap", when_used="unless-none") 

2351 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2352 return package_file_descr_serializer(self, nxt, info) 

2353 

2354 

2355class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2356 type = "keras_hdf5" 

2357 weights_format_name: ClassVar[str] = "Keras HDF5" 

2358 tensorflow_version: Version 

2359 """TensorFlow version used to create these weights.""" 

2360 

2361 

2362class OnnxWeightsDescr(WeightsEntryDescrBase): 

2363 type = "onnx" 

2364 weights_format_name: ClassVar[str] = "ONNX" 

2365 opset_version: Annotated[int, Ge(7)] 

2366 """ONNX opset version""" 

2367 

2368 

2369class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2370 type = "pytorch_state_dict" 

2371 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2372 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 

2373 pytorch_version: Version 

2374 """Version of the PyTorch library used. 

2375 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2376 """ 

2377 dependencies: Optional[FileDescr_dependencies] = None 

2378 """Custom depencies beyond pytorch described in a Conda environment file. 

2379 Allows to specify custom dependencies, see conda docs: 

2380 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2381 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2382 

2383 The conda environment file should include pytorch and any version pinning has to be compatible with 

2384 **pytorch_version**. 

2385 """ 

2386 

2387 

2388class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2389 type = "tensorflow_js" 

2390 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2391 tensorflow_version: Version 

2392 """Version of the TensorFlow library used.""" 

2393 

2394 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2395 """The multi-file weights. 

2396 All required files/folders should be a zip archive.""" 

2397 

2398 

2399class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2400 type = "tensorflow_saved_model_bundle" 

2401 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2402 tensorflow_version: Version 

2403 """Version of the TensorFlow library used.""" 

2404 

2405 dependencies: Optional[FileDescr_dependencies] = None 

2406 """Custom dependencies beyond tensorflow. 

2407 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 

2408 

2409 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2410 """The multi-file weights. 

2411 All required files/folders should be a zip archive.""" 

2412 

2413 

2414class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2415 type = "torchscript" 

2416 weights_format_name: ClassVar[str] = "TorchScript" 

2417 pytorch_version: Version 

2418 """Version of the PyTorch library used.""" 

2419 

2420 

2421class WeightsDescr(Node): 

2422 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2423 onnx: Optional[OnnxWeightsDescr] = None 

2424 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2425 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2426 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2427 None 

2428 ) 

2429 torchscript: Optional[TorchscriptWeightsDescr] = None 

2430 

2431 @model_validator(mode="after") 

2432 def check_entries(self) -> Self: 

2433 entries = {wtype for wtype, entry in self if entry is not None} 

2434 

2435 if not entries: 

2436 raise ValueError("Missing weights entry") 

2437 

2438 entries_wo_parent = { 

2439 wtype 

2440 for wtype, entry in self 

2441 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2442 } 

2443 if len(entries_wo_parent) != 1: 

2444 issue_warning( 

2445 "Exactly one weights entry may not specify the `parent` field (got" 

2446 + " {value}). That entry is considered the original set of model weights." 

2447 + " Other weight formats are created through conversion of the orignal or" 

2448 + " already converted weights. They have to reference the weights format" 

2449 + " they were converted from as their `parent`.", 

2450 value=len(entries_wo_parent), 

2451 field="weights", 

2452 ) 

2453 

2454 for wtype, entry in self: 

2455 if entry is None: 

2456 continue 

2457 

2458 assert hasattr(entry, "type") 

2459 assert hasattr(entry, "parent") 

2460 assert wtype == entry.type 

2461 if ( 

2462 entry.parent is not None and entry.parent not in entries 

2463 ): # self reference checked for `parent` field 

2464 raise ValueError( 

2465 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2466 + f" formats: {entries}" 

2467 ) 

2468 

2469 return self 

2470 

2471 def __getitem__( 

2472 self, 

2473 key: Literal[ 

2474 "keras_hdf5", 

2475 "onnx", 

2476 "pytorch_state_dict", 

2477 "tensorflow_js", 

2478 "tensorflow_saved_model_bundle", 

2479 "torchscript", 

2480 ], 

2481 ): 

2482 if key == "keras_hdf5": 

2483 ret = self.keras_hdf5 

2484 elif key == "onnx": 

2485 ret = self.onnx 

2486 elif key == "pytorch_state_dict": 

2487 ret = self.pytorch_state_dict 

2488 elif key == "tensorflow_js": 

2489 ret = self.tensorflow_js 

2490 elif key == "tensorflow_saved_model_bundle": 

2491 ret = self.tensorflow_saved_model_bundle 

2492 elif key == "torchscript": 

2493 ret = self.torchscript 

2494 else: 

2495 raise KeyError(key) 

2496 

2497 if ret is None: 

2498 raise KeyError(key) 

2499 

2500 return ret 

2501 

2502 @property 

2503 def available_formats(self): 

2504 return { 

2505 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2506 **({} if self.onnx is None else {"onnx": self.onnx}), 

2507 **( 

2508 {} 

2509 if self.pytorch_state_dict is None 

2510 else {"pytorch_state_dict": self.pytorch_state_dict} 

2511 ), 

2512 **( 

2513 {} 

2514 if self.tensorflow_js is None 

2515 else {"tensorflow_js": self.tensorflow_js} 

2516 ), 

2517 **( 

2518 {} 

2519 if self.tensorflow_saved_model_bundle is None 

2520 else { 

2521 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2522 } 

2523 ), 

2524 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2525 } 

2526 

2527 @property 

2528 def missing_formats(self): 

2529 return { 

2530 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2531 } 

2532 

2533 

2534class ModelId(ResourceId): 

2535 pass 

2536 

2537 

2538class LinkedModel(LinkedResourceBase): 

2539 """Reference to a bioimage.io model.""" 

2540 

2541 id: ModelId 

2542 """A valid model `id` from the bioimage.io collection.""" 

2543 

2544 

2545class _DataDepSize(NamedTuple): 

2546 min: StrictInt 

2547 max: Optional[StrictInt] 

2548 

2549 

2550class _AxisSizes(NamedTuple): 

2551 """the lenghts of all axes of model inputs and outputs""" 

2552 

2553 inputs: Dict[Tuple[TensorId, AxisId], int] 

2554 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2555 

2556 

2557class _TensorSizes(NamedTuple): 

2558 """_AxisSizes as nested dicts""" 

2559 

2560 inputs: Dict[TensorId, Dict[AxisId, int]] 

2561 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2562 

2563 

2564class ReproducibilityTolerance(Node, extra="allow"): 

2565 """Describes what small numerical differences -- if any -- may be tolerated 

2566 in the generated output when executing in different environments. 

2567 

2568 A tensor element *output* is considered mismatched to the **test_tensor** if 

2569 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

2570 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

2571 

2572 Motivation: 

2573 For testing we can request the respective deep learning frameworks to be as 

2574 reproducible as possible by setting seeds and chosing deterministic algorithms, 

2575 but differences in operating systems, available hardware and installed drivers 

2576 may still lead to numerical differences. 

2577 """ 

2578 

2579 relative_tolerance: RelativeTolerance = 1e-3 

2580 """Maximum relative tolerance of reproduced test tensor.""" 

2581 

2582 absolute_tolerance: AbsoluteTolerance = 1e-4 

2583 """Maximum absolute tolerance of reproduced test tensor.""" 

2584 

2585 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 

2586 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

2587 

2588 output_ids: Sequence[TensorId] = () 

2589 """Limits the output tensor IDs these reproducibility details apply to.""" 

2590 

2591 weights_formats: Sequence[WeightsFormat] = () 

2592 """Limits the weights formats these details apply to.""" 

2593 

2594 

2595class BioimageioConfig(Node, extra="allow"): 

2596 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

2597 """Tolerances to allow when reproducing the model's test outputs 

2598 from the model's test inputs. 

2599 Only the first entry matching tensor id and weights format is considered. 

2600 """ 

2601 

2602 

2603class Config(Node, extra="allow"): 

2604 bioimageio: BioimageioConfig = Field( 

2605 default_factory=BioimageioConfig.model_construct 

2606 ) 

2607 

2608 

2609class ModelDescr(GenericModelDescrBase): 

2610 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

2611 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

2612 """ 

2613 

2614 implemented_format_version: ClassVar[Literal["0.5.5"]] = "0.5.5" 

2615 if TYPE_CHECKING: 

2616 format_version: Literal["0.5.5"] = "0.5.5" 

2617 else: 

2618 format_version: Literal["0.5.5"] 

2619 """Version of the bioimage.io model description specification used. 

2620 When creating a new model always use the latest micro/patch version described here. 

2621 The `format_version` is important for any consumer software to understand how to parse the fields. 

2622 """ 

2623 

2624 implemented_type: ClassVar[Literal["model"]] = "model" 

2625 if TYPE_CHECKING: 

2626 type: Literal["model"] = "model" 

2627 else: 

2628 type: Literal["model"] 

2629 """Specialized resource type 'model'""" 

2630 

2631 id: Optional[ModelId] = None 

2632 """bioimage.io-wide unique resource identifier 

2633 assigned by bioimage.io; version **un**specific.""" 

2634 

2635 authors: FAIR[List[Author]] = Field( 

2636 default_factory=cast(Callable[[], List[Author]], list) 

2637 ) 

2638 """The authors are the creators of the model RDF and the primary points of contact.""" 

2639 

2640 documentation: FAIR[Optional[FileSource_documentation]] = None 

2641 """URL or relative path to a markdown file with additional documentation. 

2642 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

2643 The documentation should include a '#[#] Validation' (sub)section 

2644 with details on how to quantitatively validate the model on unseen data.""" 

2645 

2646 @field_validator("documentation", mode="after") 

2647 @classmethod 

2648 def _validate_documentation( 

2649 cls, value: Optional[FileSource_documentation] 

2650 ) -> Optional[FileSource_documentation]: 

2651 if not get_validation_context().perform_io_checks or value is None: 

2652 return value 

2653 

2654 doc_reader = get_reader(value) 

2655 doc_content = doc_reader.read().decode(encoding="utf-8") 

2656 if not re.search("#.*[vV]alidation", doc_content): 

2657 issue_warning( 

2658 "No '# Validation' (sub)section found in {value}.", 

2659 value=value, 

2660 field="documentation", 

2661 ) 

2662 

2663 return value 

2664 

2665 inputs: NotEmpty[Sequence[InputTensorDescr]] 

2666 """Describes the input tensors expected by this model.""" 

2667 

2668 @field_validator("inputs", mode="after") 

2669 @classmethod 

2670 def _validate_input_axes( 

2671 cls, inputs: Sequence[InputTensorDescr] 

2672 ) -> Sequence[InputTensorDescr]: 

2673 input_size_refs = cls._get_axes_with_independent_size(inputs) 

2674 

2675 for i, ipt in enumerate(inputs): 

2676 valid_independent_refs: Dict[ 

2677 Tuple[TensorId, AxisId], 

2678 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2679 ] = { 

2680 **{ 

2681 (ipt.id, a.id): (ipt, a, a.size) 

2682 for a in ipt.axes 

2683 if not isinstance(a, BatchAxis) 

2684 and isinstance(a.size, (int, ParameterizedSize)) 

2685 }, 

2686 **input_size_refs, 

2687 } 

2688 for a, ax in enumerate(ipt.axes): 

2689 cls._validate_axis( 

2690 "inputs", 

2691 i=i, 

2692 tensor_id=ipt.id, 

2693 a=a, 

2694 axis=ax, 

2695 valid_independent_refs=valid_independent_refs, 

2696 ) 

2697 return inputs 

2698 

2699 @staticmethod 

2700 def _validate_axis( 

2701 field_name: str, 

2702 i: int, 

2703 tensor_id: TensorId, 

2704 a: int, 

2705 axis: AnyAxis, 

2706 valid_independent_refs: Dict[ 

2707 Tuple[TensorId, AxisId], 

2708 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2709 ], 

2710 ): 

2711 if isinstance(axis, BatchAxis) or isinstance( 

2712 axis.size, (int, ParameterizedSize, DataDependentSize) 

2713 ): 

2714 return 

2715 elif not isinstance(axis.size, SizeReference): 

2716 assert_never(axis.size) 

2717 

2718 # validate axis.size SizeReference 

2719 ref = (axis.size.tensor_id, axis.size.axis_id) 

2720 if ref not in valid_independent_refs: 

2721 raise ValueError( 

2722 "Invalid tensor axis reference at" 

2723 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

2724 ) 

2725 if ref == (tensor_id, axis.id): 

2726 raise ValueError( 

2727 "Self-referencing not allowed for" 

2728 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

2729 ) 

2730 if axis.type == "channel": 

2731 if valid_independent_refs[ref][1].type != "channel": 

2732 raise ValueError( 

2733 "A channel axis' size may only reference another fixed size" 

2734 + " channel axis." 

2735 ) 

2736 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

2737 ref_size = valid_independent_refs[ref][2] 

2738 assert isinstance(ref_size, int), ( 

2739 "channel axis ref (another channel axis) has to specify fixed" 

2740 + " size" 

2741 ) 

2742 generated_channel_names = [ 

2743 Identifier(axis.channel_names.format(i=i)) 

2744 for i in range(1, ref_size + 1) 

2745 ] 

2746 axis.channel_names = generated_channel_names 

2747 

2748 if (ax_unit := getattr(axis, "unit", None)) != ( 

2749 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

2750 ): 

2751 raise ValueError( 

2752 "The units of an axis and its reference axis need to match, but" 

2753 + f" '{ax_unit}' != '{ref_unit}'." 

2754 ) 

2755 ref_axis = valid_independent_refs[ref][1] 

2756 if isinstance(ref_axis, BatchAxis): 

2757 raise ValueError( 

2758 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

2759 + " (a batch axis is not allowed as reference)." 

2760 ) 

2761 

2762 if isinstance(axis, WithHalo): 

2763 min_size = axis.size.get_size(axis, ref_axis, n=0) 

2764 if (min_size - 2 * axis.halo) < 1: 

2765 raise ValueError( 

2766 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

2767 + f" {axis.halo}." 

2768 ) 

2769 

2770 input_halo = axis.halo * axis.scale / ref_axis.scale 

2771 if input_halo != int(input_halo) or input_halo % 2 == 1: 

2772 raise ValueError( 

2773 f"input_halo {input_halo} (output_halo {axis.halo} *" 

2774 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 

2775 + f" {tensor_id}.{axis.id}." 

2776 ) 

2777 

2778 @model_validator(mode="after") 

2779 def _validate_test_tensors(self) -> Self: 

2780 if not get_validation_context().perform_io_checks: 

2781 return self 

2782 

2783 test_output_arrays = [ 

2784 None if descr.test_tensor is None else load_array(descr.test_tensor) 

2785 for descr in self.outputs 

2786 ] 

2787 test_input_arrays = [ 

2788 None if descr.test_tensor is None else load_array(descr.test_tensor) 

2789 for descr in self.inputs 

2790 ] 

2791 

2792 tensors = { 

2793 descr.id: (descr, array) 

2794 for descr, array in zip( 

2795 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 

2796 ) 

2797 } 

2798 validate_tensors(tensors, tensor_origin="test_tensor") 

2799 

2800 output_arrays = { 

2801 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 

2802 } 

2803 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

2804 if not rep_tol.absolute_tolerance: 

2805 continue 

2806 

2807 if rep_tol.output_ids: 

2808 out_arrays = { 

2809 oid: a 

2810 for oid, a in output_arrays.items() 

2811 if oid in rep_tol.output_ids 

2812 } 

2813 else: 

2814 out_arrays = output_arrays 

2815 

2816 for out_id, array in out_arrays.items(): 

2817 if array is None: 

2818 continue 

2819 

2820 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

2821 raise ValueError( 

2822 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

2823 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

2824 + f" (1% of the maximum value of the test tensor '{out_id}')" 

2825 ) 

2826 

2827 return self 

2828 

2829 @model_validator(mode="after") 

2830 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

2831 ipt_refs = {t.id for t in self.inputs} 

2832 out_refs = {t.id for t in self.outputs} 

2833 for ipt in self.inputs: 

2834 for p in ipt.preprocessing: 

2835 ref = p.kwargs.get("reference_tensor") 

2836 if ref is None: 

2837 continue 

2838 if ref not in ipt_refs: 

2839 raise ValueError( 

2840 f"`reference_tensor` '{ref}' not found. Valid input tensor" 

2841 + f" references are: {ipt_refs}." 

2842 ) 

2843 

2844 for out in self.outputs: 

2845 for p in out.postprocessing: 

2846 ref = p.kwargs.get("reference_tensor") 

2847 if ref is None: 

2848 continue 

2849 

2850 if ref not in ipt_refs and ref not in out_refs: 

2851 raise ValueError( 

2852 f"`reference_tensor` '{ref}' not found. Valid tensor references" 

2853 + f" are: {ipt_refs | out_refs}." 

2854 ) 

2855 

2856 return self 

2857 

2858 # TODO: use validate funcs in validate_test_tensors 

2859 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 

2860 

2861 name: Annotated[ 

2862 str, 

2863 RestrictCharacters(string.ascii_letters + string.digits + "_+- ()"), 

2864 MinLen(5), 

2865 MaxLen(128), 

2866 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

2867 ] 

2868 """A human-readable name of this model. 

2869 It should be no longer than 64 characters 

2870 and may only contain letter, number, underscore, minus, parentheses and spaces. 

2871 We recommend to chose a name that refers to the model's task and image modality. 

2872 """ 

2873 

2874 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

2875 """Describes the output tensors.""" 

2876 

2877 @field_validator("outputs", mode="after") 

2878 @classmethod 

2879 def _validate_tensor_ids( 

2880 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

2881 ) -> Sequence[OutputTensorDescr]: 

2882 tensor_ids = [ 

2883 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

2884 ] 

2885 duplicate_tensor_ids: List[str] = [] 

2886 seen: Set[str] = set() 

2887 for t in tensor_ids: 

2888 if t in seen: 

2889 duplicate_tensor_ids.append(t) 

2890 

2891 seen.add(t) 

2892 

2893 if duplicate_tensor_ids: 

2894 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

2895 

2896 return outputs 

2897 

2898 @staticmethod 

2899 def _get_axes_with_parameterized_size( 

2900 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2901 ): 

2902 return { 

2903 f"{t.id}.{a.id}": (t, a, a.size) 

2904 for t in io 

2905 for a in t.axes 

2906 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

2907 } 

2908 

2909 @staticmethod 

2910 def _get_axes_with_independent_size( 

2911 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2912 ): 

2913 return { 

2914 (t.id, a.id): (t, a, a.size) 

2915 for t in io 

2916 for a in t.axes 

2917 if not isinstance(a, BatchAxis) 

2918 and isinstance(a.size, (int, ParameterizedSize)) 

2919 } 

2920 

2921 @field_validator("outputs", mode="after") 

2922 @classmethod 

2923 def _validate_output_axes( 

2924 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

2925 ) -> List[OutputTensorDescr]: 

2926 input_size_refs = cls._get_axes_with_independent_size( 

2927 info.data.get("inputs", []) 

2928 ) 

2929 output_size_refs = cls._get_axes_with_independent_size(outputs) 

2930 

2931 for i, out in enumerate(outputs): 

2932 valid_independent_refs: Dict[ 

2933 Tuple[TensorId, AxisId], 

2934 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2935 ] = { 

2936 **{ 

2937 (out.id, a.id): (out, a, a.size) 

2938 for a in out.axes 

2939 if not isinstance(a, BatchAxis) 

2940 and isinstance(a.size, (int, ParameterizedSize)) 

2941 }, 

2942 **input_size_refs, 

2943 **output_size_refs, 

2944 } 

2945 for a, ax in enumerate(out.axes): 

2946 cls._validate_axis( 

2947 "outputs", 

2948 i, 

2949 out.id, 

2950 a, 

2951 ax, 

2952 valid_independent_refs=valid_independent_refs, 

2953 ) 

2954 

2955 return outputs 

2956 

2957 packaged_by: List[Author] = Field( 

2958 default_factory=cast(Callable[[], List[Author]], list) 

2959 ) 

2960 """The persons that have packaged and uploaded this model. 

2961 Only required if those persons differ from the `authors`.""" 

2962 

2963 parent: Optional[LinkedModel] = None 

2964 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

2965 

2966 @model_validator(mode="after") 

2967 def _validate_parent_is_not_self(self) -> Self: 

2968 if self.parent is not None and self.parent.id == self.id: 

2969 raise ValueError("A model description may not reference itself as parent.") 

2970 

2971 return self 

2972 

2973 run_mode: Annotated[ 

2974 Optional[RunMode], 

2975 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

2976 ] = None 

2977 """Custom run mode for this model: for more complex prediction procedures like test time 

2978 data augmentation that currently cannot be expressed in the specification. 

2979 No standard run modes are defined yet.""" 

2980 

2981 timestamp: Datetime = Field(default_factory=Datetime.now) 

2982 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

2983 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

2984 (In Python a datetime object is valid, too).""" 

2985 

2986 training_data: Annotated[ 

2987 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

2988 Field(union_mode="left_to_right"), 

2989 ] = None 

2990 """The dataset used to train this model""" 

2991 

2992 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

2993 """The weights for this model. 

2994 Weights can be given for different formats, but should otherwise be equivalent. 

2995 The available weight formats determine which consumers can use this model.""" 

2996 

2997 config: Config = Field(default_factory=Config.model_construct) 

2998 

2999 @model_validator(mode="after") 

3000 def _add_default_cover(self) -> Self: 

3001 if not get_validation_context().perform_io_checks or self.covers: 

3002 return self 

3003 

3004 try: 

3005 generated_covers = generate_covers( 

3006 [ 

3007 (t, load_array(t.test_tensor)) 

3008 for t in self.inputs 

3009 if t.test_tensor is not None 

3010 ], 

3011 [ 

3012 (t, load_array(t.test_tensor)) 

3013 for t in self.outputs 

3014 if t.test_tensor is not None 

3015 ], 

3016 ) 

3017 except Exception as e: 

3018 issue_warning( 

3019 "Failed to generate cover image(s): {e}", 

3020 value=self.covers, 

3021 msg_context=dict(e=e), 

3022 field="covers", 

3023 ) 

3024 else: 

3025 self.covers.extend(generated_covers) 

3026 

3027 return self 

3028 

3029 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

3030 return self._get_test_arrays(self.inputs) 

3031 

3032 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

3033 return self._get_test_arrays(self.outputs) 

3034 

3035 @staticmethod 

3036 def _get_test_arrays( 

3037 io_descr: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

3038 ): 

3039 ts: List[FileDescr] = [] 

3040 for d in io_descr: 

3041 if d.test_tensor is None: 

3042 raise ValueError( 

3043 f"Failed to get test arrays: description of '{d.id}' is missing a `test_tensor`." 

3044 ) 

3045 ts.append(d.test_tensor) 

3046 

3047 data = [load_array(t) for t in ts] 

3048 assert all(isinstance(d, np.ndarray) for d in data) 

3049 return data 

3050 

3051 @staticmethod 

3052 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

3053 batch_size = 1 

3054 tensor_with_batchsize: Optional[TensorId] = None 

3055 for tid in tensor_sizes: 

3056 for aid, s in tensor_sizes[tid].items(): 

3057 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

3058 continue 

3059 

3060 if batch_size != 1: 

3061 assert tensor_with_batchsize is not None 

3062 raise ValueError( 

3063 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

3064 ) 

3065 

3066 batch_size = s 

3067 tensor_with_batchsize = tid 

3068 

3069 return batch_size 

3070 

3071 def get_output_tensor_sizes( 

3072 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

3073 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

3074 """Returns the tensor output sizes for given **input_sizes**. 

3075 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

3076 Otherwise it might be larger than the actual (valid) output""" 

3077 batch_size = self.get_batch_size(input_sizes) 

3078 ns = self.get_ns(input_sizes) 

3079 

3080 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

3081 return tensor_sizes.outputs 

3082 

3083 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

3084 """get parameter `n` for each parameterized axis 

3085 such that the valid input size is >= the given input size""" 

3086 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

3087 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

3088 for tid in input_sizes: 

3089 for aid, s in input_sizes[tid].items(): 

3090 size_descr = axes[tid][aid].size 

3091 if isinstance(size_descr, ParameterizedSize): 

3092 ret[(tid, aid)] = size_descr.get_n(s) 

3093 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

3094 pass 

3095 else: 

3096 assert_never(size_descr) 

3097 

3098 return ret 

3099 

3100 def get_tensor_sizes( 

3101 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

3102 ) -> _TensorSizes: 

3103 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

3104 return _TensorSizes( 

3105 { 

3106 t: { 

3107 aa: axis_sizes.inputs[(tt, aa)] 

3108 for tt, aa in axis_sizes.inputs 

3109 if tt == t 

3110 } 

3111 for t in {tt for tt, _ in axis_sizes.inputs} 

3112 }, 

3113 { 

3114 t: { 

3115 aa: axis_sizes.outputs[(tt, aa)] 

3116 for tt, aa in axis_sizes.outputs 

3117 if tt == t 

3118 } 

3119 for t in {tt for tt, _ in axis_sizes.outputs} 

3120 }, 

3121 ) 

3122 

3123 def get_axis_sizes( 

3124 self, 

3125 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3126 batch_size: Optional[int] = None, 

3127 *, 

3128 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

3129 ) -> _AxisSizes: 

3130 """Determine input and output block shape for scale factors **ns** 

3131 of parameterized input sizes. 

3132 

3133 Args: 

3134 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

3135 that is parameterized as `size = min + n * step`. 

3136 batch_size: The desired size of the batch dimension. 

3137 If given **batch_size** overwrites any batch size present in 

3138 **max_input_shape**. Default 1. 

3139 max_input_shape: Limits the derived block shapes. 

3140 Each axis for which the input size, parameterized by `n`, is larger 

3141 than **max_input_shape** is set to the minimal value `n_min` for which 

3142 this is still true. 

3143 Use this for small input samples or large values of **ns**. 

3144 Or simply whenever you know the full input shape. 

3145 

3146 Returns: 

3147 Resolved axis sizes for model inputs and outputs. 

3148 """ 

3149 max_input_shape = max_input_shape or {} 

3150 if batch_size is None: 

3151 for (_t_id, a_id), s in max_input_shape.items(): 

3152 if a_id == BATCH_AXIS_ID: 

3153 batch_size = s 

3154 break 

3155 else: 

3156 batch_size = 1 

3157 

3158 all_axes = { 

3159 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

3160 } 

3161 

3162 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

3163 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

3164 

3165 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

3166 if isinstance(a, BatchAxis): 

3167 if (t_descr.id, a.id) in ns: 

3168 logger.warning( 

3169 "Ignoring unexpected size increment factor (n) for batch axis" 

3170 + " of tensor '{}'.", 

3171 t_descr.id, 

3172 ) 

3173 return batch_size 

3174 elif isinstance(a.size, int): 

3175 if (t_descr.id, a.id) in ns: 

3176 logger.warning( 

3177 "Ignoring unexpected size increment factor (n) for fixed size" 

3178 + " axis '{}' of tensor '{}'.", 

3179 a.id, 

3180 t_descr.id, 

3181 ) 

3182 return a.size 

3183 elif isinstance(a.size, ParameterizedSize): 

3184 if (t_descr.id, a.id) not in ns: 

3185 raise ValueError( 

3186 "Size increment factor (n) missing for parametrized axis" 

3187 + f" '{a.id}' of tensor '{t_descr.id}'." 

3188 ) 

3189 n = ns[(t_descr.id, a.id)] 

3190 s_max = max_input_shape.get((t_descr.id, a.id)) 

3191 if s_max is not None: 

3192 n = min(n, a.size.get_n(s_max)) 

3193 

3194 return a.size.get_size(n) 

3195 

3196 elif isinstance(a.size, SizeReference): 

3197 if (t_descr.id, a.id) in ns: 

3198 logger.warning( 

3199 "Ignoring unexpected size increment factor (n) for axis '{}'" 

3200 + " of tensor '{}' with size reference.", 

3201 a.id, 

3202 t_descr.id, 

3203 ) 

3204 assert not isinstance(a, BatchAxis) 

3205 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

3206 assert not isinstance(ref_axis, BatchAxis) 

3207 ref_key = (a.size.tensor_id, a.size.axis_id) 

3208 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

3209 assert ref_size is not None, ref_key 

3210 assert not isinstance(ref_size, _DataDepSize), ref_key 

3211 return a.size.get_size( 

3212 axis=a, 

3213 ref_axis=ref_axis, 

3214 ref_size=ref_size, 

3215 ) 

3216 elif isinstance(a.size, DataDependentSize): 

3217 if (t_descr.id, a.id) in ns: 

3218 logger.warning( 

3219 "Ignoring unexpected increment factor (n) for data dependent" 

3220 + " size axis '{}' of tensor '{}'.", 

3221 a.id, 

3222 t_descr.id, 

3223 ) 

3224 return _DataDepSize(a.size.min, a.size.max) 

3225 else: 

3226 assert_never(a.size) 

3227 

3228 # first resolve all , but the `SizeReference` input sizes 

3229 for t_descr in self.inputs: 

3230 for a in t_descr.axes: 

3231 if not isinstance(a.size, SizeReference): 

3232 s = get_axis_size(a) 

3233 assert not isinstance(s, _DataDepSize) 

3234 inputs[t_descr.id, a.id] = s 

3235 

3236 # resolve all other input axis sizes 

3237 for t_descr in self.inputs: 

3238 for a in t_descr.axes: 

3239 if isinstance(a.size, SizeReference): 

3240 s = get_axis_size(a) 

3241 assert not isinstance(s, _DataDepSize) 

3242 inputs[t_descr.id, a.id] = s 

3243 

3244 # resolve all output axis sizes 

3245 for t_descr in self.outputs: 

3246 for a in t_descr.axes: 

3247 assert not isinstance(a.size, ParameterizedSize) 

3248 s = get_axis_size(a) 

3249 outputs[t_descr.id, a.id] = s 

3250 

3251 return _AxisSizes(inputs=inputs, outputs=outputs) 

3252 

3253 @model_validator(mode="before") 

3254 @classmethod 

3255 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

3256 cls.convert_from_old_format_wo_validation(data) 

3257 return data 

3258 

3259 @classmethod 

3260 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

3261 """Convert metadata following an older format version to this classes' format 

3262 without validating the result. 

3263 """ 

3264 if ( 

3265 data.get("type") == "model" 

3266 and isinstance(fv := data.get("format_version"), str) 

3267 and fv.count(".") == 2 

3268 ): 

3269 fv_parts = fv.split(".") 

3270 if any(not p.isdigit() for p in fv_parts): 

3271 return 

3272 

3273 fv_tuple = tuple(map(int, fv_parts)) 

3274 

3275 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

3276 if fv_tuple[:2] in ((0, 3), (0, 4)): 

3277 m04 = _ModelDescr_v0_4.load(data) 

3278 if isinstance(m04, InvalidDescr): 

3279 try: 

3280 updated = _model_conv.convert_as_dict( 

3281 m04 # pyright: ignore[reportArgumentType] 

3282 ) 

3283 except Exception as e: 

3284 logger.error( 

3285 "Failed to convert from invalid model 0.4 description." 

3286 + f"\nerror: {e}" 

3287 + "\nProceeding with model 0.5 validation without conversion." 

3288 ) 

3289 updated = None 

3290 else: 

3291 updated = _model_conv.convert_as_dict(m04) 

3292 

3293 if updated is not None: 

3294 data.clear() 

3295 data.update(updated) 

3296 

3297 elif fv_tuple[:2] == (0, 5): 

3298 # bump patch version 

3299 data["format_version"] = cls.implemented_format_version 

3300 

3301 

3302class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

3303 def _convert( 

3304 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

3305 ) -> "ModelDescr | dict[str, Any]": 

3306 name = "".join( 

3307 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

3308 for c in src.name 

3309 ) 

3310 

3311 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

3312 conv = ( 

3313 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

3314 ) 

3315 return None if auths is None else [conv(a) for a in auths] 

3316 

3317 if TYPE_CHECKING: 

3318 arch_file_conv = _arch_file_conv.convert 

3319 arch_lib_conv = _arch_lib_conv.convert 

3320 else: 

3321 arch_file_conv = _arch_file_conv.convert_as_dict 

3322 arch_lib_conv = _arch_lib_conv.convert_as_dict 

3323 

3324 input_size_refs = { 

3325 ipt.name: { 

3326 a: s 

3327 for a, s in zip( 

3328 ipt.axes, 

3329 ( 

3330 ipt.shape.min 

3331 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

3332 else ipt.shape 

3333 ), 

3334 ) 

3335 } 

3336 for ipt in src.inputs 

3337 if ipt.shape 

3338 } 

3339 output_size_refs = { 

3340 **{ 

3341 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

3342 for out in src.outputs 

3343 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

3344 }, 

3345 **input_size_refs, 

3346 } 

3347 

3348 return tgt( 

3349 attachments=( 

3350 [] 

3351 if src.attachments is None 

3352 else [FileDescr(source=f) for f in src.attachments.files] 

3353 ), 

3354 authors=[_author_conv.convert_as_dict(a) for a in src.authors], # pyright: ignore[reportArgumentType] 

3355 cite=[{"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite], # pyright: ignore[reportArgumentType] 

3356 config=src.config, # pyright: ignore[reportArgumentType] 

3357 covers=src.covers, 

3358 description=src.description, 

3359 documentation=src.documentation, 

3360 format_version="0.5.5", 

3361 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3362 icon=src.icon, 

3363 id=None if src.id is None else ModelId(src.id), 

3364 id_emoji=src.id_emoji, 

3365 license=src.license, # type: ignore 

3366 links=src.links, 

3367 maintainers=[_maintainer_conv.convert_as_dict(m) for m in src.maintainers], # pyright: ignore[reportArgumentType] 

3368 name=name, 

3369 tags=src.tags, 

3370 type=src.type, 

3371 uploader=src.uploader, 

3372 version=src.version, 

3373 inputs=[ # pyright: ignore[reportArgumentType] 

3374 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

3375 for ipt, tt, st in zip( 

3376 src.inputs, 

3377 src.test_inputs, 

3378 src.sample_inputs or [None] * len(src.test_inputs), 

3379 ) 

3380 ], 

3381 outputs=[ # pyright: ignore[reportArgumentType] 

3382 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

3383 for out, tt, st in zip( 

3384 src.outputs, 

3385 src.test_outputs, 

3386 src.sample_outputs or [None] * len(src.test_outputs), 

3387 ) 

3388 ], 

3389 parent=( 

3390 None 

3391 if src.parent is None 

3392 else LinkedModel( 

3393 id=ModelId( 

3394 str(src.parent.id) 

3395 + ( 

3396 "" 

3397 if src.parent.version_number is None 

3398 else f"/{src.parent.version_number}" 

3399 ) 

3400 ) 

3401 ) 

3402 ), 

3403 training_data=( 

3404 None 

3405 if src.training_data is None 

3406 else ( 

3407 LinkedDataset( 

3408 id=DatasetId( 

3409 str(src.training_data.id) 

3410 + ( 

3411 "" 

3412 if src.training_data.version_number is None 

3413 else f"/{src.training_data.version_number}" 

3414 ) 

3415 ) 

3416 ) 

3417 if isinstance(src.training_data, LinkedDataset02) 

3418 else src.training_data 

3419 ) 

3420 ), 

3421 packaged_by=[_author_conv.convert_as_dict(a) for a in src.packaged_by], # pyright: ignore[reportArgumentType] 

3422 run_mode=src.run_mode, 

3423 timestamp=src.timestamp, 

3424 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

3425 keras_hdf5=(w := src.weights.keras_hdf5) 

3426 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

3427 authors=conv_authors(w.authors), 

3428 source=w.source, 

3429 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3430 parent=w.parent, 

3431 ), 

3432 onnx=(w := src.weights.onnx) 

3433 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

3434 source=w.source, 

3435 authors=conv_authors(w.authors), 

3436 parent=w.parent, 

3437 opset_version=w.opset_version or 15, 

3438 ), 

3439 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

3440 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

3441 source=w.source, 

3442 authors=conv_authors(w.authors), 

3443 parent=w.parent, 

3444 architecture=( 

3445 arch_file_conv( 

3446 w.architecture, 

3447 w.architecture_sha256, 

3448 w.kwargs, 

3449 ) 

3450 if isinstance(w.architecture, _CallableFromFile_v0_4) 

3451 else arch_lib_conv(w.architecture, w.kwargs) 

3452 ), 

3453 pytorch_version=w.pytorch_version or Version("1.10"), 

3454 dependencies=( 

3455 None 

3456 if w.dependencies is None 

3457 else (FileDescr if TYPE_CHECKING else dict)( 

3458 source=cast( 

3459 FileSource, 

3460 str(deps := w.dependencies)[ 

3461 ( 

3462 len("conda:") 

3463 if str(deps).startswith("conda:") 

3464 else 0 

3465 ) : 

3466 ], 

3467 ) 

3468 ) 

3469 ), 

3470 ), 

3471 tensorflow_js=(w := src.weights.tensorflow_js) 

3472 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

3473 source=w.source, 

3474 authors=conv_authors(w.authors), 

3475 parent=w.parent, 

3476 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3477 ), 

3478 tensorflow_saved_model_bundle=( 

3479 w := src.weights.tensorflow_saved_model_bundle 

3480 ) 

3481 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

3482 authors=conv_authors(w.authors), 

3483 parent=w.parent, 

3484 source=w.source, 

3485 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3486 dependencies=( 

3487 None 

3488 if w.dependencies is None 

3489 else (FileDescr if TYPE_CHECKING else dict)( 

3490 source=cast( 

3491 FileSource, 

3492 ( 

3493 str(w.dependencies)[len("conda:") :] 

3494 if str(w.dependencies).startswith("conda:") 

3495 else str(w.dependencies) 

3496 ), 

3497 ) 

3498 ) 

3499 ), 

3500 ), 

3501 torchscript=(w := src.weights.torchscript) 

3502 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

3503 source=w.source, 

3504 authors=conv_authors(w.authors), 

3505 parent=w.parent, 

3506 pytorch_version=w.pytorch_version or Version("1.10"), 

3507 ), 

3508 ), 

3509 ) 

3510 

3511 

3512_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

3513 

3514 

3515# create better cover images for 3d data and non-image outputs 

3516def generate_covers( 

3517 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

3518 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

3519) -> List[Path]: 

3520 def squeeze( 

3521 data: NDArray[Any], axes: Sequence[AnyAxis] 

3522 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

3523 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

3524 if data.ndim != len(axes): 

3525 raise ValueError( 

3526 f"tensor shape {data.shape} does not match described axes" 

3527 + f" {[a.id for a in axes]}" 

3528 ) 

3529 

3530 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

3531 return data.squeeze(), axes 

3532 

3533 def normalize( 

3534 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

3535 ) -> NDArray[np.float32]: 

3536 data = data.astype("float32") 

3537 data -= data.min(axis=axis, keepdims=True) 

3538 data /= data.max(axis=axis, keepdims=True) + eps 

3539 return data 

3540 

3541 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

3542 original_shape = data.shape 

3543 data, axes = squeeze(data, axes) 

3544 

3545 # take slice fom any batch or index axis if needed 

3546 # and convert the first channel axis and take a slice from any additional channel axes 

3547 slices: Tuple[slice, ...] = () 

3548 ndim = data.ndim 

3549 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

3550 has_c_axis = False 

3551 for i, a in enumerate(axes): 

3552 s = data.shape[i] 

3553 assert s > 1 

3554 if ( 

3555 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

3556 and ndim > ndim_need 

3557 ): 

3558 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3559 ndim -= 1 

3560 elif isinstance(a, ChannelAxis): 

3561 if has_c_axis: 

3562 # second channel axis 

3563 data = data[slices + (slice(0, 1),)] 

3564 ndim -= 1 

3565 else: 

3566 has_c_axis = True 

3567 if s == 2: 

3568 # visualize two channels with cyan and magenta 

3569 data = np.concatenate( 

3570 [ 

3571 data[slices + (slice(1, 2),)], 

3572 data[slices + (slice(0, 1),)], 

3573 ( 

3574 data[slices + (slice(0, 1),)] 

3575 + data[slices + (slice(1, 2),)] 

3576 ) 

3577 / 2, # TODO: take maximum instead? 

3578 ], 

3579 axis=i, 

3580 ) 

3581 elif data.shape[i] == 3: 

3582 pass # visualize 3 channels as RGB 

3583 else: 

3584 # visualize first 3 channels as RGB 

3585 data = data[slices + (slice(3),)] 

3586 

3587 assert data.shape[i] == 3 

3588 

3589 slices += (slice(None),) 

3590 

3591 data, axes = squeeze(data, axes) 

3592 assert len(axes) == ndim 

3593 # take slice from z axis if needed 

3594 slices = () 

3595 if ndim > ndim_need: 

3596 for i, a in enumerate(axes): 

3597 s = data.shape[i] 

3598 if a.id == AxisId("z"): 

3599 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3600 data, axes = squeeze(data, axes) 

3601 ndim -= 1 

3602 break 

3603 

3604 slices += (slice(None),) 

3605 

3606 # take slice from any space or time axis 

3607 slices = () 

3608 

3609 for i, a in enumerate(axes): 

3610 if ndim <= ndim_need: 

3611 break 

3612 

3613 s = data.shape[i] 

3614 assert s > 1 

3615 if isinstance( 

3616 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

3617 ): 

3618 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3619 ndim -= 1 

3620 

3621 slices += (slice(None),) 

3622 

3623 del slices 

3624 data, axes = squeeze(data, axes) 

3625 assert len(axes) == ndim 

3626 

3627 if (has_c_axis and ndim != 3) or ndim != 2: 

3628 raise ValueError( 

3629 f"Failed to construct cover image from shape {original_shape}" 

3630 ) 

3631 

3632 if not has_c_axis: 

3633 assert ndim == 2 

3634 data = np.repeat(data[:, :, None], 3, axis=2) 

3635 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

3636 ndim += 1 

3637 

3638 assert ndim == 3 

3639 

3640 # transpose axis order such that longest axis comes first... 

3641 axis_order: List[int] = list(np.argsort(list(data.shape))) 

3642 axis_order.reverse() 

3643 # ... and channel axis is last 

3644 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

3645 axis_order.append(axis_order.pop(c)) 

3646 axes = [axes[ao] for ao in axis_order] 

3647 data = data.transpose(axis_order) 

3648 

3649 # h, w = data.shape[:2] 

3650 # if h / w in (1.0 or 2.0): 

3651 # pass 

3652 # elif h / w < 2: 

3653 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

3654 

3655 norm_along = ( 

3656 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

3657 ) 

3658 # normalize the data and map to 8 bit 

3659 data = normalize(data, norm_along) 

3660 data = (data * 255).astype("uint8") 

3661 

3662 return data 

3663 

3664 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

3665 assert im0.dtype == im1.dtype == np.uint8 

3666 assert im0.shape == im1.shape 

3667 assert im0.ndim == 3 

3668 N, M, C = im0.shape 

3669 assert C == 3 

3670 out = np.ones((N, M, C), dtype="uint8") 

3671 for c in range(C): 

3672 outc = np.tril(im0[..., c]) 

3673 mask = outc == 0 

3674 outc[mask] = np.triu(im1[..., c])[mask] 

3675 out[..., c] = outc 

3676 

3677 return out 

3678 

3679 if not inputs: 

3680 raise ValueError("Missing test input tensor for cover generation.") 

3681 

3682 if not outputs: 

3683 raise ValueError("Missing test output tensor for cover generation.") 

3684 

3685 ipt_descr, ipt = inputs[0] 

3686 out_descr, out = outputs[0] 

3687 

3688 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

3689 out_img = to_2d_image(out, out_descr.axes) 

3690 

3691 cover_folder = Path(mkdtemp()) 

3692 if ipt_img.shape == out_img.shape: 

3693 covers = [cover_folder / "cover.png"] 

3694 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

3695 else: 

3696 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

3697 imwrite(covers[0], ipt_img) 

3698 imwrite(covers[1], out_img) 

3699 

3700 return covers