Einstein summation
JAX implementation of numpy.einsum()
.
einsum
is a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays. It has a somewhat complicated overloaded API; the arguments below reflect the most common calling convention. The Examples section below demonstrates some of the alternative calling conventions.
subscripts – string containing axes names separated by commas.
*operands – sequence of one or more arrays corresponding to the subscripts.
optimize (str | bool | list[tuple[int, ...]]) – specify how to optimize the order of computation. In JAX this defaults to "auto"
which produces optimized expressions via the opt_einsum package. Other options are True
(same as "optimal"
), False
(unoptimized), or any string supported by opt_einsum
, which includes "optimal"
, "greedy"
, "eager"
, and others. It may also be a pre-computed path (see einsum_path()
).
precision (None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset) – either None
(default), which means the default precision for the backend, a Precision
enum value (Precision.DEFAULT
, Precision.HIGH
or Precision.HIGHEST
).
preferred_element_type (str | type[Any] | dtype | SupportsDType | None) – either None
(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
out (None) – unsupported by JAX
_dot_general (Callable[[...], Array]) – optionally override the dot_general
callable used by einsum
. This parameter is experimental, and may be removed without warning at any time.
array containing the result of the einstein summation.
Examples
The mechanics of einsum
are perhaps best demonstrated by example. Here we show how to use einsum
to compute a number of quantities from one or more arrays. For more discussion and examples of einsum
, see the documentation of numpy.einsum()
.
>>> M = jnp.arange(16).reshape(4, 4) >>> x = jnp.arange(4) >>> y = jnp.array([5, 4, 3, 2])
Vector product
>>> jnp.einsum('i,i', x, y) Array(16, dtype=int32) >>> jnp.vecdot(x, y) Array(16, dtype=int32)
Here are some alternative einsum
calling conventions to compute the same result:
>>> jnp.einsum('i,i->', x, y) # explicit form Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,)) # implicit form via indices Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,), ()) # explicit form via indices Array(16, dtype=int32)
Matrix product
>>> jnp.einsum('ij,j->i', M, x) # explicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.matmul(M, x) Array([14, 38, 62, 86], dtype=int32)
Here are some alternative einsum
calling conventions to compute the same result:
>>> jnp.einsum('ij,j', M, x) # implicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,)) # implicit form via indices Array([14, 38, 62, 86], dtype=int32)
Outer product
>>> jnp.einsum("i,j->ij", x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.outer(x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)
Some other ways of computing outer products:
>>> jnp.einsum("i,j", x, y) # implicit form Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,), (0, 1)) # explicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,)) # implicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)
1D array sum
>>> jnp.einsum("i->", x) # requires explicit form Array(6, dtype=int32) >>> jnp.einsum(x, (0,), ()) # explicit form via indices Array(6, dtype=int32) >>> jnp.sum(x) Array(6, dtype=int32)
Sum along an axis
>>> jnp.einsum("...j->...", M) # requires explicit form Array([ 6, 22, 38, 54], dtype=int32) >>> jnp.einsum(M, (..., 0), (...,)) # explicit form via indices Array([ 6, 22, 38, 54], dtype=int32) >>> M.sum(-1) Array([ 6, 22, 38, 54], dtype=int32)
Matrix transpose
>>> y = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.einsum("ij->ji", y) # explicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum("ji", y) # implicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (1, 0)) # implicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (0, 1), (1, 0)) # explicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.transpose(y) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
Matrix diagonal
>>> jnp.einsum("ii->i", M) Array([ 0, 5, 10, 15], dtype=int32) >>> jnp.diagonal(M) Array([ 0, 5, 10, 15], dtype=int32)
Matrix trace
>>> jnp.einsum("ii", M) Array(30, dtype=int32) >>> jnp.trace(M) Array(30, dtype=int32)
Tensor products
>>> x = jnp.arange(30).reshape(2, 3, 5) >>> y = jnp.arange(60).reshape(3, 4, 5) >>> jnp.einsum('ijk,jlk->il', x, y) # explicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)]) Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum('ijk,jlk', x, y) # implicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3)) # explicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2)) # implicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32)
Chained dot products
>>> w = jnp.arange(5, 9).reshape(2, 2) >>> x = jnp.arange(6).reshape(2, 3) >>> y = jnp.arange(-2, 4).reshape(3, 2) >>> z = jnp.array([[2, 4, 6], [3, 5, 7]]) >>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4)) # implicit, via indices Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> w @ x @ y @ z # direct chain of matmuls Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.linalg.multi_dot([w, x, y, z]) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4