A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/issues/16413 below:

`lax.cond` can bind to unexpected function signature · Issue #16413 · jax-ml/jax · GitHub

Description

If you call a lax.cond with the format lax.cond(predicate, function, function, callable_pytree, callable_pytree) then lax.cond will bind to the old function signature <Signature (pred, true_operand, true_fun: Callable, false_operand, false_fun: Callable)> and swap the arguments in an unexpected way.

Here is a reproducing example:

import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu

def true_branch(add_one, add_two):
    return add_one(add_two(jnp.array(1.)))

def false_branch(add_one, add_two):
    return add_two(add_one(jnp.array(1.)))

add_one = jtu.Partial(jnp.add, jnp.array(1.)) # A callable pytree
add_two = jtu.Partial(jnp.add, jnp.array(2.))
four = lax.cond(True, true_branch, false_branch, add_one, add_two) # TypeError
What jax/jaxlib version are you using?

0.4.11

Which accelerator(s) are you using?

CPU

Additional system info

Linux

NVIDIA GPU info

No response


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3