Compute the inner product of two arrays.
JAX implementation of numpy.inner()
.
Unlike jax.numpy.matmul()
or jax.numpy.dot()
, this always performs a contraction along the last dimension of each input.
a (Array | ndarray | bool | number | bool | int | float | complex) – array of shape (..., N)
b (Array | ndarray | bool | number | bool | int | float | complex) – array of shape (..., N)
precision (None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset) – either None
(default), which means the default precision for the backend, a Precision
enum value (Precision.DEFAULT
, Precision.HIGH
or Precision.HIGHEST
) or a tuple of two such values indicating precision of a
and b
.
preferred_element_type (str | type[Any] | dtype | SupportsDType | None) – either None
(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
array of shape (*a.shape[:-1], *b.shape[:-1])
containing the batched vector product of the inputs.
Examples
For 1D inputs, this implements standard (non-conjugate) vector multiplication:
>>> a = jnp.array([1j, 3j, 4j]) >>> b = jnp.array([4., 2., 5.]) >>> jnp.inner(a, b) Array(0.+30.j, dtype=complex64)
For multi-dimensional inputs, batch dimensions are stacked rather than broadcast:
>>> a = jnp.ones((2, 3)) >>> b = jnp.ones((5, 3)) >>> jnp.inner(a, b).shape (2, 5)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4