Convert flat indices into multi-dimensional indices.
JAX implementation of numpy.unravel_index()
. The JAX version differs in its treatment of out-of-bound indices: unlike NumPy, negative indices are supported, and out-of-bound indices are clipped to the nearest valid value.
indices (ArrayLike) – integer array of flat indices
shape (Shape) – shape of multidimensional array to index into
Tuple of unraveled indices
Examples
Start with a 1D array values and indices:
>>> x = jnp.array([2., 3., 4., 5., 6., 7.]) >>> indices = jnp.array([1, 3, 5]) >>> print(x[indices]) [3. 5. 7.]
Now if x
is reshaped, unravel_indices
can be used to convert the flat indices into a tuple of indices that access the same entries:
>>> shape = (2, 3) >>> x_2D = x.reshape(shape) >>> indices_2D = jnp.unravel_index(indices, shape) >>> indices_2D (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32)) >>> print(x_2D[indices_2D]) [3. 5. 7.]
The inverse function, ravel_multi_index
, can be used to obtain the original indices:
>>> jnp.ravel_multi_index(indices_2D, shape) Array([1, 3, 5], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4