from: https://en.wikipedia.org/wiki/MNIST_database
This post displays a nice example of multi class classification using multi layered neural network. The database used is the MNIST digits images. This code uses object oriented design, and is readable, and friendly.
import time
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
class MultiLayerNetwork(torch.nn.Module):
def __init__(self, number_of_features, number_of_outputs):
super(MultiLayerNetwork, self).__init__()
self.number_of_features = number_of_features
self.loss_function = torch.nn.functional.cross_entropy
hidden_layer_width = 100
self.network = torch.nn.Sequential(
torch.nn.Linear(number_of_features, hidden_layer_width),
torch.nn.Sigmoid(),
torch.nn.Linear(hidden_layer_width, number_of_outputs),
)
def forward(self, x):
return self.network(x)
class Trainer:
def __init__(self):
batch_size = 100
learning_rate = 0.1
limit_size = None
# limit_size = 1000
if limit_size is None:
sampler = None
shuffle = True
else:
sampler = torch.arange(limit_size)
shuffle = False
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)
dataset_train = datasets.MNIST(root='local_cache_folder',
train=True,
transform=transforms.ToTensor(),
download=True)
dataset_test = datasets.MNIST(root='local_cache_folder',
train=False,
transform=transforms.ToTensor())
print('train samples', dataset_train.data.shape[0])
print('test samples', dataset_train.data.shape[0])
self.loader_train = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
)
self.loader_test = DataLoader(dataset=dataset_test,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
)
for images, labels in self.loader_train:
print('single batch dimensions:', images.shape)
print('single batch label dimensions:', labels.shape)
self.number_of_features = images.shape[2] * images.shape[3]
print('number of features', self.number_of_features)
break
self.model = MultiLayerNetwork(number_of_features=self.number_of_features, number_of_outputs=10)
self.model = self.model.to(device=self.device)
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate)
self.loss_function = torch.nn.functional.cross_entropy
def run_batches(self, data_loader, batch_callback=None):
total_loss = 0
total_samples = 0
correct_predictions = 0
for _, (batch_x, batch_labels) in enumerate(data_loader):
# reshape from [100, 1, 28, 28] to [100, 28*28]
batch_x = batch_x.to(device=self.device)
batch_labels = batch_labels.to(device=self.device)
batch_samples = batch_x.shape[0]
total_samples += batch_samples
batch_x = batch_x.view(-1, self.number_of_features)
logics = self.model(batch_x)
batch_loss = self.loss_function(logics, batch_labels)
batch_predictions = torch.argmax(logics, dim=1)
batch_correct = batch_predictions == batch_labels
correct_predictions += batch_correct.sum()
if batch_callback is not None:
batch_callback(logics, batch_loss)
total_loss += batch_loss
average_loss = total_loss / total_samples
accuracy = float(correct_predictions) / total_samples
return average_loss.cpu(), accuracy
def batch_callback_train(self, _, batch_loss):
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
def train(self, number_of_epochs):
accuracy_train_per_epoch = []
accuracy_test_per_epoch = []
loss_train_per_epoch = []
loss_test_per_epoch = []
self.model.train(mode=True)
for epoch_index in range(number_of_epochs):
start_time = time.time()
self.run_batches(data_loader=self.loader_train, batch_callback=self.batch_callback_train)
with torch.no_grad():
epoch_loss_train, epoch_accuracy_train = self.run_batches(data_loader=self.loader_train)
epoch_loss_test, epoch_accuracy_test = self.run_batches(data_loader=self.loader_test)
loss_train_per_epoch.append(epoch_loss_train)
loss_test_per_epoch.append(epoch_loss_test)
accuracy_train_per_epoch.append(epoch_accuracy_train)
accuracy_test_per_epoch.append(epoch_accuracy_test)
passed_seconds= time.time() - start_time
print(f'epoch {epoch_index},'
f'process seconds {passed_seconds},'
f'loss train {epoch_loss_train},'
f'loss test {epoch_loss_test},'
f'accuracy train {epoch_accuracy_train},'
f'accuracy test {epoch_accuracy_test}')
self.model.train(mode=False)
plt.clf()
plt.plot(loss_train_per_epoch, color='b', label='train')
plt.plot(loss_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.savefig("loss.pdf")
plt.clf()
plt.plot(accuracy_train_per_epoch, color='b', label='train')
plt.plot(accuracy_test_per_epoch, color='r', label='test')
plt.legend()
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.savefig("accuracy.pdf")
def main():
random_seed = 42
number_of_epochs = 10
torch.manual_seed(random_seed)
trainer = Trainer()
trainer.train(number_of_epochs)
main()
The results are:
using device: cpu
train samples 60000
test samples 60000
single batch dimensions: torch.Size([100, 1, 28, 28])
single batch label dimensions: torch.Size([100])
number of features 784
epoch 0,loss train 0.0054548028856515884,loss test 0.005281297955662012,accuracy train 0.86655,accuracy test 0.8739
epoch 1,loss train 0.003863557009026408,loss test 0.0037198178470134735,accuracy train 0.8947833333333334,accuracy test 0.8988
epoch 2,loss train 0.0033619378227740526,loss test 0.0032330257818102837,accuracy train 0.9052166666666667,accuracy test 0.9099
epoch 3,loss train 0.003106057411059737,loss test 0.0030021234415471554,accuracy train 0.9111333333333334,accuracy test 0.9147
epoch 4,loss train 0.0029036011546850204,loss test 0.002816939726471901,accuracy train 0.9166333333333333,accuracy test 0.9199
epoch 5,loss train 0.002762931864708662,loss test 0.0026854947209358215,accuracy train 0.9205,accuracy test 0.9247
epoch 6,loss train 0.0026315872091799974,loss test 0.0025823484174907207,accuracy train 0.9245,accuracy test 0.9259
epoch 7,loss train 0.0025059168692678213,loss test 0.0024447101168334484,accuracy train 0.92775,accuracy test 0.9299
epoch 8,loss train 0.00239209970459342,loss test 0.0023570044431835413,accuracy train 0.9314,accuracy test 0.9323
epoch 9,loss train 0.0023001916706562042,loss test 0.0022823973558843136,accuracy train 0.93395,accuracy test 0.9349
No comments:
Post a Comment