Return unique values from x, along with counts.
JAX implementation of numpy.unique_counts()
; this is equivalent to calling jax.numpy.unique()
with return_counts and equal_nan set to True.
Because the size of the output of unique_counts
is data-dependent, the function is not typically compatible with jit()
and other JAX transformations. The JAX version adds the optional size
argument which must be specified statically for jnp.unique
to be used in such contexts.
x (ArrayLike) – N-dimensional array from which unique values will be extracted.
size (int | None) – if specified, return only the first size
sorted unique elements. If there are fewer unique elements than size
indicates, the return value will be padded with fill_value
.
fill_value (ArrayLike | None) – when size
is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value
. Defaults to the minimum unique value.
values
:
an array of shape (n_unique,)
containing the unique values from x
.
counts
:
An array of shape (n_unique,)
. Contains the number of occurrences of each unique value in x
.
A tuple (values, counts)
, with the following properties
Examples
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> result = jnp.unique_counts(x)
The result is a NamedTuple
with two named attributes. The values
attribute contains the unique values from the array:
>>> result.values Array([1, 3, 4], dtype=int32)
The counts
attribute contains the counts of each unique value in the input:
>>> result.counts Array([2, 2, 1], dtype=int32)
For examples of the size
and fill_value
arguments, see jax.numpy.unique()
.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4