Coverage for src/bioimageio/core/remote_backends/gradio/serializer.py: 0%

44 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-22 16:54 +0000

1import tempfile 

2from pathlib import Path 

3from typing import Dict, List, Mapping, Union 

4 

5import numpy as np 

6from gradio_client import handle_file 

7from pydantic import BaseModel 

8from typing_extensions import Self 

9 

10from ..._common_annotations import PerMemberAnno 

11from ..._description_serializer import DescriptionSerializer as DescriptionSerializer 

12from ..._sample_serializer import SampleSerializer 

13from ...common import MemberId 

14from ...io import JsonValue, load_stat, save_tensor, serialize_stat 

15from ...sample import SampleBlock, SampleBlockMeta 

16from ...tensor import Tensor 

17 

18 

19class _SerializableBlock(BaseModel, frozen=True): 

20 path: Path 

21 meta: Mapping[str, str] 

22 orig_name: str 

23 

24 @classmethod 

25 def from_tensor(cls, tensor: Tensor) -> Self: 

26 with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp: 

27 save_tensor(tmp.name, tensor) 

28 

29 handled = handle_file(Path(tmp.name)) 

30 return cls.model_validate(handled) 

31 

32 

33class _SerializableSampleBlock(BaseModel, frozen=True): 

34 meta: SampleBlockMeta 

35 data: PerMemberAnno[Union[_SerializableBlock, Path]] 

36 serialized_stat: List[JsonValue] 

37 

38 

39SerializedSampleBlock = Dict[str, JsonValue] 

40 

41 

42class GradioSampleSerializer(SampleSerializer[SerializedSampleBlock]): 

43 @staticmethod 

44 def serialize_sample_block(sample_block: SampleBlock) -> SerializedSampleBlock: 

45 handled_members: Dict[MemberId, _SerializableBlock] = {} 

46 for m, t in sample_block.members.items(): 

47 handled_members[m] = _SerializableBlock.from_tensor(t) 

48 

49 serializable = _SerializableSampleBlock( 

50 data=handled_members, 

51 meta=sample_block.get_meta(), 

52 serialized_stat=serialize_stat(sample_block.stat), 

53 ) 

54 serialized = serializable.model_dump(mode="json") 

55 return serialized 

56 

57 @staticmethod 

58 def deserialize_sample_block(serialized: SerializedSampleBlock) -> SampleBlock: 

59 deserializable_sample = _SerializableSampleBlock.model_validate(serialized) 

60 sample_meta = deserializable_sample.meta 

61 members = { 

62 k: Tensor.from_numpy( 

63 np.load(v if isinstance(v, Path) else v.path), 

64 dims=list(sample_meta.shape[k]), 

65 ) 

66 for k, v in deserializable_sample.data.items() 

67 } 

68 return SampleBlock.from_meta( 

69 sample_meta, 

70 data=members, 

71 stat=load_stat(deserializable_sample.serialized_stat), 

72 )