A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/issues/7737 below:

Regularize jax.numpy API · Issue #7737 · jax-ml/jax · GitHub

I've been picking away at this for the past year or so (since #3038), but I wanted to track it a bit more formally here.

We want the jax.numpy API to have two properties:

In [1]: import jax.numpy as jnp                                                                                                                              

In [2]: from jax import jit                                                                                                                                  

In [3]: f = jit(lambda x: jnp.mean(jnp.asarray(x)))

In [4]: %timeit f([float(i) for i in range(1000)])                                                                                                           
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
1.79 ms ± 556 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: %timeit f(jnp.array([float(i) for i in range(1000)]))                                                                                                
375 µs ± 4.76 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

This change may be problematic because it breaks users who may be passing lists to jax.numpy functions, but this input type restriction has long been documented. Still, with each change I plan to run a full set of tests and fix downstream packages if necessary.

froystig, soraros, daskol and Edenhofer


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4