Skip to content

evhub/better_einsum

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

better_einsum

np.einsum but better:

  • better syntax ("C[i,k] = A[i,j] B[j,k]" instead of "ij, jk -> ik"),
  • names and indices can be arbitrary variable names not just single letters,
  • support for keyword arguments (einsum("C = A[i] B[i]", A=..., B=...)),
  • warnings on common bugs, and
  • an einsum.exec method for executing the einsum assignment in the calling scope.

pip install better_einsum then:

>>> import numpy as np
>>> from better_einsum import einsum

>>> A = np.array([[1, 2], [3, 4]])
>>> B = np.array([[5, 6], [7, 8]])

>>> einsum("C[i,k] = A[i,j] * B[j,k]", A=A, B=B)  # equivalent to A.dot(B)
array([[19, 22],
       [43, 50]])

>>> einsum("C = A[i,j] * B[i,j]", A=A, B=B)  # equivalent to np.sum(A * B)
70

>>> einsum("C[...] = A[i,...] * B[i,...]", A=A, B=B)  # equivalent to np.sum(A * B, axis=0)
array([26, 44])

>>> einsum("C[i,k] = A[i,j] B[j,k]", A, B)  # * is optional; positional args are also supported
array([[19, 22],
       [43, 50]])

>>> einsum("C[i,k] = A[i,j] * B[j,k]", A, A)  # better_einsum will catch common mistakes for you
better_einsum.py: UserWarning: better_einsum: variable 'B' in calling scope points to a different object than was passed in; this usually denotes an error
array([[ 7, 10],
       [15, 22]])

>>> einsum("_[i,k] = _[i,j] * _[j,k]", A, B)  # use placeholders if you don't want to name your variables
array([[19, 22],
       [43, 50]])

>>> einsum.exec("C[i,k] = A[i,j] * B[j,k]")  # directly assigns to C and looks up A and B
array([[19, 22],
       [43, 50]])
>>> C
array([[19, 22],
       [43, 50]])

>>> import jax.numpy as jnp
>>> from functools import partial
>>> jnp_einsum = partial(einsum, base_einsum_func=jnp.einsum)  # better_einsum for JAX

About

np.einsum but better

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published