yashkatariya@
TL;DR#JAX switched its default array implementation to the new jax.Array
as of version 0.4.1. This guide explains the reasoning behind this, the impact it might have on your code, and how to (temporarily) switch back to the old behavior.
jax.Array
is a unified array type that subsumes DeviceArray
, ShardedDeviceArray
, and GlobalDeviceArray
types in JAX. The jax.Array
type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and pjit. If your code doesn’t mention DeviceArray
vs ShardedDeviceArray
vs GlobalDeviceArray
, no changes are needed. But code that depends on details of these separate classes may need to be tweaked to work with the unified jax.Array
After the migration is complete jax.Array
will be the only type of array in JAX.
This doc explains how to migrate existing codebases to jax.Array
. For more information on using jax.Array
and JAX parallelism APIs, see the Distributed arrays and automatic parallelization tutorial.
You can enable jax.Array
by:
setting the shell environment variable JAX_ARRAY
to something true-like (e.g., 1
);
setting the boolean flag jax_array
to something true-like if your code parses flags with absl;
using this statement at the top of your main file:
import jax jax.config.update('jax_array', True)
The easiest way to tell if jax.Array
is responsible for any problems is to disable jax.Array
and see if the issues go away.
Through March 15, 2023 it will be possible to disable jax.Array by:
setting the shell environment variable JAX_ARRAY
to something falsey (e.g., 0
);
setting the boolean flag jax_array
to something falsey if your code parses flags with absl;
using this statement at the top of your main file:
import jax jax.config.update('jax_array', False)
Currently JAX has three types; DeviceArray
, ShardedDeviceArray
and GlobalDeviceArray
. jax.Array
merges these three types and cleans up JAX’s internals while adding new parallelism features.
We also introduce a new Sharding
abstraction that describes how a logical Array is physically sharded out across one or more devices, such as TPUs or GPUs. The change also upgrades, simplifies and merges the parallelism features of pjit
into jit
. Functions decorated with jit
will be able to operate over sharded arrays without copying data onto a single device.
Features you get with jax.Array
:
C++ pjit
dispatch path
Op-by-op parallelism (even if the array distributed across multiple devices across multiple hosts)
Simpler batch data parallelism with pjit
/jit
.
Ways to create Sharding
s that are not necessarily consisting of a mesh and partition spec. Can fully utilize the flexibility of OpSharding if you want or any other Sharding that you want.
and many more
Example:
import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P import numpy as np x = jnp.arange(8) # Let's say there are 8 devices in jax.devices() mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y')) sharding = jax.sharding.NamedSharding(mesh, P('x')) sharded_x = jax.device_put(x, sharding) # `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a # sharded array without copying data to a single device. matmul_sharded_x = sharded_x @ sharded_x.T sin_sharded_x = jnp.sin(sharded_x) # Even jnp.copy preserves the sharding on the output. copy_sharded_x = jnp.copy(sharded_x) # double_out is also sharded double_out = jax.jit(lambda x: x * 2)(sharded_x)What issues can arise when jax.Array is switched on?# New public type named jax.Array#
All isinstance(..., jnp.DeviceArray)
or isinstance(.., jax.xla.DeviceArray)
and other variants of DeviceArray
should be switched to using isinstance(..., jax.Array)
.
Since jax.Array
can represent DA, SDA and GDA, you can differentiate those 3 types in jax.Array
via:
x.is_fully_addressable and len(x.sharding.device_set) == 1
– this means that jax.Array
is like a DA
x.is_fully_addressable and (len(x.sharding.device_set) > 1
– this means that jax.Array
is like a SDA
not x.is_fully_addressable
– this means that jax.Array
is like a GDA and spans across multiple processes
For ShardedDeviceArray
, you can move isinstance(..., pxla.ShardedDeviceArray)
to isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1
.
In general it is not possible to differentiate a ShardedDeviceArray
on 1 device from any other kind of single-device Array.
GDA’s local_shards
and local_data
have been deprecated.
Please use addressable_shards
and addressable_data
which are compatible with jax.Array
and GDA
.
All JAX functions will output jax.Array
when the jax_array
flag is True. If you were using GlobalDeviceArray.from_callback
or make_sharded_device_array
or make_device_array
functions to explicitly create the respective JAX data types, you will need to switch them to use jax.make_array_from_callback()
or jax.make_array_from_single_device_arrays()
.
For GDA:
GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)
can become jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)
in a 1:1 switch.
If you were using the raw GDA constructor to create GDAs, then do this:
GlobalDeviceArray(shape, mesh, pspec, buffers)
can become jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)
For SDA:
make_sharded_device_array(aval, sharding_spec, device_buffers, indices)
can become jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)
.
To decide what the sharding should be, it depends on why you were creating the SDAs:
If it was created to give as an input to pmap
, then sharding can be: jax.sharding.PmapSharding(devices, sharding_spec)
.
If it was created to give as an input to pjit
, then sharding can be jax.sharding.NamedSharding(mesh, pspec)
.
If you are exclusively using GDA arguments to pjit, you can skip this section! 🎉
With jax.Array
enabled, all inputs to pjit
must be globally shaped. This is a breaking change from the previous behavior where pjit
would concatenate process-local arguments into a global value; this concatenation no longer occurs.
Why are we making this breaking change? Each array now says explicitly how its local shards fit into a global whole, rather than leaving it implicit. The more explicit representation also unlocks additional flexibility, for example the use of non-contiguous meshes with pjit
which can improve efficiency on some TPU models.
Running multi-process pjit computation and passing host-local inputs when jax.Array
is enabled can lead to an error similar to this:
Example:
Mesh = {'x': 2, 'y': 2, 'z': 2}
and host local input shape == (4,)
and pspec = P(('x', 'y', 'z'))
Since pjit
doesn’t lift host local shapes to global shapes with jax.Array
, you get the following error:
Note: You will only see this error if your host local shape is smaller than the shape of the mesh.
ValueError: One of pjit arguments was given the sharding of NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)), which implies that the global size of its dimension 0 should be divisible by 8, but it is equal to 4
The error makes sense because you can’t shard dimension 0, 8 ways when the value on dimension 0
is 4
.
How can you migrate if you still pass host local inputs to pjit
? We are providing transitional APIs to help you migrate:
Note: You don’t need these utilities if you run your pjitted computation on a single process.
from jax.experimental import multihost_utils global_inps = multihost_utils.host_local_array_to_global_array( local_inputs, mesh, in_pspecs) global_outputs = pjit(f, in_shardings=in_pspecs, out_shardings=out_pspecs)(global_inps) local_outs = multihost_utils.global_array_to_host_local_array( global_outputs, mesh, out_pspecs)
host_local_array_to_global_array
is a type cast that looks at a value with only local shards and changes its local shape to the shape that pjit
would have previously assumed if that value was passed before the change.
Passing in fully replicated inputs i.e. same shape on each process with P(None)
as in_axis_resources
is still supported. In this case you do not have to use host_local_array_to_global_array
because the shape is already global.
key = jax.random.PRNGKey(1) # As you can see, using host_local_array_to_global_array is not required since in_axis_resources says # that the input is fully replicated via P(None) pjit(f, in_shardings=None, out_shardings=None)(key) # Mixing inputs global_inp = multihost_utils.host_local_array_to_global_array( local_inp, mesh, P('data')) global_out = pjit(f, in_shardings=(P(None), P('data')), out_shardings=...)(key, global_inp)FROM_GDA and jax.Array#
If you were using FROM_GDA
in in_axis_resources
argument to pjit
, then with jax.Array
there is no need to pass anything to in_axis_resources
as jax.Array
will follow computation follows sharding semantics.
For example:
pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)
If you have PartitionSpecs mixed in with FROM_GDA
for inputs like numpy arrays, etc, then use host_local_array_to_global_array
to convert them to jax.Array
.
For example:
If you had this:
pjitted_f = pjit( f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)), out_shardings=...) pjitted_f(gda1, np_array1, gda2, np_array2)
then you can replace it with:
pjitted_f = pjit(f, out_shardings=...) array2, array3 = multihost_utils.host_local_array_to_global_array( (np_array1, np_array2), mesh, (P('x'), P(None))) pjitted_f(array1, array2, array3, array4)live_buffers replaced with live_arrays#
live_buffers
attribute on jax Device
has been deprecated. Please use jax.live_arrays()
instead which is compatible with jax.Array
.
If you are passing host local inputs to pjit
in a multi-process environment, then please use multihost_utils.host_local_array_to_global_array
to convert the batch to a global jax.Array
and then pass that to pjit
.
The most common example of such a host local input is a batch of input data.
This will work for any host local input (not just a batch of input data).
from jax.experimental import multihost_utils batch = multihost_utils.host_local_array_to_global_array( batch, mesh, batch_partition_spec)
See the pjit section above for more details about this change and more examples.
RecursionError: Recursively calling jit#This happens when some part of your code has jax.Array
disabled and then you enable it only for some other part. For example, if you use some third_party code which has jax.Array
disabled and you get a DeviceArray
from that library and then you enable jax.Array
in your library and pass that DeviceArray
to JAX functions, it will lead to a RecursionError.
This error should go away when jax.Array
is enabled by default so that all libraries return jax.Array
unless they explicitly disable it.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3