Returns a partially-sorted copy of an array.
JAX implementation of numpy.partition()
. The JAX version differs from NumPy in the treatment of NaN entries: NaNs which have the negative bit set are sorted to the beginning of the array.
A copy of a
partitioned at the kth
value along axis
. The entries before kth
are values smaller than take(a, kth, axis)
, and entries after kth
are indices of values larger than take(a, kth, axis)
Note
The JAX version requires the kth
argument to be a static integer rather than a general array. This is implemented via two calls to jax.lax.top_k()
. If you’re only accessing the top or bottom k values of the output, it may be more efficient to call jax.lax.top_k()
directly.
Examples
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> x_partitioned = jnp.partition(x, kth) >>> x_partitioned Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
The result is a partially-sorted copy of the input. All values before kth
are of smaller than the pivot value, and all values after kth
are larger than the pivot value:
>>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [9 8 7 6 5]
Notice that among smallest_values
and largest_values
, the returned order is arbitrary and implementation-dependent.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4