Coverage for src / bioimageio / core / _op_base.py: 88%
33 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 13:23 +0000
1from abc import ABC, abstractmethod
2from dataclasses import dataclass
3from typing import Collection, Generic, Union
5from typing_extensions import TypeVar, assert_never
7from .axis import PerAxis
8from .block import Block
9from .common import MemberId
10from .sample import Sample, SampleBlock, SampleBlockWithOrigin
11from .stat_measures import (
12 Measure,
13 Stat,
14)
15from .tensor import Tensor
17SampleT = TypeVar("SampleT", bound=Union[Sample, SampleBlock, SampleBlockWithOrigin])
20@dataclass
21class Operator(Generic[SampleT], ABC):
22 """Base class for all operators."""
24 @abstractmethod
25 def __call__(self, sample: SampleT) -> None: ...
27 @property
28 @abstractmethod
29 def required_measures(self) -> Collection[Measure]: ...
32@dataclass
33class SamplewiseOperator(Operator[Sample]):
34 """Base class for operators that can only be applied to whole samples."""
37@dataclass
38class BlockwiseOperator(Operator[Union[Sample, SampleBlock]]):
39 """Base class for operators that can be applied to whole sample or blockwise."""
42@dataclass
43class SimpleOperator(BlockwiseOperator):
44 """Convenience base class for blockwise operators with a single input and single output."""
46 input: MemberId
47 output: MemberId
49 @abstractmethod
50 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ...
52 def __call__(self, sample: Union[Sample, SampleBlock]) -> None:
53 if self.input not in sample.members:
54 return
56 input_tensor = sample.members[self.input]
57 output_tensor = self._apply(input_tensor, sample.stat)
59 if self.output in sample.members:
60 assert (
61 sample.members[self.output].tagged_shape == output_tensor.tagged_shape
62 )
64 if isinstance(sample, Sample):
65 sample.members[self.output] = output_tensor
66 elif isinstance(sample, SampleBlock):
67 b = sample.blocks[self.input]
68 sample.blocks[self.output] = Block(
69 sample_shape=self.get_output_shape(sample.shape[self.input]),
70 data=output_tensor,
71 inner_slice=b.inner_slice,
72 halo=b.halo,
73 block_index=b.block_index,
74 blocks_in_sample=b.blocks_in_sample,
75 )
76 else:
77 assert_never(sample)
79 @abstractmethod
80 def _apply(self, x: Tensor, stat: Stat) -> Tensor: ...