If you call a lax.cond
with the format lax.cond(predicate, function, function, callable_pytree, callable_pytree)
then lax.cond
will bind to the old function signature <Signature (pred, true_operand, true_fun: Callable, false_operand, false_fun: Callable)>
and swap the arguments in an unexpected way.
Here is a reproducing example:
import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu def true_branch(add_one, add_two): return add_one(add_two(jnp.array(1.))) def false_branch(add_one, add_two): return add_two(add_one(jnp.array(1.))) add_one = jtu.Partial(jnp.add, jnp.array(1.)) # A callable pytree add_two = jtu.Partial(jnp.add, jnp.array(2.)) four = lax.cond(True, true_branch, false_branch, add_one, add_two) # TypeErrorWhat jax/jaxlib version are you using?
0.4.11
Which accelerator(s) are you using?CPU
Additional system infoLinux
NVIDIA GPU infoNo response
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3