In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error as mse
from scipy.stats import entropy
import warnings

from causalml.inference.meta import LRSRegressor
from causalml.inference.meta import XGBTRegressor, MLPTRegressor
from causalml.inference.meta import BaseXRegressor, BaseRRegressor, BaseSRegressor, BaseTRegressor
from causalml.inference.nn import DragonNet
from causalml.match import NearestNeighborMatch, MatchOptimizer, create_table_one
from causalml.propensity import ElasticNetPropensityModel
from causalml.dataset.regression import *
from causalml.metrics import *

import os, sys

%matplotlib inline

warnings.filterwarnings('ignore')
plt.style.use('fivethirtyeight')
sns.set_palette('Paired')
plt.rcParams['figure.figsize'] = (12,8)
/Users/jeong/.conda/envs/py36/lib/python3.6/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.utils.testing module is  deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.
  warnings.warn(message, FutureWarning)
Using TensorFlow backend.

IHDP semi-synthetic dataset

Hill introduced a semi-synthetic dataset constructed from the Infant Health and Development Program (IHDP). This dataset is based on a randomized experiment investigating the effect of home visits by specialists on future cognitive scores. The data has 747 observations (rows). The IHDP simulation is considered the de-facto standard benchmark for neural network treatment effect estimation methods.

The original paper uses 1000 realizations from the NCPI package, but for illustration purposes, we use 1 dataset (realization) as an example below.

In [3]:
df = pd.read_csv(f'data/ihdp_npci_3.csv', header=None)
cols =  ["treatment", "y_factual", "y_cfactual", "mu0", "mu1"] + [f'x{i}' for i in range(1,26)]
df.columns = cols
In [4]:
df.shape
Out[4]:
(747, 30)
In [5]:
df.head()
Out[5]:
treatment y_factual y_cfactual mu0 mu1 x1 x2 x3 x4 x5 ... x16 x17 x18 x19 x20 x21 x22 x23 x24 x25
0 1 5.931652 3.500591 2.253801 7.136441 -0.528603 -0.343455 1.128554 0.161703 -0.316603 ... 1 1 1 1 0 0 0 0 0 0
1 0 2.175966 5.952101 1.257592 6.553022 -1.736945 -1.802002 0.383828 2.244320 -0.629189 ... 1 1 1 1 0 0 0 0 0 0
2 0 2.180294 7.175734 2.384100 7.192645 -0.807451 -0.202946 -0.360898 -0.879606 0.808706 ... 1 0 1 1 0 0 0 0 0 0
3 0 3.587662 7.787537 4.009365 7.712456 0.390083 0.596582 -1.850350 -0.879606 -0.004017 ... 1 0 1 1 0 0 0 0 0 0
4 0 2.372618 5.461871 2.481631 7.232739 -1.045229 -0.602710 0.011465 0.161703 0.683672 ... 1 1 1 1 0 0 0 0 0 0

5 rows × 30 columns

In [6]:
pd.Series(df['treatment']).value_counts(normalize=True)
Out[6]:
0    0.813922
1    0.186078
Name: treatment, dtype: float64
In [7]:
X = df.loc[:,'x1':]
treatment = df['treatment']
y = df['y_factual']
tau = df.apply(lambda d: d['y_factual'] - d['y_cfactual'] if d['treatment']==1 
               else d['y_cfactual'] - d['y_factual'], 
               axis=1)
In [8]:
# p_model = LogisticRegressionCV(penalty='elasticnet', solver='saga', l1_ratios=np.linspace(0,1,5),
#                                cv=StratifiedKFold(n_splits=4, shuffle=True))
# p_model.fit(X, treatment)
# p = p_model.predict_proba(X)[:, 1]
In [9]:
p_model = ElasticNetPropensityModel()
p = p_model.fit_predict(X, treatment)
In [10]:
s_learner = BaseSRegressor(LGBMRegressor())
s_ate = s_learner.estimate_ate(X, treatment, y)[0]
s_ite = s_learner.fit_predict(X, treatment, y)

t_learner = BaseTRegressor(LGBMRegressor())
t_ate = t_learner.estimate_ate(X, treatment, y)[0][0]
t_ite = t_learner.fit_predict(X, treatment, y)

x_learner = BaseXRegressor(LGBMRegressor())
x_ate = x_learner.estimate_ate(X, treatment, y, p)[0][0]
x_ite = x_learner.fit_predict(X, treatment, y, p)

r_learner = BaseRRegressor(LGBMRegressor())
r_ate = r_learner.estimate_ate(X, treatment, y, p)[0][0]
r_ite = r_learner.fit_predict(X, treatment, y, p)
In [11]:
dragon = DragonNet(neurons_per_layer=200, targeted_reg=True)
dragon_ite = dragon.fit_predict(X, treatment, y, return_components=False)
dragon_ate = dragon_ite.mean()
Train on 597 samples, validate on 150 samples
Epoch 1/30
597/597 [==============================] - 1s 1ms/step - loss: 1153.1169 - regression_loss: 526.5245 - binary_classification_loss: 34.2278 - treatment_accuracy: 0.7999 - track_epsilon: 0.0516 - val_loss: 356.0019 - val_regression_loss: 126.8068 - val_binary_classification_loss: 34.7623 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0513
Epoch 2/30
597/597 [==============================] - 0s 67us/step - loss: 343.8514 - regression_loss: 142.3123 - binary_classification_loss: 28.2888 - treatment_accuracy: 0.8434 - track_epsilon: 0.0513 - val_loss: 230.0812 - val_regression_loss: 81.2849 - val_binary_classification_loss: 34.9740 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0496
Epoch 3/30
597/597 [==============================] - 0s 73us/step - loss: 255.0366 - regression_loss: 108.9301 - binary_classification_loss: 26.8012 - treatment_accuracy: 0.8465 - track_epsilon: 0.0490 - val_loss: 235.1863 - val_regression_loss: 82.9400 - val_binary_classification_loss: 35.9143 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0493
Epoch 4/30
597/597 [==============================] - 0s 83us/step - loss: 214.3295 - regression_loss: 84.8636 - binary_classification_loss: 26.3836 - treatment_accuracy: 0.8561 - track_epsilon: 0.0496 - val_loss: 206.7090 - val_regression_loss: 66.4528 - val_binary_classification_loss: 36.8853 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0489
Epoch 5/30
597/597 [==============================] - 0s 63us/step - loss: 193.1023 - regression_loss: 77.6289 - binary_classification_loss: 25.8865 - treatment_accuracy: 0.8497 - track_epsilon: 0.0478 - val_loss: 204.8226 - val_regression_loss: 71.0998 - val_binary_classification_loss: 35.7694 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0470
Epoch 6/30
597/597 [==============================] - 0s 62us/step - loss: 181.7809 - regression_loss: 71.4368 - binary_classification_loss: 25.3941 - treatment_accuracy: 0.8593 - track_epsilon: 0.0469 - val_loss: 209.0668 - val_regression_loss: 68.9204 - val_binary_classification_loss: 36.2566 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0464
Epoch 7/30
597/597 [==============================] - 0s 61us/step - loss: 176.5884 - regression_loss: 70.6928 - binary_classification_loss: 25.1117 - treatment_accuracy: 0.8561 - track_epsilon: 0.0455 - val_loss: 203.3805 - val_regression_loss: 69.1391 - val_binary_classification_loss: 35.8173 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0446
Epoch 8/30
597/597 [==============================] - 0s 65us/step - loss: 170.7210 - regression_loss: 66.4146 - binary_classification_loss: 24.8363 - treatment_accuracy: 0.8401 - track_epsilon: 0.0441 - val_loss: 192.5185 - val_regression_loss: 62.5455 - val_binary_classification_loss: 36.8282 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0433
Epoch 9/30
597/597 [==============================] - 0s 66us/step - loss: 160.6429 - regression_loss: 61.6206 - binary_classification_loss: 24.6174 - treatment_accuracy: 0.8497 - track_epsilon: 0.0426 - val_loss: 194.9871 - val_regression_loss: 64.1374 - val_binary_classification_loss: 36.2175 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0418
Epoch 10/30
597/597 [==============================] - 0s 65us/step - loss: 160.4497 - regression_loss: 61.8506 - binary_classification_loss: 24.4592 - treatment_accuracy: 0.8497 - track_epsilon: 0.0412 - val_loss: 188.0958 - val_regression_loss: 60.7865 - val_binary_classification_loss: 36.4476 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0403
Epoch 11/30
597/597 [==============================] - 0s 62us/step - loss: 159.8468 - regression_loss: 62.8502 - binary_classification_loss: 24.3127 - treatment_accuracy: 0.8529 - track_epsilon: 0.0395 - val_loss: 197.3698 - val_regression_loss: 63.0735 - val_binary_classification_loss: 36.6958 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0390
Epoch 12/30
597/597 [==============================] - 0s 57us/step - loss: 159.8472 - regression_loss: 61.2275 - binary_classification_loss: 24.2195 - treatment_accuracy: 0.8497 - track_epsilon: 0.0383 - val_loss: 190.8406 - val_regression_loss: 64.3669 - val_binary_classification_loss: 35.1488 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0372
Train on 597 samples, validate on 150 samples
Epoch 1/300
597/597 [==============================] - 1s 1ms/step - loss: 151.0525 - regression_loss: 58.8814 - binary_classification_loss: 24.1191 - treatment_accuracy: 0.8529 - track_epsilon: 0.0377 - val_loss: 184.5767 - val_regression_loss: 59.0096 - val_binary_classification_loss: 35.9360 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0390
Epoch 2/300
597/597 [==============================] - 0s 67us/step - loss: 150.0762 - regression_loss: 56.8447 - binary_classification_loss: 24.1029 - treatment_accuracy: 0.8497 - track_epsilon: 0.0326 - val_loss: 184.1211 - val_regression_loss: 59.6037 - val_binary_classification_loss: 35.7386 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0227
Epoch 3/300
597/597 [==============================] - 0s 67us/step - loss: 149.7746 - regression_loss: 57.0947 - binary_classification_loss: 24.0785 - treatment_accuracy: 0.8561 - track_epsilon: 0.0181 - val_loss: 181.9517 - val_regression_loss: 59.1016 - val_binary_classification_loss: 35.8648 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0142
Epoch 4/300
597/597 [==============================] - 0s 68us/step - loss: 148.7084 - regression_loss: 57.0758 - binary_classification_loss: 24.0558 - treatment_accuracy: 0.8561 - track_epsilon: 0.0127 - val_loss: 182.8566 - val_regression_loss: 59.1128 - val_binary_classification_loss: 35.9316 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0076
Epoch 5/300
597/597 [==============================] - 0s 69us/step - loss: 150.4725 - regression_loss: 56.8933 - binary_classification_loss: 24.0455 - treatment_accuracy: 0.8529 - track_epsilon: 0.0040 - val_loss: 182.9057 - val_regression_loss: 59.1165 - val_binary_classification_loss: 36.0808 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0022
Epoch 6/300
597/597 [==============================] - 0s 68us/step - loss: 147.7774 - regression_loss: 56.4606 - binary_classification_loss: 24.0391 - treatment_accuracy: 0.8593 - track_epsilon: 0.0013 - val_loss: 183.9675 - val_regression_loss: 59.4084 - val_binary_classification_loss: 36.1876 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0024
Epoch 7/300
597/597 [==============================] - 0s 67us/step - loss: 149.8826 - regression_loss: 57.2671 - binary_classification_loss: 24.0319 - treatment_accuracy: 0.8529 - track_epsilon: 0.0028 - val_loss: 186.5590 - val_regression_loss: 60.4098 - val_binary_classification_loss: 36.1753 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 4.0377e-04
Epoch 8/300
597/597 [==============================] - 0s 70us/step - loss: 148.1314 - regression_loss: 56.3730 - binary_classification_loss: 24.0128 - treatment_accuracy: 0.8561 - track_epsilon: 0.0021 - val_loss: 183.1079 - val_regression_loss: 59.5408 - val_binary_classification_loss: 36.1076 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0031
Epoch 9/300
597/597 [==============================] - 0s 70us/step - loss: 148.6218 - regression_loss: 56.6761 - binary_classification_loss: 23.9945 - treatment_accuracy: 0.8561 - track_epsilon: 0.0017 - val_loss: 183.6684 - val_regression_loss: 59.4958 - val_binary_classification_loss: 36.1848 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0041
Epoch 10/300
597/597 [==============================] - 0s 80us/step - loss: 147.2199 - regression_loss: 55.6598 - binary_classification_loss: 23.9914 - treatment_accuracy: 0.8561 - track_epsilon: 0.0037 - val_loss: 187.5044 - val_regression_loss: 60.6762 - val_binary_classification_loss: 36.0621 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 11/300
597/597 [==============================] - ETA: 0s - loss: 184.3614 - regression_loss: 74.6064 - binary_classification_loss: 30.5212 - treatment_accuracy: 0.8281 - track_epsilon: 0.001 - 0s 67us/step - loss: 149.2038 - regression_loss: 56.6065 - binary_classification_loss: 23.9720 - treatment_accuracy: 0.8465 - track_epsilon: 0.0028 - val_loss: 185.2099 - val_regression_loss: 59.7681 - val_binary_classification_loss: 36.0292 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 1.1392e-04
Epoch 12/300
597/597 [==============================] - 0s 65us/step - loss: 144.7243 - regression_loss: 56.0806 - binary_classification_loss: 23.9684 - treatment_accuracy: 0.8401 - track_epsilon: 0.0012 - val_loss: 182.7289 - val_regression_loss: 59.1949 - val_binary_classification_loss: 36.1361 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 13/300
597/597 [==============================] - 0s 82us/step - loss: 146.8869 - regression_loss: 56.1869 - binary_classification_loss: 23.9454 - treatment_accuracy: 0.8593 - track_epsilon: 0.0012 - val_loss: 181.4378 - val_regression_loss: 58.5800 - val_binary_classification_loss: 36.0047 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020
Epoch 14/300
597/597 [==============================] - 0s 64us/step - loss: 145.5166 - regression_loss: 55.1947 - binary_classification_loss: 23.9264 - treatment_accuracy: 0.8497 - track_epsilon: 0.0028 - val_loss: 183.9117 - val_regression_loss: 59.8171 - val_binary_classification_loss: 36.0495 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0023
Epoch 15/300
597/597 [==============================] - 0s 67us/step - loss: 147.9824 - regression_loss: 55.9960 - binary_classification_loss: 23.9193 - treatment_accuracy: 0.8561 - track_epsilon: 0.0013 - val_loss: 184.8934 - val_regression_loss: 60.1771 - val_binary_classification_loss: 36.0228 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0027
Epoch 16/300
597/597 [==============================] - 0s 67us/step - loss: 146.7458 - regression_loss: 55.8055 - binary_classification_loss: 23.8981 - treatment_accuracy: 0.8561 - track_epsilon: 0.0022 - val_loss: 184.1797 - val_regression_loss: 59.5255 - val_binary_classification_loss: 35.9737 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 9.6159e-04
Epoch 17/300
597/597 [==============================] - 0s 76us/step - loss: 145.5521 - regression_loss: 55.3490 - binary_classification_loss: 23.8978 - treatment_accuracy: 0.8529 - track_epsilon: 0.0014 - val_loss: 183.2418 - val_regression_loss: 59.2208 - val_binary_classification_loss: 35.7738 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0039

Epoch 00017: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-06.
Epoch 18/300
597/597 [==============================] - 0s 74us/step - loss: 144.7616 - regression_loss: 54.9449 - binary_classification_loss: 23.8797 - treatment_accuracy: 0.8561 - track_epsilon: 0.0050 - val_loss: 183.1350 - val_regression_loss: 59.2228 - val_binary_classification_loss: 35.7351 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0039
Epoch 19/300
597/597 [==============================] - 0s 67us/step - loss: 141.8471 - regression_loss: 54.6760 - binary_classification_loss: 23.8693 - treatment_accuracy: 0.8561 - track_epsilon: 0.0020 - val_loss: 182.4961 - val_regression_loss: 59.0138 - val_binary_classification_loss: 35.8385 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 2.2382e-04
Epoch 20/300
597/597 [==============================] - 0s 75us/step - loss: 143.4988 - regression_loss: 54.6465 - binary_classification_loss: 23.8661 - treatment_accuracy: 0.8593 - track_epsilon: 9.6414e-04 - val_loss: 183.4780 - val_regression_loss: 59.2525 - val_binary_classification_loss: 35.8081 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 6.3370e-04
Epoch 21/300
597/597 [==============================] - 0s 69us/step - loss: 143.2713 - regression_loss: 54.8240 - binary_classification_loss: 23.8655 - treatment_accuracy: 0.8529 - track_epsilon: 5.8381e-04 - val_loss: 182.7529 - val_regression_loss: 59.1405 - val_binary_classification_loss: 35.8905 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0014
Epoch 22/300
597/597 [==============================] - 0s 73us/step - loss: 144.5639 - regression_loss: 54.9520 - binary_classification_loss: 23.8562 - treatment_accuracy: 0.8497 - track_epsilon: 0.0011 - val_loss: 182.2272 - val_regression_loss: 58.9541 - val_binary_classification_loss: 35.8026 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020
Epoch 23/300
597/597 [==============================] - 0s 88us/step - loss: 144.3322 - regression_loss: 54.4709 - binary_classification_loss: 23.8485 - treatment_accuracy: 0.8465 - track_epsilon: 0.0033 - val_loss: 183.0935 - val_regression_loss: 59.1250 - val_binary_classification_loss: 35.7517 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0026
Epoch 24/300
597/597 [==============================] - 0s 65us/step - loss: 143.6903 - regression_loss: 54.4800 - binary_classification_loss: 23.8423 - treatment_accuracy: 0.8561 - track_epsilon: 0.0013 - val_loss: 182.7994 - val_regression_loss: 59.0775 - val_binary_classification_loss: 35.7825 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0011

Epoch 00024: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-06.
Epoch 25/300
597/597 [==============================] - 0s 69us/step - loss: 142.5934 - regression_loss: 54.3459 - binary_classification_loss: 23.8378 - treatment_accuracy: 0.8529 - track_epsilon: 0.0012 - val_loss: 182.6808 - val_regression_loss: 59.0681 - val_binary_classification_loss: 35.7840 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0015
Epoch 26/300
597/597 [==============================] - 0s 69us/step - loss: 144.1265 - regression_loss: 54.4636 - binary_classification_loss: 23.8337 - treatment_accuracy: 0.8593 - track_epsilon: 0.0011 - val_loss: 183.0977 - val_regression_loss: 59.1001 - val_binary_classification_loss: 35.7414 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0011
Epoch 27/300
597/597 [==============================] - 0s 67us/step - loss: 143.5707 - regression_loss: 54.1999 - binary_classification_loss: 23.8293 - treatment_accuracy: 0.8497 - track_epsilon: 0.0016 - val_loss: 182.1685 - val_regression_loss: 58.8281 - val_binary_classification_loss: 35.7402 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019
Epoch 28/300
597/597 [==============================] - 0s 70us/step - loss: 144.1436 - regression_loss: 54.1982 - binary_classification_loss: 23.8266 - treatment_accuracy: 0.8561 - track_epsilon: 0.0018 - val_loss: 182.2616 - val_regression_loss: 58.8418 - val_binary_classification_loss: 35.7468 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0017
Epoch 29/300
597/597 [==============================] - 0s 69us/step - loss: 143.1436 - regression_loss: 54.2246 - binary_classification_loss: 23.8253 - treatment_accuracy: 0.8497 - track_epsilon: 0.0017 - val_loss: 182.5233 - val_regression_loss: 58.9060 - val_binary_classification_loss: 35.7543 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0017

Epoch 00029: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-06.
Epoch 30/300
597/597 [==============================] - 0s 69us/step - loss: 142.9970 - regression_loss: 54.0639 - binary_classification_loss: 23.8208 - treatment_accuracy: 0.8625 - track_epsilon: 0.0016 - val_loss: 182.8976 - val_regression_loss: 59.0591 - val_binary_classification_loss: 35.7240 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0017
Epoch 31/300
597/597 [==============================] - 0s 67us/step - loss: 143.8003 - regression_loss: 54.1442 - binary_classification_loss: 23.8190 - treatment_accuracy: 0.8529 - track_epsilon: 0.0021 - val_loss: 182.6798 - val_regression_loss: 59.0072 - val_binary_classification_loss: 35.7270 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0022
Epoch 32/300
597/597 [==============================] - 0s 66us/step - loss: 143.4029 - regression_loss: 54.1355 - binary_classification_loss: 23.8157 - treatment_accuracy: 0.8561 - track_epsilon: 0.0023 - val_loss: 182.5541 - val_regression_loss: 58.9682 - val_binary_classification_loss: 35.7180 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020
Epoch 33/300
597/597 [==============================] - 0s 73us/step - loss: 142.1901 - regression_loss: 54.0516 - binary_classification_loss: 23.8148 - treatment_accuracy: 0.8529 - track_epsilon: 0.0018 - val_loss: 183.0714 - val_regression_loss: 59.1151 - val_binary_classification_loss: 35.7216 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0015
Epoch 34/300
597/597 [==============================] - 0s 65us/step - loss: 140.2360 - regression_loss: 54.0345 - binary_classification_loss: 23.8139 - treatment_accuracy: 0.8497 - track_epsilon: 0.0016 - val_loss: 182.7475 - val_regression_loss: 59.0084 - val_binary_classification_loss: 35.7426 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0015
Epoch 35/300
597/597 [==============================] - 0s 81us/step - loss: 142.8741 - regression_loss: 54.0038 - binary_classification_loss: 23.8122 - treatment_accuracy: 0.8433 - track_epsilon: 0.0013 - val_loss: 182.6587 - val_regression_loss: 58.9828 - val_binary_classification_loss: 35.7345 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0014
Epoch 36/300
597/597 [==============================] - 0s 73us/step - loss: 143.2542 - regression_loss: 54.0470 - binary_classification_loss: 23.8112 - treatment_accuracy: 0.8497 - track_epsilon: 0.0015 - val_loss: 182.7340 - val_regression_loss: 59.0171 - val_binary_classification_loss: 35.7291 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0016
Epoch 37/300
597/597 [==============================] - 0s 63us/step - loss: 143.1216 - regression_loss: 53.9242 - binary_classification_loss: 23.8101 - treatment_accuracy: 0.8497 - track_epsilon: 0.0018 - val_loss: 182.6380 - val_regression_loss: 58.9966 - val_binary_classification_loss: 35.7090 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019
Epoch 38/300
597/597 [==============================] - 0s 74us/step - loss: 142.9598 - regression_loss: 53.9560 - binary_classification_loss: 23.8082 - treatment_accuracy: 0.8497 - track_epsilon: 0.0019 - val_loss: 182.5107 - val_regression_loss: 58.9566 - val_binary_classification_loss: 35.7025 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 39/300
597/597 [==============================] - 0s 70us/step - loss: 142.1619 - regression_loss: 53.9813 - binary_classification_loss: 23.8070 - treatment_accuracy: 0.8497 - track_epsilon: 0.0015 - val_loss: 182.6606 - val_regression_loss: 58.9962 - val_binary_classification_loss: 35.7107 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0015

Epoch 00039: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-07.
Epoch 40/300
597/597 [==============================] - 0s 65us/step - loss: 143.1522 - regression_loss: 53.9099 - binary_classification_loss: 23.8051 - treatment_accuracy: 0.8561 - track_epsilon: 0.0017 - val_loss: 182.5675 - val_regression_loss: 58.9788 - val_binary_classification_loss: 35.6982 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 41/300
597/597 [==============================] - 0s 70us/step - loss: 142.9669 - regression_loss: 54.0113 - binary_classification_loss: 23.8046 - treatment_accuracy: 0.8465 - track_epsilon: 0.0017 - val_loss: 182.7173 - val_regression_loss: 59.0139 - val_binary_classification_loss: 35.6968 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 42/300
597/597 [==============================] - 0s 71us/step - loss: 142.9812 - regression_loss: 53.9480 - binary_classification_loss: 23.8039 - treatment_accuracy: 0.8625 - track_epsilon: 0.0019 - val_loss: 182.5140 - val_regression_loss: 58.9547 - val_binary_classification_loss: 35.7111 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020
Epoch 43/300
597/597 [==============================] - 0s 73us/step - loss: 143.3660 - regression_loss: 53.9199 - binary_classification_loss: 23.8027 - treatment_accuracy: 0.8529 - track_epsilon: 0.0018 - val_loss: 182.6215 - val_regression_loss: 58.9790 - val_binary_classification_loss: 35.7084 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 44/300
597/597 [==============================] - 0s 80us/step - loss: 142.2327 - regression_loss: 53.8626 - binary_classification_loss: 23.8023 - treatment_accuracy: 0.8561 - track_epsilon: 0.0019 - val_loss: 182.6031 - val_regression_loss: 58.9846 - val_binary_classification_loss: 35.7026 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019

Epoch 00044: ReduceLROnPlateau reducing learning rate to 3.12499992105586e-07.
Epoch 45/300
597/597 [==============================] - 0s 70us/step - loss: 141.4800 - regression_loss: 53.8688 - binary_classification_loss: 23.8012 - treatment_accuracy: 0.8497 - track_epsilon: 0.0018 - val_loss: 182.5443 - val_regression_loss: 58.9709 - val_binary_classification_loss: 35.7019 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019
Epoch 46/300
597/597 [==============================] - 0s 75us/step - loss: 141.3155 - regression_loss: 53.8269 - binary_classification_loss: 23.8007 - treatment_accuracy: 0.8497 - track_epsilon: 0.0018 - val_loss: 182.6176 - val_regression_loss: 58.9883 - val_binary_classification_loss: 35.7016 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 47/300
597/597 [==============================] - 0s 73us/step - loss: 143.1740 - regression_loss: 53.8460 - binary_classification_loss: 23.8005 - treatment_accuracy: 0.8497 - track_epsilon: 0.0018 - val_loss: 182.6459 - val_regression_loss: 59.0014 - val_binary_classification_loss: 35.6936 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 48/300
597/597 [==============================] - 0s 85us/step - loss: 142.8012 - regression_loss: 53.8343 - binary_classification_loss: 23.7998 - treatment_accuracy: 0.8593 - track_epsilon: 0.0019 - val_loss: 182.6606 - val_regression_loss: 59.0031 - val_binary_classification_loss: 35.6939 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019
Epoch 49/300
597/597 [==============================] - 0s 71us/step - loss: 142.5543 - regression_loss: 53.8559 - binary_classification_loss: 23.7995 - treatment_accuracy: 0.8497 - track_epsilon: 0.0019 - val_loss: 182.6408 - val_regression_loss: 58.9984 - val_binary_classification_loss: 35.6932 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019

Epoch 00049: ReduceLROnPlateau reducing learning rate to 1.56249996052793e-07.
Epoch 50/300
597/597 [==============================] - 0s 74us/step - loss: 142.2061 - regression_loss: 53.8305 - binary_classification_loss: 23.7990 - treatment_accuracy: 0.8465 - track_epsilon: 0.0019 - val_loss: 182.6622 - val_regression_loss: 59.0048 - val_binary_classification_loss: 35.6935 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019
Epoch 51/300
597/597 [==============================] - 0s 71us/step - loss: 144.1153 - regression_loss: 53.8202 - binary_classification_loss: 23.7987 - treatment_accuracy: 0.8561 - track_epsilon: 0.0019 - val_loss: 182.6417 - val_regression_loss: 58.9983 - val_binary_classification_loss: 35.6931 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019
Epoch 52/300
597/597 [==============================] - 0s 75us/step - loss: 142.2625 - regression_loss: 53.8170 - binary_classification_loss: 23.7987 - treatment_accuracy: 0.8497 - track_epsilon: 0.0018 - val_loss: 182.6301 - val_regression_loss: 58.9968 - val_binary_classification_loss: 35.6929 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
Epoch 53/300
597/597 [==============================] - 0s 77us/step - loss: 142.2695 - regression_loss: 53.8087 - binary_classification_loss: 23.7985 - treatment_accuracy: 0.8593 - track_epsilon: 0.0018 - val_loss: 182.6330 - val_regression_loss: 58.9989 - val_binary_classification_loss: 35.6917 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
In [12]:
df_preds = pd.DataFrame([s_ite.ravel(),
                          t_ite.ravel(),
                          x_ite.ravel(),
                          r_ite.ravel(),
                          dragon_ite.ravel(),
                          tau.ravel(),
                          treatment.ravel(),
                          y.ravel()],
                       index=['S','T','X','R','dragonnet','tau','w','y']).T

df_cumgain = get_cumgain(df_preds)
In [13]:
df_result = pd.DataFrame([s_ate, t_ate, x_ate, r_ate, dragon_ate, tau.mean()],
                     index=['S','T','X','R','dragonnet','actual'], columns=['ATE'])
df_result['MAE'] = [mean_absolute_error(t,p) for t,p in zip([s_ite, t_ite, x_ite, r_ite, dragon_ite],
                                                            [tau.values.reshape(-1,1)]*5 )
                ] + [None]
df_result['AUUC'] = auuc_score(df_preds)
In [14]:
df_result
Out[14]:
ATE MAE AUUC
S 4.054511 1.027666 0.575822
T 4.100199 0.980788 0.580929
X 4.020589 1.115693 0.564634
R 3.867016 2.033445 0.557536
dragonnet 4.003578 1.182555 0.553948
actual 4.098887 NaN NaN
In [15]:
plot_gain(df_preds)

causalml Synthetic Data Generation Method

In [16]:
y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000)

X_train, X_val, y_train, y_val, w_train, w_val, tau_train, tau_val, b_train, b_val, e_train, e_val = \
    train_test_split(X, y, w, tau, b, e, test_size=0.2, random_state=123, shuffle=True)

preds_dict_train = {}
preds_dict_valid = {}

preds_dict_train['Actuals'] = tau_train
preds_dict_valid['Actuals'] = tau_val

preds_dict_train['generated_data'] = {
    'y': y_train,
    'X': X_train,
    'w': w_train,
    'tau': tau_train,
    'b': b_train,
    'e': e_train}
preds_dict_valid['generated_data'] = {
    'y': y_val,
    'X': X_val,
    'w': w_val,
    'tau': tau_val,
    'b': b_val,
    'e': e_val}

# Predict p_hat because e would not be directly observed in real-life
p_model = ElasticNetPropensityModel()
p_hat_train = p_model.fit_predict(X_train, w_train)
p_hat_val = p_model.fit_predict(X_val, w_val)

for base_learner, label_l in zip([BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor],
                                 ['S', 'T', 'X', 'R']):
    for model, label_m in zip([LinearRegression, XGBRegressor], ['LR', 'XGB']):
        # RLearner will need to fit on the p_hat
        if label_l != 'R':
            learner = base_learner(model())
            # fit the model on training data only
            learner.fit(X=X_train, treatment=w_train, y=y_train)
            try:
                preds_dict_train['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_train, p=p_hat_train).flatten()
                preds_dict_valid['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_val, p=p_hat_val).flatten()
            except TypeError:
                preds_dict_train['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_train, treatment=w_train, y=y_train).flatten()
                preds_dict_valid['{} Learner ({})'.format(
                    label_l, label_m)] = learner.predict(X=X_val, treatment=w_val, y=y_val).flatten()
        else:
            learner = base_learner(model())
            learner.fit(X=X_train, p=p_hat_train, treatment=w_train, y=y_train)
            preds_dict_train['{} Learner ({})'.format(
                label_l, label_m)] = learner.predict(X=X_train).flatten()
            preds_dict_valid['{} Learner ({})'.format(
                label_l, label_m)] = learner.predict(X=X_val).flatten()

learner = DragonNet(verbose=False)
learner.fit(X_train, treatment=w_train, y=y_train)
preds_dict_train['DragonNet'] = learner.predict_tau(X=X_train).flatten()
preds_dict_valid['DragonNet'] = learner.predict_tau(X=X_val).flatten()
In [17]:
actuals_train = preds_dict_train['Actuals']
actuals_validation = preds_dict_valid['Actuals']

synthetic_summary_train = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_train)] for label, preds
                                        in preds_dict_train.items() if 'generated' not in label.lower()},
                                       index=['ATE', 'MSE']).T
synthetic_summary_train['Abs % Error of ATE'] = np.abs(
    (synthetic_summary_train['ATE']/synthetic_summary_train.loc['Actuals', 'ATE']) - 1)

synthetic_summary_validation = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_validation)]
                                             for label, preds in preds_dict_valid.items()
                                             if 'generated' not in label.lower()},
                                            index=['ATE', 'MSE']).T
synthetic_summary_validation['Abs % Error of ATE'] = np.abs(
    (synthetic_summary_validation['ATE']/synthetic_summary_validation.loc['Actuals', 'ATE']) - 1)

# calculate kl divergence for training
for label in synthetic_summary_train.index:
    stacked_values = np.hstack((preds_dict_train[label], actuals_train))
    stacked_low = np.percentile(stacked_values, 0.1)
    stacked_high = np.percentile(stacked_values, 99.9)
    bins = np.linspace(stacked_low, stacked_high, 100)

    distr = np.histogram(preds_dict_train[label], bins=bins)[0]
    distr = np.clip(distr/distr.sum(), 0.001, 0.999)
    true_distr = np.histogram(actuals_train, bins=bins)[0]
    true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)

    kl = entropy(distr, true_distr)
    synthetic_summary_train.loc[label, 'KL Divergence'] = kl

# calculate kl divergence for validation
for label in synthetic_summary_validation.index:
    stacked_values = np.hstack((preds_dict_valid[label], actuals_validation))
    stacked_low = np.percentile(stacked_values, 0.1)
    stacked_high = np.percentile(stacked_values, 99.9)
    bins = np.linspace(stacked_low, stacked_high, 100)

    distr = np.histogram(preds_dict_valid[label], bins=bins)[0]
    distr = np.clip(distr/distr.sum(), 0.001, 0.999)
    true_distr = np.histogram(actuals_validation, bins=bins)[0]
    true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)

    kl = entropy(distr, true_distr)
    synthetic_summary_validation.loc[label, 'KL Divergence'] = kl
In [18]:
df_preds_train = pd.DataFrame([preds_dict_train['S Learner (LR)'].ravel(),
                               preds_dict_train['S Learner (XGB)'].ravel(),
                               preds_dict_train['T Learner (LR)'].ravel(),
                               preds_dict_train['T Learner (XGB)'].ravel(),
                               preds_dict_train['X Learner (LR)'].ravel(),
                               preds_dict_train['X Learner (XGB)'].ravel(),
                               preds_dict_train['R Learner (LR)'].ravel(),
                               preds_dict_train['R Learner (XGB)'].ravel(),                               
                               preds_dict_train['DragonNet'].ravel(),
                               preds_dict_train['generated_data']['tau'].ravel(),
                               preds_dict_train['generated_data']['w'].ravel(),
                               preds_dict_train['generated_data']['y'].ravel()],
                              index=['S Learner (LR)','S Learner (XGB)',
                                     'T Learner (LR)','T Learner (XGB)',
                                     'X Learner (LR)','X Learner (XGB)',
                                     'R Learner (LR)','R Learner (XGB)',
                                     'DragonNet','tau','w','y']).T

synthetic_summary_train['AUUC'] = auuc_score(df_preds_train).iloc[:-1]
In [19]:
df_preds_validation = pd.DataFrame([preds_dict_valid['S Learner (LR)'].ravel(),
                               preds_dict_valid['S Learner (XGB)'].ravel(),
                               preds_dict_valid['T Learner (LR)'].ravel(),
                               preds_dict_valid['T Learner (XGB)'].ravel(),
                               preds_dict_valid['X Learner (LR)'].ravel(),
                               preds_dict_valid['X Learner (XGB)'].ravel(),
                               preds_dict_valid['R Learner (LR)'].ravel(),
                               preds_dict_valid['R Learner (XGB)'].ravel(),                               
                               preds_dict_valid['DragonNet'].ravel(),
                               preds_dict_valid['generated_data']['tau'].ravel(),
                               preds_dict_valid['generated_data']['w'].ravel(),
                               preds_dict_valid['generated_data']['y'].ravel()],
                              index=['S Learner (LR)','S Learner (XGB)',
                                     'T Learner (LR)','T Learner (XGB)',
                                     'X Learner (LR)','X Learner (XGB)',
                                     'R Learner (LR)','R Learner (XGB)',
                                     'DragonNet','tau','w','y']).T

synthetic_summary_validation['AUUC'] = auuc_score(df_preds_validation).iloc[:-1]
In [20]:
synthetic_summary_train
Out[20]:
ATE MSE Abs % Error of ATE KL Divergence AUUC
Actuals 0.484486 0.000000 0.000000 0.000000 NaN
S Learner (LR) 0.528743 0.044194 0.091349 3.473087 0.492660
S Learner (XGB) 0.358208 0.310652 0.260643 0.817620 0.544115
T Learner (LR) 0.493815 0.022688 0.019255 0.289978 0.610855
T Learner (XGB) 0.397053 1.350928 0.180465 1.452143 0.521719
X Learner (LR) 0.493815 0.022688 0.019255 0.289978 0.610855
X Learner (XGB) 0.341013 0.620823 0.296134 1.098308 0.534908
R Learner (LR) 0.471610 0.030968 0.026577 0.378494 0.614607
R Learner (XGB) 0.413902 4.850255 0.145688 1.950556 0.510872
DragonNet 0.415214 0.038613 0.142980 0.405291 0.612157
In [21]:
synthetic_summary_validation
Out[21]:
ATE MSE Abs % Error of ATE KL Divergence AUUC
Actuals 0.511242 0.000000 0.000000 0.000000 NaN
S Learner (LR) 0.528743 0.042236 0.034233 4.574498 0.494022
S Learner (XGB) 0.434208 0.260496 0.150680 0.854890 0.544212
T Learner (LR) 0.541503 0.025840 0.059191 0.686602 0.604712
T Learner (XGB) 0.483404 0.679398 0.054451 1.215394 0.526918
X Learner (LR) 0.541503 0.025840 0.059191 0.686602 0.604712
X Learner (XGB) 0.328046 0.352812 0.358335 1.310631 0.535895
R Learner (LR) 0.526797 0.034872 0.030426 0.732823 0.608290
R Learner (XGB) 0.377533 2.174835 0.261537 1.734253 0.512412
DragonNet 0.464221 0.037349 0.091973 0.695660 0.606139
In [22]:
plot_gain(df_preds_train)
In [23]:
plot_gain(df_preds_validation)
In [ ]: