Coverage for bioimageio/spec/model/v0_5.py: 75%

1311 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-02 14:21 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from abc import ABC 

8from copy import deepcopy 

9from itertools import chain 

10from math import ceil 

11from pathlib import Path, PurePosixPath 

12from tempfile import mkdtemp 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 ClassVar, 

17 Dict, 

18 Generic, 

19 List, 

20 Literal, 

21 Mapping, 

22 NamedTuple, 

23 Optional, 

24 Sequence, 

25 Set, 

26 Tuple, 

27 Type, 

28 TypeVar, 

29 Union, 

30 cast, 

31) 

32 

33import numpy as np 

34from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

35from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

36from loguru import logger 

37from numpy.typing import NDArray 

38from pydantic import ( 

39 AfterValidator, 

40 Discriminator, 

41 Field, 

42 RootModel, 

43 Tag, 

44 ValidationInfo, 

45 WrapSerializer, 

46 field_validator, 

47 model_validator, 

48) 

49from typing_extensions import Annotated, Self, assert_never, get_args 

50 

51from .._internal.common_nodes import ( 

52 InvalidDescr, 

53 Node, 

54 NodeWithExplicitlySetFields, 

55) 

56from .._internal.constants import DTYPE_LIMITS 

57from .._internal.field_warning import issue_warning, warn 

58from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

59from .._internal.io import FileDescr as FileDescr 

60from .._internal.io import WithSuffix, YamlValue, download 

61from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath 

62from .._internal.io_basics import Sha256 as Sha256 

63from .._internal.io_utils import load_array 

64from .._internal.node_converter import Converter 

65from .._internal.types import ( 

66 AbsoluteTolerance, 

67 ImportantFileSource, 

68 LowerCaseIdentifier, 

69 LowerCaseIdentifierAnno, 

70 MismatchedElementsPerMillion, 

71 RelativeTolerance, 

72 SiUnit, 

73) 

74from .._internal.types import Datetime as Datetime 

75from .._internal.types import Identifier as Identifier 

76from .._internal.types import NotEmpty as NotEmpty 

77from .._internal.url import HttpUrl as HttpUrl 

78from .._internal.validation_context import get_validation_context 

79from .._internal.validator_annotations import RestrictCharacters 

80from .._internal.version_type import Version as Version 

81from .._internal.warning_levels import INFO 

82from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

83from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

84from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

85from ..dataset.v0_3 import DatasetId as DatasetId 

86from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

87from ..dataset.v0_3 import Uploader as Uploader 

88from ..generic.v0_3 import ( 

89 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

90) 

91from ..generic.v0_3 import Author as Author 

92from ..generic.v0_3 import BadgeDescr as BadgeDescr 

93from ..generic.v0_3 import CiteEntry as CiteEntry 

94from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

95from ..generic.v0_3 import ( 

96 DocumentationSource, 

97 GenericModelDescrBase, 

98 LinkedResourceBase, 

99 _author_conv, # pyright: ignore[reportPrivateUsage] 

100 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

101) 

102from ..generic.v0_3 import Doi as Doi 

103from ..generic.v0_3 import LicenseId as LicenseId 

104from ..generic.v0_3 import LinkedResource as LinkedResource 

105from ..generic.v0_3 import Maintainer as Maintainer 

106from ..generic.v0_3 import OrcidId as OrcidId 

107from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

108from ..generic.v0_3 import ResourceId as ResourceId 

109from .v0_4 import Author as _Author_v0_4 

110from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

111from .v0_4 import CallableFromDepencency as CallableFromDepencency 

112from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

113from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

114from .v0_4 import ClipDescr as _ClipDescr_v0_4 

115from .v0_4 import ClipKwargs as ClipKwargs 

116from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

117from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

118from .v0_4 import KnownRunMode as KnownRunMode 

119from .v0_4 import ModelDescr as _ModelDescr_v0_4 

120from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

121from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

122from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

123from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

124from .v0_4 import ProcessingKwargs as ProcessingKwargs 

125from .v0_4 import RunMode as RunMode 

126from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

127from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

128from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

129from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

130from .v0_4 import TensorName as _TensorName_v0_4 

131from .v0_4 import WeightsFormat as WeightsFormat 

132from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

133from .v0_4 import package_weights 

134 

135SpaceUnit = Literal[ 

136 "attometer", 

137 "angstrom", 

138 "centimeter", 

139 "decimeter", 

140 "exameter", 

141 "femtometer", 

142 "foot", 

143 "gigameter", 

144 "hectometer", 

145 "inch", 

146 "kilometer", 

147 "megameter", 

148 "meter", 

149 "micrometer", 

150 "mile", 

151 "millimeter", 

152 "nanometer", 

153 "parsec", 

154 "petameter", 

155 "picometer", 

156 "terameter", 

157 "yard", 

158 "yoctometer", 

159 "yottameter", 

160 "zeptometer", 

161 "zettameter", 

162] 

163"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

164 

165TimeUnit = Literal[ 

166 "attosecond", 

167 "centisecond", 

168 "day", 

169 "decisecond", 

170 "exasecond", 

171 "femtosecond", 

172 "gigasecond", 

173 "hectosecond", 

174 "hour", 

175 "kilosecond", 

176 "megasecond", 

177 "microsecond", 

178 "millisecond", 

179 "minute", 

180 "nanosecond", 

181 "petasecond", 

182 "picosecond", 

183 "second", 

184 "terasecond", 

185 "yoctosecond", 

186 "yottasecond", 

187 "zeptosecond", 

188 "zettasecond", 

189] 

190"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

191 

192AxisType = Literal["batch", "channel", "index", "time", "space"] 

193 

194_AXIS_TYPE_MAP: Mapping[str, AxisType] = { 

195 "b": "batch", 

196 "t": "time", 

197 "i": "index", 

198 "c": "channel", 

199 "x": "space", 

200 "y": "space", 

201 "z": "space", 

202} 

203 

204_AXIS_ID_MAP = { 

205 "b": "batch", 

206 "t": "time", 

207 "i": "index", 

208 "c": "channel", 

209} 

210 

211 

212class TensorId(LowerCaseIdentifier): 

213 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

214 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

215 ] 

216 

217 

218def _normalize_axis_id(a: str): 

219 a = str(a) 

220 normalized = _AXIS_ID_MAP.get(a, a) 

221 if a != normalized: 

222 logger.opt(depth=3).warning( 

223 "Normalized axis id from '{}' to '{}'.", a, normalized 

224 ) 

225 return normalized 

226 

227 

228class AxisId(LowerCaseIdentifier): 

229 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

230 Annotated[ 

231 LowerCaseIdentifierAnno, 

232 MaxLen(16), 

233 AfterValidator(_normalize_axis_id), 

234 ] 

235 ] 

236 

237 

238def _is_batch(a: str) -> bool: 

239 return str(a) == "batch" 

240 

241 

242def _is_not_batch(a: str) -> bool: 

243 return not _is_batch(a) 

244 

245 

246NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

247 

248PostprocessingId = Literal[ 

249 "binarize", 

250 "clip", 

251 "ensure_dtype", 

252 "fixed_zero_mean_unit_variance", 

253 "scale_linear", 

254 "scale_mean_variance", 

255 "scale_range", 

256 "sigmoid", 

257 "zero_mean_unit_variance", 

258] 

259PreprocessingId = Literal[ 

260 "binarize", 

261 "clip", 

262 "ensure_dtype", 

263 "scale_linear", 

264 "sigmoid", 

265 "zero_mean_unit_variance", 

266 "scale_range", 

267] 

268 

269 

270SAME_AS_TYPE = "<same as type>" 

271 

272 

273ParameterizedSize_N = int 

274""" 

275Annotates an integer to calculate a concrete axis size from a `ParameterizedSize`. 

276""" 

277 

278 

279class ParameterizedSize(Node): 

280 """Describes a range of valid tensor axis sizes as `size = min + n*step`. 

281 

282 - **min** and **step** are given by the model description. 

283 - All blocksize paramters n = 0,1,2,... yield a valid `size`. 

284 - A greater blocksize paramter n = 0,1,2,... results in a greater **size**. 

285 This allows to adjust the axis size more generically. 

286 """ 

287 

288 N: ClassVar[Type[int]] = ParameterizedSize_N 

289 """Positive integer to parameterize this axis""" 

290 

291 min: Annotated[int, Gt(0)] 

292 step: Annotated[int, Gt(0)] 

293 

294 def validate_size(self, size: int) -> int: 

295 if size < self.min: 

296 raise ValueError(f"size {size} < {self.min}") 

297 if (size - self.min) % self.step != 0: 

298 raise ValueError( 

299 f"axis of size {size} is not parameterized by `min + n*step` =" 

300 + f" `{self.min} + n*{self.step}`" 

301 ) 

302 

303 return size 

304 

305 def get_size(self, n: ParameterizedSize_N) -> int: 

306 return self.min + self.step * n 

307 

308 def get_n(self, s: int) -> ParameterizedSize_N: 

309 """return smallest n parameterizing a size greater or equal than `s`""" 

310 return ceil((s - self.min) / self.step) 

311 

312 

313class DataDependentSize(Node): 

314 min: Annotated[int, Gt(0)] = 1 

315 max: Annotated[Optional[int], Gt(1)] = None 

316 

317 @model_validator(mode="after") 

318 def _validate_max_gt_min(self): 

319 if self.max is not None and self.min >= self.max: 

320 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

321 

322 return self 

323 

324 def validate_size(self, size: int) -> int: 

325 if size < self.min: 

326 raise ValueError(f"size {size} < {self.min}") 

327 

328 if self.max is not None and size > self.max: 

329 raise ValueError(f"size {size} > {self.max}") 

330 

331 return size 

332 

333 

334class SizeReference(Node): 

335 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

336 

337 `axis.size = reference.size * reference.scale / axis.scale + offset` 

338 

339 Note: 

340 1. The axis and the referenced axis need to have the same unit (or no unit). 

341 2. Batch axes may not be referenced. 

342 3. Fractions are rounded down. 

343 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

344 `concatenable` as well with the same block order. 

345 

346 Example: 

347 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

348 Let's assume that we want to express the image height h in relation to its width w 

349 instead of only accepting input images of exactly 100*49 pixels 

350 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

351 

352 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

353 >>> h = SpaceInputAxis( 

354 ... id=AxisId("h"), 

355 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

356 ... unit="millimeter", 

357 ... scale=4, 

358 ... ) 

359 >>> print(h.size.get_size(h, w)) 

360 49 

361 

362 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

363 """ 

364 

365 tensor_id: TensorId 

366 """tensor id of the reference axis""" 

367 

368 axis_id: AxisId 

369 """axis id of the reference axis""" 

370 

371 offset: int = 0 

372 

373 def get_size( 

374 self, 

375 axis: Union[ 

376 ChannelAxis, 

377 IndexInputAxis, 

378 IndexOutputAxis, 

379 TimeInputAxis, 

380 SpaceInputAxis, 

381 TimeOutputAxis, 

382 TimeOutputAxisWithHalo, 

383 SpaceOutputAxis, 

384 SpaceOutputAxisWithHalo, 

385 ], 

386 ref_axis: Union[ 

387 ChannelAxis, 

388 IndexInputAxis, 

389 IndexOutputAxis, 

390 TimeInputAxis, 

391 SpaceInputAxis, 

392 TimeOutputAxis, 

393 TimeOutputAxisWithHalo, 

394 SpaceOutputAxis, 

395 SpaceOutputAxisWithHalo, 

396 ], 

397 n: ParameterizedSize_N = 0, 

398 ref_size: Optional[int] = None, 

399 ): 

400 """Compute the concrete size for a given axis and its reference axis. 

401 

402 Args: 

403 axis: The axis this `SizeReference` is the size of. 

404 ref_axis: The reference axis to compute the size from. 

405 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

406 and no fixed **ref_size** is given, 

407 **n** is used to compute the size of the parameterized **ref_axis**. 

408 ref_size: Overwrite the reference size instead of deriving it from 

409 **ref_axis** 

410 (**ref_axis.scale** is still used; any given **n** is ignored). 

411 """ 

412 assert ( 

413 axis.size == self 

414 ), "Given `axis.size` is not defined by this `SizeReference`" 

415 

416 assert ( 

417 ref_axis.id == self.axis_id 

418 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

419 

420 assert axis.unit == ref_axis.unit, ( 

421 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

422 f" but {axis.unit}!={ref_axis.unit}" 

423 ) 

424 if ref_size is None: 

425 if isinstance(ref_axis.size, (int, float)): 

426 ref_size = ref_axis.size 

427 elif isinstance(ref_axis.size, ParameterizedSize): 

428 ref_size = ref_axis.size.get_size(n) 

429 elif isinstance(ref_axis.size, DataDependentSize): 

430 raise ValueError( 

431 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

432 ) 

433 elif isinstance(ref_axis.size, SizeReference): 

434 raise ValueError( 

435 "Reference axis referenced in `SizeReference` may not be sized by a" 

436 + " `SizeReference` itself." 

437 ) 

438 else: 

439 assert_never(ref_axis.size) 

440 

441 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

442 

443 @staticmethod 

444 def _get_unit( 

445 axis: Union[ 

446 ChannelAxis, 

447 IndexInputAxis, 

448 IndexOutputAxis, 

449 TimeInputAxis, 

450 SpaceInputAxis, 

451 TimeOutputAxis, 

452 TimeOutputAxisWithHalo, 

453 SpaceOutputAxis, 

454 SpaceOutputAxisWithHalo, 

455 ], 

456 ): 

457 return axis.unit 

458 

459 

460class AxisBase(NodeWithExplicitlySetFields): 

461 id: AxisId 

462 """An axis id unique across all axes of one tensor.""" 

463 

464 description: Annotated[str, MaxLen(128)] = "" 

465 

466 

467class WithHalo(Node): 

468 halo: Annotated[int, Ge(1)] 

469 """The halo should be cropped from the output tensor to avoid boundary effects. 

470 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

471 To document a halo that is already cropped by the model use `size.offset` instead.""" 

472 

473 size: Annotated[ 

474 SizeReference, 

475 Field( 

476 examples=[ 

477 10, 

478 SizeReference( 

479 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

480 ).model_dump(mode="json"), 

481 ] 

482 ), 

483 ] 

484 """reference to another axis with an optional offset (see `SizeReference`)""" 

485 

486 

487BATCH_AXIS_ID = AxisId("batch") 

488 

489 

490class BatchAxis(AxisBase): 

491 implemented_type: ClassVar[Literal["batch"]] = "batch" 

492 if TYPE_CHECKING: 

493 type: Literal["batch"] = "batch" 

494 else: 

495 type: Literal["batch"] 

496 

497 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

498 size: Optional[Literal[1]] = None 

499 """The batch size may be fixed to 1, 

500 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

501 

502 @property 

503 def scale(self): 

504 return 1.0 

505 

506 @property 

507 def concatenable(self): 

508 return True 

509 

510 @property 

511 def unit(self): 

512 return None 

513 

514 

515class ChannelAxis(AxisBase): 

516 implemented_type: ClassVar[Literal["channel"]] = "channel" 

517 if TYPE_CHECKING: 

518 type: Literal["channel"] = "channel" 

519 else: 

520 type: Literal["channel"] 

521 

522 id: NonBatchAxisId = AxisId("channel") 

523 channel_names: NotEmpty[List[Identifier]] 

524 

525 @property 

526 def size(self) -> int: 

527 return len(self.channel_names) 

528 

529 @property 

530 def concatenable(self): 

531 return False 

532 

533 @property 

534 def scale(self) -> float: 

535 return 1.0 

536 

537 @property 

538 def unit(self): 

539 return None 

540 

541 

542class IndexAxisBase(AxisBase): 

543 implemented_type: ClassVar[Literal["index"]] = "index" 

544 if TYPE_CHECKING: 

545 type: Literal["index"] = "index" 

546 else: 

547 type: Literal["index"] 

548 

549 id: NonBatchAxisId = AxisId("index") 

550 

551 @property 

552 def scale(self) -> float: 

553 return 1.0 

554 

555 @property 

556 def unit(self): 

557 return None 

558 

559 

560class _WithInputAxisSize(Node): 

561 size: Annotated[ 

562 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

563 Field( 

564 examples=[ 

565 10, 

566 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

567 SizeReference( 

568 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

569 ).model_dump(mode="json"), 

570 ] 

571 ), 

572 ] 

573 """The size/length of this axis can be specified as 

574 - fixed integer 

575 - parameterized series of valid sizes (`ParameterizedSize`) 

576 - reference to another axis with an optional offset (`SizeReference`) 

577 """ 

578 

579 

580class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

581 concatenable: bool = False 

582 """If a model has a `concatenable` input axis, it can be processed blockwise, 

583 splitting a longer sample axis into blocks matching its input tensor description. 

584 Output axes are concatenable if they have a `SizeReference` to a concatenable 

585 input axis. 

586 """ 

587 

588 

589class IndexOutputAxis(IndexAxisBase): 

590 size: Annotated[ 

591 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

592 Field( 

593 examples=[ 

594 10, 

595 SizeReference( 

596 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

597 ).model_dump(mode="json"), 

598 ] 

599 ), 

600 ] 

601 """The size/length of this axis can be specified as 

602 - fixed integer 

603 - reference to another axis with an optional offset (`SizeReference`) 

604 - data dependent size using `DataDependentSize` (size is only known after model inference) 

605 """ 

606 

607 

608class TimeAxisBase(AxisBase): 

609 implemented_type: ClassVar[Literal["time"]] = "time" 

610 if TYPE_CHECKING: 

611 type: Literal["time"] = "time" 

612 else: 

613 type: Literal["time"] 

614 

615 id: NonBatchAxisId = AxisId("time") 

616 unit: Optional[TimeUnit] = None 

617 scale: Annotated[float, Gt(0)] = 1.0 

618 

619 

620class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

621 concatenable: bool = False 

622 """If a model has a `concatenable` input axis, it can be processed blockwise, 

623 splitting a longer sample axis into blocks matching its input tensor description. 

624 Output axes are concatenable if they have a `SizeReference` to a concatenable 

625 input axis. 

626 """ 

627 

628 

629class SpaceAxisBase(AxisBase): 

630 implemented_type: ClassVar[Literal["space"]] = "space" 

631 if TYPE_CHECKING: 

632 type: Literal["space"] = "space" 

633 else: 

634 type: Literal["space"] 

635 

636 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

637 unit: Optional[SpaceUnit] = None 

638 scale: Annotated[float, Gt(0)] = 1.0 

639 

640 

641class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

642 concatenable: bool = False 

643 """If a model has a `concatenable` input axis, it can be processed blockwise, 

644 splitting a longer sample axis into blocks matching its input tensor description. 

645 Output axes are concatenable if they have a `SizeReference` to a concatenable 

646 input axis. 

647 """ 

648 

649 

650INPUT_AXIS_TYPES = ( 

651 BatchAxis, 

652 ChannelAxis, 

653 IndexInputAxis, 

654 TimeInputAxis, 

655 SpaceInputAxis, 

656) 

657"""intended for isinstance comparisons in py<3.10""" 

658 

659_InputAxisUnion = Union[ 

660 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

661] 

662InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

663 

664 

665class _WithOutputAxisSize(Node): 

666 size: Annotated[ 

667 Union[Annotated[int, Gt(0)], SizeReference], 

668 Field( 

669 examples=[ 

670 10, 

671 SizeReference( 

672 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

673 ).model_dump(mode="json"), 

674 ] 

675 ), 

676 ] 

677 """The size/length of this axis can be specified as 

678 - fixed integer 

679 - reference to another axis with an optional offset (see `SizeReference`) 

680 """ 

681 

682 

683class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

684 pass 

685 

686 

687class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

688 pass 

689 

690 

691def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

692 if isinstance(v, dict): 

693 return "with_halo" if "halo" in v else "wo_halo" 

694 else: 

695 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

696 

697 

698_TimeOutputAxisUnion = Annotated[ 

699 Union[ 

700 Annotated[TimeOutputAxis, Tag("wo_halo")], 

701 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

702 ], 

703 Discriminator(_get_halo_axis_discriminator_value), 

704] 

705 

706 

707class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

708 pass 

709 

710 

711class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

712 pass 

713 

714 

715_SpaceOutputAxisUnion = Annotated[ 

716 Union[ 

717 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

718 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

719 ], 

720 Discriminator(_get_halo_axis_discriminator_value), 

721] 

722 

723 

724_OutputAxisUnion = Union[ 

725 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

726] 

727OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

728 

729OUTPUT_AXIS_TYPES = ( 

730 BatchAxis, 

731 ChannelAxis, 

732 IndexOutputAxis, 

733 TimeOutputAxis, 

734 TimeOutputAxisWithHalo, 

735 SpaceOutputAxis, 

736 SpaceOutputAxisWithHalo, 

737) 

738"""intended for isinstance comparisons in py<3.10""" 

739 

740 

741AnyAxis = Union[InputAxis, OutputAxis] 

742 

743ANY_AXIS_TYPES = INPUT_AXIS_TYPES + OUTPUT_AXIS_TYPES 

744"""intended for isinstance comparisons in py<3.10""" 

745 

746TVs = Union[ 

747 NotEmpty[List[int]], 

748 NotEmpty[List[float]], 

749 NotEmpty[List[bool]], 

750 NotEmpty[List[str]], 

751] 

752 

753 

754NominalOrOrdinalDType = Literal[ 

755 "float32", 

756 "float64", 

757 "uint8", 

758 "int8", 

759 "uint16", 

760 "int16", 

761 "uint32", 

762 "int32", 

763 "uint64", 

764 "int64", 

765 "bool", 

766] 

767 

768 

769class NominalOrOrdinalDataDescr(Node): 

770 values: TVs 

771 """A fixed set of nominal or an ascending sequence of ordinal values. 

772 In this case `data.type` is required to be an unsigend integer type, e.g. 'uint8'. 

773 String `values` are interpreted as labels for tensor values 0, ..., N. 

774 Note: as YAML 1.2 does not natively support a "set" datatype, 

775 nominal values should be given as a sequence (aka list/array) as well. 

776 """ 

777 

778 type: Annotated[ 

779 NominalOrOrdinalDType, 

780 Field( 

781 examples=[ 

782 "float32", 

783 "uint8", 

784 "uint16", 

785 "int64", 

786 "bool", 

787 ], 

788 ), 

789 ] = "uint8" 

790 

791 @model_validator(mode="after") 

792 def _validate_values_match_type( 

793 self, 

794 ) -> Self: 

795 incompatible: List[Any] = [] 

796 for v in self.values: 

797 if self.type == "bool": 

798 if not isinstance(v, bool): 

799 incompatible.append(v) 

800 elif self.type in DTYPE_LIMITS: 

801 if ( 

802 isinstance(v, (int, float)) 

803 and ( 

804 v < DTYPE_LIMITS[self.type].min 

805 or v > DTYPE_LIMITS[self.type].max 

806 ) 

807 or (isinstance(v, str) and "uint" not in self.type) 

808 or (isinstance(v, float) and "int" in self.type) 

809 ): 

810 incompatible.append(v) 

811 else: 

812 incompatible.append(v) 

813 

814 if len(incompatible) == 5: 

815 incompatible.append("...") 

816 break 

817 

818 if incompatible: 

819 raise ValueError( 

820 f"data type '{self.type}' incompatible with values {incompatible}" 

821 ) 

822 

823 return self 

824 

825 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

826 

827 @property 

828 def range(self): 

829 if isinstance(self.values[0], str): 

830 return 0, len(self.values) - 1 

831 else: 

832 return min(self.values), max(self.values) 

833 

834 

835IntervalOrRatioDType = Literal[ 

836 "float32", 

837 "float64", 

838 "uint8", 

839 "int8", 

840 "uint16", 

841 "int16", 

842 "uint32", 

843 "int32", 

844 "uint64", 

845 "int64", 

846] 

847 

848 

849class IntervalOrRatioDataDescr(Node): 

850 type: Annotated[ # todo: rename to dtype 

851 IntervalOrRatioDType, 

852 Field( 

853 examples=["float32", "float64", "uint8", "uint16"], 

854 ), 

855 ] = "float32" 

856 range: Tuple[Optional[float], Optional[float]] = ( 

857 None, 

858 None, 

859 ) 

860 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

861 `None` corresponds to min/max of what can be expressed by **type**.""" 

862 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

863 scale: float = 1.0 

864 """Scale for data on an interval (or ratio) scale.""" 

865 offset: Optional[float] = None 

866 """Offset for data on a ratio scale.""" 

867 

868 

869TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

870 

871 

872class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 

873 """processing base class""" 

874 

875 

876class BinarizeKwargs(ProcessingKwargs): 

877 """key word arguments for `BinarizeDescr`""" 

878 

879 threshold: float 

880 """The fixed threshold""" 

881 

882 

883class BinarizeAlongAxisKwargs(ProcessingKwargs): 

884 """key word arguments for `BinarizeDescr`""" 

885 

886 threshold: NotEmpty[List[float]] 

887 """The fixed threshold values along `axis`""" 

888 

889 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

890 """The `threshold` axis""" 

891 

892 

893class BinarizeDescr(ProcessingDescrBase): 

894 """Binarize the tensor with a fixed threshold. 

895 

896 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 

897 will be set to one, values below the threshold to zero. 

898 

899 Examples: 

900 - in YAML 

901 ```yaml 

902 postprocessing: 

903 - id: binarize 

904 kwargs: 

905 axis: 'channel' 

906 threshold: [0.25, 0.5, 0.75] 

907 ``` 

908 - in Python: 

909 >>> postprocessing = [BinarizeDescr( 

910 ... kwargs=BinarizeAlongAxisKwargs( 

911 ... axis=AxisId('channel'), 

912 ... threshold=[0.25, 0.5, 0.75], 

913 ... ) 

914 ... )] 

915 """ 

916 

917 implemented_id: ClassVar[Literal["binarize"]] = "binarize" 

918 if TYPE_CHECKING: 

919 id: Literal["binarize"] = "binarize" 

920 else: 

921 id: Literal["binarize"] 

922 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

923 

924 

925class ClipDescr(ProcessingDescrBase): 

926 """Set tensor values below min to min and above max to max. 

927 

928 See `ScaleRangeDescr` for examples. 

929 """ 

930 

931 implemented_id: ClassVar[Literal["clip"]] = "clip" 

932 if TYPE_CHECKING: 

933 id: Literal["clip"] = "clip" 

934 else: 

935 id: Literal["clip"] 

936 

937 kwargs: ClipKwargs 

938 

939 

940class EnsureDtypeKwargs(ProcessingKwargs): 

941 """key word arguments for `EnsureDtypeDescr`""" 

942 

943 dtype: Literal[ 

944 "float32", 

945 "float64", 

946 "uint8", 

947 "int8", 

948 "uint16", 

949 "int16", 

950 "uint32", 

951 "int32", 

952 "uint64", 

953 "int64", 

954 "bool", 

955 ] 

956 

957 

958class EnsureDtypeDescr(ProcessingDescrBase): 

959 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

960 

961 This can for example be used to ensure the inner neural network model gets a 

962 different input tensor data type than the fully described bioimage.io model does. 

963 

964 Examples: 

965 The described bioimage.io model (incl. preprocessing) accepts any 

966 float32-compatible tensor, normalizes it with percentiles and clipping and then 

967 casts it to uint8, which is what the neural network in this example expects. 

968 - in YAML 

969 ```yaml 

970 inputs: 

971 - data: 

972 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

973 preprocessing: 

974 - id: scale_range 

975 kwargs: 

976 axes: ['y', 'x'] 

977 max_percentile: 99.8 

978 min_percentile: 5.0 

979 - id: clip 

980 kwargs: 

981 min: 0.0 

982 max: 1.0 

983 - id: ensure_dtype 

984 kwargs: 

985 dtype: uint8 

986 ``` 

987 - in Python: 

988 >>> preprocessing = [ 

989 ... ScaleRangeDescr( 

990 ... kwargs=ScaleRangeKwargs( 

991 ... axes= (AxisId('y'), AxisId('x')), 

992 ... max_percentile= 99.8, 

993 ... min_percentile= 5.0, 

994 ... ) 

995 ... ), 

996 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

997 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

998 ... ] 

999 """ 

1000 

1001 implemented_id: ClassVar[Literal["ensure_dtype"]] = "ensure_dtype" 

1002 if TYPE_CHECKING: 

1003 id: Literal["ensure_dtype"] = "ensure_dtype" 

1004 else: 

1005 id: Literal["ensure_dtype"] 

1006 

1007 kwargs: EnsureDtypeKwargs 

1008 

1009 

1010class ScaleLinearKwargs(ProcessingKwargs): 

1011 """Key word arguments for `ScaleLinearDescr`""" 

1012 

1013 gain: float = 1.0 

1014 """multiplicative factor""" 

1015 

1016 offset: float = 0.0 

1017 """additive term""" 

1018 

1019 @model_validator(mode="after") 

1020 def _validate(self) -> Self: 

1021 if self.gain == 1.0 and self.offset == 0.0: 

1022 raise ValueError( 

1023 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1024 + " != 0.0." 

1025 ) 

1026 

1027 return self 

1028 

1029 

1030class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 

1031 """Key word arguments for `ScaleLinearDescr`""" 

1032 

1033 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

1034 """The axis of gain and offset values.""" 

1035 

1036 gain: Union[float, NotEmpty[List[float]]] = 1.0 

1037 """multiplicative factor""" 

1038 

1039 offset: Union[float, NotEmpty[List[float]]] = 0.0 

1040 """additive term""" 

1041 

1042 @model_validator(mode="after") 

1043 def _validate(self) -> Self: 

1044 

1045 if isinstance(self.gain, list): 

1046 if isinstance(self.offset, list): 

1047 if len(self.gain) != len(self.offset): 

1048 raise ValueError( 

1049 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

1050 ) 

1051 else: 

1052 self.offset = [float(self.offset)] * len(self.gain) 

1053 elif isinstance(self.offset, list): 

1054 self.gain = [float(self.gain)] * len(self.offset) 

1055 else: 

1056 raise ValueError( 

1057 "Do not specify an `axis` for scalar gain and offset values." 

1058 ) 

1059 

1060 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

1061 raise ValueError( 

1062 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

1063 + " != 0.0." 

1064 ) 

1065 

1066 return self 

1067 

1068 

1069class ScaleLinearDescr(ProcessingDescrBase): 

1070 """Fixed linear scaling. 

1071 

1072 Examples: 

1073 1. Scale with scalar gain and offset 

1074 - in YAML 

1075 ```yaml 

1076 preprocessing: 

1077 - id: scale_linear 

1078 kwargs: 

1079 gain: 2.0 

1080 offset: 3.0 

1081 ``` 

1082 - in Python: 

1083 >>> preprocessing = [ 

1084 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

1085 ... ] 

1086 

1087 2. Independent scaling along an axis 

1088 - in YAML 

1089 ```yaml 

1090 preprocessing: 

1091 - id: scale_linear 

1092 kwargs: 

1093 axis: 'channel' 

1094 gain: [1.0, 2.0, 3.0] 

1095 ``` 

1096 - in Python: 

1097 >>> preprocessing = [ 

1098 ... ScaleLinearDescr( 

1099 ... kwargs=ScaleLinearAlongAxisKwargs( 

1100 ... axis=AxisId("channel"), 

1101 ... gain=[1.0, 2.0, 3.0], 

1102 ... ) 

1103 ... ) 

1104 ... ] 

1105 

1106 """ 

1107 

1108 implemented_id: ClassVar[Literal["scale_linear"]] = "scale_linear" 

1109 if TYPE_CHECKING: 

1110 id: Literal["scale_linear"] = "scale_linear" 

1111 else: 

1112 id: Literal["scale_linear"] 

1113 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1114 

1115 

1116class SigmoidDescr(ProcessingDescrBase): 

1117 """The logistic sigmoid funciton, a.k.a. expit function. 

1118 

1119 Examples: 

1120 - in YAML 

1121 ```yaml 

1122 postprocessing: 

1123 - id: sigmoid 

1124 ``` 

1125 - in Python: 

1126 >>> postprocessing = [SigmoidDescr()] 

1127 """ 

1128 

1129 implemented_id: ClassVar[Literal["sigmoid"]] = "sigmoid" 

1130 if TYPE_CHECKING: 

1131 id: Literal["sigmoid"] = "sigmoid" 

1132 else: 

1133 id: Literal["sigmoid"] 

1134 

1135 @property 

1136 def kwargs(self) -> ProcessingKwargs: 

1137 """empty kwargs""" 

1138 return ProcessingKwargs() 

1139 

1140 

1141class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1142 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1143 

1144 mean: float 

1145 """The mean value to normalize with.""" 

1146 

1147 std: Annotated[float, Ge(1e-6)] 

1148 """The standard deviation value to normalize with.""" 

1149 

1150 

1151class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 

1152 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1153 

1154 mean: NotEmpty[List[float]] 

1155 """The mean value(s) to normalize with.""" 

1156 

1157 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1158 """The standard deviation value(s) to normalize with. 

1159 Size must match `mean` values.""" 

1160 

1161 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1162 """The axis of the mean/std values to normalize each entry along that dimension 

1163 separately.""" 

1164 

1165 @model_validator(mode="after") 

1166 def _mean_and_std_match(self) -> Self: 

1167 if len(self.mean) != len(self.std): 

1168 raise ValueError( 

1169 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1170 + " must match." 

1171 ) 

1172 

1173 return self 

1174 

1175 

1176class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1177 """Subtract a given mean and divide by the standard deviation. 

1178 

1179 Normalize with fixed, precomputed values for 

1180 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1181 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1182 axes. 

1183 

1184 Examples: 

1185 1. scalar value for whole tensor 

1186 - in YAML 

1187 ```yaml 

1188 preprocessing: 

1189 - id: fixed_zero_mean_unit_variance 

1190 kwargs: 

1191 mean: 103.5 

1192 std: 13.7 

1193 ``` 

1194 - in Python 

1195 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1196 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1197 ... )] 

1198 

1199 2. independently along an axis 

1200 - in YAML 

1201 ```yaml 

1202 preprocessing: 

1203 - id: fixed_zero_mean_unit_variance 

1204 kwargs: 

1205 axis: channel 

1206 mean: [101.5, 102.5, 103.5] 

1207 std: [11.7, 12.7, 13.7] 

1208 ``` 

1209 - in Python 

1210 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1211 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1212 ... axis=AxisId("channel"), 

1213 ... mean=[101.5, 102.5, 103.5], 

1214 ... std=[11.7, 12.7, 13.7], 

1215 ... ) 

1216 ... )] 

1217 """ 

1218 

1219 implemented_id: ClassVar[Literal["fixed_zero_mean_unit_variance"]] = ( 

1220 "fixed_zero_mean_unit_variance" 

1221 ) 

1222 if TYPE_CHECKING: 

1223 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1224 else: 

1225 id: Literal["fixed_zero_mean_unit_variance"] 

1226 

1227 kwargs: Union[ 

1228 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1229 ] 

1230 

1231 

1232class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1233 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

1234 

1235 axes: Annotated[ 

1236 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1237 ] = None 

1238 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1239 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1240 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1241 To normalize each sample independently leave out the 'batch' axis. 

1242 Default: Scale all axes jointly.""" 

1243 

1244 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1245 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1246 

1247 

1248class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1249 """Subtract mean and divide by variance. 

1250 

1251 Examples: 

1252 Subtract tensor mean and variance 

1253 - in YAML 

1254 ```yaml 

1255 preprocessing: 

1256 - id: zero_mean_unit_variance 

1257 ``` 

1258 - in Python 

1259 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1260 """ 

1261 

1262 implemented_id: ClassVar[Literal["zero_mean_unit_variance"]] = ( 

1263 "zero_mean_unit_variance" 

1264 ) 

1265 if TYPE_CHECKING: 

1266 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1267 else: 

1268 id: Literal["zero_mean_unit_variance"] 

1269 

1270 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1271 default_factory=ZeroMeanUnitVarianceKwargs 

1272 ) 

1273 

1274 

1275class ScaleRangeKwargs(ProcessingKwargs): 

1276 """key word arguments for `ScaleRangeDescr` 

1277 

1278 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1279 this processing step normalizes data to the [0, 1] intervall. 

1280 For other percentiles the normalized values will partially be outside the [0, 1] 

1281 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1282 normalized values to a range. 

1283 """ 

1284 

1285 axes: Annotated[ 

1286 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1287 ] = None 

1288 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1289 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1290 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1291 To normalize samples independently, leave out the "batch" axis. 

1292 Default: Scale all axes jointly.""" 

1293 

1294 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1295 """The lower percentile used to determine the value to align with zero.""" 

1296 

1297 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1298 """The upper percentile used to determine the value to align with one. 

1299 Has to be bigger than `min_percentile`. 

1300 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1301 accepting percentiles specified in the range 0.0 to 1.0.""" 

1302 

1303 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1304 """Epsilon for numeric stability. 

1305 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1306 with `v_lower,v_upper` values at the respective percentiles.""" 

1307 

1308 reference_tensor: Optional[TensorId] = None 

1309 """Tensor ID to compute the percentiles from. Default: The tensor itself. 

1310 For any tensor in `inputs` only input tensor references are allowed.""" 

1311 

1312 @field_validator("max_percentile", mode="after") 

1313 @classmethod 

1314 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1315 if (min_p := info.data["min_percentile"]) >= value: 

1316 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1317 

1318 return value 

1319 

1320 

1321class ScaleRangeDescr(ProcessingDescrBase): 

1322 """Scale with percentiles. 

1323 

1324 Examples: 

1325 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1326 - in YAML 

1327 ```yaml 

1328 preprocessing: 

1329 - id: scale_range 

1330 kwargs: 

1331 axes: ['y', 'x'] 

1332 max_percentile: 99.8 

1333 min_percentile: 5.0 

1334 ``` 

1335 - in Python 

1336 >>> preprocessing = [ 

1337 ... ScaleRangeDescr( 

1338 ... kwargs=ScaleRangeKwargs( 

1339 ... axes= (AxisId('y'), AxisId('x')), 

1340 ... max_percentile= 99.8, 

1341 ... min_percentile= 5.0, 

1342 ... ) 

1343 ... ), 

1344 ... ClipDescr( 

1345 ... kwargs=ClipKwargs( 

1346 ... min=0.0, 

1347 ... max=1.0, 

1348 ... ) 

1349 ... ), 

1350 ... ] 

1351 

1352 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1353 - in YAML 

1354 ```yaml 

1355 preprocessing: 

1356 - id: scale_range 

1357 kwargs: 

1358 axes: ['y', 'x'] 

1359 max_percentile: 99.8 

1360 min_percentile: 5.0 

1361 - id: scale_range 

1362 - id: clip 

1363 kwargs: 

1364 min: 0.0 

1365 max: 1.0 

1366 ``` 

1367 - in Python 

1368 >>> preprocessing = [ScaleRangeDescr( 

1369 ... kwargs=ScaleRangeKwargs( 

1370 ... axes= (AxisId('y'), AxisId('x')), 

1371 ... max_percentile= 99.8, 

1372 ... min_percentile= 5.0, 

1373 ... ) 

1374 ... )] 

1375 

1376 """ 

1377 

1378 implemented_id: ClassVar[Literal["scale_range"]] = "scale_range" 

1379 if TYPE_CHECKING: 

1380 id: Literal["scale_range"] = "scale_range" 

1381 else: 

1382 id: Literal["scale_range"] 

1383 kwargs: ScaleRangeKwargs 

1384 

1385 

1386class ScaleMeanVarianceKwargs(ProcessingKwargs): 

1387 """key word arguments for `ScaleMeanVarianceKwargs`""" 

1388 

1389 reference_tensor: TensorId 

1390 """Name of tensor to match.""" 

1391 

1392 axes: Annotated[ 

1393 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1394 ] = None 

1395 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1396 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1397 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1398 To normalize samples independently, leave out the 'batch' axis. 

1399 Default: Scale all axes jointly.""" 

1400 

1401 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1402 """Epsilon for numeric stability: 

1403 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1404 

1405 

1406class ScaleMeanVarianceDescr(ProcessingDescrBase): 

1407 """Scale a tensor's data distribution to match another tensor's mean/std. 

1408 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1409 """ 

1410 

1411 implemented_id: ClassVar[Literal["scale_mean_variance"]] = "scale_mean_variance" 

1412 if TYPE_CHECKING: 

1413 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1414 else: 

1415 id: Literal["scale_mean_variance"] 

1416 kwargs: ScaleMeanVarianceKwargs 

1417 

1418 

1419PreprocessingDescr = Annotated[ 

1420 Union[ 

1421 BinarizeDescr, 

1422 ClipDescr, 

1423 EnsureDtypeDescr, 

1424 ScaleLinearDescr, 

1425 SigmoidDescr, 

1426 FixedZeroMeanUnitVarianceDescr, 

1427 ZeroMeanUnitVarianceDescr, 

1428 ScaleRangeDescr, 

1429 ], 

1430 Discriminator("id"), 

1431] 

1432PostprocessingDescr = Annotated[ 

1433 Union[ 

1434 BinarizeDescr, 

1435 ClipDescr, 

1436 EnsureDtypeDescr, 

1437 ScaleLinearDescr, 

1438 SigmoidDescr, 

1439 FixedZeroMeanUnitVarianceDescr, 

1440 ZeroMeanUnitVarianceDescr, 

1441 ScaleRangeDescr, 

1442 ScaleMeanVarianceDescr, 

1443 ], 

1444 Discriminator("id"), 

1445] 

1446 

1447IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1448 

1449 

1450class TensorDescrBase(Node, Generic[IO_AxisT]): 

1451 id: TensorId 

1452 """Tensor id. No duplicates are allowed.""" 

1453 

1454 description: Annotated[str, MaxLen(128)] = "" 

1455 """free text description""" 

1456 

1457 axes: NotEmpty[Sequence[IO_AxisT]] 

1458 """tensor axes""" 

1459 

1460 @property 

1461 def shape(self): 

1462 return tuple(a.size for a in self.axes) 

1463 

1464 @field_validator("axes", mode="after", check_fields=False) 

1465 @classmethod 

1466 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1467 batch_axes = [a for a in axes if a.type == "batch"] 

1468 if len(batch_axes) > 1: 

1469 raise ValueError( 

1470 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1471 ) 

1472 

1473 seen_ids: Set[AxisId] = set() 

1474 duplicate_axes_ids: Set[AxisId] = set() 

1475 for a in axes: 

1476 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1477 

1478 if duplicate_axes_ids: 

1479 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1480 

1481 return axes 

1482 

1483 test_tensor: FileDescr 

1484 """An example tensor to use for testing. 

1485 Using the model with the test input tensors is expected to yield the test output tensors. 

1486 Each test tensor has be a an ndarray in the 

1487 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1488 The file extension must be '.npy'.""" 

1489 

1490 sample_tensor: Optional[FileDescr] = None 

1491 """A sample tensor to illustrate a possible input/output for the model, 

1492 The sample image primarily serves to inform a human user about an example use case 

1493 and is typically stored as .hdf5, .png or .tiff. 

1494 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1495 (numpy's `.npy` format is not supported). 

1496 The image dimensionality has to match the number of axes specified in this tensor description. 

1497 """ 

1498 

1499 @model_validator(mode="after") 

1500 def _validate_sample_tensor(self) -> Self: 

1501 if self.sample_tensor is None or not get_validation_context().perform_io_checks: 

1502 return self 

1503 

1504 local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1505 tensor: NDArray[Any] = imread( 

1506 local.path.read_bytes(), 

1507 extension=PurePosixPath(local.original_file_name).suffix, 

1508 ) 

1509 n_dims = len(tensor.squeeze().shape) 

1510 n_dims_min = n_dims_max = len(self.axes) 

1511 

1512 for a in self.axes: 

1513 if isinstance(a, BatchAxis): 

1514 n_dims_min -= 1 

1515 elif isinstance(a.size, int): 

1516 if a.size == 1: 

1517 n_dims_min -= 1 

1518 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1519 if a.size.min == 1: 

1520 n_dims_min -= 1 

1521 elif isinstance(a.size, SizeReference): 

1522 if a.size.offset < 2: 

1523 # size reference may result in singleton axis 

1524 n_dims_min -= 1 

1525 else: 

1526 assert_never(a.size) 

1527 

1528 n_dims_min = max(0, n_dims_min) 

1529 if n_dims < n_dims_min or n_dims > n_dims_max: 

1530 raise ValueError( 

1531 f"Expected sample tensor to have {n_dims_min} to" 

1532 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1533 ) 

1534 

1535 return self 

1536 

1537 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1538 IntervalOrRatioDataDescr() 

1539 ) 

1540 """Description of the tensor's data values, optionally per channel. 

1541 If specified per channel, the data `type` needs to match across channels.""" 

1542 

1543 @property 

1544 def dtype( 

1545 self, 

1546 ) -> Literal[ 

1547 "float32", 

1548 "float64", 

1549 "uint8", 

1550 "int8", 

1551 "uint16", 

1552 "int16", 

1553 "uint32", 

1554 "int32", 

1555 "uint64", 

1556 "int64", 

1557 "bool", 

1558 ]: 

1559 """dtype as specified under `data.type` or `data[i].type`""" 

1560 if isinstance(self.data, collections.abc.Sequence): 

1561 return self.data[0].type 

1562 else: 

1563 return self.data.type 

1564 

1565 @field_validator("data", mode="after") 

1566 @classmethod 

1567 def _check_data_type_across_channels( 

1568 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1569 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1570 if not isinstance(value, list): 

1571 return value 

1572 

1573 dtypes = {t.type for t in value} 

1574 if len(dtypes) > 1: 

1575 raise ValueError( 

1576 "Tensor data descriptions per channel need to agree in their data" 

1577 + f" `type`, but found {dtypes}." 

1578 ) 

1579 

1580 return value 

1581 

1582 @model_validator(mode="after") 

1583 def _check_data_matches_channelaxis(self) -> Self: 

1584 if not isinstance(self.data, (list, tuple)): 

1585 return self 

1586 

1587 for a in self.axes: 

1588 if isinstance(a, ChannelAxis): 

1589 size = a.size 

1590 assert isinstance(size, int) 

1591 break 

1592 else: 

1593 return self 

1594 

1595 if len(self.data) != size: 

1596 raise ValueError( 

1597 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1598 + f" '{a.id}' axis has size {size}." 

1599 ) 

1600 

1601 return self 

1602 

1603 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1604 if len(array.shape) != len(self.axes): 

1605 raise ValueError( 

1606 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1607 + f" incompatible with {len(self.axes)} axes." 

1608 ) 

1609 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1610 

1611 

1612class InputTensorDescr(TensorDescrBase[InputAxis]): 

1613 id: TensorId = TensorId("input") 

1614 """Input tensor id. 

1615 No duplicates are allowed across all inputs and outputs.""" 

1616 

1617 optional: bool = False 

1618 """indicates that this tensor may be `None`""" 

1619 

1620 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 

1621 """Description of how this input should be preprocessed. 

1622 

1623 notes: 

1624 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1625 to ensure an input tensor's data type matches the input tensor's data description. 

1626 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1627 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1628 changing the data type. 

1629 """ 

1630 

1631 @model_validator(mode="after") 

1632 def _validate_preprocessing_kwargs(self) -> Self: 

1633 axes_ids = [a.id for a in self.axes] 

1634 for p in self.preprocessing: 

1635 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1636 if kwargs_axes is None: 

1637 continue 

1638 

1639 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1640 raise ValueError( 

1641 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1642 ) 

1643 

1644 if any(a not in axes_ids for a in kwargs_axes): 

1645 raise ValueError( 

1646 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1647 ) 

1648 

1649 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1650 dtype = self.data.type 

1651 else: 

1652 dtype = self.data[0].type 

1653 

1654 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1655 if not self.preprocessing or not isinstance( 

1656 self.preprocessing[0], EnsureDtypeDescr 

1657 ): 

1658 self.preprocessing.insert( 

1659 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1660 ) 

1661 

1662 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1663 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1664 self.preprocessing.append( 

1665 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1666 ) 

1667 

1668 return self 

1669 

1670 

1671def convert_axes( 

1672 axes: str, 

1673 *, 

1674 shape: Union[ 

1675 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1676 ], 

1677 tensor_type: Literal["input", "output"], 

1678 halo: Optional[Sequence[int]], 

1679 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1680): 

1681 ret: List[AnyAxis] = [] 

1682 for i, a in enumerate(axes): 

1683 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1684 if axis_type == "batch": 

1685 ret.append(BatchAxis()) 

1686 continue 

1687 

1688 scale = 1.0 

1689 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1690 if shape.step[i] == 0: 

1691 size = shape.min[i] 

1692 else: 

1693 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1694 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1695 ref_t = str(shape.reference_tensor) 

1696 if ref_t.count(".") == 1: 

1697 t_id, orig_a_id = ref_t.split(".") 

1698 else: 

1699 t_id = ref_t 

1700 orig_a_id = a 

1701 

1702 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1703 if not (orig_scale := shape.scale[i]): 

1704 # old way to insert a new axis dimension 

1705 size = int(2 * shape.offset[i]) 

1706 else: 

1707 scale = 1 / orig_scale 

1708 if axis_type in ("channel", "index"): 

1709 # these axes no longer have a scale 

1710 offset_from_scale = orig_scale * size_refs.get( 

1711 _TensorName_v0_4(t_id), {} 

1712 ).get(orig_a_id, 0) 

1713 else: 

1714 offset_from_scale = 0 

1715 size = SizeReference( 

1716 tensor_id=TensorId(t_id), 

1717 axis_id=AxisId(a_id), 

1718 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1719 ) 

1720 else: 

1721 size = shape[i] 

1722 

1723 if axis_type == "time": 

1724 if tensor_type == "input": 

1725 ret.append(TimeInputAxis(size=size, scale=scale)) 

1726 else: 

1727 assert not isinstance(size, ParameterizedSize) 

1728 if halo is None: 

1729 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1730 else: 

1731 assert not isinstance(size, int) 

1732 ret.append( 

1733 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1734 ) 

1735 

1736 elif axis_type == "index": 

1737 if tensor_type == "input": 

1738 ret.append(IndexInputAxis(size=size)) 

1739 else: 

1740 if isinstance(size, ParameterizedSize): 

1741 size = DataDependentSize(min=size.min) 

1742 

1743 ret.append(IndexOutputAxis(size=size)) 

1744 elif axis_type == "channel": 

1745 assert not isinstance(size, ParameterizedSize) 

1746 if isinstance(size, SizeReference): 

1747 warnings.warn( 

1748 "Conversion of channel size from an implicit output shape may be" 

1749 + " wrong" 

1750 ) 

1751 ret.append( 

1752 ChannelAxis( 

1753 channel_names=[ 

1754 Identifier(f"channel{i}") for i in range(size.offset) 

1755 ] 

1756 ) 

1757 ) 

1758 else: 

1759 ret.append( 

1760 ChannelAxis( 

1761 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1762 ) 

1763 ) 

1764 elif axis_type == "space": 

1765 if tensor_type == "input": 

1766 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

1767 else: 

1768 assert not isinstance(size, ParameterizedSize) 

1769 if halo is None or halo[i] == 0: 

1770 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

1771 elif isinstance(size, int): 

1772 raise NotImplementedError( 

1773 f"output axis with halo and fixed size (here {size}) not allowed" 

1774 ) 

1775 else: 

1776 ret.append( 

1777 SpaceOutputAxisWithHalo( 

1778 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

1779 ) 

1780 ) 

1781 

1782 return ret 

1783 

1784 

1785def _axes_letters_to_ids( 

1786 axes: Optional[str], 

1787) -> Optional[List[AxisId]]: 

1788 if axes is None: 

1789 return None 

1790 

1791 return [AxisId(a) for a in axes] 

1792 

1793 

1794def _get_complement_v04_axis( 

1795 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

1796) -> Optional[AxisId]: 

1797 if axes is None: 

1798 return None 

1799 

1800 non_complement_axes = set(axes) | {"b"} 

1801 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

1802 if len(complement_axes) > 1: 

1803 raise ValueError( 

1804 f"Expected none or a single complement axis, but axes '{axes}' " 

1805 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

1806 ) 

1807 

1808 return None if not complement_axes else AxisId(complement_axes[0]) 

1809 

1810 

1811def _convert_proc( 

1812 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

1813 tensor_axes: Sequence[str], 

1814) -> Union[PreprocessingDescr, PostprocessingDescr]: 

1815 if isinstance(p, _BinarizeDescr_v0_4): 

1816 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

1817 elif isinstance(p, _ClipDescr_v0_4): 

1818 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

1819 elif isinstance(p, _SigmoidDescr_v0_4): 

1820 return SigmoidDescr() 

1821 elif isinstance(p, _ScaleLinearDescr_v0_4): 

1822 axes = _axes_letters_to_ids(p.kwargs.axes) 

1823 if p.kwargs.axes is None: 

1824 axis = None 

1825 else: 

1826 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1827 

1828 if axis is None: 

1829 assert not isinstance(p.kwargs.gain, list) 

1830 assert not isinstance(p.kwargs.offset, list) 

1831 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

1832 else: 

1833 kwargs = ScaleLinearAlongAxisKwargs( 

1834 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

1835 ) 

1836 return ScaleLinearDescr(kwargs=kwargs) 

1837 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

1838 return ScaleMeanVarianceDescr( 

1839 kwargs=ScaleMeanVarianceKwargs( 

1840 axes=_axes_letters_to_ids(p.kwargs.axes), 

1841 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

1842 eps=p.kwargs.eps, 

1843 ) 

1844 ) 

1845 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

1846 if p.kwargs.mode == "fixed": 

1847 mean = p.kwargs.mean 

1848 std = p.kwargs.std 

1849 assert mean is not None 

1850 assert std is not None 

1851 

1852 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1853 

1854 if axis is None: 

1855 return FixedZeroMeanUnitVarianceDescr( 

1856 kwargs=FixedZeroMeanUnitVarianceKwargs( 

1857 mean=mean, std=std # pyright: ignore[reportArgumentType] 

1858 ) 

1859 ) 

1860 else: 

1861 if not isinstance(mean, list): 

1862 mean = [float(mean)] 

1863 if not isinstance(std, list): 

1864 std = [float(std)] 

1865 

1866 return FixedZeroMeanUnitVarianceDescr( 

1867 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1868 axis=axis, mean=mean, std=std 

1869 ) 

1870 ) 

1871 

1872 else: 

1873 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

1874 if p.kwargs.mode == "per_dataset": 

1875 axes = [AxisId("batch")] + axes 

1876 if not axes: 

1877 axes = None 

1878 return ZeroMeanUnitVarianceDescr( 

1879 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

1880 ) 

1881 

1882 elif isinstance(p, _ScaleRangeDescr_v0_4): 

1883 return ScaleRangeDescr( 

1884 kwargs=ScaleRangeKwargs( 

1885 axes=_axes_letters_to_ids(p.kwargs.axes), 

1886 min_percentile=p.kwargs.min_percentile, 

1887 max_percentile=p.kwargs.max_percentile, 

1888 eps=p.kwargs.eps, 

1889 ) 

1890 ) 

1891 else: 

1892 assert_never(p) 

1893 

1894 

1895class _InputTensorConv( 

1896 Converter[ 

1897 _InputTensorDescr_v0_4, 

1898 InputTensorDescr, 

1899 ImportantFileSource, 

1900 Optional[ImportantFileSource], 

1901 Mapping[_TensorName_v0_4, Mapping[str, int]], 

1902 ] 

1903): 

1904 def _convert( 

1905 self, 

1906 src: _InputTensorDescr_v0_4, 

1907 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

1908 test_tensor: ImportantFileSource, 

1909 sample_tensor: Optional[ImportantFileSource], 

1910 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1911 ) -> "InputTensorDescr | dict[str, Any]": 

1912 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

1913 src.axes, 

1914 shape=src.shape, 

1915 tensor_type="input", 

1916 halo=None, 

1917 size_refs=size_refs, 

1918 ) 

1919 prep: List[PreprocessingDescr] = [] 

1920 for p in src.preprocessing: 

1921 cp = _convert_proc(p, src.axes) 

1922 assert not isinstance(cp, ScaleMeanVarianceDescr) 

1923 prep.append(cp) 

1924 

1925 prep.append(EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="float32"))) 

1926 

1927 return tgt( 

1928 axes=axes, 

1929 id=TensorId(str(src.name)), 

1930 test_tensor=FileDescr(source=test_tensor), 

1931 sample_tensor=( 

1932 None if sample_tensor is None else FileDescr(source=sample_tensor) 

1933 ), 

1934 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

1935 preprocessing=prep, 

1936 ) 

1937 

1938 

1939_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

1940 

1941 

1942class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

1943 id: TensorId = TensorId("output") 

1944 """Output tensor id. 

1945 No duplicates are allowed across all inputs and outputs.""" 

1946 

1947 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 

1948 """Description of how this output should be postprocessed. 

1949 

1950 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

1951 If not given this is added to cast to this tensor's `data.type`. 

1952 """ 

1953 

1954 @model_validator(mode="after") 

1955 def _validate_postprocessing_kwargs(self) -> Self: 

1956 axes_ids = [a.id for a in self.axes] 

1957 for p in self.postprocessing: 

1958 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1959 if kwargs_axes is None: 

1960 continue 

1961 

1962 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1963 raise ValueError( 

1964 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

1965 ) 

1966 

1967 if any(a not in axes_ids for a in kwargs_axes): 

1968 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

1969 

1970 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1971 dtype = self.data.type 

1972 else: 

1973 dtype = self.data[0].type 

1974 

1975 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1976 if not self.postprocessing or not isinstance( 

1977 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

1978 ): 

1979 self.postprocessing.append( 

1980 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1981 ) 

1982 return self 

1983 

1984 

1985class _OutputTensorConv( 

1986 Converter[ 

1987 _OutputTensorDescr_v0_4, 

1988 OutputTensorDescr, 

1989 ImportantFileSource, 

1990 Optional[ImportantFileSource], 

1991 Mapping[_TensorName_v0_4, Mapping[str, int]], 

1992 ] 

1993): 

1994 def _convert( 

1995 self, 

1996 src: _OutputTensorDescr_v0_4, 

1997 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

1998 test_tensor: ImportantFileSource, 

1999 sample_tensor: Optional[ImportantFileSource], 

2000 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

2001 ) -> "OutputTensorDescr | dict[str, Any]": 

2002 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

2003 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

2004 src.axes, 

2005 shape=src.shape, 

2006 tensor_type="output", 

2007 halo=src.halo, 

2008 size_refs=size_refs, 

2009 ) 

2010 data_descr: Dict[str, Any] = dict(type=src.data_type) 

2011 if data_descr["type"] == "bool": 

2012 data_descr["values"] = [False, True] 

2013 

2014 return tgt( 

2015 axes=axes, 

2016 id=TensorId(str(src.name)), 

2017 test_tensor=FileDescr(source=test_tensor), 

2018 sample_tensor=( 

2019 None if sample_tensor is None else FileDescr(source=sample_tensor) 

2020 ), 

2021 data=data_descr, # pyright: ignore[reportArgumentType] 

2022 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

2023 ) 

2024 

2025 

2026_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

2027 

2028 

2029TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

2030 

2031 

2032def validate_tensors( 

2033 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 

2034 tensor_origin: Literal[ 

2035 "test_tensor" 

2036 ], # for more precise error messages, e.g. 'test_tensor' 

2037): 

2038 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 

2039 

2040 def e_msg(d: TensorDescr): 

2041 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

2042 

2043 for descr, array in tensors.values(): 

2044 try: 

2045 axis_sizes = descr.get_axis_sizes_for_array(array) 

2046 except ValueError as e: 

2047 raise ValueError(f"{e_msg(descr)} {e}") 

2048 else: 

2049 all_tensor_axes[descr.id] = { 

2050 a.id: (a, axis_sizes[a.id]) for a in descr.axes 

2051 } 

2052 

2053 for descr, array in tensors.values(): 

2054 if descr.dtype in ("float32", "float64"): 

2055 invalid_test_tensor_dtype = array.dtype.name not in ( 

2056 "float32", 

2057 "float64", 

2058 "uint8", 

2059 "int8", 

2060 "uint16", 

2061 "int16", 

2062 "uint32", 

2063 "int32", 

2064 "uint64", 

2065 "int64", 

2066 ) 

2067 else: 

2068 invalid_test_tensor_dtype = array.dtype.name != descr.dtype 

2069 

2070 if invalid_test_tensor_dtype: 

2071 raise ValueError( 

2072 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 

2073 + f" match described dtype '{descr.dtype}'" 

2074 ) 

2075 

2076 if array.min() > -1e-4 and array.max() < 1e-4: 

2077 raise ValueError( 

2078 "Output values are too small for reliable testing." 

2079 + f" Values <-1e5 or >=1e5 must be present in {tensor_origin}" 

2080 ) 

2081 

2082 for a in descr.axes: 

2083 actual_size = all_tensor_axes[descr.id][a.id][1] 

2084 if a.size is None: 

2085 continue 

2086 

2087 if isinstance(a.size, int): 

2088 if actual_size != a.size: 

2089 raise ValueError( 

2090 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 

2091 + f"has incompatible size {actual_size}, expected {a.size}" 

2092 ) 

2093 elif isinstance(a.size, ParameterizedSize): 

2094 _ = a.size.validate_size(actual_size) 

2095 elif isinstance(a.size, DataDependentSize): 

2096 _ = a.size.validate_size(actual_size) 

2097 elif isinstance(a.size, SizeReference): 

2098 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

2099 if ref_tensor_axes is None: 

2100 raise ValueError( 

2101 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

2102 + f" reference '{a.size.tensor_id}'" 

2103 ) 

2104 

2105 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

2106 if ref_axis is None or ref_size is None: 

2107 raise ValueError( 

2108 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

2109 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 

2110 ) 

2111 

2112 if a.unit != ref_axis.unit: 

2113 raise ValueError( 

2114 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 

2115 + " axis and reference axis to have the same `unit`, but" 

2116 + f" {a.unit}!={ref_axis.unit}" 

2117 ) 

2118 

2119 if actual_size != ( 

2120 expected_size := ( 

2121 ref_size * ref_axis.scale / a.scale + a.size.offset 

2122 ) 

2123 ): 

2124 raise ValueError( 

2125 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 

2126 + f" {actual_size} invalid for referenced size {ref_size};" 

2127 + f" expected {expected_size}" 

2128 ) 

2129 else: 

2130 assert_never(a.size) 

2131 

2132 

2133class EnvironmentFileDescr(FileDescr): 

2134 source: Annotated[ 

2135 ImportantFileSource, 

2136 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2137 Field( 

2138 examples=["environment.yaml"], 

2139 ), 

2140 ] 

2141 """∈📦 Conda environment file. 

2142 Allows to specify custom dependencies, see conda docs: 

2143 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2144 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2145 """ 

2146 

2147 

2148class _ArchitectureCallableDescr(Node): 

2149 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2150 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2151 

2152 kwargs: Dict[str, YamlValue] = Field(default_factory=dict) 

2153 """key word arguments for the `callable`""" 

2154 

2155 

2156class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2157 pass 

2158 

2159 

2160class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2161 import_from: str 

2162 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2163 

2164 

2165ArchitectureDescr = Annotated[ 

2166 Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], 

2167 Field(union_mode="left_to_right"), 

2168] 

2169 

2170 

2171class _ArchFileConv( 

2172 Converter[ 

2173 _CallableFromFile_v0_4, 

2174 ArchitectureFromFileDescr, 

2175 Optional[Sha256], 

2176 Dict[str, Any], 

2177 ] 

2178): 

2179 def _convert( 

2180 self, 

2181 src: _CallableFromFile_v0_4, 

2182 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2183 sha256: Optional[Sha256], 

2184 kwargs: Dict[str, Any], 

2185 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2186 if src.startswith("http") and src.count(":") == 2: 

2187 http, source, callable_ = src.split(":") 

2188 source = ":".join((http, source)) 

2189 elif not src.startswith("http") and src.count(":") == 1: 

2190 source, callable_ = src.split(":") 

2191 else: 

2192 source = str(src) 

2193 callable_ = str(src) 

2194 return tgt( 

2195 callable=Identifier(callable_), 

2196 source=cast(ImportantFileSource, source), 

2197 sha256=sha256, 

2198 kwargs=kwargs, 

2199 ) 

2200 

2201 

2202_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2203 

2204 

2205class _ArchLibConv( 

2206 Converter[ 

2207 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2208 ] 

2209): 

2210 def _convert( 

2211 self, 

2212 src: _CallableFromDepencency_v0_4, 

2213 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2214 kwargs: Dict[str, Any], 

2215 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2216 *mods, callable_ = src.split(".") 

2217 import_from = ".".join(mods) 

2218 return tgt( 

2219 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2220 ) 

2221 

2222 

2223_arch_lib_conv = _ArchLibConv( 

2224 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2225) 

2226 

2227 

2228class WeightsEntryDescrBase(FileDescr): 

2229 type: ClassVar[WeightsFormat] 

2230 weights_format_name: ClassVar[str] # human readable 

2231 

2232 source: ImportantFileSource 

2233 """∈📦 The weights file.""" 

2234 

2235 authors: Optional[List[Author]] = None 

2236 """Authors 

2237 Either the person(s) that have trained this model resulting in the original weights file. 

2238 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2239 Or the person(s) who have converted the weights to this weights format. 

2240 (If this is a child weight, i.e. it has a `parent` field) 

2241 """ 

2242 

2243 parent: Annotated[ 

2244 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2245 ] = None 

2246 """The source weights these weights were converted from. 

2247 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2248 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2249 All weight entries except one (the initial set of weights resulting from training the model), 

2250 need to have this field.""" 

2251 

2252 comment: str = "" 

2253 """A comment about this weights entry, for example how these weights were created.""" 

2254 

2255 @model_validator(mode="after") 

2256 def check_parent_is_not_self(self) -> Self: 

2257 if self.type == self.parent: 

2258 raise ValueError("Weights entry can't be it's own parent.") 

2259 

2260 return self 

2261 

2262 

2263class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2264 type = "keras_hdf5" 

2265 weights_format_name: ClassVar[str] = "Keras HDF5" 

2266 tensorflow_version: Version 

2267 """TensorFlow version used to create these weights.""" 

2268 

2269 

2270class OnnxWeightsDescr(WeightsEntryDescrBase): 

2271 type = "onnx" 

2272 weights_format_name: ClassVar[str] = "ONNX" 

2273 opset_version: Annotated[int, Ge(7)] 

2274 """ONNX opset version""" 

2275 

2276 

2277class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2278 type = "pytorch_state_dict" 

2279 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2280 architecture: ArchitectureDescr 

2281 pytorch_version: Version 

2282 """Version of the PyTorch library used. 

2283 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2284 """ 

2285 dependencies: Optional[EnvironmentFileDescr] = None 

2286 """Custom depencies beyond pytorch. 

2287 The conda environment file should include pytorch and any version pinning has to be compatible with 

2288 `pytorch_version`. 

2289 """ 

2290 

2291 

2292class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2293 type = "tensorflow_js" 

2294 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2295 tensorflow_version: Version 

2296 """Version of the TensorFlow library used.""" 

2297 

2298 source: ImportantFileSource 

2299 """∈📦 The multi-file weights. 

2300 All required files/folders should be a zip archive.""" 

2301 

2302 

2303class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2304 type = "tensorflow_saved_model_bundle" 

2305 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2306 tensorflow_version: Version 

2307 """Version of the TensorFlow library used.""" 

2308 

2309 dependencies: Optional[EnvironmentFileDescr] = None 

2310 """Custom dependencies beyond tensorflow. 

2311 Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`.""" 

2312 

2313 source: ImportantFileSource 

2314 """∈📦 The multi-file weights. 

2315 All required files/folders should be a zip archive.""" 

2316 

2317 

2318class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2319 type = "torchscript" 

2320 weights_format_name: ClassVar[str] = "TorchScript" 

2321 pytorch_version: Version 

2322 """Version of the PyTorch library used.""" 

2323 

2324 

2325class WeightsDescr(Node): 

2326 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2327 onnx: Optional[OnnxWeightsDescr] = None 

2328 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2329 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2330 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2331 None 

2332 ) 

2333 torchscript: Optional[TorchscriptWeightsDescr] = None 

2334 

2335 @model_validator(mode="after") 

2336 def check_entries(self) -> Self: 

2337 entries = {wtype for wtype, entry in self if entry is not None} 

2338 

2339 if not entries: 

2340 raise ValueError("Missing weights entry") 

2341 

2342 entries_wo_parent = { 

2343 wtype 

2344 for wtype, entry in self 

2345 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2346 } 

2347 if len(entries_wo_parent) != 1: 

2348 issue_warning( 

2349 "Exactly one weights entry may not specify the `parent` field (got" 

2350 + " {value}). That entry is considered the original set of model weights." 

2351 + " Other weight formats are created through conversion of the orignal or" 

2352 + " already converted weights. They have to reference the weights format" 

2353 + " they were converted from as their `parent`.", 

2354 value=len(entries_wo_parent), 

2355 field="weights", 

2356 ) 

2357 

2358 for wtype, entry in self: 

2359 if entry is None: 

2360 continue 

2361 

2362 assert hasattr(entry, "type") 

2363 assert hasattr(entry, "parent") 

2364 assert wtype == entry.type 

2365 if ( 

2366 entry.parent is not None and entry.parent not in entries 

2367 ): # self reference checked for `parent` field 

2368 raise ValueError( 

2369 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2370 + f" formats: {entries}" 

2371 ) 

2372 

2373 return self 

2374 

2375 def __getitem__( 

2376 self, 

2377 key: Literal[ 

2378 "keras_hdf5", 

2379 "onnx", 

2380 "pytorch_state_dict", 

2381 "tensorflow_js", 

2382 "tensorflow_saved_model_bundle", 

2383 "torchscript", 

2384 ], 

2385 ): 

2386 if key == "keras_hdf5": 

2387 ret = self.keras_hdf5 

2388 elif key == "onnx": 

2389 ret = self.onnx 

2390 elif key == "pytorch_state_dict": 

2391 ret = self.pytorch_state_dict 

2392 elif key == "tensorflow_js": 

2393 ret = self.tensorflow_js 

2394 elif key == "tensorflow_saved_model_bundle": 

2395 ret = self.tensorflow_saved_model_bundle 

2396 elif key == "torchscript": 

2397 ret = self.torchscript 

2398 else: 

2399 raise KeyError(key) 

2400 

2401 if ret is None: 

2402 raise KeyError(key) 

2403 

2404 return ret 

2405 

2406 @property 

2407 def available_formats(self): 

2408 return { 

2409 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2410 **({} if self.onnx is None else {"onnx": self.onnx}), 

2411 **( 

2412 {} 

2413 if self.pytorch_state_dict is None 

2414 else {"pytorch_state_dict": self.pytorch_state_dict} 

2415 ), 

2416 **( 

2417 {} 

2418 if self.tensorflow_js is None 

2419 else {"tensorflow_js": self.tensorflow_js} 

2420 ), 

2421 **( 

2422 {} 

2423 if self.tensorflow_saved_model_bundle is None 

2424 else { 

2425 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2426 } 

2427 ), 

2428 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2429 } 

2430 

2431 @property 

2432 def missing_formats(self): 

2433 return { 

2434 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2435 } 

2436 

2437 

2438class ModelId(ResourceId): 

2439 pass 

2440 

2441 

2442class LinkedModel(LinkedResourceBase): 

2443 """Reference to a bioimage.io model.""" 

2444 

2445 id: ModelId 

2446 """A valid model `id` from the bioimage.io collection.""" 

2447 

2448 

2449class _DataDepSize(NamedTuple): 

2450 min: int 

2451 max: Optional[int] 

2452 

2453 

2454class _AxisSizes(NamedTuple): 

2455 """the lenghts of all axes of model inputs and outputs""" 

2456 

2457 inputs: Dict[Tuple[TensorId, AxisId], int] 

2458 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2459 

2460 

2461class _TensorSizes(NamedTuple): 

2462 """_AxisSizes as nested dicts""" 

2463 

2464 inputs: Dict[TensorId, Dict[AxisId, int]] 

2465 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2466 

2467 

2468class ReproducibilityTolerance(Node, extra="allow"): 

2469 """Describes what small numerical differences -- if any -- may be tolerated 

2470 in the generated output when executing in different environments. 

2471 

2472 A tensor element *output* is considered mismatched to the **test_tensor** if 

2473 abs(*output* - **test_tensor**) > **absolute_tolerance** + **relative_tolerance** * abs(**test_tensor**). 

2474 (Internally we call [numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html).) 

2475 

2476 Motivation: 

2477 For testing we can request the respective deep learning frameworks to be as 

2478 reproducible as possible by setting seeds and chosing deterministic algorithms, 

2479 but differences in operating systems, available hardware and installed drivers 

2480 may still lead to numerical differences. 

2481 """ 

2482 

2483 relative_tolerance: RelativeTolerance = 1e-3 

2484 """Maximum relative tolerance of reproduced test tensor.""" 

2485 

2486 absolute_tolerance: AbsoluteTolerance = 1e-4 

2487 """Maximum absolute tolerance of reproduced test tensor.""" 

2488 

2489 mismatched_elements_per_million: MismatchedElementsPerMillion = 0 

2490 """Maximum number of mismatched elements/pixels per million to tolerate.""" 

2491 

2492 output_ids: Sequence[TensorId] = () 

2493 """Limits the output tensor IDs these reproducibility details apply to.""" 

2494 

2495 weights_formats: Sequence[WeightsFormat] = () 

2496 """Limits the weights formats these details apply to.""" 

2497 

2498 

2499class BioimageioConfig(Node, extra="allow"): 

2500 reproducibility_tolerance: Sequence[ReproducibilityTolerance] = () 

2501 """Tolerances to allow when reproducing the model's test outputs 

2502 from the model's test inputs. 

2503 Only the first entry matching tensor id and weights format is considered. 

2504 """ 

2505 

2506 

2507class Config(Node, extra="allow"): 

2508 bioimageio: BioimageioConfig = Field(default_factory=BioimageioConfig) 

2509 

2510 

2511class ModelDescr(GenericModelDescrBase): 

2512 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

2513 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

2514 """ 

2515 

2516 implemented_format_version: ClassVar[Literal["0.5.4"]] = "0.5.4" 

2517 if TYPE_CHECKING: 

2518 format_version: Literal["0.5.4"] = "0.5.4" 

2519 else: 

2520 format_version: Literal["0.5.4"] 

2521 """Version of the bioimage.io model description specification used. 

2522 When creating a new model always use the latest micro/patch version described here. 

2523 The `format_version` is important for any consumer software to understand how to parse the fields. 

2524 """ 

2525 

2526 implemented_type: ClassVar[Literal["model"]] = "model" 

2527 if TYPE_CHECKING: 

2528 type: Literal["model"] = "model" 

2529 else: 

2530 type: Literal["model"] 

2531 """Specialized resource type 'model'""" 

2532 

2533 id: Optional[ModelId] = None 

2534 """bioimage.io-wide unique resource identifier 

2535 assigned by bioimage.io; version **un**specific.""" 

2536 

2537 authors: NotEmpty[List[Author]] 

2538 """The authors are the creators of the model RDF and the primary points of contact.""" 

2539 

2540 documentation: Annotated[ 

2541 DocumentationSource, 

2542 Field( 

2543 examples=[ 

2544 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 

2545 "README.md", 

2546 ], 

2547 ), 

2548 ] 

2549 """∈📦 URL or relative path to a markdown file with additional documentation. 

2550 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

2551 The documentation should include a '#[#] Validation' (sub)section 

2552 with details on how to quantitatively validate the model on unseen data.""" 

2553 

2554 @field_validator("documentation", mode="after") 

2555 @classmethod 

2556 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource: 

2557 if not get_validation_context().perform_io_checks: 

2558 return value 

2559 

2560 doc_path = download(value).path 

2561 doc_content = doc_path.read_text(encoding="utf-8") 

2562 assert isinstance(doc_content, str) 

2563 if not re.search("#.*[vV]alidation", doc_content): 

2564 issue_warning( 

2565 "No '# Validation' (sub)section found in {value}.", 

2566 value=value, 

2567 field="documentation", 

2568 ) 

2569 

2570 return value 

2571 

2572 inputs: NotEmpty[Sequence[InputTensorDescr]] 

2573 """Describes the input tensors expected by this model.""" 

2574 

2575 @field_validator("inputs", mode="after") 

2576 @classmethod 

2577 def _validate_input_axes( 

2578 cls, inputs: Sequence[InputTensorDescr] 

2579 ) -> Sequence[InputTensorDescr]: 

2580 input_size_refs = cls._get_axes_with_independent_size(inputs) 

2581 

2582 for i, ipt in enumerate(inputs): 

2583 valid_independent_refs: Dict[ 

2584 Tuple[TensorId, AxisId], 

2585 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2586 ] = { 

2587 **{ 

2588 (ipt.id, a.id): (ipt, a, a.size) 

2589 for a in ipt.axes 

2590 if not isinstance(a, BatchAxis) 

2591 and isinstance(a.size, (int, ParameterizedSize)) 

2592 }, 

2593 **input_size_refs, 

2594 } 

2595 for a, ax in enumerate(ipt.axes): 

2596 cls._validate_axis( 

2597 "inputs", 

2598 i=i, 

2599 tensor_id=ipt.id, 

2600 a=a, 

2601 axis=ax, 

2602 valid_independent_refs=valid_independent_refs, 

2603 ) 

2604 return inputs 

2605 

2606 @staticmethod 

2607 def _validate_axis( 

2608 field_name: str, 

2609 i: int, 

2610 tensor_id: TensorId, 

2611 a: int, 

2612 axis: AnyAxis, 

2613 valid_independent_refs: Dict[ 

2614 Tuple[TensorId, AxisId], 

2615 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2616 ], 

2617 ): 

2618 if isinstance(axis, BatchAxis) or isinstance( 

2619 axis.size, (int, ParameterizedSize, DataDependentSize) 

2620 ): 

2621 return 

2622 elif not isinstance(axis.size, SizeReference): 

2623 assert_never(axis.size) 

2624 

2625 # validate axis.size SizeReference 

2626 ref = (axis.size.tensor_id, axis.size.axis_id) 

2627 if ref not in valid_independent_refs: 

2628 raise ValueError( 

2629 "Invalid tensor axis reference at" 

2630 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

2631 ) 

2632 if ref == (tensor_id, axis.id): 

2633 raise ValueError( 

2634 "Self-referencing not allowed for" 

2635 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

2636 ) 

2637 if axis.type == "channel": 

2638 if valid_independent_refs[ref][1].type != "channel": 

2639 raise ValueError( 

2640 "A channel axis' size may only reference another fixed size" 

2641 + " channel axis." 

2642 ) 

2643 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

2644 ref_size = valid_independent_refs[ref][2] 

2645 assert isinstance(ref_size, int), ( 

2646 "channel axis ref (another channel axis) has to specify fixed" 

2647 + " size" 

2648 ) 

2649 generated_channel_names = [ 

2650 Identifier(axis.channel_names.format(i=i)) 

2651 for i in range(1, ref_size + 1) 

2652 ] 

2653 axis.channel_names = generated_channel_names 

2654 

2655 if (ax_unit := getattr(axis, "unit", None)) != ( 

2656 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

2657 ): 

2658 raise ValueError( 

2659 "The units of an axis and its reference axis need to match, but" 

2660 + f" '{ax_unit}' != '{ref_unit}'." 

2661 ) 

2662 ref_axis = valid_independent_refs[ref][1] 

2663 if isinstance(ref_axis, BatchAxis): 

2664 raise ValueError( 

2665 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

2666 + " (a batch axis is not allowed as reference)." 

2667 ) 

2668 

2669 if isinstance(axis, WithHalo): 

2670 min_size = axis.size.get_size(axis, ref_axis, n=0) 

2671 if (min_size - 2 * axis.halo) < 1: 

2672 raise ValueError( 

2673 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

2674 + f" {axis.halo}." 

2675 ) 

2676 

2677 input_halo = axis.halo * axis.scale / ref_axis.scale 

2678 if input_halo != int(input_halo) or input_halo % 2 == 1: 

2679 raise ValueError( 

2680 f"input_halo {input_halo} (output_halo {axis.halo} *" 

2681 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 

2682 + f" {tensor_id}.{axis.id}." 

2683 ) 

2684 

2685 @model_validator(mode="after") 

2686 def _validate_test_tensors(self) -> Self: 

2687 if not get_validation_context().perform_io_checks: 

2688 return self 

2689 

2690 test_output_arrays = [ 

2691 load_array(descr.test_tensor.download().path) for descr in self.outputs 

2692 ] 

2693 test_input_arrays = [ 

2694 load_array(descr.test_tensor.download().path) for descr in self.inputs 

2695 ] 

2696 

2697 tensors = { 

2698 descr.id: (descr, array) 

2699 for descr, array in zip( 

2700 chain(self.inputs, self.outputs), test_input_arrays + test_output_arrays 

2701 ) 

2702 } 

2703 validate_tensors(tensors, tensor_origin="test_tensor") 

2704 

2705 output_arrays = { 

2706 descr.id: array for descr, array in zip(self.outputs, test_output_arrays) 

2707 } 

2708 for rep_tol in self.config.bioimageio.reproducibility_tolerance: 

2709 if not rep_tol.absolute_tolerance: 

2710 continue 

2711 

2712 if rep_tol.output_ids: 

2713 out_arrays = { 

2714 oid: a 

2715 for oid, a in output_arrays.items() 

2716 if oid in rep_tol.output_ids 

2717 } 

2718 else: 

2719 out_arrays = output_arrays 

2720 

2721 for out_id, array in out_arrays.items(): 

2722 if rep_tol.absolute_tolerance > (max_test_value := array.max()) * 0.01: 

2723 raise ValueError( 

2724 "config.bioimageio.reproducibility_tolerance.absolute_tolerance=" 

2725 + f"{rep_tol.absolute_tolerance} > 0.01*{max_test_value}" 

2726 + f" (1% of the maximum value of the test tensor '{out_id}')" 

2727 ) 

2728 

2729 return self 

2730 

2731 @model_validator(mode="after") 

2732 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

2733 ipt_refs = {t.id for t in self.inputs} 

2734 out_refs = {t.id for t in self.outputs} 

2735 for ipt in self.inputs: 

2736 for p in ipt.preprocessing: 

2737 ref = p.kwargs.get("reference_tensor") 

2738 if ref is None: 

2739 continue 

2740 if ref not in ipt_refs: 

2741 raise ValueError( 

2742 f"`reference_tensor` '{ref}' not found. Valid input tensor" 

2743 + f" references are: {ipt_refs}." 

2744 ) 

2745 

2746 for out in self.outputs: 

2747 for p in out.postprocessing: 

2748 ref = p.kwargs.get("reference_tensor") 

2749 if ref is None: 

2750 continue 

2751 

2752 if ref not in ipt_refs and ref not in out_refs: 

2753 raise ValueError( 

2754 f"`reference_tensor` '{ref}' not found. Valid tensor references" 

2755 + f" are: {ipt_refs | out_refs}." 

2756 ) 

2757 

2758 return self 

2759 

2760 # TODO: use validate funcs in validate_test_tensors 

2761 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 

2762 

2763 name: Annotated[ 

2764 Annotated[ 

2765 str, RestrictCharacters(string.ascii_letters + string.digits + "_+- ()") 

2766 ], 

2767 MinLen(5), 

2768 MaxLen(128), 

2769 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

2770 ] 

2771 """A human-readable name of this model. 

2772 It should be no longer than 64 characters 

2773 and may only contain letter, number, underscore, minus, parentheses and spaces. 

2774 We recommend to chose a name that refers to the model's task and image modality. 

2775 """ 

2776 

2777 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

2778 """Describes the output tensors.""" 

2779 

2780 @field_validator("outputs", mode="after") 

2781 @classmethod 

2782 def _validate_tensor_ids( 

2783 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

2784 ) -> Sequence[OutputTensorDescr]: 

2785 tensor_ids = [ 

2786 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

2787 ] 

2788 duplicate_tensor_ids: List[str] = [] 

2789 seen: Set[str] = set() 

2790 for t in tensor_ids: 

2791 if t in seen: 

2792 duplicate_tensor_ids.append(t) 

2793 

2794 seen.add(t) 

2795 

2796 if duplicate_tensor_ids: 

2797 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

2798 

2799 return outputs 

2800 

2801 @staticmethod 

2802 def _get_axes_with_parameterized_size( 

2803 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2804 ): 

2805 return { 

2806 f"{t.id}.{a.id}": (t, a, a.size) 

2807 for t in io 

2808 for a in t.axes 

2809 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

2810 } 

2811 

2812 @staticmethod 

2813 def _get_axes_with_independent_size( 

2814 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2815 ): 

2816 return { 

2817 (t.id, a.id): (t, a, a.size) 

2818 for t in io 

2819 for a in t.axes 

2820 if not isinstance(a, BatchAxis) 

2821 and isinstance(a.size, (int, ParameterizedSize)) 

2822 } 

2823 

2824 @field_validator("outputs", mode="after") 

2825 @classmethod 

2826 def _validate_output_axes( 

2827 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

2828 ) -> List[OutputTensorDescr]: 

2829 input_size_refs = cls._get_axes_with_independent_size( 

2830 info.data.get("inputs", []) 

2831 ) 

2832 output_size_refs = cls._get_axes_with_independent_size(outputs) 

2833 

2834 for i, out in enumerate(outputs): 

2835 valid_independent_refs: Dict[ 

2836 Tuple[TensorId, AxisId], 

2837 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2838 ] = { 

2839 **{ 

2840 (out.id, a.id): (out, a, a.size) 

2841 for a in out.axes 

2842 if not isinstance(a, BatchAxis) 

2843 and isinstance(a.size, (int, ParameterizedSize)) 

2844 }, 

2845 **input_size_refs, 

2846 **output_size_refs, 

2847 } 

2848 for a, ax in enumerate(out.axes): 

2849 cls._validate_axis( 

2850 "outputs", 

2851 i, 

2852 out.id, 

2853 a, 

2854 ax, 

2855 valid_independent_refs=valid_independent_refs, 

2856 ) 

2857 

2858 return outputs 

2859 

2860 packaged_by: List[Author] = Field(default_factory=list) 

2861 """The persons that have packaged and uploaded this model. 

2862 Only required if those persons differ from the `authors`.""" 

2863 

2864 parent: Optional[LinkedModel] = None 

2865 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

2866 

2867 @model_validator(mode="after") 

2868 def _validate_parent_is_not_self(self) -> Self: 

2869 if self.parent is not None and self.parent.id == self.id: 

2870 raise ValueError("A model description may not reference itself as parent.") 

2871 

2872 return self 

2873 

2874 run_mode: Annotated[ 

2875 Optional[RunMode], 

2876 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

2877 ] = None 

2878 """Custom run mode for this model: for more complex prediction procedures like test time 

2879 data augmentation that currently cannot be expressed in the specification. 

2880 No standard run modes are defined yet.""" 

2881 

2882 timestamp: Datetime = Field(default_factory=Datetime.now) 

2883 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

2884 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

2885 (In Python a datetime object is valid, too).""" 

2886 

2887 training_data: Annotated[ 

2888 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

2889 Field(union_mode="left_to_right"), 

2890 ] = None 

2891 """The dataset used to train this model""" 

2892 

2893 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

2894 """The weights for this model. 

2895 Weights can be given for different formats, but should otherwise be equivalent. 

2896 The available weight formats determine which consumers can use this model.""" 

2897 

2898 config: Config = Field(default_factory=Config) 

2899 

2900 @model_validator(mode="after") 

2901 def _add_default_cover(self) -> Self: 

2902 if not get_validation_context().perform_io_checks or self.covers: 

2903 return self 

2904 

2905 try: 

2906 generated_covers = generate_covers( 

2907 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs], 

2908 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs], 

2909 ) 

2910 except Exception as e: 

2911 issue_warning( 

2912 "Failed to generate cover image(s): {e}", 

2913 value=self.covers, 

2914 msg_context=dict(e=e), 

2915 field="covers", 

2916 ) 

2917 else: 

2918 self.covers.extend(generated_covers) 

2919 

2920 return self 

2921 

2922 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

2923 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs] 

2924 assert all(isinstance(d, np.ndarray) for d in data) 

2925 return data 

2926 

2927 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

2928 data = [load_array(out.test_tensor.download().path) for out in self.outputs] 

2929 assert all(isinstance(d, np.ndarray) for d in data) 

2930 return data 

2931 

2932 @staticmethod 

2933 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

2934 batch_size = 1 

2935 tensor_with_batchsize: Optional[TensorId] = None 

2936 for tid in tensor_sizes: 

2937 for aid, s in tensor_sizes[tid].items(): 

2938 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

2939 continue 

2940 

2941 if batch_size != 1: 

2942 assert tensor_with_batchsize is not None 

2943 raise ValueError( 

2944 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

2945 ) 

2946 

2947 batch_size = s 

2948 tensor_with_batchsize = tid 

2949 

2950 return batch_size 

2951 

2952 def get_output_tensor_sizes( 

2953 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

2954 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

2955 """Returns the tensor output sizes for given **input_sizes**. 

2956 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

2957 Otherwise it might be larger than the actual (valid) output""" 

2958 batch_size = self.get_batch_size(input_sizes) 

2959 ns = self.get_ns(input_sizes) 

2960 

2961 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

2962 return tensor_sizes.outputs 

2963 

2964 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

2965 """get parameter `n` for each parameterized axis 

2966 such that the valid input size is >= the given input size""" 

2967 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

2968 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

2969 for tid in input_sizes: 

2970 for aid, s in input_sizes[tid].items(): 

2971 size_descr = axes[tid][aid].size 

2972 if isinstance(size_descr, ParameterizedSize): 

2973 ret[(tid, aid)] = size_descr.get_n(s) 

2974 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

2975 pass 

2976 else: 

2977 assert_never(size_descr) 

2978 

2979 return ret 

2980 

2981 def get_tensor_sizes( 

2982 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

2983 ) -> _TensorSizes: 

2984 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

2985 return _TensorSizes( 

2986 { 

2987 t: { 

2988 aa: axis_sizes.inputs[(tt, aa)] 

2989 for tt, aa in axis_sizes.inputs 

2990 if tt == t 

2991 } 

2992 for t in {tt for tt, _ in axis_sizes.inputs} 

2993 }, 

2994 { 

2995 t: { 

2996 aa: axis_sizes.outputs[(tt, aa)] 

2997 for tt, aa in axis_sizes.outputs 

2998 if tt == t 

2999 } 

3000 for t in {tt for tt, _ in axis_sizes.outputs} 

3001 }, 

3002 ) 

3003 

3004 def get_axis_sizes( 

3005 self, 

3006 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

3007 batch_size: Optional[int] = None, 

3008 *, 

3009 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

3010 ) -> _AxisSizes: 

3011 """Determine input and output block shape for scale factors **ns** 

3012 of parameterized input sizes. 

3013 

3014 Args: 

3015 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

3016 that is parameterized as `size = min + n * step`. 

3017 batch_size: The desired size of the batch dimension. 

3018 If given **batch_size** overwrites any batch size present in 

3019 **max_input_shape**. Default 1. 

3020 max_input_shape: Limits the derived block shapes. 

3021 Each axis for which the input size, parameterized by `n`, is larger 

3022 than **max_input_shape** is set to the minimal value `n_min` for which 

3023 this is still true. 

3024 Use this for small input samples or large values of **ns**. 

3025 Or simply whenever you know the full input shape. 

3026 

3027 Returns: 

3028 Resolved axis sizes for model inputs and outputs. 

3029 """ 

3030 max_input_shape = max_input_shape or {} 

3031 if batch_size is None: 

3032 for (_t_id, a_id), s in max_input_shape.items(): 

3033 if a_id == BATCH_AXIS_ID: 

3034 batch_size = s 

3035 break 

3036 else: 

3037 batch_size = 1 

3038 

3039 all_axes = { 

3040 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

3041 } 

3042 

3043 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

3044 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

3045 

3046 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

3047 if isinstance(a, BatchAxis): 

3048 if (t_descr.id, a.id) in ns: 

3049 logger.warning( 

3050 "Ignoring unexpected size increment factor (n) for batch axis" 

3051 + " of tensor '{}'.", 

3052 t_descr.id, 

3053 ) 

3054 return batch_size 

3055 elif isinstance(a.size, int): 

3056 if (t_descr.id, a.id) in ns: 

3057 logger.warning( 

3058 "Ignoring unexpected size increment factor (n) for fixed size" 

3059 + " axis '{}' of tensor '{}'.", 

3060 a.id, 

3061 t_descr.id, 

3062 ) 

3063 return a.size 

3064 elif isinstance(a.size, ParameterizedSize): 

3065 if (t_descr.id, a.id) not in ns: 

3066 raise ValueError( 

3067 "Size increment factor (n) missing for parametrized axis" 

3068 + f" '{a.id}' of tensor '{t_descr.id}'." 

3069 ) 

3070 n = ns[(t_descr.id, a.id)] 

3071 s_max = max_input_shape.get((t_descr.id, a.id)) 

3072 if s_max is not None: 

3073 n = min(n, a.size.get_n(s_max)) 

3074 

3075 return a.size.get_size(n) 

3076 

3077 elif isinstance(a.size, SizeReference): 

3078 if (t_descr.id, a.id) in ns: 

3079 logger.warning( 

3080 "Ignoring unexpected size increment factor (n) for axis '{}'" 

3081 + " of tensor '{}' with size reference.", 

3082 a.id, 

3083 t_descr.id, 

3084 ) 

3085 assert not isinstance(a, BatchAxis) 

3086 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

3087 assert not isinstance(ref_axis, BatchAxis) 

3088 ref_key = (a.size.tensor_id, a.size.axis_id) 

3089 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

3090 assert ref_size is not None, ref_key 

3091 assert not isinstance(ref_size, _DataDepSize), ref_key 

3092 return a.size.get_size( 

3093 axis=a, 

3094 ref_axis=ref_axis, 

3095 ref_size=ref_size, 

3096 ) 

3097 elif isinstance(a.size, DataDependentSize): 

3098 if (t_descr.id, a.id) in ns: 

3099 logger.warning( 

3100 "Ignoring unexpected increment factor (n) for data dependent" 

3101 + " size axis '{}' of tensor '{}'.", 

3102 a.id, 

3103 t_descr.id, 

3104 ) 

3105 return _DataDepSize(a.size.min, a.size.max) 

3106 else: 

3107 assert_never(a.size) 

3108 

3109 # first resolve all , but the `SizeReference` input sizes 

3110 for t_descr in self.inputs: 

3111 for a in t_descr.axes: 

3112 if not isinstance(a.size, SizeReference): 

3113 s = get_axis_size(a) 

3114 assert not isinstance(s, _DataDepSize) 

3115 inputs[t_descr.id, a.id] = s 

3116 

3117 # resolve all other input axis sizes 

3118 for t_descr in self.inputs: 

3119 for a in t_descr.axes: 

3120 if isinstance(a.size, SizeReference): 

3121 s = get_axis_size(a) 

3122 assert not isinstance(s, _DataDepSize) 

3123 inputs[t_descr.id, a.id] = s 

3124 

3125 # resolve all output axis sizes 

3126 for t_descr in self.outputs: 

3127 for a in t_descr.axes: 

3128 assert not isinstance(a.size, ParameterizedSize) 

3129 s = get_axis_size(a) 

3130 outputs[t_descr.id, a.id] = s 

3131 

3132 return _AxisSizes(inputs=inputs, outputs=outputs) 

3133 

3134 @model_validator(mode="before") 

3135 @classmethod 

3136 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

3137 cls.convert_from_old_format_wo_validation(data) 

3138 return data 

3139 

3140 @classmethod 

3141 def convert_from_old_format_wo_validation(cls, data: Dict[str, Any]) -> None: 

3142 """Convert metadata following an older format version to this classes' format 

3143 without validating the result. 

3144 """ 

3145 if ( 

3146 data.get("type") == "model" 

3147 and isinstance(fv := data.get("format_version"), str) 

3148 and fv.count(".") == 2 

3149 ): 

3150 fv_parts = fv.split(".") 

3151 if any(not p.isdigit() for p in fv_parts): 

3152 return 

3153 

3154 fv_tuple = tuple(map(int, fv_parts)) 

3155 

3156 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

3157 if fv_tuple[:2] in ((0, 3), (0, 4)): 

3158 m04 = _ModelDescr_v0_4.load(data) 

3159 if isinstance(m04, InvalidDescr): 

3160 try: 

3161 updated = _model_conv.convert_as_dict( 

3162 m04 # pyright: ignore[reportArgumentType] 

3163 ) 

3164 except Exception as e: 

3165 logger.error( 

3166 "Failed to convert from invalid model 0.4 description." 

3167 + f"\nerror: {e}" 

3168 + "\nProceeding with model 0.5 validation without conversion." 

3169 ) 

3170 updated = None 

3171 else: 

3172 updated = _model_conv.convert_as_dict(m04) 

3173 

3174 if updated is not None: 

3175 data.clear() 

3176 data.update(updated) 

3177 

3178 elif fv_tuple[:2] == (0, 5): 

3179 # bump patch version 

3180 data["format_version"] = cls.implemented_format_version 

3181 

3182 

3183class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

3184 def _convert( 

3185 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

3186 ) -> "ModelDescr | dict[str, Any]": 

3187 name = "".join( 

3188 c if c in string.ascii_letters + string.digits + "_+- ()" else " " 

3189 for c in src.name 

3190 ) 

3191 

3192 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

3193 conv = ( 

3194 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

3195 ) 

3196 return None if auths is None else [conv(a) for a in auths] 

3197 

3198 if TYPE_CHECKING: 

3199 arch_file_conv = _arch_file_conv.convert 

3200 arch_lib_conv = _arch_lib_conv.convert 

3201 else: 

3202 arch_file_conv = _arch_file_conv.convert_as_dict 

3203 arch_lib_conv = _arch_lib_conv.convert_as_dict 

3204 

3205 input_size_refs = { 

3206 ipt.name: { 

3207 a: s 

3208 for a, s in zip( 

3209 ipt.axes, 

3210 ( 

3211 ipt.shape.min 

3212 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

3213 else ipt.shape 

3214 ), 

3215 ) 

3216 } 

3217 for ipt in src.inputs 

3218 if ipt.shape 

3219 } 

3220 output_size_refs = { 

3221 **{ 

3222 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

3223 for out in src.outputs 

3224 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

3225 }, 

3226 **input_size_refs, 

3227 } 

3228 

3229 return tgt( 

3230 attachments=( 

3231 [] 

3232 if src.attachments is None 

3233 else [FileDescr(source=f) for f in src.attachments.files] 

3234 ), 

3235 authors=[ 

3236 _author_conv.convert_as_dict(a) for a in src.authors 

3237 ], # pyright: ignore[reportArgumentType] 

3238 cite=[ 

3239 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 

3240 ], # pyright: ignore[reportArgumentType] 

3241 config=src.config, # pyright: ignore[reportArgumentType] 

3242 covers=src.covers, 

3243 description=src.description, 

3244 documentation=src.documentation, 

3245 format_version="0.5.4", 

3246 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3247 icon=src.icon, 

3248 id=None if src.id is None else ModelId(src.id), 

3249 id_emoji=src.id_emoji, 

3250 license=src.license, # type: ignore 

3251 links=src.links, 

3252 maintainers=[ 

3253 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 

3254 ], # pyright: ignore[reportArgumentType] 

3255 name=name, 

3256 tags=src.tags, 

3257 type=src.type, 

3258 uploader=src.uploader, 

3259 version=src.version, 

3260 inputs=[ # pyright: ignore[reportArgumentType] 

3261 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

3262 for ipt, tt, st, in zip( 

3263 src.inputs, 

3264 src.test_inputs, 

3265 src.sample_inputs or [None] * len(src.test_inputs), 

3266 ) 

3267 ], 

3268 outputs=[ # pyright: ignore[reportArgumentType] 

3269 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

3270 for out, tt, st, in zip( 

3271 src.outputs, 

3272 src.test_outputs, 

3273 src.sample_outputs or [None] * len(src.test_outputs), 

3274 ) 

3275 ], 

3276 parent=( 

3277 None 

3278 if src.parent is None 

3279 else LinkedModel( 

3280 id=ModelId( 

3281 str(src.parent.id) 

3282 + ( 

3283 "" 

3284 if src.parent.version_number is None 

3285 else f"/{src.parent.version_number}" 

3286 ) 

3287 ) 

3288 ) 

3289 ), 

3290 training_data=( 

3291 None 

3292 if src.training_data is None 

3293 else ( 

3294 LinkedDataset( 

3295 id=DatasetId( 

3296 str(src.training_data.id) 

3297 + ( 

3298 "" 

3299 if src.training_data.version_number is None 

3300 else f"/{src.training_data.version_number}" 

3301 ) 

3302 ) 

3303 ) 

3304 if isinstance(src.training_data, LinkedDataset02) 

3305 else src.training_data 

3306 ) 

3307 ), 

3308 packaged_by=[ 

3309 _author_conv.convert_as_dict(a) for a in src.packaged_by 

3310 ], # pyright: ignore[reportArgumentType] 

3311 run_mode=src.run_mode, 

3312 timestamp=src.timestamp, 

3313 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

3314 keras_hdf5=(w := src.weights.keras_hdf5) 

3315 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

3316 authors=conv_authors(w.authors), 

3317 source=w.source, 

3318 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3319 parent=w.parent, 

3320 ), 

3321 onnx=(w := src.weights.onnx) 

3322 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

3323 source=w.source, 

3324 authors=conv_authors(w.authors), 

3325 parent=w.parent, 

3326 opset_version=w.opset_version or 15, 

3327 ), 

3328 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

3329 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

3330 source=w.source, 

3331 authors=conv_authors(w.authors), 

3332 parent=w.parent, 

3333 architecture=( 

3334 arch_file_conv( 

3335 w.architecture, 

3336 w.architecture_sha256, 

3337 w.kwargs, 

3338 ) 

3339 if isinstance(w.architecture, _CallableFromFile_v0_4) 

3340 else arch_lib_conv(w.architecture, w.kwargs) 

3341 ), 

3342 pytorch_version=w.pytorch_version or Version("1.10"), 

3343 dependencies=( 

3344 None 

3345 if w.dependencies is None 

3346 else (EnvironmentFileDescr if TYPE_CHECKING else dict)( 

3347 source=cast( 

3348 ImportantFileSource, 

3349 str(deps := w.dependencies)[ 

3350 ( 

3351 len("conda:") 

3352 if str(deps).startswith("conda:") 

3353 else 0 

3354 ) : 

3355 ], 

3356 ) 

3357 ) 

3358 ), 

3359 ), 

3360 tensorflow_js=(w := src.weights.tensorflow_js) 

3361 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

3362 source=w.source, 

3363 authors=conv_authors(w.authors), 

3364 parent=w.parent, 

3365 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3366 ), 

3367 tensorflow_saved_model_bundle=( 

3368 w := src.weights.tensorflow_saved_model_bundle 

3369 ) 

3370 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

3371 authors=conv_authors(w.authors), 

3372 parent=w.parent, 

3373 source=w.source, 

3374 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3375 dependencies=( 

3376 None 

3377 if w.dependencies is None 

3378 else (EnvironmentFileDescr if TYPE_CHECKING else dict)( 

3379 source=cast( 

3380 ImportantFileSource, 

3381 ( 

3382 str(w.dependencies)[len("conda:") :] 

3383 if str(w.dependencies).startswith("conda:") 

3384 else str(w.dependencies) 

3385 ), 

3386 ) 

3387 ) 

3388 ), 

3389 ), 

3390 torchscript=(w := src.weights.torchscript) 

3391 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

3392 source=w.source, 

3393 authors=conv_authors(w.authors), 

3394 parent=w.parent, 

3395 pytorch_version=w.pytorch_version or Version("1.10"), 

3396 ), 

3397 ), 

3398 ) 

3399 

3400 

3401_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

3402 

3403 

3404# create better cover images for 3d data and non-image outputs 

3405def generate_covers( 

3406 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

3407 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

3408) -> List[Path]: 

3409 def squeeze( 

3410 data: NDArray[Any], axes: Sequence[AnyAxis] 

3411 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

3412 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

3413 if data.ndim != len(axes): 

3414 raise ValueError( 

3415 f"tensor shape {data.shape} does not match described axes" 

3416 + f" {[a.id for a in axes]}" 

3417 ) 

3418 

3419 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

3420 return data.squeeze(), axes 

3421 

3422 def normalize( 

3423 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

3424 ) -> NDArray[np.float32]: 

3425 data = data.astype("float32") 

3426 data -= data.min(axis=axis, keepdims=True) 

3427 data /= data.max(axis=axis, keepdims=True) + eps 

3428 return data 

3429 

3430 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

3431 original_shape = data.shape 

3432 data, axes = squeeze(data, axes) 

3433 

3434 # take slice fom any batch or index axis if needed 

3435 # and convert the first channel axis and take a slice from any additional channel axes 

3436 slices: Tuple[slice, ...] = () 

3437 ndim = data.ndim 

3438 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

3439 has_c_axis = False 

3440 for i, a in enumerate(axes): 

3441 s = data.shape[i] 

3442 assert s > 1 

3443 if ( 

3444 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

3445 and ndim > ndim_need 

3446 ): 

3447 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3448 ndim -= 1 

3449 elif isinstance(a, ChannelAxis): 

3450 if has_c_axis: 

3451 # second channel axis 

3452 data = data[slices + (slice(0, 1),)] 

3453 ndim -= 1 

3454 else: 

3455 has_c_axis = True 

3456 if s == 2: 

3457 # visualize two channels with cyan and magenta 

3458 data = np.concatenate( 

3459 [ 

3460 data[slices + (slice(1, 2),)], 

3461 data[slices + (slice(0, 1),)], 

3462 ( 

3463 data[slices + (slice(0, 1),)] 

3464 + data[slices + (slice(1, 2),)] 

3465 ) 

3466 / 2, # TODO: take maximum instead? 

3467 ], 

3468 axis=i, 

3469 ) 

3470 elif data.shape[i] == 3: 

3471 pass # visualize 3 channels as RGB 

3472 else: 

3473 # visualize first 3 channels as RGB 

3474 data = data[slices + (slice(3),)] 

3475 

3476 assert data.shape[i] == 3 

3477 

3478 slices += (slice(None),) 

3479 

3480 data, axes = squeeze(data, axes) 

3481 assert len(axes) == ndim 

3482 # take slice from z axis if needed 

3483 slices = () 

3484 if ndim > ndim_need: 

3485 for i, a in enumerate(axes): 

3486 s = data.shape[i] 

3487 if a.id == AxisId("z"): 

3488 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3489 data, axes = squeeze(data, axes) 

3490 ndim -= 1 

3491 break 

3492 

3493 slices += (slice(None),) 

3494 

3495 # take slice from any space or time axis 

3496 slices = () 

3497 

3498 for i, a in enumerate(axes): 

3499 if ndim <= ndim_need: 

3500 break 

3501 

3502 s = data.shape[i] 

3503 assert s > 1 

3504 if isinstance( 

3505 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

3506 ): 

3507 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3508 ndim -= 1 

3509 

3510 slices += (slice(None),) 

3511 

3512 del slices 

3513 data, axes = squeeze(data, axes) 

3514 assert len(axes) == ndim 

3515 

3516 if (has_c_axis and ndim != 3) or ndim != 2: 

3517 raise ValueError( 

3518 f"Failed to construct cover image from shape {original_shape}" 

3519 ) 

3520 

3521 if not has_c_axis: 

3522 assert ndim == 2 

3523 data = np.repeat(data[:, :, None], 3, axis=2) 

3524 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

3525 ndim += 1 

3526 

3527 assert ndim == 3 

3528 

3529 # transpose axis order such that longest axis comes first... 

3530 axis_order: List[int] = list(np.argsort(list(data.shape))) 

3531 axis_order.reverse() 

3532 # ... and channel axis is last 

3533 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

3534 axis_order.append(axis_order.pop(c)) 

3535 axes = [axes[ao] for ao in axis_order] 

3536 data = data.transpose(axis_order) 

3537 

3538 # h, w = data.shape[:2] 

3539 # if h / w in (1.0 or 2.0): 

3540 # pass 

3541 # elif h / w < 2: 

3542 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

3543 

3544 norm_along = ( 

3545 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

3546 ) 

3547 # normalize the data and map to 8 bit 

3548 data = normalize(data, norm_along) 

3549 data = (data * 255).astype("uint8") 

3550 

3551 return data 

3552 

3553 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

3554 assert im0.dtype == im1.dtype == np.uint8 

3555 assert im0.shape == im1.shape 

3556 assert im0.ndim == 3 

3557 N, M, C = im0.shape 

3558 assert C == 3 

3559 out = np.ones((N, M, C), dtype="uint8") 

3560 for c in range(C): 

3561 outc = np.tril(im0[..., c]) 

3562 mask = outc == 0 

3563 outc[mask] = np.triu(im1[..., c])[mask] 

3564 out[..., c] = outc 

3565 

3566 return out 

3567 

3568 ipt_descr, ipt = inputs[0] 

3569 out_descr, out = outputs[0] 

3570 

3571 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

3572 out_img = to_2d_image(out, out_descr.axes) 

3573 

3574 cover_folder = Path(mkdtemp()) 

3575 if ipt_img.shape == out_img.shape: 

3576 covers = [cover_folder / "cover.png"] 

3577 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

3578 else: 

3579 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

3580 imwrite(covers[0], ipt_img) 

3581 imwrite(covers[1], out_img) 

3582 

3583 return covers