This works:
import jax.numpy as jnp
x = [jnp.ones((), dtype=jnp.float32), jnp.ones((), dtype=jnp.float32)]
jnp.sum(x)
and gives me DeviceArray(2., dtype=float32)
. But this doesn't:
import jax.numpy as jnp
x = [jnp.ones((), dtype=jnp.float32), jnp.ones((), dtype=jnp.float32)]
jnp.mean(x)
and I got TypeError: data type not understood
.
In numpy both would work:
import numpy as np
x = [np.ones((), dtype=np.float32), np.ones((), dtype=np.float32)]
np.mean(x) # 1.0
np.sum(x) # 2.0
Is this intended?
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4