A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://docs.jax.dev/en/latest/jax.random.html below:

jax.random module — JAX documentation

jax.random module#

Utilities for pseudo-random number generation.

The jax.random package provides a number of routines for deterministic generation of sequences of pseudorandom numbers.

Basic usage#
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
...   key, subkey = jax.random.split(key)
...   params = compiled_update(subkey, params, next(batches))  
PRNG keys#

Unlike the stateful pseudorandom number generators (PRNGs) that users of NumPy and SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to be passed as a first argument. The random state is described by a special array element type that we call a key, usually generated by the jax.random.key() function:

>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]

This key can then be used in any of JAX’s random number generation routines:

>>> random.uniform(key)
Array(0.947667, dtype=float32)

Note that using a key does not modify it, so reusing the same key will lead to the same result:

>>> random.uniform(key)
Array(0.947667, dtype=float32)

If you need a new random number, you can use jax.random.split() to generate new subkeys:

>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.00729382, dtype=float32)

Note

Typed key arrays, with element types such as key<fry> above, were introduced in JAX v0.4.16. Before then, keys were conventionally represented in uint32 arrays, whose final dimension represented the key’s bit-level representation.

Both forms of key array can still be created and used with the jax.random module. New-style typed key arrays are made with jax.random.key(). Legacy uint32 key arrays are made with jax.random.PRNGKey().

To convert between the two, use jax.random.key_data() and jax.random.wrap_key_data(). The legacy key format may be needed when interfacing with systems outside of JAX (e.g. exporting arrays to a serializable format), or when passing keys to JAX-based libraries that assume the legacy format.

Otherwise, typed keys are recommended. Caveats of legacy keys relative to typed ones include:

To learn more about this upgrade, and the design of key types, see JEP 9263.

Advanced# Design and background#

TLDR: JAX PRNG = Threefry counter PRNG + a functional array-oriented splitting model

See docs/jep/263-prng.md for more details.

To summarize, among other requirements, the JAX PRNG aims to:

  1. ensure reproducibility,

  2. parallelize well, both in terms of vectorization (generating array values) and multi-replica, multi-core computation. In particular it should not use sequencing constraints between random function calls.

Advanced RNG configuration#

JAX provides several PRNG implementations. A specific one can be selected with the optional impl keyword argument to jax.random.key. When no impl option is passed to the key constructor, the implementation is determined by the global jax_default_prng_impl configuration flag. The string names of available implementations are:

Reasons to use an alternative to the default RNG include that:

  1. It may be slow to compile for TPUs.

  2. It is relatively slower to execute on TPUs.

Automatic partitioning:

In order for jax.jit to efficiently auto-partition functions that generate sharded random number arrays (or key arrays), all PRNG implementations require extra flags:

The XLA flag can be set using an the XLA_FLAGS environment variable, e.g. as XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1.

For more about jax_threefry_partitionable, see https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers

Summary:

Property

Threefry

Threefry*

rbg

unsafe_rbg

rbg**

unsafe_rbg**

Fastest on TPU

efficiently shardable (w/ pjit)

identical across shardings

identical across CPU/GPU/TPU

exact jax.vmap over keys

(*): with jax_threefry_partitionable=1 set

(**): with XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1 set

API Reference# Key Creation & Manipulation#

key(seed, *[, impl])

Create a pseudo-random number generator (PRNG) key given an integer seed.

key_data(keys)

Recover the bits of key data underlying a PRNG key array.

wrap_key_data(key_bits_array, *[, impl])

Wrap an array of key data bits into a PRNG key array.

fold_in(key, data)

Folds in data to a PRNG key to form a new PRNG key.

split(key[, num])

Splits a PRNG key into num new keys by adding a leading axis.

clone(key)

Clone a key for reuse

PRNGKey(seed, *[, impl])

Create a legacy PRNG key given an integer seed.

Random Samplers#

ball(key, d[, p, shape, dtype])

Sample uniformly from the unit Lp ball.

bernoulli(key[, p, shape, mode])

Sample Bernoulli random values with given shape and mean.

beta(key, a, b[, shape, dtype])

Sample Beta random values with given shape and float dtype.

binomial(key, n, p[, shape, dtype])

Sample Binomial random values with given shape and float dtype.

bits(key[, shape, dtype, out_sharding])

Sample uniform bits in the form of unsigned integers.

categorical(key, logits[, axis, shape, replace])

Sample random values from categorical distributions.

cauchy(key[, shape, dtype])

Sample Cauchy random values with given shape and float dtype.

chisquare(key, df[, shape, dtype])

Sample Chisquare random values with given shape and float dtype.

choice(key, a[, shape, replace, p, axis])

Generates a random sample from a given array.

dirichlet(key, alpha[, shape, dtype])

Sample Dirichlet random values with given shape and float dtype.

double_sided_maxwell(key, loc, scale[, ...])

Sample from a double sided Maxwell distribution.

exponential(key[, shape, dtype])

Sample Exponential random values with given shape and float dtype.

f(key, dfnum, dfden[, shape, dtype])

Sample F-distribution random values with given shape and float dtype.

gamma(key, a[, shape, dtype])

Sample Gamma random values with given shape and float dtype.

generalized_normal(key, p[, shape, dtype])

Sample from the generalized normal distribution.

geometric(key, p[, shape, dtype])

Sample Geometric random values with given shape and float dtype.

gumbel(key[, shape, dtype, mode])

Sample Gumbel random values with given shape and float dtype.

laplace(key[, shape, dtype])

Sample Laplace random values with given shape and float dtype.

loggamma(key, a[, shape, dtype])

Sample log-gamma random values with given shape and float dtype.

logistic(key[, shape, dtype])

Sample logistic random values with given shape and float dtype.

lognormal(key[, sigma, shape, dtype])

Sample lognormal random values with given shape and float dtype.

maxwell(key[, shape, dtype])

Sample from a one sided Maxwell distribution.

multinomial(key, n, p, *[, shape, dtype, unroll])

Sample from a multinomial distribution.

multivariate_normal(key, mean, cov[, shape, ...])

Sample multivariate normal random values with given mean and covariance.

normal(key[, shape, dtype, out_sharding])

Sample standard normal random values with given shape and float dtype.

orthogonal(key, n[, shape, dtype, m])

Sample uniformly from the orthogonal group O(n).

pareto(key, b[, shape, dtype])

Sample Pareto random values with given shape and float dtype.

permutation(key, x[, axis, independent, ...])

Returns a randomly permuted array or range.

poisson(key, lam[, shape, dtype])

Sample Poisson random values with given shape and integer dtype.

rademacher(key[, shape, dtype])

Sample from a Rademacher distribution.

randint(key, shape, minval, maxval[, dtype, ...])

Sample uniform random values in [minval, maxval) with given shape/dtype.

rayleigh(key, scale[, shape, dtype])

Sample Rayleigh random values with given shape and float dtype.

t(key, df[, shape, dtype])

Sample Student's t random values with given shape and float dtype.

triangular(key, left, mode, right[, shape, ...])

Sample Triangular random values with given shape and float dtype.

truncated_normal(key, lower, upper[, shape, ...])

Sample truncated standard normal random values with given shape and dtype.

uniform(key[, shape, dtype, minval, maxval, ...])

Sample uniform random values in [minval, maxval) with given shape/dtype.

wald(key, mean[, shape, dtype])

Sample Wald random values with given shape and float dtype.

weibull_min(key, scale, concentration[, ...])

Sample from a Weibull distribution.


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3