← retour aux snippets

numpy: einsum et matmul batched

Exprimer des contractions et produits batched lisiblement.

python numpy #numpy#einsum#matmul

objectif

Exprimer des contractions et produits batched lisiblement.

code minimal

import numpy as np

A = np.arange(6).reshape(2,3)
b = np.arange(3)
print(np.einsum("ij,j->i", A, b).tolist())

utilisation

import numpy as np

A = np.random.default_rng(0).normal(size=(4,2,3))
B = np.random.default_rng(1).normal(size=(4,3,2))
print((A @ B).shape)

variante(s) utile(s)

import numpy as np

A = np.arange(8).reshape(2,2,2)
B = np.arange(4).reshape(2,2)
print(np.einsum("ijk,jk->i", A, B).tolist())

notes

  • einsum_path peut optimiser pour gros tenseurs.