Coverage for bioimageio/core/axis.py: 82%
62 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1from __future__ import annotations
3from dataclasses import dataclass
4from typing import Literal, Mapping, Optional, TypeVar, Union
6from typing_extensions import assert_never
8from bioimageio.spec.model import v0_5
11def _get_axis_type(a: Literal["b", "t", "i", "c", "x", "y", "z"]):
12 if a == "b":
13 return "batch"
14 elif a == "t":
15 return "time"
16 elif a == "i":
17 return "index"
18 elif a == "c":
19 return "channel"
20 elif a in ("x", "y", "z"):
21 return "space"
22 else:
23 return "index" # return most unspecific axis
26S = TypeVar("S", bound=str)
29AxisId = v0_5.AxisId
31T = TypeVar("T")
32PerAxis = Mapping[AxisId, T]
34BatchSize = int
36AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"]
37AxisLike = Union[AxisId, AxisLetter, v0_5.AnyAxis, "Axis"]
40@dataclass
41class Axis:
42 id: AxisId
43 type: Literal["batch", "channel", "index", "space", "time"]
45 @classmethod
46 def create(cls, axis: AxisLike) -> Axis:
47 if isinstance(axis, cls):
48 return axis
49 elif isinstance(axis, Axis):
50 return Axis(id=axis.id, type=axis.type)
51 elif isinstance(axis, str):
52 return Axis(id=AxisId(axis), type=_get_axis_type(axis))
53 elif isinstance(axis, v0_5.AxisBase):
54 return Axis(id=AxisId(axis.id), type=axis.type)
55 else:
56 assert_never(axis)
59@dataclass
60class AxisInfo(Axis):
61 maybe_singleton: bool # TODO: replace 'maybe_singleton' with size min/max for better axis guessing
63 @classmethod
64 def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo:
65 if isinstance(axis, AxisInfo):
66 return axis
68 axis_base = super().create(axis)
69 if maybe_singleton is None:
70 if isinstance(axis, (Axis, str)):
71 maybe_singleton = True
72 else:
73 if axis.size is None:
74 maybe_singleton = True
75 elif isinstance(axis.size, int):
76 maybe_singleton = axis.size == 1
77 elif isinstance(axis.size, v0_5.SizeReference):
78 maybe_singleton = (
79 True # TODO: check if singleton is ok for a `SizeReference`
80 )
81 elif isinstance(
82 axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize)
83 ):
84 try:
85 maybe_size_one = axis.size.validate_size(
86 1
87 ) # TODO: refactor validate_size() to have boolean func here
88 except ValueError:
89 maybe_singleton = False
90 else:
91 maybe_singleton = maybe_size_one == 1
92 else:
93 assert_never(axis.size)
95 return AxisInfo(
96 id=axis_base.id, type=axis_base.type, maybe_singleton=maybe_singleton
97 )