Move an array axis to a new position
JAX implementation of numpy.moveaxis()
, implemented in terms of jax.lax.transpose()
.
Copy of a
with axes moved from source
to destination
.
Notes
Unlike numpy.moveaxis()
, jax.numpy.moveaxis()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.
Examples
>>> a = jnp.ones((2, 3, 4, 5))
Move axis 1
to the end of the array:
>>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)
Move the last axis to position 1:
>>> jnp.moveaxis(a, -1, 1).shape (2, 5, 3, 4)
Move multiple axes:
>>> jnp.moveaxis(a, (0, 1), (-1, -2)).shape (4, 5, 3, 2)
This can also be accomplished via transpose()
:
>>> a.transpose(2, 3, 1, 0).shape (4, 5, 3, 2)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4