Asma QuickStart guide: Dataset and Dataloader in Pytorch

AsmaTheWizard
7 min readFeb 6, 2023

--

Credit: DALL-E

In this post, we’ll explore how to use the Dataset and DataLoader classes to efficiently load, pre-process, and batch your data, and provide examples of how to use it. Whether you're a beginner or an experienced PyTorch user, this post will provide you with a comprehensive understanding of how to use these essential components in your deep-learning projects.

What is Pytorch?

It is an open-source software library for machine learning(ML), primarily used for natural language processing (NLP) and computer vision tasks. It provides a rich set of tools and libraries for data loading, processing, and model training, making it a popular choice for both research and production use.

What is Dataset in Pytorch?

In PyTorch, a dataset is a collection of data that can be used to train a ML model. It is a key component of the PyTorch data loading and processing pipeline.

It’s worth mentioning that the ML pipeline in PyTorch can be represented as follows:

Preprocess data /creates dataset -> load data -> train -> evaluate

To create a dataset, you need to do the following steps:

1-Import libraries

2- Create a class that inherits from Dataset class

3-Override __len__ and __getitem__ methods, the first method returns the number of samples in the dataset, while the second method returns a specific sample from the dataset given an index.

What is Dataloader in Pytorch?

It’s used to easily load the dataset and iterate over the data in a dataset efficiently. Not only that, but it also provides batching, shuffling, and multi-threaded loading features.

You can use built-in datasets by the PyTorch team for famous datasets such as MNIST here or build your own custom dataset from scratch which what I will cover in this post.

Structure of Data:

Annotations:

You can have your annotations saved as one JSON file, or one txt or XML file per image. It doesn’t make that huge difference because, at the end of the day, we just want to locate bounding boxes and labels for each image. In this case, I had annotation saved as XML files, one XML file per each image, and the file name match the image name.

Folder Structure:

You can save all of your images and labels in one folder or separate them, personally, I prefer to follow this folder structure to make things simpler and straightforward

Root Folder:
- Dataset Folder
- - Train folder
- - - IMAGES folder
- - - LABELS folder
- - Test Folder
- - - IMAGES folder
- - - LABELS folder
- Model Folder

Bounding Boxes(2D):

I mentioned earlier that the file of annotation n doesn’t make a huge difference, which it’s true but the format of Bounding Boxes-in the annotation file- will make or break your project. You can convert Bounding boxes before the dataset step or during the dataset step.

Bounding boxes formats can be:

  • x, y, width, height: top-left corner of the bounding box, width, and height of the bounding box. Commonly known as COCO format (from the COCO dataset)
  • x1, y1, x2, y2: top-left corner or bounding box and bottom-right corner, also it can be written as x_min, y_min, x_max, y_max, both are the same. Commonly known as PASCAL VOC (from the PASCAL VOC dataset).
  • x_center, y_center, width, height: it represents the center coordinate of the bounding box with its width and height of it. Till now, I saw this format only used with any YOLO model.

Code:

Libraries

#libraries
import os
import albumentations as A # for image transformation
import cv2 # load image, resize
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

#because my annotation files are XML, so I'll use it to extract data
from xml.etree import ElementTree as et

Dataset Class Structure

We will create a class, you can name it whatever you want, here I named it CustomDataset :) the class will inherit from Dataset, and it will have 4 methods:

  • __init__: initiate variables values and get the list of images we have (just names). The dataset will not work without this method.
  • get_image_files: helper method to get the list of images’ names.
  • __getitem__: here we will get the image, its labels, and its bounding boxes by id. The dataset will not work without this method.
  • __len__: to get length of images list that we created in __init__ method
class CustomDataset(Dataset):
def __init__(self):
...

def get_image_files(self):
...

def __getitem__(self):
...


def __len__(self):
...

__init__

You don’t have to follow the code exactly. What we need to do here is

1- set paths for the images folder and annotations folder, you still can use only the root path and inside code build subdirectory paths.

2- pass the list of classes /label names, the order is important here.

3- pass the new height and width of an image to resize it later in __getitem__. You can resize all images beforehand and skip resizing each image in __getitem__ but if you resize the image you need to calculate a new bounding box.

4- create a list of all image names. Again, to simplify things, save names in the list without extension.

5- pass transformer function, we will create a global function using albumentations library, one for training and another one for testing.

def __init__(self, dir_img_path, dir_annotation_path, classes, width, height, transforms=None):
self.dir_img_path = dir_img_path
self.dir_annotation_path = dir_annotation_path
self.classes = classes
self.new_height = height
self.new_width = width
self.all_images = self.get_image_files()
self.transforms = transforms

__getitem__

Here, we mainly do 2 things:

1- load the image and process it (convert color channel, resize) also apply the transformer.

2- load the annotation file, extract data from it (the extraction part will vary depending on your annotation file and format), process the bounding box (resize if we resized its image, convert from one format to another if needed, normalization procedure if needed), and extract label(convert from text to class id).

def __getitem__(self, idx):
#get image name by id
image_name = self.all_images[idx]
image_path = os.path.join(self.dir_img_path, image_name)

#read the image
image = cv2.imread(image_path)

#convert BGR to RGB color format, because cv2 read image as BRG and we need RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
image_resized = cv2.resize(image, (self.new_width, self.new_height)) #resize image
image_resized /= 255.0

# get the annotation file
annot_filename = image_name + '.xml'
annot_file_path = os.path.join(self.dir_annotation_path, annot_filename)

boxes = []
labels = []

#because my annotation file is xml
tree = et.parse(annot_file_path)
root = tree.getroot()

#get the height and width of the image (original height and width)
image_width = image.shape[1]
image_height = image.shape[0]

for member in root.findall('object'):
labels.append(self.classes.index(member.find('name').text))
xmin = int(member.find('bndbox').find('xmin').text)
xmax = int(member.find('bndbox').find('xmax').text)
ymin = int(member.find('bndbox').find('ymin').text)
ymax = int(member.find('bndbox').find('ymax').text)

#resize the bounding boxes
xmin_final = (xmin/image_width)*self.new_width
xmax_final = (xmax/image_width)*self.new_width
ymin_final = (ymin/image_height)*self.new_height
yamx_final = (ymax/image_height)*self.new_height

#if you need to calculate bounding boxes normalization, then you should do it here before appending the bounding box to the list
boxes.append([xmin_final, ymin_final, xmax_final, yamx_final])

#bounding box to tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
#area of the bounding boxes
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
#no crowd instances
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
#labels to tensor
labels = torch.as_tensor(labels, dtype=torch.int64)
# prepare the final `target` dictionary
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["area"] = area
target["iscrowd"] = iscrowd
image_id = torch.tensor([idx]) #convert img id to tensor as well
target["image_id"] = image_id

#apply the image transforms
if self.transforms:
sample = self.transforms(image = image_resized,
bboxes = target['boxes'],
labels = labels)
image_resized = sample['image']
target['boxes'] = torch.Tensor(sample['bboxes'])

return image_resized, target

__len___

Just return the length of the image names list.

    def __len__(self):
return len(self.all_images)

Transformer

A transformer in PyTorch using Albumentations is a data augmentation technique that is used to increase the size of the training dataset by applying random transformations to the existing images. Here, I created 2 functions, one for the training dataset and the other one for the testing dataset. You can play with configuration here, the only thing that you need to keep in mind is that you need to tell it the bounding box format by using bbox_params. Otherwise, it will nor performs as you hope.

# the training tranforms
def get_train_transform():
return A.Compose([
A.Flip(0.5),
A.RandomRotate90(0.5),
ToTensorV2(p=1.0),
], bbox_params={
'format': 'pascal_voc',
'label_fields': ['labels']
})

# the test transforms
def get_test_transform():
return A.Compose([
ToTensorV2(p=1.0),
], bbox_params={
'format': 'pascal_voc',
'label_fields': ['labels']
})

Now, we will create 2 datasets as follows:

train_dataset = CustomDataset(IMG_TRAIN_DIR, ANN_TRAIN_DIR, CLASSES,  WIDTH, HEIGHT,  get_train_transform())
valid_dataset = CustomDataset(IMG_TEST_DIR, ANN_TEST_DIR, CLASSES, WIDTH, HEIGHT, get_test_transform())

Dataloader:

Now we can load the dataset to use it later for training, we will have to specify batch size and pass a function to combine a list of samples into a batch.

#used to combine a list of samples into a batch, which is then passed
def collate_fn(batch):
return tuple(zip(*batch))
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
collate_fn=collate_fn
)

valid_loader = DataLoader(
valid_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=0,
collate_fn=collate_fn
)

In conclusion, Dataset and DataLoader classes are fundamental components in the PyTorch framework that make it easy to load and process your data. By using these classes, you can efficiently load and batch your data, apply pre-processing transformations, and feed it into your model for training. Understanding how to use these classes will help you build robust and efficient deep-learning pipelines, and will enable you to tackle a wide range of tasks with ease. This post has provided a comprehensive introduction to using the Dataset and DataLoader classes in PyTorch, and I hope it has helped you to gain a deeper understanding of these important concepts. Happy learning!

And that’s a wrap!

Resources:

https://chat.openai.com/

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

AsmaTheWizard
AsmaTheWizard

Written by AsmaTheWizard

Developer by morning and wizard by night

No responses yet