Coverage for bioimageio/spec/model/v0_5.py: 45%

1233 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-02-05 13:53 +0000

1from __future__ import annotations 

2 

3import collections.abc 

4import re 

5import string 

6import warnings 

7from abc import ABC 

8from copy import deepcopy 

9from datetime import datetime 

10from itertools import chain 

11from math import ceil 

12from pathlib import Path, PurePosixPath 

13from tempfile import mkdtemp 

14from typing import ( 

15 TYPE_CHECKING, 

16 Any, 

17 ClassVar, 

18 Dict, 

19 FrozenSet, 

20 Generic, 

21 List, 

22 Literal, 

23 Mapping, 

24 NamedTuple, 

25 Optional, 

26 Sequence, 

27 Set, 

28 Tuple, 

29 Type, 

30 TypeVar, 

31 Union, 

32 cast, 

33) 

34 

35import numpy as np 

36from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate 

37from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType] 

38from loguru import logger 

39from numpy.typing import NDArray 

40from pydantic import ( 

41 Discriminator, 

42 Field, 

43 RootModel, 

44 Tag, 

45 ValidationInfo, 

46 WrapSerializer, 

47 field_validator, 

48 model_validator, 

49) 

50from typing_extensions import Annotated, LiteralString, Self, assert_never, get_args 

51 

52from .._internal.common_nodes import ( 

53 InvalidDescr, 

54 Node, 

55 NodeWithExplicitlySetFields, 

56) 

57from .._internal.constants import DTYPE_LIMITS 

58from .._internal.field_warning import issue_warning, warn 

59from .._internal.io import BioimageioYamlContent as BioimageioYamlContent 

60from .._internal.io import FileDescr as FileDescr 

61from .._internal.io import WithSuffix, YamlValue, download 

62from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath 

63from .._internal.io_basics import Sha256 as Sha256 

64from .._internal.io_utils import load_array 

65from .._internal.node_converter import Converter 

66from .._internal.types import Datetime as Datetime 

67from .._internal.types import Identifier as Identifier 

68from .._internal.types import ( 

69 ImportantFileSource, 

70 LowerCaseIdentifier, 

71 LowerCaseIdentifierAnno, 

72 SiUnit, 

73) 

74from .._internal.types import NotEmpty as NotEmpty 

75from .._internal.url import HttpUrl as HttpUrl 

76from .._internal.validation_context import validation_context_var 

77from .._internal.validator_annotations import RestrictCharacters 

78from .._internal.version_type import Version as Version 

79from .._internal.warning_levels import INFO 

80from ..dataset.v0_2 import DatasetDescr as DatasetDescr02 

81from ..dataset.v0_2 import LinkedDataset as LinkedDataset02 

82from ..dataset.v0_3 import DatasetDescr as DatasetDescr 

83from ..dataset.v0_3 import DatasetId as DatasetId 

84from ..dataset.v0_3 import LinkedDataset as LinkedDataset 

85from ..dataset.v0_3 import Uploader as Uploader 

86from ..generic.v0_3 import ( 

87 VALID_COVER_IMAGE_EXTENSIONS as VALID_COVER_IMAGE_EXTENSIONS, 

88) 

89from ..generic.v0_3 import Author as Author 

90from ..generic.v0_3 import BadgeDescr as BadgeDescr 

91from ..generic.v0_3 import CiteEntry as CiteEntry 

92from ..generic.v0_3 import DeprecatedLicenseId as DeprecatedLicenseId 

93from ..generic.v0_3 import ( 

94 DocumentationSource, 

95 GenericModelDescrBase, 

96 LinkedResourceBase, 

97 _author_conv, # pyright: ignore[reportPrivateUsage] 

98 _maintainer_conv, # pyright: ignore[reportPrivateUsage] 

99) 

100from ..generic.v0_3 import Doi as Doi 

101from ..generic.v0_3 import LicenseId as LicenseId 

102from ..generic.v0_3 import LinkedResource as LinkedResource 

103from ..generic.v0_3 import Maintainer as Maintainer 

104from ..generic.v0_3 import OrcidId as OrcidId 

105from ..generic.v0_3 import RelativeFilePath as RelativeFilePath 

106from ..generic.v0_3 import ResourceId as ResourceId 

107from .v0_4 import Author as _Author_v0_4 

108from .v0_4 import BinarizeDescr as _BinarizeDescr_v0_4 

109from .v0_4 import CallableFromDepencency as CallableFromDepencency 

110from .v0_4 import CallableFromDepencency as _CallableFromDepencency_v0_4 

111from .v0_4 import CallableFromFile as _CallableFromFile_v0_4 

112from .v0_4 import ClipDescr as _ClipDescr_v0_4 

113from .v0_4 import ClipKwargs as ClipKwargs 

114from .v0_4 import ImplicitOutputShape as _ImplicitOutputShape_v0_4 

115from .v0_4 import InputTensorDescr as _InputTensorDescr_v0_4 

116from .v0_4 import KnownRunMode as KnownRunMode 

117from .v0_4 import ModelDescr as _ModelDescr_v0_4 

118from .v0_4 import OutputTensorDescr as _OutputTensorDescr_v0_4 

119from .v0_4 import ParameterizedInputShape as _ParameterizedInputShape_v0_4 

120from .v0_4 import PostprocessingDescr as _PostprocessingDescr_v0_4 

121from .v0_4 import PreprocessingDescr as _PreprocessingDescr_v0_4 

122from .v0_4 import ProcessingKwargs as ProcessingKwargs 

123from .v0_4 import RunMode as RunMode 

124from .v0_4 import ScaleLinearDescr as _ScaleLinearDescr_v0_4 

125from .v0_4 import ScaleMeanVarianceDescr as _ScaleMeanVarianceDescr_v0_4 

126from .v0_4 import ScaleRangeDescr as _ScaleRangeDescr_v0_4 

127from .v0_4 import SigmoidDescr as _SigmoidDescr_v0_4 

128from .v0_4 import TensorName as _TensorName_v0_4 

129from .v0_4 import WeightsFormat as WeightsFormat 

130from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4 

131from .v0_4 import package_weights 

132 

133SpaceUnit = Literal[ 

134 "attometer", 

135 "angstrom", 

136 "centimeter", 

137 "decimeter", 

138 "exameter", 

139 "femtometer", 

140 "foot", 

141 "gigameter", 

142 "hectometer", 

143 "inch", 

144 "kilometer", 

145 "megameter", 

146 "meter", 

147 "micrometer", 

148 "mile", 

149 "millimeter", 

150 "nanometer", 

151 "parsec", 

152 "petameter", 

153 "picometer", 

154 "terameter", 

155 "yard", 

156 "yoctometer", 

157 "yottameter", 

158 "zeptometer", 

159 "zettameter", 

160] 

161"""Space unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

162 

163TimeUnit = Literal[ 

164 "attosecond", 

165 "centisecond", 

166 "day", 

167 "decisecond", 

168 "exasecond", 

169 "femtosecond", 

170 "gigasecond", 

171 "hectosecond", 

172 "hour", 

173 "kilosecond", 

174 "megasecond", 

175 "microsecond", 

176 "millisecond", 

177 "minute", 

178 "nanosecond", 

179 "petasecond", 

180 "picosecond", 

181 "second", 

182 "terasecond", 

183 "yoctosecond", 

184 "yottasecond", 

185 "zeptosecond", 

186 "zettasecond", 

187] 

188"""Time unit compatible to the [OME-Zarr axes specification 0.5](https://ngff.openmicroscopy.org/0.5/#axes-md)""" 

189 

190AxisType = Literal["batch", "channel", "index", "time", "space"] 

191 

192 

193class TensorId(LowerCaseIdentifier): 

194 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

195 Annotated[LowerCaseIdentifierAnno, MaxLen(32)] 

196 ] 

197 

198 

199class AxisId(LowerCaseIdentifier): 

200 root_model: ClassVar[Type[RootModel[Any]]] = RootModel[ 

201 Annotated[LowerCaseIdentifierAnno, MaxLen(16)] 

202 ] 

203 

204 

205def _is_batch(a: str) -> bool: 

206 return a == BATCH_AXIS_ID 

207 

208 

209def _is_not_batch(a: str) -> bool: 

210 return not _is_batch(a) 

211 

212 

213NonBatchAxisId = Annotated[AxisId, Predicate(_is_not_batch)] 

214 

215PostprocessingId = Literal[ 

216 "binarize", 

217 "clip", 

218 "ensure_dtype", 

219 "fixed_zero_mean_unit_variance", 

220 "scale_linear", 

221 "scale_mean_variance", 

222 "scale_range", 

223 "sigmoid", 

224 "zero_mean_unit_variance", 

225] 

226PreprocessingId = Literal[ 

227 "binarize", 

228 "clip", 

229 "ensure_dtype", 

230 "scale_linear", 

231 "sigmoid", 

232 "zero_mean_unit_variance", 

233 "scale_range", 

234] 

235 

236 

237SAME_AS_TYPE = "<same as type>" 

238 

239 

240ParameterizedSize_N = int 

241 

242 

243class ParameterizedSize(Node): 

244 """Describes a range of valid tensor axis sizes as `size = min + n*step`.""" 

245 

246 N: ClassVar[Type[int]] = ParameterizedSize_N 

247 """integer to parameterize this axis""" 

248 

249 min: Annotated[int, Gt(0)] 

250 step: Annotated[int, Gt(0)] 

251 

252 def validate_size(self, size: int) -> int: 

253 if size < self.min: 

254 raise ValueError(f"size {size} < {self.min}") 

255 if (size - self.min) % self.step != 0: 

256 raise ValueError( 

257 f"axis of size {size} is not parameterized by `min + n*step` =" 

258 + f" `{self.min} + n*{self.step}`" 

259 ) 

260 

261 return size 

262 

263 def get_size(self, n: ParameterizedSize_N) -> int: 

264 return self.min + self.step * n 

265 

266 def get_n(self, s: int) -> ParameterizedSize_N: 

267 """return smallest n parameterizing a size greater or equal than `s`""" 

268 return ceil((s - self.min) / self.step) 

269 

270 

271class DataDependentSize(Node): 

272 min: Annotated[int, Gt(0)] = 1 

273 max: Annotated[Optional[int], Gt(1)] = None 

274 

275 @model_validator(mode="after") 

276 def _validate_max_gt_min(self): 

277 if self.max is not None and self.min >= self.max: 

278 raise ValueError(f"expected `min` < `max`, but got {self.min}, {self.max}") 

279 

280 return self 

281 

282 def validate_size(self, size: int) -> int: 

283 if size < self.min: 

284 raise ValueError(f"size {size} < {self.min}") 

285 

286 if self.max is not None and size > self.max: 

287 raise ValueError(f"size {size} > {self.max}") 

288 

289 return size 

290 

291 

292class SizeReference(Node): 

293 """A tensor axis size (extent in pixels/frames) defined in relation to a reference axis. 

294 

295 `axis.size = reference.size * reference.scale / axis.scale + offset` 

296 

297 Note: 

298 1. The axis and the referenced axis need to have the same unit (or no unit). 

299 2. Batch axes may not be referenced. 

300 3. Fractions are rounded down. 

301 4. If the reference axis is `concatenable` the referencing axis is assumed to be 

302 `concatenable` as well with the same block order. 

303 

304 Example: 

305 An unisotropic input image of w*h=100*49 pixels depicts a phsical space of 200*196mm². 

306 Let's assume that we want to express the image height h in relation to its width w 

307 instead of only accepting input images of exactly 100*49 pixels 

308 (for example to express a range of valid image shapes by parametrizing w, see `ParameterizedSize`). 

309 

310 >>> w = SpaceInputAxis(id=AxisId("w"), size=100, unit="millimeter", scale=2) 

311 >>> h = SpaceInputAxis( 

312 ... id=AxisId("h"), 

313 ... size=SizeReference(tensor_id=TensorId("input"), axis_id=AxisId("w"), offset=-1), 

314 ... unit="millimeter", 

315 ... scale=4, 

316 ... ) 

317 >>> print(h.size.get_size(h, w)) 

318 49 

319 

320 ⇒ h = w * w.scale / h.scale + offset = 100 * 2mm / 4mm - 1 = 49 

321 """ 

322 

323 tensor_id: TensorId 

324 """tensor id of the reference axis""" 

325 

326 axis_id: AxisId 

327 """axis id of the reference axis""" 

328 

329 offset: int = 0 

330 

331 def get_size( 

332 self, 

333 axis: Union[ 

334 ChannelAxis, 

335 IndexInputAxis, 

336 IndexOutputAxis, 

337 TimeInputAxis, 

338 SpaceInputAxis, 

339 TimeOutputAxis, 

340 TimeOutputAxisWithHalo, 

341 SpaceOutputAxis, 

342 SpaceOutputAxisWithHalo, 

343 ], 

344 ref_axis: Union[ 

345 ChannelAxis, 

346 IndexInputAxis, 

347 IndexOutputAxis, 

348 TimeInputAxis, 

349 SpaceInputAxis, 

350 TimeOutputAxis, 

351 TimeOutputAxisWithHalo, 

352 SpaceOutputAxis, 

353 SpaceOutputAxisWithHalo, 

354 ], 

355 n: ParameterizedSize_N = 0, 

356 ref_size: Optional[int] = None, 

357 ): 

358 """Compute the concrete size for a given axis and its reference axis. 

359 

360 Args: 

361 axis: The axis this `SizeReference` is the size of. 

362 ref_axis: The reference axis to compute the size from. 

363 n: If the **ref_axis** is parameterized (of type `ParameterizedSize`) 

364 and no fixed **ref_size** is given, 

365 **n** is used to compute the size of the parameterized **ref_axis**. 

366 ref_size: Overwrite the reference size instead of deriving it from 

367 **ref_axis** 

368 (**ref_axis.scale** is still used; any given **n** is ignored). 

369 """ 

370 assert ( 

371 axis.size == self 

372 ), "Given `axis.size` is not defined by this `SizeReference`" 

373 

374 assert ( 

375 ref_axis.id == self.axis_id 

376 ), f"Expected `ref_axis.id` to be {self.axis_id}, but got {ref_axis.id}." 

377 

378 assert axis.unit == ref_axis.unit, ( 

379 "`SizeReference` requires `axis` and `ref_axis` to have the same `unit`," 

380 f" but {axis.unit}!={ref_axis.unit}" 

381 ) 

382 if ref_size is None: 

383 if isinstance(ref_axis.size, (int, float)): 

384 ref_size = ref_axis.size 

385 elif isinstance(ref_axis.size, ParameterizedSize): 

386 ref_size = ref_axis.size.get_size(n) 

387 elif isinstance(ref_axis.size, DataDependentSize): 

388 raise ValueError( 

389 "Reference axis referenced in `SizeReference` may not be a `DataDependentSize`." 

390 ) 

391 elif isinstance(ref_axis.size, SizeReference): 

392 raise ValueError( 

393 "Reference axis referenced in `SizeReference` may not be sized by a" 

394 + " `SizeReference` itself." 

395 ) 

396 else: 

397 assert_never(ref_axis.size) 

398 

399 return int(ref_size * ref_axis.scale / axis.scale + self.offset) 

400 

401 @staticmethod 

402 def _get_unit( 

403 axis: Union[ 

404 ChannelAxis, 

405 IndexInputAxis, 

406 IndexOutputAxis, 

407 TimeInputAxis, 

408 SpaceInputAxis, 

409 TimeOutputAxis, 

410 TimeOutputAxisWithHalo, 

411 SpaceOutputAxis, 

412 SpaceOutputAxisWithHalo, 

413 ], 

414 ): 

415 return axis.unit 

416 

417 

418class AxisBase(NodeWithExplicitlySetFields): 

419 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"type"}) 

420 

421 id: AxisId 

422 """An axis id unique across all axes of one tensor.""" 

423 

424 description: Annotated[str, MaxLen(128)] = "" 

425 

426 

427class WithHalo(Node): 

428 halo: Annotated[int, Ge(1)] 

429 """The halo should be cropped from the output tensor to avoid boundary effects. 

430 It is to be cropped from both sides, i.e. `size_after_crop = size - 2 * halo`. 

431 To document a halo that is already cropped by the model use `size.offset` instead.""" 

432 

433 size: Annotated[ 

434 SizeReference, 

435 Field( 

436 examples=[ 

437 10, 

438 SizeReference( 

439 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

440 ).model_dump(mode="json"), 

441 ] 

442 ), 

443 ] 

444 """reference to another axis with an optional offset (see `SizeReference`)""" 

445 

446 

447BATCH_AXIS_ID = AxisId("batch") 

448 

449 

450class BatchAxis(AxisBase): 

451 type: Literal["batch"] = "batch" 

452 id: Annotated[AxisId, Predicate(_is_batch)] = BATCH_AXIS_ID 

453 size: Optional[Literal[1]] = None 

454 """The batch size may be fixed to 1, 

455 otherwise (the default) it may be chosen arbitrarily depending on available memory""" 

456 

457 @property 

458 def scale(self): 

459 return 1.0 

460 

461 @property 

462 def concatenable(self): 

463 return True 

464 

465 @property 

466 def unit(self): 

467 return None 

468 

469 

470class ChannelAxis(AxisBase): 

471 type: Literal["channel"] = "channel" 

472 id: NonBatchAxisId = AxisId("channel") 

473 channel_names: NotEmpty[List[Identifier]] 

474 

475 @property 

476 def size(self) -> int: 

477 return len(self.channel_names) 

478 

479 @property 

480 def concatenable(self): 

481 return False 

482 

483 @property 

484 def scale(self) -> float: 

485 return 1.0 

486 

487 @property 

488 def unit(self): 

489 return None 

490 

491 

492class IndexAxisBase(AxisBase): 

493 type: Literal["index"] = "index" 

494 id: NonBatchAxisId = AxisId("index") 

495 

496 @property 

497 def scale(self) -> float: 

498 return 1.0 

499 

500 @property 

501 def unit(self): 

502 return None 

503 

504 

505class _WithInputAxisSize(Node): 

506 size: Annotated[ 

507 Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference], 

508 Field( 

509 examples=[ 

510 10, 

511 ParameterizedSize(min=32, step=16).model_dump(mode="json"), 

512 SizeReference( 

513 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

514 ).model_dump(mode="json"), 

515 ] 

516 ), 

517 ] 

518 """The size/length of this axis can be specified as 

519 - fixed integer 

520 - parameterized series of valid sizes (`ParameterizedSize`) 

521 - reference to another axis with an optional offset (`SizeReference`) 

522 """ 

523 

524 

525class IndexInputAxis(IndexAxisBase, _WithInputAxisSize): 

526 concatenable: bool = False 

527 """If a model has a `concatenable` input axis, it can be processed blockwise, 

528 splitting a longer sample axis into blocks matching its input tensor description. 

529 Output axes are concatenable if they have a `SizeReference` to a concatenable 

530 input axis. 

531 """ 

532 

533 

534class IndexOutputAxis(IndexAxisBase): 

535 size: Annotated[ 

536 Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize], 

537 Field( 

538 examples=[ 

539 10, 

540 SizeReference( 

541 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

542 ).model_dump(mode="json"), 

543 ] 

544 ), 

545 ] 

546 """The size/length of this axis can be specified as 

547 - fixed integer 

548 - reference to another axis with an optional offset (`SizeReference`) 

549 - data dependent size using `DataDependentSize` (size is only known after model inference) 

550 """ 

551 

552 

553class TimeAxisBase(AxisBase): 

554 type: Literal["time"] = "time" 

555 id: NonBatchAxisId = AxisId("time") 

556 unit: Optional[TimeUnit] = None 

557 scale: Annotated[float, Gt(0)] = 1.0 

558 

559 

560class TimeInputAxis(TimeAxisBase, _WithInputAxisSize): 

561 concatenable: bool = False 

562 """If a model has a `concatenable` input axis, it can be processed blockwise, 

563 splitting a longer sample axis into blocks matching its input tensor description. 

564 Output axes are concatenable if they have a `SizeReference` to a concatenable 

565 input axis. 

566 """ 

567 

568 

569class SpaceAxisBase(AxisBase): 

570 type: Literal["space"] = "space" 

571 id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x") 

572 unit: Optional[SpaceUnit] = None 

573 scale: Annotated[float, Gt(0)] = 1.0 

574 

575 

576class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize): 

577 concatenable: bool = False 

578 """If a model has a `concatenable` input axis, it can be processed blockwise, 

579 splitting a longer sample axis into blocks matching its input tensor description. 

580 Output axes are concatenable if they have a `SizeReference` to a concatenable 

581 input axis. 

582 """ 

583 

584 

585_InputAxisUnion = Union[ 

586 BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis 

587] 

588InputAxis = Annotated[_InputAxisUnion, Discriminator("type")] 

589 

590 

591class _WithOutputAxisSize(Node): 

592 size: Annotated[ 

593 Union[Annotated[int, Gt(0)], SizeReference], 

594 Field( 

595 examples=[ 

596 10, 

597 SizeReference( 

598 tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5 

599 ).model_dump(mode="json"), 

600 ] 

601 ), 

602 ] 

603 """The size/length of this axis can be specified as 

604 - fixed integer 

605 - reference to another axis with an optional offset (see `SizeReference`) 

606 """ 

607 

608 

609class TimeOutputAxis(TimeAxisBase, _WithOutputAxisSize): 

610 pass 

611 

612 

613class TimeOutputAxisWithHalo(TimeAxisBase, WithHalo): 

614 pass 

615 

616 

617def _get_halo_axis_discriminator_value(v: Any) -> Literal["with_halo", "wo_halo"]: 

618 if isinstance(v, dict): 

619 return "with_halo" if "halo" in v else "wo_halo" 

620 else: 

621 return "with_halo" if hasattr(v, "halo") else "wo_halo" 

622 

623 

624_TimeOutputAxisUnion = Annotated[ 

625 Union[ 

626 Annotated[TimeOutputAxis, Tag("wo_halo")], 

627 Annotated[TimeOutputAxisWithHalo, Tag("with_halo")], 

628 ], 

629 Discriminator(_get_halo_axis_discriminator_value), 

630] 

631 

632 

633class SpaceOutputAxis(SpaceAxisBase, _WithOutputAxisSize): 

634 pass 

635 

636 

637class SpaceOutputAxisWithHalo(SpaceAxisBase, WithHalo): 

638 pass 

639 

640 

641_SpaceOutputAxisUnion = Annotated[ 

642 Union[ 

643 Annotated[SpaceOutputAxis, Tag("wo_halo")], 

644 Annotated[SpaceOutputAxisWithHalo, Tag("with_halo")], 

645 ], 

646 Discriminator(_get_halo_axis_discriminator_value), 

647] 

648 

649 

650_OutputAxisUnion = Union[ 

651 BatchAxis, ChannelAxis, IndexOutputAxis, _TimeOutputAxisUnion, _SpaceOutputAxisUnion 

652] 

653OutputAxis = Annotated[_OutputAxisUnion, Discriminator("type")] 

654 

655AnyAxis = Union[InputAxis, OutputAxis] 

656 

657TVs = Union[ 

658 NotEmpty[List[int]], 

659 NotEmpty[List[float]], 

660 NotEmpty[List[bool]], 

661 NotEmpty[List[str]], 

662] 

663 

664 

665NominalOrOrdinalDType = Literal[ 

666 "float32", 

667 "float64", 

668 "uint8", 

669 "int8", 

670 "uint16", 

671 "int16", 

672 "uint32", 

673 "int32", 

674 "uint64", 

675 "int64", 

676 "bool", 

677] 

678 

679 

680class NominalOrOrdinalDataDescr(Node): 

681 values: TVs 

682 """A fixed set of nominal or an ascending sequence of ordinal values. 

683 In this case `data_type` is required to be an unsigend integer type, e.g. 'uint8'. 

684 String `values` are interpreted as labels for tensor values 0, ..., N. 

685 Note: as YAML 1.2 does not natively support a "set" datatype, 

686 nominal values should be given as a sequence (aka list/array) as well. 

687 """ 

688 

689 type: Annotated[ 

690 NominalOrOrdinalDType, 

691 Field( 

692 examples=[ 

693 "float32", 

694 "uint8", 

695 "uint16", 

696 "int64", 

697 "bool", 

698 ], 

699 ), 

700 ] = "uint8" 

701 

702 @model_validator(mode="after") 

703 def _validate_values_match_type( 

704 self, 

705 ) -> Self: 

706 incompatible: List[Any] = [] 

707 for v in self.values: 

708 if self.type == "bool": 

709 if not isinstance(v, bool): 

710 incompatible.append(v) 

711 elif self.type in DTYPE_LIMITS: 

712 if ( 

713 isinstance(v, (int, float)) 

714 and ( 

715 v < DTYPE_LIMITS[self.type].min 

716 or v > DTYPE_LIMITS[self.type].max 

717 ) 

718 or (isinstance(v, str) and "uint" not in self.type) 

719 or (isinstance(v, float) and "int" in self.type) 

720 ): 

721 incompatible.append(v) 

722 else: 

723 incompatible.append(v) 

724 

725 if len(incompatible) == 5: 

726 incompatible.append("...") 

727 break 

728 

729 if incompatible: 

730 raise ValueError( 

731 f"data type '{self.type}' incompatible with values {incompatible}" 

732 ) 

733 

734 return self 

735 

736 unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None 

737 

738 @property 

739 def range(self): 

740 if isinstance(self.values[0], str): 

741 return 0, len(self.values) - 1 

742 else: 

743 return min(self.values), max(self.values) 

744 

745 

746IntervalOrRatioDType = Literal[ 

747 "float32", 

748 "float64", 

749 "uint8", 

750 "int8", 

751 "uint16", 

752 "int16", 

753 "uint32", 

754 "int32", 

755 "uint64", 

756 "int64", 

757] 

758 

759 

760class IntervalOrRatioDataDescr(Node): 

761 type: Annotated[ # todo: rename to dtype 

762 IntervalOrRatioDType, 

763 Field( 

764 examples=["float32", "float64", "uint8", "uint16"], 

765 ), 

766 ] = "float32" 

767 range: Tuple[Optional[float], Optional[float]] = ( 

768 None, 

769 None, 

770 ) 

771 """Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor. 

772 `None` corresponds to min/max of what can be expressed by `data_type`.""" 

773 unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit" 

774 scale: float = 1.0 

775 """Scale for data on an interval (or ratio) scale.""" 

776 offset: Optional[float] = None 

777 """Offset for data on a ratio scale.""" 

778 

779 

780TensorDataDescr = Union[NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr] 

781 

782 

783class ProcessingDescrBase(NodeWithExplicitlySetFields, ABC): 

784 """processing base class""" 

785 

786 # id: Literal[PreprocessingId, PostprocessingId] # make abstract field 

787 fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"}) 

788 

789 

790class BinarizeKwargs(ProcessingKwargs): 

791 """key word arguments for `BinarizeDescr`""" 

792 

793 threshold: float 

794 """The fixed threshold""" 

795 

796 

797class BinarizeAlongAxisKwargs(ProcessingKwargs): 

798 """key word arguments for `BinarizeDescr`""" 

799 

800 threshold: NotEmpty[List[float]] 

801 """The fixed threshold values along `axis`""" 

802 

803 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

804 """The `threshold` axis""" 

805 

806 

807class BinarizeDescr(ProcessingDescrBase): 

808 """Binarize the tensor with a fixed threshold. 

809 

810 Values above `BinarizeKwargs.threshold`/`BinarizeAlongAxisKwargs.threshold` 

811 will be set to one, values below the threshold to zero. 

812 

813 Examples: 

814 - in YAML 

815 ```yaml 

816 postprocessing: 

817 - id: binarize 

818 kwargs: 

819 axis: 'channel' 

820 threshold: [0.25, 0.5, 0.75] 

821 ``` 

822 - in Python: 

823 >>> postprocessing = [BinarizeDescr( 

824 ... kwargs=BinarizeAlongAxisKwargs( 

825 ... axis=AxisId('channel'), 

826 ... threshold=[0.25, 0.5, 0.75], 

827 ... ) 

828 ... )] 

829 """ 

830 

831 id: Literal["binarize"] = "binarize" 

832 kwargs: Union[BinarizeKwargs, BinarizeAlongAxisKwargs] 

833 

834 

835class ClipDescr(ProcessingDescrBase): 

836 """Set tensor values below min to min and above max to max. 

837 

838 See `ScaleRangeDescr` for examples. 

839 """ 

840 

841 id: Literal["clip"] = "clip" 

842 kwargs: ClipKwargs 

843 

844 

845class EnsureDtypeKwargs(ProcessingKwargs): 

846 """key word arguments for `EnsureDtypeDescr`""" 

847 

848 dtype: Literal[ 

849 "float32", 

850 "float64", 

851 "uint8", 

852 "int8", 

853 "uint16", 

854 "int16", 

855 "uint32", 

856 "int32", 

857 "uint64", 

858 "int64", 

859 "bool", 

860 ] 

861 

862 

863class EnsureDtypeDescr(ProcessingDescrBase): 

864 """Cast the tensor data type to `EnsureDtypeKwargs.dtype` (if not matching). 

865 

866 This can for example be used to ensure the inner neural network model gets a 

867 different input tensor data type than the fully described bioimage.io model does. 

868 

869 Examples: 

870 The described bioimage.io model (incl. preprocessing) accepts any 

871 float32-compatible tensor, normalizes it with percentiles and clipping and then 

872 casts it to uint8, which is what the neural network in this example expects. 

873 - in YAML 

874 ```yaml 

875 inputs: 

876 - data: 

877 type: float32 # described bioimage.io model is compatible with any float32 input tensor 

878 preprocessing: 

879 - id: scale_range 

880 kwargs: 

881 axes: ['y', 'x'] 

882 max_percentile: 99.8 

883 min_percentile: 5.0 

884 - id: clip 

885 kwargs: 

886 min: 0.0 

887 max: 1.0 

888 - id: ensure_dtype 

889 kwargs: 

890 dtype: uint8 

891 ``` 

892 - in Python: 

893 >>> preprocessing = [ 

894 ... ScaleRangeDescr( 

895 ... kwargs=ScaleRangeKwargs( 

896 ... axes= (AxisId('y'), AxisId('x')), 

897 ... max_percentile= 99.8, 

898 ... min_percentile= 5.0, 

899 ... ) 

900 ... ), 

901 ... ClipDescr(kwargs=ClipKwargs(min=0.0, max=1.0)), 

902 ... EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype="uint8")), 

903 ... ] 

904 """ 

905 

906 id: Literal["ensure_dtype"] = "ensure_dtype" 

907 kwargs: EnsureDtypeKwargs 

908 

909 

910class ScaleLinearKwargs(ProcessingKwargs): 

911 """Key word arguments for `ScaleLinearDescr`""" 

912 

913 gain: float = 1.0 

914 """multiplicative factor""" 

915 

916 offset: float = 0.0 

917 """additive term""" 

918 

919 @model_validator(mode="after") 

920 def _validate(self) -> Self: 

921 if self.gain == 1.0 and self.offset == 0.0: 

922 raise ValueError( 

923 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

924 + " != 0.0." 

925 ) 

926 

927 return self 

928 

929 

930class ScaleLinearAlongAxisKwargs(ProcessingKwargs): 

931 """Key word arguments for `ScaleLinearDescr`""" 

932 

933 axis: Annotated[NonBatchAxisId, Field(examples=["channel"])] 

934 """The axis of of gains/offsets values.""" 

935 

936 gain: Union[float, NotEmpty[List[float]]] = 1.0 

937 """multiplicative factor""" 

938 

939 offset: Union[float, NotEmpty[List[float]]] = 0.0 

940 """additive term""" 

941 

942 @model_validator(mode="after") 

943 def _validate(self) -> Self: 

944 

945 if isinstance(self.gain, list): 

946 if isinstance(self.offset, list): 

947 if len(self.gain) != len(self.offset): 

948 raise ValueError( 

949 f"Size of `gain` ({len(self.gain)}) and `offset` ({len(self.offset)}) must match." 

950 ) 

951 else: 

952 self.offset = [float(self.offset)] * len(self.gain) 

953 elif isinstance(self.offset, list): 

954 self.gain = [float(self.gain)] * len(self.offset) 

955 else: 

956 raise ValueError( 

957 "Do not specify an `axis` for scalar gain and offset values." 

958 ) 

959 

960 if all(g == 1.0 for g in self.gain) and all(off == 0.0 for off in self.offset): 

961 raise ValueError( 

962 "Redundant linear scaling not allowd. Set `gain` != 1.0 and/or `offset`" 

963 + " != 0.0." 

964 ) 

965 

966 return self 

967 

968 

969class ScaleLinearDescr(ProcessingDescrBase): 

970 """Fixed linear scaling. 

971 

972 Examples: 

973 1. Scale with scalar gain and offset 

974 - in YAML 

975 ```yaml 

976 preprocessing: 

977 - id: scale_linear 

978 kwargs: 

979 gain: 2.0 

980 offset: 3.0 

981 ``` 

982 - in Python: 

983 >>> preprocessing = [ 

984 ... ScaleLinearDescr(kwargs=ScaleLinearKwargs(gain= 2.0, offset=3.0)) 

985 ... ] 

986 

987 2. Independent scaling along an axis 

988 - in YAML 

989 ```yaml 

990 preprocessing: 

991 - id: scale_linear 

992 kwargs: 

993 axis: 'channel' 

994 gain: [1.0, 2.0, 3.0] 

995 ``` 

996 - in Python: 

997 >>> preprocessing = [ 

998 ... ScaleLinearDescr( 

999 ... kwargs=ScaleLinearAlongAxisKwargs( 

1000 ... axis=AxisId("channel"), 

1001 ... gain=[1.0, 2.0, 3.0], 

1002 ... ) 

1003 ... ) 

1004 ... ] 

1005 

1006 """ 

1007 

1008 id: Literal["scale_linear"] = "scale_linear" 

1009 kwargs: Union[ScaleLinearKwargs, ScaleLinearAlongAxisKwargs] 

1010 

1011 

1012class SigmoidDescr(ProcessingDescrBase): 

1013 """The logistic sigmoid funciton, a.k.a. expit function. 

1014 

1015 Examples: 

1016 - in YAML 

1017 ```yaml 

1018 postprocessing: 

1019 - id: sigmoid 

1020 ``` 

1021 - in Python: 

1022 >>> postprocessing = [SigmoidDescr()] 

1023 """ 

1024 

1025 id: Literal["sigmoid"] = "sigmoid" 

1026 

1027 @property 

1028 def kwargs(self) -> ProcessingKwargs: 

1029 """empty kwargs""" 

1030 return ProcessingKwargs() 

1031 

1032 

1033class FixedZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1034 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1035 

1036 mean: float 

1037 """The mean value to normalize with.""" 

1038 

1039 std: Annotated[float, Ge(1e-6)] 

1040 """The standard deviation value to normalize with.""" 

1041 

1042 

1043class FixedZeroMeanUnitVarianceAlongAxisKwargs(ProcessingKwargs): 

1044 """key word arguments for `FixedZeroMeanUnitVarianceDescr`""" 

1045 

1046 mean: NotEmpty[List[float]] 

1047 """The mean value(s) to normalize with.""" 

1048 

1049 std: NotEmpty[List[Annotated[float, Ge(1e-6)]]] 

1050 """The standard deviation value(s) to normalize with. 

1051 Size must match `mean` values.""" 

1052 

1053 axis: Annotated[NonBatchAxisId, Field(examples=["channel", "index"])] 

1054 """The axis of the mean/std values to normalize each entry along that dimension 

1055 separately.""" 

1056 

1057 @model_validator(mode="after") 

1058 def _mean_and_std_match(self) -> Self: 

1059 if len(self.mean) != len(self.std): 

1060 raise ValueError( 

1061 f"Size of `mean` ({len(self.mean)}) and `std` ({len(self.std)})" 

1062 + " must match." 

1063 ) 

1064 

1065 return self 

1066 

1067 

1068class FixedZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1069 """Subtract a given mean and divide by the standard deviation. 

1070 

1071 Normalize with fixed, precomputed values for 

1072 `FixedZeroMeanUnitVarianceKwargs.mean` and `FixedZeroMeanUnitVarianceKwargs.std` 

1073 Use `FixedZeroMeanUnitVarianceAlongAxisKwargs` for independent scaling along given 

1074 axes. 

1075 

1076 Examples: 

1077 1. scalar value for whole tensor 

1078 - in YAML 

1079 ```yaml 

1080 preprocessing: 

1081 - id: fixed_zero_mean_unit_variance 

1082 kwargs: 

1083 mean: 103.5 

1084 std: 13.7 

1085 ``` 

1086 - in Python 

1087 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1088 ... kwargs=FixedZeroMeanUnitVarianceKwargs(mean=103.5, std=13.7) 

1089 ... )] 

1090 

1091 2. independently along an axis 

1092 - in YAML 

1093 ```yaml 

1094 preprocessing: 

1095 - id: fixed_zero_mean_unit_variance 

1096 kwargs: 

1097 axis: channel 

1098 mean: [101.5, 102.5, 103.5] 

1099 std: [11.7, 12.7, 13.7] 

1100 ``` 

1101 - in Python 

1102 >>> preprocessing = [FixedZeroMeanUnitVarianceDescr( 

1103 ... kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1104 ... axis=AxisId("channel"), 

1105 ... mean=[101.5, 102.5, 103.5], 

1106 ... std=[11.7, 12.7, 13.7], 

1107 ... ) 

1108 ... )] 

1109 """ 

1110 

1111 id: Literal["fixed_zero_mean_unit_variance"] = "fixed_zero_mean_unit_variance" 

1112 kwargs: Union[ 

1113 FixedZeroMeanUnitVarianceKwargs, FixedZeroMeanUnitVarianceAlongAxisKwargs 

1114 ] 

1115 

1116 

1117class ZeroMeanUnitVarianceKwargs(ProcessingKwargs): 

1118 """key word arguments for `ZeroMeanUnitVarianceDescr`""" 

1119 

1120 axes: Annotated[ 

1121 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1122 ] = None 

1123 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1124 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1125 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1126 To normalize each sample independently leave out the 'batch' axis. 

1127 Default: Scale all axes jointly.""" 

1128 

1129 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1130 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`.""" 

1131 

1132 

1133class ZeroMeanUnitVarianceDescr(ProcessingDescrBase): 

1134 """Subtract mean and divide by variance. 

1135 

1136 Examples: 

1137 Subtract tensor mean and variance 

1138 - in YAML 

1139 ```yaml 

1140 preprocessing: 

1141 - id: zero_mean_unit_variance 

1142 ``` 

1143 - in Python 

1144 >>> preprocessing = [ZeroMeanUnitVarianceDescr()] 

1145 """ 

1146 

1147 id: Literal["zero_mean_unit_variance"] = "zero_mean_unit_variance" 

1148 kwargs: ZeroMeanUnitVarianceKwargs = Field( 

1149 default_factory=ZeroMeanUnitVarianceKwargs 

1150 ) 

1151 

1152 

1153class ScaleRangeKwargs(ProcessingKwargs): 

1154 """key word arguments for `ScaleRangeDescr` 

1155 

1156 For `min_percentile`=0.0 (the default) and `max_percentile`=100 (the default) 

1157 this processing step normalizes data to the [0, 1] intervall. 

1158 For other percentiles the normalized values will partially be outside the [0, 1] 

1159 intervall. Use `ScaleRange` followed by `ClipDescr` if you want to limit the 

1160 normalized values to a range. 

1161 """ 

1162 

1163 axes: Annotated[ 

1164 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1165 ] = None 

1166 """The subset of axes to normalize jointly, i.e. axes to reduce to compute the min/max percentile value. 

1167 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1168 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1169 To normalize samples independently, leave out the "batch" axis. 

1170 Default: Scale all axes jointly.""" 

1171 

1172 min_percentile: Annotated[float, Interval(ge=0, lt=100)] = 0.0 

1173 """The lower percentile used to determine the value to align with zero.""" 

1174 

1175 max_percentile: Annotated[float, Interval(gt=1, le=100)] = 100.0 

1176 """The upper percentile used to determine the value to align with one. 

1177 Has to be bigger than `min_percentile`. 

1178 The range is 1 to 100 instead of 0 to 100 to avoid mistakenly 

1179 accepting percentiles specified in the range 0.0 to 1.0.""" 

1180 

1181 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1182 """Epsilon for numeric stability. 

1183 `out = (tensor - v_lower) / (v_upper - v_lower + eps)`; 

1184 with `v_lower,v_upper` values at the respective percentiles.""" 

1185 

1186 reference_tensor: Optional[TensorId] = None 

1187 """Tensor ID to compute the percentiles from. Default: The tensor itself. 

1188 For any tensor in `inputs` only input tensor references are allowed.""" 

1189 

1190 @field_validator("max_percentile", mode="after") 

1191 @classmethod 

1192 def min_smaller_max(cls, value: float, info: ValidationInfo) -> float: 

1193 if (min_p := info.data["min_percentile"]) >= value: 

1194 raise ValueError(f"min_percentile {min_p} >= max_percentile {value}") 

1195 

1196 return value 

1197 

1198 

1199class ScaleRangeDescr(ProcessingDescrBase): 

1200 """Scale with percentiles. 

1201 

1202 Examples: 

1203 1. Scale linearly to map 5th percentile to 0 and 99.8th percentile to 1.0 

1204 - in YAML 

1205 ```yaml 

1206 preprocessing: 

1207 - id: scale_range 

1208 kwargs: 

1209 axes: ['y', 'x'] 

1210 max_percentile: 99.8 

1211 min_percentile: 5.0 

1212 ``` 

1213 - in Python 

1214 >>> preprocessing = [ 

1215 ... ScaleRangeDescr( 

1216 ... kwargs=ScaleRangeKwargs( 

1217 ... axes= (AxisId('y'), AxisId('x')), 

1218 ... max_percentile= 99.8, 

1219 ... min_percentile= 5.0, 

1220 ... ) 

1221 ... ), 

1222 ... ClipDescr( 

1223 ... kwargs=ClipKwargs( 

1224 ... min=0.0, 

1225 ... max=1.0, 

1226 ... ) 

1227 ... ), 

1228 ... ] 

1229 

1230 2. Combine the above scaling with additional clipping to clip values outside the range given by the percentiles. 

1231 - in YAML 

1232 ```yaml 

1233 preprocessing: 

1234 - id: scale_range 

1235 kwargs: 

1236 axes: ['y', 'x'] 

1237 max_percentile: 99.8 

1238 min_percentile: 5.0 

1239 - id: scale_range 

1240 - id: clip 

1241 kwargs: 

1242 min: 0.0 

1243 max: 1.0 

1244 ``` 

1245 - in Python 

1246 >>> preprocessing = [ScaleRangeDescr( 

1247 ... kwargs=ScaleRangeKwargs( 

1248 ... axes= (AxisId('y'), AxisId('x')), 

1249 ... max_percentile= 99.8, 

1250 ... min_percentile= 5.0, 

1251 ... ) 

1252 ... )] 

1253 

1254 """ 

1255 

1256 id: Literal["scale_range"] = "scale_range" 

1257 kwargs: ScaleRangeKwargs 

1258 

1259 

1260class ScaleMeanVarianceKwargs(ProcessingKwargs): 

1261 """key word arguments for `ScaleMeanVarianceKwargs`""" 

1262 

1263 reference_tensor: TensorId 

1264 """Name of tensor to match.""" 

1265 

1266 axes: Annotated[ 

1267 Optional[Sequence[AxisId]], Field(examples=[("batch", "x", "y")]) 

1268 ] = None 

1269 """The subset of axes to normalize jointly, i.e. axes to reduce to compute mean/std. 

1270 For example to normalize 'batch', 'x' and 'y' jointly in a tensor ('batch', 'channel', 'y', 'x') 

1271 resulting in a tensor of equal shape normalized per channel, specify `axes=('batch', 'x', 'y')`. 

1272 To normalize samples independently, leave out the 'batch' axis. 

1273 Default: Scale all axes jointly.""" 

1274 

1275 eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6 

1276 """Epsilon for numeric stability: 

1277 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.`""" 

1278 

1279 

1280class ScaleMeanVarianceDescr(ProcessingDescrBase): 

1281 """Scale a tensor's data distribution to match another tensor's mean/std. 

1282 `out = (tensor - mean) / (std + eps) * (ref_std + eps) + ref_mean.` 

1283 """ 

1284 

1285 id: Literal["scale_mean_variance"] = "scale_mean_variance" 

1286 kwargs: ScaleMeanVarianceKwargs 

1287 

1288 

1289PreprocessingDescr = Annotated[ 

1290 Union[ 

1291 BinarizeDescr, 

1292 ClipDescr, 

1293 EnsureDtypeDescr, 

1294 ScaleLinearDescr, 

1295 SigmoidDescr, 

1296 FixedZeroMeanUnitVarianceDescr, 

1297 ZeroMeanUnitVarianceDescr, 

1298 ScaleRangeDescr, 

1299 ], 

1300 Discriminator("id"), 

1301] 

1302PostprocessingDescr = Annotated[ 

1303 Union[ 

1304 BinarizeDescr, 

1305 ClipDescr, 

1306 EnsureDtypeDescr, 

1307 ScaleLinearDescr, 

1308 SigmoidDescr, 

1309 FixedZeroMeanUnitVarianceDescr, 

1310 ZeroMeanUnitVarianceDescr, 

1311 ScaleRangeDescr, 

1312 ScaleMeanVarianceDescr, 

1313 ], 

1314 Discriminator("id"), 

1315] 

1316 

1317IO_AxisT = TypeVar("IO_AxisT", InputAxis, OutputAxis) 

1318 

1319 

1320class TensorDescrBase(Node, Generic[IO_AxisT]): 

1321 id: TensorId 

1322 """Tensor id. No duplicates are allowed.""" 

1323 

1324 description: Annotated[str, MaxLen(128)] = "" 

1325 """free text description""" 

1326 

1327 axes: NotEmpty[Sequence[IO_AxisT]] 

1328 """tensor axes""" 

1329 

1330 @property 

1331 def shape(self): 

1332 return tuple(a.size for a in self.axes) 

1333 

1334 @field_validator("axes", mode="after", check_fields=False) 

1335 @classmethod 

1336 def _validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]: 

1337 batch_axes = [a for a in axes if a.type == "batch"] 

1338 if len(batch_axes) > 1: 

1339 raise ValueError( 

1340 f"Only one batch axis (per tensor) allowed, but got {batch_axes}" 

1341 ) 

1342 

1343 seen_ids: Set[AxisId] = set() 

1344 duplicate_axes_ids: Set[AxisId] = set() 

1345 for a in axes: 

1346 (duplicate_axes_ids if a.id in seen_ids else seen_ids).add(a.id) 

1347 

1348 if duplicate_axes_ids: 

1349 raise ValueError(f"Duplicate axis ids: {duplicate_axes_ids}") 

1350 

1351 return axes 

1352 

1353 test_tensor: FileDescr 

1354 """An example tensor to use for testing. 

1355 Using the model with the test input tensors is expected to yield the test output tensors. 

1356 Each test tensor has be a an ndarray in the 

1357 [numpy.lib file format](https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html#module-numpy.lib.format). 

1358 The file extension must be '.npy'.""" 

1359 

1360 sample_tensor: Optional[FileDescr] = None 

1361 """A sample tensor to illustrate a possible input/output for the model, 

1362 The sample image primarily serves to inform a human user about an example use case 

1363 and is typically stored as .hdf5, .png or .tiff. 

1364 It has to be readable by the [imageio library](https://imageio.readthedocs.io/en/stable/formats/index.html#supported-formats) 

1365 (numpy's `.npy` format is not supported). 

1366 The image dimensionality has to match the number of axes specified in this tensor description. 

1367 """ 

1368 

1369 @model_validator(mode="after") 

1370 def _validate_sample_tensor(self) -> Self: 

1371 if ( 

1372 self.sample_tensor is None 

1373 or not validation_context_var.get().perform_io_checks 

1374 ): 

1375 return self 

1376 

1377 local = download(self.sample_tensor.source, sha256=self.sample_tensor.sha256) 

1378 tensor: NDArray[Any] = imread( 

1379 local.path.read_bytes(), 

1380 extension=PurePosixPath(local.original_file_name).suffix, 

1381 ) 

1382 n_dims = len(tensor.squeeze().shape) 

1383 n_dims_min = n_dims_max = len(self.axes) 

1384 

1385 for a in self.axes: 

1386 if isinstance(a, BatchAxis): 

1387 n_dims_min -= 1 

1388 elif isinstance(a.size, int): 

1389 if a.size == 1: 

1390 n_dims_min -= 1 

1391 elif isinstance(a.size, (ParameterizedSize, DataDependentSize)): 

1392 if a.size.min == 1: 

1393 n_dims_min -= 1 

1394 elif isinstance(a.size, SizeReference): 

1395 if a.size.offset < 2: 

1396 # size reference may result in singleton axis 

1397 n_dims_min -= 1 

1398 else: 

1399 assert_never(a.size) 

1400 

1401 n_dims_min = max(0, n_dims_min) 

1402 if n_dims < n_dims_min or n_dims > n_dims_max: 

1403 raise ValueError( 

1404 f"Expected sample tensor to have {n_dims_min} to" 

1405 + f" {n_dims_max} dimensions, but found {n_dims} (shape: {tensor.shape})." 

1406 ) 

1407 

1408 return self 

1409 

1410 data: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] = ( 

1411 IntervalOrRatioDataDescr() 

1412 ) 

1413 """Description of the tensor's data values, optionally per channel. 

1414 If specified per channel, the data `type` needs to match across channels.""" 

1415 

1416 @property 

1417 def dtype( 

1418 self, 

1419 ) -> Literal[ 

1420 "float32", 

1421 "float64", 

1422 "uint8", 

1423 "int8", 

1424 "uint16", 

1425 "int16", 

1426 "uint32", 

1427 "int32", 

1428 "uint64", 

1429 "int64", 

1430 "bool", 

1431 ]: 

1432 """dtype as specified under `data.type` or `data[i].type`""" 

1433 if isinstance(self.data, collections.abc.Sequence): 

1434 return self.data[0].type 

1435 else: 

1436 return self.data.type 

1437 

1438 @field_validator("data", mode="after") 

1439 @classmethod 

1440 def _check_data_type_across_channels( 

1441 cls, value: Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]] 

1442 ) -> Union[TensorDataDescr, NotEmpty[Sequence[TensorDataDescr]]]: 

1443 if not isinstance(value, list): 

1444 return value 

1445 

1446 dtypes = {t.type for t in value} 

1447 if len(dtypes) > 1: 

1448 raise ValueError( 

1449 "Tensor data descriptions per channel need to agree in their data" 

1450 + f" `type`, but found {dtypes}." 

1451 ) 

1452 

1453 return value 

1454 

1455 @model_validator(mode="after") 

1456 def _check_data_matches_channelaxis(self) -> Self: 

1457 if not isinstance(self.data, (list, tuple)): 

1458 return self 

1459 

1460 for a in self.axes: 

1461 if isinstance(a, ChannelAxis): 

1462 size = a.size 

1463 assert isinstance(size, int) 

1464 break 

1465 else: 

1466 return self 

1467 

1468 if len(self.data) != size: 

1469 raise ValueError( 

1470 f"Got tensor data descriptions for {len(self.data)} channels, but" 

1471 + f" '{a.id}' axis has size {size}." 

1472 ) 

1473 

1474 return self 

1475 

1476 def get_axis_sizes_for_array(self, array: NDArray[Any]) -> Dict[AxisId, int]: 

1477 if len(array.shape) != len(self.axes): 

1478 raise ValueError( 

1479 f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)})" 

1480 + f" incompatible with {len(self.axes)} axes." 

1481 ) 

1482 return {a.id: array.shape[i] for i, a in enumerate(self.axes)} 

1483 

1484 

1485class InputTensorDescr(TensorDescrBase[InputAxis]): 

1486 id: TensorId = TensorId("input") 

1487 """Input tensor id. 

1488 No duplicates are allowed across all inputs and outputs.""" 

1489 

1490 optional: bool = False 

1491 """indicates that this tensor may be `None`""" 

1492 

1493 preprocessing: List[PreprocessingDescr] = Field(default_factory=list) 

1494 """Description of how this input should be preprocessed. 

1495 

1496 notes: 

1497 - If preprocessing does not start with an 'ensure_dtype' entry, it is added 

1498 to ensure an input tensor's data type matches the input tensor's data description. 

1499 - If preprocessing does not end with an 'ensure_dtype' or 'binarize' entry, an 

1500 'ensure_dtype' step is added to ensure preprocessing steps are not unintentionally 

1501 changing the data type. 

1502 """ 

1503 

1504 @model_validator(mode="after") 

1505 def _validate_preprocessing_kwargs(self) -> Self: 

1506 axes_ids = [a.id for a in self.axes] 

1507 for p in self.preprocessing: 

1508 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1509 if kwargs_axes is None: 

1510 continue 

1511 

1512 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1513 raise ValueError( 

1514 f"Expected `preprocessing.i.kwargs.axes` to be a sequence, but got {type(kwargs_axes)}" 

1515 ) 

1516 

1517 if any(a not in axes_ids for a in kwargs_axes): 

1518 raise ValueError( 

1519 "`preprocessing.i.kwargs.axes` needs to be subset of axes ids" 

1520 ) 

1521 

1522 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1523 dtype = self.data.type 

1524 else: 

1525 dtype = self.data[0].type 

1526 

1527 # ensure `preprocessing` begins with `EnsureDtypeDescr` 

1528 if not self.preprocessing or not isinstance( 

1529 self.preprocessing[0], EnsureDtypeDescr 

1530 ): 

1531 self.preprocessing.insert( 

1532 0, EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1533 ) 

1534 

1535 # ensure `preprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1536 if not isinstance(self.preprocessing[-1], (EnsureDtypeDescr, BinarizeDescr)): 

1537 self.preprocessing.append( 

1538 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1539 ) 

1540 

1541 return self 

1542 

1543 

1544def convert_axes( 

1545 axes: str, 

1546 *, 

1547 shape: Union[ 

1548 Sequence[int], _ParameterizedInputShape_v0_4, _ImplicitOutputShape_v0_4 

1549 ], 

1550 tensor_type: Literal["input", "output"], 

1551 halo: Optional[Sequence[int]], 

1552 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1553): 

1554 ret: List[AnyAxis] = [] 

1555 for i, a in enumerate(axes): 

1556 axis_type = _AXIS_TYPE_MAP.get(a, a) 

1557 if axis_type == "batch": 

1558 ret.append(BatchAxis()) 

1559 continue 

1560 

1561 scale = 1.0 

1562 if isinstance(shape, _ParameterizedInputShape_v0_4): 

1563 if shape.step[i] == 0: 

1564 size = shape.min[i] 

1565 else: 

1566 size = ParameterizedSize(min=shape.min[i], step=shape.step[i]) 

1567 elif isinstance(shape, _ImplicitOutputShape_v0_4): 

1568 ref_t = str(shape.reference_tensor) 

1569 if ref_t.count(".") == 1: 

1570 t_id, orig_a_id = ref_t.split(".") 

1571 else: 

1572 t_id = ref_t 

1573 orig_a_id = a 

1574 

1575 a_id = _AXIS_ID_MAP.get(orig_a_id, a) 

1576 if not (orig_scale := shape.scale[i]): 

1577 # old way to insert a new axis dimension 

1578 size = int(2 * shape.offset[i]) 

1579 else: 

1580 scale = 1 / orig_scale 

1581 if axis_type in ("channel", "index"): 

1582 # these axes no longer have a scale 

1583 offset_from_scale = orig_scale * size_refs.get( 

1584 _TensorName_v0_4(t_id), {} 

1585 ).get(orig_a_id, 0) 

1586 else: 

1587 offset_from_scale = 0 

1588 size = SizeReference( 

1589 tensor_id=TensorId(t_id), 

1590 axis_id=AxisId(a_id), 

1591 offset=int(offset_from_scale + 2 * shape.offset[i]), 

1592 ) 

1593 else: 

1594 size = shape[i] 

1595 

1596 if axis_type == "time": 

1597 if tensor_type == "input": 

1598 ret.append(TimeInputAxis(size=size, scale=scale)) 

1599 else: 

1600 assert not isinstance(size, ParameterizedSize) 

1601 if halo is None: 

1602 ret.append(TimeOutputAxis(size=size, scale=scale)) 

1603 else: 

1604 assert not isinstance(size, int) 

1605 ret.append( 

1606 TimeOutputAxisWithHalo(size=size, scale=scale, halo=halo[i]) 

1607 ) 

1608 

1609 elif axis_type == "index": 

1610 if tensor_type == "input": 

1611 ret.append(IndexInputAxis(size=size)) 

1612 else: 

1613 if isinstance(size, ParameterizedSize): 

1614 size = DataDependentSize(min=size.min) 

1615 

1616 ret.append(IndexOutputAxis(size=size)) 

1617 elif axis_type == "channel": 

1618 assert not isinstance(size, ParameterizedSize) 

1619 if isinstance(size, SizeReference): 

1620 warnings.warn( 

1621 "Conversion of channel size from an implicit output shape may be" 

1622 + " wrong" 

1623 ) 

1624 ret.append( 

1625 ChannelAxis( 

1626 channel_names=[ 

1627 Identifier(f"channel{i}") for i in range(size.offset) 

1628 ] 

1629 ) 

1630 ) 

1631 else: 

1632 ret.append( 

1633 ChannelAxis( 

1634 channel_names=[Identifier(f"channel{i}") for i in range(size)] 

1635 ) 

1636 ) 

1637 elif axis_type == "space": 

1638 if tensor_type == "input": 

1639 ret.append(SpaceInputAxis(id=AxisId(a), size=size, scale=scale)) 

1640 else: 

1641 assert not isinstance(size, ParameterizedSize) 

1642 if halo is None or halo[i] == 0: 

1643 ret.append(SpaceOutputAxis(id=AxisId(a), size=size, scale=scale)) 

1644 elif isinstance(size, int): 

1645 raise NotImplementedError( 

1646 f"output axis with halo and fixed size (here {size}) not allowed" 

1647 ) 

1648 else: 

1649 ret.append( 

1650 SpaceOutputAxisWithHalo( 

1651 id=AxisId(a), size=size, scale=scale, halo=halo[i] 

1652 ) 

1653 ) 

1654 

1655 return ret 

1656 

1657 

1658_AXIS_TYPE_MAP = { 

1659 "b": "batch", 

1660 "t": "time", 

1661 "i": "index", 

1662 "c": "channel", 

1663 "x": "space", 

1664 "y": "space", 

1665 "z": "space", 

1666} 

1667 

1668_AXIS_ID_MAP = { 

1669 "b": "batch", 

1670 "t": "time", 

1671 "i": "index", 

1672 "c": "channel", 

1673} 

1674 

1675 

1676def _axes_letters_to_ids( 

1677 axes: Optional[str], 

1678) -> Optional[List[AxisId]]: 

1679 if axes is None: 

1680 return None 

1681 return [AxisId(_AXIS_ID_MAP.get(a, a)) for a in map(str, axes)] 

1682 

1683 

1684def _get_complement_v04_axis( 

1685 tensor_axes: Sequence[str], axes: Optional[Sequence[str]] 

1686) -> Optional[AxisId]: 

1687 if axes is None: 

1688 return None 

1689 

1690 non_complement_axes = set(axes) | {"b"} 

1691 complement_axes = [a for a in tensor_axes if a not in non_complement_axes] 

1692 if len(complement_axes) > 1: 

1693 raise ValueError( 

1694 f"Expected none or a single complement axis, but axes '{axes}' " 

1695 + f"for tensor dims '{tensor_axes}' leave '{complement_axes}'." 

1696 ) 

1697 

1698 return None if not complement_axes else AxisId(complement_axes[0]) 

1699 

1700 

1701def _convert_proc( 

1702 p: Union[_PreprocessingDescr_v0_4, _PostprocessingDescr_v0_4], 

1703 tensor_axes: Sequence[str], 

1704) -> Union[PreprocessingDescr, PostprocessingDescr]: 

1705 if isinstance(p, _BinarizeDescr_v0_4): 

1706 return BinarizeDescr(kwargs=BinarizeKwargs(threshold=p.kwargs.threshold)) 

1707 elif isinstance(p, _ClipDescr_v0_4): 

1708 return ClipDescr(kwargs=ClipKwargs(min=p.kwargs.min, max=p.kwargs.max)) 

1709 elif isinstance(p, _SigmoidDescr_v0_4): 

1710 return SigmoidDescr() 

1711 elif isinstance(p, _ScaleLinearDescr_v0_4): 

1712 axes = _axes_letters_to_ids(p.kwargs.axes) 

1713 if p.kwargs.axes is None: 

1714 axis = None 

1715 else: 

1716 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1717 

1718 if axis is None: 

1719 assert not isinstance(p.kwargs.gain, list) 

1720 assert not isinstance(p.kwargs.offset, list) 

1721 kwargs = ScaleLinearKwargs(gain=p.kwargs.gain, offset=p.kwargs.offset) 

1722 else: 

1723 kwargs = ScaleLinearAlongAxisKwargs( 

1724 axis=axis, gain=p.kwargs.gain, offset=p.kwargs.offset 

1725 ) 

1726 return ScaleLinearDescr(kwargs=kwargs) 

1727 elif isinstance(p, _ScaleMeanVarianceDescr_v0_4): 

1728 return ScaleMeanVarianceDescr( 

1729 kwargs=ScaleMeanVarianceKwargs( 

1730 axes=_axes_letters_to_ids(p.kwargs.axes), 

1731 reference_tensor=TensorId(str(p.kwargs.reference_tensor)), 

1732 eps=p.kwargs.eps, 

1733 ) 

1734 ) 

1735 elif isinstance(p, _ZeroMeanUnitVarianceDescr_v0_4): 

1736 if p.kwargs.mode == "fixed": 

1737 mean = p.kwargs.mean 

1738 std = p.kwargs.std 

1739 assert mean is not None 

1740 assert std is not None 

1741 

1742 axis = _get_complement_v04_axis(tensor_axes, p.kwargs.axes) 

1743 

1744 if axis is None: 

1745 return FixedZeroMeanUnitVarianceDescr( 

1746 kwargs=FixedZeroMeanUnitVarianceKwargs( 

1747 mean=mean, std=std # pyright: ignore[reportArgumentType] 

1748 ) 

1749 ) 

1750 else: 

1751 if not isinstance(mean, list): 

1752 mean = [float(mean)] 

1753 if not isinstance(std, list): 

1754 std = [float(std)] 

1755 

1756 return FixedZeroMeanUnitVarianceDescr( 

1757 kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( 

1758 axis=axis, mean=mean, std=std 

1759 ) 

1760 ) 

1761 

1762 else: 

1763 axes = _axes_letters_to_ids(p.kwargs.axes) or [] 

1764 if p.kwargs.mode == "per_dataset": 

1765 axes = [AxisId("batch")] + axes 

1766 if not axes: 

1767 axes = None 

1768 return ZeroMeanUnitVarianceDescr( 

1769 kwargs=ZeroMeanUnitVarianceKwargs(axes=axes, eps=p.kwargs.eps) 

1770 ) 

1771 

1772 elif isinstance(p, _ScaleRangeDescr_v0_4): 

1773 return ScaleRangeDescr( 

1774 kwargs=ScaleRangeKwargs( 

1775 axes=_axes_letters_to_ids(p.kwargs.axes), 

1776 min_percentile=p.kwargs.min_percentile, 

1777 max_percentile=p.kwargs.max_percentile, 

1778 eps=p.kwargs.eps, 

1779 ) 

1780 ) 

1781 else: 

1782 assert_never(p) 

1783 

1784 

1785class _InputTensorConv( 

1786 Converter[ 

1787 _InputTensorDescr_v0_4, 

1788 InputTensorDescr, 

1789 ImportantFileSource, 

1790 Optional[ImportantFileSource], 

1791 Mapping[_TensorName_v0_4, Mapping[str, int]], 

1792 ] 

1793): 

1794 def _convert( 

1795 self, 

1796 src: _InputTensorDescr_v0_4, 

1797 tgt: "type[InputTensorDescr] | type[dict[str, Any]]", 

1798 test_tensor: ImportantFileSource, 

1799 sample_tensor: Optional[ImportantFileSource], 

1800 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1801 ) -> "InputTensorDescr | dict[str, Any]": 

1802 axes: List[InputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

1803 src.axes, 

1804 shape=src.shape, 

1805 tensor_type="input", 

1806 halo=None, 

1807 size_refs=size_refs, 

1808 ) 

1809 prep: List[PreprocessingDescr] = [] 

1810 for p in src.preprocessing: 

1811 cp = _convert_proc(p, src.axes) 

1812 assert not isinstance(cp, ScaleMeanVarianceDescr) 

1813 prep.append(cp) 

1814 

1815 return tgt( 

1816 axes=axes, 

1817 id=TensorId(str(src.name)), 

1818 test_tensor=FileDescr(source=test_tensor), 

1819 sample_tensor=( 

1820 None if sample_tensor is None else FileDescr(source=sample_tensor) 

1821 ), 

1822 data=dict(type=src.data_type), # pyright: ignore[reportArgumentType] 

1823 preprocessing=prep, 

1824 ) 

1825 

1826 

1827_input_tensor_conv = _InputTensorConv(_InputTensorDescr_v0_4, InputTensorDescr) 

1828 

1829 

1830class OutputTensorDescr(TensorDescrBase[OutputAxis]): 

1831 id: TensorId = TensorId("output") 

1832 """Output tensor id. 

1833 No duplicates are allowed across all inputs and outputs.""" 

1834 

1835 postprocessing: List[PostprocessingDescr] = Field(default_factory=list) 

1836 """Description of how this output should be postprocessed. 

1837 

1838 note: `postprocessing` always ends with an 'ensure_dtype' operation. 

1839 If not given this is added to cast to this tensor's `data.type`. 

1840 """ 

1841 

1842 @model_validator(mode="after") 

1843 def _validate_postprocessing_kwargs(self) -> Self: 

1844 axes_ids = [a.id for a in self.axes] 

1845 for p in self.postprocessing: 

1846 kwargs_axes: Optional[Sequence[Any]] = p.kwargs.get("axes") 

1847 if kwargs_axes is None: 

1848 continue 

1849 

1850 if not isinstance(kwargs_axes, collections.abc.Sequence): 

1851 raise ValueError( 

1852 f"expected `axes` sequence, but got {type(kwargs_axes)}" 

1853 ) 

1854 

1855 if any(a not in axes_ids for a in kwargs_axes): 

1856 raise ValueError("`kwargs.axes` needs to be subset of axes ids") 

1857 

1858 if isinstance(self.data, (NominalOrOrdinalDataDescr, IntervalOrRatioDataDescr)): 

1859 dtype = self.data.type 

1860 else: 

1861 dtype = self.data[0].type 

1862 

1863 # ensure `postprocessing` ends with `EnsureDtypeDescr` or `BinarizeDescr` 

1864 if not self.postprocessing or not isinstance( 

1865 self.postprocessing[-1], (EnsureDtypeDescr, BinarizeDescr) 

1866 ): 

1867 self.postprocessing.append( 

1868 EnsureDtypeDescr(kwargs=EnsureDtypeKwargs(dtype=dtype)) 

1869 ) 

1870 return self 

1871 

1872 

1873class _OutputTensorConv( 

1874 Converter[ 

1875 _OutputTensorDescr_v0_4, 

1876 OutputTensorDescr, 

1877 ImportantFileSource, 

1878 Optional[ImportantFileSource], 

1879 Mapping[_TensorName_v0_4, Mapping[str, int]], 

1880 ] 

1881): 

1882 def _convert( 

1883 self, 

1884 src: _OutputTensorDescr_v0_4, 

1885 tgt: "type[OutputTensorDescr] | type[dict[str, Any]]", 

1886 test_tensor: ImportantFileSource, 

1887 sample_tensor: Optional[ImportantFileSource], 

1888 size_refs: Mapping[_TensorName_v0_4, Mapping[str, int]], 

1889 ) -> "OutputTensorDescr | dict[str, Any]": 

1890 # TODO: split convert_axes into convert_output_axes and convert_input_axes 

1891 axes: List[OutputAxis] = convert_axes( # pyright: ignore[reportAssignmentType] 

1892 src.axes, 

1893 shape=src.shape, 

1894 tensor_type="output", 

1895 halo=src.halo, 

1896 size_refs=size_refs, 

1897 ) 

1898 data_descr: Dict[str, Any] = dict(type=src.data_type) 

1899 if data_descr["type"] == "bool": 

1900 data_descr["values"] = [False, True] 

1901 

1902 return tgt( 

1903 axes=axes, 

1904 id=TensorId(str(src.name)), 

1905 test_tensor=FileDescr(source=test_tensor), 

1906 sample_tensor=( 

1907 None if sample_tensor is None else FileDescr(source=sample_tensor) 

1908 ), 

1909 data=data_descr, # pyright: ignore[reportArgumentType] 

1910 postprocessing=[_convert_proc(p, src.axes) for p in src.postprocessing], 

1911 ) 

1912 

1913 

1914_output_tensor_conv = _OutputTensorConv(_OutputTensorDescr_v0_4, OutputTensorDescr) 

1915 

1916 

1917TensorDescr = Union[InputTensorDescr, OutputTensorDescr] 

1918 

1919 

1920def validate_tensors( 

1921 tensors: Mapping[TensorId, Tuple[TensorDescr, NDArray[Any]]], 

1922 tensor_origin: str, # for more precise error messages, e.g. 'test_tensor' 

1923): 

1924 all_tensor_axes: Dict[TensorId, Dict[AxisId, Tuple[AnyAxis, int]]] = {} 

1925 

1926 def e_msg(d: TensorDescr): 

1927 return f"{'inputs' if isinstance(d, InputTensorDescr) else 'outputs'}[{d.id}]" 

1928 

1929 for descr, array in tensors.values(): 

1930 try: 

1931 axis_sizes = descr.get_axis_sizes_for_array(array) 

1932 except ValueError as e: 

1933 raise ValueError(f"{e_msg(descr)} {e}") 

1934 else: 

1935 all_tensor_axes[descr.id] = { 

1936 a.id: (a, axis_sizes[a.id]) for a in descr.axes 

1937 } 

1938 

1939 for descr, array in tensors.values(): 

1940 if array.dtype.name != descr.dtype: 

1941 raise ValueError( 

1942 f"{e_msg(descr)}.{tensor_origin}.dtype '{array.dtype.name}' does not" 

1943 + f" match described dtype '{descr.dtype}'" 

1944 ) 

1945 

1946 for a in descr.axes: 

1947 actual_size = all_tensor_axes[descr.id][a.id][1] 

1948 if a.size is None: 

1949 continue 

1950 

1951 if isinstance(a.size, int): 

1952 if actual_size != a.size: 

1953 raise ValueError( 

1954 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' " 

1955 + f"has incompatible size {actual_size}, expected {a.size}" 

1956 ) 

1957 elif isinstance(a.size, ParameterizedSize): 

1958 _ = a.size.validate_size(actual_size) 

1959 elif isinstance(a.size, DataDependentSize): 

1960 _ = a.size.validate_size(actual_size) 

1961 elif isinstance(a.size, SizeReference): 

1962 ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id) 

1963 if ref_tensor_axes is None: 

1964 raise ValueError( 

1965 f"{e_msg(descr)}.axes[{a.id}].size.tensor_id: Unknown tensor" 

1966 + f" reference '{a.size.tensor_id}'" 

1967 ) 

1968 

1969 ref_axis, ref_size = ref_tensor_axes.get(a.size.axis_id, (None, None)) 

1970 if ref_axis is None or ref_size is None: 

1971 raise ValueError( 

1972 f"{e_msg(descr)}.axes[{a.id}].size.axis_id: Unknown tensor axis" 

1973 + f" reference '{a.size.tensor_id}.{a.size.axis_id}" 

1974 ) 

1975 

1976 if a.unit != ref_axis.unit: 

1977 raise ValueError( 

1978 f"{e_msg(descr)}.axes[{a.id}].size: `SizeReference` requires" 

1979 + " axis and reference axis to have the same `unit`, but" 

1980 + f" {a.unit}!={ref_axis.unit}" 

1981 ) 

1982 

1983 if actual_size != ( 

1984 expected_size := ( 

1985 ref_size * ref_axis.scale / a.scale + a.size.offset 

1986 ) 

1987 ): 

1988 raise ValueError( 

1989 f"{e_msg(descr)}.{tensor_origin}: axis '{a.id}' of size" 

1990 + f" {actual_size} invalid for referenced size {ref_size};" 

1991 + f" expected {expected_size}" 

1992 ) 

1993 else: 

1994 assert_never(a.size) 

1995 

1996 

1997class EnvironmentFileDescr(FileDescr): 

1998 source: Annotated[ 

1999 ImportantFileSource, 

2000 WithSuffix((".yaml", ".yml"), case_sensitive=True), 

2001 Field( 

2002 examples=["environment.yaml"], 

2003 ), 

2004 ] 

2005 """∈📦 Conda environment file. 

2006 Allows to specify custom dependencies, see conda docs: 

2007 - [Exporting an environment file across platforms](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#exporting-an-environment-file-across-platforms) 

2008 - [Creating an environment file manually](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-file-manually) 

2009 """ 

2010 

2011 

2012class _ArchitectureCallableDescr(Node): 

2013 callable: Annotated[Identifier, Field(examples=["MyNetworkClass", "get_my_model"])] 

2014 """Identifier of the callable that returns a torch.nn.Module instance.""" 

2015 

2016 kwargs: Dict[str, YamlValue] = Field(default_factory=dict) 

2017 """key word arguments for the `callable`""" 

2018 

2019 

2020class ArchitectureFromFileDescr(_ArchitectureCallableDescr, FileDescr): 

2021 pass 

2022 

2023 

2024class ArchitectureFromLibraryDescr(_ArchitectureCallableDescr): 

2025 import_from: str 

2026 """Where to import the callable from, i.e. `from <import_from> import <callable>`""" 

2027 

2028 

2029ArchitectureDescr = Annotated[ 

2030 Union[ArchitectureFromFileDescr, ArchitectureFromLibraryDescr], 

2031 Field(union_mode="left_to_right"), 

2032] 

2033 

2034 

2035class _ArchFileConv( 

2036 Converter[ 

2037 _CallableFromFile_v0_4, 

2038 ArchitectureFromFileDescr, 

2039 Optional[Sha256], 

2040 Dict[str, Any], 

2041 ] 

2042): 

2043 def _convert( 

2044 self, 

2045 src: _CallableFromFile_v0_4, 

2046 tgt: "type[ArchitectureFromFileDescr | dict[str, Any]]", 

2047 sha256: Optional[Sha256], 

2048 kwargs: Dict[str, Any], 

2049 ) -> "ArchitectureFromFileDescr | dict[str, Any]": 

2050 if src.startswith("http") and src.count(":") == 2: 

2051 http, source, callable_ = src.split(":") 

2052 source = ":".join((http, source)) 

2053 elif not src.startswith("http") and src.count(":") == 1: 

2054 source, callable_ = src.split(":") 

2055 else: 

2056 source = str(src) 

2057 callable_ = str(src) 

2058 return tgt( 

2059 callable=Identifier(callable_), 

2060 source=cast(ImportantFileSource, source), 

2061 sha256=sha256, 

2062 kwargs=kwargs, 

2063 ) 

2064 

2065 

2066_arch_file_conv = _ArchFileConv(_CallableFromFile_v0_4, ArchitectureFromFileDescr) 

2067 

2068 

2069class _ArchLibConv( 

2070 Converter[ 

2071 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr, Dict[str, Any] 

2072 ] 

2073): 

2074 def _convert( 

2075 self, 

2076 src: _CallableFromDepencency_v0_4, 

2077 tgt: "type[ArchitectureFromLibraryDescr | dict[str, Any]]", 

2078 kwargs: Dict[str, Any], 

2079 ) -> "ArchitectureFromLibraryDescr | dict[str, Any]": 

2080 *mods, callable_ = src.split(".") 

2081 import_from = ".".join(mods) 

2082 return tgt( 

2083 import_from=import_from, callable=Identifier(callable_), kwargs=kwargs 

2084 ) 

2085 

2086 

2087_arch_lib_conv = _ArchLibConv( 

2088 _CallableFromDepencency_v0_4, ArchitectureFromLibraryDescr 

2089) 

2090 

2091 

2092class WeightsEntryDescrBase(FileDescr): 

2093 type: ClassVar[WeightsFormat] 

2094 weights_format_name: ClassVar[str] # human readable 

2095 

2096 source: ImportantFileSource 

2097 """∈📦 The weights file.""" 

2098 

2099 authors: Optional[List[Author]] = None 

2100 """Authors 

2101 Either the person(s) that have trained this model resulting in the original weights file. 

2102 (If this is the initial weights entry, i.e. it does not have a `parent`) 

2103 Or the person(s) who have converted the weights to this weights format. 

2104 (If this is a child weight, i.e. it has a `parent` field) 

2105 """ 

2106 

2107 parent: Annotated[ 

2108 Optional[WeightsFormat], Field(examples=["pytorch_state_dict"]) 

2109 ] = None 

2110 """The source weights these weights were converted from. 

2111 For example, if a model's weights were converted from the `pytorch_state_dict` format to `torchscript`, 

2112 The `pytorch_state_dict` weights entry has no `parent` and is the parent of the `torchscript` weights. 

2113 All weight entries except one (the initial set of weights resulting from training the model), 

2114 need to have this field.""" 

2115 

2116 @model_validator(mode="after") 

2117 def check_parent_is_not_self(self) -> Self: 

2118 if self.type == self.parent: 

2119 raise ValueError("Weights entry can't be it's own parent.") 

2120 

2121 return self 

2122 

2123 

2124class KerasHdf5WeightsDescr(WeightsEntryDescrBase): 

2125 type = "keras_hdf5" 

2126 weights_format_name: ClassVar[str] = "Keras HDF5" 

2127 tensorflow_version: Version 

2128 """TensorFlow version used to create these weights.""" 

2129 

2130 

2131class OnnxWeightsDescr(WeightsEntryDescrBase): 

2132 type = "onnx" 

2133 weights_format_name: ClassVar[str] = "ONNX" 

2134 opset_version: Annotated[int, Ge(7)] 

2135 """ONNX opset version""" 

2136 

2137 

2138class PytorchStateDictWeightsDescr(WeightsEntryDescrBase): 

2139 type = "pytorch_state_dict" 

2140 weights_format_name: ClassVar[str] = "Pytorch State Dict" 

2141 architecture: ArchitectureDescr 

2142 pytorch_version: Version 

2143 """Version of the PyTorch library used. 

2144 If `architecture.depencencies` is specified it has to include pytorch and any version pinning has to be compatible. 

2145 """ 

2146 dependencies: Optional[EnvironmentFileDescr] = None 

2147 """Custom depencies beyond pytorch. 

2148 The conda environment file should include pytorch and any version pinning has to be compatible with 

2149 `pytorch_version`. 

2150 """ 

2151 

2152 

2153class TensorflowJsWeightsDescr(WeightsEntryDescrBase): 

2154 type = "tensorflow_js" 

2155 weights_format_name: ClassVar[str] = "Tensorflow.js" 

2156 tensorflow_version: Version 

2157 """Version of the TensorFlow library used.""" 

2158 

2159 source: ImportantFileSource 

2160 """∈📦 The multi-file weights. 

2161 All required files/folders should be a zip archive.""" 

2162 

2163 

2164class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase): 

2165 type = "tensorflow_saved_model_bundle" 

2166 weights_format_name: ClassVar[str] = "Tensorflow Saved Model" 

2167 tensorflow_version: Version 

2168 """Version of the TensorFlow library used.""" 

2169 

2170 dependencies: Optional[EnvironmentFileDescr] = None 

2171 """Custom dependencies beyond tensorflow. 

2172 Should include tensorflow and any version pinning has to be compatible with `tensorflow_version`.""" 

2173 

2174 source: ImportantFileSource 

2175 """∈📦 The multi-file weights. 

2176 All required files/folders should be a zip archive.""" 

2177 

2178 

2179class TorchscriptWeightsDescr(WeightsEntryDescrBase): 

2180 type = "torchscript" 

2181 weights_format_name: ClassVar[str] = "TorchScript" 

2182 pytorch_version: Version 

2183 """Version of the PyTorch library used.""" 

2184 

2185 

2186class WeightsDescr(Node): 

2187 keras_hdf5: Optional[KerasHdf5WeightsDescr] = None 

2188 onnx: Optional[OnnxWeightsDescr] = None 

2189 pytorch_state_dict: Optional[PytorchStateDictWeightsDescr] = None 

2190 tensorflow_js: Optional[TensorflowJsWeightsDescr] = None 

2191 tensorflow_saved_model_bundle: Optional[TensorflowSavedModelBundleWeightsDescr] = ( 

2192 None 

2193 ) 

2194 torchscript: Optional[TorchscriptWeightsDescr] = None 

2195 

2196 @model_validator(mode="after") 

2197 def check_entries(self) -> Self: 

2198 entries = {wtype for wtype, entry in self if entry is not None} 

2199 

2200 if not entries: 

2201 raise ValueError("Missing weights entry") 

2202 

2203 entries_wo_parent = { 

2204 wtype 

2205 for wtype, entry in self 

2206 if entry is not None and hasattr(entry, "parent") and entry.parent is None 

2207 } 

2208 if len(entries_wo_parent) != 1: 

2209 issue_warning( 

2210 "Exactly one weights entry may not specify the `parent` field (got" 

2211 + " {value}). That entry is considered the original set of model weights." 

2212 + " Other weight formats are created through conversion of the orignal or" 

2213 + " already converted weights. They have to reference the weights format" 

2214 + " they were converted from as their `parent`.", 

2215 value=len(entries_wo_parent), 

2216 field="weights", 

2217 ) 

2218 

2219 for wtype, entry in self: 

2220 if entry is None: 

2221 continue 

2222 

2223 assert hasattr(entry, "type") 

2224 assert hasattr(entry, "parent") 

2225 assert wtype == entry.type 

2226 if ( 

2227 entry.parent is not None and entry.parent not in entries 

2228 ): # self reference checked for `parent` field 

2229 raise ValueError( 

2230 f"`weights.{wtype}.parent={entry.parent} not in specified weight" 

2231 + f" formats: {entries}" 

2232 ) 

2233 

2234 return self 

2235 

2236 def __getitem__( 

2237 self, 

2238 key: Literal[ 

2239 "keras_hdf5", 

2240 "onnx", 

2241 "pytorch_state_dict", 

2242 "tensorflow_js", 

2243 "tensorflow_saved_model_bundle", 

2244 "torchscript", 

2245 ], 

2246 ): 

2247 if key == "keras_hdf5": 

2248 ret = self.keras_hdf5 

2249 elif key == "onnx": 

2250 ret = self.onnx 

2251 elif key == "pytorch_state_dict": 

2252 ret = self.pytorch_state_dict 

2253 elif key == "tensorflow_js": 

2254 ret = self.tensorflow_js 

2255 elif key == "tensorflow_saved_model_bundle": 

2256 ret = self.tensorflow_saved_model_bundle 

2257 elif key == "torchscript": 

2258 ret = self.torchscript 

2259 else: 

2260 raise KeyError(key) 

2261 

2262 if ret is None: 

2263 raise KeyError(key) 

2264 

2265 return ret 

2266 

2267 @property 

2268 def available_formats(self): 

2269 return { 

2270 **({} if self.keras_hdf5 is None else {"keras_hdf5": self.keras_hdf5}), 

2271 **({} if self.onnx is None else {"onnx": self.onnx}), 

2272 **( 

2273 {} 

2274 if self.pytorch_state_dict is None 

2275 else {"pytorch_state_dict": self.pytorch_state_dict} 

2276 ), 

2277 **( 

2278 {} 

2279 if self.tensorflow_js is None 

2280 else {"tensorflow_js": self.tensorflow_js} 

2281 ), 

2282 **( 

2283 {} 

2284 if self.tensorflow_saved_model_bundle is None 

2285 else { 

2286 "tensorflow_saved_model_bundle": self.tensorflow_saved_model_bundle 

2287 } 

2288 ), 

2289 **({} if self.torchscript is None else {"torchscript": self.torchscript}), 

2290 } 

2291 

2292 @property 

2293 def missing_formats(self): 

2294 return { 

2295 wf for wf in get_args(WeightsFormat) if wf not in self.available_formats 

2296 } 

2297 

2298 

2299class ModelId(ResourceId): 

2300 pass 

2301 

2302 

2303class LinkedModel(LinkedResourceBase): 

2304 """Reference to a bioimage.io model.""" 

2305 

2306 id: ModelId 

2307 """A valid model `id` from the bioimage.io collection.""" 

2308 

2309 

2310class _DataDepSize(NamedTuple): 

2311 min: int 

2312 max: Optional[int] 

2313 

2314 

2315class _AxisSizes(NamedTuple): 

2316 """the lenghts of all axes of model inputs and outputs""" 

2317 

2318 inputs: Dict[Tuple[TensorId, AxisId], int] 

2319 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] 

2320 

2321 

2322class _TensorSizes(NamedTuple): 

2323 """_AxisSizes as nested dicts""" 

2324 

2325 inputs: Dict[TensorId, Dict[AxisId, int]] 

2326 outputs: Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]] 

2327 

2328 

2329class ModelDescr(GenericModelDescrBase): 

2330 """Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights. 

2331 These fields are typically stored in a YAML file which we call a model resource description file (model RDF). 

2332 """ 

2333 

2334 format_version: Literal["0.5.3"] = "0.5.3" 

2335 """Version of the bioimage.io model description specification used. 

2336 When creating a new model always use the latest micro/patch version described here. 

2337 The `format_version` is important for any consumer software to understand how to parse the fields. 

2338 """ 

2339 

2340 type: Literal["model"] = "model" 

2341 """Specialized resource type 'model'""" 

2342 

2343 id: Optional[ModelId] = None 

2344 """bioimage.io-wide unique resource identifier 

2345 assigned by bioimage.io; version **un**specific.""" 

2346 

2347 authors: NotEmpty[List[Author]] 

2348 """The authors are the creators of the model RDF and the primary points of contact.""" 

2349 

2350 documentation: Annotated[ 

2351 DocumentationSource, 

2352 Field( 

2353 examples=[ 

2354 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/README.md", 

2355 "README.md", 

2356 ], 

2357 ), 

2358 ] 

2359 """∈📦 URL or relative path to a markdown file with additional documentation. 

2360 The recommended documentation file name is `README.md`. An `.md` suffix is mandatory. 

2361 The documentation should include a '#[#] Validation' (sub)section 

2362 with details on how to quantitatively validate the model on unseen data.""" 

2363 

2364 @field_validator("documentation", mode="after") 

2365 @classmethod 

2366 def _validate_documentation(cls, value: DocumentationSource) -> DocumentationSource: 

2367 if not validation_context_var.get().perform_io_checks: 

2368 return value 

2369 

2370 doc_path = download(value).path 

2371 doc_content = doc_path.read_text(encoding="utf-8") 

2372 assert isinstance(doc_content, str) 

2373 if not re.match("#.*[vV]alidation", doc_content): 

2374 issue_warning( 

2375 "No '# Validation' (sub)section found in {value}.", 

2376 value=value, 

2377 field="documentation", 

2378 ) 

2379 

2380 return value 

2381 

2382 inputs: NotEmpty[Sequence[InputTensorDescr]] 

2383 """Describes the input tensors expected by this model.""" 

2384 

2385 @field_validator("inputs", mode="after") 

2386 @classmethod 

2387 def _validate_input_axes( 

2388 cls, inputs: Sequence[InputTensorDescr] 

2389 ) -> Sequence[InputTensorDescr]: 

2390 input_size_refs = cls._get_axes_with_independent_size(inputs) 

2391 

2392 for i, ipt in enumerate(inputs): 

2393 valid_independent_refs: Dict[ 

2394 Tuple[TensorId, AxisId], 

2395 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2396 ] = { 

2397 **{ 

2398 (ipt.id, a.id): (ipt, a, a.size) 

2399 for a in ipt.axes 

2400 if not isinstance(a, BatchAxis) 

2401 and isinstance(a.size, (int, ParameterizedSize)) 

2402 }, 

2403 **input_size_refs, 

2404 } 

2405 for a, ax in enumerate(ipt.axes): 

2406 cls._validate_axis( 

2407 "inputs", 

2408 i=i, 

2409 tensor_id=ipt.id, 

2410 a=a, 

2411 axis=ax, 

2412 valid_independent_refs=valid_independent_refs, 

2413 ) 

2414 return inputs 

2415 

2416 @staticmethod 

2417 def _validate_axis( 

2418 field_name: str, 

2419 i: int, 

2420 tensor_id: TensorId, 

2421 a: int, 

2422 axis: AnyAxis, 

2423 valid_independent_refs: Dict[ 

2424 Tuple[TensorId, AxisId], 

2425 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2426 ], 

2427 ): 

2428 if isinstance(axis, BatchAxis) or isinstance( 

2429 axis.size, (int, ParameterizedSize, DataDependentSize) 

2430 ): 

2431 return 

2432 elif not isinstance(axis.size, SizeReference): 

2433 assert_never(axis.size) 

2434 

2435 # validate axis.size SizeReference 

2436 ref = (axis.size.tensor_id, axis.size.axis_id) 

2437 if ref not in valid_independent_refs: 

2438 raise ValueError( 

2439 "Invalid tensor axis reference at" 

2440 + f" {field_name}[{i}].axes[{a}].size: {axis.size}." 

2441 ) 

2442 if ref == (tensor_id, axis.id): 

2443 raise ValueError( 

2444 "Self-referencing not allowed for" 

2445 + f" {field_name}[{i}].axes[{a}].size: {axis.size}" 

2446 ) 

2447 if axis.type == "channel": 

2448 if valid_independent_refs[ref][1].type != "channel": 

2449 raise ValueError( 

2450 "A channel axis' size may only reference another fixed size" 

2451 + " channel axis." 

2452 ) 

2453 if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names: 

2454 ref_size = valid_independent_refs[ref][2] 

2455 assert isinstance(ref_size, int), ( 

2456 "channel axis ref (another channel axis) has to specify fixed" 

2457 + " size" 

2458 ) 

2459 generated_channel_names = [ 

2460 Identifier(axis.channel_names.format(i=i)) 

2461 for i in range(1, ref_size + 1) 

2462 ] 

2463 axis.channel_names = generated_channel_names 

2464 

2465 if (ax_unit := getattr(axis, "unit", None)) != ( 

2466 ref_unit := getattr(valid_independent_refs[ref][1], "unit", None) 

2467 ): 

2468 raise ValueError( 

2469 "The units of an axis and its reference axis need to match, but" 

2470 + f" '{ax_unit}' != '{ref_unit}'." 

2471 ) 

2472 ref_axis = valid_independent_refs[ref][1] 

2473 if isinstance(ref_axis, BatchAxis): 

2474 raise ValueError( 

2475 f"Invalid reference axis '{ref_axis.id}' for {tensor_id}.{axis.id}" 

2476 + " (a batch axis is not allowed as reference)." 

2477 ) 

2478 

2479 if isinstance(axis, WithHalo): 

2480 min_size = axis.size.get_size(axis, ref_axis, n=0) 

2481 if (min_size - 2 * axis.halo) < 1: 

2482 raise ValueError( 

2483 f"axis {axis.id} with minimum size {min_size} is too small for halo" 

2484 + f" {axis.halo}." 

2485 ) 

2486 

2487 input_halo = axis.halo * axis.scale / ref_axis.scale 

2488 if input_halo != int(input_halo) or input_halo % 2 == 1: 

2489 raise ValueError( 

2490 f"input_halo {input_halo} (output_halo {axis.halo} *" 

2491 + f" output_scale {axis.scale} / input_scale {ref_axis.scale})" 

2492 + f" is not an even integer for {tensor_id}.{axis.id}." 

2493 ) 

2494 

2495 @model_validator(mode="after") 

2496 def _validate_test_tensors(self) -> Self: 

2497 if not validation_context_var.get().perform_io_checks: 

2498 return self 

2499 

2500 test_arrays = [ 

2501 load_array(descr.test_tensor.download().path) 

2502 for descr in chain(self.inputs, self.outputs) 

2503 ] 

2504 tensors = { 

2505 descr.id: (descr, array) 

2506 for descr, array in zip(chain(self.inputs, self.outputs), test_arrays) 

2507 } 

2508 validate_tensors(tensors, tensor_origin="test_tensor") 

2509 return self 

2510 

2511 @model_validator(mode="after") 

2512 def _validate_tensor_references_in_proc_kwargs(self, info: ValidationInfo) -> Self: 

2513 ipt_refs = {t.id for t in self.inputs} 

2514 out_refs = {t.id for t in self.outputs} 

2515 for ipt in self.inputs: 

2516 for p in ipt.preprocessing: 

2517 ref = p.kwargs.get("reference_tensor") 

2518 if ref is None: 

2519 continue 

2520 if ref not in ipt_refs: 

2521 raise ValueError( 

2522 f"`reference_tensor` '{ref}' not found. Valid input tensor" 

2523 + f" references are: {ipt_refs}." 

2524 ) 

2525 

2526 for out in self.outputs: 

2527 for p in out.postprocessing: 

2528 ref = p.kwargs.get("reference_tensor") 

2529 if ref is None: 

2530 continue 

2531 

2532 if ref not in ipt_refs and ref not in out_refs: 

2533 raise ValueError( 

2534 f"`reference_tensor` '{ref}' not found. Valid tensor references" 

2535 + f" are: {ipt_refs | out_refs}." 

2536 ) 

2537 

2538 return self 

2539 

2540 # TODO: use validate funcs in validate_test_tensors 

2541 # def validate_inputs(self, input_tensors: Mapping[TensorId, NDArray[Any]]) -> Mapping[TensorId, NDArray[Any]]: 

2542 

2543 name: Annotated[ 

2544 Annotated[ 

2545 str, RestrictCharacters(string.ascii_letters + string.digits + "_- ()") 

2546 ], 

2547 MinLen(5), 

2548 MaxLen(128), 

2549 warn(MaxLen(64), "Name longer than 64 characters.", INFO), 

2550 ] 

2551 """A human-readable name of this model. 

2552 It should be no longer than 64 characters 

2553 and may only contain letter, number, underscore, minus, parentheses and spaces. 

2554 We recommend to chose a name that refers to the model's task and image modality. 

2555 """ 

2556 

2557 outputs: NotEmpty[Sequence[OutputTensorDescr]] 

2558 """Describes the output tensors.""" 

2559 

2560 @field_validator("outputs", mode="after") 

2561 @classmethod 

2562 def _validate_tensor_ids( 

2563 cls, outputs: Sequence[OutputTensorDescr], info: ValidationInfo 

2564 ) -> Sequence[OutputTensorDescr]: 

2565 tensor_ids = [ 

2566 t.id for t in info.data.get("inputs", []) + info.data.get("outputs", []) 

2567 ] 

2568 duplicate_tensor_ids: List[str] = [] 

2569 seen: Set[str] = set() 

2570 for t in tensor_ids: 

2571 if t in seen: 

2572 duplicate_tensor_ids.append(t) 

2573 

2574 seen.add(t) 

2575 

2576 if duplicate_tensor_ids: 

2577 raise ValueError(f"Duplicate tensor ids: {duplicate_tensor_ids}") 

2578 

2579 return outputs 

2580 

2581 @staticmethod 

2582 def _get_axes_with_parameterized_size( 

2583 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2584 ): 

2585 return { 

2586 f"{t.id}.{a.id}": (t, a, a.size) 

2587 for t in io 

2588 for a in t.axes 

2589 if not isinstance(a, BatchAxis) and isinstance(a.size, ParameterizedSize) 

2590 } 

2591 

2592 @staticmethod 

2593 def _get_axes_with_independent_size( 

2594 io: Union[Sequence[InputTensorDescr], Sequence[OutputTensorDescr]], 

2595 ): 

2596 return { 

2597 (t.id, a.id): (t, a, a.size) 

2598 for t in io 

2599 for a in t.axes 

2600 if not isinstance(a, BatchAxis) 

2601 and isinstance(a.size, (int, ParameterizedSize)) 

2602 } 

2603 

2604 @field_validator("outputs", mode="after") 

2605 @classmethod 

2606 def _validate_output_axes( 

2607 cls, outputs: List[OutputTensorDescr], info: ValidationInfo 

2608 ) -> List[OutputTensorDescr]: 

2609 input_size_refs = cls._get_axes_with_independent_size( 

2610 info.data.get("inputs", []) 

2611 ) 

2612 output_size_refs = cls._get_axes_with_independent_size(outputs) 

2613 

2614 for i, out in enumerate(outputs): 

2615 valid_independent_refs: Dict[ 

2616 Tuple[TensorId, AxisId], 

2617 Tuple[TensorDescr, AnyAxis, Union[int, ParameterizedSize]], 

2618 ] = { 

2619 **{ 

2620 (out.id, a.id): (out, a, a.size) 

2621 for a in out.axes 

2622 if not isinstance(a, BatchAxis) 

2623 and isinstance(a.size, (int, ParameterizedSize)) 

2624 }, 

2625 **input_size_refs, 

2626 **output_size_refs, 

2627 } 

2628 for a, ax in enumerate(out.axes): 

2629 cls._validate_axis( 

2630 "outputs", 

2631 i, 

2632 out.id, 

2633 a, 

2634 ax, 

2635 valid_independent_refs=valid_independent_refs, 

2636 ) 

2637 

2638 return outputs 

2639 

2640 packaged_by: List[Author] = Field(default_factory=list) 

2641 """The persons that have packaged and uploaded this model. 

2642 Only required if those persons differ from the `authors`.""" 

2643 

2644 parent: Optional[LinkedModel] = None 

2645 """The model from which this model is derived, e.g. by fine-tuning the weights.""" 

2646 

2647 # todo: add parent self check once we have `id` 

2648 # @model_validator(mode="after") 

2649 # def validate_parent_is_not_self(self) -> Self: 

2650 # if self.parent is not None and self.parent == self.id: 

2651 # raise ValueError("The model may not reference itself as parent model") 

2652 

2653 # return self 

2654 

2655 run_mode: Annotated[ 

2656 Optional[RunMode], 

2657 warn(None, "Run mode '{value}' has limited support across consumer softwares."), 

2658 ] = None 

2659 """Custom run mode for this model: for more complex prediction procedures like test time 

2660 data augmentation that currently cannot be expressed in the specification. 

2661 No standard run modes are defined yet.""" 

2662 

2663 timestamp: Datetime = Datetime(datetime.now()) 

2664 """Timestamp in [ISO 8601](#https://en.wikipedia.org/wiki/ISO_8601) format 

2665 with a few restrictions listed [here](https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat). 

2666 (In Python a datetime object is valid, too).""" 

2667 

2668 training_data: Annotated[ 

2669 Union[None, LinkedDataset, DatasetDescr, DatasetDescr02], 

2670 Field(union_mode="left_to_right"), 

2671 ] = None 

2672 """The dataset used to train this model""" 

2673 

2674 weights: Annotated[WeightsDescr, WrapSerializer(package_weights)] 

2675 """The weights for this model. 

2676 Weights can be given for different formats, but should otherwise be equivalent. 

2677 The available weight formats determine which consumers can use this model.""" 

2678 

2679 @model_validator(mode="after") 

2680 def _add_default_cover(self) -> Self: 

2681 if not validation_context_var.get().perform_io_checks or self.covers: 

2682 return self 

2683 

2684 try: 

2685 generated_covers = generate_covers( 

2686 [(t, load_array(t.test_tensor.download().path)) for t in self.inputs], 

2687 [(t, load_array(t.test_tensor.download().path)) for t in self.outputs], 

2688 ) 

2689 except Exception as e: 

2690 issue_warning( 

2691 "Failed to generate cover image(s): {e}", 

2692 value=self.covers, 

2693 msg_context=dict(e=e), 

2694 field="covers", 

2695 ) 

2696 else: 

2697 self.covers.extend(generated_covers) 

2698 

2699 return self 

2700 

2701 def get_input_test_arrays(self) -> List[NDArray[Any]]: 

2702 data = [load_array(ipt.test_tensor.download().path) for ipt in self.inputs] 

2703 assert all(isinstance(d, np.ndarray) for d in data) 

2704 return data 

2705 

2706 def get_output_test_arrays(self) -> List[NDArray[Any]]: 

2707 data = [load_array(out.test_tensor.download().path) for out in self.outputs] 

2708 assert all(isinstance(d, np.ndarray) for d in data) 

2709 return data 

2710 

2711 @staticmethod 

2712 def get_batch_size(tensor_sizes: Mapping[TensorId, Mapping[AxisId, int]]) -> int: 

2713 batch_size = 1 

2714 tensor_with_batchsize: Optional[TensorId] = None 

2715 for tid in tensor_sizes: 

2716 for aid, s in tensor_sizes[tid].items(): 

2717 if aid != BATCH_AXIS_ID or s == 1 or s == batch_size: 

2718 continue 

2719 

2720 if batch_size != 1: 

2721 assert tensor_with_batchsize is not None 

2722 raise ValueError( 

2723 f"batch size mismatch for tensors '{tensor_with_batchsize}' ({batch_size}) and '{tid}' ({s})" 

2724 ) 

2725 

2726 batch_size = s 

2727 tensor_with_batchsize = tid 

2728 

2729 return batch_size 

2730 

2731 def get_output_tensor_sizes( 

2732 self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]] 

2733 ) -> Dict[TensorId, Dict[AxisId, Union[int, _DataDepSize]]]: 

2734 """Returns the tensor output sizes for given **input_sizes**. 

2735 Only if **input_sizes** has a valid input shape, the tensor output size is exact. 

2736 Otherwise it might be larger than the actual (valid) output""" 

2737 batch_size = self.get_batch_size(input_sizes) 

2738 ns = self.get_ns(input_sizes) 

2739 

2740 tensor_sizes = self.get_tensor_sizes(ns, batch_size=batch_size) 

2741 return tensor_sizes.outputs 

2742 

2743 def get_ns(self, input_sizes: Mapping[TensorId, Mapping[AxisId, int]]): 

2744 """get parameter `n` for each parameterized axis 

2745 such that the valid input size is >= the given input size""" 

2746 ret: Dict[Tuple[TensorId, AxisId], ParameterizedSize_N] = {} 

2747 axes = {t.id: {a.id: a for a in t.axes} for t in self.inputs} 

2748 for tid in input_sizes: 

2749 for aid, s in input_sizes[tid].items(): 

2750 size_descr = axes[tid][aid].size 

2751 if isinstance(size_descr, ParameterizedSize): 

2752 ret[(tid, aid)] = size_descr.get_n(s) 

2753 elif size_descr is None or isinstance(size_descr, (int, SizeReference)): 

2754 pass 

2755 else: 

2756 assert_never(size_descr) 

2757 

2758 return ret 

2759 

2760 def get_tensor_sizes( 

2761 self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int 

2762 ) -> _TensorSizes: 

2763 axis_sizes = self.get_axis_sizes(ns, batch_size=batch_size) 

2764 return _TensorSizes( 

2765 { 

2766 t: { 

2767 aa: axis_sizes.inputs[(tt, aa)] 

2768 for tt, aa in axis_sizes.inputs 

2769 if tt == t 

2770 } 

2771 for t in {tt for tt, _ in axis_sizes.inputs} 

2772 }, 

2773 { 

2774 t: { 

2775 aa: axis_sizes.outputs[(tt, aa)] 

2776 for tt, aa in axis_sizes.outputs 

2777 if tt == t 

2778 } 

2779 for t in {tt for tt, _ in axis_sizes.outputs} 

2780 }, 

2781 ) 

2782 

2783 def get_axis_sizes( 

2784 self, 

2785 ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], 

2786 batch_size: Optional[int] = None, 

2787 *, 

2788 max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None, 

2789 ) -> _AxisSizes: 

2790 """Determine input and output block shape for scale factors **ns** 

2791 of parameterized input sizes. 

2792 

2793 Args: 

2794 ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id)) 

2795 that is parameterized as `size = min + n * step`. 

2796 batch_size: The desired size of the batch dimension. 

2797 If given **batch_size** overwrites any batch size present in 

2798 **max_input_shape**. Default 1. 

2799 max_input_shape: Limits the derived block shapes. 

2800 Each axis for which the input size, parameterized by `n`, is larger 

2801 than **max_input_shape** is set to the minimal value `n_min` for which 

2802 this is still true. 

2803 Use this for small input samples or large values of **ns**. 

2804 Or simply whenever you know the full input shape. 

2805 

2806 Returns: 

2807 Resolved axis sizes for model inputs and outputs. 

2808 """ 

2809 max_input_shape = max_input_shape or {} 

2810 if batch_size is None: 

2811 for (_t_id, a_id), s in max_input_shape.items(): 

2812 if a_id == BATCH_AXIS_ID: 

2813 batch_size = s 

2814 break 

2815 else: 

2816 batch_size = 1 

2817 

2818 all_axes = { 

2819 t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs) 

2820 } 

2821 

2822 inputs: Dict[Tuple[TensorId, AxisId], int] = {} 

2823 outputs: Dict[Tuple[TensorId, AxisId], Union[int, _DataDepSize]] = {} 

2824 

2825 def get_axis_size(a: Union[InputAxis, OutputAxis]): 

2826 if isinstance(a, BatchAxis): 

2827 if (t_descr.id, a.id) in ns: 

2828 logger.warning( 

2829 "Ignoring unexpected size increment factor (n) for batch axis" 

2830 + " of tensor '{}'.", 

2831 t_descr.id, 

2832 ) 

2833 return batch_size 

2834 elif isinstance(a.size, int): 

2835 if (t_descr.id, a.id) in ns: 

2836 logger.warning( 

2837 "Ignoring unexpected size increment factor (n) for fixed size" 

2838 + " axis '{}' of tensor '{}'.", 

2839 a.id, 

2840 t_descr.id, 

2841 ) 

2842 return a.size 

2843 elif isinstance(a.size, ParameterizedSize): 

2844 if (t_descr.id, a.id) not in ns: 

2845 raise ValueError( 

2846 "Size increment factor (n) missing for parametrized axis" 

2847 + f" '{a.id}' of tensor '{t_descr.id}'." 

2848 ) 

2849 n = ns[(t_descr.id, a.id)] 

2850 s_max = max_input_shape.get((t_descr.id, a.id)) 

2851 if s_max is not None: 

2852 n = min(n, a.size.get_n(s_max)) 

2853 

2854 return a.size.get_size(n) 

2855 

2856 elif isinstance(a.size, SizeReference): 

2857 if (t_descr.id, a.id) in ns: 

2858 logger.warning( 

2859 "Ignoring unexpected size increment factor (n) for axis '{}'" 

2860 + " of tensor '{}' with size reference.", 

2861 a.id, 

2862 t_descr.id, 

2863 ) 

2864 assert not isinstance(a, BatchAxis) 

2865 ref_axis = all_axes[a.size.tensor_id][a.size.axis_id] 

2866 assert not isinstance(ref_axis, BatchAxis) 

2867 ref_key = (a.size.tensor_id, a.size.axis_id) 

2868 ref_size = inputs.get(ref_key, outputs.get(ref_key)) 

2869 assert ref_size is not None, ref_key 

2870 assert not isinstance(ref_size, _DataDepSize), ref_key 

2871 return a.size.get_size( 

2872 axis=a, 

2873 ref_axis=ref_axis, 

2874 ref_size=ref_size, 

2875 ) 

2876 elif isinstance(a.size, DataDependentSize): 

2877 if (t_descr.id, a.id) in ns: 

2878 logger.warning( 

2879 "Ignoring unexpected increment factor (n) for data dependent" 

2880 + " size axis '{}' of tensor '{}'.", 

2881 a.id, 

2882 t_descr.id, 

2883 ) 

2884 return _DataDepSize(a.size.min, a.size.max) 

2885 else: 

2886 assert_never(a.size) 

2887 

2888 # first resolve all , but the `SizeReference` input sizes 

2889 for t_descr in self.inputs: 

2890 for a in t_descr.axes: 

2891 if not isinstance(a.size, SizeReference): 

2892 s = get_axis_size(a) 

2893 assert not isinstance(s, _DataDepSize) 

2894 inputs[t_descr.id, a.id] = s 

2895 

2896 # resolve all other input axis sizes 

2897 for t_descr in self.inputs: 

2898 for a in t_descr.axes: 

2899 if isinstance(a.size, SizeReference): 

2900 s = get_axis_size(a) 

2901 assert not isinstance(s, _DataDepSize) 

2902 inputs[t_descr.id, a.id] = s 

2903 

2904 # resolve all output axis sizes 

2905 for t_descr in self.outputs: 

2906 for a in t_descr.axes: 

2907 assert not isinstance(a.size, ParameterizedSize) 

2908 s = get_axis_size(a) 

2909 outputs[t_descr.id, a.id] = s 

2910 

2911 return _AxisSizes(inputs=inputs, outputs=outputs) 

2912 

2913 @model_validator(mode="before") 

2914 @classmethod 

2915 def _convert(cls, data: Dict[str, Any]) -> Dict[str, Any]: 

2916 if ( 

2917 data.get("type") == "model" 

2918 and isinstance(fv := data.get("format_version"), str) 

2919 and fv.count(".") == 2 

2920 ): 

2921 fv_parts = fv.split(".") 

2922 if any(not p.isdigit() for p in fv_parts): 

2923 return data 

2924 

2925 fv_tuple = tuple(map(int, fv_parts)) 

2926 

2927 assert cls.implemented_format_version_tuple[0:2] == (0, 5) 

2928 if fv_tuple[:2] in ((0, 3), (0, 4)): 

2929 m04 = _ModelDescr_v0_4.load(data) 

2930 if not isinstance(m04, InvalidDescr): 

2931 return _model_conv.convert_as_dict(m04) 

2932 elif fv_tuple[:2] == (0, 5): 

2933 # bump patch version 

2934 data["format_version"] = cls.implemented_format_version 

2935 

2936 return data 

2937 

2938 

2939class _ModelConv(Converter[_ModelDescr_v0_4, ModelDescr]): 

2940 def _convert( 

2941 self, src: _ModelDescr_v0_4, tgt: "type[ModelDescr] | type[dict[str, Any]]" 

2942 ) -> "ModelDescr | dict[str, Any]": 

2943 name = "".join( 

2944 c if c in string.ascii_letters + string.digits + "_- ()" else " " 

2945 for c in src.name 

2946 ) 

2947 

2948 def conv_authors(auths: Optional[Sequence[_Author_v0_4]]): 

2949 conv = ( 

2950 _author_conv.convert if TYPE_CHECKING else _author_conv.convert_as_dict 

2951 ) 

2952 return None if auths is None else [conv(a) for a in auths] 

2953 

2954 if TYPE_CHECKING: 

2955 arch_file_conv = _arch_file_conv.convert 

2956 arch_lib_conv = _arch_lib_conv.convert 

2957 else: 

2958 arch_file_conv = _arch_file_conv.convert_as_dict 

2959 arch_lib_conv = _arch_lib_conv.convert_as_dict 

2960 

2961 input_size_refs = { 

2962 ipt.name: { 

2963 a: s 

2964 for a, s in zip( 

2965 ipt.axes, 

2966 ( 

2967 ipt.shape.min 

2968 if isinstance(ipt.shape, _ParameterizedInputShape_v0_4) 

2969 else ipt.shape 

2970 ), 

2971 ) 

2972 } 

2973 for ipt in src.inputs 

2974 if ipt.shape 

2975 } 

2976 output_size_refs = { 

2977 **{ 

2978 out.name: {a: s for a, s in zip(out.axes, out.shape)} 

2979 for out in src.outputs 

2980 if not isinstance(out.shape, _ImplicitOutputShape_v0_4) 

2981 }, 

2982 **input_size_refs, 

2983 } 

2984 

2985 return tgt( 

2986 attachments=( 

2987 [] 

2988 if src.attachments is None 

2989 else [FileDescr(source=f) for f in src.attachments.files] 

2990 ), 

2991 authors=[ 

2992 _author_conv.convert_as_dict(a) for a in src.authors 

2993 ], # pyright: ignore[reportArgumentType] 

2994 cite=[ 

2995 {"text": c.text, "doi": c.doi, "url": c.url} for c in src.cite 

2996 ], # pyright: ignore[reportArgumentType] 

2997 config=src.config, 

2998 covers=src.covers, 

2999 description=src.description, 

3000 documentation=src.documentation, 

3001 format_version="0.5.3", 

3002 git_repo=src.git_repo, # pyright: ignore[reportArgumentType] 

3003 icon=src.icon, 

3004 id=None if src.id is None else ModelId(src.id), 

3005 id_emoji=src.id_emoji, 

3006 license=src.license, # type: ignore 

3007 links=src.links, 

3008 maintainers=[ 

3009 _maintainer_conv.convert_as_dict(m) for m in src.maintainers 

3010 ], # pyright: ignore[reportArgumentType] 

3011 name=name, 

3012 tags=src.tags, 

3013 type=src.type, 

3014 uploader=src.uploader, 

3015 version=src.version, 

3016 inputs=[ # pyright: ignore[reportArgumentType] 

3017 _input_tensor_conv.convert_as_dict(ipt, tt, st, input_size_refs) 

3018 for ipt, tt, st, in zip( 

3019 src.inputs, 

3020 src.test_inputs, 

3021 src.sample_inputs or [None] * len(src.test_inputs), 

3022 ) 

3023 ], 

3024 outputs=[ # pyright: ignore[reportArgumentType] 

3025 _output_tensor_conv.convert_as_dict(out, tt, st, output_size_refs) 

3026 for out, tt, st, in zip( 

3027 src.outputs, 

3028 src.test_outputs, 

3029 src.sample_outputs or [None] * len(src.test_outputs), 

3030 ) 

3031 ], 

3032 parent=( 

3033 None 

3034 if src.parent is None 

3035 else LinkedModel( 

3036 id=ModelId( 

3037 str(src.parent.id) 

3038 + ( 

3039 "" 

3040 if src.parent.version_number is None 

3041 else f"/{src.parent.version_number}" 

3042 ) 

3043 ) 

3044 ) 

3045 ), 

3046 training_data=( 

3047 None 

3048 if src.training_data is None 

3049 else ( 

3050 LinkedDataset( 

3051 id=DatasetId( 

3052 str(src.training_data.id) 

3053 + ( 

3054 "" 

3055 if src.training_data.version_number is None 

3056 else f"/{src.training_data.version_number}" 

3057 ) 

3058 ) 

3059 ) 

3060 if isinstance(src.training_data, LinkedDataset02) 

3061 else src.training_data 

3062 ) 

3063 ), 

3064 packaged_by=[ 

3065 _author_conv.convert_as_dict(a) for a in src.packaged_by 

3066 ], # pyright: ignore[reportArgumentType] 

3067 run_mode=src.run_mode, 

3068 timestamp=src.timestamp, 

3069 weights=(WeightsDescr if TYPE_CHECKING else dict)( 

3070 keras_hdf5=(w := src.weights.keras_hdf5) 

3071 and (KerasHdf5WeightsDescr if TYPE_CHECKING else dict)( 

3072 authors=conv_authors(w.authors), 

3073 source=w.source, 

3074 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3075 parent=w.parent, 

3076 ), 

3077 onnx=(w := src.weights.onnx) 

3078 and (OnnxWeightsDescr if TYPE_CHECKING else dict)( 

3079 source=w.source, 

3080 authors=conv_authors(w.authors), 

3081 parent=w.parent, 

3082 opset_version=w.opset_version or 15, 

3083 ), 

3084 pytorch_state_dict=(w := src.weights.pytorch_state_dict) 

3085 and (PytorchStateDictWeightsDescr if TYPE_CHECKING else dict)( 

3086 source=w.source, 

3087 authors=conv_authors(w.authors), 

3088 parent=w.parent, 

3089 architecture=( 

3090 arch_file_conv( 

3091 w.architecture, 

3092 w.architecture_sha256, 

3093 w.kwargs, 

3094 ) 

3095 if isinstance(w.architecture, _CallableFromFile_v0_4) 

3096 else arch_lib_conv(w.architecture, w.kwargs) 

3097 ), 

3098 pytorch_version=w.pytorch_version or Version("1.10"), 

3099 dependencies=( 

3100 None 

3101 if w.dependencies is None 

3102 else (EnvironmentFileDescr if TYPE_CHECKING else dict)( 

3103 source=cast( 

3104 ImportantFileSource, 

3105 str(deps := w.dependencies)[ 

3106 ( 

3107 len("conda:") 

3108 if str(deps).startswith("conda:") 

3109 else 0 

3110 ) : 

3111 ], 

3112 ) 

3113 ) 

3114 ), 

3115 ), 

3116 tensorflow_js=(w := src.weights.tensorflow_js) 

3117 and (TensorflowJsWeightsDescr if TYPE_CHECKING else dict)( 

3118 source=w.source, 

3119 authors=conv_authors(w.authors), 

3120 parent=w.parent, 

3121 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3122 ), 

3123 tensorflow_saved_model_bundle=( 

3124 w := src.weights.tensorflow_saved_model_bundle 

3125 ) 

3126 and (TensorflowSavedModelBundleWeightsDescr if TYPE_CHECKING else dict)( 

3127 authors=conv_authors(w.authors), 

3128 parent=w.parent, 

3129 source=w.source, 

3130 tensorflow_version=w.tensorflow_version or Version("1.15"), 

3131 dependencies=( 

3132 None 

3133 if w.dependencies is None 

3134 else (EnvironmentFileDescr if TYPE_CHECKING else dict)( 

3135 source=cast( 

3136 ImportantFileSource, 

3137 ( 

3138 str(w.dependencies)[len("conda:") :] 

3139 if str(w.dependencies).startswith("conda:") 

3140 else str(w.dependencies) 

3141 ), 

3142 ) 

3143 ) 

3144 ), 

3145 ), 

3146 torchscript=(w := src.weights.torchscript) 

3147 and (TorchscriptWeightsDescr if TYPE_CHECKING else dict)( 

3148 source=w.source, 

3149 authors=conv_authors(w.authors), 

3150 parent=w.parent, 

3151 pytorch_version=w.pytorch_version or Version("1.10"), 

3152 ), 

3153 ), 

3154 ) 

3155 

3156 

3157_model_conv = _ModelConv(_ModelDescr_v0_4, ModelDescr) 

3158 

3159 

3160# create better cover images for 3d data and non-image outputs 

3161def generate_covers( 

3162 inputs: Sequence[Tuple[InputTensorDescr, NDArray[Any]]], 

3163 outputs: Sequence[Tuple[OutputTensorDescr, NDArray[Any]]], 

3164) -> List[Path]: 

3165 def squeeze( 

3166 data: NDArray[Any], axes: Sequence[AnyAxis] 

3167 ) -> Tuple[NDArray[Any], List[AnyAxis]]: 

3168 """apply numpy.ndarray.squeeze while keeping track of the axis descriptions remaining""" 

3169 if data.ndim != len(axes): 

3170 raise ValueError( 

3171 f"tensor shape {data.shape} does not match described axes" 

3172 + f" {[a.id for a in axes]}" 

3173 ) 

3174 

3175 axes = [deepcopy(a) for a, s in zip(axes, data.shape) if s != 1] 

3176 return data.squeeze(), axes 

3177 

3178 def normalize( 

3179 data: NDArray[Any], axis: Optional[Tuple[int, ...]], eps: float = 1e-7 

3180 ) -> NDArray[np.float32]: 

3181 data = data.astype("float32") 

3182 data -= data.min(axis=axis, keepdims=True) 

3183 data /= data.max(axis=axis, keepdims=True) + eps 

3184 return data 

3185 

3186 def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]): 

3187 original_shape = data.shape 

3188 data, axes = squeeze(data, axes) 

3189 

3190 # take slice fom any batch or index axis if needed 

3191 # and convert the first channel axis and take a slice from any additional channel axes 

3192 slices: Tuple[slice, ...] = () 

3193 ndim = data.ndim 

3194 ndim_need = 3 if any(isinstance(a, ChannelAxis) for a in axes) else 2 

3195 has_c_axis = False 

3196 for i, a in enumerate(axes): 

3197 s = data.shape[i] 

3198 assert s > 1 

3199 if ( 

3200 isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis)) 

3201 and ndim > ndim_need 

3202 ): 

3203 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3204 ndim -= 1 

3205 elif isinstance(a, ChannelAxis): 

3206 if has_c_axis: 

3207 # second channel axis 

3208 data = data[slices + (slice(0, 1),)] 

3209 ndim -= 1 

3210 else: 

3211 has_c_axis = True 

3212 if s == 2: 

3213 # visualize two channels with cyan and magenta 

3214 data = np.concatenate( 

3215 [ 

3216 data[slices + (slice(1, 2),)], 

3217 data[slices + (slice(0, 1),)], 

3218 ( 

3219 data[slices + (slice(0, 1),)] 

3220 + data[slices + (slice(1, 2),)] 

3221 ) 

3222 / 2, # TODO: take maximum instead? 

3223 ], 

3224 axis=i, 

3225 ) 

3226 elif data.shape[i] == 3: 

3227 pass # visualize 3 channels as RGB 

3228 else: 

3229 # visualize first 3 channels as RGB 

3230 data = data[slices + (slice(3),)] 

3231 

3232 assert data.shape[i] == 3 

3233 

3234 slices += (slice(None),) 

3235 

3236 data, axes = squeeze(data, axes) 

3237 assert len(axes) == ndim 

3238 # take slice from z axis if needed 

3239 slices = () 

3240 if ndim > ndim_need: 

3241 for i, a in enumerate(axes): 

3242 s = data.shape[i] 

3243 if a.id == AxisId("z"): 

3244 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3245 data, axes = squeeze(data, axes) 

3246 ndim -= 1 

3247 break 

3248 

3249 slices += (slice(None),) 

3250 

3251 # take slice from any space or time axis 

3252 slices = () 

3253 

3254 for i, a in enumerate(axes): 

3255 if ndim <= ndim_need: 

3256 break 

3257 

3258 s = data.shape[i] 

3259 assert s > 1 

3260 if isinstance( 

3261 a, (SpaceInputAxis, SpaceOutputAxis, TimeInputAxis, TimeOutputAxis) 

3262 ): 

3263 data = data[slices + (slice(s // 2 - 1, s // 2),)] 

3264 ndim -= 1 

3265 

3266 slices += (slice(None),) 

3267 

3268 del slices 

3269 data, axes = squeeze(data, axes) 

3270 assert len(axes) == ndim 

3271 

3272 if (has_c_axis and ndim != 3) or ndim != 2: 

3273 raise ValueError( 

3274 f"Failed to construct cover image from shape {original_shape}" 

3275 ) 

3276 

3277 if not has_c_axis: 

3278 assert ndim == 2 

3279 data = np.repeat(data[:, :, None], 3, axis=2) 

3280 axes.append(ChannelAxis(channel_names=list(map(Identifier, "RGB")))) 

3281 ndim += 1 

3282 

3283 assert ndim == 3 

3284 

3285 # transpose axis order such that longest axis comes first... 

3286 axis_order = list(np.argsort(list(data.shape))) 

3287 axis_order.reverse() 

3288 # ... and channel axis is last 

3289 c = [i for i in range(3) if isinstance(axes[i], ChannelAxis)][0] 

3290 axis_order.append(axis_order.pop(c)) 

3291 axes = [axes[ao] for ao in axis_order] 

3292 data = data.transpose(axis_order) 

3293 

3294 # h, w = data.shape[:2] 

3295 # if h / w in (1.0 or 2.0): 

3296 # pass 

3297 # elif h / w < 2: 

3298 # TODO: enforce 2:1 or 1:1 aspect ratio for generated cover images 

3299 

3300 norm_along = ( 

3301 tuple(i for i, a in enumerate(axes) if a.type in ("space", "time")) or None 

3302 ) 

3303 # normalize the data and map to 8 bit 

3304 data = normalize(data, norm_along) 

3305 data = (data * 255).astype("uint8") 

3306 

3307 return data 

3308 

3309 def create_diagonal_split_image(im0: NDArray[Any], im1: NDArray[Any]): 

3310 assert im0.dtype == im1.dtype == np.uint8 

3311 assert im0.shape == im1.shape 

3312 assert im0.ndim == 3 

3313 N, M, C = im0.shape 

3314 assert C == 3 

3315 out = np.ones((N, M, C), dtype="uint8") 

3316 for c in range(C): 

3317 outc = np.tril(im0[..., c]) 

3318 mask = outc == 0 

3319 outc[mask] = np.triu(im1[..., c])[mask] 

3320 out[..., c] = outc 

3321 

3322 return out 

3323 

3324 ipt_descr, ipt = inputs[0] 

3325 out_descr, out = outputs[0] 

3326 

3327 ipt_img = to_2d_image(ipt, ipt_descr.axes) 

3328 out_img = to_2d_image(out, out_descr.axes) 

3329 

3330 cover_folder = Path(mkdtemp()) 

3331 if ipt_img.shape == out_img.shape: 

3332 covers = [cover_folder / "cover.png"] 

3333 imwrite(covers[0], create_diagonal_split_image(ipt_img, out_img)) 

3334 else: 

3335 covers = [cover_folder / "input.png", cover_folder / "output.png"] 

3336 imwrite(covers[0], ipt_img) 

3337 imwrite(covers[1], out_img) 

3338 

3339 return covers