Compute the set intersection of two 1D arrays.
JAX implementation of numpy.intersect1d()
.
Because the size of the output of intersect1d
is data-dependent, the function is not typically compatible with jit()
and other JAX transformations. The JAX version adds the optional size
argument which must be specified statically for jnp.intersect1d
to be used in such contexts.
ar1 (ArrayLike) – first array of values to intersect.
ar2 (ArrayLike) – second array of values to intersect.
assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if assume_unique
is True and the input arrays contain duplicates, the behavior is undefined. default: False.
return_indices (bool) – If True, return arrays of indices specifying where the intersected values first appear in the input arrays.
size (int | None) – if specified, return only the first size
sorted elements. If there are fewer elements than size
indicates, the return value will be padded with fill_value
, and returned indices will be padded with an out-of-bound index.
fill_value (ArrayLike | None) – when size
is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value
. Defaults to the smallest value in the intersection.
An array intersection
, or if return_indices=True
, a tuple of arrays (intersection, ar1_indices, ar2_indices)
. Returned values are
intersection
: A 1D array containing each value that appears in both ar1
and ar2
.
ar1_indices
: (returned if return_indices=True) an array of shape intersection.shape
containing the indices in flattened ar1
of values in intersection
. For 1D inputs, intersection
is equivalent to ar1[ar1_indices]
.
ar2_indices
: (returned if return_indices=True) an array of shape intersection.shape
containing the indices in flattened ar2
of values in intersection
. For 1D inputs, intersection
is equivalent to ar2[ar2_indices]
.
Examples
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.intersect1d(ar1, ar2) Array([3, 4], dtype=int32)
Computing intersection with indices:
>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) >>> intersection Array([3, 4], dtype=int32)
ar1_indices
gives the indices of the intersected values within ar1
:
>>> ar1_indices Array([2, 3], dtype=int32) >>> jnp.all(intersection == ar1[ar1_indices]) Array(True, dtype=bool)
ar2_indices
gives the indices of the intersected values within ar2
:
>>> ar2_indices Array([0, 1], dtype=int32) >>> jnp.all(intersection == ar2[ar2_indices]) Array(True, dtype=bool)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4