Efficiently compute matrix products between a sequence of arrays.
JAX implementation of numpy.linalg.multi_dot()
.
JAX internally uses the opt_einsum library to compute the most efficient operation order.
arrays (Sequence[ArrayLike]) – sequence of arrays. All must be two-dimensional, except the first and last which may be one-dimensional.
precision (lax.PrecisionLike) – either None
(default), which means the default precision for the backend, a Precision
enum value (Precision.DEFAULT
, Precision.HIGH
or Precision.HIGHEST
).
an array representing the equivalent of reduce(jnp.matmul, arrays)
, but evaluated in the optimal order.
This function exists because the cost of computing sequences of matmul operations can differ vastly depending on the order in which the operations are evaluated. For a single matmul, the number of floating point operations (flops) required to compute a matrix product can be approximated this way:
>>> def approx_flops(x, y): ... # for 2D x and y, with x.shape[1] == y.shape[0] ... return 2 * x.shape[0] * x.shape[1] * y.shape[1]
Suppose we have three matrices that we’d like to multiply in sequence:
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> x = jax.random.normal(key1, shape=(200, 5)) >>> y = jax.random.normal(key2, shape=(5, 100)) >>> z = jax.random.normal(key3, shape=(100, 10))
Because of associativity of matrix products, there are two orders in which we might evaluate the product x @ y @ z
, and both produce equivalent outputs up to floating point precision:
>>> result1 = (x @ y) @ z >>> result2 = x @ (y @ z) >>> jnp.allclose(result1, result2, atol=1E-4) Array(True, dtype=bool)
But the computational cost of these differ greatly:
>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z)) (x @ y) @ z flops: 600000 >>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z)) x @ (y @ z) flops: 30000
The second approach is about 20x more efficient in terms of estimated flops!
multi_dot
is a function that will automatically choose the fastest computational path for such problems:
>>> result3 = jnp.linalg.multi_dot([x, y, z]) >>> jnp.allclose(result1, result3, atol=1E-4) Array(True, dtype=bool)
We can use JAX’s Ahead-of-time lowering and compilation tools to estimate the total flops of each approach, and confirm that multi_dot
is choosing the more efficient option:
>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops'] 600000.0 >>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops'] 30000.0 >>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops'] 30000.0
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4