Neural networks with JAX
Flax Linen delivers an end-to-end and flexible user experience for researchers who use JAX with neural networks. Flax exposes the full power of JAX. It is made up of loosely coupled libraries, which are showcased with end-to-end integrated guides and examples.
Flax Linen is used by hundreds of projects (and growing), both in the open source community (like Hugging Face) and at Google (like Gemini, Imagen, Scenic, and Big Vision).
Features#Safety
Flax is designed for correctness and safety. Thanks to its immutable Modules and Functional API, Flax helps mitigate bugs that arise when handling state in JAX.
Control
Flax grants more fine-grained control and expressivity than most Neural Network frameworks via its Variable Collections, RNG Collections and Mutability conditions.
Functional API
Flax’s functional API radically redefines what Modules can do via lifted transformations like vmap, scan, etc, while also enabling seamless integration with other JAX libraries like Optax and Chex.
Terse code
Flax’s compact
Modules enables submodules to be defined directly at their callsite, leading to code that is easier to read and avoids repetition.
pip install flax # or to install the latest version of Flax: pip install --upgrade git+https://github.com/google/flax.git
Flax installs the vanilla CPU version of JAX, if you need a custom version please check out JAX’s installation page.
Basic usage#class MLP(nn.Module): # create a Flax Module dataclass out_dims: int @nn.compact def __call__(self, x): x = x.reshape((x.shape[0], -1)) x = nn.Dense(128)(x) # create inline Flax Module submodules x = nn.relu(x) x = nn.Dense(self.out_dims)(x) # shape inference return x model = MLP(out_dims=10) # instantiate the MLP model x = jnp.empty((4, 28, 28, 1)) # generate random data variables = model.init(random.key(42), x)# initialize the weights y = model.apply(variables, x) # make forward passEcosystem#
Notable examples in Flax include:
NLP and computer vision models
Model for text-to-image generation
540-billion parameter model for text generation
Text-to-image diffusion models
Libraries for large-scale computer vision
Large-scale computer vision models
Open source high performance LLM
On-device differentiable reinforcement learning environments
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4