Apply a function repeatedly over specified axes.
JAX implementation of numpy.apply_over_axes()
.
func (Callable[[ArrayLike, int], Array]) – the function to apply, with signature func(Array, int) -> Array
, and where y = func(x, axis)
must satisfy y.ndim in [x.ndim, x.ndim - 1]
.
a (ArrayLike) – N-dimensional array over which to apply the function.
axes (Sequence[int]) – the sequence of axes over which to apply the function.
An N-dimensional array containing the result of the repeated function application.
Examples
This function is designed to have similar semantics to typical associative jax.numpy
reductions over one or more axes with keepdims=True
. For example:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]])
>>> jnp.apply_over_axes(jnp.sum, x, [0]) Array([[5, 7, 9]], dtype=int32) >>> jnp.sum(x, [0], keepdims=True) Array([[5, 7, 9]], dtype=int32)
>>> jnp.apply_over_axes(jnp.min, x, [1]) Array([[1], [4]], dtype=int32) >>> jnp.min(x, [1], keepdims=True) Array([[1], [4]], dtype=int32)
>>> jnp.apply_over_axes(jnp.prod, x, [0, 1]) Array([[720]], dtype=int32) >>> jnp.prod(x, [0, 1], keepdims=True) Array([[720]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4