Source code for dice_ml.explainer_interfaces.explainer_base

"""Module containing a template class to generate counterfactual explanations.
   Subclasses implement interfaces for different ML frameworks such as TensorFlow or PyTorch.
   All methods are in dice_ml.explainer_interfaces"""

from abc import ABC, abstractmethod
import numpy as np
import pandas as pd
from tqdm import tqdm

from collections.abc import Iterable
from sklearn.neighbors import KDTree
from dice_ml.counterfactual_explanations import CounterfactualExplanations
from dice_ml.utils.exception import UserConfigValidationException
from dice_ml.constants import ModelTypes


[docs]class ExplainerBase(ABC): def __init__(self, data_interface, model_interface=None): """Init method :param data_interface: an interface class to access data related params. :param model_interface: an interface class to access trained ML model. """ # initiating data and model related parameters self.data_interface = data_interface if model_interface is not None: # self.data_interface.create_ohe_params() self.model = model_interface self.model.load_model() # loading pickled trained model if applicable self.model.transformer.feed_data_params(data_interface) self.model.transformer.initialize_transform_func() # moved the following snippet to a method in public_data_interface # self.minx, self.maxx, self.encoded_categorical_feature_indexes = self.data_interface.get_data_params() # # # min and max for continuous features in original scale # flattened_indexes = [item for sublist in self.encoded_categorical_feature_indexes for item in sublist] # self.encoded_continuous_feature_indexes = [ix for ix in range(len(self.minx[0])) if ix not in flattened_indexes] # org_minx, org_maxx = self.data_interface.get_minx_maxx(normalized=False) # self.cont_minx = list(org_minx[0][self.encoded_continuous_feature_indexes]) # self.cont_maxx = list(org_maxx[0][self.encoded_continuous_feature_indexes]) # # # decimal precisions for continuous features # self.cont_precisions = \ # [self.data_interface.get_decimal_precisions()[ix] for ix in self.encoded_continuous_feature_indexes]
[docs] def generate_counterfactuals(self, query_instances, total_CFs, desired_class="opposite", desired_range=None, permitted_range=None, features_to_vary="all", stopping_threshold=0.5, posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", verbose=False, **kwargs): """General method for generating counterfactuals. :param query_instances: Input point(s) for which counterfactuals are to be generated. This can be a dataframe with one or more rows. :param total_CFs: Total number of counterfactuals required. :param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the outcome class of query_instance for binary classification. :param desired_range: For regression problems. Contains the outcome range to generate counterfactuals in. :param permitted_range: Dictionary with feature names as keys and permitted range in list as values. Defaults to the range inferred from training data. If None, uses the parameters initialized in data_interface. :param features_to_vary: Either a string "all" or a list of feature names to vary. :param stopping_threshold: Minimum threshold for counterfactuals target class probability. :param posthoc_sparsity_param: Parameter for the post-hoc operation on continuous features to enhance sparsity. :param posthoc_sparsity_algorithm: Perform either linear or binary search. Takes "linear" or "binary". Prefer binary search when a feature range is large (for instance, income varying from 10k to 1000k) and only if the features share a monotonic relationship with predicted outcome in the model. :param verbose: Whether to output detailed messages. :param sample_size: Sampling size :param random_seed: Random seed for reproducibility :param kwargs: Other parameters accepted by specific explanation method :returns: A CounterfactualExplanations object that contains the list of counterfactual examples per query_instance as one of its attributes. """ if total_CFs <= 0: raise UserConfigValidationException( "The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.") cf_examples_arr = [] query_instances_list = [] if isinstance(query_instances, pd.DataFrame): for ix in range(query_instances.shape[0]): query_instances_list.append(query_instances[ix:(ix+1)]) elif isinstance(query_instances, Iterable): query_instances_list = query_instances for query_instance in tqdm(query_instances_list): self.data_interface.set_continuous_feature_indexes(query_instance) res = self._generate_counterfactuals( query_instance, total_CFs, desired_class=desired_class, desired_range=desired_range, permitted_range=permitted_range, features_to_vary=features_to_vary, stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param, posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, verbose=verbose, **kwargs) cf_examples_arr.append(res) return CounterfactualExplanations(cf_examples_list=cf_examples_arr)
@abstractmethod def _generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", desired_range=None, permitted_range=None, features_to_vary="all", stopping_threshold=0.5, posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", verbose=False, **kwargs): """Internal method for generating counterfactuals for a given query instance. Any explainerclass inherting from this class would need to implement this abstract method. :param query_instance: Input point for which counterfactuals are to be generated. This can be a dataframe with one row. :param total_CFs: Total number of counterfactuals required. :param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the outcome class of query_instance for binary classification. :param desired_range: For regression problems. Contains the outcome range to generate counterfactuals in. :param permitted_range: Dictionary with feature names as keys and permitted range in list as values. Defaults to the range inferred from training data. If None, uses the parameters initialized in data_interface. :param features_to_vary: Either a string "all" or a list of feature names to vary. :param stopping_threshold: Minimum threshold for counterfactuals target class probability. :param posthoc_sparsity_param: Parameter for the post-hoc operation on continuous features to enhance sparsity. :param posthoc_sparsity_algorithm: Perform either linear or binary search. Takes "linear" or "binary". Prefer binary search when a feature range is large (for instance, income varying from 10k to 1000k) and only if the features share a monotonic relationship with predicted outcome in the model. :param verbose: Whether to output detailed messages. :param sample_size: Sampling size :param random_seed: Random seed for reproducibility :param kwargs: Other parameters accepted by specific explanation method :returns: A CounterfactualExplanations object that contains the list of counterfactual examples per query_instance as one of its attributes. """ pass
[docs] def setup(self, features_to_vary, permitted_range, query_instance, feature_weights): self.data_interface.check_features_to_vary(features_to_vary=features_to_vary) self.data_interface.check_permitted_range(permitted_range) if features_to_vary == 'all': features_to_vary = self.data_interface.feature_names if permitted_range is None: # use the precomputed default self.feature_range = self.data_interface.permitted_range feature_ranges_orig = self.feature_range else: # compute the new ranges based on user input self.feature_range, feature_ranges_orig = self.data_interface.get_features_range(permitted_range) self.check_query_instance_validity(features_to_vary, permitted_range, query_instance, feature_ranges_orig) # check feature MAD validity and throw warnings self.data_interface.check_mad_validity(feature_weights) return features_to_vary
[docs] def check_query_instance_validity(self, features_to_vary, permitted_range, query_instance, feature_ranges_orig): for feature in query_instance: if feature == self.data_interface.outcome_name: raise ValueError("Target", self.data_interface.outcome_name, "present in query instance") if feature not in self.data_interface.feature_names: raise ValueError("Feature", feature, "not present in training data!") for feature in self.data_interface.categorical_feature_names: if query_instance[feature].values[0] not in feature_ranges_orig[feature]: raise ValueError("Feature", feature, "has a value outside the dataset.") if feature not in features_to_vary and permitted_range is not None: if feature in permitted_range and feature in self.data_interface.continuous_feature_names: if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][1]: raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.") elif feature in permitted_range and feature in self.data_interface.categorical_feature_names: if query_instance[feature].values[0] not in self.feature_range[feature]: raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.")
[docs] def local_feature_importance(self, query_instances, cf_examples_list=None, total_CFs=10, desired_class="opposite", desired_range=None, permitted_range=None, features_to_vary="all", stopping_threshold=0.5, posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", **kwargs): """ Estimate local feature importance scores for the given inputs. :param query_instances: A list of inputs for which to compute the feature importances. These can be provided as a dataframe. :param cf_examples_list: If precomputed, a list of counterfactual examples for every input point. If cf_examples_list is provided, then all the following parameters are ignored. :param total_CFs: The number of counterfactuals to generate per input (default is 10) :param other_parameters: These are the same as the generate_counterfactuals method. :returns: An object of class CounterfactualExplanations that includes the list of counterfactuals per input, local feature importances per input, and the global feature importance summarized over all inputs. """ if cf_examples_list is not None: if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be " "greater than or equal to 10") elif total_CFs < 10: raise UserConfigValidationException("The number of counterfactuals generated per " "query instance should be greater than or equal to 10") importances = self.feature_importance( query_instances, cf_examples_list=cf_examples_list, total_CFs=total_CFs, local_importance=True, global_importance=False, desired_class=desired_class, desired_range=desired_range, permitted_range=permitted_range, features_to_vary=features_to_vary, stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param, posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, **kwargs) return importances
[docs] def global_feature_importance(self, query_instances, cf_examples_list=None, total_CFs=10, local_importance=True, desired_class="opposite", desired_range=None, permitted_range=None, features_to_vary="all", stopping_threshold=0.5, posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", **kwargs): """ Estimate global feature importance scores for the given inputs. :param query_instances: A list of inputs for which to compute the feature importances. These can be provided as a dataframe. :param cf_examples_list: If precomputed, a list of counterfactual examples for every input point. If cf_examples_list is provided, then all the following parameters are ignored. :param total_CFs: The number of counterfactuals to generate per input (default is 10) :param local_importance: Binary flag indicating whether local feature importance values should also be returned for each query instance. :param other_parameters: These are the same as the generate_counterfactuals method. :returns: An object of class CounterfactualExplanations that includes the list of counterfactuals per input, local feature importances per input, and the global feature importance summarized over all inputs. """ if query_instances is not None and len(query_instances) < 10: raise UserConfigValidationException("The number of query instances should be greater than or equal to 10") if cf_examples_list is not None: if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be " "greater than or equal to 10") elif total_CFs < 10: raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be greater " "than or equal to 10") importances = self.feature_importance( query_instances, cf_examples_list=cf_examples_list, total_CFs=total_CFs, local_importance=local_importance, global_importance=True, desired_class=desired_class, desired_range=desired_range, permitted_range=permitted_range, features_to_vary=features_to_vary, stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param, posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, **kwargs) return importances
[docs] def feature_importance(self, query_instances, cf_examples_list=None, total_CFs=10, local_importance=True, global_importance=True, desired_class="opposite", desired_range=None, permitted_range=None, features_to_vary="all", stopping_threshold=0.5, posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", **kwargs): """ Estimate feature importance scores for the given inputs. :param query_instances: A list of inputs for which to compute the feature importances. These can be provided as a dataframe. :param cf_examples_list: If precomputed, a list of counterfactual examples for every input point. If cf_examples_list is provided, then all the following parameters are ignored. :param total_CFs: The number of counterfactuals to generate per input (default is 10) :param other_parameters: These are the same as the generate_counterfactuals method. :returns: An object of class CounterfactualExplanations that includes the list of counterfactuals per input, local feature importances per input, and the global feature importance summarized over all inputs. """ if cf_examples_list is None: cf_examples_list = self.generate_counterfactuals( query_instances, total_CFs, desired_class=desired_class, desired_range=desired_range, permitted_range=permitted_range, features_to_vary=features_to_vary, stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param, posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, **kwargs).cf_examples_list allcols = self.data_interface.categorical_feature_names + self.data_interface.continuous_feature_names summary_importance = None local_importances = None if global_importance: summary_importance = {} # Initializing importance vector for col in allcols: summary_importance[col] = 0 if local_importance: local_importances = [{} for _ in range(len(cf_examples_list))] # Initializing local importance for the ith query instance for i in range(len(cf_examples_list)): for col in allcols: local_importances[i][col] = 0 overall_num_cfs = 0 # Summarizing the found counterfactuals for i in range(len(cf_examples_list)): cf_examples = cf_examples_list[i] org_instance = cf_examples.test_instance_df if cf_examples.final_cfs_df_sparse is not None: df = cf_examples.final_cfs_df_sparse else: df = cf_examples.final_cfs_df if df is None: continue per_query_point_cfs = 0 for index, row in df.iterrows(): per_query_point_cfs += 1 for col in self.data_interface.continuous_feature_names: if not np.isclose(org_instance[col].iat[0], row[col]): if summary_importance is not None: summary_importance[col] += 1 if local_importances is not None: local_importances[i][col] += 1 for col in self.data_interface.categorical_feature_names: if org_instance[col].iat[0] != row[col]: if summary_importance is not None: summary_importance[col] += 1 if local_importances is not None: local_importances[i][col] += 1 if local_importances is not None: for col in allcols: if per_query_point_cfs > 0: local_importances[i][col] /= per_query_point_cfs overall_num_cfs += per_query_point_cfs if summary_importance is not None: for col in allcols: if overall_num_cfs > 0: summary_importance[col] /= overall_num_cfs return CounterfactualExplanations( cf_examples_list, local_importance=local_importances, summary_importance=summary_importance)
[docs] def predict_fn(self, input_instance): """prediction function""" return self.model.get_output(input_instance)
[docs] def predict_fn_for_sparsity(self, input_instance): """prediction function for sparsity correction""" return self.model.get_output(input_instance)
[docs] def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, posthoc_sparsity_param, posthoc_sparsity_algorithm): """Post-hoc method to encourage sparsity in a generated counterfactuals. :param final_cfs_sparse: Final CFs in original user-fed format, in a pandas dataframe. :param query_instance: Query instance in original user-fed format, in a pandas dataframe. :param posthoc_sparsity_param: Parameter for the post-hoc operation on continuous features to enhance sparsity. :param posthoc_sparsity_algorithm: Perform either linear or binary search. Prefer binary search when a feature range is large (for instance, income varying from 10k to 1000k) and only if the features share a monotonic relationship with predicted outcome in the model. """ if final_cfs_sparse is None: return final_cfs_sparse # quantiles of the deviation from median for every continuous feature quantiles = self.data_interface.get_quantiles_from_training_data(quantile=posthoc_sparsity_param) mads = self.data_interface.get_valid_mads() # Setting the quantile of a feature to be the minimum of mad and quantile # Thus, the maximum deviation can be mad. for feature in quantiles: quantiles[feature] = min(quantiles[feature], mads[feature]) # Sorting features such that the feature with the highest quantile deviation # is first features_sorted = sorted(quantiles.items(), key=lambda kv: kv[1], reverse=True) for ix in range(len(features_sorted)): features_sorted[ix] = features_sorted[ix][0] precs = self.data_interface.get_decimal_precisions() decimal_prec = dict(zip(self.data_interface.continuous_feature_names, precs)) cfs_preds_sparse = [] for cf_ix in list(final_cfs_sparse.index): current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) for feature in features_sorted: # current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names]) # feat_ix = self.data_interface.continuous_feature_names.index(feature) diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) if(abs(diff) <= quantiles[feature]): if posthoc_sparsity_algorithm == "linear": final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix, feature, final_cfs_sparse, current_pred) elif posthoc_sparsity_algorithm == "binary": final_cfs_sparse = self.do_binary_search( diff, decimal_prec, query_instance, cf_ix, feature, final_cfs_sparse, current_pred) temp_preds = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) cfs_preds_sparse.append(temp_preds) final_cfs_sparse[self.data_interface.outcome_name] = self.get_model_output_from_scores(cfs_preds_sparse) # final_cfs_sparse[self.data_interface.outcome_name] = np.round(final_cfs_sparse[self.data_interface.outcome_name], 3) return final_cfs_sparse
[docs] def misc_init(self, stopping_threshold, desired_class, desired_range, test_pred): self.stopping_threshold = stopping_threshold if self.model.model_type == ModelTypes.Classifier: self.target_cf_class = np.array( [[self.infer_target_cfs_class(desired_class, test_pred, self.num_output_nodes)]], dtype=np.float32) desired_class = self.target_cf_class[0][0] if self.target_cf_class == 0 and self.stopping_threshold > 0.5: self.stopping_threshold = 0.25 elif self.target_cf_class == 1 and self.stopping_threshold < 0.5: self.stopping_threshold = 0.75 elif self.model.model_type == ModelTypes.Regressor: self.target_cf_range = self.infer_target_cfs_range(desired_range) return desired_class
[docs] def infer_target_cfs_class(self, desired_class_input, original_pred, num_output_nodes): """ Infer the target class for generating CFs. Only called when model_type=="classifier". TODO: Add support for opposite desired class in multiclass. Downstream methods should decide whether it is allowed or not. """ if desired_class_input == "opposite": if num_output_nodes == 2: original_pred_1 = np.argmax(original_pred) target_class = int(1 - original_pred_1) return target_class elif num_output_nodes > 2: raise UserConfigValidationException( "Desired class cannot be opposite if the number of classes is more than 2.") elif isinstance(desired_class_input, int): if desired_class_input >= 0 and desired_class_input < num_output_nodes: target_class = desired_class_input return target_class else: raise UserConfigValidationException("Desired class not present in training data!") else: raise UserConfigValidationException("The target class for {0} could not be identified".format( desired_class_input))
[docs] def infer_target_cfs_range(self, desired_range_input): target_range = None if desired_range_input is None: raise ValueError("Need to provide a desired_range for the target counterfactuals for a regression model.") else: if desired_range_input[0] > desired_range_input[1]: raise ValueError("Invalid Range!") else: target_range = desired_range_input return target_range
[docs] def decide_cf_validity(self, model_outputs): validity = np.zeros(len(model_outputs), dtype=np.int32) for i in range(len(model_outputs)): pred = model_outputs[i] if self.model.model_type == ModelTypes.Classifier: if self.num_output_nodes == 2: # binary pred_1 = pred[self.num_output_nodes-1] validity[i] = 1 if \ ((self.target_cf_class == 0 and pred_1 <= self.stopping_threshold) or (self.target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else 0 else: # multiclass if np.argmax(pred) == self.target_cf_class: validity[i] = 1 elif self.model.model_type == ModelTypes.Regressor: if self.target_cf_range[0] <= pred <= self.target_cf_range[1]: validity[i] = 1 return validity
[docs] def is_cf_valid(self, model_score): """Check if a cf belongs to the target class or target range. """ # Converting to single prediction if the prediction is provided as a # singleton array correct_dim = 1 if self.model.model_type == ModelTypes.Classifier else 0 if hasattr(model_score, "shape") and len(model_score.shape) > correct_dim: model_score = model_score[0] # Converting target_cf_class to a scalar (tf/torch have it as (1,1) shape) if self.model.model_type == ModelTypes.Classifier: target_cf_class = self.target_cf_class if hasattr(self.target_cf_class, "shape"): if len(self.target_cf_class.shape) == 1: target_cf_class = self.target_cf_class[0] elif len(self.target_cf_class.shape) == 2: target_cf_class = self.target_cf_class[0][0] target_cf_class = int(target_cf_class) if self.num_output_nodes == 1: # for tensorflow/pytorch models pred_1 = model_score[0] validity = True if \ ((target_cf_class == 0 and pred_1 <= self.stopping_threshold) or (target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False return validity if self.num_output_nodes == 2: # binary pred_1 = model_score[self.num_output_nodes-1] validity = True if \ ((target_cf_class == 0 and pred_1 <= self.stopping_threshold) or (target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False return validity else: # multiclass return np.argmax(model_score) == target_cf_class else: return self.target_cf_range[0] <= model_score and model_score <= self.target_cf_range[1]
[docs] def get_model_output_from_scores(self, model_scores): if self.model.model_type == ModelTypes.Classifier: output_type = np.int32 else: output_type = np.float32 model_output = np.zeros(len(model_scores), dtype=output_type) for i in range(len(model_scores)): if self.model.model_type == ModelTypes.Classifier: model_output[i] = np.argmax(model_scores[i]) elif self.model.model_type == ModelTypes.Regressor: model_output[i] = model_scores[i] return model_output
[docs] def check_permitted_range(self, permitted_range): """checks permitted range for continuous features TODO: add comments as to where this is used if this function is necessary, else remove. """ if permitted_range is not None: # if not self.data_interface.check_features_range(permitted_range): # raise ValueError( # "permitted range of features should be within their original range") # else: self.data_interface.permitted_range = permitted_range self.minx, self.maxx = self.data_interface.get_minx_maxx(normalized=True) self.cont_minx = [] self.cont_maxx = [] for feature in self.data_interface.continuous_feature_names: self.cont_minx.append(self.data_interface.permitted_range[feature][0]) self.cont_maxx.append(self.data_interface.permitted_range[feature][1])
[docs] def sigmoid(self, z): """This is used in VAE-based CF explainers.""" return 1 / (1 + np.exp(-z))
[docs] def build_KD_tree(self, data_df_copy, desired_range, desired_class, predicted_outcome_name): # Stores the predictions on the training data dataset_instance = self.data_interface.prepare_query_instance( query_instance=data_df_copy[self.data_interface.feature_names]) predictions = self.model.model.predict(dataset_instance) # TODO: Is it okay to insert a column in the original dataframe with the predicted outcome? This is memory-efficient data_df_copy[predicted_outcome_name] = predictions # segmenting the dataset according to outcome dataset_with_predictions = None if self.model.model_type == ModelTypes.Classifier: dataset_with_predictions = data_df_copy.loc[[i == desired_class for i in predictions]].copy() elif self.model.model_type == ModelTypes.Regressor: dataset_with_predictions = data_df_copy.loc[ [desired_range[0] <= pred <= desired_range[1] for pred in predictions]].copy() KD_tree = None # Prepares the KD trees for DiCE if len(dataset_with_predictions) > 0: dummies = pd.get_dummies(dataset_with_predictions[self.data_interface.feature_names]) KD_tree = KDTree(dummies) return dataset_with_predictions, KD_tree, predictions
[docs] def round_to_precision(self): # to display the values with the same precision as the original data precisions = self.data_interface.get_decimal_precisions() for ix, feature in enumerate(self.data_interface.continuous_feature_names): self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix]) if self.final_cfs_df_sparse is not None: self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix])