Nov 10, 2023 · 3 comments · 2 replies
{{actor}} deleted this content .
-
Partitionable Threefry RNG upgradeImportant
EDIT (January 2025): this upgrade is complete as of JAX v.0.5.0, see the comment posted below for more detailsYou can stop reading now if:
jax_threefry_partitionable
to True
; orunsafe_rbg
JAX RNG algorithm.JAX's default RNG algorithm ("Threefry") is changing under the hood, to make random number generation efficiently auto-parallelizable ("partitionable"). This makes random numbers faster with multiple devices.
The current behavior corresponds to setting the config field jax_threefry_partitionable
to False
, its current default value. The new behavior corresponds to setting jax_threefry_partitionable
to True
, its future default value.
Some time soon, we will set out to change the default value of the configuration flag jax_threefry_partitionable
from False
to True
. This will cause a one-time change in the random values generated from a given RNG key. See code below.
Want to ensure this goes smoothly? Try flipping jax_threefry_partitionable
to True
today to detect any issues in your code ahead of the upgrade. See code below.
What do we mean by a one-time change to random values? Here is today's default behavior:
$ JAX_THREEFRY_PARTITIONABLE=False python -c ' > import jax > print(jax.random.randint(jax.random.key(72), (), 0, 10))' 3
And here is what will happen when the default setting changes soon:
$ JAX_THREEFRY_PARTITIONABLE=True python -c ' > import jax > print(jax.random.randint(jax.random.key(72), (), 0, 10))' 7
Same key, different generated value when JAX_THREEFRY_PARTITIONABLE=False
versus JAX_THREEFRY_PARTITIONABLE=True
.
JAX's RNG will remain deterministic. We try to rarely change the output of JAX's pseudorandom functions. That said, our API policy promises stability in distribution, not in value. This particular change is broad-based, and is the first of its kind in a long time, so we're drawing extra attention to it.
This change will break code that depends on specific RNG keys generating specific RNG values. Common examples include reference tests, high-variance randomized tests, or a machine learning experiment that depends on random values (e.g. model initialization) that you want to reproduce exactly.
Non-partitionable Threefry will later be deprecated. That is, at some point after we've upgraded the default value of the jax_threefry_partitionable
setting, we will deprecate the flag entirely.
JAX supports several RNG schemes. Its current three built-in modes are called threefry2x32
, rbg
and unsafe_rbg
. These modes can be set using the impl
argument to jax.random.PRNGKey
and jax.random.key
, or with the configuration flag jax_default_prng_impl
.
You are affected if you use either threefry2x32
or rbg
. Specifically:
threefry2x32
.rbg
mode. That is, the keys generated from a specific key, using jax.random.split
or jax.random.fold_in
, will be different than before. In turn, random values generated from such derived keys will change.To try the upgrade now, you can set the configuration flag jax_threefry_partitionable
to True
in your code explicitly. This can be done with the environment variable JAX_THREEFRY_PARTITIONABLE=True
, the command-line flag --jax_threefry_partitionable=True
, or programmatically, using jax.config.update
or the jax.threefry_partitionable
context manager. For example:
jax.config.update("jax_threefry_partitionable", True) # OR with jax.threefry_partitionable(True): ...
Beta Was this translation helpful? Give feedback.
You must be logged in to vote
All reactions-
There is a typo in what is happening:
Change jax_threefry_patitionable
to jax_threefry_partitionable
if you want.
Thanks for the heads up!
Beta Was this translation helpful? Give feedback.
You must be logged in to vote
1 reply
This comment was marked as off-topic.
-
This is unrelated to the topic here - please open a new discussion for this question.
Beta Was this translation helpful? Give feedback.
{{actor}} deleted this content .
-
v.0.5.0 update: partitionable by defaultAs of jax v.0.5.0, partitionable Threefry is enabled by default, i.e. the default value of the jax_threefry_partitionable
setting is True
. If this breaks your code or project, what can you do?
One option, for now, is to ignore the issue and revert the upgrade by setting jax_threefry_partitionable
to False
.
To investigate the issue further, consider where your code might possibly be sensitive to random values, rather than only to their (pseudorandom) distribution.
For example, consider this unit test:
def test_normal_negative(): key = jax.random.key(42) value = jax.random.normal(key) assert value < 0.
Because pseudorandom values are not guaranteed stable, value-sensitive tests are fragile. For example, they can fail for different values of the random seed, or when JAX's PRNG changes.
The unit test above might pass at v0.5.0, but it will fail if we change the seed argument to jax.random.key
from 42 to 727, because then value
becomes positive. Or it may change if we change JAX's PRNG algorithm again, even if the distribution stays the same.
One way to identify value-sensitive tests heuristically is to run them under several different random seeds, which in turn changes the values they generate. If several such trials fail, the test may be too value-sensitive. To run a test under several seeds, you can use the environment variable JAX_RANDOM_SEED_OFFSET
. Its value is an integer (zero by default) that JAX adds to every seed in the program (i.e. any argument to jax.random.key
or jax.random.PRNGKey
). For example:
$ JAX_RANDOM_SEED_OFFSET=0 python -c 'import jax.random as jr; print(jr.normal(jr.key(7)))' 0.45123515 $ JAX_RANDOM_SEED_OFFSET=3 python -c 'import jax.random as jr; print(jr.normal(jr.key(4)))' 0.45123515
Two common flavors of value-sensitive tests are "reference tests," that compare a randomized outcome against a fixed reference value, and tests with overly tight numerical tolerances. Well-behaved randomized tests may also be affected if they have some chance of false failure. For instance, if a test suite runs a thousand independent tests, each with a false failure probability of 5%, then we might expect roughly 50 tests to fail when the random bits that they consume change.
To address issues like this, consider (a) rewriting code to somehow be less sensitive to random values, or (b) changing its random seed (provided that the code is behaving as intended to begin with).
Beta Was this translation helpful? Give feedback.
You must be logged in to vote
0 replies
Heading
Bold
Italic
Quote
Code
Link
Numbered list
Unordered list
Task list
Attach files
Mention
Reference
MenuLoading
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emojiYou can’t perform that action at this time.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3