Return indices of nonzero elements of an array.
JAX implementation of numpy.nonzero()
.
Because the size of the output of nonzero
is data-dependent, the function is not compatible with JIT and other transformations. The JAX version adds the optional size
argument which must be specified statically for jnp.nonzero
to be used within JAX’s transformations.
a (ArrayLike) – N-dimensional array.
size (int | None | None) – optional static integer specifying the number of nonzero entries to return. If there are more nonzero elements than the specified size
, then indices will be truncated at the end. If there are fewer nonzero elements than the specified size, then indices will be padded with fill_value
, which defaults to zero.
fill_value (None | ArrayLike | tuple[ArrayLike, ...] | None) – optional padding value when size
is specified. Defaults to 0.
Tuple of JAX Arrays of length a.ndim
, containing the indices of each nonzero value.
Examples
One-dimensional array returns a length-1 tuple of indices:
>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jnp.nonzero(x) (Array([1, 3, 5], dtype=int32),)
Two-dimensional array returns a length-2 tuple of indices:
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 7]]) >>> jnp.nonzero(x) (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
In either case, the resulting tuple of indices can be used directly to extract the nonzero values:
>>> indices = jnp.nonzero(x) >>> x[indices] Array([5, 6, 7], dtype=int32)
The output of nonzero
has a dynamic shape, because the number of returned indices depends on the contents of the input array. As such, it is incompatible with JIT and other JAX transformations:
>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jax.jit(jnp.nonzero)(x) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This can be addressed by passing a static size
parameter to specify the desired output shape:
>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size') >>> nonzero_jit(x, size=3) (Array([1, 3, 5], dtype=int32),)
If size
does not match the true size, the result will be either truncated or padded:
>>> nonzero_jit(x, size=2) # size < 3: indices are truncated (Array([1, 3], dtype=int32),) >>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros. (Array([1, 3, 5, 0, 0], dtype=int32),)
You can specify a custom fill value for the padding using the fill_value
argument:
>>> nonzero_jit(x, size=5, fill_value=len(x)) (Array([1, 3, 5, 6, 6], dtype=int32),)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3