Full Blog TOC

Full Blog Table Of Content with Keywords Available HERE

Thursday, August 22, 2024

Resnet32 Full Implementation in Pytorch

 

In this post we show a self implementaion of Resnet32 using pytorch.

The implementation is based on the following:


The Resnet32 arcitechture is contains a convolution first layer and a fully connected last layer.

The real trick for this network are the skip connections, which skip a layer in case its weight got too low values during backpropagation, and hence avoiding collapsing the entire chain.




The thing I've noticed only during implementation is that we have multiple layers handling the same input dimension. For example, we have 6 blocks handling the 28X28 input size. This consumes huge amount of parameters, though this is one of the smaller common networks.



image from https://www.researchgate.net/figure/Number-of-training-parameters-in-millionsM-for-VGG-ResNet-and-DenseNet-models_tbl1_338552250



As input database, we are using the CIFAR-10 dataset.


The code below is an object oriented based implementation.


import time

import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets


class ResnetBlock(torch.nn.Module):
def __init__(self, input_channels, output_channels, stride):
super(ResnetBlock, self).__init__()

# we could add Dropout2d(p=0.5) here to avoid overfitting

self.convolution_layer = torch.nn.Sequential(
torch.nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=stride, padding=1, bias=False),
torch.nn.BatchNorm2d(output_channels),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(output_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=False),
torch.nn.BatchNorm2d(output_channels),
)
self.skip_layer = None
if stride != 1 or input_channels != output_channels:
self.skip_layer = torch.nn.Sequential(
torch.nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=stride, bias=False),
torch.nn.BatchNorm2d(output_channels),
)
self.relu = torch.nn.ReLU(inplace=True)

def forward(self, x):
identity = x
z = self.convolution_layer(x)

if self.skip_layer:
identity = self.skip_layer(x)

out = z + identity
return self.relu(out)


class ResnetLayer(torch.nn.Module):
def __init__(self, input_channels, output_channels, stride, blocks_count):
super(ResnetLayer, self).__init__()
layers = [
ResnetBlock(input_channels, output_channels, stride),
]

for _ in range(blocks_count - 1):
layers.append(ResnetBlock(output_channels, output_channels, 1))

self.layers = torch.nn.Sequential(*layers)

def forward(self, x):
return self.layers(x)


class Resnet32(torch.nn.Module):

def __init__(self, number_of_classes):
super(Resnet32, self).__init__()

# 224 X 224

input_channels = 3
output_channels = 64
self.convolution1 = self.create_convolution1(input_channels, output_channels)

# 112 X 112

input_channels = output_channels
output_channels = 128
self.block1 = ResnetLayer(input_channels, output_channels, stride=1, blocks_count=3)

# 56 X 56

input_channels = output_channels
output_channels = 256
self.block2 = ResnetLayer(input_channels, output_channels, stride=2, blocks_count=4)

# 28 X 28

input_channels = output_channels
output_channels = 512
self.block3 = ResnetLayer(input_channels, output_channels, stride=2, blocks_count=6)

# 14 X 14

input_channels = output_channels
output_channels = 1024
self.block4 = ResnetLayer(input_channels, output_channels, stride=2, blocks_count=3)

# 7 X 7

self.classifier = self.create_classifier(output_channels, number_of_classes)

@staticmethod
def create_convolution1(input_channels, output_channels):
return torch.nn.Sequential(
torch.nn.Conv2d(input_channels, output_channels, kernel_size=7, stride=2, padding=3, bias=False),
torch.nn.BatchNorm2d(output_channels),
torch.nn.ReLU(inplace=True),
torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)

@staticmethod
def create_classifier(number_of_channels, number_of_classes):
return torch.nn.Sequential(
torch.nn.AdaptiveAvgPool2d((1, 1)),
torch.nn.Flatten(),
torch.nn.Linear(number_of_channels, number_of_classes),
)

def forward(self, x):
x = self.convolution1(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.classifier(x)
return x


class Trainer:

def __init__(self):
batch_size = 128
learning_rate = 0.1
learning_momentum = 0.9
learning_rate_scheduler_factor = 0.1

self.loss_train_per_batch = []

limit_size = None
limit_size = 1000
if limit_size is None:
sampler = None
shuffle = True
else:
sampler = torch.arange(limit_size)
shuffle = False

self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)

scale_up = (250, 250)
crop_down = (224, 224)
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

transform_train = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(scale_up),
torchvision.transforms.RandomCrop(crop_down),
torchvision.transforms.RandomRotation(20),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)
transform_test = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(scale_up),
torchvision.transforms.CenterCrop(crop_down),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)

dataset_train = datasets.CIFAR10(root='local_cache_folder',
train=True,
transform=transform_train,
download=True)

dataset_test = datasets.CIFAR10(root='local_cache_folder',
train=False,
transform=transform_test)

print('train samples', dataset_train.data.shape[0])
print('test samples', dataset_train.data.shape[0])

self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
)

self.loader_test = DataLoader(dataset=dataset_test,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
)

for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break

self.model = Resnet32(number_of_classes=10)
self.model = self.model.to(device=self.device)

self.optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate, momentum=learning_momentum)

self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.optimizer,
factor=learning_rate_scheduler_factor, mode='max')

self.loss_function = torch.nn.functional.cross_entropy

def run_batches(self, data_loader, batch_callback=None):
total_loss = 0
total_samples = 0
correct_predictions = 0
batches_in_epoch_count = 0
for _, (batch_x, batch_labels) in enumerate(data_loader):
batches_in_epoch_count += 1
batch_x = batch_x.to(device=self.device)
batch_labels = batch_labels.to(device=self.device)
batch_samples = batch_x.shape[0]
total_samples += batch_samples
logics = self.model(batch_x)
batch_loss = self.loss_function(logics, batch_labels)
batch_predictions = torch.argmax(logics, dim=1)
batch_correct = batch_predictions == batch_labels
correct_predictions += batch_correct.sum()
if batch_callback is not None:
batch_callback(logics, batch_loss, batch_samples)
total_loss += batch_loss

average_loss = total_loss / total_samples
accuracy = float(correct_predictions) / total_samples

return average_loss.cpu(), accuracy, batches_in_epoch_count

def batch_callback_train(self, _, batch_loss, batch_samples):
self.loss_train_per_batch.append(batch_loss.item() / batch_samples)

self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()

def train(self, number_of_epochs):
accuracy_train_per_epoch = []
accuracy_test_per_epoch = []
loss_train_per_epoch = []
loss_test_per_epoch = []
self.loss_train_per_batch = []
self.model.train()

for epoch_index in range(number_of_epochs):
start_time = time.time()
self.run_batches(data_loader=self.loader_train, batch_callback=self.batch_callback_train)

with torch.no_grad():
epoch_loss_train, epoch_accuracy_train, batches_in_epoch_count = self.run_batches(
data_loader=self.loader_train)
epoch_loss_test, epoch_accuracy_test, _ = self.run_batches(data_loader=self.loader_test)

loss_train_per_epoch.append(epoch_loss_train)
loss_test_per_epoch.append(epoch_loss_test)
accuracy_train_per_epoch.append(epoch_accuracy_train)
accuracy_test_per_epoch.append(epoch_accuracy_test)

self.scheduler.step(epoch_loss_train)

passed_seconds = time.time() - start_time
print(f'epoch {epoch_index},'
f'process seconds {passed_seconds},'
f'loss train {epoch_loss_train},'
f'loss test {epoch_loss_test},'
f'accuracy train {epoch_accuracy_train},'
f'accuracy test {epoch_accuracy_test}')

self.model.eval()

loss_train_per_epoch = self.spread_points(loss_train_per_epoch, batches_in_epoch_count)
loss_test_per_epoch = self.spread_points(loss_test_per_epoch, batches_in_epoch_count)
plt.clf()
plt.plot(self.loss_train_per_batch, color='b', label='train batch')
plt.plot(loss_train_per_epoch, color='g', label='train epoch')
plt.plot(loss_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylim(0, 0.01)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")

plt.clf()
plt.plot(accuracy_train_per_epoch, color='b', label='train')
plt.plot(accuracy_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.savefig("accuracy.pdf")
self.loss_train_per_batch = []

@staticmethod
def spread_points(points, spread_factor):
result = []
for point in points:
for _ in range(spread_factor):
result.append(point)
return result


def main():
random_seed = 42
number_of_epochs = 10

torch.manual_seed(random_seed)
trainer = Trainer()
trainer.train(number_of_epochs)


main()




Training this network is very resources consuming, I run it on a CPU, with just 10 samples to check that it is not failing, and it took several minutes.

Running this on Google colabs also took hours.


No comments:

Post a Comment