Machine Learning Interpretation: IML and Alternatives
Using the Interpretable Machine Learning Package
(IML
)
We’re going to run through a random forest example (including a
single tree example too) and use Interpretable Machine Learning
(IML
) to model our output from the random forest. As we’ve
distributed this lab as a standalone, please feel free to come ask us
about anything contained herein during office hours.
Setup
Estimating a Single Tree
Report Model Statistics on Training Set
## Confusion Matrix and Statistics
##
## Reference
## Prediction Clinton Trump
## Clinton 266 13
## Trump 36 219
##
## Accuracy : 0.9082
## 95% CI : (0.8805, 0.9313)
## No Information Rate : 0.5655
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.8154
##
## Mcnemar's Test P-Value : 0.001673
##
## Sensitivity : 0.8808
## Specificity : 0.9440
## Pos Pred Value : 0.9534
## Neg Pred Value : 0.8588
## Prevalence : 0.5655
## Detection Rate : 0.4981
## Detection Prevalence : 0.5225
## Balanced Accuracy : 0.9124
##
## 'Positive' Class : Clinton
##
Report Model Statistics on Test Set
## Confusion Matrix and Statistics
##
## Reference
## Prediction Clinton Trump
## Clinton 266 13
## Trump 36 219
##
## Accuracy : 0.9082
## 95% CI : (0.8805, 0.9313)
## No Information Rate : 0.5655
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.8154
##
## Mcnemar's Test P-Value : 0.001673
##
## Sensitivity : 0.8808
## Specificity : 0.9440
## Pos Pred Value : 0.9534
## Neg Pred Value : 0.8588
## Prevalence : 0.5655
## Detection Rate : 0.4981
## Detection Prevalence : 0.5225
## Balanced Accuracy : 0.9124
##
## 'Positive' Class : Clinton
##
Setting the Grid and Parameters for Random Forest
# Set grid
rfGrid <- expand.grid(mtry = 1:8,
min.node.size=c(5,10,15),
splitrule=c("gini","extratrees"))
# Set control parameters for model training
fitCtrl <- trainControl(method = "repeatedcv",
number = 5,
repeats = 2,
summaryFunction=twoClassSummary,
classProbs = TRUE,
savePredictions = TRUE,
search = "grid",
sampling = "down",#down sampling
allowParallel = TRUE) #parallel on
Running the Random Forest
Print Tuning Results
## mtry splitrule min.node.size
## 3 1 gini 10
Predict Values
Training Set Results
## Confusion Matrix and Statistics
##
## Reference
## Prediction Clinton Trump
## Clinton 611 49
## Trump 41 549
##
## Accuracy : 0.928
## 95% CI : (0.9122, 0.9417)
## No Information Rate : 0.5216
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.8557
##
## Mcnemar's Test P-Value : 0.4606
##
## Sensitivity : 0.9371
## Specificity : 0.9181
## Pos Pred Value : 0.9258
## Neg Pred Value : 0.9305
## Prevalence : 0.5216
## Detection Rate : 0.4888
## Detection Prevalence : 0.5280
## Balanced Accuracy : 0.9276
##
## 'Positive' Class : Clinton
##
Test Set Results
## Confusion Matrix and Statistics
##
## Reference
## Prediction Clinton Trump
## Clinton 259 23
## Trump 20 232
##
## Accuracy : 0.9195
## 95% CI : (0.8931, 0.9411)
## No Information Rate : 0.5225
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.8385
##
## Mcnemar's Test P-Value : 0.7604
##
## Sensitivity : 0.9283
## Specificity : 0.9098
## Pos Pred Value : 0.9184
## Neg Pred Value : 0.9206
## Prevalence : 0.5225
## Detection Rate : 0.4850
## Detection Prevalence : 0.5281
## Balanced Accuracy : 0.9191
##
## 'Positive' Class : Clinton
##
ROC Curve - pROC
ROC Curve - ggplot alternative
#new df for ggplot
df <- data.frame(real = c(train$presvote,test$presvote),
preds = c(pred.train,pred.test),
type = c(rep("Training Set",length(pred.train)),rep("Testing Set",length(pred.test))))
#build the base roc curve
roc.ggplot <- ggplot(df, aes(d = real, m = preds, color = type)) +
geom_roc() + theme_bw() + scale_color_manual(values = c("gray50","black")) +
labs(color='Set', x = "Specificity", y = "Sensitivity", title = "ROC Curves for Training and Testing Sets")
roc.ggplot
Interpretable Machine Learning (IML
)
This package helps with, well, interpreting machine learning algorithms and their outputs. It has functions that estimate feature importance and effects, interactions, and variable importance measured by Shapley values.
IML can be used with many machine learning models, but for this lab we will be using it with a random forest.
Setting up IML and Feature Importance
# Set up parallel processing using the 'future' package
plan("multisession")
# Subset predictor variables, set prediction parameters
X <- dat2016[which(names(dat2016) != "presvote")]
predictor <- Predictor$new(rf.res, data = X,
y = dat2016$presvote, type = "prob")
# Measure feature importance (permutation test)
# This outputs a ggplot object
imp <- FeatureImp$new(predictor,
loss = "ce",
n.repetitions = 30)
imp.plot <- plot(imp)
Pull out IMP data for custom use
#extract the imp values from the imp object
imp.df <- imp$results
imp.df <- arrange(imp.df)
#here's a simple custom ggplot
ggplot(imp.df, aes(x = reorder(feature,importance), y = importance,fill=importance)) +
geom_bar(stat = "identity",position = "dodge") +
guides(fill=F)+
coord_flip() +
xlab("") +
ggtitle("Feature Importance") +
ylab("Variable Importance (relative)") +
scale_x_discrete(labels = rev(c("Pres Vote 2012",
"Black Lives Matter\nTherm",
"Gender Resentment",
"Universal Health\nInsurance",
"Union Therm",
"Abortion",
"Rich Therm",
"Female",
"Black"))) +
theme_calc()
Feature Effects
We can examine feature effects through average local effects (ALE), partial dependence plots (PDP), and Individual Conditional Expectation (ICE) curves all within the IML package.
ALE Code
ALE Plot
PDP Code
PDP+ICE Plot
Interaction Effects
Using Shapley Values to Measure Feature Effects
shapley <- Shapley$new(predictor, x.interest = X[1,])
shapley.plot <- shapley$plot()
shapley2 <- Shapley$new(predictor, x.interest = X[2,])
shapley2.plot <- shapley2$plot()
Shapley Plot for Observation 1
Shapley Plot for Observation 2
Surrogate Model
One interesting use of the IML package is creating a “surrogate model” that takes incorporates information from a more complicated model and makes it interpretable. For example, a single decision tree can act as a surrogate model for a random forest. This is done by training the decision tree on the original features and the predicted outcomes from the complex model.
#Install these packages if you haven't already to do this analysis.
#install.packages(c("partykit", "libcoin", "mvtnorm"))
tree.surr <- TreeSurrogate$new(predictor, maxdepth = 2)
Tree Plot
Cell Plot
Predictions from Surrogate Model
surrogate.pred <- as.factor(unlist(tree.surr$predict(X, type="class")))
confusionMatrix(dat2016$presvote, surrogate.pred)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Clinton Trump
## Clinton 881 50
## Trump 127 726
##
## Accuracy : 0.9008
## 95% CI : (0.886, 0.9143)
## No Information Rate : 0.565
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.8004
##
## Mcnemar's Test P-Value : 1.113e-08
##
## Sensitivity : 0.8740
## Specificity : 0.9356
## Pos Pred Value : 0.9463
## Neg Pred Value : 0.8511
## Prevalence : 0.5650
## Detection Rate : 0.4938
## Detection Prevalence : 0.5219
## Balanced Accuracy : 0.9048
##
## 'Positive' Class : Clinton
##
Alternatives: VIP, PDP, FASTSHAP
The IML package is great, but there are other competitors as well. These competitors have different benefits but offer some additional functionality. We’ll begin with the VIP package, which uses FASTSHAP as a backend.
Parallel backend
Variable importance permuted scores
set.seed(2022)
vi_scores <- vi(object = rf.res, #specify the model
method = "permute", #specify VI method (i.e., shap,permute,etc.)
train = as.data.frame(train), #training data
metric = "auc",
target = "presvote", #specify Y
reference_class = "Trump", #specify reference class in Y
pred_wrapper = predict, #prediction function
keep = T, #keep all perm results
nsim = 100, #number of sims
parallel = T) #good for cross model comparison
Variable importance Shapley values
set.seed(2022)
#define a prediction function for the vi function
predict.shap <- function(object,newdata){
predict(object = object, newdata = newdata,type = "prob")[,1]
}
#calculate aggregated shapley importance values
shaps <- vi_shap(rf.res,
nsim = 20,
parallel = TRUE,
train = as.data.frame(X),
pred_wrapper = predict.shap)
#calculate individual shapley values
fullshaps <- fastshap::explain(object = rf.res,
parallel = T,
feature_names = colnames(X),
X = as.data.frame(X),
pred_wrapper = predict.shap,
nsim = 20)
fullshaps <- as.data.frame(fullshaps)
#build df for plot
g.df <- data.frame(unions = X$thermometerunions, #original values
shaps = fullshaps$thermometerunions) #shap values
#build plot
ggplot(g.df, aes(x = unions,y = shaps)) +
geom_jitter(alpha = .15,color = "forest green") +
geom_smooth(se=T,color = "black") +
scale_color_viridis_c() +
ylab("Shapley value") +
xlab("Union Feeling Thermometer") +
theme_gdocs()
Simple VIP barplot
VIP’s basic bar plot is a decent first cut and works out of the box with ggplot
Boxplot importance values with uncertainty estimates
Boxplots are another good option with a better representation of uncertainty
Partial Dependence Plots with the PDP package
Another way to make PDP plots is using the PDP package (very powerful! many options!).
Binary/categorical version
Conformal Inference and LOCO (Experimental!)
Leave-one-covariate-out (LOCO) variable importance with a simple random forest
# need to do some cleaning, loco only works with numeric matrices - no factors
dat2016$female <- as.numeric(dat2016$female)
dat2016$presvote2012 <- as.numeric(dat2016$presvote2012)
dat2016$black <- as.numeric(dat2016$black)
dat2016$presvote <- as.numeric(dat2016$presvote)
#
#we want to use random forest, it'll setup the functions for us. In practice, would want to use ideal settings from a CV procedure
funs <- conformalInference::rf.funs(ntree = 2000,nodesize = 5)
# setup the X matrix the way loco wants it
X_mat <- as.matrix(as.data.frame(subset(dat2016,select = -thermometerblacklivesmatter)))
#
#run the loco vi calcs, it does the splitting for us - use whole data frame
out.loco <- loco(x = X_mat, #X matrix
y = dat2016$thermometerblacklivesmatter, #DV
train.fun = funs$train.fun, #training function, in this case rf
predict.fun = funs$predict.fun, #function to extract predictions
active.fun = funs$active.fun, #checks for active features from training model
alpha = .1, #uncertainty interval settings (i.e., 90% confidence intervals)
split = NULL, #allow the function to split for us, default is 70/30
bonf.correct = TRUE, #correct p-values and confidence intervals for multiple testing
seed = 1995) #set the seed for replicability
# create a df for the plot
plot.loco <- data.frame(variable = colnames(X_mat), out.loco$inf.wilcox)
# plot
plot.loco %>%
mutate(var_order = forcats::fct_reorder(variable, LowConfPt)) %>%
ggplot(aes(x = var_order,ymax = UpConfPt,ymin = LowConfPt,color = UpConfPt)) +
scale_color_gradient2(low = "firebrick1",mid = "blue3",high = "green2") +
geom_errorbar(width = .1) +
ggtitle("LOCO Variable Importance: RF Model of Therms BLM") +
ylab("Median Test Error (Delta)") +
xlab("") +
geom_hline(yintercept = 0,linetype = "dashed") +
theme_tufte() +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
To read more about LOCO, since we didn’t cover it in class, see here (section 6): https://www.stat.cmu.edu/~ryantibs/papers/conformal.pdf