I've been picking away at this for the past year or so (since #3038), but I wanted to track it a bit more formally here.
We want the jax.numpy API to have two properties:
static arguments should generally be checked with core.concrete_or_error()
. This helps localize concretization errors and provides a more uniform user experience.
dynamic arguments should generally be validated with _check_arraylike
. This is because passing lists to jax functions can be a quiet source of performance degradation, because lists are treated as pytrees, for example:
In [1]: import jax.numpy as jnp In [2]: from jax import jit In [3]: f = jit(lambda x: jnp.mean(jnp.asarray(x))) In [4]: %timeit f([float(i) for i in range(1000)]) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) 1.79 ms ± 556 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) In [5]: %timeit f(jnp.array([float(i) for i in range(1000)])) 375 µs ± 4.76 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
This change may be problematic because it breaks users who may be passing lists to jax.numpy
functions, but this input type restriction has long been documented. Still, with each change I plan to run a full set of tests and fix downstream packages if necessary.
froystig, soraros, daskol and Edenhofer
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4