Full Blog TOC

Full Blog Table Of Content with Keywords Available HERE

Monday, August 26, 2024

Variational AutoEncoder

 

In this post we display a Variational AutoEncoder. This is a modified version of the AutoEncoder presented in a previous post.

This work is based on the class lesson of Sebastian Raschka.


We can see the training loss graph:




Encoding and decoding from actual samples:



As well as new capability from the variational auto encoder enabling us to generate brand new images based on normal distribution of the latent space:





And the actual implementation:


import time

import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets

MODEL_PATH = "auto-encoder.model"
IMAGE_WIDTH = 28


class Trim(torch.nn.Module):
def __init__(self, *args):
super().__init__()
self.size = args[0]

def forward(self, x):
return x[:, :, :self.size, :self.size]


class VariationAutoEncoder(torch.nn.Module):

def __init__(self):
super(VariationAutoEncoder, self).__init__()

relu_slope = 0.01
convolution_kernel = 3
latent_space_external_width = 7
latent_space_external_channels = 64
latent_space_external_features = latent_space_external_channels * latent_space_external_width * latent_space_external_width
self.latent_space_internal_features = 8

self.encoder = torch.nn.Sequential(

# images input: 28 X 28

torch.nn.Conv2d(
in_channels=1,
out_channels=32,
stride=1,
kernel_size=convolution_kernel,
padding=1,
),

# channel size: 28 X 28

torch.nn.LeakyReLU(relu_slope),

torch.nn.Conv2d(
in_channels=32,
out_channels=64,
stride=2,
kernel_size=convolution_kernel,
padding=1,
),

# channel size: 14 X 14

torch.nn.LeakyReLU(relu_slope),

torch.nn.Conv2d(
in_channels=64,
out_channels=latent_space_external_channels,
stride=2,
kernel_size=convolution_kernel,
padding=1,
),

torch.nn.Flatten(),
)

self.encoded_mean = torch.nn.Linear(
in_features=latent_space_external_features,
out_features=self.latent_space_internal_features,
)

self.encoded_log_variance = torch.nn.Linear(
in_features=latent_space_external_features,
out_features=self.latent_space_internal_features,
)

self.decoder = torch.nn.Sequential(

torch.nn.Linear(
in_features=self.latent_space_internal_features,
out_features=latent_space_external_features,
),

torch.nn.Unflatten(
dim=1,
unflattened_size=(
latent_space_external_channels,
latent_space_external_width,
latent_space_external_width,
),
),

torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=64,
stride=2,
kernel_size=convolution_kernel,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=32,
stride=2,
kernel_size=convolution_kernel,
padding=0,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.ConvTranspose2d(
in_channels=32,
out_channels=1,
stride=1,
kernel_size=convolution_kernel,
padding=0,
),

Trim(IMAGE_WIDTH),

torch.nn.Sigmoid(),
)

def forward(self, x):
device = x.get_device()
z = self.encoder(x)

mean = self.encoded_mean(z)
log_variance = self.encoded_log_variance(z)

epsilon = torch.randn(mean.shape)
if device != -1:
epsilon = epsilon.to(device)

std = torch.exp(log_variance / 2.)
encoded = mean + epsilon * std

decoded = self.decoder(encoded)
return encoded, mean, log_variance, decoded

def generate_images(self, number_of_images):
random_encoding = torch.randn((number_of_images, self.latent_space_internal_features))
images = self.decoder(random_encoding)
return images


class Trainer:

def __init__(self):
self.examples_amount = 10
self.number_of_epochs = 3
self.reconstruction_term_weight = 1
batch_size = 32
learning_rate = 0.0005

self.loss_train_per_batch = []
self.loss_train_per_epoch = []
self.loss_test_per_epoch = []

limit_size = None
# limit_size = 13
if limit_size is None:
train_sampler = None
shuffle = True
else:
train_indexes = torch.arange(limit_size)
train_sampler = SubsetRandomSampler(train_indexes)
shuffle = False

self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)

transform_train = torchvision.transforms.Compose(
[
# torchvision.transforms.Resize(size=(IMAGE_WIDTH + 4, IMAGE_WIDTH + 4)),
# torchvision.transforms.RandomCrop(size=(IMAGE_WIDTH, IMAGE_WIDTH)),
# torchvision.transforms.RandomRotation(degrees=20),
torchvision.transforms.ToTensor(),
]
)

transform_test = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)

dataset_train = datasets.MNIST(root='local_cache_folder',
train=True,
transform=transform_train,
download=True)

dataset_test = datasets.MNIST(root='local_cache_folder',
train=False,
transform=transform_test)

print('train samples', dataset_train.data.shape[0])
print('test samples', dataset_train.data.shape[0])

self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=train_sampler,
)

self.loader_test = DataLoader(dataset=dataset_test,
batch_size=batch_size,
shuffle=False,
)

for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break

self.model = VariationAutoEncoder()
self.model = self.model.to(device=self.device)

self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

self.loss_function = torch.nn.functional.mse_loss

def run_batches(self, data_loader, batch_callback=None, print_details=False):
start_time = time.time()
total_loss = 0
total_samples = 0
batches_in_epoch_count = 0
number_of_batches = len(data_loader)
for batch_index, (batch_x, _) in enumerate(data_loader):
batches_in_epoch_count += 1
batch_x = batch_x.to(device=self.device)
batch_samples = batch_x.shape[0]
total_samples += batch_samples
encoded, mean, log_variance, decoded = self.model(batch_x)

kl_div = -0.5 * torch.sum(1 + log_variance
- mean ** 2
- torch.exp(log_variance),
axis=1) # sum over latent dimension

batchsize = kl_div.size(0)
kl_div = kl_div.mean() # average over batch dimension

pixelwise = self.loss_function(decoded, batch_x, reduction='none')
pixelwise = pixelwise.view(batchsize, -1).sum(axis=1) # sum over pixels
pixelwise = pixelwise.mean() # average over batch dimension

batch_loss = self.reconstruction_term_weight * pixelwise + kl_div

if batch_callback is not None:
batch_callback(batch_loss, batch_samples)
total_loss += batch_loss

if print_details:
passed_seconds = time.time() - start_time
if passed_seconds > 5:
start_time = time.time()
print('batch %05d/%05d loss %.5f' % (batch_index, number_of_batches, batch_loss.item()))

average_loss = total_loss / total_samples

return average_loss.cpu(), batches_in_epoch_count

def batch_callback_train(self, batch_loss, batch_samples):
self.loss_train_per_batch.append(batch_loss.item() / batch_samples)

self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()

def train_epoch(self, epoch_index):
start_time = time.time()
self.run_batches(data_loader=self.loader_train, batch_callback=self.batch_callback_train,
print_details=True)

with torch.no_grad():
self.model.eval()

epoch_loss_train, batches_in_epoch_count = self.run_batches(
data_loader=self.loader_train)

epoch_loss_test, _ = self.run_batches(data_loader=self.loader_test)

self.loss_train_per_epoch.append(epoch_loss_train)
self.loss_test_per_epoch.append(epoch_loss_test)

passed_seconds = time.time() - start_time
print('epoch %05d/%05d, batch duration %03.0f seconds, loss train %.5f, loss test %.5f' % (
epoch_index + 1, self.number_of_epochs, passed_seconds, epoch_loss_train, epoch_loss_test))

return batches_in_epoch_count

def train(self):
self.loss_train_per_batch = []
self.loss_train_per_epoch = []
self.loss_test_per_epoch = []
self.model.train()

start_time = time.time()
batches_in_epoch_count = 0
for epoch_index in range(self.number_of_epochs):
batches_in_epoch_count = self.train_epoch(epoch_index)

passed_seconds = time.time() - start_time
print('total training duration %03.0f seconds' % passed_seconds)

self.model.eval()

self.loss_train_per_epoch = self.spread_points(self.loss_train_per_epoch, batches_in_epoch_count)
self.loss_test_per_epoch = self.spread_points(self.loss_test_per_epoch, batches_in_epoch_count)
plt.clf()
plt.plot(self.loss_train_per_batch, color='b', label='train batch')
plt.plot(self.loss_train_per_epoch, color='g', label='train epoch')
plt.plot(self.loss_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylim(0, self.loss_train_per_epoch[0])
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")
self.loss_train_per_batch = []

def save(self):
torch.save(self.model.state_dict(), MODEL_PATH)

def load(self):
self.model = VariationAutoEncoder()
self.model.load_state_dict(torch.load(MODEL_PATH))
self.model = self.model.to(self.device)
self.model.eval()

def show_examples(self):
with torch.no_grad():
for _, (original_images, _) in enumerate(self.loader_test):
original_images = original_images.to(device=self.device)
original_images = original_images[:self.examples_amount]
_, _, _, decoded_images = self.model(original_images)
original_images = original_images.to('cpu')
decoded_images = decoded_images.to('cpu')
images_rows = [original_images, decoded_images]
self.plot_images("decoded.pdf", images_rows)
return

def generate_examples(self):
with torch.no_grad():
images = self.model.generate_images(self.examples_amount)
images_rows = [images]
self.plot_images("generated.pdf", images_rows)

@staticmethod
def plot_images(file_name, images_rows):
first_row = images_rows[0]
number_of_columns = len(first_row)
plt.clf()
fig, axes = plt.subplots(
nrows=len(images_rows),
ncols=number_of_columns,
sharex=True,
sharey=True,
figsize=(20, 2.5),
)

for row in range(len(images_rows)):
ax_row = axes[row]
images_row = images_rows[row]
for col in range(len(first_row)):
if len(images_rows) == 1:
ax = axes[col]
else:
ax = ax_row[col]
image = images_row[col]
ax.imshow(image.view((IMAGE_WIDTH, IMAGE_WIDTH)), cmap='binary')
plt.savefig(file_name)

@staticmethod
def spread_points(points, spread_factor):
result = []
for point in points:
for _ in range(spread_factor):
result.append(point)
return result


def set_seed():
random_seed = 123

if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_deterministic(True)

torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)


def main():
trainer = Trainer()
trainer.train()
trainer.save()
trainer.load()
trainer.show_examples()
trainer.generate_examples()


main()



No comments:

Post a Comment