Prints values and works in staged out JAX functions.
This function does not work with f-strings because formatting is delayed. So instead of jax.debug.print(f"hello {bar}")
, write jax.debug.print("hello {bar}", bar=bar)
.
This function is a thin convenience wrapper around jax.debug.callback()
. The implementation is essentially:
def debug_print(fmt: str, *args, **kwargs): jax.debug.callback( lambda *args, **kwargs: print(fmt.format(*args, **kwargs)), *args, **kwargs)
It may be useful to call jax.debug.callback()
directly instead of this convenience wrapper. For example, to get debug printing in logs, you might use jax.debug.callback()
together with logging.log
.
fmt (str) – A format string, e.g. "hello {x}"
, that will be used to format input arguments, like str.format
. See the Python docs on string formatting and format string syntax.
*args – A list of positional arguments to be formatted, as if passed to fmt.format
.
ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this jax.debug.print
w.r.t. other ordered jax.debug.print
calls.
partitioned (bool) – If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first.
**kwargs – Additional keyword arguments to be formatted, as if passed to fmt.format
.
None
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3