from sklearn import datasets
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn import metrics
from sklearn.inspection import permutation_importance
from sklearn.inspection import PartialDependenceDisplay
import shap
# Java Script for SHAP Plots
shap.initjs()
# Helper function to see methods in object
# Might be useful when working through this exercise
def object_methods(obj):
'''
Helper function to list methods associated with an object
'''
try:
= [method_name for method_name in dir(obj)
methods if callable(getattr(obj, method_name))]
print('Below are the methods for object: ', obj)
for method in methods:
print(method)
except:
print("Error")
18 Exercise Solution: Explainable AI (Penguins Classification Dataset)
In this notebook, we’ll be exploring how to use a couple of different explainable AI techniques.
We’ll be using a different dataset to take a look into this today. The penguins dataset is a great dataset for practising classification problems. This data has been pulled using the excellent palmerpenguins
package.
What we need to know is that
- The dataset is made up of 344 rows/ instances
- Each row has columns pertaining to sex, species, island on which they are found, bill length, bill depth, flipper length, and body mass.
- There are three species (classes) to consider: Adelie, Chinstrap and Gentoo.
- These are the targets.
In this exercise you will need go through the code and fill in any missing spaces.
By the end of this exercise you should know - how to calculate feature importance using the MDI method for tree-based models - how to calculate feature importance for any model using the permutation feature importance method - how to create partial dependence plots (PDPs) and individual conditionla expectation (ICE) plots for any model - how to use the SHAP library to understand a model
The SHAP code does vary subtly for different kinds of model; we will just be working with an XGboost model in this case to match what we’ve done in the lecture.
18.0.0.1 Library Imports
18.0.0.2 Load & Clean Data
Run this cell to load the dataframe.
= pd.read_csv("../datasets/penguins.csv") penguins
Examine the dataset with your choice(s) of function(s).
penguins.head()
bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | island_Biscoe | island_Dream | island_Torgersen | male | target | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 32.1 | 15.5 | 188.0 | 3050.0 | 2009 | 0 | 1 | 0 | 0.0 | 0 |
1 | 33.1 | 16.1 | 178.0 | 2900.0 | 2008 | 0 | 1 | 0 | 0.0 | 0 |
2 | 33.5 | 19.0 | 190.0 | 3600.0 | 2008 | 0 | 0 | 1 | 0.0 | 0 |
3 | 34.0 | 17.1 | 185.0 | 3400.0 | 2008 | 0 | 1 | 0 | 0.0 | 0 |
4 | 34.1 | 18.1 | 193.0 | 3475.0 | 2007 | 0 | 0 | 1 | NaN | 0 |
Run the code below to convert the classes 0, 1 and 2 into the relevant species names and add this as a new column.
Try to understand how this is working - it’s a useful little pattern to know for your own datasets!
First, we are going to create a dictionary. Can you remember what we call the parts before and the parts after the colon in the dictionary?
# Define the different classes/ species
= {0 : 'Adelie',
class_dict 1 : 'Chinstrap',
2 : 'Gentoo'}
class_dict
{0: 'Adelie', 1: 'Chinstrap', 2: 'Gentoo'}
Now we are going to use our dictionary for creating the column.
How do you think this is working? You may want to look up the get
method of the standard python dictionary, and the apply
method of pandas, just to understand a little more about this very useful way of making new conditional columns.
# Add species into the dataframe
'species'] = penguins['target'].apply(lambda x: class_dict.get(x))
penguins[
# view a random sample of 10 rows
10) penguins.sample(
bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | island_Biscoe | island_Dream | island_Torgersen | male | target | species | |
---|---|---|---|---|---|---|---|---|---|---|---|
197 | 50.7 | 19.7 | 203.0 | 4050.0 | 2009 | 0 | 1 | 0 | 1.0 | 1 | Chinstrap |
269 | 46.5 | 13.5 | 210.0 | 4550.0 | 2007 | 1 | 0 | 0 | 0.0 | 2 | Gentoo |
151 | NaN | NaN | NaN | NaN | 2007 | 0 | 0 | 1 | NaN | 0 | Adelie |
169 | 46.4 | 17.8 | 191.0 | 3700.0 | 2008 | 0 | 1 | 0 | 0.0 | 1 | Chinstrap |
12 | 35.2 | 15.9 | 186.0 | 3050.0 | 2009 | 0 | 0 | 1 | 0.0 | 0 | Adelie |
245 | 45.1 | 14.5 | 215.0 | 5000.0 | 2007 | 1 | 0 | 0 | 0.0 | 2 | Gentoo |
307 | 49.2 | 15.2 | 221.0 | 6300.0 | 2007 | 1 | 0 | 0 | 1.0 | 2 | Gentoo |
252 | 45.4 | 14.6 | 211.0 | 4800.0 | 2007 | 1 | 0 | 0 | 0.0 | 2 | Gentoo |
68 | 38.3 | 19.2 | 189.0 | 3950.0 | 2008 | 0 | 1 | 0 | 1.0 | 0 | Adelie |
207 | 51.5 | 18.7 | 187.0 | 3250.0 | 2009 | 0 | 1 | 0 | 1.0 | 1 | Chinstrap |
Let’s take a look at some stats about the data to get an idea of the scale and distribution of the different features. Run the cell below to do this.
# Take a look at some stats about the data
penguins.describe()
bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | island_Biscoe | island_Dream | island_Torgersen | male | target | |
---|---|---|---|---|---|---|---|---|---|---|
count | 342.000000 | 342.000000 | 342.000000 | 342.000000 | 344.000000 | 344.000000 | 344.000000 | 344.000000 | 333.000000 | 344.000000 |
mean | 43.921930 | 17.151170 | 200.915205 | 4201.754386 | 2008.029070 | 0.488372 | 0.360465 | 0.151163 | 0.504505 | 0.918605 |
std | 5.459584 | 1.974793 | 14.061714 | 801.954536 | 0.818356 | 0.500593 | 0.480835 | 0.358729 | 0.500732 | 0.893320 |
min | 32.100000 | 13.100000 | 172.000000 | 2700.000000 | 2007.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
25% | 39.225000 | 15.600000 | 190.000000 | 3550.000000 | 2007.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
50% | 44.450000 | 17.300000 | 197.000000 | 4050.000000 | 2008.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
75% | 48.500000 | 18.700000 | 213.000000 | 4750.000000 | 2009.000000 | 1.000000 | 1.000000 | 0.000000 | 1.000000 | 2.000000 |
max | 59.600000 | 21.500000 | 231.000000 | 6300.000000 | 2009.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 2.000000 |
18.0.0.3 Plot the Data
Before we go any further, let’s plot the iris dataset to see how the petal width and length relate to the species.
Fill in the gaps below to create the plot
= penguins[penguins.species=='Adelie']
adelie = penguins[penguins.species == "Chinstrap"]
chinstrap = penguins[penguins.species=='Gentoo']
gentoo
= plt.subplots()
fig, ax 13, 7) # adjusting the length and width of plot
fig.set_size_inches(
# lables and scatter points
'bill_length_mm'], adelie['bill_depth_mm'], label="Adelie", facecolor="blue")
ax.scatter(adelie['bill_length_mm'], chinstrap['bill_depth_mm'], label="Chinstrap", facecolor="green")
ax.scatter(chinstrap['bill_length_mm'], gentoo['bill_depth_mm'], label="Gentoo", facecolor="red")
ax.scatter(gentoo[
"Bill Length (mm)")
ax.set_xlabel("Bill Depth (mm)")
ax.set_ylabel(
ax.grid()"Gentoo Penguin Measurements")
ax.set_title( ax.legend()
Now it’s your turn; create this plot, but this time we are interested in the other two columns: sepal length and sepal width.
In the space below, make a copy of the plot that looks at flipper length and body mass.
= penguins[penguins.species=='Adelie']
adelie = penguins[penguins.species == "Chinstrap"]
chinstrap = penguins[penguins.species=='Gentoo']
gentoo
= plt.subplots()
fig, ax 13, 7) # adjusting the length and width of plot
fig.set_size_inches(
# lables and scatter points
'flipper_length_mm'], adelie['body_mass_g'], label="Adelie", facecolor="blue")
ax.scatter(adelie['flipper_length_mm'], chinstrap['body_mass_g'], label="Chinstrap", facecolor="green")
ax.scatter(chinstrap['flipper_length_mm'], gentoo['body_mass_g'], label="Gentoo", facecolor="red")
ax.scatter(gentoo[
"Flipper Length (mm)")
ax.set_xlabel("Body Mass (g)")
ax.set_ylabel(
ax.grid()"Gentoo Penguin Measurements")
ax.set_title( ax.legend()
18.0.0.4 Performing Classification
Time to get on to the machine learning aspect.
It’s always important to spend some time understanding your dataset first, though! What we’ve done above is just the tip of the iceberg, but it’s a good start.
Now we’re going to process our dataset for machine learning as we did in the logistic regression, decision tree and boosted tree sessions - this time we’re going to be using XGBoost.
# Droping the target and species since we only need the measurements
= penguins.drop(['target','species'], axis=1)
X
# Define features (X) and target (y)
= X
X = penguins['target']
y
# get class and features names
= penguins.species.unique()
class_names = X.columns
feature_names
# Splitting into train and test
= train_test_split(X, y, test_size=0.2,
X_train, X_test, y_train, y_test =42)
random_state
# Instantiate an XGBoost model and fit it
= XGBClassifier(random_state=42)
model
model.fit(X_train, y_train)
XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample_bytree=None, device=None, early_stopping_rounds=None, enable_categorical=False, eval_metric=None, feature_types=None, gamma=None, grow_policy=None, importance_type=None, interaction_constraints=None, learning_rate=None, max_bin=None, max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=None, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=None, n_jobs=None, num_parallel_tree=None, objective='multi:softprob', ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample_bytree=None, device=None, early_stopping_rounds=None, enable_categorical=False, eval_metric=None, feature_types=None, gamma=None, grow_policy=None, importance_type=None, interaction_constraints=None, learning_rate=None, max_bin=None, max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=None, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=None, n_jobs=None, num_parallel_tree=None, objective='multi:softprob', ...)
18.0.1 Feature Importance
18.0.1.0.1 Look at feature importance using the feature_importances_ attribute
# Looking at standard feature importance
# This attribute is the mean decrease in impurity for each feature
= model.feature_importances_
importances importances
array([0.10309331, 0.01894252, 0.5356301 , 0.00305671, 0.00406754,
0.07005206, 0.1871977 , 0.07796 , 0. ], dtype=float32)
18.0.1.1 Mean decrease in impurity
Generate a plot of the MDI feature importances.
= X.columns.tolist()
feature_names
= pd.Series(importances, index=feature_names)
model_importances_mdi_series
= plt.subplots(figsize=(15,10))
fig, ax =ax)
model_importances_mdi_series.plot.bar(ax"Feature importances using MDI")
ax.set_title("Mean decrease in impurity")
ax.set_ylabel( fig.tight_layout()
18.0.2 Permutation Feature Importance
Calculate permutation feature importance for this dataset and plot it.
This will include error bars.
# Permutation feature importance
= permutation_importance(
result =10, random_state=42)
model, X_test, y_test, n_repeats
= pd.Series(result.importances_mean, index=feature_names)
model_importances_pfi_series
= plt.subplots(figsize=(15,10))
fig, ax =result.importances_std, ax=ax)
model_importances_pfi_series.plot.bar(yerr"Feature importances using permutation on full model")
ax.set_title("Mean accuracy decrease")
ax.set_ylabel(
fig.tight_layout() plt.show()
18.0.2.1 Predictions
Use the model to make predictions for the training and test set
# Training predictions
= model.predict(X_train)
y_pred_train = model.predict(X_test) y_pred_test
18.0.2.2 Assessing Performance
Run this code to generate metrics for the training and test performance of this model.
= np.mean(y_pred_train == y_train)
accuracy_train = np.mean(y_pred_test == y_test)
accuracy_test
print (f'Accuracy of predicting training data = {accuracy_train:.2%}')
print (f'Accuracy of predicting test data = {accuracy_test:.2%}')
Accuracy of predicting training data = 100.00%
Accuracy of predicting test data = 97.10%
18.1 PDP Plots
Now let’s create a partial dependence plot for flipper length.
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X, =['flipper_length_mm'], # List of features to plot
features=0,
target='average', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()
Now create two plots side-by-side for bill length and bill depth.
HINT: You don’t need to create multiple separate plots using matplotlib for this - you can do it from within the graphing function we’re using from scikit-learn.
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X, =['bill_length_mm', 'bill_depth_mm'], # List of features to plot
features=0,
target='average', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()
18.2 ICE Plots
Now create three ICE plots of the same feature - one for each class. Make sure to give each plot a name.
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X, =['bill_length_mm'], # List of features to plot
features=0,
target='individual', # Type of PDP
kind=ax,
ax=42
random_state
)"Adelie Penguins - Bill Length ICE Plot")
plt.title( plt.show()
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X, =['bill_length_mm'], # List of features to plot
features=1,
target='individual', # Type of PDP
kind=ax,
ax=42
random_state
)"Chinstrap Penguins - Bill Length ICE Plot")
plt.title( plt.show()
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X, =['bill_length_mm'], # List of features to plot
features=2,
target='individual', # Type of PDP
kind=ax,
ax=42
random_state
)"Gentoo Penguins - Bill Length ICE Plot")
plt.title( plt.show()
Now, just for one of the classes, create an ICE plot for bill_length_mm that also shows the average of all the ICE plots - a joint PDP/ICE plot, effectively!
Again, make sure you provide a title.
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X, =['bill_length_mm'], # List of features to plot
features=2,
target='both', # Type of PDP
kind=ax,
ax=42
random_state
)"Adelie Penguins - Joint ICE/PDP Plot")
plt.title( plt.show()
18.2.0.1 BONUS: 2D PDP Plots
Now create a 2D plot of bill length and bill depth.
PartialDependenceDisplay.from_estimator(
model,
X_test,=[('bill_length_mm', 'bill_depth_mm')],
features='average',
kind=0,
target=0
random_state )
18.3 SHAP
We have a multiclass problem with our penguins dataset.
This results in some slightly different outputs from our SHAP code, which can be confusing to deal with, so for now we’re just going to focus on a binary classification problem - is a penguin an Adelie, or not?
Run the code below to turn this into a binary classification problem and retrain the model.
= penguins.copy()
penguins_binary
# If Adelie penguin, return 1, else return 0
'target'] = np.where(penguins_binary['target'] == 0, 1, 0)
penguins_binary['species'] = np.where(penguins_binary['species'] == "Adelie", "Adelie", "Not Adelie")
penguins_binary[
# Droping the target and species since we only need the measurements
= penguins_binary.drop(['target','species'], axis=1)
X
# Define features (X) and target (y)
= X
X = penguins_binary['target']
y
# get class and features names
= penguins_binary.species.unique()
class_names = X.columns
feature_names
# Splitting into train and test
= train_test_split(X, y, test_size=0.2,
X_train, X_test, y_train, y_test =42)
random_state
# Instantiate an XGBoost model and fit it
= XGBClassifier(random_state=42)
model
model.fit(X_train, y_train)
XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample_bytree=None, device=None, early_stopping_rounds=None, enable_categorical=False, eval_metric=None, feature_types=None, gamma=None, grow_policy=None, importance_type=None, interaction_constraints=None, learning_rate=None, max_bin=None, max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=None, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=None, n_jobs=None, num_parallel_tree=None, random_state=42, ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample_bytree=None, device=None, early_stopping_rounds=None, enable_categorical=False, eval_metric=None, feature_types=None, gamma=None, grow_policy=None, importance_type=None, interaction_constraints=None, learning_rate=None, max_bin=None, max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=None, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=None, n_jobs=None, num_parallel_tree=None, random_state=42, ...)
18.3.1 Obtaining Shap Values
Generate a SHAP explainer for our model, using X_train as the background data.
# Compute SHAP values
= shap.Explainer(
explainer
model,
X_train )
Now create the shap_values object, using X_test as the foreground data.
= explainer.shap_values(X_test)
shap_values shap_values
array([[-5.17051824e+00, 1.21827941e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, -1.36875771e-01],
[ 1.13075847e-01, 1.78565094e+00, -4.36455202e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.35543021e-01, -1.91024542e-01],
[-2.72566813e-02, -3.30757442e+00, -2.22220667e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-5.10160202e+00, 1.83535931e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[-4.14983763e+00, -1.87998490e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 5.45284030e+00, 1.74075824e+00, -3.02373941e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 1.91534915e+00, -3.58559378e+00, -3.02373941e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -2.18033250e-01, -1.91024542e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.68026578e+00, 1.44937979e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.42733446e+00, -1.30585920e-01],
[ 6.04219020e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, 1.08201245e-01],
[ 4.79211754e+00, -1.02131687e-03, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 6.44204529e+00, -5.33143660e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.01846663e-01, -1.30585920e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -8.35388462e-02, 1.08201245e-01],
[-4.44524486e-02, -3.30757442e+00, -2.22220667e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-4.04089575e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.97731472e+00, 1.03570520e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[ 3.33277721e+00, 2.57971638e+00, 1.75061948e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.46053878e-01, 1.59685997e-01],
[ 5.91554046e+00, 1.47107445e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.02227896e-02, -1.30585920e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.51627192e-02, -1.30585920e-01],
[ 2.60384000e+00, -4.47425562e-01, 9.22989947e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.50891294e-01, -1.91024542e-01],
[-5.13327016e+00, -1.40029112e-01, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, 1.13559266e-01],
[ 5.68476218e+00, 1.66361659e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.02227896e-02, 1.08201245e-01],
[-5.16359000e+00, 9.08194313e-01, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-3.45011731e+00, 2.42737406e+00, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.18269527e-01, 1.59685997e-01],
[-3.79519642e+00, -2.13501758e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.71910748e-01, -1.91024542e-01],
[ 6.07069525e+00, 1.30955449e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.69047685e-02, 1.08201245e-01],
[ 4.36485172e+00, 1.75760771e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.06125901e+00, 4.69742856e-02, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.85268838e-01, -1.36875771e-01],
[ 2.94871511e+00, 1.16633540e+00, 1.27291407e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 2.49448815e+00, -1.91024542e-01],
[ 5.94404552e+00, 1.47107445e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, -1.30585920e-01],
[ 3.45756491e+00, 6.67898319e-01, 1.75061948e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.16614424e-01, 1.59685997e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.94008613e+00, 1.97020284e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.76751433e+00, 1.38394169e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, -1.30585920e-01],
[-5.19412194e+00, 9.21530487e-01, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-4.41536294e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[-4.81358717e+00, -8.71694393e-01, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[-4.81358717e+00, -1.02284420e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[ 4.60306368e+00, 1.51015618e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, 1.08201245e-01],
[ 6.44204529e+00, -3.81993851e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -1.01846663e-01, -1.30585920e-01],
[ 4.78287797e+00, -2.68434791e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-7.61286116e-01, -3.12572213e+00, -1.47042132e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[ 4.49503812e+00, 1.58729784e+00, 1.60711127e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[ 4.49106620e+00, 1.63139322e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.11681622e+00, 1.45285302e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.96807515e+00, 1.03570520e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-3.27354855e+00, -2.53607305e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.71910748e-01, -1.91024542e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.08042636e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, -1.30585920e-01],
[ 4.93718394e+00, -3.09827117e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, 1.08201245e-01],
[-5.17991076e+00, 1.21827941e+00, -5.74065812e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 3.65377114e+00, 2.45184202e+00, 9.22989947e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.42782735e-01, 1.59685997e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[-4.53595532e+00, -1.58787129e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.84721798e-01, -1.36875771e-01],
[-3.45512069e+00, -2.34506634e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.81345309e-01, -1.91024542e-01],
[-5.12608560e+00, 9.08194313e-01, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.46807246e-01, -1.36875771e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[ 4.57339775e+00, 1.58729784e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, -1.30585920e-01],
[ 5.53214335e+00, 1.84474047e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 6.66092991e+00, 7.65813945e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01]])
It looks like it’s returned our outputs just as an array instead of a SHAP explanation object. Run the code below to turn our object into a proper shap.Explanation() object, as this is what all the plotting functions will be expecting.
# Create an Explanation object
= shap.Explanation(
shap_values =shap_values,
values=explainer.expected_value,
base_values=X_test.values,
data=X.columns
feature_names )
Now let’s see what this looks like instead.
shap_values
.values =
array([[-5.17051824e+00, 1.21827941e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, -1.36875771e-01],
[ 1.13075847e-01, 1.78565094e+00, -4.36455202e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.35543021e-01, -1.91024542e-01],
[-2.72566813e-02, -3.30757442e+00, -2.22220667e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-5.10160202e+00, 1.83535931e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[-4.14983763e+00, -1.87998490e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 5.45284030e+00, 1.74075824e+00, -3.02373941e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 1.91534915e+00, -3.58559378e+00, -3.02373941e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -2.18033250e-01, -1.91024542e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.68026578e+00, 1.44937979e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.42733446e+00, -1.30585920e-01],
[ 6.04219020e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, 1.08201245e-01],
[ 4.79211754e+00, -1.02131687e-03, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 6.44204529e+00, -5.33143660e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.01846663e-01, -1.30585920e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -8.35388462e-02, 1.08201245e-01],
[-4.44524486e-02, -3.30757442e+00, -2.22220667e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-4.04089575e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.97731472e+00, 1.03570520e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[ 3.33277721e+00, 2.57971638e+00, 1.75061948e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.46053878e-01, 1.59685997e-01],
[ 5.91554046e+00, 1.47107445e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.02227896e-02, -1.30585920e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.51627192e-02, -1.30585920e-01],
[ 2.60384000e+00, -4.47425562e-01, 9.22989947e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.50891294e-01, -1.91024542e-01],
[-5.13327016e+00, -1.40029112e-01, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, 1.13559266e-01],
[ 5.68476218e+00, 1.66361659e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.02227896e-02, 1.08201245e-01],
[-5.16359000e+00, 9.08194313e-01, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-3.45011731e+00, 2.42737406e+00, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.18269527e-01, 1.59685997e-01],
[-3.79519642e+00, -2.13501758e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.71910748e-01, -1.91024542e-01],
[ 6.07069525e+00, 1.30955449e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.69047685e-02, 1.08201245e-01],
[ 4.36485172e+00, 1.75760771e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.06125901e+00, 4.69742856e-02, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.85268838e-01, -1.36875771e-01],
[ 2.94871511e+00, 1.16633540e+00, 1.27291407e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 2.49448815e+00, -1.91024542e-01],
[ 5.94404552e+00, 1.47107445e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, -1.30585920e-01],
[ 3.45756491e+00, 6.67898319e-01, 1.75061948e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.16614424e-01, 1.59685997e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.94008613e+00, 1.97020284e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.76751433e+00, 1.38394169e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, -1.30585920e-01],
[-5.19412194e+00, 9.21530487e-01, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-4.41536294e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[-4.81358717e+00, -8.71694393e-01, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[-4.81358717e+00, -1.02284420e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[ 4.60306368e+00, 1.51015618e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, 1.08201245e-01],
[ 6.44204529e+00, -3.81993851e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -1.01846663e-01, -1.30585920e-01],
[ 4.78287797e+00, -2.68434791e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-7.61286116e-01, -3.12572213e+00, -1.47042132e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[ 4.49503812e+00, 1.58729784e+00, 1.60711127e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[ 4.49106620e+00, 1.63139322e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.11681622e+00, 1.45285302e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.96807515e+00, 1.03570520e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-3.27354855e+00, -2.53607305e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.71910748e-01, -1.91024542e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.08042636e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, -1.30585920e-01],
[ 4.93718394e+00, -3.09827117e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, 1.08201245e-01],
[-5.17991076e+00, 1.21827941e+00, -5.74065812e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 3.65377114e+00, 2.45184202e+00, 9.22989947e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.42782735e-01, 1.59685997e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[-4.53595532e+00, -1.58787129e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.84721798e-01, -1.36875771e-01],
[-3.45512069e+00, -2.34506634e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.81345309e-01, -1.91024542e-01],
[-5.12608560e+00, 9.08194313e-01, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.46807246e-01, -1.36875771e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[ 4.57339775e+00, 1.58729784e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, -1.30585920e-01],
[ 5.53214335e+00, 1.84474047e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 6.66092991e+00, 7.65813945e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01]])
.base_values =
-0.7156927563131259
.data =
array([[ 32.1, 15.5, 188. , ..., 1. , 0. , 0. ],
[ 33.1, 16.1, 178. , ..., 1. , 0. , 0. ],
[ 33.5, 19. , 190. , ..., 0. , 1. , 0. ],
...,
[ 55.9, 17. , 228. , ..., 0. , 0. , 1. ],
[ 59.6, 17. , 230. , ..., 0. , 0. , 1. ],
[ nan, nan, nan, ..., 0. , 0. , nan]])
Finally, let’s grab just the numeric component (our actual shap values).
= shap_values.values
shap_values_numeric shap_values_numeric
array([[-5.17051824e+00, 1.21827941e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, -1.36875771e-01],
[ 1.13075847e-01, 1.78565094e+00, -4.36455202e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.35543021e-01, -1.91024542e-01],
[-2.72566813e-02, -3.30757442e+00, -2.22220667e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-5.10160202e+00, 1.83535931e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[-4.14983763e+00, -1.87998490e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 5.45284030e+00, 1.74075824e+00, -3.02373941e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 1.91534915e+00, -3.58559378e+00, -3.02373941e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -2.18033250e-01, -1.91024542e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.68026578e+00, 1.44937979e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.42733446e+00, -1.30585920e-01],
[ 6.04219020e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, 1.08201245e-01],
[ 4.79211754e+00, -1.02131687e-03, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 6.44204529e+00, -5.33143660e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.01846663e-01, -1.30585920e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -8.35388462e-02, 1.08201245e-01],
[-4.44524486e-02, -3.30757442e+00, -2.22220667e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-4.04089575e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.97731472e+00, 1.03570520e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[ 3.33277721e+00, 2.57971638e+00, 1.75061948e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.46053878e-01, 1.59685997e-01],
[ 5.91554046e+00, 1.47107445e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.02227896e-02, -1.30585920e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.51627192e-02, -1.30585920e-01],
[ 2.60384000e+00, -4.47425562e-01, 9.22989947e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.50891294e-01, -1.91024542e-01],
[-5.13327016e+00, -1.40029112e-01, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, 1.13559266e-01],
[ 5.68476218e+00, 1.66361659e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.02227896e-02, 1.08201245e-01],
[-5.16359000e+00, 9.08194313e-01, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-3.45011731e+00, 2.42737406e+00, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.18269527e-01, 1.59685997e-01],
[-3.79519642e+00, -2.13501758e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.71910748e-01, -1.91024542e-01],
[ 6.07069525e+00, 1.30955449e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.69047685e-02, 1.08201245e-01],
[ 4.36485172e+00, 1.75760771e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.06125901e+00, 4.69742856e-02, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.85268838e-01, -1.36875771e-01],
[ 2.94871511e+00, 1.16633540e+00, 1.27291407e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 2.49448815e+00, -1.91024542e-01],
[ 5.94404552e+00, 1.47107445e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, -1.30585920e-01],
[ 3.45756491e+00, 6.67898319e-01, 1.75061948e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.16614424e-01, 1.59685997e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.94008613e+00, 1.97020284e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.76751433e+00, 1.38394169e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, -1.30585920e-01],
[-5.19412194e+00, 9.21530487e-01, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-4.41536294e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[-4.81358717e+00, -8.71694393e-01, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[-4.81358717e+00, -1.02284420e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[ 4.60306368e+00, 1.51015618e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, 1.08201245e-01],
[ 6.44204529e+00, -3.81993851e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -1.01846663e-01, -1.30585920e-01],
[ 4.78287797e+00, -2.68434791e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-7.61286116e-01, -3.12572213e+00, -1.47042132e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[ 4.49503812e+00, 1.58729784e+00, 1.60711127e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[ 4.49106620e+00, 1.63139322e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.11681622e+00, 1.45285302e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.96807515e+00, 1.03570520e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-3.27354855e+00, -2.53607305e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.71910748e-01, -1.91024542e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.08042636e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, -1.30585920e-01],
[ 4.93718394e+00, -3.09827117e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, 1.08201245e-01],
[-5.17991076e+00, 1.21827941e+00, -5.74065812e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 3.65377114e+00, 2.45184202e+00, 9.22989947e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.42782735e-01, 1.59685997e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[-4.53595532e+00, -1.58787129e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.84721798e-01, -1.36875771e-01],
[-3.45512069e+00, -2.34506634e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.81345309e-01, -1.91024542e-01],
[-5.12608560e+00, 9.08194313e-01, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.46807246e-01, -1.36875771e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[ 4.57339775e+00, 1.58729784e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, -1.30585920e-01],
[ 5.53214335e+00, 1.84474047e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 6.66092991e+00, 7.65813945e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01]])
18.3.2 Exploring the SHAP outputs
First, let’s just get a list of our most important features according to SHAP.
# get feature importance for comparison using MDI method
= list(X_train)
features = model.feature_importances_
feature_importances = pd.DataFrame(index=features)
importances 'importance'] = feature_importances
importances['rank'] = importances['importance'].rank(ascending=False).values
importances['rank').head()
importances.sort_values(
# Get shapley importances
# Calculate mean Shapley value for each feature in trainign set
'mean_shapley_values'] = np.mean(
importances[=0
shap_values_numeric, axis
)
# Calculate mean absolute Shapley value for each feature in trainign set
# This will give us the average importance of each feature
'mean_abs_shapley_values'] = np.mean(
importances[abs(shap_values_numeric), axis=0
np.
)
= \
importance_top
importances.sort_values(='importance', ascending=False
by
).index
= \
shapley_top
importances.sort_values(='mean_abs_shapley_values',
by=False).index
ascending
# Add to DataFrame
= pd.DataFrame()
top_features 'importances'] = importance_top.values
top_features['Shapley'] = shapley_top.values
top_features[
# Display
top_features
importances | Shapley | |
---|---|---|
0 | bill_length_mm | bill_length_mm |
1 | island_Torgersen | bill_depth_mm |
2 | bill_depth_mm | island_Torgersen |
3 | flipper_length_mm | island_Dream |
4 | male | flipper_length_mm |
5 | island_Dream | male |
6 | body_mass_g | body_mass_g |
7 | year | year |
8 | island_Biscoe | island_Biscoe |
18.3.2.1 SHAP plots
Generate a bar plot of the SHAP values.
shap.plots.bar(shap_values)
Generate a beeswarm plot.
shap.plots.beeswarm(shap_values)
DimensionError: Feature and SHAP matrices must have the same number of rows!
Generate a waterfall plot for 5 different examples from the dataset.
0]) shap.plots.waterfall(shap_values[
3]) shap.plots.waterfall(shap_values[
5]) shap.plots.waterfall(shap_values[
-1]) shap.plots.waterfall(shap_values[
194]) shap.plots.waterfall(shap_values[
18.3.2.2 Dependence Plots for each Class (Species)
Let’s look at the columns in our dataset and the indices.
# Lets see the features and respective index numbers
for e, i in enumerate(X_test.columns):
print(f"{e} - {i}")
First, generate a scatter plot for the bill length.
'bill_length_mm']) shap.plots.scatter(shap_values[:,
Now colour this by bill depth.
'bill_length_mm'], color=shap_values[:,"bill_depth_mm"]) shap.plots.scatter(shap_values[:,
Now colour it by the most strongly interacting feature.
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
"bill_length_mm"], color=shap_values) shap.plots.scatter(shap_values[:,
Now let’s iterate through and create scatter plots per column.
# dependence plots
= plt.subplots(3, 3, figsize=(20,10))
fig, ax = ax.ravel()
ax
for idx, col_name in enumerate(feature_names):
=False, ax=ax[idx]) shap.plots.scatter(shap_values[:, col_name], show
18.3.2.3 Force Plots
Create a force plot for the whole dataset.
shap.plots.force(shap_values)
Create a force plot for five randomly chosen pieces of data.
0]) shap.plots.force(shap_values[
1]) shap.plots.force(shap_values[
-1]) shap.plots.force(shap_values[
185]) shap.plots.force(shap_values[
247]) shap.plots.force(shap_values[