Explainable AI Framework Comparison

Part 1: Explaining MNIST Image Classification with SHAP

Kedion
14 min readJan 26, 2022
Photo by Shubham Dhage

Written by Tigran Avetisyan

Simple machine learning models can be easy to explain. Their architecture is easily interpretable, and you can quite easily understand the logic behind their outputs.

But as soon as we move on to complex models — like deep neural networks — things get much, much more difficult. Neural nets that have millions of parameters and hundreds of layers can be extremely difficult to interpret. This is why complex machine and deep learning models are often called black boxes.

This opaqueness is a huge hurdle for the adoption of artificial intelligence in many industries, especially in areas like healthcare. Without a precise understanding of how a model can predict pneumonia or cancer in a patient, medical institutions will have nothing but distrust toward machine learning and deep learning.

Researchers have been working hard to solve this issue, coming up with mathematical methods to help us understand how an AI model makes decisions. There’s still a lot of work to be done in the field of explainable AI, but we already have methods that can be considered somewhat reliable.

Some of these methods have been used in AI explainability frameworks. Knowing how to use these frameworks can help us demonstrate the usefulness and value of our AI solutions.

In this series, we are going to explore two popular AI explainability toolsets — SHAP and LIME. More concretely:

· In PART 1, our whole attention will be towards SHAP — a framework that adopts a game-theoretic approach to AI explainability. We will use SHAP to explain the output of a convolutional neural network trained with TensorFlow/Keras.

· In PART 2, we will explore LIME — a similar framework in its goals, but very different in execution. We’ll again use a convolutional neural net.

· In PART 3, we will use SHAP or LIME to build a React application and make explanations accessible via a convenient user interface.

For now, let’s focus on SHAP!

What Is SHAP?

SHAP is an AI explainability framework that unifies a number of existing explainability methods to help us better interpret model predictions. SHAP stands for Shapley Additive exPlanations.

The framework was introduced in 2017 by Scott M. Lundberg and Su-In Lee in the research paper A Unified Approach to Interpreting Model Predictions.

The abstract of the paper reads:

. . . the highest accuracy for large modern datasets is often achieved by complex models that even experts struggle to interpret, such as ensemble or deep learning models, creating a tension between accuracy and interpretability. In response, various methods have recently been proposed to help users interpret the predictions of complex models, but it is often unclear how these methods are related and when one method is preferable over another. To address this problem, we present a unified framework for interpreting predictions, SHAP (SHapley Additive exPlanations). SHAP assigns each feature an importance value for a particular prediction. Its novel components include: (1) the identification of a new class of additive feature importance measures, and (2) theoretical results showing there is a unique solution in this class with a set of desirable properties. The new class unifies six existing methods, notable because several recent methods in the class lack the proposed desirable properties.

The paper then lists three results that SHAP brings to the space of explainable AI:

1. We introduce the perspective of viewing any explanation of a model’s prediction as a model itself, which we term the explanation model. This lets us define the class of additive feature attribution methods (Section 2), which unifies six current methods.

2. We then show that game theory results guaranteeing a unique solution apply to the entire class of additive feature attribution methods (Section 3) and propose SHAP values as a unified measure of feature importance that various methods approximate (Section 4).

3. We propose new SHAP value estimation methods and demonstrate that they are better aligned with human intuition as measured by user studies and more effectually discriminate among model output classes than several existing methods (Section 5).

We won’t be covering the theoretical foundations of SHAP in this guide — if you are interested in them, you should read the research paper. We are going to instead focus on the usage of SHAP.

Prerequisites for Using SHAP

In this guide, we are going to use SHAP to explain a convolutional neural network trained on the MNIST digits dataset. We will be using TensorFlow 2 (more precisely, TF 2.7.0) for training.

To run the code for this project, you’ll need the following libraries:

· SHAP and OpenCV. SHAP relies on OpenCV for image processing and visualization.

· TensorFlow. We’ll use TensorFlow 2 to train a convolutional neural network. Note that some TF 2 features were unsupported as of SHAP 0.40.0 because SHAP had originally been written for TensorFlow 1 and hadn’t been fully migrated to version 2. Below, we’ll show you how to fix some of the issues with TensorFlow 2.

· Graphviz and Pydot. We’ll need these libraries to plot tf.keras models. This will help us better understand what we are doing.

If you are using the pip package installer, run the following commands to set up dependencies:

pip install shap

pip install opencv-python

pip install tensorflow

pip install graphviz

pip install pydot

If you are using conda, use these commands instead:

conda install -c conda-forge shap

conda install -c conda-forge opencv

conda install -c conda-forge tensorflow

conda install -c conda-forge graphviz

conda install -c conda-forge pydot

You will also need NumPy and Matplotlib, but you probably have these already.

Using SHAP to Explore Image Predictions

After you set up your environment, you can start using SHAP!

Below, we are going to train a convolutional neural network on the MNIST dataset. Then, we’ll use SHAP to see how our model is making predictions. Let’s get started!

Importing dependencies

First up, we need to import a number of dependencies for our project, including SHAP and TensorFlow:

Training a convolutional neural network

Next, let’s train a convolutional neural network on the MNIST dataset. We’re using MNIST to keep things simple. Besides, because MNIST images are small (28 x 28), you should be able to easily train our neural net on any modern machine.

To get started with the training, let’s load and process our data:

In this code block, we:

1. Load the MNIST dataset and store its training and test sets in Python tuples (line 2).

2. Add a channel dimension to the image sets (line 5). The original images had a shape of (28, 28) — after the transformation on line 5, they will have a shape of (28, 28, 1).

3. Get the possible label values from the dataset (line 8). We’re obtaining the labels from the test set, but the train set would work as well. We’ll be using these labels later to have a look at images for each of the ten digits.

4. Convert the train and test labels from integers to a one-hot representation (line 11).

Next, we define our neural net, using Keras’s Functional API (represented by the class Model):

Our model has three Conv2D and MaxPool2D blocks, followed by GlobalAveragePooling2D and Dense layers. We trained the model on a GPU, but you should be able to train it on a CPU as well because it is pretty simple.

Finally, let’s compile the model and start training:

Epoch 1/5469/469 [==============================] — 6s 7ms/step — loss: 0.6678 — accuracy: 0.7944 — val_loss: 0.2399 — val_accuracy: 0.9269Epoch 2/5469/469 [==============================] — 3s 7ms/step — loss: 0.2104 — accuracy: 0.9376 — val_loss: 0.1694 — val_accuracy: 0.9490Epoch 3/5469/469 [==============================] — 3s 7ms/step — loss: 0.1540 — accuracy: 0.9535 — val_loss: 0.1275 — val_accuracy: 0.9610Epoch 4/5469/469 [==============================] — 3s 7ms/step — loss: 0.1297 — accuracy: 0.9610 — val_loss: 0.1044 — val_accuracy: 0.9693Epoch 5/5469/469 [==============================] — 3s 6ms/step — loss: 0.1098 — accuracy: 0.9665 — val_loss: 0.0888 — val_accuracy: 0.9728

We’re using a batch size of 128 and are training for 5 epochs. If you encounter OOM (out of memory) issues or if training takes way too long, reduce the batch size, the number of epochs, and/or simplify the model architecture.

You should also save your model for later reuse:

Using SHAP to explain model predictions

Now that we have a model, we can use SHAP to explain its output. We are going to be using these three SHAP classes for explanation:

· Explainer — the basic explainer class in SHAP that works with different kinds of data.

· GradientExplainer — a method that combines SHAP, SmoothGrad, and Integrated Gradients.

· DeepExplainer — a method that combines SHAP and a variant of DeepLIFT.

But before using these classes, let’s get an image for each digit from the dataset. This is so that we can understand how the model behaves with each of the digits.

Picking images to explain

Here’s the code snippet that will allow us to get one image for each label:

In the code block above, we:

1. Create a Python list to store the indices of the labels (line 2).

2. Iterate over our labels (lines 5 to 7), obtaining the index of the first test label that corresponds to the current label in the for loop (line 6). The index for each label is added to our indices list (line 7).

3. Get the test images at indices (line 9).

Just in case, let’s confirm that our code picks indices correctly:

All looks good! We can now start using SHAP!

Using Explainer

Explainer is the basic explainer class in SHAP. You can use it not only for images but also for text or tabular features.

Here’s how you instantiate Explainer:

In the code block above, we:

1. Create a masker for our image (line 2). SHAP will use this masker to highlight areas on input images. We passed “blur(128, 128)” as our masking technique, with the numbers indicating the kernel size for blurring. Other supported mask values are “inpaint_telea” and “inpaint_ns”. Under the hood, SHAP relies on OpenCV for masking.

2. Create an Explainer object (line 6). We pass to it the prediction function of our model, the masker, and our unique labels as strings, and instruct Explainer to select an explanation algorithm automatically.

Then, we pass images to the Explainer object to obtain SHAP values and then plot them:

On line 6 in the code block above, outputs refers to the predicted labels. In our case, we sort and flip the outputs to get class probabilities from the highest to the lowest.

The resulting plot looks like this:

In the plot, red areas increase the probability of the class, while blue areas decrease it. You can roughly see where the model is looking when predicting each class, which can be immensely helpful when trying to figure out if your model is doing predictions correctly.

Explainer goes over the images a number of times — the higher the number, the more refined the results. The plot above was generated after 1,000 evaluations (max_evals=1000). If the calculation takes too long for you, pass a smaller value to max­_evals when calling explainer.

Below, we test a few values for max_evals to give you an idea of how it can impact results.

Here’s how the plot looks for 100 evaluations:

For 500:

And for a 1,000:

You can see that the plots are really blocky at low values and get progressively smoother as we go through more evaluations. However, no matter the number of evaluations, our plots don’t seem convincing. In particular, the explanations for the digit 3 seem to have low detail. The other two methods may be able to provide better explanations.

Using GradientExplainer

Another method you can use to explain image outputs in SHAP is GradientExplainer. The GitHub README file of the framework describes GradientExplainer as follows:

Expected gradients combines ideas from Integrated Gradients, SHAP, and SmoothGrad into a single expected value equation. This allows an entire dataset to be used as the background distribution (as opposed to a single reference value) and allows local smoothing. If we approximate the model with a linear function between each background data sample and the current input to be explained, and we assume the input features are independent then expected gradients will compute approximate SHAP values.

A very useful feature of GradientExplainer is that you can use it to inspect the intermediate layers in a model.

Note that SHAP docs only show the usage of GradientExplainer in TensorFlow 1. We didn’t use TensorFlow 1 or the compatibility features in TensorFlow 2 for this guide. Instead, we adapted the code in the documentation to TensorFlow 2, using the functionality of tf.keras.Model to explain intermediary layers.

To calculate SHAP values for intermediate layers, we will need to do a few tricks. To understand why, let’s have a look at the architecture of our model, using plot_model from tf.keras.utils:

Suppose that we want to inspect the second convolutional layer in our model — conv2d_1. To help us do this, GradientExplainer needs to repeatedly feed inputs into this layer to calculate SHAP values.

To be able to feed images directly to conv2d_1, we need to create a separate model that has conv2d_1 as the first layer (not counting the input). This model will look like this:

Note the input shape of this model — (None, 14, 14, 16), where None is the batch dimension. We can’t feed our (None, 28, 28, 1) MNIST images to this model directly because the input shape is wrong. But this is easy to fix — we just need to create another model that will contain the layers before conv2_1. This model will look like this:

This model will take our (None, 28, 28, 1) MNIST images and produce their intermediary representations, which will have a shape of (None, 14, 14, 16) — exactly what conv2d_1 needs!

And here’s how we create the two models from above:

For model_input, we reuse our original model’s input, but the output is the layer before the intermediate layer that we are interested in in this case, conv2d_1. This model will produce processed images in the shape that is required by conv2d_1 — that is, (None, 14, 14, 16).

model_output takes the processed images and produces class predictions. We can use model_output to calculate and visualize SHAP values.

To obtain SHAP values for an intermediary layer, we will be using the following function:

In the code block above, we:

1. Create a GradientExplainer object (lines 8 and 9), supplying model_output and model_input.predict(test_images). model_input.predict(test_images) are the outputs of the layer that precedes the intermediate layer that we are interested in. GradientExplainer will integrate over these samples to produce SHAP values.

2. Obtain the SHAP values for processed images and the indexes of the top three predicted classes for each image (lines 12 and 13).

3. Plot the SHAP values on top of the input images with respect to the predicted classes (lines 16 to 18).

Let’s use calculate_plot_gradient_explainer to explain all three convolutional layers in our model:

In this code block, we:

1. Get the indices of the convolutional layers in our original model (line 2).

2. Iterate over the indices (lines 4 to 16).

3. Create our models to process the images (lines 6 and 10).

4. Calculate the SHAP values and generate plots for the current layer index (lines 13 to 16).

The resulting plots look like this for the first convolutional layer:

The second convolutional layer:

And the third convolutional layer:

You can generate SHAP values for the MaxPool2D layers as well, by the way.

From the plots above, you can see that the SHAP values follow the outlines and distinctive features in the images. If there were any irregularities or abnormalities in the plots, we could instantly notice them and tweak our model to hopefully obtain more sensible outputs.

You can also see that the plots get coarser as we get closer to the output layer. This is because the outputs of convolutional layers are downscaled by the max pooling layers.

As an alternative to inspecting intermediate layers, you can also explain the input layer, like so:

Here, we just pass our original model to parameter model and test_images to data. We also directly provide imgs_to_explain to the parameter shap_valuesof gradient_explainer.shap_values.

In our case, although we’re explaining the first layer in the model, the SHAP values are the same as for the first convolutional intermediary layer we inspected earlier. That’s because the convolutional layer is the very first layer in the model.

Using DeepExplainer

Finally, let’s try DeepExplainer. Here’s the description of this algorithm from SHAP’s GitHub README:

Deep SHAP is a high-speed approximation algorithm for SHAP values in deep learning models that builds on a connection with DeepLIFT described in the SHAP NIPS paper. The implementation here differs from the original DeepLIFT by using a distribution of background samples instead of a single reference value, and using Shapley equations to linearize components such as max, softmax, products, divisions, etc. Note that some of these enhancements have also been since integrated into DeepLIFT.

The usage of DeepExplaineris similar to Explainer and GradientExplainer. However, as of SHAP 0.40.0, DeepExplainerdidn’t work with TensorFlow 2.

Because DeepExplainer doesn’t fully support TensorFlow versions 2.0 and above, we need to enable compatibility mode with TensorFlow 1. We can do this by calling this function:

After you run this function, TensorFlow will not use eager execution or tf.function — key components of TensorFlow 2.

After V2 behavior is disabled, we need to create our model from zero and load existing weights. We need to do this so that TF 1 can record the model as graph operations.

Next, we need to take a few samples from our training images, use them to generate SHAP values, and plot them:

Like GradientExplainer, DeepExplainer integrates over a background dataset (data) to produce explanations.

The plot would look like this:

This plot is pretty interesting. We can notice that the model sees the features of the digit 8 in the digit 3. It’s clear why — this particular writing of 3 is very similar to an 8. Likewise, the model sees features of the digit 3 in the digit 8.

And once again, plots like this can help you see if there’s anything wrong with your model’s outputs.

Next Steps

The theoretical foundations of SHAP can be hard to grasp. But in practice, the framework is easy to use — if we disregard the spotty support of TF 2 and the quite lackluster documentation, that is.

Nonetheless, SHAP appears to be a strong choice for explainable AI. We’ve demonstrated its uses for image classification, but it can be used for tabular and text data as well.

In PART 2 of this series, we are going to be shifting our attention to LIME — another popular AI interpretability framework. We’ll not only have a look at the features and capabilities of LIME but will also compare it with SHAP.

Stay tuned for PART 2!

Code

You can find all code for this article in the Jupyter notebook here.

--

--

Kedion
Kedion

Written by Kedion

Kedion brings rapid development and product discovery to machine learning & AI solutions.

No responses yet