Return indices for accessing the main diagonal of a multidimensional array.
JAX implementation of numpy.diag_indices()
.
A tuple of arrays, each of length n, containing the indices to access the main diagonal.
Examples
>>> jnp.diag_indices(3) (Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32)) >>> jnp.diag_indices(4, ndim=3) (Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32))
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4