Source code for dasf.ml.inference.loader.torch

import inspect
import os

import torch

from .base import BaseLoader


[docs] class TorchLoader(BaseLoader): """ Model Loader for Torch models """ def __init__( self, model_class_or_file, dtype=torch.float32, checkpoint=None, device=None ): """ model_class_or_file: class or file with model definition dtype: data type of model input checkpoint: model chekpoint file device: device to place model ("cpu" or "gpu") """ super().__init__() self.model_class_or_file = model_class_or_file self.dtype = dtype self.checkpoint = checkpoint self.device = device
[docs] def load_model(self, **kwargs): device = torch.device("cuda" if self.device == "gpu" else "cpu") if inspect.isclass(self.model_class_or_file): model = self.model_class_or_file(**kwargs) if self.checkpoint: state_dict = torch.load(self.checkpoint, map_location=device) state_dict = ( state_dict["state_dict"] if "state_dict" in state_dict else state_dict ) # In case model was saved by TensorBoard model.load_state_dict(state_dict) elif os.path.isfile(self.model_class_or_file): model = torch.load(self.model_class_or_file) else: raise ValueError( "model_class_or_file must be a model class or path to model file" ) model.to(device=device) model.eval() return model
[docs] def inference(self, model, data): data = torch.from_numpy(data) device = torch.device("cuda" if self.device == "gpu" else "cpu") data = data.to(device, dtype=self.dtype) with torch.no_grad(): output = model(data) return output.cpu().numpy() if self.device == "gpu" else output.numpy()