Since numpy 1.16, indexing with a list in place of a tuple has led to a FutureWarning
(See numpy/numpy#9686 for a discussion of the rationale for this):
>>> import numpy as np >>> x = np.arange(6).reshape(2, 3) >>> idx = [[0], [1]] >>> x[idx] FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. array([1])
As mentioned in the warning, the current behavior treats the indices as identical to a tuple:
>>> x[tuple(idx)] array([1])
while in the future, the indices will be treated as an array:
>>> x[np.array(idx)] array([[[0, 1, 2]], [[3, 4, 5]]])
JAX currently implements the old, deprecated behavior, without any warning:
>>> import jax.numpy as jnp >>> jnp.array(x)[idx] DeviceArray([1], dtype=int32)
This is setting us up for a future where numpy and JAX have different indexing semantics for lists of indices. I would propose that we follow numpy and start warning about this behavior now, so that when a numpy release finally does deprecate this indexing behavior, jax will be ready to immediately follow suit.
Thoughts?
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4