A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://docs.jax.dev/en/latest/_autosummary/jax.numpy.matrix_transpose.html below:

jax.numpy.matrix_transpose — JAX documentation

jax.numpy.matrix_transpose#
jax.numpy.matrix_transpose(x, /)[source]#

Transpose the last two dimensions of an array.

JAX implementation of numpy.matrix_transpose(), implemented in terms of jax.lax.transpose().

Parameters:

x (ArrayLike) – input array, Must have x.ndim >= 2

Returns:

matrix-transposed copy of the array.

Return type:

Array

Examples

Here is a 2x2x2 matrix representing a batched 2x2 matrix:

>>> x = jnp.array([[[1, 2],
...                 [3, 4]],
...                [[5, 6],
...                 [7, 8]]])
>>> jnp.matrix_transpose(x)
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)

For convenience, you can perform the same transpose via the mT property of jax.Array:

>>> x.mT
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4