Full Blog TOC

Full Blog Table Of Content with Keywords Available HERE

Monday, August 26, 2024

AutoEncoder

 



I this post we present an AutoEncoder Implementation in pytorch.
This is based on the Convolutional Autoencoder code by Sebastian Raschka.


We are using the MNIST images dataset to encode a gray scale 28X28 numbers images to be encoded  by only 2 numbers!

At the top of this post we have an image of the original digits images and the related decoded images. It is not perfect, but very close. We can get better performance by additional training (but I don't have a GPU).


Something to notice is that using randomness on the training dataloader seems to make result worse. This again might be related to more training required.

Also, removing to duplicate 64 channels convolution layers seems to have similar results.



import time

import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets

MODEL_PATH = "auto-encoder.model"
IMAGE_WIDTH = 28


class Trim(torch.nn.Module):
def __init__(self, *args):
super().__init__()
self.size = args[0]

def forward(self, x):
return x[:, :, :self.size, :self.size]


class AutoEncoder(torch.nn.Module):

def __init__(self):
super(AutoEncoder, self).__init__()

relu_slope = 0.01

self.encoder = torch.nn.Sequential(

torch.nn.Conv2d(
in_channels=1,
out_channels=32,
stride=1,
kernel_size=3,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.Conv2d(
in_channels=32,
out_channels=64,
stride=2,
kernel_size=3,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.Conv2d(
in_channels=64,
out_channels=64,
stride=2,
kernel_size=3,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.Conv2d(
in_channels=64,
out_channels=64,
stride=1,
kernel_size=3,
padding=1,
),

torch.nn.Flatten(),

torch.nn.Linear(
in_features=3136,
out_features=2,
),
)

self.decoder = torch.nn.Sequential(

torch.nn.Linear(
in_features=2,
out_features=3136,
),

torch.nn.Unflatten(
dim=1,
unflattened_size=(
64,
7,
7,
),
),

torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=64,
stride=1,
kernel_size=3,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=64,
stride=2,
kernel_size=3,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=32,
stride=2,
kernel_size=3,
padding=0,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.ConvTranspose2d(
in_channels=32,
out_channels=1,
stride=1,
kernel_size=3,
padding=0,
),

Trim(IMAGE_WIDTH),

torch.nn.Sigmoid(),
)

def forward(self, x):
z = self.encoder(x)
y = self.decoder(z)
return y


class Trainer:

def __init__(self):
self.number_of_epochs = 5
batch_size = 32
learning_rate = 0.0005

self.loss_train_per_batch = []
self.loss_train_per_epoch = []
self.loss_test_per_epoch = []

limit_size = None
# limit_size = 13
if limit_size is None:
train_sampler = None
shuffle = True
else:
train_indexes = torch.arange(limit_size)
train_sampler = SubsetRandomSampler(train_indexes)
shuffle = False

self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)

transform_train = torchvision.transforms.Compose(
[
# torchvision.transforms.Resize(size=(IMAGE_WIDTH + 4, IMAGE_WIDTH + 4)),
# torchvision.transforms.RandomCrop(size=(IMAGE_WIDTH, IMAGE_WIDTH)),
# torchvision.transforms.RandomRotation(degrees=20),
torchvision.transforms.ToTensor(),
]
)

transform_test = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)

dataset_train = datasets.MNIST(root='local_cache_folder',
train=True,
transform=transform_train,
download=True)

dataset_test = datasets.MNIST(root='local_cache_folder',
train=False,
transform=transform_test)

print('train samples', dataset_train.data.shape[0])
print('test samples', dataset_train.data.shape[0])

self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=train_sampler,
)

self.loader_test = DataLoader(dataset=dataset_test,
batch_size=batch_size,
shuffle=False,
)

for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break

self.model = AutoEncoder()
self.model = self.model.to(device=self.device)

self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

self.loss_function = torch.nn.functional.mse_loss

def run_batches(self, data_loader, batch_callback=None, print_details=False):
start_time = time.time()
total_loss = 0
total_samples = 0
batches_in_epoch_count = 0
for batch_index, (batch_x, _) in enumerate(data_loader):
batches_in_epoch_count += 1
batch_x = batch_x.to(device=self.device)
batch_samples = batch_x.shape[0]
total_samples += batch_samples
logics = self.model(batch_x)
batch_loss = self.loss_function(logics, batch_x)
if batch_callback is not None:
batch_callback(logics, batch_loss, batch_samples)
total_loss += batch_loss

if print_details:
passed_seconds = time.time() - start_time
if passed_seconds > 5:
start_time = time.time()
print('batch %05d loss %.5f' % (batch_index, batch_loss.item()))

average_loss = total_loss / total_samples

return average_loss.cpu(), batches_in_epoch_count

def batch_callback_train(self, _, batch_loss, batch_samples):
self.loss_train_per_batch.append(batch_loss.item() / batch_samples)

self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()

def train_epoch(self, epoch_index):
start_time = time.time()
self.run_batches(data_loader=self.loader_train, batch_callback=self.batch_callback_train,
print_details=True)

with torch.no_grad():
self.model.eval()
epoch_loss_train, batches_in_epoch_count = self.run_batches(
data_loader=self.loader_train)
epoch_loss_test, _ = self.run_batches(data_loader=self.loader_test)

self.loss_train_per_epoch.append(epoch_loss_train)
self.loss_test_per_epoch.append(epoch_loss_test)

passed_seconds = time.time() - start_time
print('epoch %05d/%05d, batch duration %03.0f seconds, loss train %.5f, loss test %.5f' % (
epoch_index + 1, self.number_of_epochs, passed_seconds, epoch_loss_train, epoch_loss_test))

return batches_in_epoch_count

def train(self):
self.loss_train_per_batch = []
self.loss_train_per_epoch = []
self.loss_test_per_epoch = []
self.model.train()

start_time = time.time()
batches_in_epoch_count = 0
for epoch_index in range(self.number_of_epochs):
batches_in_epoch_count = self.train_epoch(epoch_index)

passed_seconds = time.time() - start_time
print('total training duration %03.0f seconds' % passed_seconds)

self.model.eval()

self.loss_train_per_epoch = self.spread_points(self.loss_train_per_epoch, batches_in_epoch_count)
self.loss_test_per_epoch = self.spread_points(self.loss_test_per_epoch, batches_in_epoch_count)
plt.clf()
plt.plot(self.loss_train_per_batch, color='b', label='train batch')
plt.plot(self.loss_train_per_epoch, color='g', label='train epoch')
plt.plot(self.loss_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylim(0, 0.003)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")
self.loss_train_per_batch = []

def save(self):
torch.save(self.model.state_dict(), MODEL_PATH)

def load(self):
self.model = AutoEncoder()
self.model.load_state_dict(torch.load(MODEL_PATH))
self.model = self.model.to(self.device)
self.model.eval()

def show_examples(self):

examples_amount = 10
with torch.no_grad():
for _, (original_images, _) in enumerate(self.loader_test):
original_images = original_images.to(device=self.device)
original_images = original_images[:examples_amount]
decoded_images = self.model(original_images)
batch_samples = original_images.shape[0]
original_images = original_images.to('cpu')
decoded_images = decoded_images.to('cpu')

plt.clf()
fig, axes = plt.subplots(nrows=2, ncols=batch_samples, sharex=True, sharey=True, figsize=(20, 2.5))
for i in range(batch_samples):
for ax, img in zip(axes, [original_images, decoded_images]):
curr_img = img[i].detach()
ax[i].imshow(curr_img.view((IMAGE_WIDTH, IMAGE_WIDTH)), cmap='binary')
plt.savefig("examples.pdf")

return

@staticmethod
def spread_points(points, spread_factor):
result = []
for point in points:
for _ in range(spread_factor):
result.append(point)
return result


def set_seed():
random_seed = 123

if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_deterministic(True)

torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)


def main():
trainer = Trainer()
trainer.train()
trainer.save()
trainer.load()
trainer.show_examples()


main()




No comments:

Post a Comment