Explainable AI Framework Comparison
Written by Tigran Avetisyan
This is PART 2 of our 3 PART series “Explainable AI Framework Comparison.”
See PART 1 here.
In PART 1, we’ve explored SHAP — a popular framework for explainable AI that uses Shapley values to explain machine learning/deep learning models. We’ve used SHAP to explain MNIST digit classification with TensorFlow.
In PART 2, we are going to have a look at LIME — another very popular AI explainability framework. Although LIME pursues the same goals as SHAP, it’s very different from SHAP in terms of implementation and capabilities.
Below, we are going to use LIME to once again explain MNIST digit classification with TensorFlow. This will allow us to have an apples-to-apples comparison with SHAP.
Then, we’ll outline the distinctions between SHAP and LIME to hopefully help you pick the right framework for your tasks.
Let’s get started!
What is LIME?
LIME (Local Interpretable Model-agnostic Explanations) is an AI explainability framework that was introduced in 2016 by Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin in the research paper “Why Should I Trust You?”: Explaining the Predictions of Any Classifier.
LIME works by perturbing input data to observe how the changes affect predictions. In more technical terms, here’s how the authors describe the intuition behind LIME:
Because we want to be model-agnostic, what we can do to learn the behavior of the underlying model is to perturb the input and see how the predictions change. This turns out to be a benefit in terms of interpretability, because we can perturb the input by changing components that make sense to humans (e.g., words or parts of an image), even if the model is using much more complicated components as features (e.g., word embeddings).
We generate an explanation by approximating the underlying model by an interpretable one (such as a linear model with only a few non-zero coefficients), learned on perturbations of the original instance (e.g., removing words or hiding parts of the image). The key intuition behind LIME is that it is much easier to approximate a black-box model by a simple model locally (in the neighborhood of the prediction we want to explain), as opposed to trying to approximate a model globally. This is done by weighting the perturbed images by their similarity to the instance we want to explain. Going back to our example of a flu prediction, the three highlighted symptoms may be a faithful approximation of the black-box model for patients who look like the one being inspected, but they probably do not represent how the model behaves for all patients.
The authors then give an example to illustrate how LIME works for image classification, which is what we are interested in today:
With that, the intuition behind LIME is pretty simple. If you want to find out more, be sure to read the research paper of LIME.
Prerequisites for Using LIME
To use LIME for image explanation, you will, first of all, need LIME itself and scikit-image (skimage). LIME uses scikit-image for image segmentation.
To install LIME and scikit-image, run the following commands in the terminal if you are using the pip package manager:
pip install lime
pip install scikit-image
If you are using conda, use the following commands instead:
conda install -c conda-forge lime
conda install scikit-image
You’ll also need TensorFlow:
pip install tensorflow
Or with conda:
conda install -c conda-forge tensorflow
Finally, if you don’t already have them, be sure to install Matplotlib and NumPy.
Using LIME to Explore Image Predictions
Now that we understand what LIME is and how it works, let’s have a look at the framework in action!
Importing dependencies
As always, we begin by importing the dependencies for our project:
We are additionally increasing the font size in Matplotlib so that the image labels in the plots are nice and visible.
Training a convolutional neural network
Next, let’s import and process the MNIST dataset for training:
Like in PART 1, we add a channel dimension to the images (lines 5 and 6), get the unique labels in the dataset (line 13), and convert labels to categoricals (lines 16 and 17). We additionally convert the images from grayscale to RGB (lines 9 and 10) because LIME expects RGB images by default.
Next, we create our model, which is nearly identical to the model we’ve used in PART 1:
The only difference in this model is that the input layer expects 3-channel RGB images.
Finally, let’s compile and train our model:
After training, save the model for later reuse:
Using LIME to explain model predictions
Now, we can start using LIME to explain model predictions!
Like in PART 1, let’s get one image per label so that we can explain all digits. We’ll need these later.
https://gist.github.com/tavetisyan95/9798b8b91c580e5d6417839e7e944969
Next, let’s explain the outputs of our model. LIME implements the class LimeImageExplainer
to explain image models. You instantiate this class as follows:
Next, we can define a segmentation algorithm for LIME to use. explainer
will use this algorithm to highlight areas on the image that count positively and negatively toward predictions.
To handle segmentation, LIME implements the class SegmentationAlgorithm
, which is a wrapper on top of three scikit-image segmentation algorithms — quickshift
, felzenszwalb
, and slic
. By default, LimeImageExplainer
uses quickshift
, but its default arguments didn’t work for MNIST digit explanation with our model.
Here’s how we instantiate a segmenter with LIME:
In this code block, we pass three arguments to SegmentationAlgorithm
:
· algo_type=”quickshift”
— defines the segmentation algorithm to use — quickshift
in our case.
· kernel_size=1
— sets the width (standard deviation) of the Gaussian kernel used in smoothing the sample density. Higher values lead to fewer clusters and less detailed explanations.
· max_dist=2
— defines the cut-off point for data distances, with higher meaning fewer clusters.
Note that only the parameter algo_type
belongs to SegmentationAlgorithm
. The rest belong to quickshift
. LIME takes arguments after algo_type
and passes them to the algorithm defined in algo_type
. Check out the documentation of skimage.segmentation
to learn more about the parameters implemented in each of the supported algorithms.
Once we select a segmentation algorithm, we can generate explanations for a prediction, using the method explainer.explain_instance
:
This piece of code will store image explanations in explanation
.
We’ve passed a number of arguments to explainer.explain_instance
:
· image=imgs_to_explain[1]
— a sample image to generate an explanation for.
· classifier_fn=model.predict
— a function that outputs prediction probabilities. In our case, we need to just use model.predict
.
· top_labels=10
— the number of labels with the highest probabilities that explainer
should produce explanations for.
· num_samples=500
— the size of the neighborhood to use in the linear model.
· segmentation_fn=segmenter
— the segmentation algorithm used for explanation.
· random_seed=5
— an integer value used as a random seed for the segmentation algorithm.
Finally, to get explanations, run the following piece of code:
Here, explanation.get_image_and_mask
(line 2) returns:
· The original image with a heatmap on top of it.
· An explanation mask with the respect to label=labels_to_explain[1]
.
Instead of labels_to_explain
, you can use the predictions stored in explanation.top_labels
, which contains predictions sorted from the highest to lowest probabilities.
Note the following arguments passed to explanation.get_image_and_mask
:
· positive_only
— whether LIME should only highlight areas that positively contribute to the prediction.
· negative_only
— whether LIME should only highlight areas that negatively contribute to the prediction.
· hide_rest
– whether to hide irrelevant areas in the explanation. If positive_only=True
, LIME will hide negative areas. If positive_only=False
and negative_only=True
, LIME will hide positive areas.
On line 9, we use the function mark_boundaries
from scikit-image to superimpose the boundaries from mask
onto temp
. We then use plt.imshow
to plot the result.
And here’s what the explanation looks like:
In this image, the color green corresponds to areas that count positively toward the prediction of the digit 1
. Red areas correspond to areas that count negatively toward the class.
We can inspect temp
and mask
to get a better idea of what they are. Let’s start with temp:
temp
is our original image with a heatmap over it. As for mask
:
mask
consists of the following values, from dark to bright:
· -1
– areas that count negatively toward the class. These correspond to the red areas in the image explanation from above.
· 0
— areas that don’t count toward the class. These correspond to the gray areas in the explanation.
· 1
— areas that count positively toward the class. These correspond to the green areas in the explanation.
Exploring the capabilities of image explanations in LIME
Above, we’ve only looked at the basics of LIME image explanation. However, there are many parameters that we can tweak to get better explanations. Let’s now have a look at some of the things that LIME can do!
To avoid repetition in the code and make adjustments to the explainer quicker, let’s define several helper objects.
Creating a segmenter
First up, let’s set up a function to help us create segmenters:
We are using **kwargs
for the arguments because each of the supported segmentation algorithms has a different set of parameters.
Handling explanation generation and plotting
Next, let’s create a class to help us quickly generate and plot explanations. We are using a class so that we can generate explanations only once and then make a variety of plots by just changing the arguments.
The code for the class looks like this:
Here are a few things to keep in mind with this class:
· In the class’s constructor (lines 4 to 7), we define an empty list self.explanations
. This list will later store the explanations for imgs_to_explain
, which are our sample images of ten digits. We also create an internal variable imgs_to_explain
that will store the images of the ten digits for reuse.
· In the method explain_instances
(lines 10 to 29), we iterate over the provided images, generate explanations for each of them, and then store the explanations in self.explanations
. Before generating new explanations, we purge explanations from previous runs (line 19).
· The method plot_explanations
(lines 32 to 84) accepts a number of arguments, including image_indices
and top_predictions
. image_indices
determines the indices of the explainer objects in self.explanations
that we want to use. image_indices
effectively are the digits that we want to produce explanations for. As an example, if you passed top_labels=3
to explain_instances
and wanted to explain all images, image_indices
would need to be [0, 1, 2]
. top_predictions
determines the number of the top predictions for image_indices
that we want to explain.
· The method plot_explanations_for_single_image
(lines 86 to 126) expects parameters like image_index
and labels
. This method is intended to generate explanations for a single image with respect to each of the ten possible classes. image_index
determines the index of the explainer that we want to use, while labels
should be a list of the ten possible classes.
Generating explanations with quickshift
Let’s put our utility tools into action and produce explanations for our images with respect to their top 3 predicted classes. To start, we need to create a segmenter object and an object of the class Explainer
:
Let’s now generate explanations for imgs_to_explain
, which will be stored inside my_explainer
:
With num_samples=1000
, it may take quite some time for LIME to generate explanations. You can reduce num_samples
if the process is taking too long for you.
Once the explaining is done, we can plot the explanations for the top three predicted classes for each of the ten digits:
https://gist.github.com/tavetisyan95/984f811d9539194d0f3e1a6f0522d7b4
The resulting plot looks like this:
This plot only outlines the areas that positively contribute to the given prediction. All in all, we can see how different areas in the images contribute to different classes. We can also see that the model mistook 3
for 8
because this particular instance of 3
is similar to 8
.
However, the explanations are somewhat unclear because in some of them, you can’t easily tell which areas are positive and which ones are negative. One possible solution to this is to pass hide_rest=True
to our plotting function. This will hide negative areas in the explanations, leaving only the positive ones:
And here’s what the resulting plot looks like:
We can now better see which areas in the explanations contribute to the predictions.
Alternatively, we can pass positive_only=False
to our plotting function, which will let us see both positive and negative areas in the explanations. This will help us better distinguish between the two.
The result is as follows:
The explanations are clearer here, and we can more easily tell which areas contribute positively (green) or negatively (negatively) to the predictions.
If you want to see only negative areas in the explanations, pass positive_only=False
and negative_only=True
to the plotting function. You can additionally pass hide_rest=True
to only see negative areas:
The plot is as follows:
Let’s now generate explanations for a single image with respect to all the possible labels. As an example, let’s pick the digit 8
— we should get interesting explanations because 8
resembles other digits like 3
or 0
.
Here’s what we get with positive_only=True
(the default in plot.explanations_for_single_image
):
Then with hide_rest=True
:
With positive_only=False
:
In this particular set of explanations, it’s unclear which areas are positive and which are negative. You could try solving this by passing a larger value to the max_dist
parameter of quickshift
.
And finally, here’s what we get with positive_only=False
, negative_only=True
, and hide_rest=True
:
Generating explanations with felzenszwalb and slic
Now, let’s also quickly take a look at the two other segmenters that LIME supports — felzenszwalb
and slic
. For these segmentation algorithms, we’ve tried different combinations of parameters to obtain clear explanations. The parameters you’ll find below should work for you as well, but you could try playing around with them to see how they affect explanations.
To switch segmenters, we need to create a new segmenter and explain our images from scratch. Let’s start with felzenszwalb
:
To keep things short, let’s only plot explanations with positive_only=False
. For the ten digits, the explanations would look like this with felzenszwalb
:
felzenszwalb
appears to highlight digit outlines more tightly than quickshift
does. quickshift
tended to highlight large areas in the background.
And for the digit 8
with respect to all possible classes, the explanations look like this:
And finally, let’s try slic
:
For the ten images, slic
explanations would look like this:
And for the digit 8
with respect to possible labels, the explanations would be as follows:
In our case, slic
seems to provide more fine-grained explanations than felzenszwalb
. The segments in the explanations are smaller and perhaps more precise.
All in all, all three segmentation algorithms appear to be working, but we struggle to point out a clear winner. With that in mind, rather than give preference to one algorithm over the other, we think you should use them all to explain images from different angles. Remember that the explanations are approximate, so it would be a safer bet to combine different methods to get a better overall picture.
LIME vs SHAP — How Do They Compare?
Now, as promised, let’s briefly go over the differences and similarities between LIME and SHAP. Below are some of the areas where we observed noticeable differences between the two frameworks. But note that the points we will mention relate to MNIST image classification with TensorFlow 2 — your experience might vary depending on what data and models you use.
Ease of Use — Winner: SHAP
All in all, we found that SHAP was more intuitive. If we don’t count the weak support for TensorFlow 2 — which is a separate issue — it was much easier to make image explanations with SHAP.
In LIME, we had to play around with the parameters of the segmenters to make sound explanations. The default values defined in scikit-image and LIME didn’t quite work with the MNIST dataset and our neural network, so we had to try different combinations to arrive at the plots that you saw earlier.
We assume that with other datasets and other models, you might again need to test different parameters to get good explanations.
SHAP was more intuitive because it produced clear explanations pretty much out of the box. The only issue we had was that the maskers inpaint_telea
and inpaint_ns
wouldn’t generate good explanations. We ended up choosing blur(128, 128)
for PART 1 because it worked for our dataset and model.
Other than that, generating explanations with SHAP was just a matter of calling a few functions while passing the desirable arguments to them.
SHAP also has built-in plotting functions, so displaying explanations was very easy. With LIME, you need to make plots yourself — this isn’t at all difficult, but it takes time.
Explanation clarity — Winner: Draw
It’s a tough call when it comes to explanation clarity. We managed to get convincing and good-looking explanations with either of the frameworks.
With that being said, SHAP took relatively little effort to make clear explanations. Provide as much data or allow SHAP to run for as many iterations as you can, and you get pretty good results.
However, although LIME’s segmenters needed some tweaking, it allows you to independently inspect positive and negative areas in the images. SHAP doesn’t offer comparable functionality, though you might be able to achieve similar results if you manually filter the SHAP values.
TensorFlow 2 Support — Winner: LIME
Both SHAP and LIME work with TensorFlow 2, but we didn’t have any compatibility issues when using LIME. With SHAP, if you remember, we had to enable compatibility mode in TensorFlow to use DeepExplainer
. Besides, many of the examples in SHAP’s documentation were written in TensorFlow 1, so we had to make some modifications in the code to make it work.
We think that part of the reason for LIME’s better support for TensorFlow 2 is that LIME’s backend doesn’t rely on TensorFlow/Keras functions much, if at all.
SHAP, in contrast, does some very TensorFlow-specific things under the hood. And as of version 0.40.0, SHAP hadn’t been fully transitioned to TF 2, which caused some compatibility issues for us. So although you can use SHAP with TensorFlow 2, LIME seems to work with it way better.
Documentation — Winner: SHAP
Both LIME and SHAP have weak documentation, but overall, SHAP docs were easier to work with. The SHAP documentation has many usage examples for different tasks, which made getting started quicker for us. Although LIME does have image explanation examples in its GitHub repo, they are just Jupyter notebooks with very little commentary.
Both frameworks had pretty good API references — you can find the API reference for SHAP here and for LIME here. With that said, the SHAP reference seems to be incomplete. As an example, there was no reference to the class GradientExplainer
when we published PART 1, and the method __call__
(used to generate explanations) wasn’t listed under Explainer
.
In the end, with either of the frameworks, you might often need to dive into source code to figure out what is going on. But SHAP at least has fairly decent examples in the documentation to help you get started quickly.
Next Steps
LIME is quite distinct from SHAP in its implementation, usage, and capabilities. However, we think that there are no clear winners in the SHAP vs LIME comparison — either framework has its own advantages and disadvantages. Ideally, you would use both to get a wider set of explanations for your model.
This concludes PART 2! In PART 3 of this series, we are going to integrate LIME into a React application. We’re choosing LIME because of its adjustability and better support for TensorFlow 2.
Stay tuned for PART 3!
Code
You can find all code for this article in the Jupyter notebook here.