Full Blog TOC

Full Blog Table Of Content with Keywords Available HERE

Thursday, August 22, 2024

Transfer Learning

 

In this post we have an example of transfer learning, a.k.a fine-tunning of a pre-trained model.

This entire code is simply a nice code version of the example in Transfer Learning for Computer Vision Tutorial.




import os
import time
from tempfile import TemporaryDirectory

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.optim import lr_scheduler
from torchvision import datasets, transforms


class TransferLearning:
def __init__(self):
cudnn.benchmark = True
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
image_size = 224
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std)
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean, std)
]),
}

# need to manually copy and extract the images from here:
# https://download.pytorch.org/tutorial/hymenoptera_data.zip
data_dir = 'hymenoptera_data'

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
self.dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
self.dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print('class names are:', class_names)

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')

print('original model')
print(model_conv)
print()

# freeze all model
for param in model_conv.parameters():
param.requires_grad = False

# replace the last classifier fully connected network - the only once to be trained
fully_connected_input_features = model_conv.fc.in_features
model_conv.fc = nn.Linear(fully_connected_input_features, len(class_names))

self.model = model_conv.to(self.device)

print('transform model')
print(model_conv)
print()

self.optimizer = optim.SGD(self.model.fc.parameters(), lr=0.001, momentum=0.9)
self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer, step_size=7, gamma=0.1)

def train_model(self, num_epochs=3):
start_time = time.time()

# Create a temporary directory to save training checkpoints
with TemporaryDirectory() as tempdir:
best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')

torch.save(self.model.state_dict(), best_model_params_path)
best_acc = 0.0

for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)

self.model.train()
self.run_epoch('train')

self.model.eval()
epoch_acc = self.run_epoch('val')

if epoch_acc > best_acc:
best_acc = epoch_acc
torch.save(self.model.state_dict(), best_model_params_path)

print()

time_elapsed = time.time() - start_time
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best validation Acc: {best_acc:4f}')

# load best model weights
self.model.load_state_dict(torch.load(best_model_params_path))

def run_epoch(self, phase):
running_loss = 0.0
running_corrects = 0

for inputs, labels in self.dataloaders[phase]:
inputs = inputs.to(self.device)
labels = labels.to(self.device)

# zero the parameter gradients
self.optimizer.zero_grad()

with torch.set_grad_enabled(phase == 'train'):
outputs = self.model(inputs)
_, predictions = torch.max(outputs, 1)
loss = nn.CrossEntropyLoss()(outputs, labels)

if phase == 'train':
loss.backward()
self.optimizer.step()

running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(predictions == labels.data)
if phase == 'train':
self.learning_rate_scheduler.step()

epoch_loss = running_loss / self.dataset_sizes[phase]
epoch_acc = running_corrects.double() / self.dataset_sizes[phase]

print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
return epoch_acc


def main():
transfer = TransferLearning()
transfer.train_model()


main()

No comments:

Post a Comment