Call body_fun
repeatedly in a loop while cond_fun
is True.
The Haskell-like type signature in brief is
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
The semantics of while_loop
are given by this Python implementation:
def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
Unlike that Python version, while_loop
is a JAX primitive and is lowered to a single WhileOp. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an @jit
function are unrolled, leading to large XLA computations.
Also unlike the Python analogue, the loop-carried value val
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).
Another difference from using Python-native loop constructs is that while_loop
is not reverse-mode differentiable because XLA computations require static bounds on memory requirements.
Note
while_loop()
compiles cond_fun
and body_fun
, so while it can be combined with jit()
, it’s usually unnecessary.
cond_fun (Callable[[T], BooleanNumeric]) – function of type a -> Bool
.
body_fun (Callable[[T], T]) – function of type a -> a
.
init_val (T) – value of type a
, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value.
The output from the final iteration of body_fun, of type a
.
T
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3