18  Exercise Solution: Explainable AI (Penguins Classification Dataset)

In this notebook, we’ll be exploring how to use a couple of different explainable AI techniques.

We’ll be using a different dataset to take a look into this today. The penguins dataset is a great dataset for practising classification problems. This data has been pulled using the excellent palmerpenguins package.

LINK

What we need to know is that

In this exercise you will need go through the code and fill in any missing spaces.

By the end of this exercise you should know - how to calculate feature importance using the MDI method for tree-based models - how to calculate feature importance for any model using the permutation feature importance method - how to create partial dependence plots (PDPs) and individual conditionla expectation (ICE) plots for any model - how to use the SHAP library to understand a model

The SHAP code does vary subtly for different kinds of model; we will just be working with an XGboost model in this case to match what we’ve done in the lecture.

18.0.0.1 Library Imports

from sklearn import datasets
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn import metrics
from sklearn.inspection import permutation_importance
from sklearn.inspection import PartialDependenceDisplay

import shap

# Java Script for SHAP Plots
shap.initjs()

# Helper function to see methods in object
# Might be useful when working through this exercise

def object_methods(obj):
    '''
    Helper function to list methods associated with an object
    '''
    try:
        methods = [method_name for method_name in dir(obj)
                   if callable(getattr(obj, method_name))]
        print('Below are the methods for object: ', obj)
        for method in methods:
            print(method)
    except:
        print("Error")

18.0.0.2 Load & Clean Data

Run this cell to load the dataframe.

penguins = pd.read_csv("../datasets/penguins.csv")

Examine the dataset with your choice(s) of function(s).

penguins.head()
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g year island_Biscoe island_Dream island_Torgersen male target
0 32.1 15.5 188.0 3050.0 2009 0 1 0 0.0 0
1 33.1 16.1 178.0 2900.0 2008 0 1 0 0.0 0
2 33.5 19.0 190.0 3600.0 2008 0 0 1 0.0 0
3 34.0 17.1 185.0 3400.0 2008 0 1 0 0.0 0
4 34.1 18.1 193.0 3475.0 2007 0 0 1 NaN 0

Run the code below to convert the classes 0, 1 and 2 into the relevant species names and add this as a new column.

Try to understand how this is working - it’s a useful little pattern to know for your own datasets!

First, we are going to create a dictionary. Can you remember what we call the parts before and the parts after the colon in the dictionary?

# Define the different classes/ species
class_dict = {0 : 'Adelie',
             1 : 'Chinstrap',
             2 : 'Gentoo'}

class_dict
{0: 'Adelie', 1: 'Chinstrap', 2: 'Gentoo'}

Now we are going to use our dictionary for creating the column.

How do you think this is working? You may want to look up the get method of the standard python dictionary, and the apply method of pandas, just to understand a little more about this very useful way of making new conditional columns.

# Add species into the dataframe
penguins['species'] = penguins['target'].apply(lambda x: class_dict.get(x))

# view a random sample of 10 rows
penguins.sample(10)
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g year island_Biscoe island_Dream island_Torgersen male target species
197 50.7 19.7 203.0 4050.0 2009 0 1 0 1.0 1 Chinstrap
269 46.5 13.5 210.0 4550.0 2007 1 0 0 0.0 2 Gentoo
151 NaN NaN NaN NaN 2007 0 0 1 NaN 0 Adelie
169 46.4 17.8 191.0 3700.0 2008 0 1 0 0.0 1 Chinstrap
12 35.2 15.9 186.0 3050.0 2009 0 0 1 0.0 0 Adelie
245 45.1 14.5 215.0 5000.0 2007 1 0 0 0.0 2 Gentoo
307 49.2 15.2 221.0 6300.0 2007 1 0 0 1.0 2 Gentoo
252 45.4 14.6 211.0 4800.0 2007 1 0 0 0.0 2 Gentoo
68 38.3 19.2 189.0 3950.0 2008 0 1 0 1.0 0 Adelie
207 51.5 18.7 187.0 3250.0 2009 0 1 0 1.0 1 Chinstrap

Let’s take a look at some stats about the data to get an idea of the scale and distribution of the different features. Run the cell below to do this.

# Take a look at some stats about the data
penguins.describe()
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g year island_Biscoe island_Dream island_Torgersen male target
count 342.000000 342.000000 342.000000 342.000000 344.000000 344.000000 344.000000 344.000000 333.000000 344.000000
mean 43.921930 17.151170 200.915205 4201.754386 2008.029070 0.488372 0.360465 0.151163 0.504505 0.918605
std 5.459584 1.974793 14.061714 801.954536 0.818356 0.500593 0.480835 0.358729 0.500732 0.893320
min 32.100000 13.100000 172.000000 2700.000000 2007.000000 0.000000 0.000000 0.000000 0.000000 0.000000
25% 39.225000 15.600000 190.000000 3550.000000 2007.000000 0.000000 0.000000 0.000000 0.000000 0.000000
50% 44.450000 17.300000 197.000000 4050.000000 2008.000000 0.000000 0.000000 0.000000 1.000000 1.000000
75% 48.500000 18.700000 213.000000 4750.000000 2009.000000 1.000000 1.000000 0.000000 1.000000 2.000000
max 59.600000 21.500000 231.000000 6300.000000 2009.000000 1.000000 1.000000 1.000000 1.000000 2.000000

18.0.0.3 Plot the Data

Before we go any further, let’s plot the iris dataset to see how the petal width and length relate to the species.

Fill in the gaps below to create the plot

adelie = penguins[penguins.species=='Adelie']
chinstrap = penguins[penguins.species == "Chinstrap"]
gentoo = penguins[penguins.species=='Gentoo']

fig, ax = plt.subplots()
fig.set_size_inches(13, 7) # adjusting the length and width of plot

# lables and scatter points
ax.scatter(adelie['bill_length_mm'], adelie['bill_depth_mm'], label="Adelie", facecolor="blue")
ax.scatter(chinstrap['bill_length_mm'], chinstrap['bill_depth_mm'], label="Chinstrap", facecolor="green")
ax.scatter(gentoo['bill_length_mm'], gentoo['bill_depth_mm'], label="Gentoo", facecolor="red")


ax.set_xlabel("Bill Length (mm)")
ax.set_ylabel("Bill Depth (mm)")
ax.grid()
ax.set_title("Gentoo Penguin Measurements")
ax.legend()

Now it’s your turn; create this plot, but this time we are interested in the other two columns: sepal length and sepal width.

In the space below, make a copy of the plot that looks at flipper length and body mass.

adelie = penguins[penguins.species=='Adelie']
chinstrap = penguins[penguins.species == "Chinstrap"]
gentoo = penguins[penguins.species=='Gentoo']

fig, ax = plt.subplots()
fig.set_size_inches(13, 7) # adjusting the length and width of plot

# lables and scatter points
ax.scatter(adelie['flipper_length_mm'], adelie['body_mass_g'], label="Adelie", facecolor="blue")
ax.scatter(chinstrap['flipper_length_mm'], chinstrap['body_mass_g'], label="Chinstrap", facecolor="green")
ax.scatter(gentoo['flipper_length_mm'], gentoo['body_mass_g'], label="Gentoo", facecolor="red")


ax.set_xlabel("Flipper Length (mm)")
ax.set_ylabel("Body Mass (g)")
ax.grid()
ax.set_title("Gentoo Penguin Measurements")
ax.legend()

18.0.0.4 Performing Classification

Time to get on to the machine learning aspect.

It’s always important to spend some time understanding your dataset first, though! What we’ve done above is just the tip of the iceberg, but it’s a good start.

Now we’re going to process our dataset for machine learning as we did in the logistic regression, decision tree and boosted tree sessions - this time we’re going to be using XGBoost.

# Droping the target and species since we only need the measurements
X = penguins.drop(['target','species'], axis=1)

# Define features (X) and target (y)
X = X
y = penguins['target']

# get class and features names
class_names = penguins.species.unique()
feature_names = X.columns

# Splitting into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
                                                    random_state=42)

# Instantiate an XGBoost model and fit it
model = XGBClassifier(random_state=42)

model.fit(X_train, y_train)
XGBClassifier(base_score=None, booster=None, callbacks=None,
              colsample_bylevel=None, colsample_bynode=None,
              colsample_bytree=None, device=None, early_stopping_rounds=None,
              enable_categorical=False, eval_metric=None, feature_types=None,
              gamma=None, grow_policy=None, importance_type=None,
              interaction_constraints=None, learning_rate=None, max_bin=None,
              max_cat_threshold=None, max_cat_to_onehot=None,
              max_delta_step=None, max_depth=None, max_leaves=None,
              min_child_weight=None, missing=nan, monotone_constraints=None,
              multi_strategy=None, n_estimators=None, n_jobs=None,
              num_parallel_tree=None, objective='multi:softprob', ...)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

18.0.1 Feature Importance

18.0.1.0.1 Look at feature importance using the feature_importances_ attribute
# Looking at standard feature importance
# This attribute is the mean decrease in impurity for each feature
importances = model.feature_importances_
importances
array([0.10309331, 0.01894252, 0.5356301 , 0.00305671, 0.00406754,
       0.07005206, 0.1871977 , 0.07796   , 0.        ], dtype=float32)

18.0.1.1 Mean decrease in impurity

Generate a plot of the MDI feature importances.

feature_names = X.columns.tolist()

model_importances_mdi_series = pd.Series(importances, index=feature_names)

fig, ax = plt.subplots(figsize=(15,10))
model_importances_mdi_series.plot.bar(ax=ax)
ax.set_title("Feature importances using MDI")
ax.set_ylabel("Mean decrease in impurity")
fig.tight_layout()

18.0.2 Permutation Feature Importance

Calculate permutation feature importance for this dataset and plot it.

This will include error bars.

# Permutation feature importance
result = permutation_importance(
    model, X_test, y_test, n_repeats=10, random_state=42)

model_importances_pfi_series = pd.Series(result.importances_mean, index=feature_names)

fig, ax = plt.subplots(figsize=(15,10))
model_importances_pfi_series.plot.bar(yerr=result.importances_std, ax=ax)
ax.set_title("Feature importances using permutation on full model")
ax.set_ylabel("Mean accuracy decrease")
fig.tight_layout()
plt.show()

18.0.2.1 Predictions

Use the model to make predictions for the training and test set

# Training predictions
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)

18.0.2.2 Assessing Performance

Run this code to generate metrics for the training and test performance of this model.

accuracy_train = np.mean(y_pred_train == y_train)
accuracy_test = np.mean(y_pred_test == y_test)

print (f'Accuracy of predicting training data = {accuracy_train:.2%}')
print (f'Accuracy of predicting test data = {accuracy_test:.2%}')
Accuracy of predicting training data = 100.00%
Accuracy of predicting test data = 97.10%

18.1 PDP Plots

Now let’s create a partial dependence plot for flipper length.

fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
    model,  # Your fitted model
    X,  # Your feature matrix
    features=['flipper_length_mm'],  # List of features to plot
    target=0,
    kind='average',  # Type of PDP
    ax=ax,
    random_state=42
)
plt.show()

Now create two plots side-by-side for bill length and bill depth.

HINT: You don’t need to create multiple separate plots using matplotlib for this - you can do it from within the graphing function we’re using from scikit-learn.

fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
    model,  # Your fitted model
    X,  # Your feature matrix
    features=['bill_length_mm', 'bill_depth_mm'],  # List of features to plot
    target=0,
    kind='average',  # Type of PDP
    ax=ax,
    random_state=42
)
plt.show()

18.2 ICE Plots

Now create three ICE plots of the same feature - one for each class. Make sure to give each plot a name.

fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
    model,  # Your fitted model
    X,  # Your feature matrix
    features=['bill_length_mm'],  # List of features to plot
    target=0,
    kind='individual',  # Type of PDP
    ax=ax,
    random_state=42
)
plt.title("Adelie Penguins - Bill Length ICE Plot")
plt.show()

fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
    model,  # Your fitted model
    X,  # Your feature matrix
    features=['bill_length_mm'],  # List of features to plot
    target=1,
    kind='individual',  # Type of PDP
    ax=ax,
    random_state=42
)
plt.title("Chinstrap Penguins - Bill Length ICE Plot")
plt.show()

fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
    model,  # Your fitted model
    X,  # Your feature matrix
    features=['bill_length_mm'],  # List of features to plot
    target=2,
    kind='individual',  # Type of PDP
    ax=ax,
    random_state=42
)
plt.title("Gentoo Penguins - Bill Length ICE Plot")
plt.show()

Now, just for one of the classes, create an ICE plot for bill_length_mm that also shows the average of all the ICE plots - a joint PDP/ICE plot, effectively!

Again, make sure you provide a title.

fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
    model,  # Your fitted model
    X,  # Your feature matrix
    features=['bill_length_mm'],  # List of features to plot
    target=2,
    kind='both',  # Type of PDP
    ax=ax,
    random_state=42
)
plt.title("Adelie Penguins - Joint ICE/PDP Plot")
plt.show()

18.2.0.1 BONUS: 2D PDP Plots

Now create a 2D plot of bill length and bill depth.

PartialDependenceDisplay.from_estimator(
    model,
    X_test,
    features=[('bill_length_mm', 'bill_depth_mm')],
    kind='average',
    target=0,
    random_state=0
)

18.3 SHAP

We have a multiclass problem with our penguins dataset.

This results in some slightly different outputs from our SHAP code, which can be confusing to deal with, so for now we’re just going to focus on a binary classification problem - is a penguin an Adelie, or not?

Run the code below to turn this into a binary classification problem and retrain the model.

penguins_binary = penguins.copy()

# If Adelie penguin, return 1, else return 0
penguins_binary['target'] = np.where(penguins_binary['target'] == 0, 1, 0)
penguins_binary['species'] = np.where(penguins_binary['species'] == "Adelie", "Adelie", "Not Adelie")

# Droping the target and species since we only need the measurements
X = penguins_binary.drop(['target','species'], axis=1)

# Define features (X) and target (y)
X = X
y = penguins_binary['target']

# get class and features names
class_names = penguins_binary.species.unique()
feature_names = X.columns

# Splitting into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
                                                    random_state=42)

# Instantiate an XGBoost model and fit it
model = XGBClassifier(random_state=42)

model.fit(X_train, y_train)
XGBClassifier(base_score=None, booster=None, callbacks=None,
              colsample_bylevel=None, colsample_bynode=None,
              colsample_bytree=None, device=None, early_stopping_rounds=None,
              enable_categorical=False, eval_metric=None, feature_types=None,
              gamma=None, grow_policy=None, importance_type=None,
              interaction_constraints=None, learning_rate=None, max_bin=None,
              max_cat_threshold=None, max_cat_to_onehot=None,
              max_delta_step=None, max_depth=None, max_leaves=None,
              min_child_weight=None, missing=nan, monotone_constraints=None,
              multi_strategy=None, n_estimators=None, n_jobs=None,
              num_parallel_tree=None, random_state=42, ...)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

18.3.1 Obtaining Shap Values

Generate a SHAP explainer for our model, using X_train as the background data.

# Compute SHAP values
explainer = shap.Explainer(
    model,
    X_train
    )

Now create the shap_values object, using X_test as the foreground data.

shap_values = explainer.shap_values(X_test)
shap_values
array([[-5.17051824e+00,  1.21827941e+00, -5.74065812e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01, -1.36875771e-01],
       [ 1.13075847e-01,  1.78565094e+00, -4.36455202e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -2.35543021e-01, -1.91024542e-01],
       [-2.72566813e-02, -3.30757442e+00, -2.22220667e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
       [-5.10160202e+00,  1.83535931e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [-4.14983763e+00, -1.87998490e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 5.45284030e+00,  1.74075824e+00, -3.02373941e-01,
         1.54761915e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02,  1.08201245e-01],
       [ 1.91534915e+00, -3.58559378e+00, -3.02373941e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -2.18033250e-01, -1.91024542e-01],
       [-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [ 4.68026578e+00,  1.44937979e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.42733446e+00, -1.30585920e-01],
       [ 6.04219020e+00,  1.30955449e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.35887119e-02,  1.08201245e-01],
       [ 4.79211754e+00, -1.02131687e-03,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.54396007e+00, -1.30585920e-01],
       [-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [ 6.44204529e+00, -5.33143660e-01,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -1.01846663e-01, -1.30585920e-01],
       [ 5.71326724e+00,  1.66361659e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -8.35388462e-02,  1.08201245e-01],
       [-4.44524486e-02, -3.30757442e+00, -2.22220667e-01,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
       [-4.04089575e+00, -1.98892677e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [ 4.97731472e+00,  1.03570520e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.54396007e+00, -1.30585920e-01],
       [ 3.33277721e+00,  2.57971638e+00,  1.75061948e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -1.46053878e-01,  1.59685997e-01],
       [ 5.91554046e+00,  1.47107445e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.02227896e-02, -1.30585920e-01],
       [-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 6.58941987e+00,  8.37323978e-01,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -9.51627192e-02, -1.30585920e-01],
       [ 2.60384000e+00, -4.47425562e-01,  9.22989947e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -2.50891294e-01, -1.91024542e-01],
       [-5.13327016e+00, -1.40029112e-01,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.83302972e-01,  1.13559266e-01],
       [ 5.68476218e+00,  1.66361659e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -9.02227896e-02,  1.08201245e-01],
       [-5.16359000e+00,  9.08194313e-01,  7.00742684e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.44841380e-01,  1.13559266e-01],
       [-3.45011731e+00,  2.42737406e+00,  7.00742684e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.18269527e-01,  1.59685997e-01],
       [-3.79519642e+00, -2.13501758e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.71910748e-01, -1.91024542e-01],
       [ 6.07069525e+00,  1.30955449e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.69047685e-02,  1.08201245e-01],
       [ 4.36485172e+00,  1.75760771e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00,  1.08201245e-01],
       [-5.06125901e+00,  4.69742856e-02,  4.39282524e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.85268838e-01, -1.36875771e-01],
       [ 2.94871511e+00,  1.16633540e+00,  1.27291407e-01,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  2.49448815e+00, -1.91024542e-01],
       [ 5.94404552e+00,  1.47107445e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02, -1.30585920e-01],
       [ 3.45756491e+00,  6.67898319e-01,  1.75061948e-01,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -2.16614424e-01,  1.59685997e-01],
       [-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.94008613e+00,  1.97020284e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [ 4.76751433e+00,  1.38394169e+00,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.43894375e+00, -1.30585920e-01],
       [-5.19412194e+00,  9.21530487e-01, -5.74065812e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.44841380e-01,  1.13559266e-01],
       [-4.41536294e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [-4.81358717e+00, -8.71694393e-01, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.83302972e-01,  1.13559266e-01],
       [-4.81358717e+00, -1.02284420e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.83302972e-01,  1.13559266e-01],
       [ 4.60306368e+00,  1.51015618e+00,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.43894375e+00,  1.08201245e-01],
       [ 6.44204529e+00, -3.81993851e-01,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -1.01846663e-01, -1.30585920e-01],
       [ 4.78287797e+00, -2.68434791e-01,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.58661936e+00, -1.30585920e-01],
       [-7.61286116e-01, -3.12572213e+00, -1.47042132e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
       [ 4.49503812e+00,  1.58729784e+00,  1.60711127e-01,
         1.54761915e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00,  1.08201245e-01],
       [ 4.49106620e+00,  1.63139322e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00,  1.08201245e-01],
       [-5.11681622e+00,  1.45285302e+00, -5.74065812e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [ 4.96807515e+00,  1.03570520e+00,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.58661936e+00, -1.30585920e-01],
       [-3.27354855e+00, -2.53607305e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.71910748e-01, -1.91024542e-01],
       [-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 6.08042636e+00,  1.30955449e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.35887119e-02, -1.30585920e-01],
       [ 4.93718394e+00, -3.09827117e-01,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.58661936e+00,  1.08201245e-01],
       [-5.17991076e+00,  1.21827941e+00, -5.74065812e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [ 3.65377114e+00,  2.45184202e+00,  9.22989947e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -1.42782735e-01,  1.59685997e-01],
       [-5.07603471e+00,  3.36381122e-02,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
       [-4.53595532e+00, -1.58787129e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.84721798e-01, -1.36875771e-01],
       [-3.45512069e+00, -2.34506634e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.81345309e-01, -1.91024542e-01],
       [-5.12608560e+00,  9.08194313e-01,  4.39282524e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.46807246e-01, -1.36875771e-01],
       [ 5.71326724e+00,  1.66361659e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02,  1.08201245e-01],
       [-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 6.58941987e+00,  8.37323978e-01,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.51627192e-02, -1.30585920e-01],
       [-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-5.07603471e+00,  3.36381122e-02,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
       [ 4.57339775e+00,  1.58729784e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00, -1.30585920e-01],
       [ 5.53214335e+00,  1.84474047e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02,  1.08201245e-01],
       [ 6.66092991e+00,  7.65813945e-01,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.51627192e-02, -1.30585920e-01]])

It looks like it’s returned our outputs just as an array instead of a SHAP explanation object. Run the code below to turn our object into a proper shap.Explanation() object, as this is what all the plotting functions will be expecting.

# Create an Explanation object
shap_values = shap.Explanation(
    values=shap_values,
    base_values=explainer.expected_value,
    data=X_test.values,
    feature_names=X.columns
    )

Now let’s see what this looks like instead.

shap_values
.values =
array([[-5.17051824e+00,  1.21827941e+00, -5.74065812e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01, -1.36875771e-01],
       [ 1.13075847e-01,  1.78565094e+00, -4.36455202e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -2.35543021e-01, -1.91024542e-01],
       [-2.72566813e-02, -3.30757442e+00, -2.22220667e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
       [-5.10160202e+00,  1.83535931e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [-4.14983763e+00, -1.87998490e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 5.45284030e+00,  1.74075824e+00, -3.02373941e-01,
         1.54761915e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02,  1.08201245e-01],
       [ 1.91534915e+00, -3.58559378e+00, -3.02373941e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -2.18033250e-01, -1.91024542e-01],
       [-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [ 4.68026578e+00,  1.44937979e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.42733446e+00, -1.30585920e-01],
       [ 6.04219020e+00,  1.30955449e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.35887119e-02,  1.08201245e-01],
       [ 4.79211754e+00, -1.02131687e-03,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.54396007e+00, -1.30585920e-01],
       [-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [ 6.44204529e+00, -5.33143660e-01,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -1.01846663e-01, -1.30585920e-01],
       [ 5.71326724e+00,  1.66361659e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -8.35388462e-02,  1.08201245e-01],
       [-4.44524486e-02, -3.30757442e+00, -2.22220667e-01,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
       [-4.04089575e+00, -1.98892677e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [ 4.97731472e+00,  1.03570520e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.54396007e+00, -1.30585920e-01],
       [ 3.33277721e+00,  2.57971638e+00,  1.75061948e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -1.46053878e-01,  1.59685997e-01],
       [ 5.91554046e+00,  1.47107445e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.02227896e-02, -1.30585920e-01],
       [-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 6.58941987e+00,  8.37323978e-01,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -9.51627192e-02, -1.30585920e-01],
       [ 2.60384000e+00, -4.47425562e-01,  9.22989947e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -2.50891294e-01, -1.91024542e-01],
       [-5.13327016e+00, -1.40029112e-01,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.83302972e-01,  1.13559266e-01],
       [ 5.68476218e+00,  1.66361659e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -9.02227896e-02,  1.08201245e-01],
       [-5.16359000e+00,  9.08194313e-01,  7.00742684e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.44841380e-01,  1.13559266e-01],
       [-3.45011731e+00,  2.42737406e+00,  7.00742684e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.18269527e-01,  1.59685997e-01],
       [-3.79519642e+00, -2.13501758e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.71910748e-01, -1.91024542e-01],
       [ 6.07069525e+00,  1.30955449e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.69047685e-02,  1.08201245e-01],
       [ 4.36485172e+00,  1.75760771e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00,  1.08201245e-01],
       [-5.06125901e+00,  4.69742856e-02,  4.39282524e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.85268838e-01, -1.36875771e-01],
       [ 2.94871511e+00,  1.16633540e+00,  1.27291407e-01,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  2.49448815e+00, -1.91024542e-01],
       [ 5.94404552e+00,  1.47107445e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02, -1.30585920e-01],
       [ 3.45756491e+00,  6.67898319e-01,  1.75061948e-01,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -2.16614424e-01,  1.59685997e-01],
       [-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.94008613e+00,  1.97020284e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [ 4.76751433e+00,  1.38394169e+00,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.43894375e+00, -1.30585920e-01],
       [-5.19412194e+00,  9.21530487e-01, -5.74065812e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.44841380e-01,  1.13559266e-01],
       [-4.41536294e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [-4.81358717e+00, -8.71694393e-01, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.83302972e-01,  1.13559266e-01],
       [-4.81358717e+00, -1.02284420e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.83302972e-01,  1.13559266e-01],
       [ 4.60306368e+00,  1.51015618e+00,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.43894375e+00,  1.08201245e-01],
       [ 6.44204529e+00, -3.81993851e-01,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -1.01846663e-01, -1.30585920e-01],
       [ 4.78287797e+00, -2.68434791e-01,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.58661936e+00, -1.30585920e-01],
       [-7.61286116e-01, -3.12572213e+00, -1.47042132e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
       [ 4.49503812e+00,  1.58729784e+00,  1.60711127e-01,
         1.54761915e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00,  1.08201245e-01],
       [ 4.49106620e+00,  1.63139322e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00,  1.08201245e-01],
       [-5.11681622e+00,  1.45285302e+00, -5.74065812e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [ 4.96807515e+00,  1.03570520e+00,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.58661936e+00, -1.30585920e-01],
       [-3.27354855e+00, -2.53607305e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.71910748e-01, -1.91024542e-01],
       [-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 6.08042636e+00,  1.30955449e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.35887119e-02, -1.30585920e-01],
       [ 4.93718394e+00, -3.09827117e-01,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.58661936e+00,  1.08201245e-01],
       [-5.17991076e+00,  1.21827941e+00, -5.74065812e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [ 3.65377114e+00,  2.45184202e+00,  9.22989947e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -1.42782735e-01,  1.59685997e-01],
       [-5.07603471e+00,  3.36381122e-02,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
       [-4.53595532e+00, -1.58787129e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.84721798e-01, -1.36875771e-01],
       [-3.45512069e+00, -2.34506634e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.81345309e-01, -1.91024542e-01],
       [-5.12608560e+00,  9.08194313e-01,  4.39282524e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.46807246e-01, -1.36875771e-01],
       [ 5.71326724e+00,  1.66361659e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02,  1.08201245e-01],
       [-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 6.58941987e+00,  8.37323978e-01,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.51627192e-02, -1.30585920e-01],
       [-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-5.07603471e+00,  3.36381122e-02,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
       [ 4.57339775e+00,  1.58729784e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00, -1.30585920e-01],
       [ 5.53214335e+00,  1.84474047e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02,  1.08201245e-01],
       [ 6.66092991e+00,  7.65813945e-01,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.51627192e-02, -1.30585920e-01]])

.base_values =
-0.7156927563131259

.data =
array([[ 32.1,  15.5, 188. , ...,   1. ,   0. ,   0. ],
       [ 33.1,  16.1, 178. , ...,   1. ,   0. ,   0. ],
       [ 33.5,  19. , 190. , ...,   0. ,   1. ,   0. ],
       ...,
       [ 55.9,  17. , 228. , ...,   0. ,   0. ,   1. ],
       [ 59.6,  17. , 230. , ...,   0. ,   0. ,   1. ],
       [  nan,   nan,   nan, ...,   0. ,   0. ,   nan]])

Finally, let’s grab just the numeric component (our actual shap values).

shap_values_numeric = shap_values.values
shap_values_numeric
array([[-5.17051824e+00,  1.21827941e+00, -5.74065812e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01, -1.36875771e-01],
       [ 1.13075847e-01,  1.78565094e+00, -4.36455202e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -2.35543021e-01, -1.91024542e-01],
       [-2.72566813e-02, -3.30757442e+00, -2.22220667e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
       [-5.10160202e+00,  1.83535931e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [-4.14983763e+00, -1.87998490e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 5.45284030e+00,  1.74075824e+00, -3.02373941e-01,
         1.54761915e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02,  1.08201245e-01],
       [ 1.91534915e+00, -3.58559378e+00, -3.02373941e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -2.18033250e-01, -1.91024542e-01],
       [-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [ 4.68026578e+00,  1.44937979e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.42733446e+00, -1.30585920e-01],
       [ 6.04219020e+00,  1.30955449e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.35887119e-02,  1.08201245e-01],
       [ 4.79211754e+00, -1.02131687e-03,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.54396007e+00, -1.30585920e-01],
       [-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [ 6.44204529e+00, -5.33143660e-01,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -1.01846663e-01, -1.30585920e-01],
       [ 5.71326724e+00,  1.66361659e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -8.35388462e-02,  1.08201245e-01],
       [-4.44524486e-02, -3.30757442e+00, -2.22220667e-01,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
       [-4.04089575e+00, -1.98892677e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [ 4.97731472e+00,  1.03570520e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.54396007e+00, -1.30585920e-01],
       [ 3.33277721e+00,  2.57971638e+00,  1.75061948e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -1.46053878e-01,  1.59685997e-01],
       [ 5.91554046e+00,  1.47107445e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.02227896e-02, -1.30585920e-01],
       [-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 6.58941987e+00,  8.37323978e-01,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -9.51627192e-02, -1.30585920e-01],
       [ 2.60384000e+00, -4.47425562e-01,  9.22989947e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -2.50891294e-01, -1.91024542e-01],
       [-5.13327016e+00, -1.40029112e-01,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.83302972e-01,  1.13559266e-01],
       [ 5.68476218e+00,  1.66361659e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -9.02227896e-02,  1.08201245e-01],
       [-5.16359000e+00,  9.08194313e-01,  7.00742684e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.44841380e-01,  1.13559266e-01],
       [-3.45011731e+00,  2.42737406e+00,  7.00742684e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.18269527e-01,  1.59685997e-01],
       [-3.79519642e+00, -2.13501758e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.71910748e-01, -1.91024542e-01],
       [ 6.07069525e+00,  1.30955449e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.69047685e-02,  1.08201245e-01],
       [ 4.36485172e+00,  1.75760771e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00,  1.08201245e-01],
       [-5.06125901e+00,  4.69742856e-02,  4.39282524e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.85268838e-01, -1.36875771e-01],
       [ 2.94871511e+00,  1.16633540e+00,  1.27291407e-01,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  2.49448815e+00, -1.91024542e-01],
       [ 5.94404552e+00,  1.47107445e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02, -1.30585920e-01],
       [ 3.45756491e+00,  6.67898319e-01,  1.75061948e-01,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -2.16614424e-01,  1.59685997e-01],
       [-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-4.94008613e+00,  1.97020284e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [ 4.76751433e+00,  1.38394169e+00,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.43894375e+00, -1.30585920e-01],
       [-5.19412194e+00,  9.21530487e-01, -5.74065812e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.44841380e-01,  1.13559266e-01],
       [-4.41536294e+00, -1.58787129e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [-4.81358717e+00, -8.71694393e-01, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.83302972e-01,  1.13559266e-01],
       [-4.81358717e+00, -1.02284420e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.83302972e-01,  1.13559266e-01],
       [ 4.60306368e+00,  1.51015618e+00,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.43894375e+00,  1.08201245e-01],
       [ 6.44204529e+00, -3.81993851e-01,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -1.01846663e-01, -1.30585920e-01],
       [ 4.78287797e+00, -2.68434791e-01,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.58661936e+00, -1.30585920e-01],
       [-7.61286116e-01, -3.12572213e+00, -1.47042132e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
       [ 4.49503812e+00,  1.58729784e+00,  1.60711127e-01,
         1.54761915e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00,  1.08201245e-01],
       [ 4.49106620e+00,  1.63139322e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00,  1.08201245e-01],
       [-5.11681622e+00,  1.45285302e+00, -5.74065812e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [ 4.96807515e+00,  1.03570520e+00,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.58661936e+00, -1.30585920e-01],
       [-3.27354855e+00, -2.53607305e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.71910748e-01, -1.91024542e-01],
       [-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 6.08042636e+00,  1.30955449e+00,  9.22989947e-02,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.35887119e-02, -1.30585920e-01],
       [ 4.93718394e+00, -3.09827117e-01,  1.27291407e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.58661936e+00,  1.08201245e-01],
       [-5.17991076e+00,  1.21827941e+00, -5.74065812e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.35151969e-01,  1.13559266e-01],
       [ 3.65377114e+00,  2.45184202e+00,  9.22989947e-02,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -1.42782735e-01,  1.59685997e-01],
       [-5.07603471e+00,  3.36381122e-02,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
       [-4.53595532e+00, -1.58787129e+00,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.84721798e-01, -1.36875771e-01],
       [-3.45512069e+00, -2.34506634e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.81345309e-01, -1.91024542e-01],
       [-5.12608560e+00,  9.08194313e-01,  4.39282524e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.46807246e-01, -1.36875771e-01],
       [ 5.71326724e+00,  1.66361659e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02,  1.08201245e-01],
       [-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
       [ 6.58941987e+00,  8.37323978e-01,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.51627192e-02, -1.30585920e-01],
       [-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
         4.70017666e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01, -3.84721798e-01,  1.13559266e-01],
       [-5.07603471e+00,  3.36381122e-02,  7.00742684e-02,
        -5.04409203e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
       [ 4.57339775e+00,  1.58729784e+00,  1.60711127e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
         2.77963270e-01,  1.39628446e+00, -1.30585920e-01],
       [ 5.53214335e+00,  1.84474047e+00,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -8.35388462e-02,  1.08201245e-01],
       [ 6.66092991e+00,  7.65813945e-01,  1.75061948e-01,
        -2.46472679e-02,  0.00000000e+00,  0.00000000e+00,
        -3.68462939e-01, -9.51627192e-02, -1.30585920e-01]])

18.3.2 Exploring the SHAP outputs

First, let’s just get a list of our most important features according to SHAP.

# get feature importance for comparison using MDI method
features = list(X_train)
feature_importances = model.feature_importances_
importances = pd.DataFrame(index=features)
importances['importance'] = feature_importances
importances['rank'] = importances['importance'].rank(ascending=False).values
importances.sort_values('rank').head()

# Get shapley importances
# Calculate mean Shapley value for each feature in trainign set
importances['mean_shapley_values'] = np.mean(
    shap_values_numeric, axis=0
    )

# Calculate mean absolute Shapley value for each feature in trainign set
# This will give us the average importance of each feature
importances['mean_abs_shapley_values'] = np.mean(
    np.abs(shap_values_numeric), axis=0
    )

importance_top = \
    importances.sort_values(
        by='importance', ascending=False
        ).index

shapley_top = \
    importances.sort_values(
        by='mean_abs_shapley_values',
        ascending=False).index

# Add to DataFrame
top_features = pd.DataFrame()
top_features['importances'] = importance_top.values
top_features['Shapley'] = shapley_top.values

# Display
top_features
importances Shapley
0 bill_length_mm bill_length_mm
1 island_Torgersen bill_depth_mm
2 bill_depth_mm island_Torgersen
3 flipper_length_mm island_Dream
4 male flipper_length_mm
5 island_Dream male
6 body_mass_g body_mass_g
7 year year
8 island_Biscoe island_Biscoe

18.3.2.1 SHAP plots

Generate a bar plot of the SHAP values.

shap.plots.bar(shap_values)

Generate a beeswarm plot.

shap.plots.beeswarm(shap_values)
DimensionError: Feature and SHAP matrices must have the same number of rows!

Generate a waterfall plot for 5 different examples from the dataset.

shap.plots.waterfall(shap_values[0])
shap.plots.waterfall(shap_values[3])
shap.plots.waterfall(shap_values[5])
shap.plots.waterfall(shap_values[-1])
shap.plots.waterfall(shap_values[194])

18.3.2.2 Dependence Plots for each Class (Species)

Let’s look at the columns in our dataset and the indices.

# Lets see the features and respective index numbers
for e, i in enumerate(X_test.columns):
    print(f"{e} - {i}")

First, generate a scatter plot for the bill length.

shap.plots.scatter(shap_values[:, 'bill_length_mm'])

Now colour this by bill depth.

shap.plots.scatter(shap_values[:, 'bill_length_mm'], color=shap_values[:,"bill_depth_mm"])

Now colour it by the most strongly interacting feature.

# create a dependence scatter plot to show the effect of a single feature across the whole dataset
shap.plots.scatter(shap_values[:, "bill_length_mm"], color=shap_values)

Now let’s iterate through and create scatter plots per column.

# dependence plots
fig, ax = plt.subplots(3, 3, figsize=(20,10))
ax = ax.ravel()

for idx, col_name in enumerate(feature_names):
    shap.plots.scatter(shap_values[:, col_name], show=False, ax=ax[idx])

18.3.2.3 Force Plots

Create a force plot for the whole dataset.

shap.plots.force(shap_values)

Create a force plot for five randomly chosen pieces of data.

shap.plots.force(shap_values[0])
shap.plots.force(shap_values[1])
shap.plots.force(shap_values[-1])
shap.plots.force(shap_values[185])
shap.plots.force(shap_values[247])