Select elements from two arrays based on a condition.
JAX implementation of numpy.where()
.
Note
when only condition
is provided, jnp.where(condition)
is equivalent to jnp.nonzero(condition)
. For that case, refer to the documentation of jax.numpy.nonzero()
. The docstring below focuses on the case where x
and y
are specified.
The three-term version of jnp.where
lowers to jax.lax.select()
.
condition – boolean array. Must be broadcast-compatible with x
and y
when they are specified.
x – arraylike. Should be broadcast-compatible with condition
and y
, and typecast-compatible with y
.
y – arraylike. Should be broadcast-compatible with condition
and x
, and typecast-compatible with x
.
size – integer, only referenced when x
and y
are None
. For details, see jax.numpy.nonzero()
.
fill_value – only referenced when x
and y
are None
. For details, see jax.numpy.nonzero()
.
An array of dtype jnp.result_type(x, y)
with values drawn from x
where condition
is True, and from y
where condition is False
. If x
and y
are None
, the function behaves differently; see jax.numpy.nonzero()
for a description of the return type.
Notes
Special care is needed when the x
or y
input to jax.numpy.where()
could have a value of NaN. Specifically, when a gradient is taken with jax.grad()
(reverse-mode differentiation), a NaN in either x
or y
will propagate into the gradient, regardless of the value of condition
. More information on this behavior and workarounds is available in the JAX FAQ.
Examples
When x
and y
are not provided, where
behaves equivalently to jax.numpy.nonzero()
:
>>> x = jnp.arange(10) >>> jnp.where(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),) >>> jnp.nonzero(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),)
When x
and y
are provided, where
selects between them based on the specified condition:
>>> jnp.where(x > 4, x, 0) Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4