Round input evenly to the given number of decimals.
JAX implementation of numpy.round()
.
a (ArrayLike) – input array or scalar.
decimals (int) – int, default=0. Number of decimal points to which the input needs to be rounded. It must be specified statically. Not implemented for decimals < 0
.
out (None) – Unused by JAX.
An array containing the rounded values to the specified decimals
with same shape and dtype as a
.
Note
jnp.round
rounds to the nearest even integer for the values exactly halfway between rounded decimal values.
Examples
>>> x = jnp.array([1.532, 3.267, 6.149]) >>> jnp.round(x) Array([2., 3., 6.], dtype=float32) >>> jnp.round(x, decimals=2) Array([1.53, 3.27, 6.15], dtype=float32)
For values exactly halfway between rounded values:
>>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5]) >>> jnp.round(x1) Array([10., 22., 12., 32.], dtype=float32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4