Join arrays along a new axis.
JAX implementation of numpy.stack()
.
arrays (np.ndarray | Array | Sequence[ArrayLike]) – a sequence of arrays to stack; each must have the same shape. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.
axis (int) – specify the axis along which to stack.
out (None) – unused by JAX
dtype (DTypeLike | None) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
the stacked result.
Examples
>>> x = jnp.array([1, 2, 3]) >>> y = jnp.array([4, 5, 6]) >>> jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
unstack()
performs the inverse operation:
>>> arr = jnp.stack([x, y], axis=1) >>> x, y = jnp.unstack(arr, axis=1) >>> x Array([1, 2, 3], dtype=int32) >>> y Array([4, 5, 6], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4