A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://docs.jax.dev/en/latest/_autosummary/jax.numpy.lexsort.html below:

jax.numpy.lexsort — JAX documentation

jax.numpy.lexsort#
jax.numpy.lexsort(keys, axis=-1)[source]#

Sort a sequence of keys in lexicographic order.

JAX implementation of numpy.lexsort().

Parameters:
Returns:

An array of integers of shape keys[0].shape giving the indices of the entries in lexicographically-sorted order.

Return type:

Array

Examples

lexsort() with a single key is equivalent to argsort():

>>> key1 = jnp.array([4, 2, 3, 2, 5])
>>> jnp.lexsort([key1])
Array([1, 3, 2, 0, 4], dtype=int32)
>>> jnp.argsort(key1)
Array([1, 3, 2, 0, 4], dtype=int32)

With multiple keys, lexsort() uses the last key as the primary key:

>>> key2 = jnp.array([2, 1, 1, 2, 2])
>>> jnp.lexsort([key1, key2])
Array([1, 2, 3, 0, 4], dtype=int32)

The meaning of the indices become more clear when printing the sorted keys:

>>> indices = jnp.lexsort([key1, key2])
>>> print(f"{key1[indices]}\n{key2[indices]}")
[2 3 2 4 5]
[1 1 2 2 2]

Notice that the elements of key2 appear in order, and within the sequences of duplicated values the corresponding elements of `key1 appear in order.

For multi-dimensional inputs, lexsort() defaults to sorting along the last axis:

>>> key1 = jnp.array([[2, 4, 2, 3],
...                   [3, 1, 2, 2]])
>>> key2 = jnp.array([[1, 2, 1, 3],
...                   [2, 1, 2, 1]])
>>> jnp.lexsort([key1, key2])
Array([[0, 2, 1, 3],
       [1, 3, 2, 0]], dtype=int32)

A different sort axis can be chosen using the axis keyword; here we sort along the leading axis:

>>> jnp.lexsort([key1, key2], axis=0)
Array([[0, 1, 0, 1],
       [1, 0, 1, 0]], dtype=int32)

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4