← retour aux snippets

numpy: argpartition (Top-K)

Extraire indices des plus grands/petits éléments.

objectif

Extraire indices des plus grands/petits éléments.

code minimal

import numpy as np
x = np.array([5,1,4,3])
k = 2
idx = np.argpartition(-x, k-1)[:k]
print(sorted(x[idx].tolist()) == [4,5])

utilisation

import numpy as np
x = np.array([1,3,2,4])
print(np.argpartition(x, 1)[:2].size)

variante(s) utile(s)

import numpy as np
x = np.array([1,2,3,4])
top = x[np.argpartition(-x, 0)[:1]][0]
print(top == 4)

notes

  • Plus rapide que argsort pour grands tableaux et petit K.