Construct an array from repeated elements.
JAX implementation of numpy.repeat()
.
a (ArrayLike) – N-dimensional array
repeats (ArrayLike) – 1D integer array specifying the number of repeats. Must match the length of the repeated axis.
axis (int | None) – integer specifying the axis of a
along which to construct the repeated array. If None (default) then a
is first flattened.
total_repeat_length (int | None) – this must be specified statically for jnp.repeat
to be compatible with jit()
and other JAX transformations. If sum(repeats)
is larger than the specified total_repeat_length
, the remaining values will be discarded. If sum(repeats)
is smaller than total_repeat_length
, the final value will be repeated.
out_sharding (NamedSharding | P | None)
an array constructed from repeated values of a
.
Examples
Repeat each value twice along the last axis:
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.repeat(a, 2, axis=-1) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
If axis
is not specified, the input array will be flattened:
>>> jnp.repeat(a, 2) Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
Pass an array to repeats
to repeat each value a different number of times:
>>> repeats = jnp.array([2, 3]) >>> jnp.repeat(a, repeats, axis=1) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
In order to use repeat
within jit
and other JAX transformations, the size of the output must be specified statically using total_repeat_length
:
>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length']) >>> jit_repeat(a, repeats, axis=1, total_repeat_length=5) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
If total_repeat_length is smaller than sum(repeats)
, the result will be truncated:
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
If it is larger, then the additional entries will be filled with the final value:
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7) Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4