Image Classification using PyTorch Lightning

Learn via video courses
Topics Covered

Overview

This tutorial article teaches the reader how to build an image classification model based on deep neural networks using PyTorch Lightning.

We will use PyTorch Lightning to build a convolutional neural network-based architecture to classify the images in the CIFAR-10 dataset.

This article teaches the reader the essential components required to leverage PyTorch Lightning, a framework for deep learning engineers and researchers aimed at providing maximum flexibility in building machine learning systems at scale in a more robust and reliable way that eliminates the need for writing boilerplate code.

What Are We Building?

The following article will build a classifier based on the foundational computer vision architectures called Convolutional Neural Networks.

We will use the easy API support provided by PyTorch Lightning to define our data pipeline, including downloading and accessing the datasets, creating appropriate splits and dataloader instances, and training components like training, evaluating, and evaluating, also saving our trained model as a checkpoint.

We will also use a third-party logger provided by weights and biases that is fully compatible with PyTorch Lightning since it is one of the supported loggers.

For the full list of third-party loggers supported by PyTorch Lightning, refer this.

Pre-requisites

This article assumes that the reader has a basic familiarity with the following -

  • CNN - A conceptual understanding of what the Convolutional neural network architectures are and the conv2D API provided by PyTorch for the same. Refer to this article to get an understanding of CNNs in PyTorch.
  • PyTorch - It is assumed that the reader is familiar with a basic functioning of the deep learning library called PyTorch, as PyTorch Lightning is based on it.
  • PyTorch Lightning - An introductory understanding of PyTorch Lightning shall help the reader to get the most out of this blog. Refer to this article for an introduction to PyTorch Lightning and this article to learn about how to migrate from PyTorch to PyTorch Lightning.

How Are We Going to Build This?

In this tutorial, we will use two core classes from PyTorch Lightning called the DataModule and the LightningModule to define a data pipeline for downloading and accessing the CIFAR-10 dataset for defining and training our convolutional neural network, respectively.

We will also leverage PyTorch Lightning's Trainer class to automate the logging and training process and save the trained model checkpoint.

Final Output

The final output of this tutorial will be a trained model checkpoint for classifying the images in the CIFAR-10 dataset that PyTorch Lightning will automatically save.

Requirements

Let us first install PyTorch Lightning and import all the necessary dependencies, like so -

To instead install using conda package manager, use the following command -

Build an Image Classification Model with PyTorch Lightning

We will now define some classes required to build and train an image classification model using PyTorch Lightning while describing the use of each class and its methods progressively.

DataModule - The Data Pipeline

Datamodules from PyTorch Lightning are a way to decouple the model definition and training from the data accessing and loading components of the life cycle of any machine learning project.

Hence, Data Module encapsulates all the steps needed to process data in one shareable and reusable class.

Broadly, a total of five steps are encapsulated by this class -

  1. Download/tokenize/ access the data
  2. Preprocess the data and optionally save it to the disk
  3. Load the dataset as PyTorch Dataset instance
  4. Apply transformations to the data
  5. Create PyTorch DataLoader instances for the different splits of the data (viz train, validation, and test)

We will now define our custom Data Module inheriting from the base class pl.LightningDataModule.

Let us step by step discuss the methods in there -

init

This method is used to define the hyperparameters required for the above five steps, like the batch size and the transformations to be applied to the data so on and so forth, like so -

prepare_data

This function is called only within a single process on the CPU. It is used for tasks like downloading the data, tokenizing the data, and so on - essentially, all the steps required to prepare the data to make it available for use are to be carried out in this method.

For our project, we will be downloading the CIFAR-10 dataset in this method post, which it'll be available for use in the specified directory, like so -

Setup

This function is used to perform such data operations that we want to be able to leverage every GPU for, i.e., operations that are not for a single process, as is the case with prepare_data. So, for example, this function can be used to - create dataset instances, build vocabulary, count the number of classes, etc.

Here, we will use this function to access the data from where we downloaded it above and create PyTorch tensor datasets using it for each split, viz.train, test etc.

Note that this method expects a staged argument to separate the setup logic for trainer.{fit,validate,test,predict}.

train_dataloader

This method is used to wrap the training dataset instance we define above in setup into PyTorch's dataloader class to create a dataloader instance which is then used by the fit() method of the Trainer class that we will look at shortly.

The function definition hence goes like this -

Note how we have similarly defined the other methods, val_dataloader and test_dataloader, for creating the dataloader instances for the validation and test datasets, respectively.

LightningModule

LightningModule is used to organize (not abstract) 6 major sections of our PyTorch code in one single class. The 6 sections are as follows -

  • Computations (init).
  • Train Loop (training_step)
  • Validation Loop (validation_step)
  • Test Loop (test_step)
  • Prediction Loop (predict_step)
  • Optimizers and LR Schedulers (configure_optimizers)

This class defines our model architecture along with the training, validation, testing, etc., loops. The optimization algorithm is also naturally included in this class as the training loop needs it.

The trainer class automates whatever is left in the LightningModule, for example, the batched iterations using the dataloader, etc.

We will now step by step work through the different methods of the Lightning Module class and understand what section is implemented by which method and how.

Model Computations

Our custom LightningModule inherits from the base class pl.LightningModule. We use the init method to define the model architecture much like it is done with plain PyTorch - we define four conv2d layers with two pooling layers and a classifier head consisting of fully connected layers.

Similarly, a forward method is also defined here, just as it is done in PyTorch.

But, the difference exists in the usage and purpose - forward in PyTorch Lightning is used only at the time of inference as the training loop is defined using training_step, that we will look at next.

Here we have also defined two helper methods called _feature_extractor which is used to extract the output of the convolution block, and _get_output_shape which is used to get the shape of the output from the convolution block which will eventually be used to define our classifier head consisting of the feed-forward layer with the final layer having an output shape of 10 equal to the number of classes.

Loops

The training, validation, and testing loops can be defined by overwriting the training_step, validation_step, and test_step methods respectively.

For all the three methods, we are also logging the metrics epoch-wise by simply passing on_epoch=True to the .log method.

Here is the code to define the training, validation and testing loops -

The training_step method requires batch and batch_idx as arguments which the Trainer takes care of automatically.

As we can see, the testing loop is very similar to the validation loop, the only difference being that the former is only called when trainer.test() is used.

Optimization

The configure_optimizers method can be used to define our optimization algorithm and learning rate schedulers - it even allows us to define multiple optimizers for networks like GANs.

The code for this method is made to return an instance from a class in the torch.optim package, like so -

The Trainer Class - Train and Evaluate

We will now take all the components together to set up the training for our model finally.

In particular, we will instantiate the DataModule and LightningModule classes we defined above and use the PyTorch Lightning Trainer that automates the following steps for us -

  • Epoch and batch iteration for training, testing etc.
  • Calling of optimizer.step(), .backward(), and .zero_grad().
  • Calling of .eval(), enabling/disabling gradients
  • Saving and loading model checkpoints
  • Logging using a third party (wandb in our case)
  • Multi-GPU training support
  • TPU support
  • 16-bit training support

We will first instantiate the classes with our third-party logger in the code.

We will then create a Trainer instance using appropriate arguments.

Note that we are using a built-in callback called pl.callbacks.ModelCheckpoint, which is used by the Trainer To save the trained model checkpoint automatically.

After that, we will call the trainer.fit, passing the model instance along with the Data Module instance to train the model, and similarly use the trainer.test call to evaluate our trained model.

Here is the code to do that -

output -

As we can see in the output, the test metrics are automatically logged and the model checkpoint is saved in the specified path.

We can now load the saved model from the location simply with a single line of code, like so -

output -

The loaded model can now be used for inference or deployment etc.

Conclusion

In the following article tutorial, we learned to use PyTorch Lightning for building an image classification system for the CIFAR-10 dataset. Let us review in points that we learned in this article -

  • We learned about PyTorch Lightning's DataModule class and defined our custom class suited according to our dataset and task.
  • Then we learned about PyTorch Lightning's LightningModule class and defined our custom class understanding the purpose of each method therein.
  • After this, we finally take these two classes together and use another class called the Trainer class to train and evaluate our model. We also learned how PyTorch Lightning automatically saves model checkpoints and how those can be loaded back for inference.
  • Throughout the project, we also used a third-party logger by Weights and Biases to log the metrics.