A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/issues/3038 below:

`jax.numpy.sum` and `jax.numpy.mean` inconsistency on python lists · Issue #3038 · jax-ml/jax · GitHub

This works:

import jax.numpy as jnp

x = [jnp.ones((), dtype=jnp.float32), jnp.ones((), dtype=jnp.float32)]
jnp.sum(x)

and gives me DeviceArray(2., dtype=float32). But this doesn't:

import jax.numpy as jnp

x = [jnp.ones((), dtype=jnp.float32), jnp.ones((), dtype=jnp.float32)]
jnp.mean(x)

and I got TypeError: data type not understood.

In numpy both would work:

import numpy as np

x = [np.ones((), dtype=np.float32), np.ones((), dtype=np.float32)]
np.mean(x)  # 1.0
np.sum(x)  # 2.0

Is this intended?


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4