Can we explain the output of complex tree models? We use different algorithms to improve the performance of the model. If you input a new test datapoint into the model, it will produce an output. Did you ever explore which features are causing to produce the output? We can extract the overall feature importance from the model, but can we get which features are responsible for the output?
If we use a decision tree, we can at least explain the output by plotting the tree structure. But, it’s not easy to explain the output for advanced tree-based algorithms like XGBoost, LightGBM, CatBoost or other scikit-learn models. To explain the output for the above algorithms, researches have come up with an approach called SHAP.
SHAP (SHapley Additive exPlanations) is a unified approach to explain the output of any machine learning model. SHAP connects game theory with local explanations, uniting several previous methods and representing the only possible consistent and locally accurate additive feature attribution method based on expectations. (Reference - SHAP Documentation)
Now, I will provide you with the Python code that helps you to explain the output using SHAP values. Tree SHAP is a fast and exact method to estimate SHAP values for tree models and ensembles of trees, under several different possible assumptions about feature dependence.
# Import libraries import shap import xgboost import pandas as pd shap.initjs()
# Load Diabetes dataset X, y = shap.datasets.diabetes()
You can explore more about the dataset here.
# Shape X.shape, y.shape
((442, 10), (442,))
# Distribution of target variable pd.Series(y).plot('hist')
The dataset has 442 datapoints and 10 features. The target variable is numerical. Let us use XGBoost Regressor to train the model.
# Train using XGBoost Regressor model XGB_model = xgboost.XGBRegressor() XGB_model.fit(X, y)
We have the trained the XGBoost model. Now, the game begins..
There are three different explainers in SHAP, that uses the trained model to estimate the shap values. They are TreeExplainer, KernelExplainer and DeepExplainer. DeepExplainer is mainly used to explain the output of Deep Learning algorithms. In this article, let us use TreeExplainer to estimate the shap values.
# Create Tree explainer explainer = shap.TreeExplainer(XGB_model)
# Extract SHAP values to explain the model predictions shap_values = explainer.shap_values(X)
The explainer returns a matrix of SHAP values (#samples X #features). The mean absolute value of the SHAP values for each feature gives us the SHAP feature importance of the model.
# Plot Feature Importance shap.summary_plot(shap_values, X, plot_type="bar")
We can also plot the SHAP value of every feature for each datapoint. We can do this by changing the plot type to 'dot'. In this plot, we can see the relation between the feature value and SHAP value. For example, if we take the most important feature 's5', it clearly says that if the feature value is high, SHAP value is also high and vice versa. This explains that if the value of the feature increases, the output also increases. Please do remember that, if the SHAP value is high, it explains us that it is pushing the output probability/value higher.
# Plot Feature Importance - 'Dot' type shap.summary_plot(shap_values, X, plot_type='dot')
We can visualize the features that are pushing or pulling towards the output from base value (the mean output over the training dataset) using force plot. This explains us alot of things. We can explain the output to clients very easily using SHAP even for the complex tree models. This makes our life easier.
# Visualize the explanation of first prediction shap.force_plot(explainer.expected_value, shap_values[0, :], X.iloc[0, :])
We can also see the above plot for all the datapoints in a same plot by stacking them horizontally.
# Visualize the training set using SHAP predictions shap.force_plot(explainer.expected_value, shap_values, X)
There are many other features of SHAP to explore. You can read the SHAP documentation here. The link to the Github repository of SHAP. You can find the link to the code related to this article here.
Hope, this article helps you learn something new. If you have any queries, comment in the comments section below. I would be more than happy to answer your queries.
Thank you for reading my blog and supporting me. Stay tuned for my next article. If you want to receive email updates, don’t forget to subscribe to my blog. Keep learning and sharing!!
Follow me here:
GitHub: https://github.com/Abhishekmamidi123
LinkedIn: https://www.linkedin.com/in/abhishekmamidi/
Kaggle: https://www.kaggle.com/abhishekmamidi
If you are looking for any specific blog, please do comment in the comment section below.
GitHub: https://github.com/Abhishekmamidi123
LinkedIn: https://www.linkedin.com/in/abhishekmamidi/
Kaggle: https://www.kaggle.com/abhishekmamidi
If you are looking for any specific blog, please do comment in the comment section below.
This comment has been removed by the author.
ReplyDeleteThis is a good post about artificial intelligence article . Thanks for sharing with us, Excellent work and I really appreciate your work.
ReplyDelete