Numba @njit: accélérer une boucle Python
objectif
Expliquer et montrer comment compiler une fonction numérique pour gagner en vitesse.
code minimal
import numpy as np
from numba import njit
@njit
def pairwise_l2(a, b):
m, n = a.shape[0], b.shape[0]
out = np.empty((m, n), dtype=np.float64)
for i in range(m):
for j in range(n):
d = 0.0
for k in range(a.shape[1]):
diff = a[i,k] - b[j,k]
d += diff * diff
out[i,j] = d
return out
A = np.random.rand(200, 8)
B = np.random.rand(300, 8)
D = pairwise_l2(A, B)
D.shape
utilisation
# premier appel compile, suivants plus rapides
D2 = pairwise_l2(A[:10], B[:10])
D2.shape
variante(s) utile(s)
# paralléliser avec prange
from numba import prange, njit
@njit(parallel=True)
def row_sum(x):
out = np.empty(x.shape[0])
for i in prange(x.shape[0]):
s = 0.0
for j in range(x.shape[1]):
s += x[i,j]
out[i] = s
return out
notes
- Évitez les objets Python dans les fonctions numba (mode nopython).
- Le premier appel inclut le temps de compilation.