Create a function that returns the jaxpr of fun
given example args.
fun – The function whose jaxpr
is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.
static_argnums – See the jax.jit()
docstring.
axis_env – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications of jax.pmap()
.
return_shape – Optional boolean, defaults to False
. If True
, the wrapped function returns a pair where the first element is the ClosedJaxpr
representation of fun
and the second element is a pytree with the same structure as the output of fun
and where the leaves are objects with shape
and dtype
attributes representing the corresponding types of the output leaves.
A wrapped version of fun
that when applied to example arguments returns a ClosedJaxpr
representation of fun
on those arguments. If the argument return_shape
is True
, then the returned function instead returns a pair where the first element is the ClosedJaxpr
representation of fun
and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of fun
.
A jaxpr
is JAX’s intermediate representation for program traces. The jaxpr
language is based on the simply-typed first-order lambda calculus with let-bindings. make_jaxpr()
adapts a function to return its jaxpr
, which we can inspect to understand what JAX is doing internally. The jaxpr
returned is a trace of fun
abstracted to ShapedArray
level. Other levels of abstraction exist internally.
We do not describe the semantics of the jaxpr
language in detail here, but instead give a few examples.
>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let b:f32[] = cos a c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b e:f32[] = mul 1.0:f32[] d f:f32[] = neg e g:f32[] = mul f c in (g,) }
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.5