Coverage for bioimageio/core/weight_converter/torch/_utils.py: 50%
14 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-19 09:02 +0000
1from typing import Union
3from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter
4from bioimageio.spec.model import v0_4, v0_5
5from bioimageio.spec.utils import download
7try:
8 import torch
9except ImportError:
10 torch = None
13# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too
14# and for each weight format
15def load_torch_model( # pyright: ignore[reportUnknownParameterType]
16 node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr],
17):
18 assert torch is not None
19 model = ( # pyright: ignore[reportUnknownVariableType]
20 PytorchModelAdapter.get_network(node)
21 )
22 state = torch.load(download(node.source).path, map_location="cpu")
23 model.load_state_dict(state) # FIXME: check incompatible keys?
24 return model.eval() # pyright: ignore[reportUnknownVariableType]