"""Module containing an interface to trained PyTorch model."""
from dice_ml.model_interfaces.base_model import BaseModel
import torch
[docs]class PyTorchModel(BaseModel):
def __init__(self, model=None, model_path='', backend='PYT'):
"""Init method
:param model: trained PyTorch Model.
:param model_path: path to trained model.
:param backend: "PYT" for PyTorch framework.
"""
super().__init__(model, model_path, backend)
[docs] def load_model(self):
if self.model_path != '':
self.model = torch.load(self.model_path)
[docs] def get_output(self, input_tensor):
return self.model(input_tensor).float()
[docs] def set_eval_mode(self):
self.model.eval()
[docs] def get_gradient(self, input):
# Future Support
raise NotImplementedError("Future Support")