Broadcast an array to a specified shape.
JAX implementation of numpy.broadcast_to()
. JAX uses NumPy-style broadcasting rules, which you can read more about at NumPy broadcasting.
array (ArrayLike) – array to be broadcast.
shape (DimSize | Shape) – shape to which the array will be broadcast.
out_sharding (NamedSharding | P | None)
a copy of array broadcast to the specified shape.
Examples
>>> x = jnp.int32(1) >>> jnp.broadcast_to(x, (1, 4)) Array([[1, 1, 1, 1]], dtype=int32)
>>> x = jnp.array([1, 2, 3]) >>> jnp.broadcast_to(x, (2, 3)) Array([[1, 2, 3], [1, 2, 3]], dtype=int32)
>>> x = jnp.array([[2], [4]]) >>> jnp.broadcast_to(x, (2, 4)) Array([[2, 2, 2, 2], [4, 4, 4, 4]], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4