A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/discussions/18480 below:

changes coming to JAX's RNG – auto-parallelizable by default · jax-ml/jax · Discussion #18480 · GitHub

Skip to content Navigation Menu Search code, repositories, users, issues, pull requests...

Saved searches Use saved searches to filter your results more quickly

Sign up You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session. You switched accounts on another tab or window. Reload to refresh your session. Dismiss alert changes coming to JAX's RNG – auto-parallelizable by default #18480

Nov 10, 2023 · 3 comments · 2 replies

{{actor}} deleted this content .

-

Partitionable Threefry RNG upgrade

Important

EDIT (January 2025): this upgrade is complete as of JAX v.0.5.0, see the comment posted below for more details

You can stop reading now if:

What is happening?

What do we mean by a one-time change to random values? Here is today's default behavior:

$ JAX_THREEFRY_PARTITIONABLE=False python -c '
> import jax
> print(jax.random.randint(jax.random.key(72), (), 0, 10))'
3

And here is what will happen when the default setting changes soon:

$ JAX_THREEFRY_PARTITIONABLE=True python -c '
> import jax
> print(jax.random.randint(jax.random.key(72), (), 0, 10))'
7

Same key, different generated value when JAX_THREEFRY_PARTITIONABLE=False versus JAX_THREEFRY_PARTITIONABLE=True.

JAX's RNG will remain deterministic. We try to rarely change the output of JAX's pseudorandom functions. That said, our API policy promises stability in distribution, not in value. This particular change is broad-based, and is the first of its kind in a long time, so we're drawing extra attention to it.

This change will break code that depends on specific RNG keys generating specific RNG values. Common examples include reference tests, high-variance randomized tests, or a machine learning experiment that depends on random values (e.g. model initialization) that you want to reproduce exactly.

Non-partitionable Threefry will later be deprecated. That is, at some point after we've upgraded the default value of the jax_threefry_partitionable setting, we will deprecate the flag entirely.

Who is affected?

JAX supports several RNG schemes. Its current three built-in modes are called threefry2x32, rbg and unsafe_rbg. These modes can be set using the impl argument to jax.random.PRNGKey and jax.random.key, or with the configuration flag jax_default_prng_impl.

You are affected if you use either threefry2x32 or rbg. Specifically:

Opting in early

To try the upgrade now, you can set the configuration flag jax_threefry_partitionable to True in your code explicitly. This can be done with the environment variable JAX_THREEFRY_PARTITIONABLE=True, the command-line flag --jax_threefry_partitionable=True, or programmatically, using jax.config.update or the jax.threefry_partitionable context manager. For example:

jax.config.update("jax_threefry_partitionable", True)

# OR

with jax.threefry_partitionable(True):
  ...

Beta Was this translation helpful? Give feedback.

You must be logged in to vote

All reactions

-

There is a typo in what is happening:
Change jax_threefry_patitionable to jax_threefry_partitionable if you want.

Thanks for the heads up!

Beta Was this translation helpful? Give feedback.

You must be logged in to vote

1 reply

This comment was marked as off-topic.

-

This is unrelated to the topic here - please open a new discussion for this question.

Beta Was this translation helpful? Give feedback.

{{actor}} deleted this content .

-

v.0.5.0 update: partitionable by default

As of jax v.0.5.0, partitionable Threefry is enabled by default, i.e. the default value of the jax_threefry_partitionable setting is True. If this breaks your code or project, what can you do?

One option, for now, is to ignore the issue and revert the upgrade by setting jax_threefry_partitionable to False.

To investigate the issue further, consider where your code might possibly be sensitive to random values, rather than only to their (pseudorandom) distribution.

For example, consider this unit test:

def test_normal_negative():
  key = jax.random.key(42)
  value = jax.random.normal(key)
  assert value < 0.

Because pseudorandom values are not guaranteed stable, value-sensitive tests are fragile. For example, they can fail for different values of the random seed, or when JAX's PRNG changes.

The unit test above might pass at v0.5.0, but it will fail if we change the seed argument to jax.random.key from 42 to 727, because then value becomes positive. Or it may change if we change JAX's PRNG algorithm again, even if the distribution stays the same.

One way to identify value-sensitive tests heuristically is to run them under several different random seeds, which in turn changes the values they generate. If several such trials fail, the test may be too value-sensitive. To run a test under several seeds, you can use the environment variable JAX_RANDOM_SEED_OFFSET. Its value is an integer (zero by default) that JAX adds to every seed in the program (i.e. any argument to jax.random.key or jax.random.PRNGKey). For example:

$ JAX_RANDOM_SEED_OFFSET=0 python -c 'import jax.random as jr; print(jr.normal(jr.key(7)))'
0.45123515
$ JAX_RANDOM_SEED_OFFSET=3 python -c 'import jax.random as jr; print(jr.normal(jr.key(4)))'
0.45123515

Two common flavors of value-sensitive tests are "reference tests," that compare a randomized outcome against a fixed reference value, and tests with overly tight numerical tolerances. Well-behaved randomized tests may also be affected if they have some chance of false failure. For instance, if a test suite runs a thousand independent tests, each with a false failure probability of 5%, then we might expect roughly 50 tests to fail when the random bits that they consume change.

To address issues like this, consider (a) rewriting code to somehow be less sensitive to random values, or (b) changing its random seed (provided that the code is behaving as intended to begin with).

Beta Was this translation helpful? Give feedback.

You must be logged in to vote

0 replies

Heading

Bold

Italic

Quote

Code

Link

Numbered list

Unordered list

Task list

Attach files

Mention

Reference

Menu

Loading

reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji

You can’t perform that action at this time.


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3