Easy JAX training loops with Flax and Optax
In my previous blog post, I discussed JAX – a framework for high performance numerical computing and machine learning — in an atypical manner. I didn't create a single training loop, and only showed a couple patterns that looked vaguely machine learning-like. If you haven't read that blog post yet, you can read it here.
This approach was deliberate as I felt that JAX — although designed for machine
learning research — is more general-purpose than that. The steps to use it are
to define what you want to happen, wrap it in within jax.jit
, let JAX trace
out your function into an intermediate graph representation, which is then
passed to XLA where it will be compiled and optimised. The result is a single,
heavily-optimised, binary blob, ready and waiting to receive your data. This
approach is a natural fit for many machine learning applications, as well as
other scientific computing tasks. Therefore, targeting machine learning only
didn't make sense. It is also ground that has already been extensively covered — I wanted to do a different take on introductory JAX.
In the previous post, I mentioned that it is possible to develop a full machine learning training loop — models, optimisers and all — in pure JAX. This is self-evident as JAX is general-purpose. It is a good exercise, but not a strategy I like to employ. In this blog post I want to introduce two higher level libraries built on top of JAX, that do a lot of the heavy lifting for us when writing machine learning applications. These libraries are Flax and Optax.
To summarise the libraries:
- JAX — provides a high-level neural network API that lets the developer reason about the model in terms of components, like in PyTorch, rather than with JAX functions that take parameters as inputs.
- Optax — a library containing a vast array of model training utilities, such as optimisers, loss functions, learning rate schedulers, and more! Very batteries-included.
At the end of this post, we will have implemented and trained a very simple class-conditioned image generation model called a variational autoencoder (VAE) to generate MNIST digits.
Neural Network API with Flax
The high level structure of a training loop in pure JAX, looks something like this:
dataset = ... # initialise training dataset that we can iterate over
params = ... # initialise trainable parameters of our model
epochs = ...
def model_forward(params, batch):
... # perform a forward pass of our model on `batch` using `params`
return outputs
def loss_fn(params, batch):
model_output = model_forward(params, batch)
loss = ... # compute a loss based on `batch` and `model_output`
return loss
@jax.jit
def train_step(params, batch):
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
grads = ... # transform `grads` (clipping, multiply by learning rate, etc.)
params = ... # update `params` using `grads` (such as via SGD)
return params, loss
for _ in range(epochs):
for batch in dataset:
params, loss = train_step(params, batch)
... # report metrics like loss, accuracy, and the like.
We define our model in a functional manner: a function that takes the model parameters and a batch as input, and returns the output of the model. Similarly, we define the loss function that also takes the parameters and a batch as input, but returns the loss instead.
Our final function is the train step itself which we wrap in jax.jit
– giving
XLA maximum context to compile and optimise the training step. This first
computes the gradient of the loss function using the function transform
jax.value_and_grad
, manipulates the returned gradients (perhaps scaling by a
learning rate), and updates the parameters. We return the new parameters, and
use them on the next call to train_step
. This is called in a loop, fetching
new batches from the dataset before each training step.
Most machine learning programs follow a pattern such as the one above. But in
frameworks like PyTorch, we package together the model forward pass and the
management of model parameters into a stateful object representing our model –
simplifying the training loop. It would be nice if we could imitate this
behaviour in stateless JAX to allow the developer to reason about models in a
class-based way. This is what Flax's neural network API – flax.linen
– aims
to achieve.
Whether or not writing models in a purely stateless, functional way is better than a stateful, class-based way, is not the topic of this blog post. Both have merits. Regardless, during execution the final result is the same whether we use Flax or not. We get a stateless, heavily-optimised, binary blob that we throw data at. It's all JAX after all.
There are two main ways to define a module in Flax: one is PyTorch-like and the other is a compact representation:
import flax.linen as nn
from typing import Callable
class Model(nn.Module):
dim: int
activation_fn: Callable = nn.relu
def setup(self):
self.layer = nn.Dense(self.dim)
def __call__(self, x):
x = self.layer(x)
return self.activation_fn(x)
class ModelCompact(nn.Module):
dim: int
activation_fn: Callable = nn.relu
@nn.compact
def __call__(self, x):
x = nn.Dense(self.dim)(x)
return self.activation_fn(x)
If we have complex initialisation logic, the former may be more appropriate.
Instead, if the module is relatively simple, we can make use of the
nn.compact
representation to automatically define the module by the forward
pass alone.
Like other frameworks, we can nest modules within each other to implement
complex model behaviour. Like we've already seen, flax.linen
provides some
pre-baked modules like nn.Dense
(same as PyTorch's nn.Linear
). I won't
enumerate them all, but the usual candidates are all there like convolutions,
embeddings, and more.
Something to bear in mind if you are porting models from PyTorch to Flax is that the default weight initialisation may be different. For example, in PyTorch the default bias initialisation is the LeCun normal, but in Flax it is initialised to zero.
However, currently we cannot call this model, even if we were to initialise the class itself. There simply aren't any parameters to use. Furthermore, the module is never a container for parameters. An instance of a Flax module is simply a hollow shell, that loosely associates operations with parameters and inputs that are passed as input later.
To see what I mean, let's initialise some parameters for our model:
key = jax.random.PRNGKey(0xffff)
key, model_key = jax.random.split(key)
model = Model(dim=4)
params = model.init(model_key, jnp.zeros((1, 8)))
params
===
Out:
FrozenDict({
params: {
layer: {
kernel: Array([[-0.05412389, -0.28172645, -0.07438638, 0.5238516 ],
[-0.13562573, -0.17592733, 0.45305118, -0.0650041 ],
[ 0.25177842, 0.13981569, -0.41496065, -0.15681015],
[ 0.13783392, -0.6254694 , -0.09966562, -0.04283331],
[ 0.48194656, 0.07596914, 0.0429794 , -0.2127948 ],
[-0.6694777 , 0.15849823, -0.4057232 , 0.26767966],
[ 0.22948688, 0.00706845, 0.0145666 , -0.1280596 ],
[ 0.62309605, 0.12575962, -0.05112049, -0.316764 ]], dtype=float32),
bias: Array([0., 0., 0., 0.], dtype=float32),
},
},
})
In the above cell, we first initialised our model class, which returns an
instance of Model
which we assign to the variable model
. Like I said, it
does not contain any parameters, it is just a hollow shell that we pass
parameters and inputs to. We can see this by printing the model
variable
itself:
model
===
Out: Model(
# attributes
dim = 4
activation_fn = relu
)
We can also call the module itself, which will fail even though we have defined
the __call__
method:
```python
model(jnp.zeros((1, 8)))
Out: /usr/local/lib/python3.10/dist-packages/flax/linen/module.py in getattr(self, name) 935 msg += (f' If "{name}" is defined in '.setup()', remember these fields ' 936 'are only accessible from inside 'init' or 'apply'.') --> 937 raise AttributeError(msg) 938 939 def dir(self) -> List[str]:
AttributeError: "Model" object has no attribute "layer". If "layer" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.
To initialise the parameters, we passed a PRNG key and some dummy inputs to the
model's `init` function of the same shape and dtype as the inputs we will use
later. In this simple case, we just pass `x` as in the original module's
`__call__` definition, but could be multiple arrays, PyTrees, or PRNG keys. We
need the input shapes and dtypes in order to determine the shape and dtype of
the model parameters.
From the `model.init` call, we get a nested `FrozenDict` holding our model's
parameters. If you have seen PyTorch state dictionaries, the format of the
parameters is similar: nested dictionaries with meaningful named keys, with
parameter arrays as values. If you've read my previous blog post or read about
JAX before, you will know that this structure is a PyTree. Not only does Flax
help developers loosely associate parameters and operations, **it also helps
initialise model parameters based on the model definition**.
With the parameters, we can call the model using `model.apply` – providing the
parameters and inputs:
```python
key, x_key = jax.random.split(key)
x = jax.random.normal(x_key, (1, 8))
y = model.apply(params, x)
y
===
Out: Array([[0.9296505 , 0.25998798, 0.01101626, 0. ]], dtype=float32)
There is nothing special about the PyTree returned by model.init
– it is just
a regular PyTree storing the model's parameters. params
can be swapped with
any other PyTree that contains the parameters model
expects:
```python
zero_params = jax.tree_map(jnp.zeros_like, params) # generates a PyTree with same structure as params
will all values set to 0.
print(zero_params)
model.apply(zero_params, x)
Out: FrozenDict({ params: { layer: { bias: Array([0., 0., 0., 0.], dtype=float32), kernel: Array([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], dtype=float32), }, }, })
Array([[0., 0., 0., 0.]], dtype=float32)
Forcing model calls to require explicitly passing parameters keeps it
stateless and returning parameters like any other PyTree, makes Flax
interoperable with JAX functions – as well as other libraries built on JAX.
**Essentially, by using Flax we aren't forced to use any other specific
frameworks and have access to all regular JAX features.**
If you are used to frameworks like PyTorch, calling models like this feels
unnatural at first. However, I personally quite like it this way – it feels
rather elegant to pass different parameters to the model to get different
behaviour rather than "load" the weights. A bit subjective and fuzzy, I know,
but I like it.
> To summarise the difference, if we aim to implement $f_\theta(x)$, a PyTorch
module is basically $f_\theta$ (which we can call on $x$). A Flax module is
simply $f$, which needs to be provided parameters $\theta$ before it can be
called on $x$ – or alternatively, we call $f$ on $(\theta, x)$.
All in all, the point of Flax is to **provide a familiar stateful API for
development** whilst **preserving JAX statelessness during runtime**. We can
build our neural network modules in terms of classes and objects, but **the
final result is a stateless function `model.apply` that takes in our inputs and
a PyTree of parameters.**
This is identical behaviour to what we began with (recall our `model_forward`
function at the start of this section), just now tied up nicely together.
Therefore, our function containing `model.apply` that takes as input our
PyTree, can be safely jit-compiled. The result is the same, a heavily-optimised
binary blob we bombard with data. Nothing changes during runtime, it just makes
development easier for those who prefer reasoning about neural networks in a
class-based way whilst remaining interoperable with, and keeping the
performance of JAX.
There's a lot more to Flax than this, especially outside the `flax.linen`
neural network API. For now though, we will move on to developing a full
training loop using Flax and **Optax**. We will swing back around to some extra
Flax points later, but I feel some concepts are hard to explain without first
showing a training loop.
## A full training loop with Optax and Flax
We've shown how to reduce the complexity of writing model code and parameter
initialisation. We can push this further by relying on Optax to handle the
gradient manipulation and parameter updates in `train_step`. For simple
optimisers, these steps can be quite simple. However, for more complex
optimisers or gradient transformation behaviour, it can get quite complex to
implement in JAX alone. Optax packages this complex behaviour into a simple
API.
```python
import optax
optimiser = optax.sgd(learning_rate=1e-3)
optimiser
===
Out: GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7fa7185503a0>, update=<function chain.<locals>.update_fn at 0x7fa718550550>)
Not pretty, but we can see that the optimiser is just a gradient
transformation – in fact all optimisers in Optax are implemented as gradient
transformations. A gradient transformation is defined to be a pair of functions
init
and update
, which are both pure functions. Like a Flax model, Optax
optimisers have no state kept internally, and must be initialised before it can
be used, and any state must be passed by the developer to update
:
```python
optimiser_state = optimiser.init(params)
optimiser_state
Out: (EmptyState(), EmptyState())
Of course, as SGD is a stateless optimiser, the initialisation call simply
returns an empty state. It must return this to maintain the API of a gradient
transformation. Let's try with a more complex optimiser like Adam:
```python
optimiser = optax.adam(learning_rate=1e-3)
optimiser_state = optimiser.init(params)
optimiser_state
===
Out: (ScaleByAdamState(count=Array(0, dtype=int32), mu=FrozenDict({
params: {
layer: {
bias: Array([0., 0., 0., 0.], dtype=float32),
kernel: Array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
},
},
}), nu=FrozenDict({
params: {
layer: {
bias: Array([0., 0., 0., 0.], dtype=float32),
kernel: Array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
},
},
})),
EmptyState())
Here, we can see the first and second order statistics of the Adam optimiser,
as well as a count storing number of optimiser updates. Like with SGD, this
state needs to be passed to update
when called.
Like Flax parameters, the optimiser state is just a PyTree. Any PyTree with a compatible structure could also be used. Again, this also allows interoperability with JAX and
jax.jit
, as well as other libraries built on top of JAX.
Concretely, Optax gradient transformations are simply a named tuple
containing pure functions init
and update
. init
is a pure function which
takes in an example instance of gradients to be transformed and returns the
optimiser initial state. In the case of optax.sgd
this returns an empty state
regardless of the example provided. For optax.adam
, we get a more complex
state containing the first and second order statistics of the same PyTree
structure as the provided example.
update
takes in a PyTree of updates with the same structure as the example
instance provided to init
. In addition, it takes in the optimiser state
returned by init
and optionally the parameters of the model itself, which may
be needed for some optimisers. This function will return the transformed
gradients (which could be another set of gradients, or the actual parameter
updates) and the new optimiser state.
This is explained quite nicely in the documentation here
Providing some dummy data, we get the following:
import optax
params = jnp.array([0.0, 1.0, 2.0]) # some dummy parameters
optimiser = optax.adam(learning_rate=0.01)
opt_state = optimiser.init(params)
grads = jnp.array([4.0, 0.6, -3])# some dummy gradients
updates, opt_state = optimiser.update(grads, opt_state, params)
updates
===
Out: Array([-0.00999993, -0.00999993, 0.00999993], dtype=float32)
Optax provides a helper function to apply the updates to our parameters: ```python new_params = optax.apply_updates(params, updates) new_params
Out: Array([-0.00999993, 0.99000007, 2.01 ], dtype=float32)
It is important to emphasise that Optax optimisers are gradient transformations,
**but gradient transformations are not just optimisers.** We'll see more of that
later after we finish the training loop.
On that note, let's begin with said training loop. Recall that our goal is to
train a class-conditioned, variational autoencoder (VAE) on the MNIST dataset.
> I chose this example as it is slightly more interesting than the typical
classification example found in most tutorials.
Not strictly related to JAX, Flax, or Optax, but it is worth describing what a
VAE is. First, an autoencoder model is one that maps some input $x$ in our data
space to a **latent vector** $z$ in the **latent space** (a space with smaller
dimensionality than the data space) and back to the data space. It is trained to
minimise the reconstruction loss between the input and the output, essentially
learning the identity function through an **information bottleneck**.
The portion of the network that maps from the data space to the latent space is
called the **encoder** and the portion that maps from the latent space to the
data space is called the **decoder**. Applying the encoder is somewhat
analogous to lossy compression. Likewise, applying the decoder is akin to
lossy decompression.
What makes a VAE different to an autoencoder is that the encoder does not
output the latent vector directly. Instead, **it outputs the mean and
log-variance of a Gaussian distribution, which we then sample from in order
to obtain our latent vector**. We apply an extra loss term to make these mean and
log-variance outputs roughly follow the standard normal distribution.
> Interestingly, defining the encoder this way means for every given input $x$
we have many possible latent vectors which are sampled stochastically. Our
encoder is almost mapping to a sphere of possible latents centred at the mean
vector with radius scaling with log-variance.
The decoder is the same as before. However, now we can sample **a latent from
the normal distribution and pass it to the decoder in order to generate samples
like those in the dataset**! Adding the variational component turns our
autoencoder compression model into a VAE generative model.
![image/jpeg](https://cdn-uploads.huggingface.co/production/uploads/6079e94c88160e14e4e2e4b1/KDnjRhV8eRodv_8c20Vsw.jpeg)
> Abstract diagram of a VAE, pilfered from [this AWS blog](https://aws.amazon.com/blogs/machine-learning/deploying-variational-autoencoders-for-anomaly-detection-with-tensorflow-serving-on-amazon-sagemaker/)
Our goal is to implement the model code for the VAE as well as the training loop
with both the reconstruction and variational loss terms. Then, we can sample new
digits that look like those in the MNIST dataset! Additionally, we will provide
an extra input to the model – the class index – so we can control which number
we want to generate.
Let's begin by defining our configuration. For this educational example, we will
just define some constants in a cell:
```python
batch_size = 16
latent_dim = 32
kl_weight = 0.5
num_classes = 10
seed = 0xffff
Along with some imports and PRNG initialisation:
import jax # install correct wheel for accelerator you want to use
import flax
import optax
import orbax
import flax.linen as nn
import jax.numpy as jnp
import numpy as np
from jax.typing import ArrayLike
from typing import Tuple, Callable
from math import sqrt
import torchvision.transforms as T
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
key = jax.random.PRNGKey(seed)
Let's grab our MNIST dataset while we are here:
train_dataset = MNIST('data', train = True, transform=T.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
JAX, Flax, and Optax do not have data loading utilities, so I just use the perfectly serviceable PyTorch implementation of the MNIST dataset here.
Now to our first real Flax model. We begin by defining a submodule FeedForward
that implements a stack of linear layers with intermediate non-linearities:
class FeedForward(nn.Module):
dimensions: Tuple[int] = (256, 128, 64)
activation_fn: Callable = nn.relu
drop_last_activation: bool = False
@nn.compact
def __call__(self, x: ArrayLike) -> ArrayLike:
for i, d in enumerate(self.dimensions):
x = nn.Dense(d)(x)
if i != len(self.dimensions) - 1 or not self.drop_last_activation:
x = self.activation_fn(x)
return x
key, model_key = jax.random.split(key)
model = FeedForward(dimensions = (4, 2, 1), drop_last_activation = True)
print(model)
params = model.init(model_key, jnp.zeros((1, 8)))
print(params)
key, x_key = jax.random.split(key)
x = jax.random.normal(x_key, (1, 8))
y = model.apply(params, x)
y
===
Out:
FeedForward(
# attributes
dimensions = (4, 2, 1)
activation_fn = relu
drop_last_activation = True
)
FrozenDict({
params: {
Dense_0: {
kernel: Array([[ 0.0840368 , -0.18825287, 0.49946404, -0.4610112 ],
[ 0.4370267 , 0.21035315, -0.19604324, 0.39427406],
[ 0.00632685, -0.02732705, 0.16799504, -0.44181877],
[ 0.26044282, 0.42476758, -0.14758752, -0.29886967],
[-0.57811564, -0.18126923, -0.19411889, -0.10860331],
[-0.20605426, -0.16065307, -0.3016759 , 0.44704655],
[ 0.35531637, -0.14256613, 0.13841921, 0.11269159],
[-0.430825 , -0.0171169 , -0.52949774, 0.4862139 ]], dtype=float32),
bias: Array([0., 0., 0., 0.], dtype=float32),
},
Dense_1: {
kernel: Array([[ 0.03389561, -0.00805947],
[ 0.47362345, 0.37944487],
[ 0.41766328, -0.15580587],
[ 0.5538078 , 0.18003668]], dtype=float32),
bias: Array([0., 0.], dtype=float32),
},
Dense_2: {
kernel: Array([[ 1.175035 ],
[-1.1607001]], dtype=float32),
bias: Array([0.], dtype=float32),
},
},
})
Array([[0.5336972]], dtype=float32)
We use the nn.compact
decorator here as the logic is relatively simple. We
iterate over the tuple self.dimensions
and pass our current activations
through a nn.Dense
module, followed by applying self.activation_fn
. This
activation can optionally be dropped for the final linear layer in
FeedForward
. This is needed as nn.relu
only outputs non-negative values,
whereas sometimes we need non-negative outputs!
Using FeedForward
, we can define our full VAE model:
class VAE(nn.Module):
encoder_dimensions: Tuple[int] = (256, 128, 64)
decoder_dimensions: Tuple[int] = (128, 256, 784)
latent_dim: int = 4
activation_fn: Callable = nn.relu
def setup(self):
self.encoder = FeedForward(self.encoder_dimensions, self.activation_fn)
self.pre_latent_proj = nn.Dense(self.latent_dim * 2)
self.post_latent_proj = nn.Dense(self.encoder_dimensions[-1])
self.class_proj = nn.Dense(self.encoder_dimensions[-1])
self.decoder = FeedForward(self.decoder_dimensions, self.activation_fn, drop_last_activation=False)
def reparam(self, mean: ArrayLike, logvar: ArrayLike, key: jax.random.PRNGKey) -> ArrayLike:
std = jnp.exp(logvar * 0.5)
eps = jax.random.normal(key, mean.shape)
return eps * std + mean
def encode(self, x: ArrayLike):
x = self.encoder(x)
mean, logvar = jnp.split(self.pre_latent_proj(x), 2, axis=-1)
return mean, logvar
def decode(self, x: ArrayLike, c: ArrayLike):
x = self.post_latent_proj(x)
x = x + self.class_proj(c)
x = self.decoder(x)
return x
def __call__(
self, x: ArrayLike, c: ArrayLike, key: jax.random.PRNGKey) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
mean, logvar = self.encode(x)
z = self.reparam(mean, logvar, key)
y = self.decode(z, c)
return y, mean, logvar
key = jax.random.PRNGKey(0x1234)
key, model_key = jax.random.split(key)
model = VAE(latent_dim=4)
print(model)
key, call_key = jax.random.split(key)
params = model.init(model_key, jnp.zeros((batch_size, 784)), jnp.zeros((batch_size, num_classes)), call_key)
recon, mean, logvar = model.apply(params, jnp.zeros((batch_size, 784)), jnp.zeros((batch_size, num_classes)), call_key)
recon.shape, mean.shape, logvar.shape
===
Out:
ClassVAE(
# attributes
encoder_dimensions = (256, 128, 64)
decoder_dimensions = (128, 256, 784)
latent_dim = 4
activation_fn = relu
)
((16, 784), (16, 4), (16, 4))
There is a lot to the above cell. Knowing the specifics of how this model works isn't too important to understanding the training loop later, as we can treat the model as a bit of a black box. Simply substitute your own model of choice. Saying that, I'll unpack each function briefly:
setup
: Creates the submodules of the network, namely twoFeedForward
stacks and twonn.Linear
layers that project to and from the latent space. Additionally, it initialises a thirdnn.Linear
layer that projects our class conditioning vector to the same dimensionality as the last encoder layer.reparam
: Sampling a latent directly from a random Gaussian is not differentiable, hence we employ the reparameterisation trick. This involves sampling a random vector, scaling by the standard deviation, then adding to the mean. As it involves random array generation, we take as input a key in addition to the mean and log-variance.encode
: Applies the encoder and projection to the latent space to the input. Note, the output of the projection is actually double the size of the latent space, as we split it in twine to obtain our mean and log-variance.decode
: Applies a projection from the latent space tox
, followed by adding the output ofclass_proj
on the conditioning vector. This is how we inject the class information into the model. Finally, it passes the result through the decoder stack.__call__
: This is simply the full model forward pass:encode
thenreparam
thendecode
. This is used during training.
The above example also demonstrates that we can add other functions to our Flax
modules aside from setup
and __call__
. This is useful for more complex
behaviour, or if we want to only execute parts of the model (more on this
later).
We now have our model, optimiser, and dataset. The next step is to write the function that implements our training step and then jit-compile it:
def create_train_step(key, model, optimiser):
params = model.init(key, jnp.zeros((batch_size, 784)), jnp.zeros((batch_size, num_classes)), jax.random.PRNGKey(0)) # dummy key just as example input
opt_state = optimiser.init(params)
def loss_fn(params, x, c, key):
reduce_dims = list(range(1, len(x.shape)))
c = jax.nn.one_hot(c, num_classes) # one hot encode the class index
recon, mean, logvar = model.apply(params, x, c, key)
mse_loss = optax.l2_loss(recon, x).sum(axis=reduce_dims).mean()
kl_loss = jnp.mean(-0.5 * jnp.sum(1 + logvar - mean ** 2 - jnp.exp(logvar), axis=reduce_dims)) # KL loss term to keep encoder output close to standard normal distribution.
loss = mse_loss + kl_weight * kl_loss
return loss, (mse_loss, kl_loss)
@jax.jit
def train_step(params, opt_state, x, c, key):
losses, grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, c, key)
loss, (mse_loss, kl_loss) = losses
updates, opt_state = optimiser.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss, mse_loss, kl_loss
return train_step, params, opt_state
Here, I don't define the training step directly, but rather define a function that returns the training step function given a target model and optimiser, along with returning the freshly initialised parameters and optimiser state.
Let us unpack it all:
- First, it initialises our model using an example input. In this case, this is a 784-dim array which contains the (flattened) MNIST digit and a random, random key.
- Also initialises the optimiser state using the parameters we just initialised.
- Now, it defines the loss function. This is simply a
model.apply
call which returns the model's reconstruction of the input, along with the predicted mean and log-variance. We then compute the mean-squared error loss and the KL-divergence, before finally computing a weighted sum to get our final loss. The KL loss term is what keeps the encoder outputs close to a standard normal distribution. - Next, the actual train step definition. This begins by transforming
loss_fn
using our old friendjax.value_and_grad
which will return the loss and also the gradients. We must sethas_aux=True
as we return all individual loss terms for logging purposes. We provide the gradients, optimiser state, and parameters tooptimiser.update
which returns the transformed gradients and the new optimiser state. The transformed gradients are then applied to the parameters. Finally, we return the new parameters, optimiser state, and loss terms – followed by wrapping the whole thing injax.jit
. Phew..
A function that generates the training step is just a pattern I quite like, and there is nothing stopping you from just writing the training step directly.
Let's call create_train_step
:
key, model_key = jax.random.split(key)
model = VAE(latent_dim=latent_dim)
optimiser = optax.adamw(learning_rate=1e-4)
train_step, params, opt_state = create_train_step(model_key, model, optimiser)
When we call the above, we get a train_step
ready to be compiled and accept
our parameters, optimiser state, and data at blistering fast speeds. As always
with jit-compiled functions, the first call with a given set of input shapes
will be slow, but fast on subsequent calls as we skip the compiling and
optimisation process.
We are now in a position to write our training loop and train the model!
freq = 100
for epoch in range(10):
total_loss, total_mse, total_kl = 0.0, 0.0, 0.0
for i, (batch, c) in enumerate(train_loader):
key, subkey = jax.random.split(key)
batch = batch.numpy().reshape(batch_size, 784)
c = c.numpy()
params, opt_state, loss, mse_loss, kl_loss = train_step(params, opt_state, batch, c, subkey)
total_loss += loss
total_mse += mse_loss
total_kl += kl_loss
if i > 0 and not i % freq:
print(f"epoch {epoch} | step {i} | loss: {total_loss / freq} ~ mse: {total_mse / freq}. kl: {total_kl / freq}")
total_loss = 0.
total_mse, total_kl = 0.0, 0.0
===
Out:
epoch 0 | step 100 | loss: 49.439998626708984 ~ mse: 49.060447692871094. kl: 0.7591156363487244
epoch 0 | step 200 | loss: 37.1823616027832 ~ mse: 36.82903289794922. kl: 0.7066375613212585
epoch 0 | step 300 | loss: 33.82365036010742 ~ mse: 33.49456024169922. kl: 0.6581906080245972
epoch 0 | step 400 | loss: 31.904821395874023 ~ mse: 31.570871353149414. kl: 0.6679074764251709
epoch 0 | step 500 | loss: 31.095705032348633 ~ mse: 30.763246536254883. kl: 0.6649144887924194
epoch 0 | step 600 | loss: 29.771989822387695 ~ mse: 29.42426872253418. kl: 0.6954278349876404
...
epoch 9 | step 3100 | loss: 14.035745620727539 ~ mse: 10.833460807800293. kl: 6.404574871063232
epoch 9 | step 3200 | loss: 14.31241226196289 ~ mse: 11.043667793273926. kl: 6.53748893737793
epoch 9 | step 3300 | loss: 14.26440143585205 ~ mse: 11.01070785522461. kl: 6.5073771476745605
epoch 9 | step 3400 | loss: 13.96005630493164 ~ mse: 10.816412925720215. kl: 6.28728723526001
epoch 9 | step 3500 | loss: 14.166285514831543 ~ mse: 10.919700622558594. kl: 6.493169784545898
epoch 9 | step 3600 | loss: 13.819541931152344 ~ mse: 10.632755279541016. kl: 6.373570919036865
epoch 9 | step 3700 | loss: 14.452215194702148 ~ mse: 11.186063766479492. kl: 6.532294750213623
Now that we have our train_step
function, the training loop itself is just
repeatedly fetching data, calling our uber-fast train_step
function, and
logging results so we can track training. We can see that the loss is
decreasing, which means our model is training!
Note that the KL-loss term increases during training. This is okay so long as it doesn't get too high, in which case sampling from the model becomes impossible. Tuning the hyperparameter
kl_weight
is quite important. Too low and we get perfect reconstructions but no sampling capabilities – too high and the outputs will become blurry.
Let's sample from the model so we can see that it does indeed produce some reasonable samples:
def build_sample_fn(model, params):
@jax.jit
def sample_fn(z: jnp.array, c: jnp.array) -> jnp.array:
return model.apply(params, z, c, method=model.decode)
return sample_fn
sample_fn = build_sample_fn(model, params)
num_samples = 100
h, w = 10
key, z_key = jax.random.split(key)
z = jax.random.normal(z_key, (num_samples, latent_dim))
c = np.repeat(np.arange(h)[:, np.newaxis], w, axis=-1).flatten()
c = jax.nn.one_hot(c, num_classes)
sample = sample_fn(z, c)
z.shape, c.shape, sample.shape
===
Out: ((100, 32), (100, 10), (100, 784))
The above cell generates 100 samples – 10 examples from each of the 10 classes.
We jit-compile our sample function in case we want to sample again later. We
only call the model.decode
method, rather than the full model, as we only
need to decode our randomly sampled latents. This is achieved by specifying
method=model.decode
in the model.apply
call.
Let's visualise the results using matplotlib:
import matplotlib.pyplot as plt
import math
from numpy import einsum
sample = einsum('ikjl', np.asarray(sample).reshape(h, w, 28, 28)).reshape(28*h, 28*w)
plt.imshow(sample, cmap='gray')
plt.show()
It seems our model did indeed train and can be sampled from! Additionally, the model is capable of using the class conditioning signal so that we can control which digits are generated. Therefore, we have succeeded in building a full training loop using Flax and Optax!
Extra Flax and Optax Tidbits
I'd like to finish this blog post by highlighting some interesting and useful features that may prove useful in your own applications. I won't delve into great detail with any of them, but simply summarise and point you in the right direction.
You may have noticed already that when we add parameters, optimiser states, and
a bunch of other metrics to the return call of train_step
it gets a bit
unwieldy to handle all the state. It could get worse if we later need a more
complex state. One solution would be to return a namedtuple
so we can at
least package the state together somewhat. However, Flax provides its own
solution, flax.training.train_state.TrainState
, which has some extra
functions that make updating the combined state (model and optimiser state)
easier.
It is easiest to show by simply taking our earlier train_step
and refactoring
it with TrainState
:
from flax.training.train_state import TrainState
def create_train_step(key, model, optimiser):
params = model.init(key, jnp.zeros((batch_size, 784)), jnp.zeros((batch_size, num_classes)), jax.random.PRNGKey(0))
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimiser)
def loss_fn(state, x, c, key):
reduce_dims = list(range(1, len(x.shape)))
c = jax.nn.one_hot(c, num_classes)
recon, mean, logvar = state.apply_fn(state.params, x, c, key)
mse_loss = optax.l2_loss(recon, x).sum(axis=reduce_dims).mean()
kl_loss = jnp.mean(-0.5 * jnp.sum(1 + logvar - mean ** 2 - jnp.exp(logvar), axis=reduce_dims))
loss = mse_loss + kl_weight * kl_loss
return loss, (mse_loss, kl_loss)
@jax.jit
def train_step(state, x, c, key):
losses, grads = jax.value_and_grad(loss_fn, has_aux=True)(state, x, c, key)
loss, (mse_loss, kl_loss) = losses
state = state.apply_gradients(grads=grads)
return state, loss, mse_loss, kl_loss
return train_step, state
We begin create_train_step
by initialising our parameters as before. However,
the next step is now to create the state using TrainState.create
and passing
our model forward call, the initialised parameters, and the optimiser we want
to use. Internally, TrainState.create
will initialise and store the optimiser
state for us.
In loss_fn
, rather than call model.apply
we can use state.apply_fn
instead. Either method is equivalent, just that sometimes we may not have
model
in scope and so can't access model.apply
.
The largest change is in train_step
itself. Rather than call
optimiser.update
followed by optax.apply_updates
, we simply call
state.apply_gradients
which internally updates the optimiser state and the
parameters. It then returns the new state, which we return and pass to the next
call of train_step
– as we would with params
and opt_state
.
It is possible to add extra attributes to
TrainState
by subclassing it, for example adding attributes to store the latest loss.
In conclusion, TrainState
makes it easier to pass around state in the training
loop, as well as abstracting away optimiser and parameter updates.
Another useful feature of Flax is the ability to bind parameters to a model, yielding an interactive instance that can be called directly, as if it were a PyTorch model with internal state. However, this state is static and can only change if we bind it again, which makes it unusable for training. However, it can be handy for interactive debugging or inference.
The API is pretty simple:
key, model_key = jax.random.split(key)
model = nn.Dense(2)
params = model.init(model_key, jnp.zeros(8))
bound_model = model.bind(params)
bound_model(jnp.ones(8))
===
Out: Array([ 0.45935923, -0.691003 ], dtype=float32)
We can get back the unbound model and its parameters by calling model.unbind
:
```python
bound_model.unbind()
Out: (Dense( # attributes features = 2 use_bias = True dtype = None param_dtype = float32 precision = None kernel_init = init bias_init = zeros dot_general = dot_general ), FrozenDict({ params: { kernel: Array([[-0.11450272, -0.2808447 ], [-0.45104247, -0.3774913 ], [ 0.07462895, 0.3622056 ], [ 0.59189916, -0.34050766], [-0.10401642, -0.36226135], [ 0.157985 , 0.00198693], [-0.00792678, -0.1142673 ], [ 0.31233454, 0.4201768 ]], dtype=float32), bias: Array([0., 0.], dtype=float32), }, }))
I said I wouldn't enumerate layers in Flax as I don't see much value in doing
so, but I will highlight two particularly interesting ones. First is
`nn.Dropout` which is numerically the same as its PyTorch counterpart, but like
anything random in JAX, requires a PRNG key as input.
The dropout layer takes its random key by internally calling
`self.make_rng('dropout')`, which pulls and splits from a PRNG stream named
`'dropout'`. This means when we call `model.apply` we will need to define the
starting key for this PRNG stream. This can be done by passing a dictionary
mapping stream names to PRNG keys, to the `rngs` argument in `model.apply`:
```python
key, x_key = jax.random.split(key)
key, drop_key = jax.random.split(key)
x = jax.random.normal(x_key, (3,3))
model = nn.Dropout(0.5, deterministic=False)
y = model.apply({}, x, rngs={'dropout': drop_key}) # there is no state, just pass empty dictionary :)
x, y
===
Out: (Array([[ 1.7353934, -1.741734 , -1.3312583],
[-1.615281 , -0.6381292, 1.3057163],
[ 1.2640097, -1.986926 , 1.7818599]], dtype=float32),
Array([[ 3.4707868, 0. , -2.6625166],
[ 0. , 0. , 2.6114326],
[ 0. , -3.973852 , 0. ]], dtype=float32))
model.init
also accepts a dictionary of PRNG keys. If you pass in a single key like we have done so far, it starts a stream named'params'
. This is equivalent to passing{'params': rng}
instead.
The streams are accessible to submodules, so nn.Dropout
can call
self.make_rng('dropout')
regardless of where it is in the model. We can
define our own PRNG streams by specifying them in the model.apply
call. In
our VAE example, we could forgo passing in the key manually, and instead get
keys for random sampling using self.make_rng('noise')
or similar, then
passing a starting key in rngs
in model.apply
. For models with lots of
randomness, it may be worth doing this.
The second useful built-in module is nn.Sequential
which is again like its
PyTorch counterpart. This simply chains together many modules such that the
outputs of one module will flow into the inputs of the next. Useful if we want
to define large stacks of layers quickly.
Now onto some Optax tidbits! First, Optax comes with a bunch of learning rate
schedulers. Instead of passing a float value to learning_rate
when creating
the optimiser, we can pass a scheduler. When applying updates, Optax will
automatically select the correct learning rate. Let's define a simple, linear
schedule:
start_lr, end_lr = 1e-3, 1e-5
steps = 10_000
lr_scheduler = optax.linear_schedule(
init_value=start_lr,
end_value=end_lr,
transition_steps=steps,
)
optimiser = optax.adam(learning_rate=lr_scheduler)
You can join together schedulers using optax.join_schedules
in order to get
more complex behaviour like learning rate warmup followed by decay:
warmup_start_lr, warmup_steps = 1e-6, 1000
start_lr, end_lr, steps = 1e-2, 1e-5, 10_000
lr_scheduler = optax.join_schedules(
[
optax.linear_schedule(
warmup_start_lr,
start_lr,
warmup_steps,
),
optax.linear_schedule(
start_lr,
end_lr,
steps - warmup_steps,
),
],
[warmup_steps],
)
optimiser = optax.adam(lr_scheduler)
The last argument to optax.join_schedules
should be a sequence of integers
defining the step boundaries between different schedules. In this case, we
switch from warmup to decay after warmup_steps
steps.
Optax keeps track of the number of optimiser steps in its
opt_state
, so we don't need to track this ourselves. It will use this count to automatically pick the correct learning rate.
Similar to joining schedulers, Optax supports chaining optimisers together. More specifically, the chaining of gradient transformations:
optimiser = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(1e-2),
)
When calling optimiser.update
, the gradients will first be clipped before
then doing the regular Adam update. Chaining together transformations like this
is quite an elegant API and allows for complex behaviour. To illustrate, adding
exponential moving averages (EMA) of our updates in something like PyTorch is
non-trivial, whereas in Optax it is as simple as adding optax.ema
to our
optax.chain
call:
optimiser = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(1e-2),
optax.ema(decay=0.999)
)
In this case, optax.ema
is a transformation on the final updates, rather than
on the unprocessed gradients.
Gradient accumulation is implemented in Optax as a optimiser wrapper, rather than as a gradient transformation:
grad_accum = 4
optimiser = optax.MultiSteps(optax.adam(1e-2), grad_accum)
The returned optimiser collects updates over the optimiser.update
calls until
grad_accum
steps have occurred. In the intermediate steps, the returned
updates will be a PyTree of zeros in the same shape as params
, resulting in
no update. Every grad_accum
steps, the accumulated updates will be returned.
grad_accum
can also be a function, which gives us a way to vary the batch
size during training via adjusting the number of steps between parameter
updates.
How about if we only want to train certain parameters? For example, when finetuning a pretrained model. Nowadays, this is a pretty common thing to do, taking pretrained large language models and adapting them for specific downstream tasks.
Let's grab a pretrained BERT model from the Huggingface hub: ```python from transformers import FlaxBertForSequenceClassification model = FlaxBertForSequenceClassification.from_pretrained('bert-base-uncased') model.params.keys()
Out: dict_keys(['bert', 'classifier'])
> Huggingface provides Flax versions of *most* of their models. The API to use
them is a bit different, calling `model(**inputs, params=params)` rather than
`model.apply`. Providing no parameters will use the pretrained weights stored
in `model.params` which is useful for inference-only tasks, but for training we
need to pass the current parameters to the call.
We can see there are two top-level keys in the parameter PyTree: `bert` and
`classifier`. Suppose we only want to finetune the classifier head and leave the
BERT backbone alone, we can achieve this using `optax.multi_transform`:
```python
optimiser = optax.multi_transform({'train': optax.adam(1e-3), 'freeze': optax.set_to_zero()}, {'bert': 'freeze', 'classifier': 'train'})
opt_state = optimiser.init(model.params)
grads = jax.tree_map(jnp.ones_like, model.params)
updates, opt_state = optimiser.update(grads, opt_state, model.params)
optax.multi_transform
takes two inputs, the first is mapping from labels to
gradient transformations. The second is a PyTree with the same structure or
prefix as the updates (in the case above we use the prefix approach) mapping to
labels. The transformation matching the label of a given update will be
applied. This allows the partitioning of parameters and applying different
updates to different parts.
The second argument can also be a function that, given the updates PyTree, returns such a PyTree mapping updates (or their prefix) to labels.
This can be used for other cases like having different optimisers for different
layers (such as disabling weight decay for certain layers), but in our case we
simply use optax.adam
for our trainable parameters, and zero out gradients
for other regions using the stateless transform optax.set_to_zero
.
In jit-compiled function, the gradients that have
optax.set_to_zero
applied to them won't be computed due to the optimisation process seeing that they will always be zero. Hence, we get the expected memory savings from only finetuning a subset of layers!
Let's print the updates so that we can see that we do indeed have no updates in the BERT backbone, and have updates in the classifier head: ```python updates['classifier'], updates['bert']['embeddings']['token_type_embeddings']
Out: {'bias': Array([-0.00100002, -0.00100002], dtype=float32), 'kernel': Array([[-0.00100002, -0.00100002], [-0.00100002, -0.00100002], [-0.00100002, -0.00100002], ..., [-0.00100002, -0.00100002], [-0.00100002, -0.00100002], [-0.00100002, -0.00100002]], dtype=float32)} {'embedding': Array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)}
We can verify that all updates are zero using `jax.tree_util.tree_reduce`:
```python
jax.tree_util.tree_reduce(lambda c, p: c and (jnp.count_nonzero(p) == 0), updates['bert'], True)
===
Out: Array(True, dtype=bool)
Both Flax and Optax are quite feature-rich despite the relative infancy of the JAX ecosystem. I'd recommend just opening the Flax or Optax API reference and searching for layers, optimisers, loss functions, and features you are used to having in other frameworks.
The last thing I want to talk about involves an entirely different library built on JAX. Orbax provides PyTree checkpointing utilities for saving and restoring arbitrary PyTrees. I won't go into great detail but will show basic usage here. There is nothing worse than spending hours training only to realise you forgot to add checkpointing code!
Here is basic usage saving the BERT classifier parameters:
import orbax
import orbax.checkpoint
from flax.training import orbax_utils
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(model.params['classifier'])
orbax_checkpointer.save('classifier.ckpt', model.params['classifier'], save_args=save_args)
!ls
===
Out: classifier.ckpt
Which we can restore by executing: ```python orbax_checkpointer.restore('classifier.ckpt')
Out: {'bias': array([0., 0.], dtype=float32), 'kernel': array([[-0.06871808, -0.06338844], [-0.03397266, 0.00899913], [-0.00669084, -0.06431466], ..., [-0.02699363, -0.03812294], [-0.00148801, 0.01149782], [-0.01051403, -0.00801195]], dtype=float32)}
Which returns the raw PyTree. If you are using a custom dataclass with objects
that can't be serialised (such as a Flax train state where `apply_fn` and `tx`
can't be serialised) you can pass an example PyTree to `item` in the `restore`
call, to let Orbax know the structure you want.
Manually saving checkpoints like this is a bit old-fashioned. Orbax has a bunch
of automatic versioning and scheduling features built in, such as automatic
deleting of old checkpoints, tracking the best metric, and more. To use these
features, wrap the `orbax_checkpointer` in
`orbax.checkpoint.CheckpointManager`:
```python
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
'managed-checkpoint', orbax_checkpointer, options)
for step in range(10):
checkpoint_manager.save(step, model.params['classifier'], save_kwargs={'save_args': save_args})
!ls -l managed-checkpoint/*
===
Out:
managed-checkpoint/6:
total 4
drwxr-xr-x 2 root root 4096 Jun 3 09:07 default
managed-checkpoint/7:
total 4
drwxr-xr-x 2 root root 4096 Jun 3 09:07 default
managed-checkpoint/8:
total 4
drwxr-xr-x 2 root root 4096 Jun 3 09:07 default
managed-checkpoint/9:
total 4
drwxr-xr-x 2 root root 4096 Jun 3 09:07 default
As we set max_to_keep=4
, only the last four checkpoints have been kept.
We can view which steps have checkpoints: ```python checkpoint_manager.all_steps()
Out: [6, 7, 8, 9]
As well as view if there is a checkpoint for a specific step:
```python
checkpoint_manager.should_save(6)
===
Out: False
And what the latest saved step was: ```python checkpoint_manager.latest_step()
Out: 9
We can restore using the checkpoint manager. Rather than provide a path to the
`restore` function, we provide the step we want to restore:
```python
step = checkpoint_manager.latest_step()
checkpoint_manager.restore(step)
===
Out: {'bias': array([0., 0.], dtype=float32),
'kernel': array([[-0.06871808, -0.06338844],
[-0.03397266, 0.00899913],
[-0.00669084, -0.06431466],
...,
[-0.02699363, -0.03812294],
[-0.00148801, 0.01149782],
[-0.01051403, -0.00801195]], dtype=float32)}
For especially large checkpoints, Orbax supports asynchronous checkpointing
which moves checkpointing to a background thread. You can do this by wrapping
orbax.checkpoint.AsyncCheckpointer
around the
orbax.checkpoint.PyTreeCheckpointer
we created earlier.
You may see reference online to Flax checkpointing utilities. However, these utilities are being deprecated and it is recommended to start using Orbax instead.
The documentation for Orbax is a bit spartan, but it has a fair few options to
choose. It is worth just reading the CheckpointManagerOptions
class
here
and seeing the available features.
Conclusion
In this blog post, I've introduced two libraries built on top of JAX: Flax and Optax. This has been more of a practical guide into how you can implement training loops easily in JAX using these libraries, rather than a ideological discussion like my previous blog post on JAX.
To summarise this post:
- Flax provides a neural network API that allows the developer to build neural network modules in a class-based way. Unlike other frameworks, these modules do not contain state within them, essentially hollow shells that loosely associate functions with parameters and inputs, and provide easy methods to initialise the parameters.
- Optax provides a large suite of optimisers for updating our parameters.
These, like Flax modules, do not contain state and must have state passed
manually to it. All optimisers are simply gradient transformations: a
pair of pure functions
init
andupdate
. Optax also provides other gradient transformations and wrappers to allow for more complex behaviour, such as gradient clipping and parameter freezing. - Both libraries simply operate on and return PyTrees and can easily
interoperate with base JAX — crucially with
jax.jit
. This also makes them interoperable with other libraries based on JAX. For example, by choosing Flax, we aren't locked into using Optax, and vice versa.
There is a lot more to these two libraries than described here, but I hope this is a good starting point and can enable you to create your own training loops in JAX. A good exercise now would be to use the training loop and model code in this blog post and adapting it for your own tasks, such as another generative model.
If you liked this post please consider following me on Twitter or visiting my website for more ramblings about machine learning and other topics. Thank you for reading this far and I hope you found it useful!
Acknowledgements and Extra Resources
Some good extra resources:
- My previous blog post on JAX
- Aleksa Gordic’s JAX and Flax tutorial series
- Flax documentation
- Optax documentation
- Orbax source code
Some alternatives to Flax:
I am not aware of relatively mature alternatives to Optax. If you know of some, please let me know!
Found something wrong with this blog post? Let me know via email or Twitter!