Introduction to Neural Network in JAX

Matthew Leung
4 min readAug 24, 2022

JAX is a functional computation library developed by Google, it provided some primitive building blocks for vectorised computation and automatic differentiation.

The simplest way to explain the training of a neural network model is that it targets to find the model parameters to minimize loss between the actual value and the model prediction (aka forward-feed function). Then, the training is formulated as the optimization problem. One common way to solve it is to get the gradient of the loss function. JAX core provides the automatic differentiation of any jax function by jax.grad. That makes JAX a perfect match to neural network training.

In order to make use of the JAX automatic differentiation, we have to:

(1) transform the loss function to the stateless pure JAX function,

(2) separate out the model parameters from the object class.

That can be easily done by using Haiku python library. The core of the Haiku library are hk.Module and hk.transform .

Remember that our target is to minimize the loss function by computing its gradient. To achieve that, the loss function must be pure JAX function.

For the case of simple regression problem, the loss function can be just the simple mean square of the difference of the prediction and the actual value: jnp.mean(jnp.square(pred — y)). JAX already provided the mean and square function same as numpy does. However, the prediction pred is computed by the the model forward-feed function which needs to transform to pure JAX function.

The forward-feed function is defined inunroll_net which I will explain more details later. And the transform is done by model = hk.transform(unroll_net), where it turns the unroll_net into the object (model) which contains theinit and apply pure JAX methods. These 2 methods will separate out the model parameters from the stateful nerual network model.

The model.init method will randomly initialize the model parameters with dummy datasample_x. It will not store the parameters, but will pass it out as the return value params for further update later in the model training.

The model.apply method will perform the forward-feed computation by passing with the initialized model parameters params, and each data point x. Here is the code:

Transform to pure JAX function:

model = hk.transform(unroll_net)

Initialize model parameters (params):

# Initialize model parameters with dummy data sample_x
rng = jax.random.PRNGKey(428)
params = model.init(rng, sample_x)

Model forward-feed computation:

pred, _ = model.apply(params, None, x)

The final loss function:

def loss(params, x, y):
pred, _ = model.apply(params, None, x)
return jnp.mean(jnp.square(pred - y))

As the model.apply and the jnp.mean/jnp.square are pure JAX functions, we then get compute the gradient of that pure JAX loss function. The jax.value_and_grad(loss) will return a function which accepts the same arguments as theloss function, but will return the gradient (grads) , and the loss value (l) as well.

l, grads = jax.value_and_grad(loss)(params, x, y)

The gradient is used to update the model parameters (params) by the ADAM optimizer opt = optax.adam(1e-4).

Initialize optimizer parameters:

opt_state = opt.init(params)

Update model parameters and optimizer parameters

grads, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, grads)

The final training loop will repeat the gradient and loss computation, and update the model parameters for each batch of data points.

Now, we understand how the model training is performed in general. Then, let’s drill down into the model definition by beginning with Linear layers of Neural Network for classification. A simple forward-feed function is just pass the data point (x) into the sequences (mlp) of 3 linear layers (hk.Linear).

def forward_fn(vec: jnp.ndarray) -> jnp.ndarray:  
mlp = hk.Sequential([
hk.Flatten(),
hk.Linear(300), jax.nn.relu,
hk.Linear(100), jax.nn.relu,
hk.Linear(NUM_CLASSES), ])
return mlp(x)

Apart from the basic linear layer (hk.Linear), there are a lot of other predefined modules already defined in Haiku, such the LSTM model. However, the forward-feed is more complicated. Here is the definition of the model foward-feed function (unroll_net):

def unroll_net(seqs: jnp.ndarray):
"""Unrolls an LSTM over seqs, mapping each output to a scalar."""
# seqs is [T, B, F].
core = hk.LSTM(32)
batch_size = seqs.shape[1]
outs, state =
hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
return hk.BatchApply(hk.Linear(1))(outs), state

The core network module is hk.LSTM(32) which has 32 hidden states. As RNN/LSTM contains hidden memory states, we cannot just pass the data into the network module as it was done in the linear case. It has to be performed by hk.dynamic_unroll It calls the core on each element of the input sequence (seqs) in a loop, carrying the state through.

Although there are many predefined common modules in Haiku, sometimes we will need to write ourselves if those cannot fit into our application. In such case, custom module can be defined in Haiku by extending the classhk.Module which is similar to the PyTorch module extending torch.nn.Module .

We only need to implement 2 class methods: __init__ and __call__ which is actually the forward function in PyTorch. Below is the example shown on the Haiku GitHub:

class MyLinear(hk.Module):def __init__(self, output_size, name=None):
super().__init__(name=name)
self.output_size = output_size
def __call__(self, x):
j, k = x.shape[-1], self.output_size
w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
return jnp.dot(x, w) + b

In order to separate out the state from the module, each named model parameter must be retrieved through hk.get_parameter in the forward feed pass (__call__). The reason that avoiding using object properties to store the model parameters is that the function can be converted into a pure JAX function using hk.transform, as we explained in the beginning.

There are many sample codes in the Haiku Git folder. I can easily modify the haiku_lstms.ipynb to create a time series predictor for stock quotes data.

--

--