Transpose the last two dimensions of an array.
JAX implementation of numpy.matrix_transpose()
, implemented in terms of jax.lax.transpose()
.
x (ArrayLike) – input array, Must have x.ndim >= 2
matrix-transposed copy of the array.
Examples
Here is a 2x2x2 matrix representing a batched 2x2 matrix:
>>> x = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.matrix_transpose(x) Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
For convenience, you can perform the same transpose via the mT
property of jax.Array
:
>>> x.mT Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4