An excerpt from JAX in Action by Grigory Sapunov This excerpt covers:
Read it if you’re a Python developer or machine learning practitioner who is interested in JAX and what it can do. |
Take 25% off JAX in Action by entering fccsapunov into the discount code box at checkout at manning.com.
JAX is gaining popularity as more and more researchers are adopting it for their research and large companies such as DeepMind have started to contribute to its ecosystem.
In this excerpt, we will introduce JAX and its powerful ecosystem. We will explain what JAX is and how it relates to NumPy, PyTorch, and TensorFlow. We will go through JAX’s strengths to understand how they work together, giving you a very powerful tool for deep learning research and high-performance computing.
What is JAX?
JAX is a Python mathematics library with a NumPy interface developed by Google (the Google Brain team, to be specific). It is heavily used for machine learning research, but it is not limited to this—many other things can be solved with JAX.
JAX creators describe it as Autograd and XLA. Do not be afraid if you are unfamiliar with these names; it’s normal, especially if you are just getting into the field.
Autograd (https://github.com/hips/autograd) is a library that efficiently computes derivatives of NumPy code; it is the predecessor of JAX. Autograd’s main developers are now working on JAX. In a few words, Autograd lets you automatically calculate gradients for your computations, which is the essence of deep learning and many other fields, including numerical optimization, physics simulations, and, more generally, differentiable programming.
XLA is Google’s domain-specific compiler for linear algebra called Accelerated Linear Algebra. It compiles your Python functions with linear algebra operations to high-performance code for running on GPUs or TPUs.
Let’s start with the NumPy part.
JAX as NumPy
NumPy is a workhorse of Python numerical computing. It is so widely used in the industry and science that NumPy API became the de facto standard for working with multidimensional arrays in Python. JAX provides a NumPy-compatible API but offers many other features that are absent in NumPy, so some people call JAX the ‘NumPy on Steroids’.
JAX provides a multidimensional array data structure called DeviceArray
that implements many typical properties and methods of the numpy.ndarray
. There is also the jax.numpy
package that implements the NumPy API with many well-known functions like abs()
, conv()
, exp()
, and so on.
JAX tries to follow the NumPy API as closely as possible, and in many cases, you can switch from numpy
to jax.numpy
without changing your program.
There are still some limitations, and not all NumPy code can be used with JAX. JAX promotes a functional programming paradigm and requires pure functions without side effects. As a result, JAX arrays are immutable, while NumPy programs frequently use in-place updates, like arr[i] += 10
. JAX has a workaround by providing an alternative purely functional API that replaces in-place updates with a pure indexed update function. For this particular case, it will be arr = arr.at[i].add(10)
. There are a few other differences, which are addressed in the book.
So, you can use almost all the power of NumPy and write your programs in the way you are accustomed to when using NumPy. But, JAX lets you do much more.
Composable transformations
JAX is much more than NumPy. It provides a set of composable function transformations for Python+NumPy code. At its core, JAX is an extensible system for transforming numerical functions with four main transformations (but it doesn’t mean that no more transformations are to come!):
- Taking the gradient of your code or differentiating it. This is the essence of deep learning and many other fields, including numerical optimization, physics simulations, and, more generally, differentiable programming. JAX uses an approach called automatic differentiation (or autodiff for short). Automatic differentiation helps you focus on your code and not deal directly with derivatives; the framework takes care of it. This is typically done by the
grad()
function, but other advanced options exist. - Compiling your code with
jit()
, or Just-in-Time compilation. JIT uses Google’s XLA to compile and produce efficient code for GPUs (typically NVIDIA ones through CUDA, though AMD ROCm platform support is in progress) and TPUs (Google’s Tensor Processing Units). XLA is the backend that powers machine learning frameworks, originally TensorFlow, on various devices, including CPUs, GPUs, and TPUs. - Auto-vectorization of your code with
vmap(),
which is the vectorizing map. If you are familiar with functional programming, you probably know what a map is. If not, do not worry; we will describe in detail what it means in the book.vmap()
takes care of batch dimensions of your arrays and can easily convert your code from processing a single item of data to processing many items (called a batch) at once. You may call it auto-batching. By doing this, you vectorize the computation, which typically gives you a significant boost on the modern hardware that can efficiently parallelize matrix computations. - Parallelizing your code to run on multiple accelerators, say, GPUs or TPUs. This is done with
pmap()
, which helps write single-program multiple-data (SPMD) programs.pmap()
compiles a function with XLA, then replicates it and executes each replica on its XLA device in parallel.
Why use JAX?
JAX is gaining momentum now. The well-known State of AI 2021 report labeled JAX as a new framework challenger.
Deep learning researchers and practitioners love JAX. More and more new research is being done with JAX. Among the recent research papers, I can mention Vision Transformer (ViT) and MLP-Mixer by Google. Deepmind announced that they are using JAX to accelerate their research, and JAX is easy to adopt as both Python and NumPy are widely used and familiar. Its composable function transformations help support machine learning research, and JAX has enabled rapid experimentation with novel algorithms and architectures, and it now underpins many of DeepMind’s recent publications. Among them, I’d highlight a new approach to self-supervised Learning called BYOL (“Bootstrap your own latent”), a general transformer-based architecture for structured inputs and outputs called Perceiver IO, and research on large language models with 280-billion parameters Gopher and 70-billion parameters Chinchilla.
In the middle of 2021, Huggingface made JAX/Flax the 3rd officially supported framework in their well-known Transformers library. The Huggingface collection of pretrained models already has twice as many JAX models (5,530) than TensorFlow models (2,221) as of April 2022. PyTorch is still ahead of both with 24,467 models, and porting models from PyTorch to JAX/Flax is an ongoing effort.
One of the open-source large GPT-like models called GPT-J-6B by EleutherAI, the 6 billion parameter transformer language model, was trained with JAX on Google Cloud. The authors state it was the right set of tools to develop large-scale models rapidly.
JAX might not be very suitable for production deployment right now, as it primarily focuses on the research side, but that was precisely the way PyTorch went. The gap between the research and production will most likely be closed soon. The Huggingface and GPT-J-6B cases are already moving in the right direction. Given Google’s clout and the rapid expansion of the community, I’d expect a bright future for JAX.
JAX is not limited to deep learning. There are many exciting applications and libraries on top of JAX for physics, including molecular dynamics, fluid dynamics, rigid body simulation, quantum computing, astrophysics, ocean modeling, and so on. There are libraries for distributed matrix factorization, streaming data processing, protein folding, and chemical modeling, with other new applications emerging constantly.
How is JAX different from TensorFlow/PyTorch?
We already discussed how JAX compares to NumPy. Let’s compare JAX with the two other leading modern deep learning frameworks: PyTorch and TensorFlow.
We mentioned that JAX promotes the functional approach compared to the object-oriented approach common to PyTorch and TensorFlow. It is the first very tangible thing you face when you start programming with JAX. It changes how you structure your code and require some changing of habits. At the same time, it gives you powerful function transformations, forces you to write clean code, and brings rich compositionality.
Another tangible thing you soon notice is that JAX is pretty minimalistic. It does not implement everything. TensorFlow and PyTorch are the two most popular and well-developed deep learning frameworks, and they come standard with all the bells and whistles. Compared to them, JAX is a very minimalistic framework, so much so that it’s even hard to name it a framework. It’s rather a library.
For example, JAX does not provide any data loaders because other libraries (e.g. PyTorch or Tensorflow) do this well. JAX’s authors did not want to reimplement everything; they wanted to focus on the core. And that’s precisely the case where you can and should mix JAX and other deep learning frameworks. It is OK to take the data loading stuff from, say, PyTorch and use it. PyTorch has excellent data loaders, so let each library use its strengths. No need to reinvent the wheel if someone else has already made a good wheel.
Another noticeable thing is that JAX primitives are pretty low-level, and writing large neural networks in terms of matrix multiplications could be time-consuming. Hence, you need a higher-level language to specify such models. JAX does not provide such high-level APIs out of the box (similar to TensorFlow 1 before the high-level Keras API was added to TensorFlow 2). There are no such features included, but this is not a problem as there are high-level libraries for the JAX ecosystem as well.
You don’t need to write your neural networks with NumPy-like primitives. There are excellent neural network libraries like Flax by Google and Haiku by DeepMind, and the Optax library with its collection of state-of-the-art optimizers. And many more!
Figure 3 visualizes the difference between PyTorch/TensorFlow and JAX.
Figure 3. JAX vs. PyTorch/TensorFlow comparison
Because JAX is an extensible system for composable function transformations, it is easy to build separate modules for everything and mix and match according to what you need.
What you will learn from this book
In short, this book will teach you how to leverage JAX for machine and deep learning. You’ll learn how to use JAX’s ecosystem of high-level libraries and modules, and also how to combine TensorFlow and PyTorch with JAX for data loading and deployment.
Who is this book for?
This book is intended for intermediate Python programmers who are familiar with deep learning and/or machine learning. Anyone who is interested in JAX, however, will find this book to be an engaging and informative read.
If you want to learn more about the book, check it out here.