Return a transposed version of an N-dimensional array.
JAX implementation of numpy.transpose()
, implemented in terms of jax.lax.transpose()
.
a (ArrayLike) – input array
axes (Sequence[int] | None | None) – optionally specify the permutation using a length-a.ndim sequence of integers i
satisfying 0 <= i < a.ndim
. Defaults to range(a.ndim)[::-1]
, i.e. reverses the order of all axes.
transposed copy of the array.
Note
Unlike numpy.transpose()
, jax.numpy.transpose()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.
Examples
For a 1D array, the transpose is the identity:
>>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32)
For a 2D array, the transpose is a matrix transpose:
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32)
For an N-dimensional array, the transpose reverses the order of the axes:
>>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3)
The axes
argument can be specified to change this default behavior:
>>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4)
Since swapping the last two axes is a common operation, it can be done via its own API, jax.numpy.matrix_transpose()
:
>>> jnp.matrix_transpose(x).shape (3, 5, 4)
For convenience, transposes may also be performed using the jax.Array.transpose()
method or the jax.Array.T
property:
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3