← retour aux snippets

sklearn: GridSearchCV

Recherche exhaustive d'hyperparamètres avec CV.

objectif

Recherche exhaustive d’hyperparamètres avec CV.

code minimal

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
g = GridSearchCV(SVC(), {"C":[0.1,1], "kernel":["linear","rbf"]}, cv=3).fit(X, y)
print(len(g.cv_results_["params"]) == 4)

utilisation

from sklearn.model_selection import RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
rs = RandomizedSearchCV(RandomForestClassifier(), {"n_estimators":[10,50]}, n_iter=2, cv=2, random_state=0)
print(hasattr(rs, "fit"))

variante(s) utile(s)

from sklearn.model_selection import StratifiedKFold
print(hasattr(StratifiedKFold(n_splits=3), "split"))

notes

  • Fixer scoring pertinent; valider par CV stratifiée.