← retour aux snippets

optuna: pruning avec MedianPruner

Arrêter tôt les essais médiocres pour accélérer la recherche.

python tuning #optuna#pruning#speed

objectif

Arrêter tôt les essais médiocres pour accélérer la recherche.

code minimal

import optuna
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import GradientBoostingClassifier

X, y = load_breast_cancer(return_X_y=True)

def objective(trial):
    trial.report(0.0, step=0)
    clf = GradientBoostingClassifier(
        n_estimators=trial.suggest_int("n_estimators", 50, 400),
        learning_rate=trial.suggest_float("learning_rate", 0.01, 0.2, log=True),
        max_depth=trial.suggest_int("max_depth", 2, 5),
        random_state=0,
    )
    score = cross_val_score(clf, X, y, cv=3, scoring="roc_auc").mean()
    trial.report(score, step=1)
    if trial.should_prune():
        raise optuna.TrialPruned()
    return score

study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
print(len(study.trials) >= 1)

utilisation

# Visualiser l'importance des hyperparamètres
imp = optuna.importance.get_param_importances(study)
print(isinstance(imp, dict))

variante(s) utile(s)

# Sauvegarder la meilleure config
best = study.best_params
print(isinstance(best, dict))

notes

  • Les pruners (Median, SuccessiveHalving, Hyperband) stoppent tôt; nécessite des reports (trial.report).