Cumulative product along the axis of an array.
JAX implementation of numpy.cumulative_prod()
.
x (ArrayLike) – N-dimensional array
axis (int | None | None) – integer axis along which to accumulate. If x
is one-dimensional, this argument is optional and defaults to zero.
dtype (DTypeLike | None | None) – optional dtype of the output.
include_initial (bool) – if True, then include the initial value in the cumulative product. Default is False.
An array containing the accumulated values.
See also
jax.numpy.cumprod()
: alternative API for cumulative product.
jax.numpy.nancumprod()
: cumulative product while ignoring NaN values.
jax.numpy.multiply.accumulate()
: cumulative product via the ufunc API.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumulative_prod(x, axis=1) Array([[ 1, 2, 6], [ 4, 20, 120]], dtype=int32) >>> jnp.cumulative_prod(x, axis=1, include_initial=True) Array([[ 1, 1, 2, 6], [ 1, 4, 20, 120]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3