Return a sorted copy of an array.
JAX implementation of numpy.sort()
.
a (Array | ndarray | bool | number | bool | int | float | complex) – array to sort
axis (int | None) – integer axis along which to sort. Defaults to -1
, i.e. the last axis. If None
, then a
is flattened before being sorted.
stable (bool) – boolean specifying whether a stable sort should be used. Default=True.
descending (bool) – boolean specifying whether to sort in descending order. Default=False.
kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.
order (None) – not supported by JAX
Sorted array of shape a.shape
(if axis
is an integer) or of shape (a.size,)
(if axis
is None).
Examples
Simple 1-dimensional sort
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> jnp.sort(x) Array([1, 1, 2, 3, 4, 5], dtype=int32)
Sort along the last axis of an array:
>>> x = jnp.array([[2, 1, 3], ... [4, 3, 6]]) >>> jnp.sort(x, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4