Coverage for bioimageio/core/weight_converter/torch/_utils.py: 50%

14 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-19 09:02 +0000

1from typing import Union 

2 

3from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter 

4from bioimageio.spec.model import v0_4, v0_5 

5from bioimageio.spec.utils import download 

6 

7try: 

8 import torch 

9except ImportError: 

10 torch = None 

11 

12 

13# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too 

14# and for each weight format 

15def load_torch_model( # pyright: ignore[reportUnknownParameterType] 

16 node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], 

17): 

18 assert torch is not None 

19 model = ( # pyright: ignore[reportUnknownVariableType] 

20 PytorchModelAdapter.get_network(node) 

21 ) 

22 state = torch.load(download(node.source).path, map_location="cpu") 

23 model.load_state_dict(state) # FIXME: check incompatible keys? 

24 return model.eval() # pyright: ignore[reportUnknownVariableType]