Return the elements of an array that satisfy a condition.
JAX implementation of numpy.extract()
.
condition (ArrayLike) – array of conditions. Will be converted to boolean and flattened to 1D.
arr (ArrayLike) – array of values to extract. Will be flattened to 1D.
size (int | None | None) – optional static size for output. Must be specified in order for extract
to be compatible with JAX transformations like jit()
or vmap()
.
fill_value (ArrayLike) – if size
is specified, fill padded entries with this value (default: 0).
1D array of extracted entries . If size
is specified, the result will have shape (size,)
and be right-padded with fill_value
. If size
is not specified, the output shape will depend on the number of True entries in condition
.
Notes
This function does not require strict shape agreement between condition
and arr
. If condition.size > arr.size
, then condition
will be truncated, and if arr.size > condition.size
, then arr
will be truncated.
Examples
Extract values from a 1D array:
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> mask = (x % 2 == 0) >>> jnp.extract(mask, x) Array([2, 4, 6], dtype=int32)
In the simplest case, this is equivalent to boolean indexing:
>>> x[mask] Array([2, 4, 6], dtype=int32)
For use with JAX transformations, you can pass the size
argument to specify a static shape for the output, along with an optional fill_value
that defaults to zero:
>>> jnp.extract(mask, x, size=len(x), fill_value=0) Array([2, 4, 6, 0, 0, 0], dtype=int32)
Notice that unlike with boolean indexing, extract
does not require strict agreement between the sizes of the array and condition, and will effectively truncate both to the minimum size:
>>> short_mask = jnp.array([False, True]) >>> jnp.extract(short_mask, x) Array([2], dtype=int32) >>> long_mask = jnp.array([True, False, True, False, False, False, False, False]) >>> jnp.extract(long_mask, x) Array([1, 3], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3