jax.tree
module#
Utilities for working with tree-like container data structures.
The jax.tree
namespace contains aliases of utilities from jax.tree_util
.
all
(tree, *[, is_leaf])
Call all() over the leaves of a tree.
flatten
(tree[, is_leaf])
Flattens a pytree.
flatten_with_path
(tree[, is_leaf])
Flattens a pytree like tree_flatten
, but also returns each leaf's key path.
leaves
(tree[, is_leaf])
Gets the leaves of a pytree.
leaves_with_path
(tree[, is_leaf])
Gets the leaves of a pytree like tree_leaves
and returns each leaf's key path.
map
(f, tree, *rest[, is_leaf])
Maps a multi-input function over pytree args to produce a new pytree.
map_with_path
(f, tree, *rest[, is_leaf])
Maps a multi-input function over pytree key path and args to produce a new pytree.
reduce
()
Call reduce() over the leaves of a tree.
structure
(tree[, is_leaf])
Gets the treedef for a pytree.
transpose
(outer_treedef, inner_treedef, ...)
Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).
unflatten
(treedef, leaves)
Reconstructs a pytree from the treedef and the leaves.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3