Return indices that sort an array.
JAX implementation of numpy.argsort()
.
a (Array | ndarray | bool | number | bool | int | float | complex) – array to sort
axis (int | None) – integer axis along which to sort. Defaults to -1
, i.e. the last axis. If None
, then a
is flattened before being sorted.
stable (bool) – boolean specifying whether a stable sort should be used. Default=True.
descending (bool) – boolean specifying whether to sort in descending order. Default=False.
kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.
order (None) – not supported by JAX
Array of indices that sort an array. Returned array will be of shape a.shape
(if axis
is an integer) or of shape (a.size,)
(if axis
is None).
Examples
Simple 1-dimensional sort
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> indices = jnp.argsort(x) >>> indices Array([0, 5, 4, 1, 3, 2], dtype=int32) >>> x[indices] Array([1, 1, 2, 3, 4, 5], dtype=int32)
Sort along the last axis of an array:
>>> x = jnp.array([[2, 1, 3], ... [6, 4, 3]]) >>> indices = jnp.argsort(x, axis=1) >>> indices Array([[1, 0, 2], [2, 1, 0]], dtype=int32) >>> jnp.take_along_axis(x, indices, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4