Return indices of nonzero elements in a flattened array
JAX implementation of numpy.flatnonzero()
.
jnp.flatnonzero(x)
is equivalent to nonzero(ravel(a))[0]
. For a full discussion of the parameters to this function, refer to jax.numpy.nonzero()
.
a (ArrayLike) – N-dimensional array.
size (int | None) – optional static integer specifying the number of nonzero entries to return. See jax.numpy.nonzero()
for more discussion of this parameter.
fill_value (None | ArrayLike | tuple[ArrayLike, ...]) – optional padding value when size
is specified. Defaults to 0. See jax.numpy.nonzero()
for more discussion of this parameter.
Array containing the indices of each nonzero value in the flattened array.
Examples
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 8]]) >>> jnp.flatnonzero(x) Array([1, 3, 5], dtype=int32)
This is equivalent to calling nonzero()
on the flattened array, and extracting the first entry in the resulting tuple:
>>> jnp.nonzero(x.ravel())[0] Array([1, 3, 5], dtype=int32)
The returned indices can be used to extract nonzero entries from the flattened array:
>>> indices = jnp.flatnonzero(x) >>> x.ravel()[indices] Array([5, 6, 8], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4