Put values into the destination array by matching 1d index and data slices.
JAX implementation of numpy.put_along_axis()
.
The semantics of numpy.put_along_axis()
are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds the inplace
parameter which must be set to False` by the user as a reminder of this API difference.
arr (Array | ndarray | bool | number | bool | int | float | complex) – array into which values will be put.
indices (Array | ndarray | bool | number | bool | int | float | complex) – array of indices at which to put values.
values (Array | ndarray | bool | number | bool | int | float | complex) – array of values to put into the array.
axis (int | None) – the axis along which to put values. If not specified, the array will be flattened before indexing is applied.
inplace (bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.
mode (str | None) – Out-of-bounds indexing mode. For more discussion of mode
options, see jax.numpy.ndarray.at
.
A copy of a
with specified entries updated.
Examples
>>> from jax import numpy as jnp >>> a = jnp.array([[10, 30, 20], [60, 40, 50]]) >>> i = jnp.argmax(a, axis=1, keepdims=True) >>> print(i) [[1] [0]] >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False) >>> print(b) [[10 99 20] [99 40 50]]
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4