PIE
The PIE
package implements Partially Interpretable Estimators (PIE), a framework that jointly train an interpretable model and a black-box model to achieve high predictive performance as well as partial model transparency.
To install the development version from GitHub, run the following:
# Install the R package from CRAN
install.packages("PIE")
Getting Started
This section demonstrates how to generate synthetic data for transfer learning and apply the ART framework using different models.
Generate DataThe function data_process()
allows you to process dataset into the format that fits with PIE model, including cross-validation dataset (such as training, validation and testing) and group indicators for group lasso.
library(PIE)
# Load the training data
data("winequality")
# Which columns are numerical?
num_col <- 1:11
# Which columns are categorical?
cat_col <- 12
# Which column is the response?
y_col <- ncol(winequality)
# Data Processing
dat <- data_process(X = as.matrix(winequality[, -y_col]),
y = winequality[, y_col],
num_col = num_col, cat_col = cat_col, y_col = y_col)
Fitting PIE
Once the data is prepared, you can use the PIE_fit()
function to train PIE model. In this example, we fit only with 5 iterations using group lasso and XGBoost models.
# Fit a PIE model
fold <- 1
fit <- PIE_fit(
X = dat$spl_train_X[[fold]],
y = dat$train_y[[fold]],
lasso_group = dat$lasso_group,
X_orig = dat$orig_train_X[[fold]],
lambda1 = 0.01, lambda2 = 0.01, iter = 5, eta = 0.05, nrounds = 200
)
Predicting PIE
Once your PIE model is trained, you can use the PIE_predict()
function to predict on test data.
# Prediction
pred <- predict(fit,
X = dat$spl_validation_X[[fold]],
X_orig = dat$orig_validation_X[[fold]])
Evaluate PIE
You can evaluate your PIE modelâs performance with RPE()
, which has formula \(RPE=\frac{\sum_i(y_i-\hat{y_i})^2}{\sum_i(y_i-\bar{y})^2}\), where \(\bar{y} = \frac{1}{n}\sum_i^n y_i\).
# Validation
val_rrmse_test <- RPE(pred$total, dat$validation_y[[fold]])
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4