Create an array of evenly-spaced values.
JAX implementation of numpy.arange()
, implemented in terms of jax.lax.iota()
.
Similar to Python’s range()
function, this can be called with a few different positional signatures:
jnp.arange(stop)
: generate values from 0 to stop
, stepping by 1.
jnp.arange(start, stop)
: generate values from start
to stop
, stepping by 1.
jnp.arange(start, stop, step)
: generate values from start
to stop
, stepping by step
.
Like with Python’s range()
function, the starting value is inclusive, and the stop value is exclusive.
start (ArrayLike | DimSize) – start of the interval, inclusive.
stop (ArrayLike | DimSize | None | None) – optional end of the interval, exclusive. If not specified, then (start, stop) = (0, start)
step (ArrayLike | None | None) – optional step size for the interval. Default = 1.
dtype (DTypeLike | None | None) – optional dtype for the returned array; if not specified it will be determined via type promotion of start, stop, and step.
device (xc.Device | Sharding | None | None) – (optional) Device
or Sharding
to which the created array will be committed.
Array of evenly-spaced values from start
to stop
, separated by step
.
Note
Using arange
with a floating-point step
argument can lead to unexpected results due to accumulation of floating-point errors, especially with lower-precision data types like float8_*
and bfloat16
. To avoid precision errors, consider generating a range of integers, and scaling it to the desired range. For example, instead of this:
jnp.arange(-1, 1, 0.01, dtype='bfloat16')
it can be more accurate to generate a sequence of integers, and scale them:
(jnp.arange(-100, 100) * 0.01).astype('bfloat16')
Examples
Single-argument version specifies only the stop
value:
>>> jnp.arange(4) Array([0, 1, 2, 3], dtype=int32)
Passing a floating-point stop
value leads to a floating-point result:
>>> jnp.arange(4.0) Array([0., 1., 2., 3.], dtype=float32)
Two-argument version specifies start
and stop
, with step=1
:
>>> jnp.arange(1, 6) Array([1, 2, 3, 4, 5], dtype=int32)
Three-argument version specifies start
, stop
, and step
:
>>> jnp.arange(0, 2, 0.5) Array([0. , 0.5, 1. , 1.5], dtype=float32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3