15  SHAP with XGBoost (Titanic Dataset)

import xgboost
import shap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Import machine learning methods
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from sklearn.ensemble import RandomForestClassifier

# Import shap for shapley values
import shap

# JavaScript Important for the interactive charts later on
download_required = True

if download_required:

    # Download processed data:
    address = 'https://raw.githubusercontent.com/MichaelAllen1966/' + \

    data = pd.read_csv(address)

    # Create a data subfolder if one does not already exist
    import os
    data_directory ='./datasets/'
    if not os.path.exists(data_directory):

    # Save data
    data.to_csv(data_directory + 'processed_data.csv', index=False)

data = pd.read_csv('datasets/processed_data.csv')
# Make all data 'float' type
data = data.astype(float)

# Use `survived` field as y, and drop for X
y = data['Survived'] # y = 'survived' column from 'data'
X = data.drop('Survived', axis=1) # X = all 'data' except the 'survived' column

# Drop PassengerId
X.drop('PassengerId',axis=1, inplace=True)

X_train, X_test, y_train, y_test = train_test_split(X,

model = XGBClassifier(random_state=42)
model.fit(X_train, y_train)

# Predict training and test set labels
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)

# Predict probabilities of survival
y_prob_train = model.predict_proba(X_train)
y_prob_test = model.predict_proba(X_test)

accuracy_train = np.mean(y_pred_train == y_train)
accuracy_test = np.mean(y_pred_test == y_test)

print (f'Accuracy of predicting training data = {accuracy_train:.2%}')
print (f'Accuracy of predicting test data = {accuracy_test:.2%}')
Accuracy of predicting training data = 97.31%
Accuracy of predicting test data = 80.72%
# explain the model's predictions using SHAP
explainer = shap.Explainer(model, X_train)
shap_values = explainer(X_test)

.values =
array([[-0.41472028, -0.82758431,  0.08631781, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.34368675,  0.01027041,  0.20749824, ...,  0.        ,
         0.        ,  0.        ],
       [-0.49629094, -0.24858944, -0.07239207, ...,  0.        ,
         0.        ,  0.        ],
       [-0.45505765, -0.95988004,  0.14636154, ...,  0.        ,
         0.        ,  0.        ],
       [ 1.38846514, -0.87314281,  0.06420199, ...,  0.        ,
         0.        ,  0.        ],
       [-0.51810456, -1.0059387 ,  0.18350499, ...,  0.        ,
         0.        ,  0.        ]])

.base_values =
array([-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
       -0.81725838, -0.81725838, -0.81725838])

.data =
array([[ 3., 28.,  1., ...,  0.,  0.,  1.],
       [ 2., 31.,  0., ...,  0.,  0.,  1.],
       [ 3., 20.,  0., ...,  0.,  0.,  1.],
       [ 3., 28.,  0., ...,  0.,  0.,  1.],
       [ 2., 24.,  0., ...,  0.,  0.,  1.],
       [ 3., 18.,  1., ...,  0.,  0.,  1.]])
shap_values_numeric = shap_values.values
array([[-0.41472028, -0.82758431,  0.08631781, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.34368675,  0.01027041,  0.20749824, ...,  0.        ,
         0.        ,  0.        ],
       [-0.49629094, -0.24858944, -0.07239207, ...,  0.        ,
         0.        ,  0.        ],
       [-0.45505765, -0.95988004,  0.14636154, ...,  0.        ,
         0.        ,  0.        ],
       [ 1.38846514, -0.87314281,  0.06420199, ...,  0.        ,
         0.        ,  0.        ],
       [-0.51810456, -1.0059387 ,  0.18350499, ...,  0.        ,
         0.        ,  0.        ]])

16 SHAP - importance table

# Calculate mean Shap value for each feature in training set
importances = pd.DataFrame()
importances['features'] = X.columns.tolist()
importances['mean_shap_values'] = np.mean(shap_values_numeric, axis=0)

# Calculate mean absolute Shap value for each feature in training set
# This will give us the average importance of each feature
importances['mean_abs_shap_values'] = np.mean(

importances['rank_shap'] = importances['mean_abs_shap_values'].rank(ascending=False).values
features mean_shap_values mean_abs_shap_values rank_shap
10 male 0.047098 1.872651 1.0
0 Pclass 0.268006 1.072309 2.0
4 Fare 0.084978 0.914949 3.0
1 Age -0.325160 0.846433 4.0
8 CabinNumber 0.126974 0.344576 5.0

17 SHAP Plots


# visualize the first prediction's explanation

X_test[['male', 'Embarked_C', 'Age', 'Pclass']]
male Embarked_C Age Pclass
709 1.0 1.0 28.0 3.0
439 1.0 0.0 31.0 2.0
840 1.0 0.0 20.0 3.0
720 0.0 0.0 6.0 2.0
39 0.0 1.0 14.0 3.0
... ... ... ... ...
880 0.0 0.0 25.0 2.0
425 1.0 0.0 28.0 3.0
101 1.0 0.0 28.0 3.0
199 0.0 0.0 24.0 2.0
424 1.0 0.0 18.0 3.0

223 rows × 4 columns

# visualize another prediction's explanation
shap.plots.waterfall(shap_values[7]) Find examples with high or low probabilities

highest_prob = pd.Series(y_pred_test).sort_values(ascending=False).head(1)
69    1
dtype: int32
high_prob_index = highest_prob.index[0]

low_prob = pd.Series(y_pred_test).sort_values(ascending=False).tail(1)
222    0
dtype: int32
low_prob_index = low_prob.index[0]

17.1 Force plots

# visualize the first prediction's explanation with a force plot
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
# visualize all the predictions
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

17.2 Dependence Plots

# create a dependence scatter plot to show the effect of a single feature across the whole dataset
shap.plots.scatter(shap_values[:, "Age"])

shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "male"])

shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Pclass"])

shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Fare"])

shap.plots.scatter(shap_values[:, "Age"], color=shap_values)

shap.plots.scatter(shap_values[:, "Fare"], color=shap_values)

shap.plots.scatter(shap_values[:, "male"], color=shap_values)

shap.plots.scatter(shap_values[:, "male"], color=shap_values)

17.3 Beeswarm

# summarize the effects of all the features

17.4 Violin

# summarize the effects of all the features

17.4.1 Bar: Cohorts

sex = ["Women" if shap_values[i, "male"].data == 0 else "Men" for i in range(shap_values.shape[0])]

Plot the bars for an individual.


18 Decision


18.0.1 Decision plot for individual

    explainer.shap_values(X_test)[0], # one way of specifying the record to look at

    shap_values.values[121], # another way of specifying the record to look at


18.1 SHAP: Probability Alternative

Recalculate the SHAP values as changes in probability instead of log odds.

# explain the model's predictions using SHAP
explainer_probability = shap.Explainer(model, X_train, model_output="probability")
shap_values_probability = explainer_probability(X_test)

.values =
array([[-0.05238777, -0.13358663,  0.01009056, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.0511844 , -0.03143588,  0.0286387 , ...,  0.        ,
         0.        ,  0.        ],
       [-0.05579166, -0.05355721, -0.00460973, ...,  0.        ,
         0.        ,  0.        ],
       [-0.0458467 , -0.10278892,  0.01064102, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.14912559, -0.06840275,  0.00767626, ...,  0.        ,
         0.        ,  0.        ],
       [-0.04991312, -0.10697621,  0.00952425, ...,  0.        ,
         0.        ,  0.        ]])

.base_values =
array([0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
       0.38372332, 0.38372332, 0.38372332])

.data =
array([[ 3., 28.,  1., ...,  0.,  0.,  1.],
       [ 2., 31.,  0., ...,  0.,  0.,  1.],
       [ 3., 20.,  0., ...,  0.,  0.,  1.],
       [ 3., 28.,  0., ...,  0.,  0.,  1.],
       [ 2., 24.,  0., ...,  0.,  0.,  1.],
       [ 3., 18.,  1., ...,  0.,  0.,  1.]])

18.1.1 Beeswarm Plot: Probability

shap.plots.beeswarm(shap_values_probability) Comparison with log odds plot

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(10,20))

plt.sca(ax1) ## NEW
shap.plots.beeswarm(shap_values, show=False)
plt.title("Log Odds")

# Change to the second axis
plt.sca(ax2) ## NEW
shap.plots.beeswarm(shap_values_probability, show=False)

18.1.2 Waterfall Plot: Probability

If we pull out the predicted probability for this passenger, we can see that the predicted probability of class 0 (died) is 0.69, while the predicted probability of survival (class 1) is 0.301.

0    0.69867
1    0.30133
Name: 56, dtype: float32

This matches what is now shown in the waterfall plot.


0    0.97308
1    0.02692
Name: 115, dtype: float32

0    0.000583
1    0.999417
Name: 195, dtype: float32
shap.plots.waterfall(shap_values_probability[195]) Comparison with log odds plot

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(10,20))

plt.sca(ax1) ## NEW
shap.plots.waterfall(shap_values[56], show=False)
plt.title("Log Odds")
# Change to the second axis
plt.sca(ax2) ## NEW
shap.plots.waterfall(shap_values_probability[56], show=False)