← retour aux snippets

shap: TreeExplainer pour modèles d'arbres

Expliquer une prédiction tabulaire d'un modèle de type gradient boosting.

objectif

Expliquer une prédiction tabulaire d’un modèle de type gradient boosting.

code minimal

import shap
from xgboost import XGBClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

X, y = load_breast_cancer(return_X_y=True)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0, stratify=y)

model = XGBClassifier(n_estimators=200, max_depth=3, learning_rate=0.1, subsample=0.9, colsample_bytree=0.9, tree_method="hist", random_state=0)
model.fit(X_train, y_train)

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_val[:10])
print(len(shap_values) == 10)

utilisation

# Valeur de base (expected value)
base = explainer.expected_value
print(base is not None)

variante(s) utile(s)

# Force plot / waterfall (les figures sont objets Matplotlib)
# shap.plots.waterfall(shap.Explanation(values=shap_values[0], base_values=base, data=X_val[0]))
print(True)

notes

  • SHAP nécessite des dépendances graphiques pour afficher; pour CPU-only, TreeExplainer est rapide sur modèles d’arbres.