In this introductory tutorial, we go through the different steps of a CausalEGM workflow.
Users can use CausalEGM by Python API or R API or a single command line after installation.
First of all, you need to install CausalEGM, please refer to the install page.
Use CausalEGM Python APIimport CausalEGM as cegm print("Currently use version v%s of CausalEGM."%cegm.__version__)
Currently use version v0.4.0 of CausalEGM.Configuring a CausalEGM model
Before creating a CausalEGM model, a python dict
object should be created for deploying the hyperparameters for a CausalEGM model, which include the dimensions for latent features, neural network architecture, etc.
The detailed hyperparameters are described as follows.
Config Parameters
Description
output_dir
Output directory to save the results during the model training. Default: “.”
dataset
Dataset name for indicating the input data. Default: “Mydata”
z_dims
Latent dimensions of the encoder outputs (e(V)_0~3). Default: [3,6,3,6]
v_dim
Dimension of covariates.
lr
Learning rate. Default: 0.0002
g_units
Number of units for decoder/generator network G. Default: [64,64,64,64,64].
e_units
Number of units for encoder network E. Default: [64,64,64,64,64].
f_units
Number of units for F network. Default: [64,32,8].
h_units
Number of units for H network. Default: [64,32,8].
dz_units
Number of units for discriminator network in latent space. Default: [64,32,8].
dz_units
Number of units for discriminator network in covariate space. Default: [64,32,8].
alpha
Coefficient for reconstruction loss. Default: 1.
beta
Coefficient for roundtrip loss. Default: 1.
gamma
Coefficient for gradient penalty loss. Default: 10.
g_d_freq
Frequency for updating discriminators and generators. Default: 5.
save_res
Whether to save results during the model training. Default: True.
save_model
Whether to save the model wegihts. Default: False.
binary_treatment
Whether to use binary treatment setting. Default: True.
use_z_rec
Use the reconstruction for latent features. Default: True.
use_v_gan
Use the GAN distribution match for covariates. Default: True.
x_min
Left bound for dose-response interval in continuous treatment settings. Default: 0.
x_max
Right bound for dose-response interval in continuous treatment settings. Default: 3.
Tips
Config parameters are necessary for creating a CausalEGM model. Here are some tips for configuring parameters.
z_dims has a noticeable impact on the performance, please refer to src/configs for guidance.
If save_res is True, results during training will be saved at output_dir
use_v_gan is recommended to be True under binary treatment setting and False under continous treatment setting.
We provide many templates of the hyperparameters in CausalEGM/src/configs
folder for different datasets/settings.
Users can use yaml
to load the hyperparameters as a python dict
object easily.
import yaml params = yaml.safe_load(open('../../src/configs/Semi_acic.yaml', 'r')) print(params)
{'dataset': 'Semi_acic', 'output_dir': '.', 'v_dim': 177, 'z_dims': [3, 6, 3, 6], 'lr': 0.0002, 'alpha': 1, 'beta': 1, 'gamma': 10, 'g_d_freq': 5, 'g_units': [64, 64, 64, 64, 64], 'e_units': [64, 64, 64, 64, 64], 'f_units': [64, 32, 8], 'h_units': [64, 32, 8], 'dz_units': [64, 32, 8], 'dv_units': [64, 32, 8], 'save_res': True, 'save_model': False, 'binary_treatment': True, 'use_z_rec': True, 'use_v_gan': True}Initilizing a CausalEGM model
It is super easy to create a CausalEGM model when the hyperparameters (params
) are prepared.
timestamp
should set to be None if you want to train a model from scratch rather than loading a pretrained model.
random_seed
denotes the random seed used for reproducing the results.
model = cegm.CausalEGM(params=params,random_seed=123)Data preparation
Before training a CausalEGM model, we need to provide the data in a triplet, which contains treatment (x
), potential outcome (y
), and covariates (v
).
Note that treatment (x
) and potential outcome (y
) should be either 1-dimensional array or with an additional axes of length one. Covariates should be a two-dimensional array.
Tips
There are three different ways to feed the training data to a CausalEGM model.
Loading an existing dataset from a data sampler.
Loading data from a python triplet list [x,y,v].
Loading data from a csv, txt, or npz file, where an example is provided at [path_to_CausalEGM]/test/demo.csv
.
#get the data from the ACIC 2018 competition dataset with a specified ufid. x,y,v = cegm.Semi_acic_sampler(path='data/ACIC_2018',ufid='d5bd8e4814904c58a79d7cdcd7c2a1bb').load_all() print(x.shape,y.shape,v.shape)
(50000, 1) (50000, 1) (50000, 177)Run CausalEGM model training
Once data is ready, CausalEGM can be trained with the following parameters
Training parameters
Description
data
List object containing the triplet data [X,Y,V]. Default: None.
data_file
Str object denoting the path to the input file (csv, txt, npz). Default: None.
sep
Str object denoting the delimiter for the input file. *Default: \t*.
header
Int object denoting row number(s) to use as the column names. Default: 0.
normalize
Bool object denoting whether apply standard normalization to covariates. Default: False.
batch_size
Int object denoting the batch size in training. Default: 32.
n_iter
Int object denoting the training iterations. Default: 30000.
batches_per_eval
Int object denoting the number of iterations per evaluation. Default: 500.
batches_per_save
Int object denoting the number of iterations per save. Default: 10000.
startoff
Int object denoting the beginning iterations to jump without save and evaluation. Defalt: 0.
verbose
Bool object denoting whether showing the progress bar. Default: True.
save_format
Str object denoting the format (csv, txt, npz) to save the results. Default: txt.
model.train(data=[x,y,v],n_iter=100,save_format='npy',verbose=False)
The average treatment effect (ATE) is -0.0064516705
We train a CausalEGM for 100 iterations for illustration purpose, n_iter
is recommended to be 30000.
The results are saved based on the output_dir
parameter where causal_pre_at_[iter_number].[format]
denotes the individual treatment effect (ITE) in binary treatment settings and average dose-response values in continuous treatment settings.
iter_number
denotes the training iteraction and format
is determined by save_format
, which can be csv
,txt
, or npz
.
When installing the CausalEGM by pip
, setuptools will add the console script to PATH and make it available for general use. This has advantage of being generalizeable to non-python scripts! This CLI takes a text file as input.
usage: causalEGM [-h] -output_dir OUTPUT_DIR -input INPUT [-dataset DATASET] [--save-model | --no-save-model] [--binary-treatment | --no-binary-treatment] [-z_dims Z_DIMS [Z_DIMS ...]] [-lr LR] [-alpha ALPHA] [-beta BETA] [-gamma GAMMA] [-g_d_freq G_D_FREQ] [-g_units G_UNITS [G_UNITS ...]] [-e_units E_UNITS [E_UNITS ...]] [-f_units F_UNITS [F_UNITS ...]] [-h_units H_UNITS [H_UNITS ...]] [-dz_units DZ_UNITS [DZ_UNITS ...]] [-dv_units DV_UNITS [DV_UNITS ...]] [--use-z-rec | --no-use-z-rec] [--use-v-gan | --no-use-v-gan] [-batch_size BATCH_SIZE] [-n_iter N_ITER] [-startoff STARTOFF] [-batches_per_eval BATCHES_PER_EVAL] [-save_format SAVE_FORMAT] [--save_res | --no-save_res] [-seed SEED] CausalEGM: A general causal inference framework by encoding generative modeling - v0.4.0 optional arguments: -h, --help show this help message and exit -output_dir OUTPUT_DIR Output directory -input INPUT Input data file must be in csv or txt or npz format -dataset DATASET Dataset name --save-model, --no-save-model whether to save model. (default: True) --binary-treatment, --no-binary-treatment whether use binary treatment setting. (default: True) -z_dims Z_DIMS [Z_DIMS ...] Latent dimensions of the four encoder outputs e(V)_0~3. -lr LR Learning rate for the optimizer (default: 0.0002). -alpha ALPHA Coefficient for reconstruction loss (default: 1). -beta BETA Coefficient for treatment and outcome MSE loss (default: 1). -gamma GAMMA Coefficient for gradient penalty loss (default: 10). -g_d_freq G_D_FREQ Frequency for updating discriminators and generators (default: 5). -g_units G_UNITS [G_UNITS ...] Number of units for generator/decoder network (default: [64,64,64,64,64]). -e_units E_UNITS [E_UNITS ...] Number of units for encoder network (default: [64,64,64,64,64]). -f_units F_UNITS [F_UNITS ...] Number of units for f network (default: [64,32,8]). -h_units H_UNITS [H_UNITS ...] Number of units for h network (default: [64,32,8]). -dz_units DZ_UNITS [DZ_UNITS ...] Number of units for discriminator network in latent space (default: [64,32,8]). -dv_units DV_UNITS [DV_UNITS ...] Number of units for discriminator network in confounder space (default: [64,32,8]). --use-z-rec, --no-use-z-rec Use the reconstruction for latent features. (default: True) --use-v-gan, --no-use-v-gan Use the GAN distribution match for covariates. (default: True) -batch_size BATCH_SIZE Batch size (default: 32). -n_iter N_ITER Number of iterations (default: 30000). -startoff STARTOFF Iteration for starting evaluation (default: 0). -batches_per_eval BATCHES_PER_EVAL Number of iterations per evaluation (default: 500). -save_format SAVE_FORMAT Saving format (default: txt) --save_res, --no-save_res Whether to save results during training. (default: True) -seed SEED Random seed for reproduction (default: 123).
The parameters are consistent with the Python APIs
. Here, we use a demo data for an example!
!causalEGM -input test/demo.csv -output_dir ./ -n_iter 100 -startoff 0 -batches_per_eval 50
2023-03-20 12:57:23.620713: W tensorflow/stream_executor/cuda/cuda_driver.cc:374] A non-primary context 0x57fa5c0 for device 0 exists before initializing the StreamExecutor. The primary context is now 0. We haven't verified StreamExecutor works with that. 2023-03-20 12:57:23.620890: F tensorflow/core/platform/statusor.cc:33] Attempting to fetch value instead of handling error INTERNAL: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_DEVICE_UNAVAILABLE: CUDA-capable device(s) is/are busy or unavailable
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4