In this post we present using LSTM for text generation.
This post is based on the Character RNN by Sebastian Raschka.
The text generation is using embedding per character, followed by LSTM and a fully connected network.
The input for the training is base on shakespeare text.
The learning is actuall running pretty good, and we can see nice texts after a very short time.
The loss graph is:
The implementation is:
import string
import time
import matplotlib.pyplot as plt
import torch
class CharacterRnn(torch.nn.Module):
def __init__(self):
super(CharacterRnn, self).__init__()
self.characters = string.printable
number_of_characters = len(self.characters)
embedding_dimension = 100
self.hidden_output_dimension = 100
self.embedding = torch.nn.Embedding(
num_embeddings=number_of_characters,
embedding_dim=embedding_dimension,
)
self.rnn = torch.nn.LSTMCell(
input_size=embedding_dimension,
hidden_size=self.hidden_output_dimension,
)
self.fully_connected = torch.nn.Linear(
in_features=self.hidden_output_dimension,
out_features=number_of_characters,
)
def produce_start_state(self):
hidden_output = torch.zeros(1, self.hidden_output_dimension)
cell_state = torch.ones(1, self.hidden_output_dimension)
return hidden_output, cell_state
def forward(self, character, hidden_output, cell_state):
embedding = self.embedding(character)
(hidden_output, cell_state) = self.rnn(embedding, (hidden_output, cell_state))
output = self.fully_connected(hidden_output)
return output, hidden_output, cell_state
def index_to_character(self, index):
return self.characters[index]
def text_to_tensor(self, text):
characters_indexes_list = [self.characters.index(character) for character in text]
characters_indexes_tensor = torch.tensor(characters_indexes_list)
return characters_indexes_tensor
class Trainer:
def __init__(self):
# input is text from:
# https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
with open('input.txt', 'r') as file:
lines = [line.strip() for line in file if len(line.strip()) > 0]
print('got {} lines'.format(len(lines)))
self.samples = []
sample = []
for line in lines:
sample.append(line)
if len(sample) > 5:
self.samples.append('\n'.join(sample))
sample = []
print('got {} samples'.format(len(self.samples)))
learning_rate = 0.005
self.loss_per_sample = []
self.number_of_epochs = 2
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device:', self.device)
self.model = CharacterRnn()
self.model = self.model.to(device=self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
self.loss_function = torch.nn.functional.cross_entropy
def generate_text(self, generate_length=200, text_prefix='Th', temperature=0.8):
hidden_output, cell_state = self.model.produce_start_state()
prefix_inputs = self.model.text_to_tensor(text_prefix)
for character_index in range(len(text_prefix) - 1):
character = prefix_inputs[character_index]
character_2_dimension = character.unsqueeze(0)
outputs, hidden_output, cell_state = self.model(character_2_dimension, hidden_output, cell_state)
generated_text = text_prefix
character = prefix_inputs[len(prefix_inputs) - 1]
character_2_dimension = character.unsqueeze(0)
for character_index in range(generate_length):
outputs, hidden_output, cell_state = self.model(character_2_dimension, hidden_output, cell_state)
outputs_distribution = (outputs / temperature).exp()
character_2_dimension = torch.multinomial(outputs_distribution, 1)[0]
next_character = self.model.index_to_character(character_2_dimension)
generated_text += next_character
return generated_text
def train_epoch(self, epoch_index):
epoch_start_time = time.time()
last_log_time = time.time()
self.model.train()
for sample_index, sample_text in enumerate(self.samples):
self.optimizer.zero_grad()
sample_loss = 0
text_tensor = self.model.text_to_tensor(sample_text)
inputs = text_tensor[:-1]
labels = text_tensor[1:]
hidden_output, cell_state = self.model.produce_start_state()
for character_index in range(inputs.shape[0]):
character = inputs[character_index]
character_2_dimension = character.unsqueeze(0)
label = labels[character_index]
label_2_dimensions = label.view(1)
outputs, hidden_output, cell_state = self.model(character_2_dimension, hidden_output, cell_state)
sample_loss += self.loss_function(outputs, label_2_dimensions)
self.loss_per_sample.append(sample_loss.item())
sample_loss /= len(inputs)
sample_loss.backward()
self.optimizer.step()
passed_seconds = time.time() - last_log_time
if passed_seconds > 30:
self.model.eval()
last_log_time = time.time()
separator = '=' * 50
print('sample {:05d}/{:05d} loss {:.5f}, sample text:\n{}\n{}\n{}\n'.format(
sample_index,
len(self.samples),
sample_loss.item(),
separator,
self.generate_text(),
separator,
))
self.model.train()
with torch.no_grad():
epoch_seconds = time.time() - epoch_start_time
print(
'epoch {:05d}/{:05d}, duration {:03.0f} seconds'.format(
epoch_index + 1,
self.number_of_epochs,
epoch_seconds,
))
def plot_loss_graph(self):
plt.clf()
plt.plot(self.loss_per_sample, color='y', label='per sample')
plt.legend()
plt.ylabel('Loss')
plt.xlabel('Sample')
plt.savefig("loss.pdf")
def train(self):
self.loss_per_sample = []
start_time = time.time()
for epoch_index in range(self.number_of_epochs):
self.train_epoch(epoch_index)
self.plot_loss_graph()
passed_seconds = time.time() - start_time
print('total training duration %03.0f seconds' % passed_seconds)
self.model.eval()
def set_seed():
random_seed = 123
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
def main():
set_seed()
trainer = Trainer()
trainer.train()
# trainer.save()
# trainer.load()
main()
The training output example is:
==================================================
sample 03005/05462 loss 1.56047, sample text:
==================================================
ou king'st thou lord, their out astere,
And words.
QUEEN MARGARET:
A'Gly lord,
And become me, and set I have gries speak;
But think thou strong to thee his stoe that any to stothing traim.
CLIFFORD:
A
==================================================
sample 03212/05462 loss 1.58164, sample text:
==================================================
e what love, march.
KING EDWARD IV:
A from the hope rease our my pelarrious as the come words, Where bride holven our scale, a styull ongers too ourself the gracty, contague thou with their from you b
==================================================
sample 03395/05462 loss 1.58987, sample text:
==================================================
at enquear us us flight his none.
KING HENRY VI:
Why, thine no be pistressed.
SOMERSET:
Is this bocked Naint woil.
KING EDWARD IV:
Farewell,
And my not his wood now in all this his sten be'r my countr
==================================================
No comments:
Post a Comment