Compute the differences of the elements of the flattened array.
JAX implementation of numpy.ediff1d()
.
ary (ArrayLike) – input array or scalar.
to_end (ArrayLike | None) – scalar or array, optional, default=None. Specifies the numbers to append to the resulting array.
to_begin (ArrayLike | None) – scalar or array, optional, default=None. Specifies the numbers to prepend to the resulting array.
An array containing the differences between the elements of the input array.
Note
Unlike NumPy’s implementation of ediff1d, jax.numpy.ediff1d()
will not issue an error if casting to_end
or to_begin
to the type of ary
loses precision.
See also
jax.numpy.diff()
: Computes the n-th order difference between elements of the array along a given axis.
jax.numpy.cumsum()
: Computes the cumulative sum of the elements of the array along a given axis.
jax.numpy.gradient()
: Computes the gradient of an N-dimensional array.
Examples
>>> a = jnp.array([2, 3, 5, 9, 1, 4]) >>> jnp.ediff1d(a) Array([ 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10) Array([-10, 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_end=jnp.array([20, 30])) Array([ 1, 2, 4, -8, 3, 20, 30], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30])) Array([-10, 1, 2, 4, -8, 3, 20, 30], dtype=int32)
For array with ndim > 1
, the differences are computed after flattening the input array.
>>> a1 = jnp.array([[2, -1, 4, 7], ... [3, 5, -6, 9]]) >>> jnp.ediff1d(a1) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) >>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9]) >>> jnp.ediff1d(a2) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4