Generate feasible counterfactual explanations using a VAE¶
This presents the variational inference based approach for generating feasible counterfactuals, where we first train an encoder-decoder framework to generate counterfactuals. More details about our framework can be found here: https://arxiv.org/abs/1912.03277
[1]:
# import DiCE
import dice_ml
from dice_ml.utils import helpers # helper functions
%load_ext autoreload
%autoreload 2
DiCE requires two inputs: a training dataset and a pre-trained ML model. It can also work without access to the full dataset (see this notebook for advanced examples).
Loading dataset¶
We use the “adult” income dataset from UCI Machine Learning Repository (https://archive.ics.uci.edu/ml/datasets/adult). For demonstration purposes, we transform the data as described in dice_ml.utils.helpers module.
[2]:
dataset = helpers.load_adult_income_dataset()
This dataset has 8 features. The outcome is income which is binarized to 0 (low-income, <=50K) or 1 (high-income, >50K).
[3]:
dataset.head()
[3]:
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 28 | Private | Bachelors | Single | White-Collar | White | Female | 60 | 0 |
1 | 30 | Self-Employed | Assoc | Married | Professional | White | Male | 65 | 1 |
2 | 32 | Private | Some-college | Married | White-Collar | White | Male | 50 | 0 |
3 | 20 | Private | Some-college | Single | Service | White | Female | 35 | 0 |
4 | 41 | Self-Employed | Some-college | Married | White-Collar | White | Male | 50 | 0 |
[4]:
# description of transformed features
adult_info = helpers.get_adult_data_info()
adult_info
[4]:
{'age': 'age',
'workclass': 'type of industry (Government, Other/Unknown, Private, Self-Employed)',
'education': 'education level (Assoc, Bachelors, Doctorate, HS-grad, Masters, Prof-school, School, Some-college)',
'marital_status': 'marital status (Divorced, Married, Separated, Single, Widowed)',
'occupation': 'occupation (Blue-Collar, Other/Unknown, Professional, Sales, Service, White-Collar)',
'race': 'white or other race?',
'gender': 'male or female?',
'hours_per_week': 'total work hours per week',
'income': '0 (<=50K) vs 1 (>50K)'}
Given this dataset, we construct a data object for DiCE. Since continuous and discrete features have different ways of perturbation, we need to specify the names of the continuous features. DiCE also requires the name of the output variable that the ML model will predict.
[5]:
d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'],
outcome_name='income', data_name='adult', test_size=0.1)
Loading the ML model¶
Below, we use a pre-trained ML model which produces high accuracy comparable to other baselines. For convenience, we include the sample trained model with the DiCE package.
Note that we need to specify the explainer in the model backend. This is because both model and explainer need to be using the same backend library (pytorch or tensorflow).
[6]:
backend = {'model': 'pytorch_model.PyTorchModel',
'explainer': 'feasible_base_vae.FeasibleBaseVAE'}
ML_modelpath = helpers.get_adult_income_modelpath(backend='PYT')
ML_modelpath = ML_modelpath[:-4] + '_2nodes.pth'
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
m.load_model()
print('ML Model', m.model)
ML Model Sequential(
(0): Linear(in_features=29, out_features=20, bias=True)
(1): ReLU()
(2): Linear(in_features=20, out_features=2, bias=True)
(3): Softmax(dim=None)
)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/serialization.py:658: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/serialization.py:658: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/serialization.py:658: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/serialization.py:658: SourceChangeWarning: source code of class 'torch.nn.modules.activation.Softmax' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
Generate counterfactuals using a VAE model¶
Based on the data object d and the model object m, we can now instantiate the DiCE class for generating explanations. We present the variational inference based approach towards generating counterfactuals, where we first train an encoder-decoder framework to generate counterfactuals.
FeasibleBaseVAE class has an method train()
, which would train the Variational Encoder Decoder framework on the input dataframe. It has another arugment, pre_trained
, which if set to 0 would re-train the framework each time while generating CFs. Else, it can be set to 1 to avoid repeated training of the framework and would load the latest fitted VAE model.
[7]:
# initiate DiCE
exp = dice_ml.Dice(d, m, encoded_size=10, lr=1e-2,
batch_size=2048, validity_reg=42.0, margin=0.165, epochs=25,
wm1=1e-2, wm2=1e-2, wm3=1e-2)
exp.train(pre_trained=1)
Dataset Shape: (26048, 30)
Datasets Columns: Index(['age', 'hours_per_week', 'workclass_Government',
'workclass_Other/Unknown', 'workclass_Private',
'workclass_Self-Employed', 'education_Assoc', 'education_Bachelors',
'education_Doctorate', 'education_HS-grad', 'education_Masters',
'education_Prof-school', 'education_School', 'education_Some-college',
'marital_status_Divorced', 'marital_status_Married',
'marital_status_Separated', 'marital_status_Single',
'marital_status_Widowed', 'occupation_Blue-Collar',
'occupation_Other/Unknown', 'occupation_Professional',
'occupation_Sales', 'occupation_Service', 'occupation_White-Collar',
'race_Other', 'race_White', 'gender_Female', 'gender_Male', 'income'],
dtype='object')
DiCE is a form of a local explanation and requires an query input whose outcome needs to be explained. Below we provide a sample input whose outcome is 0 (low-income) as per the ML model object m.
[8]:
# query instance in the form of a dictionary; keys: feature name, values: feature value
query_instance = {'age': 41,
'workclass': 'Private',
'education': 'HS-grad',
'marital_status': 'Single',
'occupation': 'Service',
'race': 'White',
'gender': 'Female',
'hours_per_week': 45}
Given the query input, we can now generate counterfactual explanations to show perturbed inputs from the original input where the ML model outputs class 1 (high-income).
[9]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite")
# visualize the results
dice_exp.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/nn/modules/container.py:117: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
input = module(input)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/dice_ml/utils/sample_architecture/vae_model.py:121: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
c = torch.tensor(c).float()
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 41.0 | Private | HS-grad | Single | Service | White | Female | 45.0 | 0 |
Counterfactual set (new outcome: 1.0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 42.0 | - | Doctorate | Married | White-Collar | - | Male | 40.0 | 1 |
1 | 40.0 | - | Doctorate | Married | White-Collar | - | Male | 40.0 | 1 |
2 | 42.0 | - | Doctorate | Married | White-Collar | - | Male | 40.0 | 1 |
3 | - | - | Doctorate | Married | White-Collar | - | Male | 40.0 | 1 |
4 | 37.0 | - | Doctorate | Married | White-Collar | - | Male | 40.0 | 1 |
That’s it! You can try generating counterfactual explanations for other examples using the same code. You can compare the running time of this VAE-based to DiCE’s default method: VAE-based method is super fast!
Adding feasibility constraints¶
However, you might notice that for some examples, the above method can still return infeasible counterfactuals. This requires our base framework to be adpated for prodcuing feasible counterfactuals. A detailed description of how we adapt the method under different assumptions is provided in this paper.
In the section below, we show an adaptation our base approach for preserving the Age-Ed constraint: Age and Education can never decrease and increasing Education implies increase in Age. This approach is called ModelApprox, where we adapt our base approach for simple unary and binary constraints.
ModelApprox¶
Similar to the FeasibleBaseVAE class above, FeasibleModelApprox class has a method train()
with argument pre_trained
, which determines whether to train the framework again or load the latest optimal model. However, there are additional arguments to the train()
method:
The first arugment determines whether the constraint to be preserved is unary or monotonic
The second arugment provides the list of constraint variable names: [[Effect, Cause_1,..,Cause_n]]. In the case of a unary constraint, there would be no causes but only a single constrained variable.
The third argument provides the intended direction of change for the constrained variables: Value of 1 means that we allow for only increase in the constrained variable on the change from data point to its counterfactual and vice versa.
The fourth argument refers to the penalty weight for infeasibility under given constraint.
Initilize the Model and Explainer for FeasibleModelApprox¶
[10]:
backend = {'model': 'pytorch_model.PyTorchModel',
'explainer': 'feasible_model_approx.FeasibleModelApprox'}
ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
ML_modelpath = ML_modelpath[:-4] + '_2nodes.pth'
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
m.load_model()
print('ML Model', m.model)
ML Model Sequential(
(0): Linear(in_features=29, out_features=20, bias=True)
(1): ReLU()
(2): Linear(in_features=20, out_features=2, bias=True)
(3): Softmax(dim=None)
)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/serialization.py:658: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/serialization.py:658: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/serialization.py:658: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/serialization.py:658: SourceChangeWarning: source code of class 'torch.nn.modules.activation.Softmax' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
[11]:
# initiate DiCE
exp = dice_ml.Dice(d, m, encoded_size=10, lr=1e-2, batch_size=2048,
validity_reg=76.0, margin=0.344, epochs=25,
wm1=1e-2, wm2=1e-2, wm3=1e-2)
exp.train(1, [[0]], 1, 87, pre_trained=1)
Dataset Shape: (26048, 30)
Datasets Columns: Index(['age', 'hours_per_week', 'workclass_Government',
'workclass_Other/Unknown', 'workclass_Private',
'workclass_Self-Employed', 'education_Assoc', 'education_Bachelors',
'education_Doctorate', 'education_HS-grad', 'education_Masters',
'education_Prof-school', 'education_School', 'education_Some-college',
'marital_status_Divorced', 'marital_status_Married',
'marital_status_Separated', 'marital_status_Single',
'marital_status_Widowed', 'occupation_Blue-Collar',
'occupation_Other/Unknown', 'occupation_Professional',
'occupation_Sales', 'occupation_Service', 'occupation_White-Collar',
'race_Other', 'race_White', 'gender_Female', 'gender_Male', 'income'],
dtype='object')
[12]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=5, desired_class="opposite")
# visualize the results
dice_exp.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/nn/modules/container.py:117: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
input = module(input)
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/dice_ml/utils/sample_architecture/vae_model.py:121: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
c = torch.tensor(c).float()
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 41.0 | Private | HS-grad | Single | Service | White | Female | 45.0 | 0 |
Counterfactual set (new outcome: 1.0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 37.0 | - | Doctorate | Married | White-Collar | - | Male | 40.0 | 1 |
1 | 43.0 | - | Doctorate | Married | White-Collar | - | Male | 40.0 | 1 |
2 | 43.0 | - | Doctorate | Married | White-Collar | - | Male | 40.0 | 1 |
3 | 45.0 | - | Doctorate | Married | White-Collar | - | Male | 40.0 | 1 |
4 | 44.0 | - | Doctorate | Married | White-Collar | - | Male | 40.0 | 1 |
The results for ModelApprox show that the Age is also increased with increase in Education in counterfactual explanations unlike the BaseVAE method. You can try to experiment with ModelApprox to preserve unary and monotonic constraints for other datasets too. Examples for even more advanced approaches like SCMGenCF,OracleGenCF would be included soon to this repository, where we learn to generate feasible counterfactuals for complex feasiblity constraints. More details can be found in our paper.