In this post we present a simple GAN implementation to generate digits.
This post is based on the GAN explanation video by Sebastian Raschka.
GAN is an abbreviation for Generative Adversarial Network. It includes 2 networks: the generator and the discriminator. The generator generates new images, and the discriminator receives a real image and a generated
The discriminator goal is to classifiy if the received image is real or generated, while the discriminator goal is to fool the discriminator to classify the generated images as real ones.
This implementation uses the MNIST database as input, and fully connected networks with a single hidden layer.
We can see the results of the generated images at the top of this post. These are not perfect, but they appear to be on the way to be so. More training and possibly using convolutional networks would have gain better results.
The training was run for 50 epochs, and the loss graph is:
The related code implementation is:
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets
MODEL_PATH = "gan.model"
class GenerativeAdversarialNetwork(torch.nn.Module):
def __init__(self):
super(GenerativeAdversarialNetwork, self).__init__()
self.image_width = 28
self.noise_width = 100
generator_hidden_dimension = 128
discriminator_hidden_dimension = 128
self.generator = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(
in_features=self.noise_width,
out_features=generator_hidden_dimension
),
torch.nn.LeakyReLU(inplace=True),
torch.nn.Linear(
in_features=generator_hidden_dimension,
out_features=self.image_width * self.image_width,
),
torch.nn.Sigmoid(),
)
self.discriminator = torch.nn.Sequential(
torch.nn.Flatten(start_dim=1),
torch.nn.Linear(
in_features=self.image_width * self.image_width,
out_features=discriminator_hidden_dimension,
),
torch.nn.LeakyReLU(inplace=True),
torch.nn.Linear(
in_features=discriminator_hidden_dimension,
out_features=1,
),
)
def forward_generator(self, x):
z = self.generator(x)
image = z.view(
z.size(0), # number of samples
1, # color channel
self.image_width,
self.image_width,
)
return image
def forward_discriminator(self, x):
z = self.discriminator(x)
return z
class Trainer:
def __init__(self):
self.number_of_epochs = 50
batch_size = 32
learning_rate = 0.0002
self.loss_generator_per_batch = []
self.loss_generator_per_epoch = []
self.loss_discriminator_per_batch = []
self.loss_discriminator_per_epoch = []
limit_size = None
# limit_size = 13
if limit_size is None:
train_sampler = None
shuffle = True
else:
train_indexes = torch.arange(limit_size)
train_sampler = SubsetRandomSampler(train_indexes)
shuffle = False
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)
transform_train = torchvision.transforms.Compose(
[
# torchvision.transforms.Normalize((0.5,), (0.5,)),
# torchvision.transforms.Resize(size=(IMAGE_WIDTH + 4, IMAGE_WIDTH + 4)),
# torchvision.transforms.RandomCrop(size=(IMAGE_WIDTH, IMAGE_WIDTH)),
# torchvision.transforms.RandomRotation(degrees=20),
torchvision.transforms.ToTensor(),
]
)
dataset_train = datasets.MNIST(root='local_cache_folder',
train=True,
transform=transform_train,
download=True)
print('train samples', dataset_train.data.shape[0])
self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=train_sampler,
)
for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break
self.model = GenerativeAdversarialNetwork()
self.model = self.model.to(device=self.device)
self.optimizer_generator = torch.optim.Adam(self.model.generator.parameters(), lr=learning_rate)
self.optimizer_discriminator = torch.optim.Adam(self.model.discriminator.parameters(), lr=learning_rate)
self.loss_function = torch.nn.functional.binary_cross_entropy_with_logits
def train_single_batch(self, batch_images):
batch_images = batch_images.to(device=self.device)
batch_number_of_samples = batch_images.shape[0]
ones = torch.ones(batch_number_of_samples)
zeros = torch.zeros(batch_number_of_samples)
random_noise = torch.randn(batch_number_of_samples, self.model.noise_width)
generated_images = self.model.forward_generator(random_noise)
# train discriminator
self.optimizer_discriminator.zero_grad()
discriminator_predictions_for_real_images = self.model.forward_discriminator(batch_images)
discriminator_predictions_for_real_images = discriminator_predictions_for_real_images.squeeze()
discriminator_batch_loss_real_images = self.loss_function(discriminator_predictions_for_real_images, ones)
# using detach() since we want to use the results for 2 optimizers
discriminator_predictions_for_generated_images = self.model.forward_discriminator(generated_images.detach())
discriminator_predictions_for_generated_images = discriminator_predictions_for_generated_images.squeeze()
discriminator_batch_loss_generated_images = self.loss_function(
discriminator_predictions_for_generated_images, zeros)
discriminator_batch_loss = discriminator_batch_loss_real_images + discriminator_batch_loss_generated_images
discriminator_batch_loss = discriminator_batch_loss / 2
self.loss_discriminator_per_batch.append(discriminator_batch_loss.item())
discriminator_batch_loss.backward()
self.optimizer_discriminator.step()
# train generator
self.optimizer_generator.zero_grad()
discriminator_predictions_for_generated_images = self.model.forward_discriminator(generated_images)
discriminator_predictions_for_generated_images = discriminator_predictions_for_generated_images.squeeze()
generator_batch_loss = self.loss_function(discriminator_predictions_for_generated_images, ones)
self.loss_generator_per_batch.append(generator_batch_loss.item())
generator_batch_loss.backward()
self.optimizer_generator.step()
return discriminator_batch_loss.item(), generator_batch_loss.item()
def train_epoch(self, epoch_index):
epoch_start_time = time.time()
last_log_time = time.time()
discriminator_epoch_loss = 0
generator_epoch_loss = 0
number_of_batches_in_epoch = 0
number_of_batches = len(self.loader_train)
for batch_index, (batch_images, _) in enumerate(self.loader_train):
number_of_batches_in_epoch += 1
discriminator_batch_loss, generator_batch_loss = self.train_single_batch(batch_images)
discriminator_epoch_loss += discriminator_batch_loss
generator_epoch_loss += generator_batch_loss
passed_seconds = time.time() - last_log_time
if passed_seconds > 5:
last_log_time = time.time()
print('batch {:05d}/{:05d} discriminator loss {:.5f} generator loss {:.5f}'.format(
batch_index,
number_of_batches,
discriminator_batch_loss.item(),
generator_batch_loss.item(),
))
discriminator_epoch_loss = discriminator_epoch_loss / number_of_batches_in_epoch
self.loss_discriminator_per_epoch.append(discriminator_epoch_loss)
generator_epoch_loss = generator_epoch_loss / number_of_batches_in_epoch
self.loss_generator_per_epoch.append(generator_epoch_loss)
epoch_seconds = time.time() - epoch_start_time
print(
'epoch {:05d}/{:05d}, duration {:03.0f} seconds, discriminator loss {:.5f} generator loss {:.5f}'.format(
epoch_index + 1,
self.number_of_epochs,
epoch_seconds,
discriminator_epoch_loss,
generator_epoch_loss,
))
return number_of_batches_in_epoch
def show_examples(self, noise, file_name):
generated_images = self.model.forward_generator(noise)
combined_image = torchvision.utils.make_grid(generated_images, padding=2, nrow=10)
# tensor CHW -> plt HWC
combined_transpose = np.transpose(combined_image, (1, 2, 0))
plt.clf()
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(combined_transpose)
plt.savefig(file_name)
def plot_loss_graph(self, number_of_batches_in_epoch):
loss_discriminator_per_epoch = self.spread_points(self.loss_discriminator_per_epoch, number_of_batches_in_epoch)
loss_generator_per_epoch = self.spread_points(self.loss_generator_per_epoch, number_of_batches_in_epoch)
plt.clf()
plt.plot(self.loss_generator_per_batch, color='y', label='generator batch')
plt.plot(loss_generator_per_epoch, color='r', label='generator epoch')
plt.plot(self.loss_discriminator_per_batch, color='c', label='discriminator batch')
plt.plot(loss_discriminator_per_epoch, color='b', label='discriminator epoch')
plt.legend()
# plt.ylim(0, 0.5)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")
def train(self):
self.loss_generator_per_batch = []
self.loss_generator_per_epoch = []
self.loss_discriminator_per_batch = []
self.loss_discriminator_per_epoch = []
self.model.train()
start_time = time.time()
static_noise = torch.randn(100, self.model.noise_width)
for epoch_index in range(self.number_of_epochs):
number_of_batches_in_epoch = self.train_epoch(epoch_index)
self.plot_loss_graph(number_of_batches_in_epoch)
with torch.no_grad():
file_name = 'output/generated_epoch_{:03d}.pdf'.format(epoch_index)
self.show_examples(static_noise, file_name)
passed_seconds = time.time() - start_time
print('total training duration %03.0f seconds' % passed_seconds)
self.model.eval()
def save(self):
torch.save(self.model.state_dict(), MODEL_PATH)
def load(self):
self.model = GenerativeAdversarialNetwork()
self.model.load_state_dict(torch.load(MODEL_PATH))
self.model = self.model.to(self.device)
self.model.eval()
@staticmethod
def spread_points(points, spread_factor):
result = []
for point in points:
for _ in range(spread_factor):
result.append(point)
return result
def set_seed():
random_seed = 123
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
def main():
set_seed()
trainer = Trainer()
trainer.train()
# trainer.save()
# trainer.load()
main()
No comments:
Post a Comment