A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://docs.jax.dev/en/latest/_autosummary/jax.debug.print.html below:

jax.debug.print — JAX documentation

jax.debug.print#
jax.debug.print(fmt, *args, ordered=False, partitioned=False, **kwargs)[source]#

Prints values and works in staged out JAX functions.

This function does not work with f-strings because formatting is delayed. So instead of jax.debug.print(f"hello {bar}"), write jax.debug.print("hello {bar}", bar=bar).

This function is a thin convenience wrapper around jax.debug.callback(). The implementation is essentially:

def debug_print(fmt: str, *args, **kwargs):
  jax.debug.callback(
      lambda *args, **kwargs: print(fmt.format(*args, **kwargs)),
      *args, **kwargs)

It may be useful to call jax.debug.callback() directly instead of this convenience wrapper. For example, to get debug printing in logs, you might use jax.debug.callback() together with logging.log.

Parameters:
  • fmt (str) – A format string, e.g. "hello {x}", that will be used to format input arguments, like str.format. See the Python docs on string formatting and format string syntax.

  • *args – A list of positional arguments to be formatted, as if passed to fmt.format.

  • ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this jax.debug.print w.r.t. other ordered jax.debug.print calls.

  • partitioned (bool) – If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first.

  • **kwargs – Additional keyword arguments to be formatted, as if passed to fmt.format.

Return type:

None


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3