import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Import machine learning methods
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.inspection import PartialDependenceDisplay, permutation_importance
# Import shap for shapley values
import shap
# JavaScript Important for the interactive charts later on
shap.initjs()
14 Explaining model predictions with PDPs, ICE plots, MDI, PFI and SHAP
Elliot Coyne and Sammi Rosser, HSMA Trainer
The notebook below is a modified version of the original by Mike Allen’s Titanic Notebooks.
Shapley values provide an estimate of how much any particular feature influences the model decision (prediction). When Shapley values are averaged they provide a measure of the overall influence of a feature.
Shapley values may be used across model types, and so provide a model-agnostic measure of a feature’s influence. This means that the influence of features may be compared across model types, and it allows black box models like neural networks to be explained, at least in part.
Here we will demonstrate Shapley values with random forests.
For more on Shapley values in general see Chris Molner’s excellent book chapter:
https://christophm.github.io/interpretable-ml-book/shapley.html
The shap
package is installed if you have used the ml
environment yaml file, but otherwise may be installed with pip install shap
.
More information on the shap
library, inclusiong lots of useful examples may be found at: https://shap.readthedocs.io/en/latest/index.html
Here we provide an example of using shap
with Random Forests.
Shap values are returned in a slightly different way to logistic regression - there is a set of Shap values for each classification probablility (‘not survive’, ‘survive’) so we need slightly different syntax to access and use the Shap values.
14.1 Load data and fit model
14.1.1 Load modules
14.1.2 Load data
The section below downloads pre-processed data, and saves it to a subfolder (from where this code is run). If data has already been downloaded that cell may be skipped.
Code that was used to pre-process the data ready for machine learning may be found at: https://github.com/MichaelAllen1966/1804_python_healthcare/blob/master/titanic/01_preprocessing.ipynb
= True
download_required
if download_required:
# Download processed data:
= 'https://raw.githubusercontent.com/MichaelAllen1966/' + \
address '1804_python_healthcare/master/titanic/data/processed_data.csv'
= pd.read_csv(address)
data
# Create a data subfolder if one does not already exist
import os
='./datasets/'
data_directory if not os.path.exists(data_directory):
os.makedirs(data_directory)
# Save data
+ 'processed_data.csv', index=False) data.to_csv(data_directory
= pd.read_csv('datasets/processed_data.csv')
data # Make all data 'float' type
= data.astype(float) data
14.1.3 Divide into X (features) and y (labels)
We will separate out our features (the data we use to make a prediction) from our label (what we are truing to predict). By convention our features are called X
(usually upper case to denote multiple features), and the label (survived or not) y
.
# Use `survived` field as y, and drop for X
= data['Survived'] # y = 'survived' column from 'data'
y = data.drop('Survived', axis=1) # X = all 'data' except the 'survived' column
X
# Drop PassengerId
'PassengerId',axis=1, inplace=True) X.drop(
14.1.4 Divide into training and tets sets
When we test a machine learning model we should always test it on data that has not been used to train the model. We will use sklearn’s train_test_split
method to randomly split the data: 75% for training, and 25% for testing.
= train_test_split(X,
X_train, X_test, y_train, y_test
y,=42,
random_state=0.25) test_size
14.1.5 Fit Random Forest model
= RandomForestClassifier(n_estimators=100,
model =-1,
n_jobs='balanced',
class_weight=42,
random_state=7)
max_depth
model.fit(X_train, y_train)
RandomForestClassifier(class_weight='balanced', max_depth=7, n_jobs=-1, random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier(class_weight='balanced', max_depth=7, n_jobs=-1, random_state=42)
14.1.6 Predict values and get probabilities of survival
Now we can use the trained model to predict survival. We will test the accuracy of both the training and test data sets.
# Predict training and test set labels
= model.predict(X_train)
y_pred_train = model.predict(X_test)
y_pred_test
# Predict probabilities of survival
= model.predict_proba(X_train)
y_prob_train = model.predict_proba(X_test) y_prob_test
14.1.7 Calculate accuracy
In this example we will measure accuracy simply as the proportion of passengers where we make the correct prediction. In a later notebook we will look at other measures of accuracy which explore false positives and false negatives in more detail.
= np.mean(y_pred_train == y_train)
accuracy_train = np.mean(y_pred_test == y_test)
accuracy_test
print (f'Accuracy of predicting training data = {accuracy_train:.2%}')
print (f'Accuracy of predicting test data = {accuracy_test:.2%}')
Accuracy of predicting training data = 88.92%
Accuracy of predicting test data = 81.17%
14.2 Examining the model importances
As we have used a tree-based model, we can easily pull out the feature importances using the MDI (mean decrease in impurity) approach, which are stored in model.feature_importances_
.
= X.columns.tolist()
feature_names
= model.feature_importances_
feature_importances_mdi
= pd.DataFrame(index=feature_names)
importances 'importance_mdi'] = feature_importances_mdi
importances['rank'] = importances['importance_mdi'].rank(ascending=False).values
importances[
# View just the top 5
'rank').head() importances.sort_values(
importance_mdi | rank | |
---|---|---|
male | 0.348332 | 1.0 |
Fare | 0.148499 | 2.0 |
Age | 0.116853 | 3.0 |
Pclass | 0.074443 | 4.0 |
CabinNumber | 0.064082 | 5.0 |
The three most influential features are:
- male
- Fare
- age
Note: random forest importances do not tell us anything about the direction of effect of features (as with random forests, the direction of effect may depend on the value oif other features).
= pd.Series(feature_importances_mdi, index=feature_names)
feature_importances_mdi
= plt.subplots(figsize=(15,10))
fig, ax =ax)
feature_importances_mdi.plot.barh(ax"Feature importances using MDI")
ax.set_title("Mean decrease in impurity")
ax.set_xlabel(
fig.tight_layout() plt.show()
Note that because we’re using a random forest here, MDI is averaged across all the trees, so we can actually include error bars.
= model.feature_importances_
importances = np.std([tree.feature_importances_ for tree in model.estimators_], axis=0)
std
= pd.Series(importances, index=feature_names)
forest_importances
= plt.subplots(figsize=(15,10))
fig, ax =std, ax=ax)
forest_importances.plot.barh(yerr"Feature importances using MDI")
ax.set_title("Mean decrease in impurity")
ax.set_ylabel( fig.tight_layout()
Note that we could also use permutation feature importance here as an alternative approach.
This works with tree based models, but is actually a model-agnostic approach.
Let’s take a quick look at the output of that function.
= permutation_importance(
result_pfi =10, random_state=42, n_jobs=2
model, X_test, y_test, n_repeats
)
result_pfi
{'importances_mean': array([ 0.03183857, 0.02331839, 0.01076233, 0.00313901, 0.00089686,
0.00044843, 0. , -0.00493274, 0.00538117, -0.00044843,
0.17399103, -0.00134529, -0.00044843, 0.00134529, 0. ,
-0.00493274, 0.00358744, -0.00179372, 0.00044843, -0.00179372,
0. , 0. , 0. , -0.00179372]),
'importances_std': array([0.00735481, 0.00998702, 0.00538117, 0.00403587, 0.00892365,
0.00509319, 0. , 0.00372494, 0.00797147, 0.00678598,
0.01038192, 0.0063576 , 0.00134529, 0.00724462, 0. ,
0.00134529, 0.00179372, 0.00219685, 0.00134529, 0.00219685,
0. , 0. , 0. , 0.00410993]),
'importances': array([[ 0.02690583, 0.04035874, 0.02690583, 0.03139013, 0.03587444,
0.04484305, 0.03139013, 0.01793722, 0.02690583, 0.03587444],
[ 0.04035874, 0.02690583, 0.03587444, 0.00896861, 0.02690583,
0.02242152, 0.00896861, 0.02690583, 0.02242152, 0.01345291],
[ 0.00896861, 0.01345291, 0.00896861, 0.0044843 , 0.01345291,
0.01793722, 0.00896861, 0.01345291, 0.01793722, 0. ],
[ 0.0044843 , 0. , 0. , 0. , 0.0044843 ,
0.0044843 , 0. , 0.0044843 , 0. , 0.01345291],
[ 0.01345291, 0.0044843 , -0.00896861, 0. , 0.0044843 ,
-0.0044843 , 0. , 0.00896861, -0.01793722, 0.00896861],
[ 0. , -0.0044843 , 0. , -0.0044843 , 0.00896861,
0.00896861, 0. , -0.0044843 , 0.0044843 , -0.0044843 ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0. , -0.00896861, -0.0044843 , -0.0044843 , -0.0044843 ,
-0.0044843 , -0.0044843 , -0.0044843 , 0. , -0.01345291],
[ 0.00896861, 0.0044843 , 0.01345291, -0.0044843 , 0.01793722,
-0.00896861, 0.01345291, 0.0044843 , 0. , 0.0044843 ],
[ 0.0044843 , -0.0044843 , 0.0044843 , 0. , 0.00896861,
-0.00896861, 0.00896861, 0. , -0.00896861, -0.00896861],
[ 0.1838565 , 0.1793722 , 0.1838565 , 0.17488789, 0.1838565 ,
0.16591928, 0.1838565 , 0.16591928, 0.15246637, 0.16591928],
[ 0. , -0.0044843 , 0. , 0.0044843 , 0.0044843 ,
0. , -0.01793722, 0. , 0.0044843 , -0.0044843 ],
[ 0. , 0. , 0. , -0.0044843 , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0. , -0.01345291, 0.00896861, 0.00896861, 0. ,
-0.0044843 , -0.0044843 , 0.00896861, 0.00896861, 0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[-0.0044843 , -0.0044843 , -0.0044843 , -0.0044843 , -0.00896861,
-0.0044843 , -0.0044843 , -0.0044843 , -0.0044843 , -0.0044843 ],
[ 0.0044843 , 0.0044843 , 0.0044843 , 0. , 0. ,
0.0044843 , 0.0044843 , 0.0044843 , 0.0044843 , 0.0044843 ],
[ 0. , -0.0044843 , 0. , -0.0044843 , -0.0044843 ,
0. , -0.0044843 , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0.0044843 , 0. , 0. , 0. ],
[-0.0044843 , -0.0044843 , 0. , 0. , -0.0044843 ,
0. , -0.0044843 , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0.0044843 , -0.0044843 , -0.0044843 , 0. , -0.0044843 ,
0. , -0.0044843 , 0. , 0.0044843 , -0.00896861]])}
Now let’s plot the output.
= pd.Series(result_pfi.importances_mean, index=feature_names)
feature_importances_pfi
= plt.subplots(figsize=(15,10))
fig, ax =result_pfi.importances_std, ax=ax)
feature_importances_pfi.plot.barh(yerr"Feature importances using permutation on full model")
ax.set_title("Mean accuracy decrease")
ax.set_xlabel(
fig.tight_layout() plt.show()
Let’s add this to our table too.
'importance_pfi'] = feature_importances_pfi
importances['rank_pfi'] = importances['importance_pfi'].rank(ascending=False).values
importances['rank_pfi').head() importances.sort_values(
importance_mdi | rank | importance_pfi | rank_pfi | |
---|---|---|---|---|
male | 0.348332 | 1.0 | 0.173991 | 1.0 |
Pclass | 0.074443 | 4.0 | 0.031839 | 2.0 |
Age | 0.116853 | 3.0 | 0.023318 | 3.0 |
SibSp | 0.049939 | 6.0 | 0.010762 | 4.0 |
CabinNumber | 0.064082 | 5.0 | 0.005381 | 5.0 |
14.3 Get Shapley values
First we need to create a shap explainer object.
= shap.Explainer(model, X_train)
explainer = explainer(X_test)
shap_values = shap_values.values shap_values_numeric
Look at the explainer object.
explainer
<shap.explainers._tree.TreeExplainer at 0x1c91bcf1b90>
Look at the shap_values
variable.
type(shap_values)
shap._explanation.Explanation
shap_values
.values =
array([[[ 0.01060336, -0.01060336],
[ 0.02870872, -0.02870872],
[-0.00961641, 0.00961641],
...,
[-0.00037315, 0.00037315],
[ 0. , 0. ],
[ 0.00809963, -0.00809963]],
[[-0.01520491, 0.01520491],
[ 0.02459478, -0.02459477],
[-0.00241162, 0.00241162],
...,
[-0.0004267 , 0.0004267 ],
[ 0. , 0. ],
[ 0.00473834, -0.00473834]],
[[ 0.01319181, -0.01319181],
[ 0.01232224, -0.01232224],
[-0.00262561, 0.00262561],
...,
[-0.00049012, 0.00049012],
[ 0. , 0. ],
[-0.00154064, 0.00154064]],
...,
[[ 0.01926633, -0.01926633],
[ 0.0221672 , -0.0221672 ],
[-0.00363846, 0.00363846],
...,
[-0.0004621 , 0.00046211],
[ 0. , 0. ],
[ 0.00050243, -0.00050243]],
[[-0.07960656, 0.07960655],
[ 0.01078217, -0.01078217],
[-0.01453268, 0.01453268],
...,
[-0.00054181, 0.00054181],
[ 0. , 0. ],
[ 0.0018481 , -0.0018481 ]],
[[ 0.02524583, -0.02524583],
[ 0.00986923, -0.00986923],
[-0.02022015, 0.02022015],
...,
[-0.00034948, 0.00034948],
[ 0. , 0. ],
[ 0.0036434 , -0.0036434 ]]])
.base_values =
array([[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463],
[0.5630537, 0.4369463]])
.data =
array([[ 3., 28., 1., ..., 0., 0., 1.],
[ 2., 31., 0., ..., 0., 0., 1.],
[ 3., 20., 0., ..., 0., 0., 1.],
...,
[ 3., 28., 0., ..., 0., 0., 1.],
[ 2., 24., 0., ..., 0., 0., 1.],
[ 3., 18., 1., ..., 0., 0., 1.]])
Look at the shap_values_numeric
variable.
shap_values_numeric
array([[[ 0.01060336, -0.01060336],
[ 0.02870872, -0.02870872],
[-0.00961641, 0.00961641],
...,
[-0.00037315, 0.00037315],
[ 0. , 0. ],
[ 0.00809963, -0.00809963]],
[[-0.01520491, 0.01520491],
[ 0.02459478, -0.02459477],
[-0.00241162, 0.00241162],
...,
[-0.0004267 , 0.0004267 ],
[ 0. , 0. ],
[ 0.00473834, -0.00473834]],
[[ 0.01319181, -0.01319181],
[ 0.01232224, -0.01232224],
[-0.00262561, 0.00262561],
...,
[-0.00049012, 0.00049012],
[ 0. , 0. ],
[-0.00154064, 0.00154064]],
...,
[[ 0.01926633, -0.01926633],
[ 0.0221672 , -0.0221672 ],
[-0.00363846, 0.00363846],
...,
[-0.0004621 , 0.00046211],
[ 0. , 0. ],
[ 0.00050243, -0.00050243]],
[[-0.07960656, 0.07960655],
[ 0.01078217, -0.01078217],
[-0.01453268, 0.01453268],
...,
[-0.00054181, 0.00054181],
[ 0. , 0. ],
[ 0.0018481 , -0.0018481 ]],
[[ 0.02524583, -0.02524583],
[ 0.00986923, -0.00986923],
[-0.02022015, 0.02022015],
...,
[-0.00034948, 0.00034948],
[ 0. , 0. ],
[ 0.0036434 , -0.0036434 ]]])
# Random forests seem to give us a slightly different output format,
# so we adjust the line below to just bring back the results from the
# positive class
1]) shap.plots.bar(shap_values[:, :,
Add Shap values to coefficient table.
# Calculate mean Shap value for each feature in training set
'mean_shap_values'] = np.mean(shap_values_numeric[:,:,1], axis=0)
importances[
# Calculate mean absolute Shap value for each feature in training set
# This will give us the average importance of each feature
'mean_abs_shap_values'] = np.mean(
importances[abs(shap_values_numeric[:,:,1]),axis=0)
np.
'rank_shap'] = importances['mean_abs_shap_values'].rank(ascending=False).values
importances['rank_shap').head() importances.sort_values(
importance_mdi | rank | importance_pfi | rank_pfi | mean_shap_values | mean_abs_shap_values | rank_shap | |
---|---|---|---|---|---|---|---|
male | 0.348332 | 1.0 | 0.173991 | 1.0 | 0.009733 | 0.163071 | 1.0 |
Pclass | 0.074443 | 4.0 | 0.031839 | 2.0 | 0.011423 | 0.043382 | 2.0 |
Fare | 0.148499 | 2.0 | 0.000897 | 9.0 | 0.001480 | 0.036203 | 3.0 |
Age | 0.116853 | 3.0 | 0.023318 | 3.0 | -0.013086 | 0.028130 | 4.0 |
CabinNumber | 0.064082 | 5.0 | 0.005381 | 5.0 | 0.005358 | 0.023901 | 5.0 |
Get top 10 influential features by co-efficients for SHAP
# Get top 10 features
= \
mdi_importance_top_10 ='importance_mdi', ascending=False).head(10).index
importances.sort_values(by
= \
pfi_importance_top_10 ='importance_pfi', ascending=False).head(10).index
importances.sort_values(by
= \
shapley_top_10
importances.sort_values(='mean_abs_shap_values', ascending=False).head(10).index
by
# Add to DataFrame
= pd.DataFrame()
top_10_features 'importances_mdi'] = mdi_importance_top_10.values
top_10_features['importances_pfii'] = pfi_importance_top_10.values
top_10_features['Shap'] = shapley_top_10.values
top_10_features[
# Display
top_10_features
importances_mdi | importances_pfii | Shap | |
---|---|---|---|
0 | male | male | male |
1 | Fare | Pclass | Pclass |
2 | Age | Age | Fare |
3 | Pclass | SibSp | Age |
4 | CabinNumber | CabinNumber | CabinNumber |
5 | SibSp | CabinLetter_B | CabinNumberImputed |
6 | Parch | Parch | Embarked_S |
7 | CabinNumberImputed | Embarked_S | AgeImputed |
8 | CabinLetter_missing | Fare | Embarked_C |
9 | CabinLetterImputed | AgeImputed | SibSp |
Let’s quickly compare our shap top 10 with the associated bar plot.
We can see a lot of overlap between the most import fatures as estimated by coefficients and those estimated using mean absolute Shapley values. But they are not identical.
Plot comparison of Shapley and model coefficients:
= plt.figure(figsize=(6,6))
fig = fig.add_subplot(111)
ax
# Plot points
= importances['importance_mdi']
x = importances['mean_abs_shap_values']
y
ax.scatter(x, y)'Shapley value vs model weight (coefficient) for each feature')
ax.set_title('Mean absolute Shap value')
ax.set_ylabel('Feature importance')
ax.set_xlabel(
plt.grid() plt.show()
15 Partial Dependence Plots (PDPs) and Individual Conditional Expectation (ICE) Plots
15.0.1 A single PDP
= plt.subplots(figsize=(4, 5)) # Create an empty plot, specifying size
fig, ax
= PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =['Age'], # List of features to plot
features='average', # Type of plot
kind=ax, # axis to plot on
ax=42 # avoidance of randomness
random_state
)
plt.show()
15.0.2 2 PDPs side-by-side
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =['Age', 'Pclass'], # List of features to plot
features='average', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()
15.0.3 A single ICE plot
= plt.subplots(figsize=(4, 5)) # Create an empty plot, specifying size
fig, ax
= PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =['Age'], # List of features to plot
features='individual', # Type of plot
kind=ax, # axis to plot on
ax=42 # avoidance of randomness
random_state
) plt.show()
An ICE plot for a subsample of people
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, 20), # Your feature matrix
X_train.sample(=['Age'], # List of features to plot
features='individual', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()
The plot below shows an alternative way to subsample to just a subset of the data.
= plt.subplots(figsize=(4, 5)) # Create an empty plot, specifying size
fig, ax
= PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =0.1, # Proportion of data to subsample down to
subsample=['Age'], # List of features to plot
features='individual', # Type of plot
kind=ax, # axis to plot on
ax=42 # avoidance of randomness
random_state
) plt.show()
= plt.subplots(figsize=(8, 7)) # Create an empty plot, specifying size
fig, ax
= PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, 'male'] == 1], # Your feature matrix
X_train[X_train[=['Age'], # List of features to plot
features='both', # Type of plot
kind=ax, # axis to plot on
ax=42, # avoidance of randomness
random_state={'color': 'red'},
ice_lines_kw={'color': 'maroon'}
pd_line_kw
)
PartialDependenceDisplay.from_estimator(# Your fitted model
model, 'male'] == 0], # Your feature matrix
X_train[X_train[=['Age'], # List of features to plot
features='both', # Type of plot
kind=display.axes_, # axis to plot on
ax=42, # avoidance of randomness
random_state={'color': 'purple'},
ice_lines_kw={'color': 'darkslateblue'}
pd_line_kw
)
# Create custom legend handles and labels
= [
legend_elements 0], [0], color='purple', lw=2, label='Female'),
plt.Line2D([0], [0], color='red', lw=2, label='Male')
plt.Line2D([
]
# Add the legend
=legend_elements, loc="upper right")
plt.legend(handles plt.show()
15.0.4 2 ICE Plots side-by-side
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =['Age', 'Pclass'], # List of features to plot
features='individual', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()
15.0.5 Joint PDP/ICE Plot
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =['Age', 'Pclass'], # List of features to plot
features='both', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()
15.0.6 Bonus: 2D PDP!
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =[('Age', 'male')], # List of features to plot
features='average', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =[('Age', 'Fare')], # List of features to plot
features='average', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()