I this post we present an AutoEncoder Implementation in pytorch.
This is based on the Convolutional Autoencoder code by Sebastian Raschka.
We are using the MNIST images dataset to encode a gray scale 28X28 numbers images to be encoded by only 2 numbers!
At the top of this post we have an image of the original digits images and the related decoded images. It is not perfect, but very close. We can get better performance by additional training (but I don't have a GPU).
Something to notice is that using randomness on the training dataloader seems to make result worse. This again might be related to more training required.
Also, removing to duplicate 64 channels convolution layers seems to have similar results.
import time
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets
MODEL_PATH = "auto-encoder.model"
IMAGE_WIDTH = 28
class Trim(torch.nn.Module):
def __init__(self, *args):
super().__init__()
self.size = args[0]
def forward(self, x):
return x[:, :, :self.size, :self.size]
class AutoEncoder(torch.nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
relu_slope = 0.01
self.encoder = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=1,
out_channels=32,
stride=1,
kernel_size=3,
padding=1,
),
torch.nn.LeakyReLU(relu_slope),
torch.nn.Conv2d(
in_channels=32,
out_channels=64,
stride=2,
kernel_size=3,
padding=1,
),
torch.nn.LeakyReLU(relu_slope),
torch.nn.Conv2d(
in_channels=64,
out_channels=64,
stride=2,
kernel_size=3,
padding=1,
),
torch.nn.LeakyReLU(relu_slope),
torch.nn.Conv2d(
in_channels=64,
out_channels=64,
stride=1,
kernel_size=3,
padding=1,
),
torch.nn.Flatten(),
torch.nn.Linear(
in_features=3136,
out_features=2,
),
)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(
in_features=2,
out_features=3136,
),
torch.nn.Unflatten(
dim=1,
unflattened_size=(
64,
7,
7,
),
),
torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=64,
stride=1,
kernel_size=3,
padding=1,
),
torch.nn.LeakyReLU(relu_slope),
torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=64,
stride=2,
kernel_size=3,
padding=1,
),
torch.nn.LeakyReLU(relu_slope),
torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=32,
stride=2,
kernel_size=3,
padding=0,
),
torch.nn.LeakyReLU(relu_slope),
torch.nn.ConvTranspose2d(
in_channels=32,
out_channels=1,
stride=1,
kernel_size=3,
padding=0,
),
Trim(IMAGE_WIDTH),
torch.nn.Sigmoid(),
)
def forward(self, x):
z = self.encoder(x)
y = self.decoder(z)
return y
class Trainer:
def __init__(self):
self.number_of_epochs = 5
batch_size = 32
learning_rate = 0.0005
self.loss_train_per_batch = []
self.loss_train_per_epoch = []
self.loss_test_per_epoch = []
limit_size = None
# limit_size = 13
if limit_size is None:
train_sampler = None
shuffle = True
else:
train_indexes = torch.arange(limit_size)
train_sampler = SubsetRandomSampler(train_indexes)
shuffle = False
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)
transform_train = torchvision.transforms.Compose(
[
# torchvision.transforms.Resize(size=(IMAGE_WIDTH + 4, IMAGE_WIDTH + 4)),
# torchvision.transforms.RandomCrop(size=(IMAGE_WIDTH, IMAGE_WIDTH)),
# torchvision.transforms.RandomRotation(degrees=20),
torchvision.transforms.ToTensor(),
]
)
transform_test = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)
dataset_train = datasets.MNIST(root='local_cache_folder',
train=True,
transform=transform_train,
download=True)
dataset_test = datasets.MNIST(root='local_cache_folder',
train=False,
transform=transform_test)
print('train samples', dataset_train.data.shape[0])
print('test samples', dataset_train.data.shape[0])
self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=train_sampler,
)
self.loader_test = DataLoader(dataset=dataset_test,
batch_size=batch_size,
shuffle=False,
)
for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break
self.model = AutoEncoder()
self.model = self.model.to(device=self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
self.loss_function = torch.nn.functional.mse_loss
def run_batches(self, data_loader, batch_callback=None, print_details=False):
start_time = time.time()
total_loss = 0
total_samples = 0
batches_in_epoch_count = 0
for batch_index, (batch_x, _) in enumerate(data_loader):
batches_in_epoch_count += 1
batch_x = batch_x.to(device=self.device)
batch_samples = batch_x.shape[0]
total_samples += batch_samples
logics = self.model(batch_x)
batch_loss = self.loss_function(logics, batch_x)
if batch_callback is not None:
batch_callback(logics, batch_loss, batch_samples)
total_loss += batch_loss
if print_details:
passed_seconds = time.time() - start_time
if passed_seconds > 5:
start_time = time.time()
print('batch %05d loss %.5f' % (batch_index, batch_loss.item()))
average_loss = total_loss / total_samples
return average_loss.cpu(), batches_in_epoch_count
def batch_callback_train(self, _, batch_loss, batch_samples):
self.loss_train_per_batch.append(batch_loss.item() / batch_samples)
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
def train_epoch(self, epoch_index):
start_time = time.time()
self.run_batches(data_loader=self.loader_train, batch_callback=self.batch_callback_train,
print_details=True)
with torch.no_grad():
self.model.eval()
epoch_loss_train, batches_in_epoch_count = self.run_batches(
data_loader=self.loader_train)
epoch_loss_test, _ = self.run_batches(data_loader=self.loader_test)
self.loss_train_per_epoch.append(epoch_loss_train)
self.loss_test_per_epoch.append(epoch_loss_test)
passed_seconds = time.time() - start_time
print('epoch %05d/%05d, batch duration %03.0f seconds, loss train %.5f, loss test %.5f' % (
epoch_index + 1, self.number_of_epochs, passed_seconds, epoch_loss_train, epoch_loss_test))
return batches_in_epoch_count
def train(self):
self.loss_train_per_batch = []
self.loss_train_per_epoch = []
self.loss_test_per_epoch = []
self.model.train()
start_time = time.time()
batches_in_epoch_count = 0
for epoch_index in range(self.number_of_epochs):
batches_in_epoch_count = self.train_epoch(epoch_index)
passed_seconds = time.time() - start_time
print('total training duration %03.0f seconds' % passed_seconds)
self.model.eval()
self.loss_train_per_epoch = self.spread_points(self.loss_train_per_epoch, batches_in_epoch_count)
self.loss_test_per_epoch = self.spread_points(self.loss_test_per_epoch, batches_in_epoch_count)
plt.clf()
plt.plot(self.loss_train_per_batch, color='b', label='train batch')
plt.plot(self.loss_train_per_epoch, color='g', label='train epoch')
plt.plot(self.loss_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylim(0, 0.003)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")
self.loss_train_per_batch = []
def save(self):
torch.save(self.model.state_dict(), MODEL_PATH)
def load(self):
self.model = AutoEncoder()
self.model.load_state_dict(torch.load(MODEL_PATH))
self.model = self.model.to(self.device)
self.model.eval()
def show_examples(self):
examples_amount = 10
with torch.no_grad():
for _, (original_images, _) in enumerate(self.loader_test):
original_images = original_images.to(device=self.device)
original_images = original_images[:examples_amount]
decoded_images = self.model(original_images)
batch_samples = original_images.shape[0]
original_images = original_images.to('cpu')
decoded_images = decoded_images.to('cpu')
plt.clf()
fig, axes = plt.subplots(nrows=2, ncols=batch_samples, sharex=True, sharey=True, figsize=(20, 2.5))
for i in range(batch_samples):
for ax, img in zip(axes, [original_images, decoded_images]):
curr_img = img[i].detach()
ax[i].imshow(curr_img.view((IMAGE_WIDTH, IMAGE_WIDTH)), cmap='binary')
plt.savefig("examples.pdf")
return
@staticmethod
def spread_points(points, spread_factor):
result = []
for point in points:
for _ in range(spread_factor):
result.append(point)
return result
def set_seed():
random_seed = 123
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_deterministic(True)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
def main():
trainer = Trainer()
trainer.train()
trainer.save()
trainer.load()
trainer.show_examples()
main()
No comments:
Post a Comment