We have marked the host_callback APIs deprecated on March 21, 2024 (JAX version 0.4.26). They will be removed in October 2024. Users should use instead the new JAX external callbacks.
Quick temporary migrationAs of October 1st, 2024 (JAX version 0.4.34) if you use the jax.experimental.host_callback
APIs they will be implemented in terms of jax.experimental.io_callback. This is controlled by the configuration variable --jax_host_callback_legacy=False
(or the environment variable JAX_HOST_CALLBACK_LEGACY=False
.
For a very limited time, you can obtain the old behavior by setting the configuration variable to True
.
Very soon this configuration flag will be removed, so it is best to take the time to do the migration as explained below.
It is best to study the different flavors of JAX external callbacks to pick the right one for your use case.
In general io_callback(ordered=True)
will have more similar support to the existing host_callback
.
In general, you should replace calls to id_tap
and call
with io_callback
, except when you need these calls to work under vmap
, grad
, jvp
, scan
, or cond
, in which case you should use jax.debug.callback
. Note that jax.debug.callback
does not support returning values from the callback, so it can be used only in lieu of .id_print
or host_callback.id_tap
or in lieu of host_callback.call
when the result_shape=None
.
tap_with_device
option for id_tap
and the call_with_device
option for call
are not supported. You must change the callbacks to not need the device
argument. If you use JAX_HOST_CALLBACK_LEGACY=False
you will get an error.transforms
argument to the callback called from id_tap
is not supported. If you use JAX_HOST_CALLBACK_LEGACY=False
the callback will be passed the empty tuple (no transforms).host_callback
APIs passed np.ndarray
objects to the callback. The new JAX external callbacks pass jax.Array
. This should be Ok, except that it may lead to a deadlock if the code making the call is already running on CPU, because the callback will try to invoke JAX functions on the arguments and will find the device busy. The solution is to add input = np.array(input)
at the start of your callback function.io_callback(ordered=True)
with jax.grad
, you will get an error that io_callback
does not support JVP. Try to use debug_callback
.io_callback(ordered=True)
with jax.pmap
you will get an error that ordered effects are not supported under jax.pmap
. Try to use ordered=True
.io_callback
in place of host_callback.call
For example,
from jax.experimental import host_callback
res = host_callback.call(fn, arg, result_shape=result_shape_dtypes)
should be replaced with
from jax.experimental import io_callback
res = io_callback(fn, result_shape_dtypes, arg)
Using io_callback
in place of host_callback.id_tap
Similarly, id_tap
can be replaced with a io_callback
with result_shape_dtypes=None
:
callback = lambda x, transforms: do_something(x)
res = host_callback.id_tap(callback, x_in)
should be replaced with
callback = lambda x: do_something(x)
io_callback(callback, None, x_in)
res = x_in # Simulates the return value of `id_tap`
Note that we have removed the transforms
callback argument (this is not supported by the new callbacks).
If you use the result
parameter with id_tap
then you can replace:
results = id_tap(
lambda arg, transform: done_callback(arg),
arg,
result=the_results,
)
with
io_callback(
lambda arg: done_callback(arg),
None,
arg
)
results = the_results
Using jax.debug.print
in place of host_callback.id_print
For id_print
you should use instead jax.debug.print
. E.g.,
id_print(x)
can be replaced by debug.print('{}', x)
.
If you use the name
parameter, you can replaceid_print(x, name="my_x")
with jax.debug_print('name: my_x\n{}', x)
.
If you use the output_stream
parameter, you can replace:id_print(x, output_stream=s)
by jax.experimental.io_callback(lambda x: s.write(str(x)), None, x)
.
jax.effects_barrier
in place of host_callback.barrier_wait
Finally, host_callback.barrier_wait
should be replaced with jax.effects_barrier()
.
jax.vmap
Under vmap
the new callbacks behave differently than the host_callback. The latter will make a single call with a vector value, while the new callbacks will behave like a loop, and will make separate calls for each element in the vmap. For example, the code
def host_fn(x):
print(x)
def fn(x):
res = 2 * x
id_tap(host_fn, res)
return res
jax.vmap(fn)(np.arange(3))
makes one call to host_fn
with the vector [0, 2, 4], and if we replace id_tap(host_fn, res)
with jax.debug.callback(host_fn, res)
we will get 3 separate calls with 0
, 2
, and 4
, respectively.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3