← retour aux snippets

xgboost: xgb.cv et early stopping

Validation croisée intégrée avec arrêt anticipé pour choisir n_estimators.

objectif

Validation croisée intégrée avec arrêt anticipé pour choisir n_estimators.

code minimal

import numpy as np, xgboost as xgb
rng = np.random.RandomState(0)
X = rng.randn(300, 5); y = (X[:,0]*1.5 + X[:,1]*-1.0 + rng.randn(300)*0.1 > 0).astype(int)
dtrain = xgb.DMatrix(X, label=y)
params = {"objective":"binary:logistic", "eta":0.1, "max_depth":4, "subsample":0.8, "colsample_bytree":0.8, "tree_method":"hist", "seed":0}
cv = xgb.cv(params, dtrain, num_boost_round=2000, nfold=5, metrics=["logloss","auc"], early_stopping_rounds=50, seed=0, verbose_eval=False)
best_n = int(cv.shape[0])
print(best_n > 0)

utilisation

# Entraîner le modèle final avec best_n
bst = xgb.train(params, dtrain, num_boost_round=best_n)
print(len(bst.get_dump()) > 0)

variante(s) utile(s)

# Courbes disponibles dans cv DataFrame
print(set(cv.columns) >= {"test-auc-mean","train-auc-mean"})

notes

  • xgb.cv fournit métriques moyennes et écart-types; utilisez best_n pour figer l’itération.