← retour aux snippets

optuna: recherche d'hyperparamètres

Optimiser automatiquement les hyperparamètres avec Optuna et pruning.

objectif

Optimiser automatiquement les hyperparamètres avec Optuna et pruning.

code minimal

import optuna
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer

X, y = load_breast_cancer(return_X_y=True)

def objective(trial):
    n_estimators = trial.suggest_int("n_estimators", 50, 500)
    max_depth = trial.suggest_int("max_depth", 2, 12)
    min_samples_split = trial.suggest_int("min_samples_split", 2, 10)
    clf = RandomForestClassifier(
        n_estimators=n_estimators,
        max_depth=max_depth,
        min_samples_split=min_samples_split,
        random_state=0,
        n_jobs=-1,
    )
    score = cross_val_score(clf, X, y, cv=3, scoring="roc_auc").mean()
    return score

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)
print(len(study.trials) >= 1)

utilisation

# Meilleurs paramètres et valeur
best = study.best_trial
print(isinstance(best.params, dict))

variante(s) utile(s)

# Importance des hyperparamètres
imp = optuna.importance.get_param_importances(study)
print(isinstance(imp, dict))

notes

  • Optuna propose TPE (Tree-structured Parzen Estimator) par défaut; utilisez pruners pour gagner du temps sur essais médiocres.