Full Blog TOC

Full Blog Table Of Content with Keywords Available HERE

Saturday, August 17, 2024

Pytorch DataLoaders and Transformers


In this post we present an example of images data loader. We have both loading the images, and augmentation of them differently in each epoch. 


import pandas as pd
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader


class MyDataset(torch.utils.data.Dataset):
def __init__(self, csv_path, images_folder, transform=None):
self.images_folder = images_folder

df = pd.read_csv(csv_path)
self.images_names = df['file name']
self.images_labels = df['label']
self.transform = transform

def __getitem__(self, item_index):
image_path = self.images_folder + '/' + self.images_names[item_index]
image = Image.open(image_path)
if self.transform is not None:
image = self.transform(image)

label = self.images_labels[item_index]
return image, label

def __len__(self):
return self.images_names.shape[0]


def main():
# NOTICE: should use transformers also on the test/validation datasets

custom_transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(size=(32, 32)),
torchvision.transforms.RandomCrop(size=(28, 28)),
torchvision.transforms.RandomRotation(degrees=30,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
torchvision.transforms.ToTensor(),
# torchvision.transforms.Lambda(lambda item: item / 256.0),
# torchvision.transforms.Normalize(mean=(0.5,), std=(0.5,))
]
)
dataset = MyDataset(
csv_path='images.csv',
images_folder='images',
transform=custom_transform
)

data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=100,
drop_last=False, # drop last batch
shuffle=True,
num_workers=1, # parallel data loading
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
number_of_epochs = 3
augmented_index = 0
for epoch_index in range(number_of_epochs):
print(f'epoch {epoch_index}')
for batch_index, (x, y) in enumerate(data_loader):
print(f'batch {batch_index}')
x = x.to(device)
y = y.to(device)
print(x.shape, y.shape)
augmented_image = x.cpu()
augmented_index += 1
torchvision.utils.save_image(augmented_image, f'augmented/{augmented_index}.png')


main()


The original images are:





And the augmented images per epoch are:










No comments:

Post a Comment