Returns the specified diagonal of an array.
JAX implementation of numpy.diagonal()
.
The JAX version always returns a copy of the input, although if this is used within a JIT compilation, the compiler may avoid the copy.
a (ArrayLike) – Input array. Must be at least 2-dimensional.
offset (int) – optional, default=0. Diagonal offset from the main diagonal. Must be a static integer value. Can be positive or negative.
axis1 (int) – optional, default=0. The first axis along which to take the diagonal.
axis2 (int) –
optional, default=1. The second axis along which to take the diagonal.
A 1D array for 2D input, and in general a N-1 dimensional array for N-dimensional input.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.diagonal(x) Array([1, 5, 9], dtype=int32) >>> jnp.diagonal(x, offset=1) Array([2, 6], dtype=int32) >>> jnp.diagonal(x, offset=-1) Array([4, 8], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4