Compute an N-dimensional histogram.
JAX implementation of numpy.histogramdd()
.
sample (ArrayLike) – input array of shape (N, D)
representing N
points in D
dimensions.
bins (ArrayLike | list[ArrayLike]) – Specify the number of bins in each dimension of the histogram. (default: 10). May also be a length-D sequence of integers or arrays of bin edges.
range (Sequence[None | Array | Sequence[ArrayLike]] | None) – Length-D sequence of pairs specifying the range for each dimension. If not specified, the range is inferred from the data.
weights (ArrayLike | None) – An optional shape (N,)
array specifying the weights of the data points. Should be the same shape as sample
. If not specified, each data point is weighted equally.
density (bool | None) – If True, return the normalized histogram in units of counts per unit volume. If False (default) return the (weighted) counts per bin.
A tuple of arrays (histogram, bin_edges)
, where histogram
contains the aggregated data, and bin_edges
specifies the boundaries of the bins.
Examples
A histogram over 100 points in three dimensions
>>> key = jax.random.key(42) >>> a = jax.random.normal(key, (100, 3)) >>> counts, bin_edges = jnp.histogramdd(a, bins=6, ... range=[(-3, 3), (-3, 3), (-3, 3)]) >>> counts.shape (6, 6, 6) >>> bin_edges [Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32)]
Using density=True
returns a normalized histogram:
>>> density, bin_edges = jnp.histogramdd(a, density=True) >>> bin_widths = map(jnp.diff, bin_edges) >>> dx, dy, dz = jnp.meshgrid(*bin_widths, indexing='ij') >>> normed = jnp.sum(density * dx * dy * dz) >>> jnp.allclose(normed, 1.0) Array(True, dtype=bool)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4