Writing a Custom Training loop

Learn via video courses
Topics Covered

Overview

Training deep neural networks forms a crucial and time-consuming part of building AI-powered applications. To this end, this article aims to teach the several components that go into implementing an efficient custom training loop in PyTorch using PyTorch’s API support.

Introduction

After deciding on the type of model architecture, model training is the next step that is a crucial part of building any deep learning-based system or product.

Model training refers to an iterative process where we calculate the model's output and quantify how much the calculated output differs from the expected output (also called the target) using a suitable criterion known as the loss function in machine learning terms.

This loss value calculated using the loss function is then used to backpropagate the errors. Finally, mathematically, the gradients of the loss function concerning the model parameters are calculated, and an optimization algorithm uses these gradients to update the model parameters.

To this end, we want our training routine to be as efficient and reusable as possible. By reusable, we should be able to reuse the same script for different architectures with minimum modifications.

How to Write a Custom Training Loop?

Setting it Up

Let us first import all the dependencies like so:

Now, we will download the MNIST dataset using torchvision.datasets, thus creating Dataset objects for the training and testing divisions. We will also wrap these Dataset objects inside the DataLoader class to create DataLoader instances for both divisions for easy and efficient data loading.

We also define some hyperparameter variables to be used later in the code.

Defining a Simple Model

To build the classifier, we will define a simple feed-forward neural network consisting of three linear layers - the input layer, the hidden layer, the output layer, and two more layers - one relu activation layer to introduce non-linearity and one dropout layer to induce regularisation.

Although a feed-forward neural network is not the best-suited network for image classification, we keep things simple here to focus primarily on the training loop's components.

The model:

After specifying the custom model architecture, we instantiated it and defined a loss function and an optimization algorithm to train our network.

We have chosen Cross Entropy Loss as the loss function to backpropagate the errors (as we have a multi-class single-label problem), and Adam optimizer as the optimization algorithm.

Taking all these components together, we will now write a custom training loop from scratch.

Writing the Custom Training Loop from Scratch

In the training loop, we will code what we understood in words about the term "training the neural networks."

Understanding loss.backward()

PyTorch's automatic differentiation engine called autograd keeps track of every operation on the tensors that require gradients, creating a computation graph consisting of all the tensor operations tensors are subjected to.

In neural networks, the weight tensors of the parameters are what we want to optimize. These tensors are, by default, the leaf tensors requiring gradient grad. As is explained above, we aim to optimize these model parameters (weights and biases) using some criterion which is called the model loss. The model loss is a function of the model output and the target tensor. The model output, eventually, is a function of all the model parameters.

The target tensor is a fixed one and hence has nothing to do with the optimization process regarding adjusting and updating the neural network.

Hence, the model parameters are the leaf tensors in the graph of the loss of the neural network, and they have their requires_grad attribute to True.

As soon as we call .backward() on the loss tensor, a backward graph is constructed. The gradients of the loss (the tensor on which backward is called) concerning the leaf tensors (the model parameters ) are calculated, and their grad attribute of them is populated.

These gradients are what we are required to update our model parameters and what we eventually access using weight.grad or bias.grad to update the weights and biases, respectively.

Understanding torch.no_grad()

torch.no_grad() is a context manager that tells the autograd engine to look away. Whatever operations are performed under torch.no_grad() are not tracked by auatograd and hence are not included in any computation graph.

As we update our model, we put the parameter update steps and zero out their gradients inside the no_grad context manager. This ensures the parameter updates are not tracked by autograd and no expensive gradient calculation is performed. We do not bother about including the parameter update step in any computational graph and only care about changing the parameter value, which is an operation that requires no gradient.

Training with Abstract Components in the Loop

The above training loop works perfectly but could be more efficient since we need to loop over every layer and then access its weights and biases separately before updating them.

We would like to have a way to be able to loop over all the parameters at one go like model.parameters(). We will look at how to do just that next.

We first re-do our model class so that it stores information about all the layers in such a way that all parameters are accessible using model.parameters().

The new model class:

In the new model class, we have utilized the python dunder method called setattr that is called every time a new attribute is created. We use it to store the information about all the model layers in a dictionary that we later use in the parameter method to create a generator that yields the model parameters.

we can now access the model parameters, like so:

Pytorch already implements all of this, and we need to inherit from the base class nn.Module so and call super().__init__() first.

With that said, our custom model class shall now look like the following -

We can now access all the layers of our model using model.named_children(), like so:

And also, abstract some components of our previous training loop by directly accessing the parameters using model.parameters(), like so :

If we create a list of layers like how we did in the beginning, we cannot access the model parameters using model.parameters() then. To be able to do that, we explicitly need to add each layer to the _modules using add_module, like so:

PyTorch offers yet another extension for it using ModuleList that we could use to pass in a list of modules (layers), like so:

We could also leverage the Sequential container for this, like so:

To finalize things with some other abstractions, we make use of the optimization algorithms offered by PyTorch torch.optim API. Specifically, we will define an optimizer instance and use its step method to do all the parameter updation.

Our final model training loop along with all the other components shall then look like the following -

Our training loop now looks much more efficient and clean with proper abstractions. Moreover, there is a lot of room to experiment with different optimization algorithms, and different values of hyperparameters like the number of epochs, the learning rate, and so on.

That was all about how to write a custom training routine in PyTorch to train deep neural networks.

Conclusion

This article discussed one of the most crucial aspects of building deep neural networks: training the models. Training the models is how the models learn the patterns from the data. in particular,

  • We briefly reviewed the model training process and where it manifests itself in a deep learning modeling pipeline.
  • We downloaded the MNIST dataset and created data loaders for its training and testing divisions.
  • After this, we wrote a training loop from scratch, inefficiently implementing what we know from theory in code.
  • Then, we gradually learned our training routine more efficiently. Finally, we wrote a simplistic and clean version of our training routine using PyTorch's API support for optimization called torch.optim.