Roll the specified axis to a given position.
JAX implementation of numpy.rollaxis()
.
This function exists for compatibility with NumPy, but in most cases the newer jax.numpy.moveaxis()
instead, because the meaning of its arguments is more intuitive.
a (ArrayLike) – input array.
axis (int) – index of the axis to roll forward.
start (int) – index toward which the axis will be rolled (default = 0). After normalizing negative axes, if start <= axis
, the axis is rolled to the start
index; if start > axis
, the axis is rolled until the position before start
.
Copy of a
with rolled axis.
Notes
Unlike numpy.rollaxis()
, jax.numpy.rollaxis()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.
Examples
>>> a = jnp.ones((2, 3, 4, 5))
Roll axis 2 to the start of the array:
>>> jnp.rollaxis(a, 2).shape (4, 2, 3, 5)
Roll axis 1 to the end of the array:
>>> jnp.rollaxis(a, 1, a.ndim).shape (2, 4, 5, 3)
Equivalent of these two with moveaxis()
>>> jnp.moveaxis(a, 2, 0).shape (4, 2, 3, 5) >>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4