import xgboost
import shap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Import machine learning methods
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from sklearn.ensemble import RandomForestClassifier
# Import shap for shapley values
import shap
# JavaScript Important for the interactive charts later on
shap.initjs()
15 SHAP with XGBoost (Titanic Dataset)
= True
download_required
if download_required:
# Download processed data:
= 'https://raw.githubusercontent.com/MichaelAllen1966/' + \
address '1804_python_healthcare/master/titanic/data/processed_data.csv'
= pd.read_csv(address)
data
# Create a data subfolder if one does not already exist
import os
='./datasets/'
data_directory if not os.path.exists(data_directory):
os.makedirs(data_directory)
# Save data
+ 'processed_data.csv', index=False)
data.to_csv(data_directory
= pd.read_csv('datasets/processed_data.csv')
data # Make all data 'float' type
= data.astype(float)
data
# Use `survived` field as y, and drop for X
= data['Survived'] # y = 'survived' column from 'data'
y = data.drop('Survived', axis=1) # X = all 'data' except the 'survived' column
X
# Drop PassengerId
'PassengerId',axis=1, inplace=True)
X.drop(
= train_test_split(X,
X_train, X_test, y_train, y_test
y,=42,
random_state=0.25)
test_size
= XGBClassifier(random_state=42)
model
model.fit(X_train, y_train)
# Predict training and test set labels
= model.predict(X_train)
y_pred_train = model.predict(X_test)
y_pred_test
# Predict probabilities of survival
= model.predict_proba(X_train)
y_prob_train = model.predict_proba(X_test)
y_prob_test
= np.mean(y_pred_train == y_train)
accuracy_train = np.mean(y_pred_test == y_test)
accuracy_test
print (f'Accuracy of predicting training data = {accuracy_train:.2%}')
print (f'Accuracy of predicting test data = {accuracy_test:.2%}')
Accuracy of predicting training data = 97.31%
Accuracy of predicting test data = 80.72%
# explain the model's predictions using SHAP
= shap.Explainer(model, X_train)
explainer = explainer(X_test)
shap_values
shap_values
.values =
array([[-0.41472028, -0.82758431, 0.08631781, ..., 0. ,
0. , 0. ],
[ 0.34368675, 0.01027041, 0.20749824, ..., 0. ,
0. , 0. ],
[-0.49629094, -0.24858944, -0.07239207, ..., 0. ,
0. , 0. ],
...,
[-0.45505765, -0.95988004, 0.14636154, ..., 0. ,
0. , 0. ],
[ 1.38846514, -0.87314281, 0.06420199, ..., 0. ,
0. , 0. ],
[-0.51810456, -1.0059387 , 0.18350499, ..., 0. ,
0. , 0. ]])
.base_values =
array([-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838, -0.81725838, -0.81725838,
-0.81725838, -0.81725838, -0.81725838])
.data =
array([[ 3., 28., 1., ..., 0., 0., 1.],
[ 2., 31., 0., ..., 0., 0., 1.],
[ 3., 20., 0., ..., 0., 0., 1.],
...,
[ 3., 28., 0., ..., 0., 0., 1.],
[ 2., 24., 0., ..., 0., 0., 1.],
[ 3., 18., 1., ..., 0., 0., 1.]])
= shap_values.values
shap_values_numeric shap_values_numeric
array([[-0.41472028, -0.82758431, 0.08631781, ..., 0. ,
0. , 0. ],
[ 0.34368675, 0.01027041, 0.20749824, ..., 0. ,
0. , 0. ],
[-0.49629094, -0.24858944, -0.07239207, ..., 0. ,
0. , 0. ],
...,
[-0.45505765, -0.95988004, 0.14636154, ..., 0. ,
0. , 0. ],
[ 1.38846514, -0.87314281, 0.06420199, ..., 0. ,
0. , 0. ],
[-0.51810456, -1.0059387 , 0.18350499, ..., 0. ,
0. , 0. ]])
16 SHAP - importance table
# Calculate mean Shap value for each feature in training set
= pd.DataFrame()
importances 'features'] = X.columns.tolist()
importances['mean_shap_values'] = np.mean(shap_values_numeric, axis=0)
importances[
# Calculate mean absolute Shap value for each feature in training set
# This will give us the average importance of each feature
'mean_abs_shap_values'] = np.mean(
importances[abs(shap_values_numeric),axis=0)
np.
'rank_shap'] = importances['mean_abs_shap_values'].rank(ascending=False).values
importances['rank_shap').head() importances.sort_values(
features | mean_shap_values | mean_abs_shap_values | rank_shap | |
---|---|---|---|---|
10 | male | 0.047098 | 1.872651 | 1.0 |
0 | Pclass | 0.268006 | 1.072309 | 2.0 |
4 | Fare | 0.084978 | 0.914949 | 3.0 |
1 | Age | -0.325160 | 0.846433 | 4.0 |
8 | CabinNumber | 0.126974 | 0.344576 | 5.0 |
17 SHAP Plots
shap.plots.bar(shap_values)
# visualize the first prediction's explanation
0]) shap.plots.waterfall(shap_values[
'male', 'Embarked_C', 'Age', 'Pclass']] X_test[[
male | Embarked_C | Age | Pclass | |
---|---|---|---|---|
709 | 1.0 | 1.0 | 28.0 | 3.0 |
439 | 1.0 | 0.0 | 31.0 | 2.0 |
840 | 1.0 | 0.0 | 20.0 | 3.0 |
720 | 0.0 | 0.0 | 6.0 | 2.0 |
39 | 0.0 | 1.0 | 14.0 | 3.0 |
... | ... | ... | ... | ... |
880 | 0.0 | 0.0 | 25.0 | 2.0 |
425 | 1.0 | 0.0 | 28.0 | 3.0 |
101 | 1.0 | 0.0 | 28.0 | 3.0 |
199 | 0.0 | 0.0 | 24.0 | 2.0 |
424 | 1.0 | 0.0 | 18.0 | 3.0 |
223 rows × 4 columns
# visualize another prediction's explanation
7]) shap.plots.waterfall(shap_values[
17.0.0.1 Find examples with high or low probabilities
= pd.Series(y_pred_test).sort_values(ascending=False).head(1)
highest_prob highest_prob
69 1
dtype: int32
= highest_prob.index[0]
high_prob_index shap.plots.waterfall(shap_values[high_prob_index])
= pd.Series(y_pred_test).sort_values(ascending=False).tail(1)
low_prob low_prob
222 0
dtype: int32
= low_prob.index[0]
low_prob_index shap.plots.waterfall(shap_values[low_prob_index])
17.1 Force plots
# visualize the first prediction's explanation with a force plot
0]) shap.plots.force(shap_values[
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
# visualize all the predictions
shap.plots.force(shap_values)
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
17.2 Dependence Plots
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
"Age"]) shap.plots.scatter(shap_values[:,
"Age"], color=shap_values[:, "male"]) shap.plots.scatter(shap_values[:,
"Age"], color=shap_values[:, "Pclass"]) shap.plots.scatter(shap_values[:,
"Age"], color=shap_values[:, "Fare"]) shap.plots.scatter(shap_values[:,
"Age"], color=shap_values) shap.plots.scatter(shap_values[:,
"Fare"], color=shap_values) shap.plots.scatter(shap_values[:,
"male"], color=shap_values) shap.plots.scatter(shap_values[:,
"male"], color=shap_values) shap.plots.scatter(shap_values[:,
17.3 Beeswarm
# summarize the effects of all the features
shap.plots.beeswarm(shap_values)
17.4 Violin
# summarize the effects of all the features
shap.plots.violin(shap_values)
17.4.1 Bar: Cohorts
= ["Women" if shap_values[i, "male"].data == 0 else "Men" for i in range(shap_values.shape[0])]
sex abs.mean(0)) shap.plots.bar(shap_values.cohorts(sex).
Plot the bars for an individual.
1]) shap.plots.bar(shap_values[
18 Decision
shap.plots.decision(
explainer.expected_value,
shap_values.values,=X.columns.tolist()
feature_names )
18.0.1 Decision plot for individual
shap.plots.decision(
explainer.expected_value,0], # one way of specifying the record to look at
explainer.shap_values(X_test)[=X.columns.tolist()
feature_names )
shap.plots.decision(
explainer.expected_value,121], # another way of specifying the record to look at
shap_values.values[=X.columns.tolist()
feature_names )
shap.plots.decision(
explainer.expected_value,215],
explainer.shap_values(X_test)[=X.columns.tolist()
feature_names )
18.1 SHAP: Probability Alternative
Recalculate the SHAP values as changes in probability instead of log odds.
# explain the model's predictions using SHAP
= shap.Explainer(model, X_train, model_output="probability")
explainer_probability = explainer_probability(X_test)
shap_values_probability
shap_values_probability
.values =
array([[-0.05238777, -0.13358663, 0.01009056, ..., 0. ,
0. , 0. ],
[ 0.0511844 , -0.03143588, 0.0286387 , ..., 0. ,
0. , 0. ],
[-0.05579166, -0.05355721, -0.00460973, ..., 0. ,
0. , 0. ],
...,
[-0.0458467 , -0.10278892, 0.01064102, ..., 0. ,
0. , 0. ],
[ 0.14912559, -0.06840275, 0.00767626, ..., 0. ,
0. , 0. ],
[-0.04991312, -0.10697621, 0.00952425, ..., 0. ,
0. , 0. ]])
.base_values =
array([0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332, 0.38372332, 0.38372332,
0.38372332, 0.38372332, 0.38372332])
.data =
array([[ 3., 28., 1., ..., 0., 0., 1.],
[ 2., 31., 0., ..., 0., 0., 1.],
[ 3., 20., 0., ..., 0., 0., 1.],
...,
[ 3., 28., 0., ..., 0., 0., 1.],
[ 2., 24., 0., ..., 0., 0., 1.],
[ 3., 18., 1., ..., 0., 0., 1.]])
18.1.1 Beeswarm Plot: Probability
shap.plots.beeswarm(shap_values_probability)
18.1.1.1 Comparison with log odds plot
= plt.subplots(nrows=2, ncols=1, figsize=(10,20))
fig, (ax1, ax2)
## NEW
plt.sca(ax1) =False)
shap.plots.beeswarm(shap_values, show"Log Odds")
plt.title(
# Change to the second axis
## NEW
plt.sca(ax2) =False)
shap.plots.beeswarm(shap_values_probability, show"Probability")
plt.title(
plt.tight_layout() plt.show()
18.1.2 Waterfall Plot: Probability
If we pull out the predicted probability for this passenger, we can see that the predicted probability of class 0 (died) is 0.69, while the predicted probability of survival (class 1) is 0.301.
=True).iloc[56] pd.DataFrame(model.predict_proba(X_test)).reset_index(drop
0 0.69867
1 0.30133
Name: 56, dtype: float32
This matches what is now shown in the waterfall plot.
56]) shap.plots.waterfall(shap_values_probability[
=True).iloc[115] pd.DataFrame(model.predict_proba(X_test)).reset_index(drop
0 0.97308
1 0.02692
Name: 115, dtype: float32
115]) shap.plots.waterfall(shap_values_probability[
=True).iloc[195] pd.DataFrame(model.predict_proba(X_test)).reset_index(drop
0 0.000583
1 0.999417
Name: 195, dtype: float32
195]) shap.plots.waterfall(shap_values_probability[
18.1.2.1 Comparison with log odds plot
= plt.subplots(nrows=2, ncols=1, figsize=(10,20))
fig, (ax1, ax2)
## NEW
plt.sca(ax1) 56], show=False)
shap.plots.waterfall(shap_values["Log Odds")
plt.title(# Change to the second axis
## NEW
plt.sca(ax2) 56], show=False)
shap.plots.waterfall(shap_values_probability["Probability")
plt.title(
plt.tight_layout() plt.show()