In this post we show a self implementaion of Resnet32 using pytorch.
The implementation is based on the following:
The Resnet32 arcitechture is contains a convolution first layer and a fully connected last layer.
The real trick for this network are the skip connections, which skip a layer in case its weight got too low values during backpropagation, and hence avoiding collapsing the entire chain.
The thing I've noticed only during implementation is that we have multiple layers handling the same input dimension. For example, we have 6 blocks handling the 28X28 input size. This consumes huge amount of parameters, though this is one of the smaller common networks.
image from https://www.researchgate.net/figure/Number-of-training-parameters-in-millionsM-for-VGG-ResNet-and-DenseNet-models_tbl1_338552250
As input database, we are using the CIFAR-10 dataset.
The code below is an object oriented based implementation.
import time
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
class ResnetBlock(torch.nn.Module):
def __init__(self, input_channels, output_channels, stride):
super(ResnetBlock, self).__init__()
# we could add Dropout2d(p=0.5) here to avoid overfitting
self.convolution_layer = torch.nn.Sequential(
torch.nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=stride, padding=1, bias=False),
torch.nn.BatchNorm2d(output_channels),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(output_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=False),
torch.nn.BatchNorm2d(output_channels),
)
self.skip_layer = None
if stride != 1 or input_channels != output_channels:
self.skip_layer = torch.nn.Sequential(
torch.nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=stride, bias=False),
torch.nn.BatchNorm2d(output_channels),
)
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, x):
identity = x
z = self.convolution_layer(x)
if self.skip_layer:
identity = self.skip_layer(x)
out = z + identity
return self.relu(out)
class ResnetLayer(torch.nn.Module):
def __init__(self, input_channels, output_channels, stride, blocks_count):
super(ResnetLayer, self).__init__()
layers = [
ResnetBlock(input_channels, output_channels, stride),
]
for _ in range(blocks_count - 1):
layers.append(ResnetBlock(output_channels, output_channels, 1))
self.layers = torch.nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class Resnet32(torch.nn.Module):
def __init__(self, number_of_classes):
super(Resnet32, self).__init__()
# 224 X 224
input_channels = 3
output_channels = 64
self.convolution1 = self.create_convolution1(input_channels, output_channels)
# 112 X 112
input_channels = output_channels
output_channels = 128
self.block1 = ResnetLayer(input_channels, output_channels, stride=1, blocks_count=3)
# 56 X 56
input_channels = output_channels
output_channels = 256
self.block2 = ResnetLayer(input_channels, output_channels, stride=2, blocks_count=4)
# 28 X 28
input_channels = output_channels
output_channels = 512
self.block3 = ResnetLayer(input_channels, output_channels, stride=2, blocks_count=6)
# 14 X 14
input_channels = output_channels
output_channels = 1024
self.block4 = ResnetLayer(input_channels, output_channels, stride=2, blocks_count=3)
# 7 X 7
self.classifier = self.create_classifier(output_channels, number_of_classes)
@staticmethod
def create_convolution1(input_channels, output_channels):
return torch.nn.Sequential(
torch.nn.Conv2d(input_channels, output_channels, kernel_size=7, stride=2, padding=3, bias=False),
torch.nn.BatchNorm2d(output_channels),
torch.nn.ReLU(inplace=True),
torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
@staticmethod
def create_classifier(number_of_channels, number_of_classes):
return torch.nn.Sequential(
torch.nn.AdaptiveAvgPool2d((1, 1)),
torch.nn.Flatten(),
torch.nn.Linear(number_of_channels, number_of_classes),
)
def forward(self, x):
x = self.convolution1(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.classifier(x)
return x
class Trainer:
def __init__(self):
batch_size = 128
learning_rate = 0.1
learning_momentum = 0.9
learning_rate_scheduler_factor = 0.1
self.loss_train_per_batch = []
limit_size = None
limit_size = 1000
if limit_size is None:
sampler = None
shuffle = True
else:
sampler = torch.arange(limit_size)
shuffle = False
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)
scale_up = (250, 250)
crop_down = (224, 224)
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
transform_train = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(scale_up),
torchvision.transforms.RandomCrop(crop_down),
torchvision.transforms.RandomRotation(20),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)
transform_test = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(scale_up),
torchvision.transforms.CenterCrop(crop_down),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)
dataset_train = datasets.CIFAR10(root='local_cache_folder',
train=True,
transform=transform_train,
download=True)
dataset_test = datasets.CIFAR10(root='local_cache_folder',
train=False,
transform=transform_test)
print('train samples', dataset_train.data.shape[0])
print('test samples', dataset_train.data.shape[0])
self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
)
self.loader_test = DataLoader(dataset=dataset_test,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
)
for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break
self.model = Resnet32(number_of_classes=10)
self.model = self.model.to(device=self.device)
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate, momentum=learning_momentum)
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.optimizer,
factor=learning_rate_scheduler_factor, mode='max')
self.loss_function = torch.nn.functional.cross_entropy
def run_batches(self, data_loader, batch_callback=None):
total_loss = 0
total_samples = 0
correct_predictions = 0
batches_in_epoch_count = 0
for _, (batch_x, batch_labels) in enumerate(data_loader):
batches_in_epoch_count += 1
batch_x = batch_x.to(device=self.device)
batch_labels = batch_labels.to(device=self.device)
batch_samples = batch_x.shape[0]
total_samples += batch_samples
logics = self.model(batch_x)
batch_loss = self.loss_function(logics, batch_labels)
batch_predictions = torch.argmax(logics, dim=1)
batch_correct = batch_predictions == batch_labels
correct_predictions += batch_correct.sum()
if batch_callback is not None:
batch_callback(logics, batch_loss, batch_samples)
total_loss += batch_loss
average_loss = total_loss / total_samples
accuracy = float(correct_predictions) / total_samples
return average_loss.cpu(), accuracy, batches_in_epoch_count
def batch_callback_train(self, _, batch_loss, batch_samples):
self.loss_train_per_batch.append(batch_loss.item() / batch_samples)
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
def train(self, number_of_epochs):
accuracy_train_per_epoch = []
accuracy_test_per_epoch = []
loss_train_per_epoch = []
loss_test_per_epoch = []
self.loss_train_per_batch = []
self.model.train()
for epoch_index in range(number_of_epochs):
start_time = time.time()
self.run_batches(data_loader=self.loader_train, batch_callback=self.batch_callback_train)
with torch.no_grad():
epoch_loss_train, epoch_accuracy_train, batches_in_epoch_count = self.run_batches(
data_loader=self.loader_train)
epoch_loss_test, epoch_accuracy_test, _ = self.run_batches(data_loader=self.loader_test)
loss_train_per_epoch.append(epoch_loss_train)
loss_test_per_epoch.append(epoch_loss_test)
accuracy_train_per_epoch.append(epoch_accuracy_train)
accuracy_test_per_epoch.append(epoch_accuracy_test)
self.scheduler.step(epoch_loss_train)
passed_seconds = time.time() - start_time
print(f'epoch {epoch_index},'
f'process seconds {passed_seconds},'
f'loss train {epoch_loss_train},'
f'loss test {epoch_loss_test},'
f'accuracy train {epoch_accuracy_train},'
f'accuracy test {epoch_accuracy_test}')
self.model.eval()
loss_train_per_epoch = self.spread_points(loss_train_per_epoch, batches_in_epoch_count)
loss_test_per_epoch = self.spread_points(loss_test_per_epoch, batches_in_epoch_count)
plt.clf()
plt.plot(self.loss_train_per_batch, color='b', label='train batch')
plt.plot(loss_train_per_epoch, color='g', label='train epoch')
plt.plot(loss_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylim(0, 0.01)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")
plt.clf()
plt.plot(accuracy_train_per_epoch, color='b', label='train')
plt.plot(accuracy_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.savefig("accuracy.pdf")
self.loss_train_per_batch = []
@staticmethod
def spread_points(points, spread_factor):
result = []
for point in points:
for _ in range(spread_factor):
result.append(point)
return result
def main():
random_seed = 42
number_of_epochs = 10
torch.manual_seed(random_seed)
trainer = Trainer()
trainer.train(number_of_epochs)
main()
Training this network is very resources consuming, I run it on a CPU, with just 10 samples to check that it is not failing, and it took several minutes.
Running this on Google colabs also took hours.
No comments:
Post a Comment