Apart from TensorFlow and PyTorch, Google’s new framework, Just After Execution or JAX, has become increasingly popular and with good reason. Essentially, JAX was developed to accelerate machine learning tasks and make Python’s Numpy easier to use. Even though deep learning is a subset of what JAX can do, JAX gained ground after it was used in Google’s Vision Transformer (ViT) and DeepMind engineers posted a blog explaining why it was suitable for several projects. Quite simply, JAX is a high-performing Python library meant for numerical computing, especially in research.
There are several reasons to use JAX or even not to use it. Let’s weigh in:
Why JAX?Speed: All JAX operations are based on XLA or Accelerated Linear Algebra which is responsible for JAX’s speed. Also developed by Google, XLA is a domain-specific compiler for linear algebra that uses whole-program optimisations to accelerate computing. XLA makes BERT’s training speed faster by almost 7.3 times. More importantly, using XLA lowers memory usage, enabling gradient accumulation, which boosts computational throughput by 12 times in the long run.
Source: AssemblyAI
JAX also allows users to transform their functions into just-in-time or JIT-compiled versions. With JIT, the speed of subsequent executions can be improved by adding a simple function decorator. However, every function can’t be compiled using JIT in JAX. The JAX documentation notes the exceptions to this rule.
Source: Tensorflow
Compatibility with GPUs: Unlike Numpy, which is only compatible with CPUs, JAX is compatible with both CPUs and GPUs easily and has an API that is very similar to Numpy. This is why JAX is able to auto-compile code directly on accelerators like GPUs and TPUs without any changes, making the process seamless. A user can write their code just once using syntax that is similar to Numpy, try it out on the CPU and then shift it to a GPU cluster smoothly.
Automatic differentiation: JAX aims to differentiate between native Python and Numpy functions automatically. Most of the optimisation algorithms in machine learning use the gradients of the functions to minimise losses. JAX simplifies differentiation with the help of the updated version of autograd.
Vectorisation: JAX offers automatic vectorisation via the vmap transformation, which makes life easier for developers. In ML research, a single function is applied to a lot of data at times, say to calculate losses across a batch or to evaluate per-example gradients for differentially private learning. In instances where the data is too large for a single accelerator, JAX performs data parallelism on a large scale using the related pmap transformation.
Source: Developpaper.com
Deep learning: While JAX is not just a deep learning framework, it has proven to be a solid foundation for deep learning tasks. Libraries like flax, haiku and elegy have been built on top of JAX for deep learning processes. Hessians perform higher-order optimisation techniques in deep learning, and JAX is efficient at computing them. JAX is able to compute Hessians much faster than PyTorch, thanks to XLA.
Why not JAX?Source: Developpaper.com
Despite JAX’s gradual growth, it is currently employed in a range of projects, like in bayesian methods and robotics, apart from deep learning. Last week, DeepMind announced four new libraries that would join their ecosystem. Mctx offers AlphaZero and MuZero Monte Carlo tree search, KFAC-JAX is a library for second-order optimisation of neural networks and for computing scalable curvature approximations, DM_AUX which is for audio signal processing in JAX, providing tools for spectrogram extraction and SpecAug augmentation, and TF2JAX which is a library for converting TensorFlow functions and graphs to JAX functions.
https://twitter.com/DeepMind/status/1517146462571794433?s=20&t=oMtOMASe5_zWnWPUcoQhcA
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3