Returns indices that partially sort an array.
JAX implementation of numpy.argpartition()
. The JAX version differs from NumPy in the treatment of NaN entries: NaNs which have the negative bit set are sorted to the beginning of the array.
Indices which partition a
at the kth
value along axis
. The entries before kth
are indices of values smaller than take(a, kth, axis)
, and entries after kth
are indices of values larger than take(a, kth, axis)
Note
The JAX version requires the kth
argument to be a static integer rather than a general array. This is implemented via two calls to jax.lax.top_k()
. If you’re only accessing the top or bottom k values of the output, it may be more efficient to call jax.lax.top_k()
directly.
Examples
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> idx = jnp.argpartition(x, kth) >>> idx Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
The result is a sequence of indices that partially sort the input. All indices before kth
are of values smaller than the pivot value, and all indices after kth
are of values larger than the pivot value:
>>> x_partitioned = x[idx] >>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [6 8 9 7 5]
Notice that among smallest_values
and largest_values
, the returned order is arbitrary and implementation-dependent.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4