Machine Learning Interpretation: IML and Alternatives

Using the Interpretable Machine Learning Package (IML)

We’re going to run through a random forest example (including a single tree example too) and use Interpretable Machine Learning (IML) to model our output from the random forest. As we’ve distributed this lab as a standalone, please feel free to come ask us about anything contained herein during office hours.

Setup

# Split data
trainIndex <- createDataPartition(dat2016$presvote, p=0.7, list = FALSE, 
                                  times = 1)
#
train <- dat2016[trainIndex,]
test <- dat2016[-trainIndex,]

Estimating a Single Tree

# Fit the classification tree
singletree <- rpart(presvote ~ ., 
                    data = train,
                    control = rpart.control(minsplit = 20, minbucket = 5))
# Predict values
train.pred <- predict(singletree, train, type="class")
test.pred <- predict(singletree, test, type="class")

Report Model Statistics on Training Set

confusionMatrix(test$presvote, test.pred)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Clinton Trump
##    Clinton     266    13
##    Trump        36   219
##                                           
##                Accuracy : 0.9082          
##                  95% CI : (0.8805, 0.9313)
##     No Information Rate : 0.5655          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.8154          
##                                           
##  Mcnemar's Test P-Value : 0.001673        
##                                           
##             Sensitivity : 0.8808          
##             Specificity : 0.9440          
##          Pos Pred Value : 0.9534          
##          Neg Pred Value : 0.8588          
##              Prevalence : 0.5655          
##          Detection Rate : 0.4981          
##    Detection Prevalence : 0.5225          
##       Balanced Accuracy : 0.9124          
##                                           
##        'Positive' Class : Clinton         
## 

Report Model Statistics on Test Set

# Report Model Statistics
confusionMatrix(test$presvote, test.pred)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Clinton Trump
##    Clinton     266    13
##    Trump        36   219
##                                           
##                Accuracy : 0.9082          
##                  95% CI : (0.8805, 0.9313)
##     No Information Rate : 0.5655          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.8154          
##                                           
##  Mcnemar's Test P-Value : 0.001673        
##                                           
##             Sensitivity : 0.8808          
##             Specificity : 0.9440          
##          Pos Pred Value : 0.9534          
##          Neg Pred Value : 0.8588          
##              Prevalence : 0.5655          
##          Detection Rate : 0.4981          
##    Detection Prevalence : 0.5225          
##       Balanced Accuracy : 0.9124          
##                                           
##        'Positive' Class : Clinton         
## 

Setting the Grid and Parameters for Random Forest

# Set grid
rfGrid <-  expand.grid(mtry = 1:8,
                       min.node.size=c(5,10,15),
                         splitrule=c("gini","extratrees"))
# Set control parameters for model training
fitCtrl <- trainControl(method = "repeatedcv",
                        number = 5, 
                        repeats = 2,
                        summaryFunction=twoClassSummary,
                        classProbs = TRUE,
                        savePredictions = TRUE,
                        search = "grid",
                        sampling = "down",#down sampling 
                        allowParallel = TRUE) #parallel on 

Running the Random Forest

set.seed(1985)
rf.res <- train(presvote ~ .,
    data=train,
    method="ranger",
    trControl=fitCtrl,
    tuneGrid=rfGrid,
    metric="ROC",
    verbose=FALSE)
#
stopCluster(cl) #turn off the cluster when done 
registerDoSEQ()

Predict Values

# Extract predictions
pred.train <- predict(rf.res, train, type="prob")[,"Trump"]
pred.test <- predict(rf.res, test, type="prob")[,"Trump"]

Training Set Results

confusionMatrix(predict(rf.res, train, type="raw"), train$presvote)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Clinton Trump
##    Clinton     611    49
##    Trump        41   549
##                                           
##                Accuracy : 0.928           
##                  95% CI : (0.9122, 0.9417)
##     No Information Rate : 0.5216          
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.8557          
##                                           
##  Mcnemar's Test P-Value : 0.4606          
##                                           
##             Sensitivity : 0.9371          
##             Specificity : 0.9181          
##          Pos Pred Value : 0.9258          
##          Neg Pred Value : 0.9305          
##              Prevalence : 0.5216          
##          Detection Rate : 0.4888          
##    Detection Prevalence : 0.5280          
##       Balanced Accuracy : 0.9276          
##                                           
##        'Positive' Class : Clinton         
## 

Test Set Results

confusionMatrix(predict(rf.res, test, type="raw"), test$presvote)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Clinton Trump
##    Clinton     259    23
##    Trump        20   232
##                                           
##                Accuracy : 0.9195          
##                  95% CI : (0.8931, 0.9411)
##     No Information Rate : 0.5225          
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.8385          
##                                           
##  Mcnemar's Test P-Value : 0.7604          
##                                           
##             Sensitivity : 0.9283          
##             Specificity : 0.9098          
##          Pos Pred Value : 0.9184          
##          Neg Pred Value : 0.9206          
##              Prevalence : 0.5225          
##          Detection Rate : 0.4850          
##    Detection Prevalence : 0.5281          
##       Balanced Accuracy : 0.9191          
##                                           
##        'Positive' Class : Clinton         
## 

ROC Curve - pROC

plot.roc(train$presvote, pred.train)
plot.roc(test$presvote, pred.test, add=TRUE, col="green")

ROC Curve - ggplot alternative

#new df for ggplot 
df <- data.frame(real = c(train$presvote,test$presvote),
                 preds = c(pred.train,pred.test),
                 type = c(rep("Training Set",length(pred.train)),rep("Testing Set",length(pred.test))))
#build the base roc curve 
roc.ggplot <- ggplot(df, aes(d = real, m = preds, color = type)) + 
                     geom_roc() + theme_bw() + scale_color_manual(values = c("gray50","black")) +
  labs(color='Set', x = "Specificity", y = "Sensitivity", title = "ROC Curves for Training and Testing Sets")
roc.ggplot

#for interactive, run plot_interactive_roc(roc.ggplot)
#for multiple ROC curves, check out this guide: https://cran.r-project.org/web/packages/plotROC/vignettes/examples.html 

Interpretable Machine Learning (IML)

This package helps with, well, interpreting machine learning algorithms and their outputs. It has functions that estimate feature importance and effects, interactions, and variable importance measured by Shapley values.

IML can be used with many machine learning models, but for this lab we will be using it with a random forest.

Setting up IML and Feature Importance

# Set up parallel processing using the 'future' package
plan("multisession")

# Subset predictor variables, set prediction parameters
X <- dat2016[which(names(dat2016) != "presvote")]
predictor <- Predictor$new(rf.res, data = X, 
                           y = dat2016$presvote, type = "prob")

# Measure feature importance (permutation test)
# This outputs a ggplot object
imp <- FeatureImp$new(predictor, 
                      loss = "ce",
                      n.repetitions = 30)

imp.plot <- plot(imp)

Pull out IMP data for custom use

#extract the imp values from the imp object
imp.df <- imp$results
imp.df <- arrange(imp.df)
#here's a simple custom ggplot 
ggplot(imp.df, aes(x = reorder(feature,importance), y = importance,fill=importance)) +
      geom_bar(stat = "identity",position = "dodge") + 
      guides(fill=F)+
      coord_flip() + 
      xlab("") + 
      ggtitle("Feature Importance") + 
      ylab("Variable Importance (relative)") + 
      scale_x_discrete(labels = rev(c("Pres Vote 2012", 
                                                   "Black Lives Matter\nTherm",
                                                   "Gender Resentment",
                                                   "Universal Health\nInsurance",
                                                   "Union Therm",
                                                   "Abortion",
                                                   "Rich Therm",
                                                   "Female",
                                                   "Black"))) +
      theme_calc()  

Feature Effects

We can examine feature effects through average local effects (ALE), partial dependence plots (PDP), and Individual Conditional Expectation (ICE) curves all within the IML package.

ALE Code

ale <- FeatureEffect$new(predictor, 
                         feature = "thermometerrich", 
                         method = "ale")
ale.plot <- ale$plot()

ALE Plot

PDP Code

# Plot feature effects
pdp.ice <- FeatureEffect$new(predictor, 
                         feature = "thermometerblacklivesmatter", 
                         method = "pdp+ice")
pdp.ice.plot <- pdp.ice$plot()

PDP+ICE Plot

Dichotomous PDP+ICE Feature Effect

ale2 <- FeatureEffect$new(predictor, 
                         feature = "black", 
                         method = "pdp+ice")
ale.plot2 <- ale2$plot()

Interaction Effects

# (Friedman's H-statistic)
# Single variable
interact <- Interaction$new(predictor, 
                            feature = "presvote2012")
interact.plot <- plot(interact) 

All Interaction Effects

interact.all <- Interaction$new(predictor)
interact.all.plot <- plot(interact.all)

Using Shapley Values to Measure Feature Effects

shapley <- Shapley$new(predictor, x.interest = X[1,])
shapley.plot <- shapley$plot()

shapley2 <- Shapley$new(predictor, x.interest = X[2,])
shapley2.plot <- shapley2$plot()

Shapley Plot for Observation 1

Shapley Plot for Observation 2

Surrogate Model

One interesting use of the IML package is creating a “surrogate model” that takes incorporates information from a more complicated model and makes it interpretable. For example, a single decision tree can act as a surrogate model for a random forest. This is done by training the decision tree on the original features and the predicted outcomes from the complex model.

#Install these packages if you haven't already to do this analysis.
#install.packages(c("partykit", "libcoin", "mvtnorm"))
tree.surr <- TreeSurrogate$new(predictor, maxdepth = 2)

Decision Tree Plots

tree <- tree.surr$tree
cell.plot <- plot(tree.surr)

Tree Plot

Cell Plot

Predictions from Surrogate Model

surrogate.pred <- as.factor(unlist(tree.surr$predict(X, type="class")))
confusionMatrix(dat2016$presvote, surrogate.pred)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Clinton Trump
##    Clinton     881    50
##    Trump       127   726
##                                          
##                Accuracy : 0.9008         
##                  95% CI : (0.886, 0.9143)
##     No Information Rate : 0.565          
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.8004         
##                                          
##  Mcnemar's Test P-Value : 1.113e-08      
##                                          
##             Sensitivity : 0.8740         
##             Specificity : 0.9356         
##          Pos Pred Value : 0.9463         
##          Neg Pred Value : 0.8511         
##              Prevalence : 0.5650         
##          Detection Rate : 0.4938         
##    Detection Prevalence : 0.5219         
##       Balanced Accuracy : 0.9048         
##                                          
##        'Positive' Class : Clinton        
## 

Alternatives: VIP, PDP, FASTSHAP

The IML package is great, but there are other competitors as well. These competitors have different benefits but offer some additional functionality. We’ll begin with the VIP package, which uses FASTSHAP as a backend.

Parallel backend

plan(sequential) #end future plan, switch back to doParallel for next stage 
cl <- makeCluster(detectCores() - 4, #how many cores allocated to our cluster 
                  setup_timeout = 0.5) 
registerDoParallel(cl)

Variable importance permuted scores

set.seed(2022)
vi_scores <- vi(object = rf.res, #specify the model 
                method = "permute", #specify VI method (i.e., shap,permute,etc.)
                train = as.data.frame(train), #training data
                metric = "auc", 
                target = "presvote", #specify Y 
                reference_class = "Trump", #specify reference class in Y 
                pred_wrapper = predict, #prediction function
                keep = T, #keep all perm results
                nsim = 100, #number of sims 
                parallel = T) #good for cross model comparison

Variable importance Shapley values

set.seed(2022)
#define a prediction function for the vi function 
predict.shap <- function(object,newdata){
  predict(object = object, newdata = newdata,type = "prob")[,1]
}
#calculate aggregated shapley importance values 
shaps <- vi_shap(rf.res,
                 nsim = 20, 
                 parallel = TRUE, 
                 train = as.data.frame(X),
                 pred_wrapper = predict.shap)
#calculate individual shapley values 
fullshaps <- fastshap::explain(object = rf.res,
                               parallel = T, 
                               feature_names = colnames(X),
                               X = as.data.frame(X), 
                               pred_wrapper = predict.shap,
                               nsim = 20) 
fullshaps <- as.data.frame(fullshaps)
#build df for plot 
g.df <- data.frame(unions = X$thermometerunions, #original values 
                   shaps = fullshaps$thermometerunions) #shap values 
#build plot 
ggplot(g.df, aes(x = unions,y = shaps)) + 
      geom_jitter(alpha = .15,color = "forest green") + 
      geom_smooth(se=T,color = "black") +
      scale_color_viridis_c() + 
      ylab("Shapley value") + 
      xlab("Union Feeling Thermometer") + 
      theme_gdocs()

Simple VIP barplot

VIP’s basic bar plot is a decent first cut and works out of the box with ggplot

vip(vi_scores) + theme_calc()

Boxplot importance values with uncertainty estimates

Boxplots are another good option with a better representation of uncertainty

vip(subset(vi_scores, vi_scores$Variable != "presvote2012"), #pull out the LDV
    all_permutations = FALSE, #use this to display all permutations 
    geom="boxplot") +
    theme_ggeffects()

Partial Dependence Plots with the PDP package

Another way to make PDP plots is using the PDP package (very powerful! many options!).

Binary/categorical version

pd <- partial(rf.res, #model 
              plot = T, #show plot or only save data 
              pred.var = c("healthinsurance","abortion"), #which vars for grid
              plot.engine = "ggplot2", #lattice also available 
              chull = T,
              parallel = T) #run in parallel or not 
pd + theme_minimal()

Continuous-(ish) version

pd <- partial(rf.res,
              pred.var = c("thermometerblacklivesmatter","thermometerrich"),
              chull = TRUE,
              parallel = T)
plotPartial(pd)

Conformal Inference and LOCO (Experimental!)

Leave-one-covariate-out (LOCO) variable importance with a simple random forest

# need to do some cleaning, loco only works with numeric matrices - no factors 
dat2016$female <- as.numeric(dat2016$female) 
dat2016$presvote2012 <- as.numeric(dat2016$presvote2012)
dat2016$black <- as.numeric(dat2016$black)
dat2016$presvote <- as.numeric(dat2016$presvote)
# 
#we want to use random forest, it'll setup the functions for us. In practice, would want to use ideal settings from a CV procedure 
funs <- conformalInference::rf.funs(ntree = 2000,nodesize = 5)
# setup the X matrix the way loco wants it 
X_mat <- as.matrix(as.data.frame(subset(dat2016,select = -thermometerblacklivesmatter)))
#
#run the loco vi calcs, it does the splitting for us - use whole data frame   
out.loco <- loco(x = X_mat, #X matrix 
                 y = dat2016$thermometerblacklivesmatter, #DV
                 train.fun = funs$train.fun, #training function, in this case rf 
                 predict.fun = funs$predict.fun, #function to extract predictions 
                 active.fun = funs$active.fun, #checks for active features from training model 
                 alpha = .1, #uncertainty interval settings (i.e., 90% confidence intervals)
                 split = NULL, #allow the function to split for us, default is 70/30
                 bonf.correct = TRUE, #correct p-values and confidence intervals for multiple testing 
                 seed = 1995) #set the seed for replicability 
# create a df for the plot 
plot.loco <- data.frame(variable = colnames(X_mat), out.loco$inf.wilcox)
# plot 
plot.loco %>% 
  mutate(var_order = forcats::fct_reorder(variable, LowConfPt)) %>%
  ggplot(aes(x =  var_order,ymax = UpConfPt,ymin = LowConfPt,color = UpConfPt)) +
  scale_color_gradient2(low = "firebrick1",mid = "blue3",high = "green2") + 
  geom_errorbar(width = .1) + 
  ggtitle("LOCO Variable Importance: RF Model of Therms BLM") + 
  ylab("Median Test Error (Delta)") +
  xlab("") + 
  geom_hline(yintercept = 0,linetype = "dashed") + 
  theme_tufte() + 
  theme(axis.text.x = element_text(angle = 45, hjust = 1))  

To read more about LOCO, since we didn’t cover it in class, see here (section 6): https://www.stat.cmu.edu/~ryantibs/papers/conformal.pdf