Swap two axes of an array.
JAX implementation of numpy.swapaxes()
, implemented in terms of jax.lax.transpose()
.
Copy of a
with specified axes swapped.
Notes
Unlike numpy.swapaxes()
, jax.numpy.swapaxes()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.
Examples
>>> a = jnp.ones((2, 3, 4, 5)) >>> jnp.swapaxes(a, 1, 3).shape (2, 5, 4, 3)
Equivalent output via the swapaxes
array method:
>>> a.swapaxes(1, 3).shape (2, 5, 4, 3)
Equivalent output via transpose()
:
>>> a.transpose(0, 3, 2, 1).shape (2, 5, 4, 3)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4