← retour aux snippets

Numba @njit: accélérer une boucle Python

compiler une fonction numérique pour gagner en vitesse

Numba @njit: accélérer une boucle Python

objectif

Expliquer et montrer comment compiler une fonction numérique pour gagner en vitesse.

code minimal

import numpy as np
from numba import njit

@njit
def pairwise_l2(a, b):
    m, n = a.shape[0], b.shape[0]
    out = np.empty((m, n), dtype=np.float64)
    for i in range(m):
        for j in range(n):
            d = 0.0
            for k in range(a.shape[1]):
                diff = a[i,k] - b[j,k]
                d += diff * diff
            out[i,j] = d
    return out

A = np.random.rand(200, 8)
B = np.random.rand(300, 8)
D = pairwise_l2(A, B)
D.shape

utilisation

# premier appel compile, suivants plus rapides
D2 = pairwise_l2(A[:10], B[:10])
D2.shape

variante(s) utile(s)

# paralléliser avec prange
from numba import prange, njit
@njit(parallel=True)
def row_sum(x):
    out = np.empty(x.shape[0])
    for i in prange(x.shape[0]):
        s = 0.0
        for j in range(x.shape[1]):
            s += x[i,j]
        out[i] = s
    return out

notes

  • Évitez les objets Python dans les fonctions numba (mode nopython).
  • Le premier appel inclut le temps de compilation.