Source code for dice_ml.dice

"""Module pointing to different implementations of DiCE based on different frameworks such as Tensorflow or PyTorch."""

import tensorflow as tf


[docs]class Dice: """An interface class to different DiCE implementations.""" def __init__(self, data_interface, model_interface, **kwargs): """Init method :param data_interface: an interface to access data related params. :param model_interface: an interface to access the output or gradients of a trained ML model. """ self.decide_implementation_type(data_interface, model_interface, **kwargs)
[docs] def decide_implementation_type(self, data_interface, model_interface, **kwargs): """Decides DiCE implementation type.""" self.__class__ = decide(data_interface, model_interface) self.__init__(data_interface, model_interface, **kwargs)
# To add new implementations of DiCE, add the class in explainer_interfaces subpackage and import-and-return the class in an elif loop as shown in the below method.
[docs]def decide(data_interface, model_interface): """Decides DiCE implementation type.""" if model_interface.backend == 'TF1': # pretrained Keras Sequential model with Tensorflow 1.x backend from dice_ml.explainer_interfaces.dice_tensorflow1 import DiceTensorFlow1 return DiceTensorFlow1 elif model_interface.backend == 'TF2': # pretrained Keras Sequential model with Tensorflow 2.x backend from dice_ml.explainer_interfaces.dice_tensorflow2 import DiceTensorFlow2 return DiceTensorFlow2 elif model_interface.backend == 'PYT': # PyTorch backend from dice_ml.explainer_interfaces.dice_pytorch import DicePyTorch return DicePyTorch else: # all other backends backend_dice = model_interface.backend['explainer'] module_name, class_name = backend_dice.split('.') module = __import__("dice_ml.explainer_interfaces." + module_name, fromlist=[class_name]) return getattr(module, class_name)