Full Blog TOC

Full Blog Table Of Content with Keywords Available HERE

Thursday, August 29, 2024

Using GAN to Generate Digits




In this post we present a simple GAN implementation to generate digits.

This post is based on the GAN explanation video by Sebastian Raschka.


GAN is an abbreviation for Generative Adversarial Network. It includes 2 networks: the generator and the discriminator. The generator generates new images, and the discriminator receives a real image and a generated 

The discriminator goal is to classifiy if the received image is real or generated, while the discriminator goal is to fool the discriminator to classify the generated images as real ones.

This implementation uses the MNIST database as input, and fully connected networks  with a single hidden layer. 

We can see the results of the generated images at the top of this post. These are not perfect, but they appear to be on the way to be so. More training and possibly using convolutional networks would have gain better results.

The training was run for 50 epochs, and the loss graph is:





The related code implementation is:



import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets

MODEL_PATH = "gan.model"


class GenerativeAdversarialNetwork(torch.nn.Module):

def __init__(self):
super(GenerativeAdversarialNetwork, self).__init__()

self.image_width = 28
self.noise_width = 100
generator_hidden_dimension = 128
discriminator_hidden_dimension = 128

self.generator = torch.nn.Sequential(

torch.nn.Flatten(),

torch.nn.Linear(
in_features=self.noise_width,
out_features=generator_hidden_dimension
),

torch.nn.LeakyReLU(inplace=True),

torch.nn.Linear(
in_features=generator_hidden_dimension,
out_features=self.image_width * self.image_width,
),

torch.nn.Sigmoid(),
)

self.discriminator = torch.nn.Sequential(

torch.nn.Flatten(start_dim=1),

torch.nn.Linear(
in_features=self.image_width * self.image_width,
out_features=discriminator_hidden_dimension,
),

torch.nn.LeakyReLU(inplace=True),

torch.nn.Linear(
in_features=discriminator_hidden_dimension,
out_features=1,
),
)

def forward_generator(self, x):
z = self.generator(x)

image = z.view(
z.size(0), # number of samples
1, # color channel
self.image_width,
self.image_width,
)

return image

def forward_discriminator(self, x):
z = self.discriminator(x)
return z


class Trainer:

def __init__(self):
self.number_of_epochs = 50
batch_size = 32
learning_rate = 0.0002

self.loss_generator_per_batch = []
self.loss_generator_per_epoch = []
self.loss_discriminator_per_batch = []
self.loss_discriminator_per_epoch = []

limit_size = None
# limit_size = 13
if limit_size is None:
train_sampler = None
shuffle = True
else:
train_indexes = torch.arange(limit_size)
train_sampler = SubsetRandomSampler(train_indexes)
shuffle = False

self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)

transform_train = torchvision.transforms.Compose(
[
# torchvision.transforms.Normalize((0.5,), (0.5,)),
# torchvision.transforms.Resize(size=(IMAGE_WIDTH + 4, IMAGE_WIDTH + 4)),
# torchvision.transforms.RandomCrop(size=(IMAGE_WIDTH, IMAGE_WIDTH)),
# torchvision.transforms.RandomRotation(degrees=20),
torchvision.transforms.ToTensor(),
]
)

dataset_train = datasets.MNIST(root='local_cache_folder',
train=True,
transform=transform_train,
download=True)

print('train samples', dataset_train.data.shape[0])

self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=train_sampler,
)

for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break

self.model = GenerativeAdversarialNetwork()
self.model = self.model.to(device=self.device)

self.optimizer_generator = torch.optim.Adam(self.model.generator.parameters(), lr=learning_rate)
self.optimizer_discriminator = torch.optim.Adam(self.model.discriminator.parameters(), lr=learning_rate)

self.loss_function = torch.nn.functional.binary_cross_entropy_with_logits

def train_single_batch(self, batch_images):
batch_images = batch_images.to(device=self.device)

batch_number_of_samples = batch_images.shape[0]

ones = torch.ones(batch_number_of_samples)
zeros = torch.zeros(batch_number_of_samples)
random_noise = torch.randn(batch_number_of_samples, self.model.noise_width)
generated_images = self.model.forward_generator(random_noise)

# train discriminator
self.optimizer_discriminator.zero_grad()

discriminator_predictions_for_real_images = self.model.forward_discriminator(batch_images)
discriminator_predictions_for_real_images = discriminator_predictions_for_real_images.squeeze()
discriminator_batch_loss_real_images = self.loss_function(discriminator_predictions_for_real_images, ones)

# using detach() since we want to use the results for 2 optimizers
discriminator_predictions_for_generated_images = self.model.forward_discriminator(generated_images.detach())
discriminator_predictions_for_generated_images = discriminator_predictions_for_generated_images.squeeze()
discriminator_batch_loss_generated_images = self.loss_function(
discriminator_predictions_for_generated_images, zeros)

discriminator_batch_loss = discriminator_batch_loss_real_images + discriminator_batch_loss_generated_images
discriminator_batch_loss = discriminator_batch_loss / 2
self.loss_discriminator_per_batch.append(discriminator_batch_loss.item())

discriminator_batch_loss.backward()
self.optimizer_discriminator.step()

# train generator
self.optimizer_generator.zero_grad()
discriminator_predictions_for_generated_images = self.model.forward_discriminator(generated_images)
discriminator_predictions_for_generated_images = discriminator_predictions_for_generated_images.squeeze()
generator_batch_loss = self.loss_function(discriminator_predictions_for_generated_images, ones)
self.loss_generator_per_batch.append(generator_batch_loss.item())
generator_batch_loss.backward()
self.optimizer_generator.step()

return discriminator_batch_loss.item(), generator_batch_loss.item()

def train_epoch(self, epoch_index):
epoch_start_time = time.time()
last_log_time = time.time()
discriminator_epoch_loss = 0
generator_epoch_loss = 0
number_of_batches_in_epoch = 0
number_of_batches = len(self.loader_train)
for batch_index, (batch_images, _) in enumerate(self.loader_train):
number_of_batches_in_epoch += 1
discriminator_batch_loss, generator_batch_loss = self.train_single_batch(batch_images)
discriminator_epoch_loss += discriminator_batch_loss
generator_epoch_loss += generator_batch_loss

passed_seconds = time.time() - last_log_time
if passed_seconds > 5:
last_log_time = time.time()
print('batch {:05d}/{:05d} discriminator loss {:.5f} generator loss {:.5f}'.format(
batch_index,
number_of_batches,
discriminator_batch_loss.item(),
generator_batch_loss.item(),
))

discriminator_epoch_loss = discriminator_epoch_loss / number_of_batches_in_epoch
self.loss_discriminator_per_epoch.append(discriminator_epoch_loss)
generator_epoch_loss = generator_epoch_loss / number_of_batches_in_epoch
self.loss_generator_per_epoch.append(generator_epoch_loss)

epoch_seconds = time.time() - epoch_start_time
print(
'epoch {:05d}/{:05d}, duration {:03.0f} seconds, discriminator loss {:.5f} generator loss {:.5f}'.format(
epoch_index + 1,
self.number_of_epochs,
epoch_seconds,
discriminator_epoch_loss,
generator_epoch_loss,
))

return number_of_batches_in_epoch

def show_examples(self, noise, file_name):
generated_images = self.model.forward_generator(noise)
combined_image = torchvision.utils.make_grid(generated_images, padding=2, nrow=10)

# tensor CHW -> plt HWC
combined_transpose = np.transpose(combined_image, (1, 2, 0))

plt.clf()
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(combined_transpose)
plt.savefig(file_name)

def plot_loss_graph(self, number_of_batches_in_epoch):

loss_discriminator_per_epoch = self.spread_points(self.loss_discriminator_per_epoch, number_of_batches_in_epoch)
loss_generator_per_epoch = self.spread_points(self.loss_generator_per_epoch, number_of_batches_in_epoch)

plt.clf()
plt.plot(self.loss_generator_per_batch, color='y', label='generator batch')
plt.plot(loss_generator_per_epoch, color='r', label='generator epoch')
plt.plot(self.loss_discriminator_per_batch, color='c', label='discriminator batch')
plt.plot(loss_discriminator_per_epoch, color='b', label='discriminator epoch')
plt.legend()
# plt.ylim(0, 0.5)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")

def train(self):
self.loss_generator_per_batch = []
self.loss_generator_per_epoch = []
self.loss_discriminator_per_batch = []
self.loss_discriminator_per_epoch = []
self.model.train()

start_time = time.time()
static_noise = torch.randn(100, self.model.noise_width)

for epoch_index in range(self.number_of_epochs):
number_of_batches_in_epoch = self.train_epoch(epoch_index)
self.plot_loss_graph(number_of_batches_in_epoch)
with torch.no_grad():
file_name = 'output/generated_epoch_{:03d}.pdf'.format(epoch_index)
self.show_examples(static_noise, file_name)

passed_seconds = time.time() - start_time
print('total training duration %03.0f seconds' % passed_seconds)

self.model.eval()

def save(self):
torch.save(self.model.state_dict(), MODEL_PATH)

def load(self):
self.model = GenerativeAdversarialNetwork()
self.model.load_state_dict(torch.load(MODEL_PATH))
self.model = self.model.to(self.device)
self.model.eval()

@staticmethod
def spread_points(points, spread_factor):
result = []
for point in points:
for _ in range(spread_factor):
result.append(point)
return result


def set_seed():
random_seed = 123

if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)


def main():
set_seed()
trainer = Trainer()
trainer.train()
# trainer.save()
# trainer.load()


main()

 




Tuesday, August 27, 2024

Update Image Features Using Variational AutoEncoder



 

In the Variational AutoEncoder post, we have created an encoder that can generate new images.

In this post, we use the encoding and labels to change the images to a specific behavior. Hence, we take a specific digit, and slowly morph it to another digit.

The implementation is the following:

  1. Target_Encoding =  average encoding for all images of digit X
  2. Non_Target_Encoding = average encoding for all images of other digits
  3. Convert_Vector = Target_Encoding - Non_Target_Encoding

Next, we can take any source image and update it to be similar to digit X by factor alpha:

  1. Encode the image
  2. Add alpha*Convert_Vector to the encoding
  3. Decode


The image at the top display gradual conversion of each digit to the target digit '1'.


The implementation is done by adding the following method to the trainer class in the Variational AutoEncoder post.



def variation_examples(self, variation_label):
matching_encoding_samples = 0
non_matching_encoding_samples = 0

matching_encoding_sum = torch.zeros(self.model.latent_space_internal_features, dtype=torch.float)
non_matching_encoding_sum = torch.zeros(self.model.latent_space_internal_features, dtype=torch.float)

start_time = time.time()

source_image_by_label = {}

for batch_index, (batch_x, batch_y) in enumerate(self.loader_train):
batch_x = batch_x.to(device=self.device)
encoded_converted, _, _, _ = self.model(batch_x)

for batch_sample in range(batch_x.shape[0]):
label = batch_y[batch_sample].item()
source_image_by_label[label] = batch_x[batch_sample]

matching_label_ids = batch_y == variation_label
non_matching_label_ids = ~matching_label_ids

matching_encoding = encoded_converted[matching_label_ids]
non_matching_encoding = encoded_converted[non_matching_label_ids]

matching_encoding_sum += torch.sum(matching_encoding, axis=0)
non_matching_encoding_sum += torch.sum(non_matching_encoding, axis=0)

matching_encoding_samples += matching_encoding.shape[0]
non_matching_encoding_samples += non_matching_label_ids.shape[0]
passed_seconds = time.time() - start_time
if passed_seconds > 5:
start_time = time.time()
print('batch %05d/%05d' % (batch_index, len(self.loader_train)))

matching_encoding_average = matching_encoding_sum / matching_encoding_samples
non_matching_encoding_average = non_matching_encoding_sum / non_matching_encoding_samples

matching_encodings = torch.unsqueeze(matching_encoding_average, dim=0)
non_matching_encodings = torch.unsqueeze(non_matching_encoding_average, dim=0)

matching_image = self.model.decoder(matching_encodings)
non_matching_image = self.model.decoder(non_matching_encodings)

image_row = [matching_image, non_matching_image]
images_rows = [image_row]
self.plot_images("variation_base_{}.pdf".format(variation_label), images_rows)

images_rows = []
convert_vector = matching_encoding_average - non_matching_encoding_average
for label in range(10):
source_image = source_image_by_label[label]
source_images = torch.unsqueeze(source_image, dim=0)
label_row = [source_image]
encoded_source, _, _, _ = self.model(source_images)
for convert_factor in range(10):
encoded_converted = encoded_source + convert_vector * convert_factor*0.5
decoded_items = self.model.decoder(encoded_converted)
label_row.append(decoded_items[0])
images_rows.append(label_row)
self.plot_images("variation_convert_{}.pdf".format(variation_label), images_rows)


Monday, August 26, 2024

Variational AutoEncoder

 

In this post we display a Variational AutoEncoder. This is a modified version of the AutoEncoder presented in a previous post.

This work is based on the class lesson of Sebastian Raschka.


We can see the training loss graph:




Encoding and decoding from actual samples:



As well as new capability from the variational auto encoder enabling us to generate brand new images based on normal distribution of the latent space:





And the actual implementation:


import time

import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets

MODEL_PATH = "auto-encoder.model"
IMAGE_WIDTH = 28


class Trim(torch.nn.Module):
def __init__(self, *args):
super().__init__()
self.size = args[0]

def forward(self, x):
return x[:, :, :self.size, :self.size]


class VariationAutoEncoder(torch.nn.Module):

def __init__(self):
super(VariationAutoEncoder, self).__init__()

relu_slope = 0.01
convolution_kernel = 3
latent_space_external_width = 7
latent_space_external_channels = 64
latent_space_external_features = latent_space_external_channels * latent_space_external_width * latent_space_external_width
self.latent_space_internal_features = 8

self.encoder = torch.nn.Sequential(

# images input: 28 X 28

torch.nn.Conv2d(
in_channels=1,
out_channels=32,
stride=1,
kernel_size=convolution_kernel,
padding=1,
),

# channel size: 28 X 28

torch.nn.LeakyReLU(relu_slope),

torch.nn.Conv2d(
in_channels=32,
out_channels=64,
stride=2,
kernel_size=convolution_kernel,
padding=1,
),

# channel size: 14 X 14

torch.nn.LeakyReLU(relu_slope),

torch.nn.Conv2d(
in_channels=64,
out_channels=latent_space_external_channels,
stride=2,
kernel_size=convolution_kernel,
padding=1,
),

torch.nn.Flatten(),
)

self.encoded_mean = torch.nn.Linear(
in_features=latent_space_external_features,
out_features=self.latent_space_internal_features,
)

self.encoded_log_variance = torch.nn.Linear(
in_features=latent_space_external_features,
out_features=self.latent_space_internal_features,
)

self.decoder = torch.nn.Sequential(

torch.nn.Linear(
in_features=self.latent_space_internal_features,
out_features=latent_space_external_features,
),

torch.nn.Unflatten(
dim=1,
unflattened_size=(
latent_space_external_channels,
latent_space_external_width,
latent_space_external_width,
),
),

torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=64,
stride=2,
kernel_size=convolution_kernel,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=32,
stride=2,
kernel_size=convolution_kernel,
padding=0,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.ConvTranspose2d(
in_channels=32,
out_channels=1,
stride=1,
kernel_size=convolution_kernel,
padding=0,
),

Trim(IMAGE_WIDTH),

torch.nn.Sigmoid(),
)

def forward(self, x):
device = x.get_device()
z = self.encoder(x)

mean = self.encoded_mean(z)
log_variance = self.encoded_log_variance(z)

epsilon = torch.randn(mean.shape)
if device != -1:
epsilon = epsilon.to(device)

std = torch.exp(log_variance / 2.)
encoded = mean + epsilon * std

decoded = self.decoder(encoded)
return encoded, mean, log_variance, decoded

def generate_images(self, number_of_images):
random_encoding = torch.randn((number_of_images, self.latent_space_internal_features))
images = self.decoder(random_encoding)
return images


class Trainer:

def __init__(self):
self.examples_amount = 10
self.number_of_epochs = 3
self.reconstruction_term_weight = 1
batch_size = 32
learning_rate = 0.0005

self.loss_train_per_batch = []
self.loss_train_per_epoch = []
self.loss_test_per_epoch = []

limit_size = None
# limit_size = 13
if limit_size is None:
train_sampler = None
shuffle = True
else:
train_indexes = torch.arange(limit_size)
train_sampler = SubsetRandomSampler(train_indexes)
shuffle = False

self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)

transform_train = torchvision.transforms.Compose(
[
# torchvision.transforms.Resize(size=(IMAGE_WIDTH + 4, IMAGE_WIDTH + 4)),
# torchvision.transforms.RandomCrop(size=(IMAGE_WIDTH, IMAGE_WIDTH)),
# torchvision.transforms.RandomRotation(degrees=20),
torchvision.transforms.ToTensor(),
]
)

transform_test = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)

dataset_train = datasets.MNIST(root='local_cache_folder',
train=True,
transform=transform_train,
download=True)

dataset_test = datasets.MNIST(root='local_cache_folder',
train=False,
transform=transform_test)

print('train samples', dataset_train.data.shape[0])
print('test samples', dataset_train.data.shape[0])

self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=train_sampler,
)

self.loader_test = DataLoader(dataset=dataset_test,
batch_size=batch_size,
shuffle=False,
)

for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break

self.model = VariationAutoEncoder()
self.model = self.model.to(device=self.device)

self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

self.loss_function = torch.nn.functional.mse_loss

def run_batches(self, data_loader, batch_callback=None, print_details=False):
start_time = time.time()
total_loss = 0
total_samples = 0
batches_in_epoch_count = 0
number_of_batches = len(data_loader)
for batch_index, (batch_x, _) in enumerate(data_loader):
batches_in_epoch_count += 1
batch_x = batch_x.to(device=self.device)
batch_samples = batch_x.shape[0]
total_samples += batch_samples
encoded, mean, log_variance, decoded = self.model(batch_x)

kl_div = -0.5 * torch.sum(1 + log_variance
- mean ** 2
- torch.exp(log_variance),
axis=1) # sum over latent dimension

batchsize = kl_div.size(0)
kl_div = kl_div.mean() # average over batch dimension

pixelwise = self.loss_function(decoded, batch_x, reduction='none')
pixelwise = pixelwise.view(batchsize, -1).sum(axis=1) # sum over pixels
pixelwise = pixelwise.mean() # average over batch dimension

batch_loss = self.reconstruction_term_weight * pixelwise + kl_div

if batch_callback is not None:
batch_callback(batch_loss, batch_samples)
total_loss += batch_loss

if print_details:
passed_seconds = time.time() - start_time
if passed_seconds > 5:
start_time = time.time()
print('batch %05d/%05d loss %.5f' % (batch_index, number_of_batches, batch_loss.item()))

average_loss = total_loss / total_samples

return average_loss.cpu(), batches_in_epoch_count

def batch_callback_train(self, batch_loss, batch_samples):
self.loss_train_per_batch.append(batch_loss.item() / batch_samples)

self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()

def train_epoch(self, epoch_index):
start_time = time.time()
self.run_batches(data_loader=self.loader_train, batch_callback=self.batch_callback_train,
print_details=True)

with torch.no_grad():
self.model.eval()

epoch_loss_train, batches_in_epoch_count = self.run_batches(
data_loader=self.loader_train)

epoch_loss_test, _ = self.run_batches(data_loader=self.loader_test)

self.loss_train_per_epoch.append(epoch_loss_train)
self.loss_test_per_epoch.append(epoch_loss_test)

passed_seconds = time.time() - start_time
print('epoch %05d/%05d, batch duration %03.0f seconds, loss train %.5f, loss test %.5f' % (
epoch_index + 1, self.number_of_epochs, passed_seconds, epoch_loss_train, epoch_loss_test))

return batches_in_epoch_count

def train(self):
self.loss_train_per_batch = []
self.loss_train_per_epoch = []
self.loss_test_per_epoch = []
self.model.train()

start_time = time.time()
batches_in_epoch_count = 0
for epoch_index in range(self.number_of_epochs):
batches_in_epoch_count = self.train_epoch(epoch_index)

passed_seconds = time.time() - start_time
print('total training duration %03.0f seconds' % passed_seconds)

self.model.eval()

self.loss_train_per_epoch = self.spread_points(self.loss_train_per_epoch, batches_in_epoch_count)
self.loss_test_per_epoch = self.spread_points(self.loss_test_per_epoch, batches_in_epoch_count)
plt.clf()
plt.plot(self.loss_train_per_batch, color='b', label='train batch')
plt.plot(self.loss_train_per_epoch, color='g', label='train epoch')
plt.plot(self.loss_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylim(0, self.loss_train_per_epoch[0])
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")
self.loss_train_per_batch = []

def save(self):
torch.save(self.model.state_dict(), MODEL_PATH)

def load(self):
self.model = VariationAutoEncoder()
self.model.load_state_dict(torch.load(MODEL_PATH))
self.model = self.model.to(self.device)
self.model.eval()

def show_examples(self):
with torch.no_grad():
for _, (original_images, _) in enumerate(self.loader_test):
original_images = original_images.to(device=self.device)
original_images = original_images[:self.examples_amount]
_, _, _, decoded_images = self.model(original_images)
original_images = original_images.to('cpu')
decoded_images = decoded_images.to('cpu')
images_rows = [original_images, decoded_images]
self.plot_images("decoded.pdf", images_rows)
return

def generate_examples(self):
with torch.no_grad():
images = self.model.generate_images(self.examples_amount)
images_rows = [images]
self.plot_images("generated.pdf", images_rows)

@staticmethod
def plot_images(file_name, images_rows):
first_row = images_rows[0]
number_of_columns = len(first_row)
plt.clf()
fig, axes = plt.subplots(
nrows=len(images_rows),
ncols=number_of_columns,
sharex=True,
sharey=True,
figsize=(20, 2.5),
)

for row in range(len(images_rows)):
ax_row = axes[row]
images_row = images_rows[row]
for col in range(len(first_row)):
if len(images_rows) == 1:
ax = axes[col]
else:
ax = ax_row[col]
image = images_row[col]
ax.imshow(image.view((IMAGE_WIDTH, IMAGE_WIDTH)), cmap='binary')
plt.savefig(file_name)

@staticmethod
def spread_points(points, spread_factor):
result = []
for point in points:
for _ in range(spread_factor):
result.append(point)
return result


def set_seed():
random_seed = 123

if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_deterministic(True)

torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)


def main():
trainer = Trainer()
trainer.train()
trainer.save()
trainer.load()
trainer.show_examples()
trainer.generate_examples()


main()



AutoEncoder

 



I this post we present an AutoEncoder Implementation in pytorch.
This is based on the Convolutional Autoencoder code by Sebastian Raschka.


We are using the MNIST images dataset to encode a gray scale 28X28 numbers images to be encoded  by only 2 numbers!

At the top of this post we have an image of the original digits images and the related decoded images. It is not perfect, but very close. We can get better performance by additional training (but I don't have a GPU).


Something to notice is that using randomness on the training dataloader seems to make result worse. This again might be related to more training required.

Also, removing to duplicate 64 channels convolution layers seems to have similar results.



import time

import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets

MODEL_PATH = "auto-encoder.model"
IMAGE_WIDTH = 28


class Trim(torch.nn.Module):
def __init__(self, *args):
super().__init__()
self.size = args[0]

def forward(self, x):
return x[:, :, :self.size, :self.size]


class AutoEncoder(torch.nn.Module):

def __init__(self):
super(AutoEncoder, self).__init__()

relu_slope = 0.01

self.encoder = torch.nn.Sequential(

torch.nn.Conv2d(
in_channels=1,
out_channels=32,
stride=1,
kernel_size=3,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.Conv2d(
in_channels=32,
out_channels=64,
stride=2,
kernel_size=3,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.Conv2d(
in_channels=64,
out_channels=64,
stride=2,
kernel_size=3,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.Conv2d(
in_channels=64,
out_channels=64,
stride=1,
kernel_size=3,
padding=1,
),

torch.nn.Flatten(),

torch.nn.Linear(
in_features=3136,
out_features=2,
),
)

self.decoder = torch.nn.Sequential(

torch.nn.Linear(
in_features=2,
out_features=3136,
),

torch.nn.Unflatten(
dim=1,
unflattened_size=(
64,
7,
7,
),
),

torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=64,
stride=1,
kernel_size=3,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=64,
stride=2,
kernel_size=3,
padding=1,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=32,
stride=2,
kernel_size=3,
padding=0,
),

torch.nn.LeakyReLU(relu_slope),

torch.nn.ConvTranspose2d(
in_channels=32,
out_channels=1,
stride=1,
kernel_size=3,
padding=0,
),

Trim(IMAGE_WIDTH),

torch.nn.Sigmoid(),
)

def forward(self, x):
z = self.encoder(x)
y = self.decoder(z)
return y


class Trainer:

def __init__(self):
self.number_of_epochs = 5
batch_size = 32
learning_rate = 0.0005

self.loss_train_per_batch = []
self.loss_train_per_epoch = []
self.loss_test_per_epoch = []

limit_size = None
# limit_size = 13
if limit_size is None:
train_sampler = None
shuffle = True
else:
train_indexes = torch.arange(limit_size)
train_sampler = SubsetRandomSampler(train_indexes)
shuffle = False

self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)

transform_train = torchvision.transforms.Compose(
[
# torchvision.transforms.Resize(size=(IMAGE_WIDTH + 4, IMAGE_WIDTH + 4)),
# torchvision.transforms.RandomCrop(size=(IMAGE_WIDTH, IMAGE_WIDTH)),
# torchvision.transforms.RandomRotation(degrees=20),
torchvision.transforms.ToTensor(),
]
)

transform_test = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)

dataset_train = datasets.MNIST(root='local_cache_folder',
train=True,
transform=transform_train,
download=True)

dataset_test = datasets.MNIST(root='local_cache_folder',
train=False,
transform=transform_test)

print('train samples', dataset_train.data.shape[0])
print('test samples', dataset_train.data.shape[0])

self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=train_sampler,
)

self.loader_test = DataLoader(dataset=dataset_test,
batch_size=batch_size,
shuffle=False,
)

for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break

self.model = AutoEncoder()
self.model = self.model.to(device=self.device)

self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

self.loss_function = torch.nn.functional.mse_loss

def run_batches(self, data_loader, batch_callback=None, print_details=False):
start_time = time.time()
total_loss = 0
total_samples = 0
batches_in_epoch_count = 0
for batch_index, (batch_x, _) in enumerate(data_loader):
batches_in_epoch_count += 1
batch_x = batch_x.to(device=self.device)
batch_samples = batch_x.shape[0]
total_samples += batch_samples
logics = self.model(batch_x)
batch_loss = self.loss_function(logics, batch_x)
if batch_callback is not None:
batch_callback(logics, batch_loss, batch_samples)
total_loss += batch_loss

if print_details:
passed_seconds = time.time() - start_time
if passed_seconds > 5:
start_time = time.time()
print('batch %05d loss %.5f' % (batch_index, batch_loss.item()))

average_loss = total_loss / total_samples

return average_loss.cpu(), batches_in_epoch_count

def batch_callback_train(self, _, batch_loss, batch_samples):
self.loss_train_per_batch.append(batch_loss.item() / batch_samples)

self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()

def train_epoch(self, epoch_index):
start_time = time.time()
self.run_batches(data_loader=self.loader_train, batch_callback=self.batch_callback_train,
print_details=True)

with torch.no_grad():
self.model.eval()
epoch_loss_train, batches_in_epoch_count = self.run_batches(
data_loader=self.loader_train)
epoch_loss_test, _ = self.run_batches(data_loader=self.loader_test)

self.loss_train_per_epoch.append(epoch_loss_train)
self.loss_test_per_epoch.append(epoch_loss_test)

passed_seconds = time.time() - start_time
print('epoch %05d/%05d, batch duration %03.0f seconds, loss train %.5f, loss test %.5f' % (
epoch_index + 1, self.number_of_epochs, passed_seconds, epoch_loss_train, epoch_loss_test))

return batches_in_epoch_count

def train(self):
self.loss_train_per_batch = []
self.loss_train_per_epoch = []
self.loss_test_per_epoch = []
self.model.train()

start_time = time.time()
batches_in_epoch_count = 0
for epoch_index in range(self.number_of_epochs):
batches_in_epoch_count = self.train_epoch(epoch_index)

passed_seconds = time.time() - start_time
print('total training duration %03.0f seconds' % passed_seconds)

self.model.eval()

self.loss_train_per_epoch = self.spread_points(self.loss_train_per_epoch, batches_in_epoch_count)
self.loss_test_per_epoch = self.spread_points(self.loss_test_per_epoch, batches_in_epoch_count)
plt.clf()
plt.plot(self.loss_train_per_batch, color='b', label='train batch')
plt.plot(self.loss_train_per_epoch, color='g', label='train epoch')
plt.plot(self.loss_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylim(0, 0.003)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")
self.loss_train_per_batch = []

def save(self):
torch.save(self.model.state_dict(), MODEL_PATH)

def load(self):
self.model = AutoEncoder()
self.model.load_state_dict(torch.load(MODEL_PATH))
self.model = self.model.to(self.device)
self.model.eval()

def show_examples(self):

examples_amount = 10
with torch.no_grad():
for _, (original_images, _) in enumerate(self.loader_test):
original_images = original_images.to(device=self.device)
original_images = original_images[:examples_amount]
decoded_images = self.model(original_images)
batch_samples = original_images.shape[0]
original_images = original_images.to('cpu')
decoded_images = decoded_images.to('cpu')

plt.clf()
fig, axes = plt.subplots(nrows=2, ncols=batch_samples, sharex=True, sharey=True, figsize=(20, 2.5))
for i in range(batch_samples):
for ax, img in zip(axes, [original_images, decoded_images]):
curr_img = img[i].detach()
ax[i].imshow(curr_img.view((IMAGE_WIDTH, IMAGE_WIDTH)), cmap='binary')
plt.savefig("examples.pdf")

return

@staticmethod
def spread_points(points, spread_factor):
result = []
for point in points:
for _ in range(spread_factor):
result.append(point)
return result


def set_seed():
random_seed = 123

if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_deterministic(True)

torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)


def main():
trainer = Trainer()
trainer.train()
trainer.save()
trainer.load()
trainer.show_examples()


main()