Find the indices of nonzero array elements
JAX implementation of numpy.argwhere()
.
jnp.argwhere(x)
is essentially equivalent to jnp.column_stack(jnp.nonzero(x))
with special handling for zero-dimensional (i.e. scalar) inputs.
Because the size of the output of argwhere
is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size
argument, which specifies the size of the leading dimension of the output - it must be specified statically for jnp.argwhere
to be compiled with non-static operands. See jax.numpy.nonzero()
for a full discussion of size
and its semantics.
a (ArrayLike) – array for which to find nonzero elements
size (int | None) – optional integer specifying statically the number of expected nonzero elements. This must be specified in order to use argwhere
within JAX transformations like jax.jit()
. See jax.numpy.nonzero()
for more information.
fill_value (ArrayLike | None) – optional array specifying the fill value when size
is specified. See jax.numpy.nonzero()
for more information.
a two-dimensional array of shape [size, x.ndim]
. If size
is not specified as an argument, it is equal to the number of nonzero elements in x
.
Examples
Two-dimensional array:
>>> x = jnp.array([[1, 0, 2], ... [0, 3, 0]]) >>> jnp.argwhere(x) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
Equivalent computation using jax.numpy.column_stack()
and jax.numpy.nonzero()
:
>>> jnp.column_stack(jnp.nonzero(x)) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
Special case for zero-dimensional (i.e. scalar) inputs:
>>> jnp.argwhere(1) Array([], shape=(1, 0), dtype=int32) >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4