Coverage for bioimageio/spec/model/v0_5.py: 75%

1325 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-18 12:47 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from abc import ABC 

8from copy import deepcopy 

9from itertools import chain 

10from math import ceil 

11from pathlib import Path, PurePosixPath 

12from tempfile import mkdtemp 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 ClassVar, 

18 Dict, 

19 Generic, 

20 List, 

21 Literal, 

22 Mapping, 

23 NamedTuple, 

24 Optional, 

25 Sequence, 

26 Set, 

27 Tuple, 

28 Type, 

29 TypeVar, 

30 Union, 

31 cast, 

32) 

33 

34import numpy as np 

35from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

36from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

37from loguru import logger 

38from numpy.typing import NDArray 

39from pydantic import ( 

40 AfterValidator, 

41 Discriminator, 

42 Field, 

43 RootModel, 

44 SerializationInfo, 

45 SerializerFunctionWrapHandler, 

46 StrictInt, 

47 Tag, 

48 ValidationInfo, 

49 WrapSerializer, 

50 field_validator, 

51 model_serializer, 

52 model_validator, 

53) 

54from typing_extensions import Annotated, Self, assert_never, get_args 

55 

56from .._internal.common_nodes import ( 

57 InvalidDescr, 

58 Node, 

59 NodeWithExplicitlySetFields, 

60) 

61from .._internal.constants import DTYPE_LIMITS 

62from .._internal.field_warning import issue_warning, warn 

63from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

64from .._internal.io import FileDescr as FileDescr 

65from .._internal.io import ( 

66 FileSource, 

67 WithSuffix, 

68 YamlValue, 

69 get_reader, 

70 wo_special_file_name, 

71) 

72from .._internal.io_basics import Sha256 as Sha256 

73from .._internal.io_packaging import ( 

74 FileDescr_, 

75 FileSource_, 

76 package_file_descr_serializer, 

77) 

78from .._internal.io_utils import load_array 

79from .._internal.node_converter import Converter 

80from .._internal.type_guards import is_dict, is_sequence 

81from .._internal.types import ( 

82 AbsoluteTolerance, 

83 LowerCaseIdentifier, 

84 LowerCaseIdentifierAnno, 

85 MismatchedElementsPerMillion, 

86 RelativeTolerance, 

87) 

88from .._internal.types import Datetime as Datetime 

89from .._internal.types import Identifier as Identifier 

90from .._internal.types import NotEmpty as NotEmpty 

91from .._internal.types import SiUnit as SiUnit 

92from .._internal.url import HttpUrl as HttpUrl 

93from .._internal.validation_context import get_validation_context 

94from .._internal.validator_annotations import RestrictCharacters 

95from .._internal.version_type import Version as Version 

96from .._internal.warning_levels import INFO 

97from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

98from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

99from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

100from ..dataset.v0_3 import DatasetId as DatasetId 

101from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

102from ..dataset.v0_3 import Uploader as Uploader 

103from ..generic.v0_3 import ( 

104 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

105) 

106from ..generic.v0_3 import Author as Author 

107from ..generic.v0_3 import BadgeDescr as BadgeDescr 

108from ..generic.v0_3 import CiteEntry as CiteEntry 

109from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

110from ..generic.v0_3 import Doi as Doi 

111from ..generic.v0_3 import ( 

112 FileSource_documentation, 

113 GenericModelDescrBase, 

114 LinkedResourceBase, 

115 _author_conv, # pyright: ignore[reportPrivateUsage] 

116 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

117) 

118from ..generic.v0_3 import LicenseId as LicenseId 

119from ..generic.v0_3 import LinkedResource as LinkedResource 

120from ..generic.v0_3 import Maintainer as Maintainer 

121from ..generic.v0_3 import OrcidId as OrcidId 

122from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

123from ..generic.v0_3 import ResourceId as ResourceId 

124from .v0_4 import Author as _Author_v0_4 

125from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

126from .v0_4 import CallableFromDepencency as CallableFromDepencency 

127from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

128from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

129from .v0_4 import ClipDescr as _ClipDescr_v0_4 

130from .v0_4 import ClipKwargs as ClipKwargs 

131from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

132from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

133from .v0_4 import KnownRunMode as KnownRunMode 

134from .v0_4 import ModelDescr as _ModelDescr_v0_4 

135from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

136from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

137from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

138from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

139from .v0_4 import ProcessingKwargs as ProcessingKwargs 

140from .v0_4 import RunMode as RunMode 

141from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

142from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

143from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

144from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

145from .v0_4 import TensorName as _TensorName_v0_4 

146from .v0_4 import WeightsFormat as WeightsFormat 

147from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

148from .v0_4 import package_weights 

149 

150SpaceUnit = Literal[ 

151 "attometer", 

152 "angstrom", 

153 "centimeter", 

154 "decimeter", 

155 "exameter", 

156 "femtometer", 

157 "foot", 

158 "gigameter", 

159 "hectometer", 

160 "inch", 

161 "kilometer", 

162 "megameter", 

163 "meter", 

164 "micrometer", 

165 "mile", 

166 "millimeter", 

167 "nanometer", 

168 "parsec", 

169 "petameter", 

170 "picometer", 

171 "terameter", 

172 "yard", 

173 "yoctometer", 

174 "yottameter", 

175 "zeptometer", 

176 "zettameter", 

177] 

178"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

179 

180TimeUnit = Literal[ 

181 "attosecond", 

182 "centisecond", 

183 "day", 

184 "decisecond", 

185 "exasecond", 

186 "femtosecond", 

187 "gigasecond", 

188 "hectosecond", 

189 "hour", 

190 "kilosecond", 

191 "megasecond", 

192 "microsecond", 

193 "millisecond", 

194 "minute", 

195 "nanosecond", 

196 "petasecond", 

197 "picosecond", 

198 "second", 

199 "terasecond", 

200 "yoctosecond", 

201 "yottasecond", 

202 "zeptosecond", 

203 "zettasecond", 

204] 

205"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

206 

207AxisType = Literal["batch", "channel", "index", "time", "space"] 

208 

209_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

210 "b": "batch", 

211 "t": "time", 

212 "i": "index", 

213 "c": "channel", 

214 "x": "space", 

215 "y": "space", 

216 "z": "space", 

217} 

218 

219_AXIS_ID_MAP = { 

220 "b": "batch", 

221 "t": "time", 

222 "i": "index", 

223 "c": "channel", 

224} 

225 

226 

227class TensorId(LowerCaseIdentifier): 

228 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

229 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

230 ] 

231 

232 

233def _normalize_axis_id(a: str): 

234 a = str(a) 

235 normalized = _AXIS_ID_MAP.get(a, a) 

236 if a != normalized: 

237 logger.opt(depth=3).warning( 

238 "Normalized axis id from '{}' to '{}'.", a, normalized 

239 ) 

240 return normalized 

241 

242 

243class AxisId(LowerCaseIdentifier): 

244 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

245 Annotated[ 

246 LowerCaseIdentifierAnno, 

247 MaxLen(16), 

248 AfterValidator(_normalize_axis_id), 

249 ] 

250 ] 

251 

252 

253def _is_batch(a: str) -> bool: 

254 return str(a) == "batch" 

255 

256 

257def _is_not_batch(a: str) -> bool: 

258 return not _is_batch(a) 

259 

260 

261NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

262 

263PostprocessingId = Literal[ 

264 "binarize", 

265 "clip", 

266 "ensure_dtype", 

267 "fixed_zero_mean_unit_variance", 

268 "scale_linear", 

269 "scale_mean_variance", 

270 "scale_range", 

271 "sigmoid", 

272 "zero_mean_unit_variance", 

273] 

274PreprocessingId = Literal[ 

275 "binarize", 

276 "clip", 

277 "ensure_dtype", 

278 "scale_linear", 

279 "sigmoid", 

280 "zero_mean_unit_variance", 

281 "scale_range", 

282] 

283 

284 

285SAME_AS_TYPE = "<same as type>" 

286 

287 

288ParameterizedSize_N = int 

289""" 

290Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

291""" 

292 

293 

294class ParameterizedSize(Node): 

295 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

296 

297 - **min** and **step** are given by the model description. 

298 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

299 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

300 This allows to adjust the axis size more generically. 

301 """ 

302 

303 N: ClassVar[Type[int]] = ParameterizedSize_N 

304 """Positive integer to parameterize this axis""" 

305 

306 min: Annotated[int, Gt(0)] 

307 step: Annotated[int, Gt(0)] 

308 

309 def validate_size(self, size: int) -> int: 

310 if size < self.min: 

311 raise ValueError(f"size {size} < {self.min}") 

312 if (size - self.min) % self.step != 0: 

313 raise ValueError( 

314 f"axis of size {size} is not parameterized by `min + n*step` =" 

315 + f" `{self.min} + n*{self.step}`" 

316 ) 

317 

318 return size 

319 

320 def get_size(self, n: ParameterizedSize_N) -> int: 

321 return self.min + self.step * n 

322 

323 def get_n(self, s: int) -> ParameterizedSize_N: 

324 """return smallest n parameterizing a size greater or equal than `s`""" 

325 return ceil((s - self.min) / self.step) 

326 

327 

328class DataDependentSize(Node): 

329 min: Annotated[int, Gt(0)] = 1 

330 max: Annotated[Optional[int], Gt(1)] = None 

331 

332 @model_validator(mode="after") 

333 def _validate_max_gt_min(self): 

334 if self.max is not None and self.min >= self.max: 

335 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

336 

337 return self 

338 

339 def validate_size(self, size: int) -> int: 

340 if size < self.min: 

341 raise ValueError(f"size {size} < {self.min}") 

342 

343 if self.max is not None and size > self.max: 

344 raise ValueError(f"size {size} > {self.max}") 

345 

346 return size 

347 

348 

349class SizeReference(Node): 

350 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

351 

352 `axis.size = reference.size * reference.scale / axis.scale + offset` 

353 

354 Note: 

355 1. The axis and the referenced axis need to have the same unit (or no unit). 

356 2. Batch axes may not be referenced. 

357 3. Fractions are rounded down. 

358 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

359 `concatenable` as well with the same block order. 

360 

361 Example: 

362 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

363 Let's assume that we want to express the image height h in relation to its width w 

364 instead of only accepting input images of exactly 100*49 pixels 

365 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

366 

367 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

368 >>> h = SpaceInputAxis( 

369 ... id=AxisId("h"), 

370 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

371 ... unit="millimeter", 

372 ... scale=4, 

373 ... ) 

374 >>> print(h.size.get_size(h, w)) 

375 49 

376 

377 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

378 """ 

379 

380 tensor_id: TensorId 

381 """tensor id of the reference axis""" 

382 

383 axis_id: AxisId 

384 """axis id of the reference axis""" 

385 

386 offset: StrictInt = 0 

387 

388 def get_size( 

389 self, 

390 axis: Union[ 

391 ChannelAxis, 

392 IndexInputAxis, 

393 IndexOutputAxis, 

394 TimeInputAxis, 

395 SpaceInputAxis, 

396 TimeOutputAxis, 

397 TimeOutputAxisWithHalo, 

398 SpaceOutputAxis, 

399 SpaceOutputAxisWithHalo, 

400 ], 

401 ref_axis: Union[ 

402 ChannelAxis, 

403 IndexInputAxis, 

404 IndexOutputAxis, 

405 TimeInputAxis, 

406 SpaceInputAxis, 

407 TimeOutputAxis, 

408 TimeOutputAxisWithHalo, 

409 SpaceOutputAxis, 

410 SpaceOutputAxisWithHalo, 

411 ], 

412 n: ParameterizedSize_N = 0, 

413 ref_size: Optional[int] = None, 

414 ): 

415 """Compute the concrete size for a given axis and its reference axis. 

416 

417 Args: 

418 axis: The axis this `SizeReference` is the size of. 

419 ref_axis: The reference axis to compute the size from. 

420 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

421 and no fixed **ref_size** is given, 

422 **n** is used to compute the size of the parameterized **ref_axis**. 

423 ref_size: Overwrite the reference size instead of deriving it from 

424 **ref_axis** 

425 (**ref_axis.scale** is still used; any given **n** is ignored). 

426 """ 

427 assert ( 

428 axis.size == self 

429 ), "Given `axis.size` is not defined by this `SizeReference`" 

430 

431 assert ( 

432 ref_axis.id == self.axis_id 

433 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

434 

435 assert axis.unit == ref_axis.unit, ( 

436 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

437 f" but {axis.unit}!={ref_axis.unit}" 

438 ) 

439 if ref_size is None: 

440 if isinstance(ref_axis.size, (int, float)): 

441 ref_size = ref_axis.size 

442 elif isinstance(ref_axis.size, ParameterizedSize): 

443 ref_size = ref_axis.size.get_size(n) 

444 elif isinstance(ref_axis.size, DataDependentSize): 

445 raise ValueError( 

446 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

447 ) 

448 elif isinstance(ref_axis.size, SizeReference): 

449 raise ValueError( 

450 "Reference axis referenced in `SizeReference` may not be sized by a" 

451 + " `SizeReference` itself." 

452 ) 

453 else: 

454 assert_never(ref_axis.size) 

455 

456 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

457 

458 @staticmethod 

459 def _get_unit( 

460 axis: Union[ 

461 ChannelAxis, 

462 IndexInputAxis, 

463 IndexOutputAxis, 

464 TimeInputAxis, 

465 SpaceInputAxis, 

466 TimeOutputAxis, 

467 TimeOutputAxisWithHalo, 

468 SpaceOutputAxis, 

469 SpaceOutputAxisWithHalo, 

470 ], 

471 ): 

472 return axis.unit 

473 

474 

475class AxisBase(NodeWithExplicitlySetFields): 

476 id: AxisId 

477 """An axis id unique across all axes of one tensor.""" 

478 

479 description: Annotated[str, MaxLen(128)] = "" 

480 

481 

482class WithHalo(Node): 

483 halo: Annotated[int, Ge(1)] 

484 """The halo should be cropped from the output tensor to avoid boundary effects. 

485 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

486 To document a halo that is already cropped by the model use `size.offset` instead.""" 

487 

488 size: Annotated[ 

489 SizeReference, 

490 Field( 

491 examples=[ 

492 10, 

493 SizeReference( 

494 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

495 ).model_dump(mode="json"), 

496 ] 

497 ), 

498 ] 

499 """reference to another axis with an optional offset (see `SizeReference`)""" 

500 

501 

502BATCH_AXIS_ID = AxisId("batch") 

503 

504 

505class BatchAxis(AxisBase): 

506 implemented_type: ClassVar[Literal["batch"]] = "batch" 

507 if TYPE_CHECKING: 

508 type: Literal["batch"] = "batch" 

509 else: 

510 type: Literal["batch"] 

511 

512 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

513 size: Optional[Literal[1]] = None 

514 """The batch size may be fixed to 1, 

515 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

516 

517 @property 

518 def scale(self): 

519 return 1.0 

520 

521 @property 

522 def concatenable(self): 

523 return True 

524 

525 @property 

526 def unit(self): 

527 return None 

528 

529 

530class ChannelAxis(AxisBase): 

531 implemented_type: ClassVar[Literal["channel"]] = "channel" 

532 if TYPE_CHECKING: 

533 type: Literal["channel"] = "channel" 

534 else: 

535 type: Literal["channel"] 

536 

537 id: NonBatchAxisId = AxisId("channel") 

538 channel_names: NotEmpty[List[Identifier]] 

539 

540 @property 

541 def size(self) -> int: 

542 return len(self.channel_names) 

543 

544 @property 

545 def concatenable(self): 

546 return False 

547 

548 @property 

549 def scale(self) -> float: 

550 return 1.0 

551 

552 @property 

553 def unit(self): 

554 return None 

555 

556 

557class IndexAxisBase(AxisBase): 

558 implemented_type: ClassVar[Literal["index"]] = "index" 

559 if TYPE_CHECKING: 

560 type: Literal["index"] = "index" 

561 else: 

562 type: Literal["index"] 

563 

564 id: NonBatchAxisId = AxisId("index") 

565 

566 @property 

567 def scale(self) -> float: 

568 return 1.0 

569 

570 @property 

571 def unit(self): 

572 return None 

573 

574 

575class _WithInputAxisSize(Node): 

576 size: Annotated[ 

577 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

578 Field( 

579 examples=[ 

580 10, 

581 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

582 SizeReference( 

583 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

584 ).model_dump(mode="json"), 

585 ] 

586 ), 

587 ] 

588 """The size/length of this axis can be specified as 

589 - fixed integer 

590 - parameterized series of valid sizes (`ParameterizedSize`) 

591 - reference to another axis with an optional offset (`SizeReference`) 

592 """ 

593 

594 

595class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

596 concatenable: bool = False 

597 """If a model has a `concatenable` input axis, it can be processed blockwise, 

598 splitting a longer sample axis into blocks matching its input tensor description. 

599 Output axes are concatenable if they have a `SizeReference` to a concatenable 

600 input axis. 

601 """ 

602 

603 

604class IndexOutputAxis(IndexAxisBase): 

605 size: Annotated[ 

606 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

607 Field( 

608 examples=[ 

609 10, 

610 SizeReference( 

611 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

612 ).model_dump(mode="json"), 

613 ] 

614 ), 

615 ] 

616 """The size/length of this axis can be specified as 

617 - fixed integer 

618 - reference to another axis with an optional offset (`SizeReference`) 

619 - data dependent size using `DataDependentSize` (size is only known after model inference) 

620 """ 

621 

622 

623class TimeAxisBase(AxisBase): 

624 implemented_type: ClassVar[Literal["time"]] = "time" 

625 if TYPE_CHECKING: 

626 type: Literal["time"] = "time" 

627 else: 

628 type: Literal["time"] 

629 

630 id: NonBatchAxisId = AxisId("time") 

631 unit: Optional[TimeUnit] = None 

632 scale: Annotated[float, Gt(0)] = 1.0 

633 

634 

635class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

636 concatenable: bool = False 

637 """If a model has a `concatenable` input axis, it can be processed blockwise, 

638 splitting a longer sample axis into blocks matching its input tensor description. 

639 Output axes are concatenable if they have a `SizeReference` to a concatenable 

640 input axis. 

641 """ 

642 

643 

644class SpaceAxisBase(AxisBase): 

645 implemented_type: ClassVar[Literal["space"]] = "space" 

646 if TYPE_CHECKING: 

647 type: Literal["space"] = "space" 

648 else: 

649 type: Literal["space"] 

650 

651 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

652 unit: Optional[SpaceUnit] = None 

653 scale: Annotated[float, Gt(0)] = 1.0 

654 

655 

656class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

657 concatenable: bool = False 

658 """If a model has a `concatenable` input axis, it can be processed blockwise, 

659 splitting a longer sample axis into blocks matching its input tensor description. 

660 Output axes are concatenable if they have a `SizeReference` to a concatenable 

661 input axis. 

662 """ 

663 

664 

665INPUT_AXIS_TYPES = ( 

666 BatchAxis, 

667 ChannelAxis, 

668 IndexInputAxis, 

669 TimeInputAxis, 

670 SpaceInputAxis, 

671) 

672"""intended for isinstance comparisons in py<3.10""" 

673 

674_InputAxisUnion = Union[ 

675 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

676] 

677InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

678 

679 

680class _WithOutputAxisSize(Node): 

681 size: Annotated[ 

682 Union[Annotated[int, Gt(0)], SizeReference], 

683 Field( 

684 examples=[ 

685 10, 

686 SizeReference( 

687 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

688 ).model_dump(mode="json"), 

689 ] 

690 ), 

691 ] 

692 """The size/length of this axis can be specified as 

693 - fixed integer 

694 - reference to another axis with an optional offset (see `SizeReference`) 

695 """ 

696 

697 

698class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

699 pass 

700 

701 

702class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

703 pass 

704 

705 

706def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

707 if isinstance(v, dict): 

708 return "with_halo" if "halo" in v else "wo_halo" 

709 else: 

710 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

711 

712 

713_TimeOutputAxisUnion = Annotated[ 

714 Union[ 

715 Annotated[TimeOutputAxis, Tag("wo_halo")], 

716 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

717 ], 

718 Discriminator(_get_halo_axis_discriminator_value), 

719] 

720 

721 

722class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

723 pass 

724 

725 

726class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

727 pass 

728 

729 

730_SpaceOutputAxisUnion = Annotated[ 

731 Union[ 

732 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

733 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

734 ], 

735 Discriminator(_get_halo_axis_discriminator_value), 

736] 

737 

738 

739_OutputAxisUnion = Union[ 

740 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

741] 

742OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

743 

744OUTPUT_AXIS_TYPES = ( 

745 BatchAxis, 

746 ChannelAxis, 

747 IndexOutputAxis, 

748 TimeOutputAxis, 

749 TimeOutputAxisWithHalo, 

750 SpaceOutputAxis, 

751 SpaceOutputAxisWithHalo, 

752) 

753"""intended for isinstance comparisons in py<3.10""" 

754 

755 

756AnyAxis = Union[InputAxis, OutputAxis] 

757 

758ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

759"""intended for isinstance comparisons in py<3.10""" 

760 

761TVs = Union[ 

762 NotEmpty[List[int]], 

763 NotEmpty[List[float]], 

764 NotEmpty[List[bool]], 

765 NotEmpty[List[str]], 

766] 

767 

768 

769NominalOrOrdinalDType = Literal[ 

770 "float32", 

771 "float64", 

772 "uint8", 

773 "int8", 

774 "uint16", 

775 "int16", 

776 "uint32", 

777 "int32", 

778 "uint64", 

779 "int64", 

780 "bool", 

781] 

782 

783 

784class NominalOrOrdinalDataDescr(Node): 

785 values: TVs 

786 """A fixed set of nominal or an ascending sequence of ordinal values. 

787 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

788 String `values` are interpreted as labels for tensor values 0, ..., N. 

789 Note: as YAML 1.2 does not natively support a "set" datatype, 

790 nominal values should be given as a sequence (aka list/array) as well. 

791 """ 

792 

793 type: Annotated[ 

794 NominalOrOrdinalDType, 

795 Field( 

796 examples=[ 

797 "float32", 

798 "uint8", 

799 "uint16", 

800 "int64", 

801 "bool", 

802 ], 

803 ), 

804 ] = "uint8" 

805 

806 @model_validator(mode="after") 

807 def _validate_values_match_type( 

808 self, 

809 ) -> Self: 

810 incompatible: List[Any] = [] 

811 for v in self.values: 

812 if self.type == "bool": 

813 if not isinstance(v, bool): 

814 incompatible.append(v) 

815 elif self.type in DTYPE_LIMITS: 

816 if ( 

817 isinstance(v, (int, float)) 

818 and ( 

819 v < DTYPE_LIMITS[self.type].min 

820 or v > DTYPE_LIMITS[self.type].max 

821 ) 

822 or (isinstance(v, str) and "uint" not in self.type) 

823 or (isinstance(v, float) and "int" in self.type) 

824 ): 

825 incompatible.append(v) 

826 else: 

827 incompatible.append(v) 

828 

829 if len(incompatible) == 5: 

830 incompatible.append("...") 

831 break 

832 

833 if incompatible: 

834 raise ValueError( 

835 f"data type '{self.type}' incompatible with values {incompatible}" 

836 ) 

837 

838 return self 

839 

840 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

841 

842 @property 

843 def range(self): 

844 if isinstance(self.values[0], str): 

845 return 0, len(self.values) - 1 

846 else: 

847 return min(self.values), max(self.values) 

848 

849 

850IntervalOrRatioDType = Literal[ 

851 "float32", 

852 "float64", 

853 "uint8", 

854 "int8", 

855 "uint16", 

856 "int16", 

857 "uint32", 

858 "int32", 

859 "uint64", 

860 "int64", 

861] 

862 

863 

864class IntervalOrRatioDataDescr(Node): 

865 type: Annotated[ # todo: rename to dtype 

866 IntervalOrRatioDType, 

867 Field( 

868 examples=["float32", "float64", "uint8", "uint16"], 

869 ), 

870 ] = "float32" 

871 range: Tuple[Optional[float], Optional[float]] = ( 

872 None, 

873 None, 

874 ) 

875 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

876 `None` corresponds to min/max of what can be expressed by **type**.""" 

877 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

878 scale: float = 1.0 

879 """Scale for data on an interval (or ratio) scale.""" 

880 offset: Optional[float] = None 

881 """Offset for data on a ratio scale.""" 

882 

883 @model_validator(mode="before") 

884 def _replace_inf(cls, data: Any): 

885 if is_dict(data): 

886 if "range" in data and is_sequence(data["range"]): 

887 forbidden = ( 

888 "inf", 

889 "-inf", 

890 ".inf", 

891 "-.inf", 

892 float("inf"), 

893 float("-inf"), 

894 ) 

895 if any(v in forbidden for v in data["range"]): 

896 issue_warning("replaced 'inf' value", value=data["range"]) 

897 

898 data["range"] = tuple( 

899 (None if v in forbidden else v) for v in data["range"] 

900 ) 

901 

902 return data 

903 

904 

905TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

906 

907 

908class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 

909 """processing base class""" 

910 

911 

912class BinarizeKwargs(ProcessingKwargs): 

913 """key word arguments for `BinarizeDescr`""" 

914 

915 threshold: float 

916 """The fixed threshold""" 

917 

918 

919class BinarizeAlongAxisKwargs(ProcessingKwargs): 

920 """key word arguments for `BinarizeDescr`""" 

921 

922 threshold: NotEmpty[List[float]] 

923 """The fixed threshold values along `axis`""" 

924 

925 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

926 """The `threshold` axis""" 

927 

928 

929class BinarizeDescr(ProcessingDescrBase): 

930 """Binarize the tensor with a fixed threshold. 

931 

932 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 

933 will be set to one, values below the threshold to zero. 

934 

935 Examples: 

936 - in YAML 

937 ```yaml 

938 postprocessing: 

939 - id: binarize 

940 kwargs: 

941 axis: 'channel' 

942 threshold: [0.25, 0.5, 0.75] 

943 ``` 

944 - in Python: 

945 >>> postprocessing = [BinarizeDescr( 

946 ... kwargs=BinarizeAlongAxisKwargs( 

947 ... axis=AxisId('channel'), 

948 ... threshold=[0.25, 0.5, 0.75], 

949 ... ) 

950 ... )] 

951 """ 

952 

953 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

954 if TYPE_CHECKING: 

955 id: Literal["binarize"] = "binarize" 

956 else: 

957 id: Literal["binarize"] 

958 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

959 

960 

961class ClipDescr(ProcessingDescrBase): 

962 """Set tensor values below min to min and above max to max. 

963 

964 See `ScaleRangeDescr` for examples. 

965 """ 

966 

967 implemented_id: ClassVar[Literal["clip"]] = "clip" 

968 if TYPE_CHECKING: 

969 id: Literal["clip"] = "clip" 

970 else: 

971 id: Literal["clip"] 

972 

973 kwargs: ClipKwargs 

974 

975 

976class EnsureDtypeKwargs(ProcessingKwargs): 

977 """key word arguments for `EnsureDtypeDescr`""" 

978 

979 dtype: Literal[ 

980 "float32", 

981 "float64", 

982 "uint8", 

983 "int8", 

984 "uint16", 

985 "int16", 

986 "uint32", 

987 "int32", 

988 "uint64", 

989 "int64", 

990 "bool", 

991 ] 

992 

993 

994class EnsureDtypeDescr(ProcessingDescrBase): 

995 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

996 

997 This can for example be used to ensure the inner neural network model gets a 

998 different input tensor data type than the fully described bioimage.io model does. 

999 

1000 Examples: 

1001 The described bioimage.io model (incl. preprocessing) accepts any 

1002 float32-compatible tensor, normalizes it with percentiles and clipping and then 

1003 casts it to uint8, which is what the neural network in this example expects. 

1004 - in YAML 

1005 ```yaml 

1006 inputs: 

1007 - data: 

1008 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

1009 preprocessing: 

1010 - id: scale_range 

1011 kwargs: 

1012 axes: ['y', 'x'] 

1013 max_percentile: 99.8 

1014 min_percentile: 5.0 

1015 - id: clip 

1016 kwargs: 

1017 min: 0.0 

1018 max: 1.0 

1019 - id: ensure_dtype # the neural network of the model requires uint8 

1020 kwargs: 

1021 dtype: uint8 

1022 ``` 

1023 - in Python: 

1024 >>> preprocessing = [ 

1025 ... ScaleRangeDescr( 

1026 ... kwargs=ScaleRangeKwargs( 

1027 ... axes= (AxisId('y'), AxisId('x')), 

1028 ... max_percentile= 99.8, 

1029 ... min_percentile= 5.0, 

1030 ... ) 

1031 ... ), 

1032 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

1033 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

1034 ... ] 

1035 """ 

1036 

1037 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1038 if TYPE_CHECKING: 

1039 id: Literal["ensure_dtype"] = "ensure_dtype" 

1040 else: 

1041 id: Literal["ensure_dtype"] 

1042 

1043 kwargs: EnsureDtypeKwargs 

1044 

1045 

1046class ScaleLinearKwargs(ProcessingKwargs): 

1047 """Key word arguments for `ScaleLinearDescr`""" 

1048 

1049 gain: float = 1.0 

1050 """multiplicative factor""" 

1051 

1052 offset: float = 0.0 

1053 """additive term""" 

1054 

1055 @model_validator(mode="after") 

1056 def _validate(self) -> Self: 

1057 if self.gain == 1.0 and self.offset == 0.0: 

1058 raise ValueError( 

1059 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1060 + " != 0.0." 

1061 ) 

1062 

1063 return self 

1064 

1065 

1066class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 

1067 """Key word arguments for `ScaleLinearDescr`""" 

1068 

1069 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1070 """The axis of gain and offset values.""" 

1071 

1072 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1073 """multiplicative factor""" 

1074 

1075 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1076 """additive term""" 

1077 

1078 @model_validator(mode="after") 

1079 def _validate(self) -> Self: 

1080 

1081 if isinstance(self.gain, list): 

1082 if isinstance(self.offset, list): 

1083 if len(self.gain) != len(self.offset): 

1084 raise ValueError( 

1085 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1086 ) 

1087 else: 

1088 self.offset = [float(self.offset)] * len(self.gain) 

1089 elif isinstance(self.offset, list): 

1090 self.gain = [float(self.gain)] * len(self.offset) 

1091 else: 

1092 raise ValueError( 

1093 "Do not specify an `axis` for scalar gain and offset values." 

1094 ) 

1095 

1096 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1097 raise ValueError( 

1098 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1099 + " != 0.0." 

1100 ) 

1101 

1102 return self 

1103 

1104 

1105class ScaleLinearDescr(ProcessingDescrBase): 

1106 """Fixed linear scaling. 

1107 

1108 Examples: 

1109 1. Scale with scalar gain and offset 

1110 - in YAML 

1111 ```yaml 

1112 preprocessing: 

1113 - id: scale_linear 

1114 kwargs: 

1115 gain: 2.0 

1116 offset: 3.0 

1117 ``` 

1118 - in Python: 

1119 >>> preprocessing = [ 

1120 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1121 ... ] 

1122 

1123 2. Independent scaling along an axis 

1124 - in YAML 

1125 ```yaml 

1126 preprocessing: 

1127 - id: scale_linear 

1128 kwargs: 

1129 axis: 'channel' 

1130 gain: [1.0, 2.0, 3.0] 

1131 ``` 

1132 - in Python: 

1133 >>> preprocessing = [ 

1134 ... ScaleLinearDescr( 

1135 ... kwargs=ScaleLinearAlongAxisKwargs( 

1136 ... axis=AxisId("channel"), 

1137 ... gain=[1.0, 2.0, 3.0], 

1138 ... ) 

1139 ... ) 

1140 ... ] 

1141 

1142 """ 

1143 

1144 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1145 if TYPE_CHECKING: 

1146 id: Literal["scale_linear"] = "scale_linear" 

1147 else: 

1148 id: Literal["scale_linear"] 

1149 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1150 

1151 

1152class SigmoidDescr(ProcessingDescrBase): 

1153 """The logistic sigmoid funciton, a.k.a. expit function. 

1154 

1155 Examples: 

1156 - in YAML 

1157 ```yaml 

1158 postprocessing: 

1159 - id: sigmoid 

1160 ``` 

1161 - in Python: 

1162 >>> postprocessing = [SigmoidDescr()] 

1163 """ 

1164 

1165 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1166 if TYPE_CHECKING: 

1167 id: Literal["sigmoid"] = "sigmoid" 

1168 else: 

1169 id: Literal["sigmoid"] 

1170 

1171 @property 

1172 def kwargs(self) -> ProcessingKwargs: 

1173 """empty kwargs""" 

1174 return ProcessingKwargs() 

1175 

1176 

1177class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1178 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1179 

1180 mean: float 

1181 """The mean value to normalize with.""" 

1182 

1183 std: Annotated[float, Ge(1e-6)] 

1184 """The standard deviation value to normalize with.""" 

1185 

1186 

1187class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 

1188 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1189 

1190 mean: NotEmpty[List[float]] 

1191 """The mean value(s) to normalize with.""" 

1192 

1193 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1194 """The standard deviation value(s) to normalize with. 

1195 Size must match `mean` values.""" 

1196 

1197 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1198 """The axis of the mean/std values to normalize each entry along that dimension 

1199 separately.""" 

1200 

1201 @model_validator(mode="after") 

1202 def _mean_and_std_match(self) -> Self: 

1203 if len(self.mean) != len(self.std): 

1204 raise ValueError( 

1205 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1206 + " must match." 

1207 ) 

1208 

1209 return self 

1210 

1211 

1212class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1213 """Subtract a given mean and divide by the standard deviation. 

1214 

1215 Normalize with fixed, precomputed values for 

1216 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1217 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1218 axes. 

1219 

1220 Examples: 

1221 1. scalar value for whole tensor 

1222 - in YAML 

1223 ```yaml 

1224 preprocessing: 

1225 - id: fixed_zero_mean_unit_variance 

1226 kwargs: 

1227 mean: 103.5 

1228 std: 13.7 

1229 ``` 

1230 - in Python 

1231 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1232 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1233 ... )] 

1234 

1235 2. independently along an axis 

1236 - in YAML 

1237 ```yaml 

1238 preprocessing: 

1239 - id: fixed_zero_mean_unit_variance 

1240 kwargs: 

1241 axis: channel 

1242 mean: [101.5, 102.5, 103.5] 

1243 std: [11.7, 12.7, 13.7] 

1244 ``` 

1245 - in Python 

1246 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1247 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1248 ... axis=AxisId("channel"), 

1249 ... mean=[101.5, 102.5, 103.5], 

1250 ... std=[11.7, 12.7, 13.7], 

1251 ... ) 

1252 ... )] 

1253 """ 

1254 

1255 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1256 "fixed_zero_mean_unit_variance" 

1257 ) 

1258 if TYPE_CHECKING: 

1259 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1260 else: 

1261 id: Literal["fixed_zero_mean_unit_variance"] 

1262 

1263 kwargs: Union[ 

1264 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1265 ] 

1266 

1267 

1268class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1269 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

1270 

1271 axes: Annotated[ 

1272 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1273 ] = None 

1274 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1275 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1276 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1277 To normalize each sample independently leave out the 'batch' axis. 

1278 Default: Scale all axes jointly.""" 

1279 

1280 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1281 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1282 

1283 

1284class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1285 """Subtract mean and divide by variance. 

1286 

1287 Examples: 

1288 Subtract tensor mean and variance 

1289 - in YAML 

1290 ```yaml 

1291 preprocessing: 

1292 - id: zero_mean_unit_variance 

1293 ``` 

1294 - in Python 

1295 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1296 """ 

1297 

1298 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1299 "zero_mean_unit_variance" 

1300 ) 

1301 if TYPE_CHECKING: 

1302 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1303 else: 

1304 id: Literal["zero_mean_unit_variance"] 

1305 

1306 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1307 default_factory=ZeroMeanUnitVarianceKwargs 

1308 ) 

1309 

1310 

1311class ScaleRangeKwargs(ProcessingKwargs): 

1312 """key word arguments for `ScaleRangeDescr` 

1313 

1314 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1315 this processing step normalizes data to the [0, 1] intervall. 

1316 For other percentiles the normalized values will partially be outside the [0, 1] 

1317 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1318 normalized values to a range. 

1319 """ 

1320 

1321 axes: Annotated[ 

1322 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1323 ] = None 

1324 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1325 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1326 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1327 To normalize samples independently, leave out the "batch" axis. 

1328 Default: Scale all axes jointly.""" 

1329 

1330 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1331 """The lower percentile used to determine the value to align with zero.""" 

1332 

1333 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1334 """The upper percentile used to determine the value to align with one. 

1335 Has to be bigger than `min_percentile`. 

1336 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1337 accepting percentiles specified in the range 0.0 to 1.0.""" 

1338 

1339 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1340 """Epsilon for numeric stability. 

1341 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1342 with `v_lower,v_upper` values at the respective percentiles.""" 

1343 

1344 reference_tensor: Optional[TensorId] = None 

1345 """Tensor ID to compute the percentiles from. Default: The tensor itself. 

1346 For any tensor in `inputs` only input tensor references are allowed.""" 

1347 

1348 @field_validator("max_percentile", mode="after") 

1349 @classmethod 

1350 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1351 if (min_p := info.data["min_percentile"]) >= value: 

1352 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1353 

1354 return value 

1355 

1356 

1357class ScaleRangeDescr(ProcessingDescrBase): 

1358 """Scale with percentiles. 

1359 

1360 Examples: 

1361 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1362 - in YAML 

1363 ```yaml 

1364 preprocessing: 

1365 - id: scale_range 

1366 kwargs: 

1367 axes: ['y', 'x'] 

1368 max_percentile: 99.8 

1369 min_percentile: 5.0 

1370 ``` 

1371 - in Python 

1372 >>> preprocessing = [ 

1373 ... ScaleRangeDescr( 

1374 ... kwargs=ScaleRangeKwargs( 

1375 ... axes= (AxisId('y'), AxisId('x')), 

1376 ... max_percentile= 99.8, 

1377 ... min_percentile= 5.0, 

1378 ... ) 

1379 ... ), 

1380 ... ClipDescr( 

1381 ... kwargs=ClipKwargs( 

1382 ... min=0.0, 

1383 ... max=1.0, 

1384 ... ) 

1385 ... ), 

1386 ... ] 

1387 

1388 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1389 - in YAML 

1390 ```yaml 

1391 preprocessing: 

1392 - id: scale_range 

1393 kwargs: 

1394 axes: ['y', 'x'] 

1395 max_percentile: 99.8 

1396 min_percentile: 5.0 

1397 - id: scale_range 

1398 - id: clip 

1399 kwargs: 

1400 min: 0.0 

1401 max: 1.0 

1402 ``` 

1403 - in Python 

1404 >>> preprocessing = [ScaleRangeDescr( 

1405 ... kwargs=ScaleRangeKwargs( 

1406 ... axes= (AxisId('y'), AxisId('x')), 

1407 ... max_percentile= 99.8, 

1408 ... min_percentile= 5.0, 

1409 ... ) 

1410 ... )] 

1411 

1412 """ 

1413 

1414 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1415 if TYPE_CHECKING: 

1416 id: Literal["scale_range"] = "scale_range" 

1417 else: 

1418 id: Literal["scale_range"] 

1419 kwargs: ScaleRangeKwargs 

1420 

1421 

1422class ScaleMeanVarianceKwargs(ProcessingKwargs): 

1423 """key word arguments for `ScaleMeanVarianceKwargs`""" 

1424 

1425 reference_tensor: TensorId 

1426 """Name of tensor to match.""" 

1427 

1428 axes: Annotated[ 

1429 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1430 ] = None 

1431 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1432 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1433 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1434 To normalize samples independently, leave out the 'batch' axis. 

1435 Default: Scale all axes jointly.""" 

1436 

1437 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1438 """Epsilon for numeric stability: 

1439 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1440 

1441 

1442class ScaleMeanVarianceDescr(ProcessingDescrBase): 

1443 """Scale a tensor's data distribution to match another tensor's mean/std. 

1444 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1445 """ 

1446 

1447 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1448 if TYPE_CHECKING: 

1449 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1450 else: 

1451 id: Literal["scale_mean_variance"] 

1452 kwargs: ScaleMeanVarianceKwargs 

1453 

1454 

1455PreprocessingDescr = Annotated[ 

1456 Union[ 

1457 BinarizeDescr, 

1458 ClipDescr, 

1459 EnsureDtypeDescr, 

1460 ScaleLinearDescr, 

1461 SigmoidDescr, 

1462 FixedZeroMeanUnitVarianceDescr, 

1463 ZeroMeanUnitVarianceDescr, 

1464 ScaleRangeDescr, 

1465 ], 

1466 Discriminator("id"), 

1467] 

1468PostprocessingDescr = Annotated[ 

1469 Union[ 

1470 BinarizeDescr, 

1471 ClipDescr, 

1472 EnsureDtypeDescr, 

1473 ScaleLinearDescr, 

1474 SigmoidDescr, 

1475 FixedZeroMeanUnitVarianceDescr, 

1476 ZeroMeanUnitVarianceDescr, 

1477 ScaleRangeDescr, 

1478 ScaleMeanVarianceDescr, 

1479 ], 

1480 Discriminator("id"), 

1481] 

1482 

1483IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1484 

1485 

1486class TensorDescrBase(Node, Generic[IO_AxisT]): 

1487 id: TensorId 

1488 """Tensor id. No duplicates are allowed.""" 

1489 

1490 description: Annotated[str, MaxLen(128)] = "" 

1491 """free text description""" 

1492 

1493 axes: NotEmpty[Sequence[IO_AxisT]] 

1494 """tensor axes""" 

1495 

1496 @property 

1497 def shape(self): 

1498 return tuple(a.size for a in self.axes) 

1499 

1500 @field_validator("axes", mode="after", check_fields=False) 

1501 @classmethod 

1502 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1503 batch_axes = [a for a in axes if a.type == "batch"] 

1504 if len(batch_axes) > 1: 

1505 raise ValueError( 

1506 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1507 ) 

1508 

1509 seen_ids: Set[AxisId] = set() 

1510 duplicate_axes_ids: Set[AxisId] = set() 

1511 for a in axes: 

1512 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1513 

1514 if duplicate_axes_ids: 

1515 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1516 

1517 return axes 

1518 

1519 test_tensor: FileDescr_ 

1520 """An example tensor to use for testing. 

1521 Using the model with the test input tensors is expected to yield the test output tensors. 

1522 Each test tensor has be a an ndarray in the 

1523 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1524 The file extension must be '.npy'.""" 

1525 

1526 sample_tensor: Optional[FileDescr_] = None 

1527 """A sample tensor to illustrate a possible input/output for the model, 

1528 The sample image primarily serves to inform a human user about an example use case 

1529 and is typically stored as .hdf5, .png or .tiff. 

1530 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1531 (numpy's `.npy` format is not supported). 

1532 The image dimensionality has to match the number of axes specified in this tensor description. 

1533 """ 

1534 

1535 @model_validator(mode="after") 

1536 def _validate_sample_tensor(self) -> Self: 

1537 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1538 return self 

1539 

1540 reader = get_reader(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1541 tensor: NDArray[Any] = imread( 

1542 reader.read(), 

1543 extension=PurePosixPath(reader.original_file_name).suffix, 

1544 ) 

1545 n_dims = len(tensor.squeeze().shape) 

1546 n_dims_min = n_dims_max = len(self.axes) 

1547 

1548 for a in self.axes: 

1549 if isinstance(a, BatchAxis): 

1550 n_dims_min -= 1 

1551 elif isinstance(a.size, int): 

1552 if a.size == 1: 

1553 n_dims_min -= 1 

1554 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1555 if a.size.min == 1: 

1556 n_dims_min -= 1 

1557 elif isinstance(a.size, SizeReference): 

1558 if a.size.offset < 2: 

1559 # size reference may result in singleton axis 

1560 n_dims_min -= 1 

1561 else: 

1562 assert_never(a.size) 

1563 

1564 n_dims_min = max(0, n_dims_min) 

1565 if n_dims < n_dims_min or n_dims > n_dims_max: 

1566 raise ValueError( 

1567 f"Expected sample tensor to have {n_dims_min} to" 

1568 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1569 ) 

1570 

1571 return self 

1572 

1573 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1574 IntervalOrRatioDataDescr() 

1575 ) 

1576 """Description of the tensor's data values, optionally per channel. 

1577 If specified per channel, the data `type` needs to match across channels.""" 

1578 

1579 @property 

1580 def dtype( 

1581 self, 

1582 ) -> Literal[ 

1583 "float32", 

1584 "float64", 

1585 "uint8", 

1586 "int8", 

1587 "uint16", 

1588 "int16", 

1589 "uint32", 

1590 "int32", 

1591 "uint64", 

1592 "int64", 

1593 "bool", 

1594 ]: 

1595 """dtype as specified under `data.type` or `data[i].type`""" 

1596 if isinstance(self.data, collections.abc.Sequence): 

1597 return self.data[0].type 

1598 else: 

1599 return self.data.type 

1600 

1601 @field_validator("data", mode="after") 

1602 @classmethod 

1603 def _check_data_type_across_channels( 

1604 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1605 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1606 if not isinstance(value, list): 

1607 return value 

1608 

1609 dtypes = {t.type for t in value} 

1610 if len(dtypes) > 1: 

1611 raise ValueError( 

1612 "Tensor data descriptions per channel need to agree in their data" 

1613 + f" `type`, but found {dtypes}." 

1614 ) 

1615 

1616 return value 

1617 

1618 @model_validator(mode="after") 

1619 def _check_data_matches_channelaxis(self) -> Self: 

1620 if not isinstance(self.data, (list, tuple)): 

1621 return self 

1622 

1623 for a in self.axes: 

1624 if isinstance(a, ChannelAxis): 

1625 size = a.size 

1626 assert isinstance(size, int) 

1627 break 

1628 else: 

1629 return self 

1630 

1631 if len(self.data) != size: 

1632 raise ValueError( 

1633 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1634 + f" '{a.id}' axis has size {size}." 

1635 ) 

1636 

1637 return self 

1638 

1639 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1640 if len(array.shape) != len(self.axes): 

1641 raise ValueError( 

1642 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1643 + f" incompatible with {len(self.axes)} axes." 

1644 ) 

1645 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1646 

1647 

1648class InputTensorDescr(TensorDescrBase[InputAxis]): 

1649 id: TensorId = TensorId("input") 

1650 """Input tensor id. 

1651 No duplicates are allowed across all inputs and outputs.""" 

1652 

1653 optional: bool = False 

1654 """indicates that this tensor may be `None`""" 

1655 

1656 preprocessing: List[PreprocessingDescr] = Field( 

1657 default_factory=cast(Callable[[], List[PreprocessingDescr]], list) 

1658 ) 

1659 

1660 """Description of how this input should be preprocessed. 

1661 

1662 notes: 

1663 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1664 to ensure an input tensor's data type matches the input tensor's data description. 

1665 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1666 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1667 changing the data type. 

1668 """ 

1669 

1670 @model_validator(mode="after") 

1671 def _validate_preprocessing_kwargs(self) -> Self: 

1672 axes_ids = [a.id for a in self.axes] 

1673 for p in self.preprocessing: 

1674 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1675 if kwargs_axes is None: 

1676 continue 

1677 

1678 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1679 raise ValueError( 

1680 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1681 ) 

1682 

1683 if any(a not in axes_ids for a in kwargs_axes): 

1684 raise ValueError( 

1685 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1686 ) 

1687 

1688 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1689 dtype = self.data.type 

1690 else: 

1691 dtype = self.data[0].type 

1692 

1693 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1694 if not self.preprocessing or not isinstance( 

1695 self.preprocessing[0], EnsureDtypeDescr 

1696 ): 

1697 self.preprocessing.insert( 

1698 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1699 ) 

1700 

1701 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1702 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1703 self.preprocessing.append( 

1704 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1705 ) 

1706 

1707 return self 

1708 

1709 

1710def convert_axes( 

1711 axes: str, 

1712 *, 

1713 shape: Union[ 

1714 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1715 ], 

1716 tensor_type: Literal["input", "output"], 

1717 halo: Optional[Sequence[int]], 

1718 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1719): 

1720 ret: List[AnyAxis] = [] 

1721 for i, a in enumerate(axes): 

1722 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1723 if axis_type == "batch": 

1724 ret.append(BatchAxis()) 

1725 continue 

1726 

1727 scale = 1.0 

1728 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1729 if shape.step[i] == 0: 

1730 size = shape.min[i] 

1731 else: 

1732 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1733 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1734 ref_t = str(shape.reference_tensor) 

1735 if ref_t.count(".") == 1: 

1736 t_id, orig_a_id = ref_t.split(".") 

1737 else: 

1738 t_id = ref_t 

1739 orig_a_id = a 

1740 

1741 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1742 if not (orig_scale := shape.scale[i]): 

1743 # old way to insert a new axis dimension 

1744 size = int(2 * shape.offset[i]) 

1745 else: 

1746 scale = 1 / orig_scale 

1747 if axis_type in ("channel", "index"): 

1748 # these axes no longer have a scale 

1749 offset_from_scale = orig_scale * size_refs.get( 

1750 _TensorName_v0_4(t_id), {} 

1751 ).get(orig_a_id, 0) 

1752 else: 

1753 offset_from_scale = 0 

1754 size = SizeReference( 

1755 tensor_id=TensorId(t_id), 

1756 axis_id=AxisId(a_id), 

1757 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1758 ) 

1759 else: 

1760 size = shape[i] 

1761 

1762 if axis_type == "time": 

1763 if tensor_type == "input": 

1764 ret.append(TimeInputAxis(size=size, scale=scale)) 

1765 else: 

1766 assert not isinstance(size, ParameterizedSize) 

1767 if halo is None: 

1768 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1769 else: 

1770 assert not isinstance(size, int) 

1771 ret.append( 

1772 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1773 ) 

1774 

1775 elif axis_type == "index": 

1776 if tensor_type == "input": 

1777 ret.append(IndexInputAxis(size=size)) 

1778 else: 

1779 if isinstance(size, ParameterizedSize): 

1780 size = DataDependentSize(min=size.min) 

1781 

1782 ret.append(IndexOutputAxis(size=size)) 

1783 elif axis_type == "channel": 

1784 assert not isinstance(size, ParameterizedSize) 

1785 if isinstance(size, SizeReference): 

1786 warnings.warn( 

1787 "Conversion of channel size from an implicit output shape may be" 

1788 + " wrong" 

1789 ) 

1790 ret.append( 

1791 ChannelAxis( 

1792 channel_names=[ 

1793 Identifier(f"channel{i}") for i in range(size.offset) 

1794 ] 

1795 ) 

1796 ) 

1797 else: 

1798 ret.append( 

1799 ChannelAxis( 

1800 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1801 ) 

1802 ) 

1803 elif axis_type == "space": 

1804 if tensor_type == "input": 

1805 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

1806 else: 

1807 assert not isinstance(size, ParameterizedSize) 

1808 if halo is None or halo[i] == 0: 

1809 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

1810 elif isinstance(size, int): 

1811 raise NotImplementedError( 

1812 f"output axis with halo and fixed size (here {size}) not allowed" 

1813 ) 

1814 else: 

1815 ret.append( 

1816 SpaceOutputAxisWithHalo( 

1817 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

1818 ) 

1819 ) 

1820 

1821 return ret 

1822 

1823 

1824def _axes_letters_to_ids( 

1825 axes: Optional[str], 

1826) -> Optional[List[AxisId]]: 

1827 if axes is None: 

1828 return None 

1829 

1830 return [AxisId(a) for a in axes] 

1831 

1832 

1833def _get_complement_v04_axis( 

1834 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

1835) -> Optional[AxisId]: 

1836 if axes is None: 

1837 return None 

1838 

1839 non_complement_axes = set(axes) | {"b"} 

1840 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

1841 if len(complement_axes) > 1: 

1842 raise ValueError( 

1843 f"Expected none or a single complement axis, but axes '{axes}' " 

1844 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

1845 ) 

1846 

1847 return None if not complement_axes else AxisId(complement_axes[0]) 

1848 

1849 

1850def _convert_proc( 

1851 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

1852 tensor_axes: Sequence[str], 

1853) -> Union[PreprocessingDescr, PostprocessingDescr]: 

1854 if isinstance(p, _BinarizeDescr_v0_4): 

1855 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

1856 elif isinstance(p, _ClipDescr_v0_4): 

1857 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

1858 elif isinstance(p, _SigmoidDescr_v0_4): 

1859 return SigmoidDescr() 

1860 elif isinstance(p, _ScaleLinearDescr_v0_4): 

1861 axes = _axes_letters_to_ids(p.kwargs.axes) 

1862 if p.kwargs.axes is None: 

1863 axis = None 

1864 else: 

1865 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1866 

1867 if axis is None: 

1868 assert not isinstance(p.kwargs.gain, list) 

1869 assert not isinstance(p.kwargs.offset, list) 

1870 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

1871 else: 

1872 kwargs = ScaleLinearAlongAxisKwargs( 

1873 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

1874 ) 

1875 return ScaleLinearDescr(kwargs=kwargs) 

1876 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

1877 return ScaleMeanVarianceDescr( 

1878 kwargs=ScaleMeanVarianceKwargs( 

1879 axes=_axes_letters_to_ids(p.kwargs.axes), 

1880 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

1881 eps=p.kwargs.eps, 

1882 ) 

1883 ) 

1884 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

1885 if p.kwargs.mode == "fixed": 

1886 mean = p.kwargs.mean 

1887 std = p.kwargs.std 

1888 assert mean is not None 

1889 assert std is not None 

1890 

1891 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1892 

1893 if axis is None: 

1894 return FixedZeroMeanUnitVarianceDescr( 

1895 kwargs=FixedZeroMeanUnitVarianceKwargs( 

1896 mean=mean, std=std # pyright: ignore[reportArgumentType] 

1897 ) 

1898 ) 

1899 else: 

1900 if not isinstance(mean, list): 

1901 mean = [float(mean)] 

1902 if not isinstance(std, list): 

1903 std = [float(std)] 

1904 

1905 return FixedZeroMeanUnitVarianceDescr( 

1906 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1907 axis=axis, mean=mean, std=std 

1908 ) 

1909 ) 

1910 

1911 else: 

1912 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

1913 if p.kwargs.mode == "per_dataset": 

1914 axes = [AxisId("batch")] + axes 

1915 if not axes: 

1916 axes = None 

1917 return ZeroMeanUnitVarianceDescr( 

1918 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

1919 ) 

1920 

1921 elif isinstance(p, _ScaleRangeDescr_v0_4): 

1922 return ScaleRangeDescr( 

1923 kwargs=ScaleRangeKwargs( 

1924 axes=_axes_letters_to_ids(p.kwargs.axes), 

1925 min_percentile=p.kwargs.min_percentile, 

1926 max_percentile=p.kwargs.max_percentile, 

1927 eps=p.kwargs.eps, 

1928 ) 

1929 ) 

1930 else: 

1931 assert_never(p) 

1932 

1933 

1934class _InputTensorConv( 

1935 Converter[ 

1936 _InputTensorDescr_v0_4, 

1937 InputTensorDescr, 

1938 FileSource_, 

1939 Optional[FileSource_], 

1940 Mapping[_TensorName_v0_4, Mapping[str, int]], 

1941 ] 

1942): 

1943 def _convert( 

1944 self, 

1945 src: _InputTensorDescr_v0_4, 

1946 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

1947 test_tensor: FileSource_, 

1948 sample_tensor: Optional[FileSource_], 

1949 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1950 ) -> "InputTensorDescr | dict[str, Any]": 

1951 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

1952 src.axes, 

1953 shape=src.shape, 

1954 tensor_type="input", 

1955 halo=None, 

1956 size_refs=size_refs, 

1957 ) 

1958 prep: List[PreprocessingDescr] = [] 

1959 for p in src.preprocessing: 

1960 cp = _convert_proc(p, src.axes) 

1961 assert not isinstance(cp, ScaleMeanVarianceDescr) 

1962 prep.append(cp) 

1963 

1964 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

1965 

1966 return tgt( 

1967 axes=axes, 

1968 id=TensorId(str(src.name)), 

1969 test_tensor=FileDescr(source=test_tensor), 

1970 sample_tensor=( 

1971 None if sample_tensor is None else FileDescr(source=sample_tensor) 

1972 ), 

1973 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

1974 preprocessing=prep, 

1975 ) 

1976 

1977 

1978_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

1979 

1980 

1981class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

1982 id: TensorId = TensorId("output") 

1983 """Output tensor id. 

1984 No duplicates are allowed across all inputs and outputs.""" 

1985 

1986 postprocessing: List[PostprocessingDescr] = Field( 

1987 default_factory=cast(Callable[[], List[PostprocessingDescr]], list) 

1988 ) 

1989 """Description of how this output should be postprocessed. 

1990 

1991 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

1992 If not given this is added to cast to this tensor's `data.type`. 

1993 """ 

1994 

1995 @model_validator(mode="after") 

1996 def _validate_postprocessing_kwargs(self) -> Self: 

1997 axes_ids = [a.id for a in self.axes] 

1998 for p in self.postprocessing: 

1999 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

2000 if kwargs_axes is None: 

2001 continue 

2002 

2003 if not isinstance(kwargs_axes, collections.abc.Sequence): 

2004 raise ValueError( 

2005 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

2006 ) 

2007 

2008 if any(a not in axes_ids for a in kwargs_axes): 

2009 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

2010 

2011 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

2012 dtype = self.data.type 

2013 else: 

2014 dtype = self.data[0].type 

2015 

2016 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

2017 if not self.postprocessing or not isinstance( 

2018 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

2019 ): 

2020 self.postprocessing.append( 

2021 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

2022 ) 

2023 return self 

2024 

2025 

2026class _OutputTensorConv( 

2027 Converter[ 

2028 _OutputTensorDescr_v0_4, 

2029 OutputTensorDescr, 

2030 FileSource_, 

2031 Optional[FileSource_], 

2032 Mapping[_TensorName_v0_4, Mapping[str, int]], 

2033 ] 

2034): 

2035 def _convert( 

2036 self, 

2037 src: _OutputTensorDescr_v0_4, 

2038 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

2039 test_tensor: FileSource_, 

2040 sample_tensor: Optional[FileSource_], 

2041 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2042 ) -> "OutputTensorDescr | dict[str, Any]": 

2043 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2044 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2045 src.axes, 

2046 shape=src.shape, 

2047 tensor_type="output", 

2048 halo=src.halo, 

2049 size_refs=size_refs, 

2050 ) 

2051 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2052 if data_descr["type"] == "bool": 

2053 data_descr["values"] = [False, True] 

2054 

2055 return tgt( 

2056 axes=axes, 

2057 id=TensorId(str(src.name)), 

2058 test_tensor=FileDescr(source=test_tensor), 

2059 sample_tensor=( 

2060 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2061 ), 

2062 data=data_descr, # pyright: ignore[reportArgumentType] 

2063 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2064 ) 

2065 

2066 

2067_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2068 

2069 

2070TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2071 

2072 

2073def validate_tensors( 

2074 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 

2075 tensor_origin: Literal[ 

2076 "test_tensor" 

2077 ], # for more precise error messages, e.g. 'test_tensor' 

2078): 

2079 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 

2080 

2081 def e_msg(d: TensorDescr): 

2082 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2083 

2084 for descr, array in tensors.values(): 

2085 try: 

2086 axis_sizes = descr.get_axis_sizes_for_array(array) 

2087 except ValueError as e: 

2088 raise ValueError(f"{e_msg(descr)} {e}") 

2089 else: 

2090 all_tensor_axes[descr.id] = { 

2091 a.id: (a, axis_sizes[a.id]) for a in descr.axes 

2092 } 

2093 

2094 for descr, array in tensors.values(): 

2095 if descr.dtype in ("float32", "float64"): 

2096 invalid_test_tensor_dtype = array.dtype.name not in ( 

2097 "float32", 

2098 "float64", 

2099 "uint8", 

2100 "int8", 

2101 "uint16", 

2102 "int16", 

2103 "uint32", 

2104 "int32", 

2105 "uint64", 

2106 "int64", 

2107 ) 

2108 else: 

2109 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2110 

2111 if invalid_test_tensor_dtype: 

2112 raise ValueError( 

2113 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 

2114 + f" match described dtype '{descr.dtype}'" 

2115 ) 

2116 

2117 if array.min() > -1e-4 and array.max() < 1e-4: 

2118 raise ValueError( 

2119 "Output values are too small for reliable testing." 

2120 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2121 ) 

2122 

2123 for a in descr.axes: 

2124 actual_size = all_tensor_axes[descr.id][a.id][1] 

2125 if a.size is None: 

2126 continue 

2127 

2128 if isinstance(a.size, int): 

2129 if actual_size != a.size: 

2130 raise ValueError( 

2131 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 

2132 + f"has incompatible size {actual_size}, expected {a.size}" 

2133 ) 

2134 elif isinstance(a.size, ParameterizedSize): 

2135 _ = a.size.validate_size(actual_size) 

2136 elif isinstance(a.size, DataDependentSize): 

2137 _ = a.size.validate_size(actual_size) 

2138 elif isinstance(a.size, SizeReference): 

2139 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2140 if ref_tensor_axes is None: 

2141 raise ValueError( 

2142 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2143 + f" reference '{a.size.tensor_id}'" 

2144 ) 

2145 

2146 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2147 if ref_axis is None or ref_size is None: 

2148 raise ValueError( 

2149 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2150 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 

2151 ) 

2152 

2153 if a.unit != ref_axis.unit: 

2154 raise ValueError( 

2155 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2156 + " axis and reference axis to have the same `unit`, but" 

2157 + f" {a.unit}!={ref_axis.unit}" 

2158 ) 

2159 

2160 if actual_size != ( 

2161 expected_size := ( 

2162 ref_size * ref_axis.scale / a.scale + a.size.offset 

2163 ) 

2164 ): 

2165 raise ValueError( 

2166 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2167 + f" {actual_size} invalid for referenced size {ref_size};" 

2168 + f" expected {expected_size}" 

2169 ) 

2170 else: 

2171 assert_never(a.size) 

2172 

2173 

2174FileDescr_dependencies = Annotated[ 

2175 FileDescr_, 

2176 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2177 Field(examples=[dict(source="environment.yaml")]), 

2178] 

2179 

2180 

2181class _ArchitectureCallableDescr(Node): 

2182 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2183 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2184 

2185 kwargs: Dict[str, YamlValue] = Field( 

2186 default_factory=cast(Callable[[], Dict[str, YamlValue]], dict) 

2187 ) 

2188 """key word arguments for the `callable`""" 

2189 

2190 

2191class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2192 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2193 """Architecture source file""" 

2194 

2195 @model_serializer(mode="wrap", when_used="unless-none") 

2196 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2197 return package_file_descr_serializer(self, nxt, info) 

2198 

2199 

2200class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2201 import_from: str 

2202 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2203 

2204 

2205class _ArchFileConv( 

2206 Converter[ 

2207 _CallableFromFile_v0_4, 

2208 ArchitectureFromFileDescr, 

2209 Optional[Sha256], 

2210 Dict[str, Any], 

2211 ] 

2212): 

2213 def _convert( 

2214 self, 

2215 src: _CallableFromFile_v0_4, 

2216 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2217 sha256: Optional[Sha256], 

2218 kwargs: Dict[str, Any], 

2219 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2220 if src.startswith("http") and src.count(":") == 2: 

2221 http, source, callable_ = src.split(":") 

2222 source = ":".join((http, source)) 

2223 elif not src.startswith("http") and src.count(":") == 1: 

2224 source, callable_ = src.split(":") 

2225 else: 

2226 source = str(src) 

2227 callable_ = str(src) 

2228 return tgt( 

2229 callable=Identifier(callable_), 

2230 source=cast(FileSource_, source), 

2231 sha256=sha256, 

2232 kwargs=kwargs, 

2233 ) 

2234 

2235 

2236_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2237 

2238 

2239class _ArchLibConv( 

2240 Converter[ 

2241 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2242 ] 

2243): 

2244 def _convert( 

2245 self, 

2246 src: _CallableFromDepencency_v0_4, 

2247 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2248 kwargs: Dict[str, Any], 

2249 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2250 *mods, callable_ = src.split(".") 

2251 import_from = ".".join(mods) 

2252 return tgt( 

2253 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2254 ) 

2255 

2256 

2257_arch_lib_conv = _ArchLibConv( 

2258 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2259) 

2260 

2261 

2262class WeightsEntryDescrBase(FileDescr): 

2263 type: ClassVar[WeightsFormat] 

2264 weights_format_name: ClassVar[str] # human readable 

2265 

2266 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2267 """Source of the weights file.""" 

2268 

2269 authors: Optional[List[Author]] = None 

2270 """Authors 

2271 Either the person(s) that have trained this model resulting in the original weights file. 

2272 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2273 Or the person(s) who have converted the weights to this weights format. 

2274 (If this is a child weight, i.e. it has a `parent` field) 

2275 """ 

2276 

2277 parent: Annotated[ 

2278 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2279 ] = None 

2280 """The source weights these weights were converted from. 

2281 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2282 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2283 All weight entries except one (the initial set of weights resulting from training the model), 

2284 need to have this field.""" 

2285 

2286 comment: str = "" 

2287 """A comment about this weights entry, for example how these weights were created.""" 

2288 

2289 @model_validator(mode="after") 

2290 def _validate(self) -> Self: 

2291 if self.type == self.parent: 

2292 raise ValueError("Weights entry can't be it's own parent.") 

2293 

2294 return self 

2295 

2296 @model_serializer(mode="wrap", when_used="unless-none") 

2297 def _serialize(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo): 

2298 return package_file_descr_serializer(self, nxt, info) 

2299 

2300 

2301class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2302 type = "keras_hdf5" 

2303 weights_format_name: ClassVar[str] = "Keras HDF5" 

2304 tensorflow_version: Version 

2305 """TensorFlow version used to create these weights.""" 

2306 

2307 

2308class OnnxWeightsDescr(WeightsEntryDescrBase): 

2309 type = "onnx" 

2310 weights_format_name: ClassVar[str] = "ONNX" 

2311 opset_version: Annotated[int, Ge(7)] 

2312 """ONNX opset version""" 

2313 

2314 

2315class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2316 type = "pytorch_state_dict" 

2317 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2318 architecture: Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr] 

2319 pytorch_version: Version 

2320 """Version of the PyTorch library used. 

2321 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2322 """ 

2323 dependencies: Optional[FileDescr_dependencies] = None 

2324 """Custom depencies beyond pytorch described in a Conda environment file. 

2325 Allows to specify custom dependencies, see conda docs: 

2326 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2327 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2328 

2329 The conda environment file should include pytorch and any version pinning has to be compatible with 

2330 **pytorch_version**. 

2331 """ 

2332 

2333 

2334class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2335 type = "tensorflow_js" 

2336 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2337 tensorflow_version: Version 

2338 """Version of the TensorFlow library used.""" 

2339 

2340 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2341 """The multi-file weights. 

2342 All required files/folders should be a zip archive.""" 

2343 

2344 

2345class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2346 type = "tensorflow_saved_model_bundle" 

2347 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2348 tensorflow_version: Version 

2349 """Version of the TensorFlow library used.""" 

2350 

2351 dependencies: Optional[FileDescr_dependencies] = None 

2352 """Custom dependencies beyond tensorflow. 

2353 Should include tensorflow and any version pinning has to be compatible with **tensorflow_version**.""" 

2354 

2355 source: Annotated[FileSource, AfterValidator(wo_special_file_name)] 

2356 """The multi-file weights. 

2357 All required files/folders should be a zip archive.""" 

2358 

2359 

2360class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2361 type = "torchscript" 

2362 weights_format_name: ClassVar[str] = "TorchScript" 

2363 pytorch_version: Version 

2364 """Version of the PyTorch library used.""" 

2365 

2366 

2367class WeightsDescr(Node): 

2368 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2369 onnx: Optional[OnnxWeightsDescr] = None 

2370 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2371 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2372 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2373 None 

2374 ) 

2375 torchscript: Optional[TorchscriptWeightsDescr] = None 

2376 

2377 @model_validator(mode="after") 

2378 def check_entries(self) -> Self: 

2379 entries = {wtype for wtype, entry in self if entry is not None} 

2380 

2381 if not entries: 

2382 raise ValueError("Missing weights entry") 

2383 

2384 entries_wo_parent = { 

2385 wtype 

2386 for wtype, entry in self 

2387 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2388 } 

2389 if len(entries_wo_parent) != 1: 

2390 issue_warning( 

2391 "Exactly one weights entry may not specify the `parent` field (got" 

2392 + " {value}). That entry is considered the original set of model weights." 

2393 + " Other weight formats are created through conversion of the orignal or" 

2394 + " already converted weights. They have to reference the weights format" 

2395 + " they were converted from as their `parent`.", 

2396 value=len(entries_wo_parent), 

2397 field="weights", 

2398 ) 

2399 

2400 for wtype, entry in self: 

2401 if entry is None: 

2402 continue 

2403 

2404 assert hasattr(entry, "type") 

2405 assert hasattr(entry, "parent") 

2406 assert wtype == entry.type 

2407 if ( 

2408 entry.parent is not None and entry.parent not in entries 

2409 ): # self reference checked for `parent` field 

2410 raise ValueError( 

2411 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2412 + f" formats: {entries}" 

2413 ) 

2414 

2415 return self 

2416 

2417 def __getitem__( 

2418 self, 

2419 key: Literal[ 

2420 "keras_hdf5", 

2421 "onnx", 

2422 "pytorch_state_dict", 

2423 "tensorflow_js", 

2424 "tensorflow_saved_model_bundle", 

2425 "torchscript", 

2426 ], 

2427 ): 

2428 if key == "keras_hdf5": 

2429 ret = self.keras_hdf5 

2430 elif key == "onnx": 

2431 ret = self.onnx 

2432 elif key == "pytorch_state_dict": 

2433 ret = self.pytorch_state_dict 

2434 elif key == "tensorflow_js": 

2435 ret = self.tensorflow_js 

2436 elif key == "tensorflow_saved_model_bundle": 

2437 ret = self.tensorflow_saved_model_bundle 

2438 elif key == "torchscript": 

2439 ret = self.torchscript 

2440 else: 

2441 raise KeyError(key) 

2442 

2443 if ret is None: 

2444 raise KeyError(key) 

2445 

2446 return ret 

2447 

2448 @property 

2449 def available_formats(self): 

2450 return { 

2451 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2452 **({} if self.onnx is None else {"onnx": self.onnx}), 

2453 **( 

2454 {} 

2455 if self.pytorch_state_dict is None 

2456 else {"pytorch_state_dict": self.pytorch_state_dict} 

2457 ), 

2458 **( 

2459 {} 

2460 if self.tensorflow_js is None 

2461 else {"tensorflow_js": self.tensorflow_js} 

2462 ), 

2463 **( 

2464 {} 

2465 if self.tensorflow_saved_model_bundle is None 

2466 else { 

2467 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2468 } 

2469 ), 

2470 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2471 } 

2472 

2473 @property 

2474 def missing_formats(self): 

2475 return { 

2476 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2477 } 

2478 

2479 

2480class ModelId(ResourceId): 

2481 pass 

2482 

2483 

2484class LinkedModel(LinkedResourceBase): 

2485 """Reference to a bioimage.io model.""" 

2486 

2487 id: ModelId 

2488 """A valid model `id` from the bioimage.io collection.""" 

2489 

2490 

2491class _DataDepSize(NamedTuple): 

2492 min: StrictInt 

2493 max: Optional[StrictInt] 

2494 

2495 

2496class _AxisSizes(NamedTuple): 

2497 """the lenghts of all axes of model inputs and outputs""" 

2498 

2499 inputs: Dict[Tuple[TensorId, AxisId], int] 

2500 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2501 

2502 

2503class _TensorSizes(NamedTuple): 

2504 """_AxisSizes as nested dicts""" 

2505 

2506 inputs: Dict[TensorId, Dict[AxisId, int]] 

2507 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2508 

2509 

2510class ReproducibilityTolerance(Node, extra="allow"): 

2511 """Describes what small numerical differences -- if any -- may be tolerated 

2512 in the generated output when executing in different environments. 

2513 

2514 A tensor element *output* is considered mismatched to the **test_tensor** if 

2515 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

2516 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

2517 

2518 Motivation: 

2519 For testing we can request the respective deep learning frameworks to be as 

2520 reproducible as possible by setting seeds and chosing deterministic algorithms, 

2521 but differences in operating systems, available hardware and installed drivers 

2522 may still lead to numerical differences. 

2523 """ 

2524 

2525 relative_tolerance: RelativeTolerance = 1e-3 

2526 """Maximum relative tolerance of reproduced test tensor.""" 

2527 

2528 absolute_tolerance: AbsoluteTolerance = 1e-4 

2529 """Maximum absolute tolerance of reproduced test tensor.""" 

2530 

2531 mismatched_elements_per_million: MismatchedElementsPerMillion = 100 

2532 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

2533 

2534 output_ids: Sequence[TensorId] = () 

2535 """Limits the output tensor IDs these reproducibility details apply to.""" 

2536 

2537 weights_formats: Sequence[WeightsFormat] = () 

2538 """Limits the weights formats these details apply to.""" 

2539 

2540 

2541class BioimageioConfig(Node, extra="allow"): 

2542 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

2543 """Tolerances to allow when reproducing the model's test outputs 

2544 from the model's test inputs. 

2545 Only the first entry matching tensor id and weights format is considered. 

2546 """ 

2547 

2548 

2549class Config(Node, extra="allow"): 

2550 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig) 

2551 

2552 

2553class ModelDescr(GenericModelDescrBase): 

2554 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

2555 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

2556 """ 

2557 

2558 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 

2559 if TYPE_CHECKING: 

2560 format_version: Literal["0.5.4"] = "0.5.4" 

2561 else: 

2562 format_version: Literal["0.5.4"] 

2563 """Version of the bioimage.io model description specification used. 

2564 When creating a new model always use the latest micro/patch version described here. 

2565 The `format_version` is important for any consumer software to understand how to parse the fields. 

2566 """ 

2567 

2568 implemented_type: ClassVar[Literal["model"]] = "model" 

2569 if TYPE_CHECKING: 

2570 type: Literal["model"] = "model" 

2571 else: 

2572 type: Literal["model"] 

2573 """Specialized resource type 'model'""" 

2574 

2575 id: Optional[ModelId] = None 

2576 """bioimage.io-wide unique resource identifier 

2577 assigned by bioimage.io; version **un**specific.""" 

2578 

2579 authors: NotEmpty[List[Author]] 

2580 """The authors are the creators of the model RDF and the primary points of contact.""" 

2581 

2582 documentation: FileSource_documentation 

2583 """URL or relative path to a markdown file with additional documentation. 

2584 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

2585 The documentation should include a '#[#] Validation' (sub)section 

2586 with details on how to quantitatively validate the model on unseen data.""" 

2587 

2588 @field_validator("documentation", mode="after") 

2589 @classmethod 

2590 def _validate_documentation( 

2591 cls, value: FileSource_documentation 

2592 ) -> FileSource_documentation: 

2593 if not get_validation_context().perform_io_checks: 

2594 return value 

2595 

2596 doc_reader = get_reader(value) 

2597 doc_content = doc_reader.read().decode(encoding="utf-8") 

2598 if not re.search("#.*[vV]alidation", doc_content): 

2599 issue_warning( 

2600 "No '# Validation' (sub)section found in {value}.", 

2601 value=value, 

2602 field="documentation", 

2603 ) 

2604 

2605 return value 

2606 

2607 inputs: NotEmpty[Sequence[InputTensorDescr]] 

2608 """Describes the input tensors expected by this model.""" 

2609 

2610 @field_validator("inputs", mode="after") 

2611 @classmethod 

2612 def _validate_input_axes( 

2613 cls, inputs: Sequence[InputTensorDescr] 

2614 ) -> Sequence[InputTensorDescr]: 

2615 input_size_refs = cls._get_axes_with_independent_size(inputs) 

2616 

2617 for i, ipt in enumerate(inputs): 

2618 valid_independent_refs: Dict[ 

2619 Tuple[TensorId, AxisId], 

2620 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2621 ] = { 

2622 **{ 

2623 (ipt.id, a.id): (ipt, a, a.size) 

2624 for a in ipt.axes 

2625 if not isinstance(a, BatchAxis) 

2626 and isinstance(a.size, (int, ParameterizedSize)) 

2627 }, 

2628 **input_size_refs, 

2629 } 

2630 for a, ax in enumerate(ipt.axes): 

2631 cls._validate_axis( 

2632 "inputs", 

2633 i=i, 

2634 tensor_id=ipt.id, 

2635 a=a, 

2636 axis=ax, 

2637 valid_independent_refs=valid_independent_refs, 

2638 ) 

2639 return inputs 

2640 

2641 @staticmethod 

2642 def _validate_axis( 

2643 field_name: str, 

2644 i: int, 

2645 tensor_id: TensorId, 

2646 a: int, 

2647 axis: AnyAxis, 

2648 valid_independent_refs: Dict[ 

2649 Tuple[TensorId, AxisId], 

2650 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2651 ], 

2652 ): 

2653 if isinstance(axis, BatchAxis) or isinstance( 

2654 axis.size, (int, ParameterizedSize, DataDependentSize) 

2655 ): 

2656 return 

2657 elif not isinstance(axis.size, SizeReference): 

2658 assert_never(axis.size) 

2659 

2660 # validate axis.size SizeReference 

2661 ref = (axis.size.tensor_id, axis.size.axis_id) 

2662 if ref not in valid_independent_refs: 

2663 raise ValueError( 

2664 "Invalid tensor axis reference at" 

2665 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

2666 ) 

2667 if ref == (tensor_id, axis.id): 

2668 raise ValueError( 

2669 "Self-referencing not allowed for" 

2670 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

2671 ) 

2672 if axis.type == "channel": 

2673 if valid_independent_refs[ref][1].type != "channel": 

2674 raise ValueError( 

2675 "A channel axis' size may only reference another fixed size" 

2676 + " channel axis." 

2677 ) 

2678 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

2679 ref_size = valid_independent_refs[ref][2] 

2680 assert isinstance(ref_size, int), ( 

2681 "channel axis ref (another channel axis) has to specify fixed" 

2682 + " size" 

2683 ) 

2684 generated_channel_names = [ 

2685 Identifier(axis.channel_names.format(i=i)) 

2686 for i in range(1, ref_size + 1) 

2687 ] 

2688 axis.channel_names = generated_channel_names 

2689 

2690 if (ax_unit := getattr(axis, "unit", None)) != ( 

2691 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

2692 ): 

2693 raise ValueError( 

2694 "The units of an axis and its reference axis need to match, but" 

2695 + f" '{ax_unit}' != '{ref_unit}'." 

2696 ) 

2697 ref_axis = valid_independent_refs[ref][1] 

2698 if isinstance(ref_axis, BatchAxis): 

2699 raise ValueError( 

2700 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

2701 + " (a batch axis is not allowed as reference)." 

2702 ) 

2703 

2704 if isinstance(axis, WithHalo): 

2705 min_size = axis.size.get_size(axis, ref_axis, n=0) 

2706 if (min_size - 2 * axis.halo) < 1: 

2707 raise ValueError( 

2708 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

2709 + f" {axis.halo}." 

2710 ) 

2711 

2712 input_halo = axis.halo * axis.scale / ref_axis.scale 

2713 if input_halo != int(input_halo) or input_halo % 2 == 1: 

2714 raise ValueError( 

2715 f"input_halo {input_halo} (output_halo {axis.halo} *" 

2716 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 

2717 + f" {tensor_id}.{axis.id}." 

2718 ) 

2719 

2720 @model_validator(mode="after") 

2721 def _validate_test_tensors(self) -> Self: 

2722 if not get_validation_context().perform_io_checks: 

2723 return self 

2724 

2725 test_output_arrays = [load_array(descr.test_tensor) for descr in self.outputs] 

2726 test_input_arrays = [load_array(descr.test_tensor) for descr in self.inputs] 

2727 

2728 tensors = { 

2729 descr.id: (descr, array) 

2730 for descr, array in zip( 

2731 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 

2732 ) 

2733 } 

2734 validate_tensors(tensors, tensor_origin="test_tensor") 

2735 

2736 output_arrays = { 

2737 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 

2738 } 

2739 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

2740 if not rep_tol.absolute_tolerance: 

2741 continue 

2742 

2743 if rep_tol.output_ids: 

2744 out_arrays = { 

2745 oid: a 

2746 for oid, a in output_arrays.items() 

2747 if oid in rep_tol.output_ids 

2748 } 

2749 else: 

2750 out_arrays = output_arrays 

2751 

2752 for out_id, array in out_arrays.items(): 

2753 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

2754 raise ValueError( 

2755 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

2756 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

2757 + f" (1% of the maximum value of the test tensor '{out_id}')" 

2758 ) 

2759 

2760 return self 

2761 

2762 @model_validator(mode="after") 

2763 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

2764 ipt_refs = {t.id for t in self.inputs} 

2765 out_refs = {t.id for t in self.outputs} 

2766 for ipt in self.inputs: 

2767 for p in ipt.preprocessing: 

2768 ref = p.kwargs.get("reference_tensor") 

2769 if ref is None: 

2770 continue 

2771 if ref not in ipt_refs: 

2772 raise ValueError( 

2773 f"`reference_tensor` '{ref}' not found. Valid input tensor" 

2774 + f" references are: {ipt_refs}." 

2775 ) 

2776 

2777 for out in self.outputs: 

2778 for p in out.postprocessing: 

2779 ref = p.kwargs.get("reference_tensor") 

2780 if ref is None: 

2781 continue 

2782 

2783 if ref not in ipt_refs and ref not in out_refs: 

2784 raise ValueError( 

2785 f"`reference_tensor` '{ref}' not found. Valid tensor references" 

2786 + f" are: {ipt_refs | out_refs}." 

2787 ) 

2788 

2789 return self 

2790 

2791 # TODO: use validate funcs in validate_test_tensors 

2792 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 

2793 

2794 name: Annotated[ 

2795 Annotated[ 

2796 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 

2797 ], 

2798 MinLen(5), 

2799 MaxLen(128), 

2800 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

2801 ] 

2802 """A human-readable name of this model. 

2803 It should be no longer than 64 characters 

2804 and may only contain letter, number, underscore, minus, parentheses and spaces. 

2805 We recommend to chose a name that refers to the model's task and image modality. 

2806 """ 

2807 

2808 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

2809 """Describes the output tensors.""" 

2810 

2811 @field_validator("outputs", mode="after") 

2812 @classmethod 

2813 def _validate_tensor_ids( 

2814 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

2815 ) -> Sequence[OutputTensorDescr]: 

2816 tensor_ids = [ 

2817 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

2818 ] 

2819 duplicate_tensor_ids: List[str] = [] 

2820 seen: Set[str] = set() 

2821 for t in tensor_ids: 

2822 if t in seen: 

2823 duplicate_tensor_ids.append(t) 

2824 

2825 seen.add(t) 

2826 

2827 if duplicate_tensor_ids: 

2828 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

2829 

2830 return outputs 

2831 

2832 @staticmethod 

2833 def _get_axes_with_parameterized_size( 

2834 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2835 ): 

2836 return { 

2837 f"{t.id}.{a.id}": (t, a, a.size) 

2838 for t in io 

2839 for a in t.axes 

2840 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

2841 } 

2842 

2843 @staticmethod 

2844 def _get_axes_with_independent_size( 

2845 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2846 ): 

2847 return { 

2848 (t.id, a.id): (t, a, a.size) 

2849 for t in io 

2850 for a in t.axes 

2851 if not isinstance(a, BatchAxis) 

2852 and isinstance(a.size, (int, ParameterizedSize)) 

2853 } 

2854 

2855 @field_validator("outputs", mode="after") 

2856 @classmethod 

2857 def _validate_output_axes( 

2858 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

2859 ) -> List[OutputTensorDescr]: 

2860 input_size_refs = cls._get_axes_with_independent_size( 

2861 info.data.get("inputs", []) 

2862 ) 

2863 output_size_refs = cls._get_axes_with_independent_size(outputs) 

2864 

2865 for i, out in enumerate(outputs): 

2866 valid_independent_refs: Dict[ 

2867 Tuple[TensorId, AxisId], 

2868 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2869 ] = { 

2870 **{ 

2871 (out.id, a.id): (out, a, a.size) 

2872 for a in out.axes 

2873 if not isinstance(a, BatchAxis) 

2874 and isinstance(a.size, (int, ParameterizedSize)) 

2875 }, 

2876 **input_size_refs, 

2877 **output_size_refs, 

2878 } 

2879 for a, ax in enumerate(out.axes): 

2880 cls._validate_axis( 

2881 "outputs", 

2882 i, 

2883 out.id, 

2884 a, 

2885 ax, 

2886 valid_independent_refs=valid_independent_refs, 

2887 ) 

2888 

2889 return outputs 

2890 

2891 packaged_by: List[Author] = Field( 

2892 default_factory=cast(Callable[[], List[Author]], list) 

2893 ) 

2894 """The persons that have packaged and uploaded this model. 

2895 Only required if those persons differ from the `authors`.""" 

2896 

2897 parent: Optional[LinkedModel] = None 

2898 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

2899 

2900 @model_validator(mode="after") 

2901 def _validate_parent_is_not_self(self) -> Self: 

2902 if self.parent is not None and self.parent.id == self.id: 

2903 raise ValueError("A model description may not reference itself as parent.") 

2904 

2905 return self 

2906 

2907 run_mode: Annotated[ 

2908 Optional[RunMode], 

2909 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

2910 ] = None 

2911 """Custom run mode for this model: for more complex prediction procedures like test time 

2912 data augmentation that currently cannot be expressed in the specification. 

2913 No standard run modes are defined yet.""" 

2914 

2915 timestamp: Datetime = Field(default_factory=Datetime.now) 

2916 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

2917 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

2918 (In Python a datetime object is valid, too).""" 

2919 

2920 training_data: Annotated[ 

2921 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

2922 Field(union_mode="left_to_right"), 

2923 ] = None 

2924 """The dataset used to train this model""" 

2925 

2926 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

2927 """The weights for this model. 

2928 Weights can be given for different formats, but should otherwise be equivalent. 

2929 The available weight formats determine which consumers can use this model.""" 

2930 

2931 config: Config = Field(default_factory=Config) 

2932 

2933 @model_validator(mode="after") 

2934 def _add_default_cover(self) -> Self: 

2935 if not get_validation_context().perform_io_checks or self.covers: 

2936 return self 

2937 

2938 try: 

2939 generated_covers = generate_covers( 

2940 [(t, load_array(t.test_tensor)) for t in self.inputs], 

2941 [(t, load_array(t.test_tensor)) for t in self.outputs], 

2942 ) 

2943 except Exception as e: 

2944 issue_warning( 

2945 "Failed to generate cover image(s): {e}", 

2946 value=self.covers, 

2947 msg_context=dict(e=e), 

2948 field="covers", 

2949 ) 

2950 else: 

2951 self.covers.extend(generated_covers) 

2952 

2953 return self 

2954 

2955 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

2956 data = [load_array(ipt.test_tensor) for ipt in self.inputs] 

2957 assert all(isinstance(d, np.ndarray) for d in data) 

2958 return data 

2959 

2960 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

2961 data = [load_array(out.test_tensor) for out in self.outputs] 

2962 assert all(isinstance(d, np.ndarray) for d in data) 

2963 return data 

2964 

2965 @staticmethod 

2966 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

2967 batch_size = 1 

2968 tensor_with_batchsize: Optional[TensorId] = None 

2969 for tid in tensor_sizes: 

2970 for aid, s in tensor_sizes[tid].items(): 

2971 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

2972 continue 

2973 

2974 if batch_size != 1: 

2975 assert tensor_with_batchsize is not None 

2976 raise ValueError( 

2977 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

2978 ) 

2979 

2980 batch_size = s 

2981 tensor_with_batchsize = tid 

2982 

2983 return batch_size 

2984 

2985 def get_output_tensor_sizes( 

2986 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

2987 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

2988 """Returns the tensor output sizes for given **input_sizes**. 

2989 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

2990 Otherwise it might be larger than the actual (valid) output""" 

2991 batch_size = self.get_batch_size(input_sizes) 

2992 ns = self.get_ns(input_sizes) 

2993 

2994 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

2995 return tensor_sizes.outputs 

2996 

2997 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

2998 """get parameter `n` for each parameterized axis 

2999 such that the valid input size is >= the given input size""" 

3000 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

3001 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

3002 for tid in input_sizes: 

3003 for aid, s in input_sizes[tid].items(): 

3004 size_descr = axes[tid][aid].size 

3005 if isinstance(size_descr, ParameterizedSize): 

3006 ret[(tid, aid)] = size_descr.get_n(s) 

3007 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

3008 pass 

3009 else: 

3010 assert_never(size_descr) 

3011 

3012 return ret 

3013 

3014 def get_tensor_sizes( 

3015 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

3016 ) -> _TensorSizes: 

3017 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

3018 return _TensorSizes( 

3019 { 

3020 t: { 

3021 aa: axis_sizes.inputs[(tt, aa)] 

3022 for tt, aa in axis_sizes.inputs 

3023 if tt == t 

3024 } 

3025 for t in {tt for tt, _ in axis_sizes.inputs} 

3026 }, 

3027 { 

3028 t: { 

3029 aa: axis_sizes.outputs[(tt, aa)] 

3030 for tt, aa in axis_sizes.outputs 

3031 if tt == t 

3032 } 

3033 for t in {tt for tt, _ in axis_sizes.outputs} 

3034 }, 

3035 ) 

3036 

3037 def get_axis_sizes( 

3038 self, 

3039 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3040 batch_size: Optional[int] = None, 

3041 *, 

3042 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

3043 ) -> _AxisSizes: 

3044 """Determine input and output block shape for scale factors **ns** 

3045 of parameterized input sizes. 

3046 

3047 Args: 

3048 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

3049 that is parameterized as `size = min + n * step`. 

3050 batch_size: The desired size of the batch dimension. 

3051 If given **batch_size** overwrites any batch size present in 

3052 **max_input_shape**. Default 1. 

3053 max_input_shape: Limits the derived block shapes. 

3054 Each axis for which the input size, parameterized by `n`, is larger 

3055 than **max_input_shape** is set to the minimal value `n_min` for which 

3056 this is still true. 

3057 Use this for small input samples or large values of **ns**. 

3058 Or simply whenever you know the full input shape. 

3059 

3060 Returns: 

3061 Resolved axis sizes for model inputs and outputs. 

3062 """ 

3063 max_input_shape = max_input_shape or {} 

3064 if batch_size is None: 

3065 for (_t_id, a_id), s in max_input_shape.items(): 

3066 if a_id == BATCH_AXIS_ID: 

3067 batch_size = s 

3068 break 

3069 else: 

3070 batch_size = 1 

3071 

3072 all_axes = { 

3073 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

3074 } 

3075 

3076 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

3077 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

3078 

3079 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

3080 if isinstance(a, BatchAxis): 

3081 if (t_descr.id, a.id) in ns: 

3082 logger.warning( 

3083 "Ignoring unexpected size increment factor (n) for batch axis" 

3084 + " of tensor '{}'.", 

3085 t_descr.id, 

3086 ) 

3087 return batch_size 

3088 elif isinstance(a.size, int): 

3089 if (t_descr.id, a.id) in ns: 

3090 logger.warning( 

3091 "Ignoring unexpected size increment factor (n) for fixed size" 

3092 + " axis '{}' of tensor '{}'.", 

3093 a.id, 

3094 t_descr.id, 

3095 ) 

3096 return a.size 

3097 elif isinstance(a.size, ParameterizedSize): 

3098 if (t_descr.id, a.id) not in ns: 

3099 raise ValueError( 

3100 "Size increment factor (n) missing for parametrized axis" 

3101 + f" '{a.id}' of tensor '{t_descr.id}'." 

3102 ) 

3103 n = ns[(t_descr.id, a.id)] 

3104 s_max = max_input_shape.get((t_descr.id, a.id)) 

3105 if s_max is not None: 

3106 n = min(n, a.size.get_n(s_max)) 

3107 

3108 return a.size.get_size(n) 

3109 

3110 elif isinstance(a.size, SizeReference): 

3111 if (t_descr.id, a.id) in ns: 

3112 logger.warning( 

3113 "Ignoring unexpected size increment factor (n) for axis '{}'" 

3114 + " of tensor '{}' with size reference.", 

3115 a.id, 

3116 t_descr.id, 

3117 ) 

3118 assert not isinstance(a, BatchAxis) 

3119 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

3120 assert not isinstance(ref_axis, BatchAxis) 

3121 ref_key = (a.size.tensor_id, a.size.axis_id) 

3122 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

3123 assert ref_size is not None, ref_key 

3124 assert not isinstance(ref_size, _DataDepSize), ref_key 

3125 return a.size.get_size( 

3126 axis=a, 

3127 ref_axis=ref_axis, 

3128 ref_size=ref_size, 

3129 ) 

3130 elif isinstance(a.size, DataDependentSize): 

3131 if (t_descr.id, a.id) in ns: 

3132 logger.warning( 

3133 "Ignoring unexpected increment factor (n) for data dependent" 

3134 + " size axis '{}' of tensor '{}'.", 

3135 a.id, 

3136 t_descr.id, 

3137 ) 

3138 return _DataDepSize(a.size.min, a.size.max) 

3139 else: 

3140 assert_never(a.size) 

3141 

3142 # first resolve all , but the `SizeReference` input sizes 

3143 for t_descr in self.inputs: 

3144 for a in t_descr.axes: 

3145 if not isinstance(a.size, SizeReference): 

3146 s = get_axis_size(a) 

3147 assert not isinstance(s, _DataDepSize) 

3148 inputs[t_descr.id, a.id] = s 

3149 

3150 # resolve all other input axis sizes 

3151 for t_descr in self.inputs: 

3152 for a in t_descr.axes: 

3153 if isinstance(a.size, SizeReference): 

3154 s = get_axis_size(a) 

3155 assert not isinstance(s, _DataDepSize) 

3156 inputs[t_descr.id, a.id] = s 

3157 

3158 # resolve all output axis sizes 

3159 for t_descr in self.outputs: 

3160 for a in t_descr.axes: 

3161 assert not isinstance(a.size, ParameterizedSize) 

3162 s = get_axis_size(a) 

3163 outputs[t_descr.id, a.id] = s 

3164 

3165 return _AxisSizes(inputs=inputs, outputs=outputs) 

3166 

3167 @model_validator(mode="before") 

3168 @classmethod 

3169 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

3170 cls.convert_from_old_format_wo_validation(data) 

3171 return data 

3172 

3173 @classmethod 

3174 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

3175 """Convert metadata following an older format version to this classes' format 

3176 without validating the result. 

3177 """ 

3178 if ( 

3179 data.get("type") == "model" 

3180 and isinstance(fv := data.get("format_version"), str) 

3181 and fv.count(".") == 2 

3182 ): 

3183 fv_parts = fv.split(".") 

3184 if any(not p.isdigit() for p in fv_parts): 

3185 return 

3186 

3187 fv_tuple = tuple(map(int, fv_parts)) 

3188 

3189 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

3190 if fv_tuple[:2] in ((0, 3), (0, 4)): 

3191 m04 = _ModelDescr_v0_4.load(data) 

3192 if isinstance(m04, InvalidDescr): 

3193 try: 

3194 updated = _model_conv.convert_as_dict( 

3195 m04 # pyright: ignore[reportArgumentType] 

3196 ) 

3197 except Exception as e: 

3198 logger.error( 

3199 "Failed to convert from invalid model 0.4 description." 

3200 + f"\nerror: {e}" 

3201 + "\nProceeding with model 0.5 validation without conversion." 

3202 ) 

3203 updated = None 

3204 else: 

3205 updated = _model_conv.convert_as_dict(m04) 

3206 

3207 if updated is not None: 

3208 data.clear() 

3209 data.update(updated) 

3210 

3211 elif fv_tuple[:2] == (0, 5): 

3212 # bump patch version 

3213 data["format_version"] = cls.implemented_format_version 

3214 

3215 

3216class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

3217 def _convert( 

3218 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

3219 ) -> "ModelDescr | dict[str, Any]": 

3220 name = "".join( 

3221 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

3222 for c in src.name 

3223 ) 

3224 

3225 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

3226 conv = ( 

3227 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

3228 ) 

3229 return None if auths is None else [conv(a) for a in auths] 

3230 

3231 if TYPE_CHECKING: 

3232 arch_file_conv = _arch_file_conv.convert 

3233 arch_lib_conv = _arch_lib_conv.convert 

3234 else: 

3235 arch_file_conv = _arch_file_conv.convert_as_dict 

3236 arch_lib_conv = _arch_lib_conv.convert_as_dict 

3237 

3238 input_size_refs = { 

3239 ipt.name: { 

3240 a: s 

3241 for a, s in zip( 

3242 ipt.axes, 

3243 ( 

3244 ipt.shape.min 

3245 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

3246 else ipt.shape 

3247 ), 

3248 ) 

3249 } 

3250 for ipt in src.inputs 

3251 if ipt.shape 

3252 } 

3253 output_size_refs = { 

3254 **{ 

3255 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

3256 for out in src.outputs 

3257 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

3258 }, 

3259 **input_size_refs, 

3260 } 

3261 

3262 return tgt( 

3263 attachments=( 

3264 [] 

3265 if src.attachments is None 

3266 else [FileDescr(source=f) for f in src.attachments.files] 

3267 ), 

3268 authors=[ 

3269 _author_conv.convert_as_dict(a) for a in src.authors 

3270 ], # pyright: ignore[reportArgumentType] 

3271 cite=[ 

3272 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 

3273 ], # pyright: ignore[reportArgumentType] 

3274 config=src.config, # pyright: ignore[reportArgumentType] 

3275 covers=src.covers, 

3276 description=src.description, 

3277 documentation=src.documentation, 

3278 format_version="0.5.4", 

3279 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3280 icon=src.icon, 

3281 id=None if src.id is None else ModelId(src.id), 

3282 id_emoji=src.id_emoji, 

3283 license=src.license, # type: ignore 

3284 links=src.links, 

3285 maintainers=[ 

3286 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 

3287 ], # pyright: ignore[reportArgumentType] 

3288 name=name, 

3289 tags=src.tags, 

3290 type=src.type, 

3291 uploader=src.uploader, 

3292 version=src.version, 

3293 inputs=[ # pyright: ignore[reportArgumentType] 

3294 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

3295 for ipt, tt, st, in zip( 

3296 src.inputs, 

3297 src.test_inputs, 

3298 src.sample_inputs or [None] * len(src.test_inputs), 

3299 ) 

3300 ], 

3301 outputs=[ # pyright: ignore[reportArgumentType] 

3302 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

3303 for out, tt, st, in zip( 

3304 src.outputs, 

3305 src.test_outputs, 

3306 src.sample_outputs or [None] * len(src.test_outputs), 

3307 ) 

3308 ], 

3309 parent=( 

3310 None 

3311 if src.parent is None 

3312 else LinkedModel( 

3313 id=ModelId( 

3314 str(src.parent.id) 

3315 + ( 

3316 "" 

3317 if src.parent.version_number is None 

3318 else f"/{src.parent.version_number}" 

3319 ) 

3320 ) 

3321 ) 

3322 ), 

3323 training_data=( 

3324 None 

3325 if src.training_data is None 

3326 else ( 

3327 LinkedDataset( 

3328 id=DatasetId( 

3329 str(src.training_data.id) 

3330 + ( 

3331 "" 

3332 if src.training_data.version_number is None 

3333 else f"/{src.training_data.version_number}" 

3334 ) 

3335 ) 

3336 ) 

3337 if isinstance(src.training_data, LinkedDataset02) 

3338 else src.training_data 

3339 ) 

3340 ), 

3341 packaged_by=[ 

3342 _author_conv.convert_as_dict(a) for a in src.packaged_by 

3343 ], # pyright: ignore[reportArgumentType] 

3344 run_mode=src.run_mode, 

3345 timestamp=src.timestamp, 

3346 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

3347 keras_hdf5=(w := src.weights.keras_hdf5) 

3348 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

3349 authors=conv_authors(w.authors), 

3350 source=w.source, 

3351 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3352 parent=w.parent, 

3353 ), 

3354 onnx=(w := src.weights.onnx) 

3355 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

3356 source=w.source, 

3357 authors=conv_authors(w.authors), 

3358 parent=w.parent, 

3359 opset_version=w.opset_version or 15, 

3360 ), 

3361 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

3362 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

3363 source=w.source, 

3364 authors=conv_authors(w.authors), 

3365 parent=w.parent, 

3366 architecture=( 

3367 arch_file_conv( 

3368 w.architecture, 

3369 w.architecture_sha256, 

3370 w.kwargs, 

3371 ) 

3372 if isinstance(w.architecture, _CallableFromFile_v0_4) 

3373 else arch_lib_conv(w.architecture, w.kwargs) 

3374 ), 

3375 pytorch_version=w.pytorch_version or Version("1.10"), 

3376 dependencies=( 

3377 None 

3378 if w.dependencies is None 

3379 else (FileDescr if TYPE_CHECKING else dict)( 

3380 source=cast( 

3381 FileSource, 

3382 str(deps := w.dependencies)[ 

3383 ( 

3384 len("conda:") 

3385 if str(deps).startswith("conda:") 

3386 else 0 

3387 ) : 

3388 ], 

3389 ) 

3390 ) 

3391 ), 

3392 ), 

3393 tensorflow_js=(w := src.weights.tensorflow_js) 

3394 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

3395 source=w.source, 

3396 authors=conv_authors(w.authors), 

3397 parent=w.parent, 

3398 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3399 ), 

3400 tensorflow_saved_model_bundle=( 

3401 w := src.weights.tensorflow_saved_model_bundle 

3402 ) 

3403 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

3404 authors=conv_authors(w.authors), 

3405 parent=w.parent, 

3406 source=w.source, 

3407 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3408 dependencies=( 

3409 None 

3410 if w.dependencies is None 

3411 else (FileDescr if TYPE_CHECKING else dict)( 

3412 source=cast( 

3413 FileSource, 

3414 ( 

3415 str(w.dependencies)[len("conda:") :] 

3416 if str(w.dependencies).startswith("conda:") 

3417 else str(w.dependencies) 

3418 ), 

3419 ) 

3420 ) 

3421 ), 

3422 ), 

3423 torchscript=(w := src.weights.torchscript) 

3424 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

3425 source=w.source, 

3426 authors=conv_authors(w.authors), 

3427 parent=w.parent, 

3428 pytorch_version=w.pytorch_version or Version("1.10"), 

3429 ), 

3430 ), 

3431 ) 

3432 

3433 

3434_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

3435 

3436 

3437# create better cover images for 3d data and non-image outputs 

3438def generate_covers( 

3439 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

3440 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

3441) -> List[Path]: 

3442 def squeeze( 

3443 data: NDArray[Any], axes: Sequence[AnyAxis] 

3444 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

3445 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

3446 if data.ndim != len(axes): 

3447 raise ValueError( 

3448 f"tensor shape {data.shape} does not match described axes" 

3449 + f" {[a.id for a in axes]}" 

3450 ) 

3451 

3452 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

3453 return data.squeeze(), axes 

3454 

3455 def normalize( 

3456 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

3457 ) -> NDArray[np.float32]: 

3458 data = data.astype("float32") 

3459 data -= data.min(axis=axis, keepdims=True) 

3460 data /= data.max(axis=axis, keepdims=True) + eps 

3461 return data 

3462 

3463 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

3464 original_shape = data.shape 

3465 data, axes = squeeze(data, axes) 

3466 

3467 # take slice fom any batch or index axis if needed 

3468 # and convert the first channel axis and take a slice from any additional channel axes 

3469 slices: Tuple[slice, ...] = () 

3470 ndim = data.ndim 

3471 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

3472 has_c_axis = False 

3473 for i, a in enumerate(axes): 

3474 s = data.shape[i] 

3475 assert s > 1 

3476 if ( 

3477 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

3478 and ndim > ndim_need 

3479 ): 

3480 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3481 ndim -= 1 

3482 elif isinstance(a, ChannelAxis): 

3483 if has_c_axis: 

3484 # second channel axis 

3485 data = data[slices + (slice(0, 1),)] 

3486 ndim -= 1 

3487 else: 

3488 has_c_axis = True 

3489 if s == 2: 

3490 # visualize two channels with cyan and magenta 

3491 data = np.concatenate( 

3492 [ 

3493 data[slices + (slice(1, 2),)], 

3494 data[slices + (slice(0, 1),)], 

3495 ( 

3496 data[slices + (slice(0, 1),)] 

3497 + data[slices + (slice(1, 2),)] 

3498 ) 

3499 / 2, # TODO: take maximum instead? 

3500 ], 

3501 axis=i, 

3502 ) 

3503 elif data.shape[i] == 3: 

3504 pass # visualize 3 channels as RGB 

3505 else: 

3506 # visualize first 3 channels as RGB 

3507 data = data[slices + (slice(3),)] 

3508 

3509 assert data.shape[i] == 3 

3510 

3511 slices += (slice(None),) 

3512 

3513 data, axes = squeeze(data, axes) 

3514 assert len(axes) == ndim 

3515 # take slice from z axis if needed 

3516 slices = () 

3517 if ndim > ndim_need: 

3518 for i, a in enumerate(axes): 

3519 s = data.shape[i] 

3520 if a.id == AxisId("z"): 

3521 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3522 data, axes = squeeze(data, axes) 

3523 ndim -= 1 

3524 break 

3525 

3526 slices += (slice(None),) 

3527 

3528 # take slice from any space or time axis 

3529 slices = () 

3530 

3531 for i, a in enumerate(axes): 

3532 if ndim <= ndim_need: 

3533 break 

3534 

3535 s = data.shape[i] 

3536 assert s > 1 

3537 if isinstance( 

3538 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

3539 ): 

3540 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3541 ndim -= 1 

3542 

3543 slices += (slice(None),) 

3544 

3545 del slices 

3546 data, axes = squeeze(data, axes) 

3547 assert len(axes) == ndim 

3548 

3549 if (has_c_axis and ndim != 3) or ndim != 2: 

3550 raise ValueError( 

3551 f"Failed to construct cover image from shape {original_shape}" 

3552 ) 

3553 

3554 if not has_c_axis: 

3555 assert ndim == 2 

3556 data = np.repeat(data[:, :, None], 3, axis=2) 

3557 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

3558 ndim += 1 

3559 

3560 assert ndim == 3 

3561 

3562 # transpose axis order such that longest axis comes first... 

3563 axis_order: List[int] = list(np.argsort(list(data.shape))) 

3564 axis_order.reverse() 

3565 # ... and channel axis is last 

3566 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

3567 axis_order.append(axis_order.pop(c)) 

3568 axes = [axes[ao] for ao in axis_order] 

3569 data = data.transpose(axis_order) 

3570 

3571 # h, w = data.shape[:2] 

3572 # if h / w in (1.0 or 2.0): 

3573 # pass 

3574 # elif h / w < 2: 

3575 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

3576 

3577 norm_along = ( 

3578 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

3579 ) 

3580 # normalize the data and map to 8 bit 

3581 data = normalize(data, norm_along) 

3582 data = (data * 255).astype("uint8") 

3583 

3584 return data 

3585 

3586 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

3587 assert im0.dtype == im1.dtype == np.uint8 

3588 assert im0.shape == im1.shape 

3589 assert im0.ndim == 3 

3590 N, M, C = im0.shape 

3591 assert C == 3 

3592 out = np.ones((N, M, C), dtype="uint8") 

3593 for c in range(C): 

3594 outc = np.tril(im0[..., c]) 

3595 mask = outc == 0 

3596 outc[mask] = np.triu(im1[..., c])[mask] 

3597 out[..., c] = outc 

3598 

3599 return out 

3600 

3601 ipt_descr, ipt = inputs[0] 

3602 out_descr, out = outputs[0] 

3603 

3604 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

3605 out_img = to_2d_image(out, out_descr.axes) 

3606 

3607 cover_folder = Path(mkdtemp()) 

3608 if ipt_img.shape == out_img.shape: 

3609 covers = [cover_folder / "cover.png"] 

3610 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

3611 else: 

3612 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

3613 imwrite(covers[0], ipt_img) 

3614 imwrite(covers[1], out_img) 

3615 

3616 return covers