Coverage for src / bioimageio / core / _op_base.py: 88%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 13:23 +0000

1from abc import ABC, abstractmethod 

2from dataclasses import dataclass 

3from typing import Collection, Generic, Union 

4 

5from typing_extensions import TypeVar, assert_never 

6 

7from .axis import PerAxis 

8from .block import Block 

9from .common import MemberId 

10from .sample import Sample, SampleBlock, SampleBlockWithOrigin 

11from .stat_measures import ( 

12 Measure, 

13 Stat, 

14) 

15from .tensor import Tensor 

16 

17SampleT = TypeVar("SampleT", bound=Union[Sample, SampleBlock, SampleBlockWithOrigin]) 

18 

19 

20@dataclass 

21class Operator(Generic[SampleT], ABC): 

22 """Base class for all operators.""" 

23 

24 @abstractmethod 

25 def __call__(self, sample: SampleT) -> None: ... 

26 

27 @property 

28 @abstractmethod 

29 def required_measures(self) -> Collection[Measure]: ... 

30 

31 

32@dataclass 

33class SamplewiseOperator(Operator[Sample]): 

34 """Base class for operators that can only be applied to whole samples.""" 

35 

36 

37@dataclass 

38class BlockwiseOperator(Operator[Union[Sample, SampleBlock]]): 

39 """Base class for operators that can be applied to whole sample or blockwise.""" 

40 

41 

42@dataclass 

43class SimpleOperator(BlockwiseOperator): 

44 """Convenience base class for blockwise operators with a single input and single output.""" 

45 

46 input: MemberId 

47 output: MemberId 

48 

49 @abstractmethod 

50 def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ... 

51 

52 def __call__(self, sample: Union[Sample, SampleBlock]) -> None: 

53 if self.input not in sample.members: 

54 return 

55 

56 input_tensor = sample.members[self.input] 

57 output_tensor = self._apply(input_tensor, sample.stat) 

58 

59 if self.output in sample.members: 

60 assert ( 

61 sample.members[self.output].tagged_shape == output_tensor.tagged_shape 

62 ) 

63 

64 if isinstance(sample, Sample): 

65 sample.members[self.output] = output_tensor 

66 elif isinstance(sample, SampleBlock): 

67 b = sample.blocks[self.input] 

68 sample.blocks[self.output] = Block( 

69 sample_shape=self.get_output_shape(sample.shape[self.input]), 

70 data=output_tensor, 

71 inner_slice=b.inner_slice, 

72 halo=b.halo, 

73 block_index=b.block_index, 

74 blocks_in_sample=b.blocks_in_sample, 

75 ) 

76 else: 

77 assert_never(sample) 

78 

79 @abstractmethod 

80 def _apply(self, x: Tensor, stat: Stat) -> Tensor: ...