A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/issues/20385 below:

Deprecate jax.experimental.host_callback in favor of JAX external callbacks · Issue #20385 · jax-ml/jax · GitHub

We have marked the host_callback APIs deprecated on March 21, 2024 (JAX version 0.4.26). They will be removed in October 2024. Users should use instead the new JAX external callbacks.

Quick temporary migration

As of October 1st, 2024 (JAX version 0.4.34) if you use the jax.experimental.host_callback APIs they will be implemented in terms of jax.experimental.io_callback. This is controlled by the configuration variable --jax_host_callback_legacy=False (or the environment variable JAX_HOST_CALLBACK_LEGACY=False.

For a very limited time, you can obtain the old behavior by setting the configuration variable to True.
Very soon this configuration flag will be removed, so it is best to take the time to do the migration as explained below.

Real migration

It is best to study the different flavors of JAX external callbacks to pick the right one for your use case.

In general io_callback(ordered=True) will have more similar support to the existing host_callback.

In general, you should replace calls to id_tap and call with io_callback, except when you need these calls to work under vmap, grad, jvp, scan, or cond, in which case you should use jax.debug.callback. Note that jax.debug.callback does not support returning values from the callback, so it can be used only in lieu of .id_print or host_callback.id_tap or in lieu of host_callback.call when the result_shape=None.

Known migration issues Using io_callback in place of host_callback.call

For example,

from jax.experimental import host_callback
res = host_callback.call(fn, arg, result_shape=result_shape_dtypes)

should be replaced with

from jax.experimental import io_callback
res = io_callback(fn, result_shape_dtypes, arg)
Using io_callback in place of host_callback.id_tap

Similarly, id_tap can be replaced with a io_callback with result_shape_dtypes=None:

 callback = lambda x, transforms: do_something(x)
 res = host_callback.id_tap(callback, x_in)

should be replaced with

  callback = lambda x: do_something(x)
  io_callback(callback, None, x_in)
  res = x_in  # Simulates the return value of `id_tap`

Note that we have removed the transforms callback argument (this is not supported by the new callbacks).

If you use the result parameter with id_tap then you can replace:

results = id_tap(
    lambda arg, transform: done_callback(arg),
    arg,
    result=the_results,
)

with

io_callback(
    lambda arg: done_callback(arg),
    None,
    arg
)
results = the_results
Using jax.debug.print in place of host_callback.id_print

For id_print you should use instead jax.debug.print. E.g.,

id_print(x) can be replaced by debug.print('{}', x).

If you use the name parameter, you can replace
id_print(x, name="my_x") with jax.debug_print('name: my_x\n{}', x).

If you use the output_stream parameter, you can replace:
id_print(x, output_stream=s) by jax.experimental.io_callback(lambda x: s.write(str(x)), None, x).

Using jax.effects_barrier in place of host_callback.barrier_wait

Finally, host_callback.barrier_wait should be replaced with jax.effects_barrier().

Callbacks and jax.vmap

Under vmap the new callbacks behave differently than the host_callback. The latter will make a single call with a vector value, while the new callbacks will behave like a loop, and will make separate calls for each element in the vmap. For example, the code

def host_fn(x):
  print(x)

def fn(x):
  res = 2 * x
  id_tap(host_fn, res)
  return res

jax.vmap(fn)(np.arange(3))

makes one call to host_fn with the vector [0, 2, 4], and if we replace id_tap(host_fn, res) with jax.debug.callback(host_fn, res) we will get 3 separate calls with 0, 2, and 4, respectively.


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3