Towards a better understanding of why machine learning models make the decisions they do, and why it matters
Model explainability is one of the most important problems in machine learning today. It’s often the case that certain “black box” models such as deep neural networks are deployed to production and are running critical systems from everything in your workplace security cameras to your smartphone. It’s a scary thought that not even the developers of these algorithms understand why exactly the algorithms make the decisions they do — or even worse, how to prevent an adversary from exploiting them.
While there are many challenges facing the designer of a “black box” algorithm, it’s not completely hopeless. There are actually many different ways to illuminate the decisions a model makes. It’s even possible to understand which features are the most salient in a model’s predictions.
In this article, I give a comprehensive overview of model explainability for deeper models in machine learning. I hope to explain how deeper models more traditionally considered “black boxes” can actually be surprisingly explainable. We use model-agnostic methods to apply interpretability to all different kinds of black box models.
A partial dependence plot shows the effect of a feature on the outcome of a ML model.
Partial dependence works by marginalizing the machine learning model output over the distribution of the features we are not interested in (denoted by features in set C). This makes it such that the partial dependence function shows the relationship between the features we do care about (which we denote by buying in set S) and the predicted outcome. By marginalizing over the other features, we get a function that depends only on features in S. This makes it easy to understand how varying a specific feature influences model predictions. For example, here are 3 PDP plots for Temperature, Humidity and Wind Speed as relating to predicted bike sales by a linear model.
PDP’s can even be used for categorical features. Here’s one for the effect of season on bike rental.
For classification, the partial dependence plot displays the probability for a certain class given different values for features. A good way to deal with multi-class problems is to have one PDP per class.
The partial dependence plot method is useful because it is global. It makes a point about the global relationship between a certain feature and a target outcome across all values of that feature.
Partial Dependence Plots are highly intuitive. The partial dependence function for a feature at a value represents the average prediction if we have all data points assume that feature value.
You can really only model a maximum of two features using the partial dependence function.
Assumption of independence: You are assuming the features that you are plotting are not correlated with any other features. For example, if you are predicting blood pressure off of height and weight, you have to assume that height is not correlated with weight. The reason this is the case is that you have to average over the marginal distribution of weight if you are plotting height (or vice-versa). This means, for example, that you can have very small weights for someone who is quite tall, which you probably not see in your actual dataset.
Great! I want to implement a PDP for my model. Where do I start?
Permutation feature importance is a way to measure the importance of a feature by calculating change in a model’s prediction error after permuting the feature. A feature is “important” if permuting its values increases the model error, and “unimportant” if permuting the values leaves the model error unchanged.
The algorithm works as follows:
input: a model f, feature matrix X, target vector Y and error measure L(y, f)1. Estimate the original model error e⁰ = L(Y, f(x))2. For each feature j:- Generate feature matrix X' by permuting feature j in the original feature matrix X.- Estimate the new error e¹=L(Y, f(X')) based off the model's predictions for the new data X'- Calculate the permutation feature importance FI=e¹/e⁰. You can also use e¹-e⁰.3. Sort the features by descending FI.
After you have sorted the features by descending FI, you can plot the results. Here is the permutation feature importance plot for the bike rentals problem.
Interpretability: Feature importance is just how much the error increases when a feature is distorted. This is easy to explain and visualize.
Permutation feature importance provides global insight into the model’s behavior.
Permutation feature importance does not require training a new model or retraining an existing model, simply shuffling features around.
It’s not clear whether you should use training or test data for your plot.
If features are correlated, you can get unrealistic samples after permuting features, biasing the outcome.
Adding a correlated feature to your model can decrease the importance of another feature.
Great! I want to implement Permutation Feature Importance for my model. Where do I start?
ALE plots are a faster and unbiased alternative to partial dependence plots. They measure how features influence the prediction of a model. Because they are unbiased, they handle correlated features much better than PDP’s do.
If features of a machine learning model are correlated, the partial dependence plot cannot be trusted, because you can generate samples that are very unlikely in reality by varying a single feature. ALE plots solve this problem by calculating – also based on the conditional distribution of the features – differences in predictions instead of averages. One way to interpret this is by thinking of the ALE as saying
“Let me show you how the model predictions change in a small “window” of the feature.”
Here’s a visual interpretation of what is going on in an ALE plot.
This can also be done with two features.
Once you have computed the differences in predictions over each window, you can generate an ALE plot.
ALE plots can also be done for categorical features.
ALE plots are unbiased, meaning they work with correlated features.
ALE plots are computationally fast to compute.
The interpretation of the ALE plot is clear.
The implementation of ALE plots is complicated and difficult to understand.
Interpretation still remains difficult if features are strongly correlated.
Second-order or 2D ALE plots can be hard to interpret.
Generally, it is better to use ALE’s over PDP’s, especially if you expect correlated features.
Great! I want to implement ALE’s for my model. Where do I start?
Individual Conditional Expectation (ICE) plots display one line per data point. It produces a plot that shows how the model’s prediction for a data point changes as a feature varies across all data points in a set. For the plot below, you can see the ICE plots for varying temperature, humidity and wind speed across all instances in the training set bike rental data.
Looking at this plot, you may ask yourself: what is the point of looking at an ICE plot instead of a PDP? It seems much less interpretable.
PDPs can only show you what the average relationship between what a feature and a prediction looks like. This only works well if the interactions between the features for which the PDP is calculated and the other features are uncorrelated, but in the case of strong, correlated interactions, the ICE plot will be more insightful.
Like PDP plots, ICE plots are very intuitive to understand.
ICE plots can uncover heterogeneous relationships better than PDP plots can.
ICE curves can only display one feature at a time.
The plots generated by this method can be hard to read and overcrowded.
Great! I want to implement ICE for my model. Where do I start?
A surrogate model is an interpretable model (such as a decision tree or linear model) that is trained to approximate the predictions of a black box. We can understand the black box better by interpreting the surrogate model’s decisions.
The algorithm for generating a surrogate model is straightforward.
1. Select a dataset X that you can run your black box model on.2. For the selected dataset X, get the predictions of your black box model. 3. Select an interpretable model type (e.g. linear model or decision tree)4. Train the interpretable model on the dataset X and the black box's predictions. 5. Measure how well the surrogate model replicates the predictions of the black box model. 6. Interpret the surrogate model.
One way to measure how well the surrogate replicates the black box through the R-squared metric:
The R-squared metric is a way to measure the variance captured by the surrogate model. An R-squared value close to 1 implies the surrogate model captures the variance well, and close to 0 implies that it is capturing very little variance, and not explaining the black box model well.
This approach is intuitive: you are learning what the black box model thinks is important by approximating it.
Easy to measure: It’s clear how well the interpretable model performs in approximating the black box through the R-squared metric.
The linear model may not approximate the black box model well.
You are drawing conclusions about the black box model and not the actual data, as you are using the black box’s model predictions as labels without seeing the ground truth.
Even if you do approximate the black box model well, the explainability of the “interpretable” model may not actually represent what the black box model has learned.
It may be difficult to explain the interpretable model.
Great! I want to implement a surrogate model.Where do I start?
As machine learning becomes more prominent in daily life, the future of interpretability is more important than ever before. I believe a few trends will categorize the future of interpretability, and this will shape how we interact with AI models in the future.
Model agnostic interpretability focus
All the trends in deep learning research point to the fact that deep networks are not saturating with our current computing and data limits. It’s important to realize that as our models get deeper and deeper in everything from image recognition to text generation, there is a need for methods that can provide interpretability across all types of models. The generalizability aspect will become more and more useful the more machine learning takes a hold in different fields. The methods I discussed in this blog post are a start, but we need to take interpretability more seriously as a whole to better understand why the machine learning systems powering our day-to-day are making the decisions they do.
Models that explain themselves
One trend that I have not seen take hold in most ML systems that I believe will exist in the future is the idea of a self-explainable model. Most systems today simply make a decision with reasons that are opaque to the user. In the future, I believe that will change. If a self-driving car makes a decision to stop, we will know why. If Alexa cannot understand our sentence, it will tell us in specific detail what went wrong and how we can phrase our query more clearly. With models that explain themselves, we can better understand how the ML systems in our lives work.
Increased model scrutiny
Finally, I believe that we as a society have pushed black-box model scrutiny under the rug. We do not understand the decisions that our models are making, and that doesn’t seem to be bothering anyone in particular. This will have to change in the future. Engineers and Data Scientists will be held accountable as models start to make mistakes, and this will lead to a trend where we examine the decisions our model makes with the same rigor we would a dataset that the model is trained on.
I hope you enjoyed this post. I certainly found it illuminating to write, and I hope it helps you with your studies or research in the field of machine learning.
Most of the examples here inspired from the excellent Interpretable Machine Learning book.
(Molnar, Christoph. “Interpretable machine learning. A Guide for Making Black Box Models Explainable”, 2019. https://christophm.github.io/interpretable-ml-book/)
I highly encourage you to buy it if you wish to further your knowledge in the topic.