import pandas as pd
import numpy as np
# Download data
# (not required if running locally and have previously downloaded data)
= True
download_required
if download_required:
# Download processed data:
= 'https://raw.githubusercontent.com/MichaelAllen1966/' + \
address '2004_titanic/master/jupyter_notebooks/data/hsma_stroke.csv'
= pd.read_csv(address)
data
# Create a data subfolder if one does not already exist
import os
='./data/'
data_directory if not os.path.exists(data_directory):
os.makedirs(data_directory)
# Save data to data subfolder
+ 'hsma_stroke.csv', index=False)
data.to_csv(data_directory
# Load data
= pd.read_csv('data/hsma_stroke.csv')
data # Make all data 'float' type
= data.astype(float) data
19 Exercise Solution: Explainable AI (Stroke Thromobolysis Dataset)
The data loaded in this exercise is for seven acute stroke units, and whether a patient receives clost-busting treatment for stroke. There are lots of features, and a description of the features can be found in the file stroke_data_feature_descriptions.csv.
Train a decision tree model to try to predict whether or not a stroke patient receives clot-busting treatment. Use the prompts below to write each section of code.
Run the code below to import the dataset.
Import the libraries you may need! What was used last time?
What additional libraries might you need to import to try out the different boosted trees?
# Import machine learning methods
from xgboost.sklearn import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.tree import plot_tree
import plotly.express as px
import matplotlib.pyplot as plt
from sklearn.metrics import auc, roc_curve, RocCurveDisplay, f1_score, precision_score, \
\
recall_score, confusion_matrix, ConfusionMatrixDisplay, classification_report
# Additional imports for explainable
from sklearn.inspection import PartialDependenceDisplay, permutation_importance
# Import shap for shapley values
import shap
# JavaScript Important for the interactive charts later on
shap.initjs()
= data.drop('Clotbuster given',axis=1) # X = all 'data' except the 'survived' column
X = data['Clotbuster given'] # y = 'survived' column from 'data'
y = train_test_split(X, y, test_size = 0.25, random_state=42)
X_train, X_test, y_train, y_test = XGBClassifier(random_state=42)
model = model.fit(X_train,y_train)
model # Predict training and test set labels
= model.predict(X_train)
y_pred_train = model.predict(X_test)
y_pred_test
= np.mean(y_pred_train == y_train)
accuracy_train = np.mean(y_pred_test == y_test)
accuracy_test
print (f'Accuracy of predicting training data = {accuracy_train}')
print (f'Accuracy of predicting test data = {accuracy_test}')
Accuracy of predicting training data = 0.9978510028653295
Accuracy of predicting test data = 0.8154506437768241
20 Feature Importance
20.0.1 MDI
= list(X_train)
features
= model.feature_importances_
feature_importances_xgb = pd.DataFrame(index=features)
importances_xgb 'importance_xgb'] = feature_importances_xgb
importances_xgb['rank_xgb'] = importances_xgb['importance_xgb'].rank(ascending=False).values
importances_xgb['rank_xgb').head() importances_xgb.sort_values(
importance_xgb | rank_xgb | |
---|---|---|
Stroke Type_I | 0.292050 | 1.0 |
Onset Time Known Type_BE | 0.060007 | 2.0 |
S2NihssArrival | 0.047182 | 3.0 |
Anticoag before stroke_1 | 0.042693 | 4.0 |
S2RankinBeforeStroke | 0.038013 | 5.0 |
20.0.2 PFI
= X.columns.tolist()
feature_names
= permutation_importance(
result_dt_pfi =10, random_state=42, n_jobs=2
model, X_test, y_test, n_repeats
)
= pd.Series(result_dt_pfi.importances_mean, index=feature_names)
importances_pfi_dt
= plt.subplots(figsize=(15,10))
fig, ax =result_dt_pfi.importances_std, ax=ax)
importances_pfi_dt.plot.bar(yerr"Feature importances using permutation on full model")
ax.set_title("Mean accuracy decrease")
ax.set_ylabel(
fig.tight_layout() plt.show()
21 PDP + ICE
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =['Age'], # List of features to plot
features=0,
target='individual', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =['Age'], # List of features to plot
features=0,
target='both', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()
= plt.subplots(figsize=(10, 6))
fig, ax = PartialDependenceDisplay.from_estimator(
display # Your fitted model
model, # Your feature matrix
X_train, =['Age'], # List of features to plot
features=0,
target='average', # Type of PDP
kind=ax,
ax=42
random_state
) plt.show()
22 SHAP
# explain the model's predictions using SHAP
= shap.Explainer(model, X_train)
explainer
= explainer(X_test)
shap_values
shap_values
.values =
array([[ 0.57191953, -0.00599206, -0.02803201, ..., 0.53735721,
-0.11320076, -0.18555162],
[ 0.01019792, -0.01376853, -0.20353487, ..., 0.0578289 ,
0.08007584, -0.10086989],
[-0.06371776, -0.05613441, 0.00877733, ..., -0.40052104,
-0.15933883, -0.24292226],
...,
[-0.00602321, 0.17367847, 0.0584153 , ..., -0.00764367,
0.18520108, -0.16023535],
[-0.02695623, -0.01631852, 0.035347 , ..., 0.75401062,
0.12778105, -0.06493782],
[-0.08913676, -0.06648799, 0.06200275, ..., -0.47741059,
-0.15630577, -0.15140597]])
.base_values =
array([-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877, -1.97564877, -1.97564877, -1.97564877, -1.97564877,
-1.97564877])
.data =
array([[1., 0., 0., ..., 2., 0., 0.],
[0., 0., 1., ..., 1., 0., 0.],
[0., 0., 1., ..., 0., 0., 0.],
...,
[0., 1., 0., ..., 1., 1., 0.],
[0., 0., 0., ..., 2., 1., 1.],
[0., 0., 0., ..., 0., 0., 0.]])
= shap_values.values
shap_values_numeric shap_values_numeric
array([[ 0.57191953, -0.00599206, -0.02803201, ..., 0.53735721,
-0.11320076, -0.18555162],
[ 0.01019792, -0.01376853, -0.20353487, ..., 0.0578289 ,
0.08007584, -0.10086989],
[-0.06371776, -0.05613441, 0.00877733, ..., -0.40052104,
-0.15933883, -0.24292226],
...,
[-0.00602321, 0.17367847, 0.0584153 , ..., -0.00764367,
0.18520108, -0.16023535],
[-0.02695623, -0.01631852, 0.035347 , ..., 0.75401062,
0.12778105, -0.06493782],
[-0.08913676, -0.06648799, 0.06200275, ..., -0.47741059,
-0.15630577, -0.15140597]])
22.0.1 Feature table
# get feature importance for comparison using MDI method
= list(X_train)
features = model.feature_importances_
feature_importances = pd.DataFrame(index=features)
importances 'importance'] = feature_importances
importances['rank'] = importances['importance'].rank(ascending=False).values
importances['rank').head()
importances.sort_values(
# Get shapley importances
# Calculate mean Shapley value for each feature in trainign set
'mean_shapley_values'] = np.mean(
importances[=0
shap_values_numeric, axis
)
# Calculate mean absolute Shapley value for each feature in trainign set
# This will give us the average importance of each feature
'mean_abs_shapley_values'] = np.mean(
importances[abs(shap_values_numeric), axis=0
np.
)
importances
importance | rank | mean_shapley_values | mean_abs_shapley_values | |
---|---|---|---|---|
Hosp_1 | 0.016272 | 14.0 | 0.048697 | 0.113734 |
Hosp_2 | 0.009090 | 41.0 | 0.007687 | 0.070485 |
Hosp_3 | 0.011619 | 30.0 | 0.009452 | 0.054508 |
Hosp_4 | 0.025572 | 6.0 | -0.006476 | 0.071526 |
Hosp_5 | 0.016513 | 13.0 | 0.000117 | 0.025848 |
Hosp_6 | 0.018912 | 8.0 | -0.017381 | 0.070901 |
Hosp_7 | 0.009989 | 37.0 | 0.020218 | 0.148842 |
Male | 0.011162 | 32.0 | 0.011670 | 0.089067 |
Age | 0.012259 | 28.0 | 0.095522 | 0.513481 |
80+ | 0.000000 | 48.0 | 0.000000 | 0.000000 |
Onset Time Known Type_BE | 0.060007 | 2.0 | -0.035751 | 0.513144 |
Onset Time Known Type_NK | 0.000000 | 48.0 | 0.000000 | 0.000000 |
Onset Time Known Type_P | 0.021683 | 7.0 | -0.012942 | 0.105682 |
# Comorbidities | 0.012480 | 27.0 | 0.008446 | 0.140819 |
2+ comorbidotes | 0.000000 | 48.0 | 0.000000 | 0.000000 |
Congestive HF | 0.008597 | 42.0 | -0.004122 | 0.008599 |
Hypertension | 0.012239 | 29.0 | -0.070380 | 0.250935 |
Atrial Fib | 0.016634 | 12.0 | 0.011129 | 0.097199 |
Diabetes | 0.017148 | 11.0 | -0.002177 | 0.020622 |
TIA | 0.016100 | 15.0 | -0.013201 | 0.363766 |
Co-mordity | 0.000000 | 48.0 | 0.000000 | 0.000000 |
Antiplatelet_0 | 0.009424 | 39.0 | 0.002172 | 0.017998 |
Antiplatelet_1 | 0.006572 | 43.0 | -0.001211 | 0.024097 |
Antiplatelet_NK | 0.013325 | 23.0 | 0.028037 | 0.113407 |
Anticoag before stroke_0 | 0.014763 | 17.0 | -0.055988 | 0.227291 |
Anticoag before stroke_1 | 0.042693 | 4.0 | 0.023494 | 0.148024 |
Anticoag before stroke_NK | 0.014519 | 18.0 | -0.002497 | 0.019437 |
Stroke severity group_1. No stroke symtpoms | 0.004698 | 44.0 | 0.003031 | 0.014462 |
Stroke severity group_2. Minor | 0.002441 | 45.0 | -0.000203 | 0.020181 |
Stroke severity group_3. Moderate | 0.009636 | 38.0 | 0.015730 | 0.187784 |
Stroke severity group_4. Moderate to severe | 0.010140 | 36.0 | -0.000302 | 0.010772 |
Stroke severity group_5. Severe | 0.011573 | 31.0 | -0.003786 | 0.015024 |
Stroke Type_I | 0.292050 | 1.0 | 0.263943 | 1.526787 |
Stroke Type_PIH | 0.000000 | 48.0 | 0.000000 | 0.000000 |
S2RankinBeforeStroke | 0.038013 | 5.0 | -0.010677 | 0.460929 |
S2NihssArrival | 0.047182 | 3.0 | 0.167792 | 1.423348 |
S2NihssArrivalLocQuestions | 0.013432 | 22.0 | -0.009531 | 0.252463 |
S2NihssArrivalLocCommands | 0.013001 | 25.0 | -0.004716 | 0.189615 |
S2NihssArrivalBestGaze | 0.017906 | 10.0 | -0.049027 | 0.294514 |
S2NihssArrivalVisual | 0.010915 | 34.0 | 0.016386 | 0.203184 |
S2NihssArrivalFacialPalsy | 0.013921 | 19.0 | 0.044696 | 0.199442 |
S2NihssArrivalMotorArmLeft | 0.018599 | 9.0 | -0.028893 | 0.237647 |
S2NihssArrivalMotorArmRight | 0.009192 | 40.0 | 0.018928 | 0.069969 |
S2NihssArrivalMotorLegLeft | 0.011116 | 33.0 | 0.020311 | 0.108674 |
S2NihssArrivalMotorLegRight | 0.013778 | 21.0 | -0.002636 | 0.231136 |
S2NihssArrivalLimbAtaxia | 0.010385 | 35.0 | -0.008956 | 0.038973 |
S2NihssArrivalSensory | 0.012529 | 26.0 | -0.020345 | 0.162434 |
S2NihssArrivalBestLanguage | 0.013232 | 24.0 | -0.006986 | 0.421014 |
S2NihssArrivalDysarthria | 0.014873 | 16.0 | -0.016975 | 0.148263 |
S2NihssArrivalExtinctionInattention | 0.013818 | 20.0 | 0.011012 | 0.232744 |
22.0.1.1 Feature importance comparison
= \
importance_top_10
importances.sort_values(='importance', ascending=False
by10).index
).head(
= \
shapley_top_10
importances.sort_values(='mean_abs_shapley_values',
by=False).head(10).index
ascending
# Add to DataFrame
= pd.DataFrame()
top_10_features 'importances'] = importance_top_10.values
top_10_features['Shapley'] = shapley_top_10.values
top_10_features[
# Display
top_10_features
importances | Shapley | |
---|---|---|
0 | Stroke Type_I | Stroke Type_I |
1 | Onset Time Known Type_BE | S2NihssArrival |
2 | S2NihssArrival | Age |
3 | Anticoag before stroke_1 | Onset Time Known Type_BE |
4 | S2RankinBeforeStroke | S2RankinBeforeStroke |
5 | Hosp_4 | S2NihssArrivalBestLanguage |
6 | Onset Time Known Type_P | TIA |
7 | Hosp_6 | S2NihssArrivalBestGaze |
8 | S2NihssArrivalMotorArmLeft | S2NihssArrivalLocQuestions |
9 | S2NihssArrivalBestGaze | Hypertension |
22.0.2 Global: Beeswarm
# summarize the effects of all the features
=25) shap.plots.beeswarm(shap_values, max_display
22.0.3 Global: Bar
=20) shap.plots.bar(shap_values, max_display
22.0.4 Bar: by factor
= ["Women" if shap_values[i, "Male"].data == 0 else "Men" for i in range(shap_values.shape[0])]
sex abs.mean(0)) shap.plots.bar(shap_values.cohorts(sex).
22.1 Local: Waterfall plots
# visualize the first prediction's explanation
0]) shap.plots.waterfall(shap_values[
# visualize another prediction's explanation
7], max_display=15) shap.plots.waterfall(shap_values[
# visualize another prediction's explanation
145], max_display=15) shap.plots.waterfall(shap_values[
22.1.1 Local: Force Plots
# visualize the first prediction's explanation with a force plot
0]) shap.plots.force(shap_values[
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
22.1.2 Global: Force Plots
# visualize all the predictions
# this struggles with a large number of values so we'll sample a small set
1000)) shap.plots.force(shap.utils.sample(shap_values,
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
22.1.3 Dependence Plots
X.columns
Index(['Hosp_1', 'Hosp_2', 'Hosp_3', 'Hosp_4', 'Hosp_5', 'Hosp_6', 'Hosp_7',
'Male', 'Age', '80+', 'Onset Time Known Type_BE',
'Onset Time Known Type_NK', 'Onset Time Known Type_P',
'# Comorbidities', '2+ comorbidotes', 'Congestive HF', 'Hypertension',
'Atrial Fib', 'Diabetes', 'TIA', 'Co-mordity', 'Antiplatelet_0',
'Antiplatelet_1', 'Antiplatelet_NK', 'Anticoag before stroke_0',
'Anticoag before stroke_1', 'Anticoag before stroke_NK',
'Stroke severity group_1. No stroke symtpoms',
'Stroke severity group_2. Minor', 'Stroke severity group_3. Moderate',
'Stroke severity group_4. Moderate to severe',
'Stroke severity group_5. Severe', 'Stroke Type_I', 'Stroke Type_PIH',
'S2RankinBeforeStroke', 'S2NihssArrival', 'S2NihssArrivalLocQuestions',
'S2NihssArrivalLocCommands', 'S2NihssArrivalBestGaze',
'S2NihssArrivalVisual', 'S2NihssArrivalFacialPalsy',
'S2NihssArrivalMotorArmLeft', 'S2NihssArrivalMotorArmRight',
'S2NihssArrivalMotorLegLeft', 'S2NihssArrivalMotorLegRight',
'S2NihssArrivalLimbAtaxia', 'S2NihssArrivalSensory',
'S2NihssArrivalBestLanguage', 'S2NihssArrivalDysarthria',
'S2NihssArrivalExtinctionInattention'],
dtype='object')
22.1.3.1 Simple scatter of a single feature
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
'Age']) shap.plots.scatter(shap_values[:,
'Diabetes']) shap.plots.scatter(shap_values[:,
'# Comorbidities']) shap.plots.scatter(shap_values[:,
22.1.4 Scatter of multiple features
Passing in shap_values to the colour will colour the value by the most strongly interacting other value.
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
"Male"], color=shap_values) shap.plots.scatter(shap_values[:,
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
"Age"], color=shap_values) shap.plots.scatter(shap_values[:,
Alternatively we can choose to colour by a specific column.
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
"Age"], color=shap_values[:,"Male"]) shap.plots.scatter(shap_values[:,
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
"Male"], color=shap_values[:,"Age"]) shap.plots.scatter(shap_values[:,
22.2 Prediction Uncertainty
= 50
splits = []
train_set
for i in range(splits):
=1, replace=True)) train_set.append(X_train.join(y_train).sample(frac
# Set up lists for models and probability predictions
= []
models = []
results = []
accuracies
for i in range(splits):
# Get X and y
= train_set[i].drop('Clotbuster given', axis=1)
X_train = train_set[i]['Clotbuster given']
y_train
# Define and train model; use different random seed for each model
= XGBClassifier(random_state=42+i)
model
model.fit(X_train, y_train)
models.append(model)
# Get predicted probabilities and class
= model.predict_proba(X_test)[:,1]
y_probs = y_probs > 0.5
y_pred
results.append([y_probs])
# Show accuracy
= np.mean(y_pred == y_test)
accuracy
accuracies.append(accuracy)
= np.array(results)
results = results.T.reshape(-1, splits) results
print (f'Mean accuracy: {np.mean(accuracies):0.3f}')
Mean accuracy: 0.803
= plt.figure()
fig = fig.add_subplot()
ax =999)
ax.boxplot(accuracies, whis'Model accuracy')
ax.set_ylabel(# Remove xtick labels
ax.axes.xaxis.set_ticklabels([]) plt.show()
= results >= 0.5
classification = classification.sum(axis=1) >= splits/2
consensus = np.mean(consensus == y_test)
consensus_accuracy print (f'Consensus accuracy: {consensus_accuracy:0.3f}')
Consensus accuracy: 0.826
= results[np.mean(results,axis=1).argsort()]
results = np.mean(results,axis=1)
mean = np.std(results,axis=1)
stdev
= plt.figure()
fig = fig.add_subplot()
ax range(len(mean)), mean, yerr=stdev, label='Standard Deviation', zorder=1)
ax.errorbar('o', c='r', markersize=2, label = 'Mean probability', zorder=2)
ax.plot(mean,
ax.axes.xaxis.set_ticklabels([])'Patient')
ax.set_xlabel('Probability of Clotbuster Being Given')
ax.set_ylabel(
ax.set_xticks([])
ax.grid()
ax.legend() plt.show()
= results[np.mean(results,axis=1).argsort()]
results
= np.mean(results,axis=1)
mean = np.std(results,axis=1)
stdev = stdev / np.sqrt(splits)
se
= plt.figure()
fig = fig.add_subplot()
ax range(len(mean)), mean, yerr=se, label='Standard Error', zorder=1)
ax.errorbar('o', c='r', markersize=2, label = 'Mean probability', zorder=2)
ax.plot(mean,
ax.axes.xaxis.set_ticklabels([])'Patient')
ax.set_xlabel('Probability of Clotbuster Being Given')
ax.set_ylabel(
ax.set_xticks([])
ax.grid()
ax.legend() plt.show()