Dataset Class in PyTorch

Learn via video courses
Topics Covered

Overview

This article is a tutorial explaining how to write a custom PyTorch Dataset class, and use it along with the PyTorch DataLoader class to preprocess the data points and make the data ready to feed into the neural networks for training. A complete hands-on article demonstrating the implementation of a custom dataset class for two datasets of different types - an image dataset and a text dataset.

Introduction

Garbage in, garbage out -- This is what deep learning models, or any modelling algorithm for that matter, live by.

To this end, to be able to pre-process our data and load it in a form that is ready to be fed into our neural network models, PyTorch has provided us with two separate classes namely -

  • PyTorch Dataset class - This is the one that deals with the data fetching (from the source) and the preprocessing part, and hence eventually gets the data ready in a form the neural network requires for training.
  • DataLoader class - This is the one that deals with the parallel, and hence efficient loading of data in the form of batches for feeding into the models, enabling easy access to the data samples.

With that discussed, some natural questions arise like the one below -

  • Why do we even need separate classes for these tasks?
  • Why cannot we include the pipeline handled by the PyTorch Dataset class and DataLoader class in our model training pipeline only?

These and similar questions are crucial to answer - In short, to enhance better readability and modularity, we would want the code meant for processing data points to be separated from the code dealing with the model training part.

And hence, further, the two separate classes serve better maintainability while ensuring our code doesn't get messy.

Another interesting dimension is that we do not want the operations involved in preprocessing the data to become a bottleneck, since the data should not make the model wait for the time it gets preprocessed. DataLoader class ensures this - we will see how later in the article.

The Dataset Class

PyTorch offers support for two different types of datasets:

  • Map-style datasets
  • Iterable-style datasets.

Iterable Style Dataset

  • Such a form of the dataset is used when the data comes from some form of a stream. An iterable-style dataset can be achieved as an instance of a subclass of IterableDataset that implements the __iter__() method.

  • These types of datasets represent an iterable of data samples and are particularly suited for cases where reading the data randomly is costly and/or improbable, and where the batch size depends on the data fetched.

  • See IterableDataset for more details on iterable style datasets as going forward, we will be dealing with the second category - Map style datasets only that we define just next.

Map-style Datasets

  • A map-style dataset simply represents a map from indices/keys to data samples. The indices do not have to be mandatorily integral and could be possibly non-integral. We will also talk about how non-integral and integral indices are differently dealt with later in the article.
  • torch.utils.data.Dataset is the class providing the prototype for map-style datasets. It is an abstract class implementing dedicated methods to deal with different steps in the data pipeline.

Let's first look at what abstract classes are while thoroughly understanding the structure of the PyTorch Dataset class.

Abstract Class

  • A super class is considered like a blueprint for its child/subclasses. The subclasses inherit the properties of the superclass.
  • An Abstract class is a class including a set of methods called abstract methods. These methods are called abstract as they must be implemented within any subclass of the parent Abstract class.
  • Another point to be noted about abstract methods is that they are just declared in the parent abstract class but aren't implemented there.

Hence, essentially, abstract methods are methods that are implemented through sub-classes.

  • Python has a dedicated module called abc for Abstract classes. The @abstractmethod decorator of the abc module is used to tell Python that the declared method is abstract and must be overridden in the subclasses. We only need to put this decorator over the method we want to be treated as abstract and the ABC module takes care of the other things.

As for torch.utils.data.Dataset being an abstract class, it has the following abstract method -

  • __getitem__()

That said, any class inheriting from torch.utils.data.Dataset must mandatorily implement the getitem() method.

torch.utils.data.Dataset

Now, let us look at the basic structure of the PyTorch dataset class torch.utils.data.Dataset and understand the mechanics of constructing a custom PyTorch Dataset class by subclassing it.

Here's the most basic structure of the PyTorch Dataset classtorch.utils.data.Dataset - source

Concurrently, any custom PyTorch Dataset class should have its skeleton structure like the following -

The first point to note is that any custom dataset class should inherit from the pytorch dataset class torch.utils.data.Dataset. Now, let us understand the general purpose of each method.

init(self, arg1, arg2, ...)

  • This method is meant to initialize some important variables that shall be used to get the data ready, like file path, root directory etc. Or, to define the image transformations, tokenizing strategy for text and so on.

Note: The actual data, let's say, the actual images for an image dataset should not be read directly in this method. This is called lazy loading of data. Real-world datasets are huge and it is rarely possible to fit the whole of the data in the memory at once. This is the reason why lazy loading of data is recommended.

len(self)

This method simply returns the length of the dataset.

  • While it is not strictly an abstract method, it is recommended that every subclass of the Dataset class implements this. The reason why that is so is that this method is expected to return the size of the dataset by many implementations of the class torch.utils.data.Sampler and the default options of the classtorch.utils.data.DataLoader.

getitem()

This method is the real deal.

  • This is where the "actual data" (both input feature data and the corresponding labels) are fetched in the memory, preprocessed with suitable steps and hence transformed into a form that's ready to be fed into the model for training.
  • Ideally, no data pre-processing steps should be coded anywhere in the whole model training pipeline but for this method. To summarise, input data containing features, and the output labels are returned by this method in a form that is ready to be directly fed into the model.
  • One important point to remember about this method - Although, we say that the data and labels returned from this method are in a form that's suitable to directly feed into the model, this in no way should imply that this method returns the data after putting it on the GPUs. Doing that could lead to CUDA out of memory errors.

Note also: As we can see the above methods all have one thing in common - their names begin and end with double underscores.

These actually belong to a category of methods called dunder methods. We will shortly look more into dunder methods while focussing on their utility when it comes to the PyTorch Dataset class.

Loading a Dataset in PyTorch

Now that we have gone through the mechanics of creating our custom dataset class in PyTorch, it is time to implement that in code.

But before that, let us understand the details of the other primitive that we talked about earlier - torch.utils.data.Dataloader.

DataLoader Class

  • The DataLoader class enables parallel data generation in the form of batches. This is necessary as we do not want the data loading part to be a bottleneck in our model training pipeline.

  • As for being a bottleneck, we are talking about a sequential pipeline where first a batch of data is fetched, preprocessed, and made compatible with what the model expects; then it is fed to the model for training.

    When the model is done processing one batch it waits for another batch to get ready - in essence, the GPUs and the model sit idle.

We'd rather want this process to be parallel in the sense that -

  • Ideally, the model and the GPUs should not have to wait till the time another batch of data gets fetched, preprocessed and made available for the model to use for training.
  • In other words, we would want to generate the batches of data in parallel as the model trains - this means that as soon as the model is done processing the current batch of data, a new batch of data should be instantly available for the model to process.

Hence, instead of a sequential process we want it to be a parallel process...

This is ensured by the DataLoader class. The DataLoader class utilizes multiple "sub-processes" to generate data on multiple CPU cores in parallel. This way it makes batches of data readily available for the model to use as the "main process", the one that's to deal with the model training part, does not need to bother with the Data loading part.

PyTorch DataLoader

Let us look at the syntax of DataLoader class while going through some of its parameters -

  • dataset - an instance of our custom Dataset class

  • batch_size - defines how many samples per batch to load

  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).

  • num_workers (int, optional) – defines how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

  • pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.

  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of the dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) We might want to drop the smaller batch (the last one) as it might lead to noisy updates to our model parameters.

  • pin_memory_device (str, optional) – the dataloader will copy Tensors into the device's pinned memory before returning them if pin_memory is set to True. This later allows faster transfer of data from CPU to GPU.

Image Dataset

Now, let us practically implement a custom Dataset class for an image dataset. We will work with this Pawpularity dataset from Kaggle.

The training csv file has the following columns -

training-csv-file

For the sake of brevity, we will work with only two columns for now as we write our custom Dataset class implementation, namely -

  • Id column - contains the names of jpg images in the train folder. This acts as our feature variable X.
  • Pawpularity column - contains the corresponding target values. This is Y - what we predict.

Let us one by one implement all the methods that we previously defined in our custom dataset class template -

  • __init__

As discussed earlier, init method is used to define some standard variables. Here, we define id as the column containing image ids in our train data frame df.

Similarly, variable labels contains the target labels column from the train data frame df called 'Pawpularity'. The variable root_dir stores the path to the directory where the images are stored.

All of these variables shall be used later in the code when we fetch actual images from their paths.

  • __len__

This method simply returns the length of the id column which is also the length of our dataset.

  • __getitem__

Let us quickly discuss in points what's happening in this piece of code -

  1. This method first uses the root_dir variable from earlier to create the complete path to the image in the row idx.
  2. Using this path, the actual image is read using the Image module from PIL and stored in the variable img.
  3. Now, img is converted to RGB mode post which a certain set of transformations are applied to the image using the variable self.transform to make it compatible with what the model expects. We will see these transformations shortly.
  4. Finally, the actual image file (not just the path or the name of the image) and the label (Popularity column) corresponding to the particular idx is returned from the method.

We will shortly understand the significance of this method in the data-loading process.

The hidden caveats?

  • The conversion of images from BGR to RGB is necessary as the transformations we apply via self.transform expect the image to be in RGB format.
  • Before returning a label, we transform it into a PyTorch tensor using torch.tensor(self.labels[idx], dtype = torch.float32). This is necessary as tensor is the core data structure when working with models in PyTorch. Consequently, models need the data to be in tensor format only. We will also ensure that the image data is transformed into PyTorch tensors as well - this will happen by means of self.transform only.

Now, let us compile these three methods to finally write our custom dataset class along with the dataloader.

Note in the above code snippet, we do not use super.__init__(). Since the base class torch.utils.data.Dataset does not have an init method implemented, using that would not add value.

train_dataset is an instance of our custom PawpularityImageData class.

Now, to be able to load the data efficiently, we use the DataLoader class and hence train_loader is an instance of the DataLoader class that we create by passing the train_dataset instance to it.

Now, let us check if we were able to implement our custom PyTorch dataset class successfully. We will loop over the dataloader instance like so -

Output:

Perfect! But, what actually happened? How did the dataloader instance train_loader utilise our custom dataset instance?

Let's dive a bit deeper into answering such questions -

  • As we iterate over the train_loader, num_workers number of worker processes are created and the main process creates an integer-valued sampler to generate indexes idx that are sent across the worker processes along with the dataset instance train_dataset that we passed to the dataloader.

  • These worker processes now deal with the data loading part, but how?

    Recall the __getitem__() method that we implemented in our custom dataset class PawpularityImageData to extract inputs and outputs corresponding to a particular index idx - that method is what the worker processes use to fetch the data from.

  • Now recall something that was mentioned earlier in this article when we got introduced to the concept of map style datasets. 'The indices do not have to be mandatorily integral and could be possibly non-integral.'

    By this time it should be clear how integral valued indices can be dealt with - by using the integer-valued sampler that the main process creates as discussed above.

What about non-integral indices?

  • Well, Pytorch provides torch.utils.data.Sampler classes that can be used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets.

  • We can use the Dataloader's sampler argument to specify a custom Sampler object that at each time yields the next index/key to fetch.

Text Dataset

Now, let us strengthen our understanding by working with another common type of dataset - a text dataset.

We will work with the spotify review dataset from Kaggle and for the sake of brevity, we will be dealing with two columns only, namely -

  • Review column - it contains the text review data given by users; We will use these reviews as the features X to predict Y.
  • Rating column - the corresponding rating to the review; we will use this column as the label Y.

Let us once again one by one implement all the methods that we previously defined in our custom dataset class template to write our custom PyTorch Dataset class that we call TextDataset -

  • Like earlier, we are defining some variables in the init method of which tokenizer is one. We will get back to this variable later too, but for now - this variable shall be used to convert text reviews into numerical vector form as Computers understand numbers and not text!
  • Similarly, the __len__() method returns the length of the review column that is also the length of the dataset.
  • __getitem__() extracts the text review and label at the particular index idx, pre-processes the text using a small utility function namely preprocess. Post this, the text is tokenized using self.tokenizer (this also ensures the returned vectors are PyTorch tensors by means of return_tensors = 'pt' argument) and finally the tokenized text review and the corresponding label rating are returned from the method - both as PyTorch tensors.

Let us now load the data using Dataloader like so -

tokenizer is an instance of the BertTokenizer class. tokenizer uses the vocabulary and tokenizing scheme of the pre-trained model 'bert-base-uncased' to convert text reviews into tokens and further encode these tokens into vectors.

As a debugging step, we will simply iterate over our train_loader instance to see our implementation is correct -

Output:

That is it! Our implementation works perfectly fine.

Before wrapping up, let us look at dunder methods.

What are Dunder Methods?

  • Dunder methods or Magic methods in Python are a special category of methods that start and end with the double underscores, hence the name dunder.
  • Dunder methods are not meant to be invoked directly by the programmer, rather their invocation happens internally.
  • When an instance of the class is subjected to certain actions, these methods are invoked producing some result.

For example, when we add two numbers using the + plus operator, internally, the dunder method __add__() is called.

Let's see what actions on an instance of the PawpularityImageData class we defined earlier invokes these dunder methods.

__init__

This method is called internally as soon as an instance of the class is created. This means that any instance of the class, let's say, p1 has all the variables defined in the __init__ method associated with it. So, we could access them like so -

Output -

__len__

this method is called internally if we take the len of any instance of the class. hence, len(p1) shall simply return the length of the train dataframe, like so -

output -

9912

__getitem__

This method is called internally whenever we index any instance of the class. so p1[0] shall give us the id and popularity score from the first row.

Conclusion

  • To conclude, in this article, we learned about the need for and the importance of PyTorch dataset class, while also learning to implement our custom dataset class in PyTorch as a child class to the parent class torch.utils.data.Dataset.
  • We defined guidelines to implement a custom PyTorch dataset class for almost any type of dataset wherein we walked through the three most basic methods to be implemented in any custom dataset class in PyTorch.
  • We walked through fully implemented code examples for two datasets containing two different types of data - an image dataset and a text dataset. We also looked into the theoretical as well as practical details of the Dataloader class while focussing on its inevitability in any deep learning modelling pipeline. We touched upon how datasets with non-integral indices could be handled using the data loader itself.
  • Additionally, we also understood two Python concepts concerning the PyTorch dataset and dataloader class - abstract classes and dunder methods while demonstrating their use when it comes to the PyTorch dataset class using code examples.