import shap
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pandas as pd
from xgboost import XGBRegressor
# Import machine learning methods
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn.metrics import mean_absolute_error, mean_absolute_percentage_error, \
r2_score, root_mean_squared_error
# Import shap for shapley values
import shap
# JavaScript Important for the interactive charts later on
shap.initjs()
16 SHAP with regression trees (Diabetes Progression Dataset)
From the SKLearn documentation:
Number of Instances: 442
Number of Attributes: First 10 columns are numeric predictive values
Target: Column 11 is a quantitative measure of disease progression one year after baseline
Attribute Information:
age age in years
sex
bmi body mass index
bp average blood pressure
s1 tc, total serum cholesterol
s2 ldl, low-density lipoproteins
s3 hdl, high-density lipoproteins
s4 tch, total cholesterol / HDL
s5 ltg, possibly log of serum triglycerides level
s6 glu, blood sugar level
https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset
Note: Each of these 10 feature variables have been mean centered and scaled by the standard deviation times the square root of n_samples (i.e. the sum of squares of each column totals 1). This isn’t necessary given we’ve opted to use a tree model, but will be fine for the purpose of demonstration.
# Load the diabetes dataset
= datasets.load_diabetes(return_X_y=True, as_frame=True, scaled=False)
X, y
X.head()
age | sex | bmi | bp | s1 | s2 | s3 | s4 | s5 | s6 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 59.0 | 2.0 | 32.1 | 101.0 | 157.0 | 93.2 | 38.0 | 4.0 | 4.8598 | 87.0 |
1 | 48.0 | 1.0 | 21.6 | 87.0 | 183.0 | 103.2 | 70.0 | 3.0 | 3.8918 | 69.0 |
2 | 72.0 | 2.0 | 30.5 | 93.0 | 156.0 | 93.6 | 41.0 | 4.0 | 4.6728 | 85.0 |
3 | 24.0 | 1.0 | 25.3 | 84.0 | 198.0 | 131.4 | 40.0 | 5.0 | 4.8903 | 89.0 |
4 | 50.0 | 1.0 | 23.0 | 101.0 | 192.0 | 125.4 | 52.0 | 4.0 | 4.2905 | 80.0 |
# let's just rename these to be clearer based on the data dictionary
= X.rename(columns=
X
{"s1": "total_serum_cholesterol",
"s2": "ldl_cholesterol",
"s3": "hdl_cholesterol",
"s4": "total_cholesterol",
"s5": "serum_triglycerides_log",
"s6": "blood_sugar"
})
y.head()
0 151.0
1 75.0
2 141.0
3 206.0
4 135.0
Name: target, dtype: float64
# Split the data into training/testing sets
= train_test_split(
diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test
X, y,= 0.25,
test_size =42
random_state
)
= XGBRegressor(random_state=42)
model
# Train the model using the training sets
model.fit(diabetes_X_train, diabetes_y_train)
# Make predictions using the testing set
= model.predict(diabetes_X_test)
diabetes_y_pred
print(f"Mean absolute error: {mean_absolute_error(diabetes_y_test, diabetes_y_pred):.2f}")
print(f"Mean absolute percentage error: {mean_absolute_percentage_error(diabetes_y_test, diabetes_y_pred):.2%}" )
print("Root Mean squared error: %.2f" % root_mean_squared_error(diabetes_y_test, diabetes_y_pred))
# The coefficient of determination: 1 is perfect prediction
print("Coefficient of determination: %.2f" % r2_score(diabetes_y_test, diabetes_y_pred))
Mean absolute error: 46.59
Mean absolute percentage error: 39.84%
Root Mean squared error: 58.01
Coefficient of determination: 0.39
# explain the model's predictions using SHAP
= shap.Explainer(model, diabetes_X_train)
explainer = explainer(diabetes_X_test)
shap_values
shap_values
.values =
array([[ 10.93827235, 2.8060772 , -1.00666967, ..., -4.88568904,
12.19191076, 3.60045137],
[ -3.45617191, 3.4392566 , 4.56441772, ..., 0.37978809,
-13.22084821, 13.30567497],
[ 10.71647202, -1.89627643, -4.11494224, ..., -1.75186097,
31.18730879, -0.99276986],
...,
[ 2.63871242, 4.10990574, 21.90833779, ..., 0.12410758,
33.50004507, 4.06646104],
[ -0.19391851, 7.1562762 , 21.04501401, ..., 0.28448169,
-46.11341172, 8.68782261],
[ 8.6260937 , -5.57821297, -20.22131193, ..., -3.46062915,
6.32432399, -1.93388196]])
.base_values =
array([167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748, 167.10061748,
167.10061748, 167.10061748, 167.10061748])
.data =
array([[ 61. , 1. , 25.8 , ..., 5. , 4.9972, 90. ],
[ 74. , 1. , 29.8 , ..., 3. , 4.3944, 86. ],
[ 66. , 2. , 26. , ..., 4. , 5.5683, 87. ],
...,
[ 55. , 1. , 28.2 , ..., 4. , 5.366 , 103. ],
[ 53. , 1. , 28.8 , ..., 3.15 , 4.0775, 85. ],
[ 67. , 2. , 23. , ..., 5. , 4.654 , 99. ]])
= shap_values.values
shap_values_numeric shap_values_numeric
array([[ 10.93827235, 2.8060772 , -1.00666967, ..., -4.88568904,
12.19191076, 3.60045137],
[ -3.45617191, 3.4392566 , 4.56441772, ..., 0.37978809,
-13.22084821, 13.30567497],
[ 10.71647202, -1.89627643, -4.11494224, ..., -1.75186097,
31.18730879, -0.99276986],
...,
[ 2.63871242, 4.10990574, 21.90833779, ..., 0.12410758,
33.50004507, 4.06646104],
[ -0.19391851, 7.1562762 , 21.04501401, ..., 0.28448169,
-46.11341172, 8.68782261],
[ 8.6260937 , -5.57821297, -20.22131193, ..., -3.46062915,
6.32432399, -1.93388196]])
17 Plots
# visualize the first prediction's explanation
0]) shap.plots.waterfall(shap_values[
# visualize a later prediction's explanation
7]) shap.plots.waterfall(shap_values[
17.1 Force plots
# visualize the first prediction's explanation with a force plot
0]) shap.plots.force(shap_values[
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
# visualize all the predictions
shap.plots.force(shap_values)
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
17.2 Dependence Plots
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
"age"]) shap.plots.scatter(shap_values[:,
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
"bmi"], color=shap_values) shap.plots.scatter(shap_values[:,
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
"blood_sugar"], color=shap_values) shap.plots.scatter(shap_values[:,
"age"], color=shap_values[:, "bmi"]) shap.plots.scatter(shap_values[:,
"bmi"], color=shap_values[:, "age"]) shap.plots.scatter(shap_values[:,
"age"], color=shap_values[:, "bmi"]) shap.plots.scatter(shap_values[:,
17.3 Beeswarm
# summarize the effects of all the features
shap.plots.beeswarm(shap_values)
17.4 Violin
# summarize the effects of all the features
shap.plots.violin(shap_values)
17.5 Bar
shap.plots.bar(shap_values)
17.5.1 Splitting by cohorts
= ["<30" if shap_values[i, "bmi"].data < 30 else ">=30" for i in range(shap_values.shape[0])]
bmi_category abs.mean(0)) shap.plots.bar(shap_values.cohorts(bmi_category).
17.5.2 Automatic cohort splitting
2).abs.mean(0)) shap.plots.bar(shap_values.cohorts(
Plot the bars for an individual.
0]) shap.plots.bar(shap_values[
17.6 Heatmap
shap.plots.heatmap(shap_values)
17.7 Decision
shap.plots.decision(
explainer.expected_value,
explainer.shap_values(diabetes_X_test),=X.columns.tolist()
feature_names )
17.7.1 Explanation plot for individual
shap.plots.decision(
explainer.expected_value,0],
explainer.shap_values(diabetes_X_test)[=X.columns.tolist()
feature_names )
shap.plots.decision(
explainer.expected_value,104],
explainer.shap_values(diabetes_X_test)[=X.columns.tolist()
feature_names )
17.8 Group Difference
= np.array([False if shap_values[i, "bmi"].data < 30 else True for i in range(shap_values.shape[0])])
bmi_category_obese # bmi_category_obese
=X.columns.tolist()) shap.plots.group_difference(shap_values_numeric, bmi_category_obese, feature_names
17.9 Modifying plots
shap.plots.decision(
explainer.expected_value,125],
explainer.shap_values(X)[=X.columns.tolist(),
feature_names=False ## NEW
show
)
## NEW
plt.gcf()
= plt.title("Here is my added title") ## NEW ax
17.10 Subplots
= plt.subplots(nrows=1, ncols=2, figsize=(15,10)) ## NEW
fig, (ax1, ax2)
# sca is 'set current axis'
# ensures next plot is put onto the axis we specify here - our first
# of the two subplots
## NEW
plt.sca(ax1)
shap.plots.decision(
explainer.expected_value,104],
explainer.shap_values(diabetes_X_test)[=X.columns.tolist(),
feature_names=False ## NEW
show
)
# Change to the second axis
## NEW
plt.sca(ax2)
shap.plots.decision(
explainer.expected_value,15],
explainer.shap_values(diabetes_X_test)[=X.columns.tolist(),
feature_names=False ## NEW
show
)
# note that the use of %matplotlib inline at the start has led to the figsize parameter being
# partly ignored