shap.dependence_plot
¶
This notebook is designed to demonstrate (and so document) how to use the shap.dependence_plot
function. It uses an XGBoost model trained on the classic UCI adult income dataset (which is classification task to predict if people made over 50k in the 90s).
In [2]:
import xgboost import shap # train XGBoost model X,y = shap.datasets.adult() model = xgboost.XGBClassifier().fit(X, y) # compute SHAP values explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values(X)Simple dependence plot¶
A dependence plot is a scatter plot that shows the effect a single feature has on the predictions made by the model. In this example the log-odds of making over 50k increases significantly between age 20 and 40.
In [3]:
# The first argument is the index of the feature we want to plot # The second argument is the matrix of SHAP values (it is the same shape as the data matrix) # The third argument is the data matrix (a pandas dataframe or numpy array) shap.dependence_plot(0, shap_values, X)Other ways to make the same plot¶
In [4]:
# If we pass a numpy array instead of a data frame then we # need pass the feature names in separately shap.dependence_plot(0, shap_values, X.values, feature_names=X.columns)
In [5]:
# We can pass a feature name instead of an index shap.dependence_plot("Age", shap_values, X)
In [6]:
# We can also use the special "rank(i)" systax to specify the i'th most # important feature to the model. As measured by: np.abs(shap_values).mean(0) # In this example age is the second most important feature. shap.dependence_plot("rank(1)", shap_values, X)
In [7]:
# The interaction_index argument can be used to explicitly # set which feature gets used for coloring shap.dependence_plot("rank(1)", shap_values, X, interaction_index="Education-Num")Exploring different interaction colorings¶
In [8]:
# we can turn off interaction coloring shap.dependence_plot("Age", shap_values, X, interaction_index=None)
In [9]:
# we can use shap.approximate_interactions to guess which features # may interact with age inds = shap.approximate_interactions("Age", shap_values, X) # make plots colored by each of the top three possible interacting features for i in range(3): shap.dependence_plot("Age", shap_values, X, interaction_index=inds[i])
In [10]:
import matplotlib.pyplot as plt # you can use the cmap parameter to provide your own custom color map shap.dependence_plot("Age", shap_values, X, cmap=plt.get_cmap("cool"))Customizing the figure properties¶
In [12]:
# by passing show=False you can prevent shap.dependence_plot from calling # the matplotlib show() function, and so you can keep customizing the plot # before eventually calling show yourself shap.dependence_plot(0, shap_values, X, show=False) plt.title("Age dependence plot") plt.ylabel("SHAP value for the 'Age' feature") # plt.savefig("my_dependence_plot.pdf") # we can save a PDF of the figure if we want plt.show()
In [13]:
# you can use xmax and xmin with a percentile notation to hide outliers shap.dependence_plot(0, shap_values, X, xmin="percentile(1)", xmax="percentile(99)")
In [14]:
# transparency can help reveal dense vs. sparse areas of the scatter plot shap.dependence_plot(0, shap_values, X, alpha=0.1)
In [15]:
# an alternative to transparency is to reduce the dot size shap.dependence_plot(0, shap_values, X, dot_size=2)
In [16]:
# for categorical (or binned) data adding a small amount of x-jitter makes # thin columns of dots more readable shap.dependence_plot(0, shap_values, X, x_jitter=1, dot_size=1)Using string categorical features¶
In [109]:
X_cat = X.copy() relationship_decoding = { 0: 'Not-in-family', 1: 'Unmarried', 2: 'Other-relative', 3: 'Own-child', 4: 'Husband', 5: 'Wife' } X_cat["Relationship"] = X_cat["Relationship"].map(relationship_decoding) X_cat.head(3)
Out[109]:
Age Workclass Education-Num Marital Status Occupation Relationship Race Sex Capital Gain Capital Loss Hours per week Country 0 39.0 7 13.0 4 1 Not-in-family 4 1 2174.0 0.0 40.0 39 1 50.0 6 13.0 2 4 Husband 4 1 0.0 0.0 13.0 39 2 38.0 4 9.0 0 6 Not-in-family 4 1 0.0 0.0 40.0 39In [108]:
# You can use string-valued category features shap.dependence_plot("Relationship", shap_values, X_cat)
In [107]:
# It is also possible to use string-valued features to plot interaction effect shap.dependence_plot(0, shap_values, X_cat, interaction_index="Relationship")
Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4