Full Blog TOC

Full Blog Table Of Content with Keywords Available HERE

Thursday, August 29, 2024

Using GAN to Generate Digits




In this post we present a simple GAN implementation to generate digits.

This post is based on the GAN explanation video by Sebastian Raschka.


GAN is an abbreviation for Generative Adversarial Network. It includes 2 networks: the generator and the discriminator. The generator generates new images, and the discriminator receives a real image and a generated 

The discriminator goal is to classifiy if the received image is real or generated, while the discriminator goal is to fool the discriminator to classify the generated images as real ones.

This implementation uses the MNIST database as input, and fully connected networks  with a single hidden layer. 

We can see the results of the generated images at the top of this post. These are not perfect, but they appear to be on the way to be so. More training and possibly using convolutional networks would have gain better results.

The training was run for 50 epochs, and the loss graph is:





The related code implementation is:



import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets

MODEL_PATH = "gan.model"


class GenerativeAdversarialNetwork(torch.nn.Module):

def __init__(self):
super(GenerativeAdversarialNetwork, self).__init__()

self.image_width = 28
self.noise_width = 100
generator_hidden_dimension = 128
discriminator_hidden_dimension = 128

self.generator = torch.nn.Sequential(

torch.nn.Flatten(),

torch.nn.Linear(
in_features=self.noise_width,
out_features=generator_hidden_dimension
),

torch.nn.LeakyReLU(inplace=True),

torch.nn.Linear(
in_features=generator_hidden_dimension,
out_features=self.image_width * self.image_width,
),

torch.nn.Sigmoid(),
)

self.discriminator = torch.nn.Sequential(

torch.nn.Flatten(start_dim=1),

torch.nn.Linear(
in_features=self.image_width * self.image_width,
out_features=discriminator_hidden_dimension,
),

torch.nn.LeakyReLU(inplace=True),

torch.nn.Linear(
in_features=discriminator_hidden_dimension,
out_features=1,
),
)

def forward_generator(self, x):
z = self.generator(x)

image = z.view(
z.size(0), # number of samples
1, # color channel
self.image_width,
self.image_width,
)

return image

def forward_discriminator(self, x):
z = self.discriminator(x)
return z


class Trainer:

def __init__(self):
self.number_of_epochs = 50
batch_size = 32
learning_rate = 0.0002

self.loss_generator_per_batch = []
self.loss_generator_per_epoch = []
self.loss_discriminator_per_batch = []
self.loss_discriminator_per_epoch = []

limit_size = None
# limit_size = 13
if limit_size is None:
train_sampler = None
shuffle = True
else:
train_indexes = torch.arange(limit_size)
train_sampler = SubsetRandomSampler(train_indexes)
shuffle = False

self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)

transform_train = torchvision.transforms.Compose(
[
# torchvision.transforms.Normalize((0.5,), (0.5,)),
# torchvision.transforms.Resize(size=(IMAGE_WIDTH + 4, IMAGE_WIDTH + 4)),
# torchvision.transforms.RandomCrop(size=(IMAGE_WIDTH, IMAGE_WIDTH)),
# torchvision.transforms.RandomRotation(degrees=20),
torchvision.transforms.ToTensor(),
]
)

dataset_train = datasets.MNIST(root='local_cache_folder',
train=True,
transform=transform_train,
download=True)

print('train samples', dataset_train.data.shape[0])

self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=train_sampler,
)

for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break

self.model = GenerativeAdversarialNetwork()
self.model = self.model.to(device=self.device)

self.optimizer_generator = torch.optim.Adam(self.model.generator.parameters(), lr=learning_rate)
self.optimizer_discriminator = torch.optim.Adam(self.model.discriminator.parameters(), lr=learning_rate)

self.loss_function = torch.nn.functional.binary_cross_entropy_with_logits

def train_single_batch(self, batch_images):
batch_images = batch_images.to(device=self.device)

batch_number_of_samples = batch_images.shape[0]

ones = torch.ones(batch_number_of_samples)
zeros = torch.zeros(batch_number_of_samples)
random_noise = torch.randn(batch_number_of_samples, self.model.noise_width)
generated_images = self.model.forward_generator(random_noise)

# train discriminator
self.optimizer_discriminator.zero_grad()

discriminator_predictions_for_real_images = self.model.forward_discriminator(batch_images)
discriminator_predictions_for_real_images = discriminator_predictions_for_real_images.squeeze()
discriminator_batch_loss_real_images = self.loss_function(discriminator_predictions_for_real_images, ones)

# using detach() since we want to use the results for 2 optimizers
discriminator_predictions_for_generated_images = self.model.forward_discriminator(generated_images.detach())
discriminator_predictions_for_generated_images = discriminator_predictions_for_generated_images.squeeze()
discriminator_batch_loss_generated_images = self.loss_function(
discriminator_predictions_for_generated_images, zeros)

discriminator_batch_loss = discriminator_batch_loss_real_images + discriminator_batch_loss_generated_images
discriminator_batch_loss = discriminator_batch_loss / 2
self.loss_discriminator_per_batch.append(discriminator_batch_loss.item())

discriminator_batch_loss.backward()
self.optimizer_discriminator.step()

# train generator
self.optimizer_generator.zero_grad()
discriminator_predictions_for_generated_images = self.model.forward_discriminator(generated_images)
discriminator_predictions_for_generated_images = discriminator_predictions_for_generated_images.squeeze()
generator_batch_loss = self.loss_function(discriminator_predictions_for_generated_images, ones)
self.loss_generator_per_batch.append(generator_batch_loss.item())
generator_batch_loss.backward()
self.optimizer_generator.step()

return discriminator_batch_loss.item(), generator_batch_loss.item()

def train_epoch(self, epoch_index):
epoch_start_time = time.time()
last_log_time = time.time()
discriminator_epoch_loss = 0
generator_epoch_loss = 0
number_of_batches_in_epoch = 0
number_of_batches = len(self.loader_train)
for batch_index, (batch_images, _) in enumerate(self.loader_train):
number_of_batches_in_epoch += 1
discriminator_batch_loss, generator_batch_loss = self.train_single_batch(batch_images)
discriminator_epoch_loss += discriminator_batch_loss
generator_epoch_loss += generator_batch_loss

passed_seconds = time.time() - last_log_time
if passed_seconds > 5:
last_log_time = time.time()
print('batch {:05d}/{:05d} discriminator loss {:.5f} generator loss {:.5f}'.format(
batch_index,
number_of_batches,
discriminator_batch_loss.item(),
generator_batch_loss.item(),
))

discriminator_epoch_loss = discriminator_epoch_loss / number_of_batches_in_epoch
self.loss_discriminator_per_epoch.append(discriminator_epoch_loss)
generator_epoch_loss = generator_epoch_loss / number_of_batches_in_epoch
self.loss_generator_per_epoch.append(generator_epoch_loss)

epoch_seconds = time.time() - epoch_start_time
print(
'epoch {:05d}/{:05d}, duration {:03.0f} seconds, discriminator loss {:.5f} generator loss {:.5f}'.format(
epoch_index + 1,
self.number_of_epochs,
epoch_seconds,
discriminator_epoch_loss,
generator_epoch_loss,
))

return number_of_batches_in_epoch

def show_examples(self, noise, file_name):
generated_images = self.model.forward_generator(noise)
combined_image = torchvision.utils.make_grid(generated_images, padding=2, nrow=10)

# tensor CHW -> plt HWC
combined_transpose = np.transpose(combined_image, (1, 2, 0))

plt.clf()
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(combined_transpose)
plt.savefig(file_name)

def plot_loss_graph(self, number_of_batches_in_epoch):

loss_discriminator_per_epoch = self.spread_points(self.loss_discriminator_per_epoch, number_of_batches_in_epoch)
loss_generator_per_epoch = self.spread_points(self.loss_generator_per_epoch, number_of_batches_in_epoch)

plt.clf()
plt.plot(self.loss_generator_per_batch, color='y', label='generator batch')
plt.plot(loss_generator_per_epoch, color='r', label='generator epoch')
plt.plot(self.loss_discriminator_per_batch, color='c', label='discriminator batch')
plt.plot(loss_discriminator_per_epoch, color='b', label='discriminator epoch')
plt.legend()
# plt.ylim(0, 0.5)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")

def train(self):
self.loss_generator_per_batch = []
self.loss_generator_per_epoch = []
self.loss_discriminator_per_batch = []
self.loss_discriminator_per_epoch = []
self.model.train()

start_time = time.time()
static_noise = torch.randn(100, self.model.noise_width)

for epoch_index in range(self.number_of_epochs):
number_of_batches_in_epoch = self.train_epoch(epoch_index)
self.plot_loss_graph(number_of_batches_in_epoch)
with torch.no_grad():
file_name = 'output/generated_epoch_{:03d}.pdf'.format(epoch_index)
self.show_examples(static_noise, file_name)

passed_seconds = time.time() - start_time
print('total training duration %03.0f seconds' % passed_seconds)

self.model.eval()

def save(self):
torch.save(self.model.state_dict(), MODEL_PATH)

def load(self):
self.model = GenerativeAdversarialNetwork()
self.model.load_state_dict(torch.load(MODEL_PATH))
self.model = self.model.to(self.device)
self.model.eval()

@staticmethod
def spread_points(points, spread_factor):
result = []
for point in points:
for _ in range(spread_factor):
result.append(point)
return result


def set_seed():
random_seed = 123

if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)


def main():
set_seed()
trainer = Trainer()
trainer.train()
# trainer.save()
# trainer.load()


main()

 




No comments:

Post a Comment