Introduction to JAX (AI Adventures)

Alex Alex 16 May
Introduction to JAX (AI Adventures)

NumPy is fast, but how can we make it even faster?  In this article, we're going to look at a new library from Google Research called JAX and see how it can speed up machine learning.

JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures. And it can take derivatives of derivatives of derivatives. It supports reverse mode differentiation, also known as back-propagation, using the grad function, as well as forward mode differentiation. And the two can be composed arbitrarily in any order you want. It can seem like every other library these days supports auto-differentiation. 

So what else can JAX do? Well, it's also able to speed up your code, sometimes really significantly, by using a special compiler under the hood.

Accelerated Linear Algebra, or XLA, is a domain-specific compiler for linear algebra. It can perform optimizations, like fusing operations together so intermediate results don't need to be written out to memory.  Instead of this time-consuming process, this data gets streamed right into the next operation. And this enables faster and more efficient processing.

JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs. Compilation happens transparently, with NumPy library calls getting sped up. But JAX takes it a step further than XLA alone, as it lets you just-in-time compile your very own Python functions into XLA-optimized kernels, using a single function API called jit.

Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximum performance, all without leaving Python.

What else is there other than jit? There's also pmap. Applying pmap means that the function you write gets compiled by XLA, just like jit, and replicated and then executed in parallel across devices. That's what the p in pmap stands for. This means you can do compilations on multiple GPUs or TPU cores all at once using pmap, and then differentiate through them all.

JAX boils down to an extensible system for composable function transformations. The main ones today are

  • grad
  • jit
  • pmap
  • vmap

Vmap is used for automatic vectorization, allowing you to turn a function that can handle only one data point into a function that can handle a batch of these data points of any size with just a single wrapper function.

Let's take a look at how this all comes together using a familiar example, training a deep neural network on the MNIST data set.

import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
param_scale = 0.1
step_size = 0.0001
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

for W,b in params: print(f'Weights: {W.shape}, bias: {b.shape}')

This code starts out by creating two utility functions to make a neural network with just randomly initialized parameters. I printed out the dimensions of each layer for convenience.

> Weights: (512, 784), bias: (512,)
> Weights: (512, 512), bias: (512,)
> Weights: (10, 512), bias: (10,)

We can see here that it takes a 784-unit-wide input and passes through two hidden layers whose size is 512 each. And the outputs are the usual 10 classes, since we're predicting what digit is supposed to be in that image.

Next, lets take a look at a function that takes care of running an image through our predict function.

from jax.scipy.special import logsumexp

def relu(x):
    return np.maximum(0, x)

def predict(params, image):
    # per-example predictions
    activations = image
    for w, b in params[:-1]:
        outputs = np.dot(w, activations) + b
        activations = relu(outputs)
  
    final_w, final_b = params[-1]
    logits = np.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,)) preds = predict(params, random_flattened_image) print(preds.shape)

> (10,)

Our predict function only handles one image at a time. And we can confirm this by passing in a single random image of the correct dimensions, which gives us a vector of size 10, representing the 10 logit values coming out of the final layer of the network.

# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
    preds = predict(params, random_flattened_images)
except TypeError:
    print('Invalid shapes!')

> Invalid shapes!

But when we try a whole batch of images-- say of length 10-- also, it fails, since the array dimensions no longer line up. But we're in luck, because wrapping our predict function in a vmap will allow us to take advantage of matrix multiplication and run all 10 images through the model in a single pass, rather than doing them one by one.

# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

> (10, 10)

The resultant function can handle a batch of arbitrary size, and we don't have to modify our function one bit. Notice that the output is now 10 by 10, representing the 10 logic values coming out of the final layer for each of the 10 examples in that batch.

Now let's see how we can use the grad and jit functions to build out the remainder of our model, as well as our training code.

def one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(batched_predict(params, images), axis=1)
  return np.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -np.sum(preds * targets)

def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

We'll add in a function to one_hot encode our data, and a couple more functions to calculate accuracy and loss. Finally, we'll put together our update function, which will take the result of the loss function and run grad on it, which will take care of the back-propagation for us and return the updated parameters of the model. We're almost ready to run our model. We just need to add in some code using TensorFlow data sets to bring in our MNIST  data set.

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = np.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = np.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

JAX purposely does not include data set loading functionality, as it's focused on program transformations and accelerator-backed NumPy.

So now we're ready to train our model.

print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

> Train: (60000, 784) (60000, 10) > Test: (10000, 784) (10000, 10)

Our training loop is set for 10 epochs. And we have a timer added in as well, because we want to see how it performs. So let's run this.

import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(128).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = np.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

And we can see that across 10 epochs, we ended up spending about 22 seconds per epoch.

Epoch 0 in 41.37 sec
Training set accuracy 0.9998833537101746
Test set accuracy 0.9827000498771667
Epoch 1 in 34.83 sec
Training set accuracy 0.9998833537101746
Test set accuracy 0.9830000400543213
Epoch 2 in 34.31 sec
Training set accuracy 0.999916672706604
Test set accuracy 0.9830000400543213
Epoch 3 in 34.52 sec
Training set accuracy 0.999916672706604
Test set accuracy 0.9830000400543213
Epoch 4 in 34.43 sec
Training set accuracy 0.999916672706604
Test set accuracy 0.9831000566482544
Epoch 5 in 34.01 sec
Training set accuracy 0.9999499917030334
Test set accuracy 0.9832000732421875
Epoch 6 in 34.16 sec
Training set accuracy 0.9999666810035706
Test set accuracy 0.9833000302314758
Epoch 7 in 34.21 sec
Training set accuracy 0.9999666810035706
Test set accuracy 0.9834000468254089
Epoch 8 in 34.39 sec
Training set accuracy 1.0
Test set accuracy 0.9834000468254089
Epoch 9 in 35.13 sec
Training set accuracy 1.0
Test set accuracy 0.9834000468254089

You might be thinking, wait, wait, didn't Yufeng mentioned something about using the jit function? Did we ever add that? Good catch. Let's add in the @jit decorator at the top of our update function, and we'll rename it jit_update.

@jit
def jit_update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

Now we'll have a before and after comparison.

Epoch 0 in 14.44 sec
Training set accuracy 0.9742833375930786
Test set accuracy 0.9678000211715698
Epoch 1 in 13.17 sec
Training set accuracy 0.9841333627700806
Test set accuracy 0.9734000563621521
Epoch 2 in 13.10 sec
Training set accuracy 0.9886000156402588
Test set accuracy 0.976300060749054
Epoch 3 in 13.16 sec
Training set accuracy 0.9919500350952148
Test set accuracy 0.9783000349998474
Epoch 4 in 13.09 sec
Training set accuracy 0.9937333464622498
Test set accuracy 0.9790000319480896
Epoch 5 in 13.30 sec
Training set accuracy 0.9961166977882385
Test set accuracy 0.9806000590324402
Epoch 6 in 13.20 sec
Training set accuracy 0.9974666833877563
Test set accuracy 0.9808000326156616
Epoch 7 in 13.29 sec
Training set accuracy 0.9983500242233276
Test set accuracy 0.9815000295639038
Epoch 8 in 13.14 sec
Training set accuracy 0.9984666705131531
Test set accuracy 0.9819000363349915
Epoch 9 in 13.18 sec
Training set accuracy 0.9994833469390869
Test set accuracy 0.9828000664710999

So this is looking like it's taking way less time -- only 8 seconds per epoch. And all we had to do was add four characters to the top of the update loop. Now that's "le-jit."

Before I close this out, I want to remind, JAX is still just a research project and not an official Google product. So it's likely you may encounter bugs and sharp edges. The team has even made a list of gotchas and a gotchas notebook to help you out. Since this list will be constantly evolving, be sure to see what the state of things are if you plan on using jit for your project.

Comments (0)

    No comments yet

You must be logged in to comment.

Sign In / Sign Up