Quick introduction to generating counterfactual explanations using DiCE¶
[1]:
# import DiCE
import dice_ml
from dice_ml.utils import helpers # helper functions
# supress deprecation warnings from TF
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
DiCE requires two inputs: a training dataset and a pre-trained ML model. It can also work without access to the full dataset (see this notebook for advanced examples).
Loading dataset¶
We use the “adult” income dataset from UCI Machine Learning Repository (https://archive.ics.uci.edu/ml/datasets/adult). For demonstration purposes, we transform the data as described in dice_ml.utils.helpers module.
[2]:
dataset = helpers.load_adult_income_dataset()
This dataset has 8 features. The outcome is income which is binarized to 0 (low-income, <=50K) or 1 (high-income, >50K).
[3]:
dataset.head()
[3]:
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 39 | Government | Bachelors | Single | White-Collar | White | Male | 40 | 0 |
1 | 50 | Self-Employed | Bachelors | Married | White-Collar | White | Male | 13 | 0 |
2 | 38 | Private | HS-grad | Divorced | Blue-Collar | White | Male | 40 | 0 |
3 | 53 | Private | School | Married | Blue-Collar | Other | Male | 40 | 0 |
4 | 28 | Private | Bachelors | Married | Professional | Other | Female | 40 | 0 |
[4]:
# description of transformed features
adult_info = helpers.get_adult_data_info()
adult_info
[4]:
{'age': 'age',
'workclass': 'type of industry (Government, Other/Unknown, Private, Self-Employed)',
'education': 'education level (Assoc, Bachelors, Doctorate, HS-grad, Masters, Prof-school, School, Some-college)',
'marital_status': 'marital status (Divorced, Married, Separated, Single, Widowed)',
'occupation': 'occupation (Blue-Collar, Other/Unknown, Professional, Sales, Service, White-Collar)',
'race': 'white or other race?',
'gender': 'male or female?',
'hours_per_week': 'total work hours per week',
'income': '0 (<=50K) vs 1 (>50K)'}
Given this dataset, we construct a data object for DiCE. Since continuous and discrete features have different ways of perturbation, we need to specify the names of the continuous features. DiCE also requires the name of the output variable that the ML model will predict.
[5]:
d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')
Loading the ML model¶
Below, we use a pre-trained ML model which produces high accuracy comparable to other baselines. For convenience, we include the sample trained model with the DiCE package.
The variable backend below indicates the implementation type of DiCE we want to use. We use TensorFlow 1.x in the notebooks with backend=‘TF1’. You can set backend to ‘TF2’ or ‘PYT’ to use DiCE with TensorFlow 2.x or with PyTorch respectively. We want to note that the time required to find counterfactuals with Tensorflow 2.x’s eager style of execution is significantly greater than that with TensorFlow 1.x’s graph execution.
[6]:
backend = 'TF'+tf.__version__[0] # TF1
ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
m = dice_ml.Model(model_path= ML_modelpath, backend=backend)
For an example of how to train your own model, check out this notebook.
Generate diverse counterfactuals¶
Based on the data object d and the model object m, we can now instantiate the DiCE class for generating explanations.
[7]:
# initiate DiCE
exp = dice_ml.Dice(d, m)
DiCE provides local explanation for the model m and requires an query input whose outcome needs to be explained. Below we provide a sample input whose outcome is 0 (low-income) as per the ML model object m.
[8]:
# query instance in the form of a dictionary; keys: feature name, values: feature value
query_instance = {'age':22,
'workclass':'Private',
'education':'HS-grad',
'marital_status':'Single',
'occupation':'Service',
'race': 'White',
'gender':'Female',
'hours_per_week': 45}
Given the query input, we can now generate counterfactual explanations to show perturbed inputs from the original input where the ML model outputs class 1 (high-income).
[9]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=4, desired_class="opposite")
Diverse Counterfactuals found! total time taken: 00 min 02 sec
[10]:
# visualize the results
dice_exp.visualize_as_dataframe()
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 22.0 | Private | HS-grad | Single | Service | White | Female | 45.0 | 0.01904 |
Diverse Counterfactual set (new outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 70.0 | Private | Masters | Single | White-Collar | White | Female | 51.0 | 0.534 |
1 | 22.0 | Self-Employed | Doctorate | Married | Service | White | Female | 45.0 | 0.861 |
2 | 47.0 | Private | HS-grad | Married | Service | White | Female | 45.0 | 0.589 |
3 | 36.0 | Private | Prof-school | Married | Service | White | Female | 62.0 | 0.937 |
In case, if you would like to visualize only the changes made to the query instance, you can set the parameter show_only_changes=True
[11]:
# highlight only the changes
dice_exp.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 22.0 | Private | HS-grad | Single | Service | White | Female | 45.0 | 0.01904 |
Diverse Counterfactual set (new outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 70.0 | - | Masters | - | White-Collar | - | - | 51.0 | 0.534 |
1 | - | Self-Employed | Doctorate | Married | - | - | - | - | 0.861 |
2 | 47.0 | - | - | Married | - | - | - | - | 0.589 |
3 | 36.0 | - | Prof-school | Married | - | - | - | 62.0 | 0.937 |
That’s it! You can try generating counterfactual explanations for other examples using the same code. You can also try out a use-case for sensitive data in this notebook.
If you are curious about changing the number of explanations showns, feasibility of the explanations, or how to weigh different features for perturbing, check out how to change DiCE’s behavior in this advanced notebook.
The counterfactuals generated above are slightly different from those shown in our paper, where the loss convergence condition was made more conservative for rigorous experimentation. To replicate the results in the paper, add an argument loss_converge_maxiter=2 (the default value is 1) in the exp.generate_counterfactuals() method above. For more info, see generate_counterfactuals() method in dice_ml.dice_interfaces.dice_tensorflow.py.
Working with PyTorch¶
[6]:
try:
import torch
print('PyTorch installed.')
except ImportError as e:
print("Import Error!", e.name, "not found.. Please install from https://pytorch.org/")
PyTorch installed.
Just change the backend variable to ‘PYT’ to use DiCE with PyTorch. Below, we use a pre-trained ML model in PyTorch which produces high accuracy comparable to other baselines. For convenience, we include the sample trained model with the DiCE package.
[7]:
backend = 'PYT'
ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
m = dice_ml.Model(model_path= ML_modelpath, backend=backend)
Instantiate the DiCE class with the new PyTorch model object m.
[8]:
exp = dice_ml.Dice(d, m)
[9]:
# query instance in the form of a dictionary; keys: feature name, values: feature value
query_instance = {'age':22,
'workclass':'Private',
'education':'HS-grad',
'marital_status':'Single',
'occupation':'Service',
'race': 'White',
'gender':'Female',
'hours_per_week': 45}
[10]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=4, desired_class="opposite")
Diverse Counterfactuals found! total time taken: 00 min 05 sec
[11]:
# highlight only the changes
dice_exp.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 22.0 | Private | HS-grad | Single | Service | White | Female | 45.0 | 0.000042 |
Diverse Counterfactual set (new outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 57.0 | - | Doctorate | Married | White-Collar | - | - | - | 0.993 |
1 | 33.0 | - | Prof-school | Married | - | - | Male | 39.0 | 0.964 |
2 | - | Self-Employed | Prof-school | Married | - | - | - | - | 0.748 |
3 | 49.0 | - | Masters | Married | - | - | - | 62.0 | 0.957 |
[ ]: