← retour aux snippets

xgboost + optuna: tuning de bout en bout

Optimiser XGBClassifier avec Optuna et réentraîner au best_iteration.

python tuning #xgboost#optuna#tuning

objectif

Optimiser XGBClassifier avec Optuna et réentraîner au best_iteration.

code minimal

import optuna
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from xgboost import XGBClassifier
from sklearn.datasets import load_breast_cancer

X, y = load_breast_cancer(return_X_y=True)
Xtr, Xva, ytr, yva = train_test_split(X, y, test_size=0.2, random_state=0, stratify=y)

def objective(trial):
    clf = XGBClassifier(
        n_estimators=2000,
        max_depth=trial.suggest_int("max_depth", 3, 8),
        learning_rate=trial.suggest_float("learning_rate", 0.01, 0.2, log=True),
        subsample=trial.suggest_float("subsample", 0.6, 1.0),
        colsample_bytree=trial.suggest_float("colsample_bytree", 0.6, 1.0),
        reg_lambda=trial.suggest_float("reg_lambda", 1e-3, 10.0, log=True),
        min_child_weight=trial.suggest_float("min_child_weight", 1e-2, 10.0, log=True),
        tree_method="hist",
        random_state=0,
    )
    clf.fit(Xtr, ytr, eval_set=[(Xva, yva)], eval_metric="auc", early_stopping_rounds=50, verbose=False)
    auc = roc_auc_score(yva, clf.predict_proba(Xva)[:,1])
    trial.set_user_attr("best_iteration", int(clf.best_iteration))
    return auc

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)
best = study.best_trial
print(best.value <= 1.0)

utilisation

# Réentraîner sur train+val avec best params et best_iteration
params = study.best_trial.params
best_n = study.best_trial.user_attrs.get("best_iteration", 200)
final = XGBClassifier(n_estimators=best_n, tree_method="hist", random_state=0, **params).fit(X, y)
print(hasattr(final, "predict_proba"))

variante(s) utile(s)

# Export des résultats
df = study.trials_dataframe()
print(set(["value","number"]).issubset(df.columns))

notes

  • Conservez best_iteration pour éviter le surapprentissage; fixez random_state pour reproductibilité.