A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://docs.jax.dev/en/latest/_autosummary/jax.numpy.rollaxis.html below:

jax.numpy.rollaxis — JAX documentation

jax.numpy.rollaxis#
jax.numpy.rollaxis(a, axis, start=0)[source]#

Roll the specified axis to a given position.

JAX implementation of numpy.rollaxis().

This function exists for compatibility with NumPy, but in most cases the newer jax.numpy.moveaxis() instead, because the meaning of its arguments is more intuitive.

Parameters:
  • a (ArrayLike) – input array.

  • axis (int) – index of the axis to roll forward.

  • start (int) – index toward which the axis will be rolled (default = 0). After normalizing negative axes, if start <= axis, the axis is rolled to the start index; if start > axis, the axis is rolled until the position before start.

Returns:

Copy of a with rolled axis.

Return type:

Array

Notes

Unlike numpy.rollaxis(), jax.numpy.rollaxis() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.

Examples

>>> a = jnp.ones((2, 3, 4, 5))

Roll axis 2 to the start of the array:

>>> jnp.rollaxis(a, 2).shape
(4, 2, 3, 5)

Roll axis 1 to the end of the array:

>>> jnp.rollaxis(a, 1, a.ndim).shape
(2, 4, 5, 3)

Equivalent of these two with moveaxis()

>>> jnp.moveaxis(a, 2, 0).shape
(4, 2, 3, 5)
>>> jnp.moveaxis(a, 1, -1).shape
(2, 4, 5, 3)

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4