← retour aux snippets

sklearn GroupKFold et GroupShuffleSplit

validation croisée en respectant les groupes (leakage évité)

python ml #sklearn#cv#groups

sklearn GroupKFold et GroupShuffleSplit

objectif

Expliquer et montrer comment validation croisée en respectant les groupes (leakage évité).

code minimal

import numpy as np
from sklearn.model_selection import GroupKFold, GroupShuffleSplit

X = np.random.randn(20, 3)
y = (X[:,0] > 0).astype(int)
groups = np.repeat(np.arange(5), 4)  # 5 groupes de 4
gkf = GroupKFold(n_splits=5)

for tr, te in gkf.split(X, y, groups):
    assert len(set(groups[tr]) & set(groups[te])) == 0

utilisation

# split aléatoire par groupes
gss = GroupShuffleSplit(n_splits=3, test_size=0.2, random_state=0)
for tr, te in gss.split(X, y, groups):
    print("groupes test:", sorted(set(groups[te])))

variante(s) utile(s)

# TimeSeriesSplit par groupe (custom): trier par groupe puis splitter
order = np.argsort(groups)
Xo, yo, go = X[order], y[order], groups[order]
# implémenter un split séquentiel manuel si nécessaire

notes

  • Indispensable quand plusieurs lignes appartiennent au même individu ou session.
  • Empêche le leakage entre train et test.