In this post we present an example of images data loader. We have both loading the images, and augmentation of them differently in each epoch.
import pandas as pd
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
class MyDataset(torch.utils.data.Dataset):
def __init__(self, csv_path, images_folder, transform=None):
self.images_folder = images_folder
df = pd.read_csv(csv_path)
self.images_names = df['file name']
self.images_labels = df['label']
self.transform = transform
def __getitem__(self, item_index):
image_path = self.images_folder + '/' + self.images_names[item_index]
image = Image.open(image_path)
if self.transform is not None:
image = self.transform(image)
label = self.images_labels[item_index]
return image, label
def __len__(self):
return self.images_names.shape[0]
def main():
# NOTICE: should use transformers also on the test/validation datasets
custom_transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(size=(32, 32)),
torchvision.transforms.RandomCrop(size=(28, 28)),
torchvision.transforms.RandomRotation(degrees=30,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
torchvision.transforms.ToTensor(),
# torchvision.transforms.Lambda(lambda item: item / 256.0),
# torchvision.transforms.Normalize(mean=(0.5,), std=(0.5,))
]
)
dataset = MyDataset(
csv_path='images.csv',
images_folder='images',
transform=custom_transform
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=100,
drop_last=False, # drop last batch
shuffle=True,
num_workers=1, # parallel data loading
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
number_of_epochs = 3
augmented_index = 0
for epoch_index in range(number_of_epochs):
print(f'epoch {epoch_index}')
for batch_index, (x, y) in enumerate(data_loader):
print(f'batch {batch_index}')
x = x.to(device)
y = y.to(device)
print(x.shape, y.shape)
augmented_image = x.cpu()
augmented_index += 1
torchvision.utils.save_image(augmented_image, f'augmented/{augmented_index}.png')
main()
The original images are:
And the augmented images per epoch are:
No comments:
Post a Comment