A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/issues/4564 below:

Should JAX deprecate indexing with lists? · Issue #4564 · jax-ml/jax · GitHub

Since numpy 1.16, indexing with a list in place of a tuple has led to a FutureWarning (See numpy/numpy#9686 for a discussion of the rationale for this):

>>> import numpy as np
>>> x = np.arange(6).reshape(2, 3)
>>> idx = [[0], [1]]
>>> x[idx]
FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
array([1])

As mentioned in the warning, the current behavior treats the indices as identical to a tuple:

>>> x[tuple(idx)]
array([1])

while in the future, the indices will be treated as an array:

>>> x[np.array(idx)]
array([[[0, 1, 2]],
       [[3, 4, 5]]])

JAX currently implements the old, deprecated behavior, without any warning:

>>> import jax.numpy as jnp
>>> jnp.array(x)[idx]
DeviceArray([1], dtype=int32)

This is setting us up for a future where numpy and JAX have different indexing semantics for lists of indices. I would propose that we follow numpy and start warning about this behavior now, so that when a numpy release finally does deprecate this indexing behavior, jax will be ready to immediately follow suit.

Thoughts?


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4