Count the number of occurrences of each value in an integer array.
JAX implementation of numpy.bincount()
.
For an array of non-negative integers x
, this function returns an array counts
of size x.max() + 1
, such that counts[i]
contains the number of occurrences of the value i
in x
.
The JAX version has a few differences from the NumPy version:
In NumPy, passing an array x
with negative entries will result in an error. In JAX, negative values are clipped to zero.
JAX adds an optional length
parameter which can be used to statically specify the length of the output array so that this function can be used with transformations like jax.jit()
. In this case, items larger than length + 1 will be dropped.
x (ArrayLike) – 1-dimensional array of non-negative integers
weights (ArrayLike | None) – optional array of weights associated with x
. If not specified, the weight for each entry will be 1
.
minlength (int) – the minimum length of the output counts array.
length (int | None) – the length of the output counts array. Must be specified statically for bincount
to be used with jax.jit()
and other JAX transformations.
An array of counts or summed weights reflecting the number of occurrences of values in x
.
Examples
Basic bincount:
>>> x = jnp.array([1, 1, 2, 3, 3, 3]) >>> jnp.bincount(x) Array([0, 2, 1, 3], dtype=int32)
Weighted bincount:
>>> weights = jnp.array([1, 2, 3, 4, 5, 6]) >>> jnp.bincount(x, weights) Array([ 0, 3, 3, 15], dtype=int32)
Specifying a static length
makes this jit-compatible:
>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length']) >>> jit_bincount(x, length=5) Array([0, 2, 1, 3, 0], dtype=int32)
Any negative numbers are clipped to the first bin, and numbers beyond the specified length
are dropped:
>>> x = jnp.array([-1, -1, 1, 3, 10]) >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4