Full Blog TOC

Full Blog Table Of Content with Keywords Available HERE

Thursday, September 12, 2024

Question Answered


 


In the previuos post Search On My Own, we've located a relevant section in a book to match a question. In this post we extend this to actually answer the question using the relecant extracted text from the book.

This is done using a pre-trained model. No additional fine tune is required. The actual change is the addition of QuestionAnswer class that uses the model.


Output example:


question: who broke in uncontrollable sobbings?
possible answer #0: Madame, the Marchioness of Schwedt
possible answer #1: princess
possible answer #2: Madam Sonsfeld
question: what is the weather?
possible answer #0: cold, gloomy, December
possible answer #1: long, cold, and dreary
possible answer #2: snow-tempests, sleet, frost
question: what is the useful knowledge?
possible answer #0: the study of philosophy, history, and languages
possible answer #1: most points
possible answer #2: Useful discourse
question: how many men did Fredrick command?
possible answer #0: twenty-four thousand
possible answer #1: thirty-five thousand
possible answer #2: four hundred



The full code is below.


import os.path
import pickle
import time
import urllib.request

import spacy
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from transformers import pipeline


def get_book_text(book_url):
print('loading book')
book_file_path = 'book.txt'
if not os.path.isfile(book_file_path):
with urllib.request.urlopen(book_url) as response:
data = response.read()
with open(book_file_path, 'wb') as file:
file.write(data)

with open(book_file_path, 'r') as file:
return file.read()


class TimedPrinter:
def __init__(self):
self.last_print = time.time()

def print(self, message):
passed = time.time() - self.last_print
if passed > 5:
self.last_print = time.time()
print(message)


class TextSplitter:
def __init__(self):
self.items = []
self.printer = TimedPrinter()
self.nlp = spacy.load('en_core_web_sm')
self.current_bulk = []
self.current_bulk_spaces = 0
self.min_words_in_section = 5

def flush_bulk(self):
if len(self.current_bulk) > 0:
self.items.append('\n'.join(self.current_bulk))

self.current_bulk = []
self.current_bulk_spaces = 0

def get_text_items(self, full_text, print_results=False):
items_file_path = 'items.pkl'
if not os.path.isfile(items_file_path):
self.split_text(full_text)
with open(items_file_path, 'wb') as file:
pickle.dump(self.items, file)

with open(items_file_path, 'rb') as file:
print('loading embedding')
self.items = pickle.load(file)

if print_results:
print('\n===\n'.join(self.items))

print('final split size is {}'.format(len(self.items)))
return self.items

def split_text(self, full_text):
print('breaking text')
sections = full_text.split('\n\n')
print('text length {}'.format(len(full_text)))
print('text split to {} sections'.format(len(sections)))

for section_index, section_text in enumerate(sections):
self.printer.print('section {}/{}'.format(section_index + 1, len(sections)))
self.scan_section(section_text)
self.flush_bulk()

def scan_section(self, section_text):
section_text = section_text.strip()
if section_text.count(' ') < self.min_words_in_section:
return
doc = self.nlp(section_text)

for sentence in doc.sents:
self.scan_sentence(sentence.text)

def scan_sentence(self, sentence_text):
sentence_text = sentence_text.strip()
if len(sentence_text) == 0:
return

spaces = sentence_text.count(' ')
if spaces + self.current_bulk_spaces > 128:
self.flush_bulk()

self.current_bulk.append(sentence_text)
self.current_bulk_spaces += spaces


class SemanticSearch:
def __init__(self, corpus):
self.embedder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
self.embedder.max_seq_length = 256
self.corpus = corpus

embeddings_file_path = 'embeddings.pkl'
if not os.path.isfile(embeddings_file_path):
print('embedding corpus')

corpus_embeddings = self.embedder.encode(corpus, convert_to_tensor=True, show_progress_bar=True)

with open(embeddings_file_path, 'wb') as file:
pickle.dump(corpus_embeddings, file)

with open(embeddings_file_path, 'rb') as file:
print('loading embedding')
self.embeddings = pickle.load(file)

def query(self, query_text, print_results=False, top_k=100):
query_embedding = self.embedder.encode(query_text, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, self.embeddings, top_k=top_k)
hits = hits[0]
results = []
for hit_index, hit in enumerate(hits):
result_text = self.corpus[hit['corpus_id']]
results.append(result_text)
if print_results:
print('=' * 100)
print('Query: {}'.format(query_text))
print('Search hit {} score {}'.format(hit_index, hit['score']))
print(result_text)

return results


class RelevantCheck:
def __init__(self):
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

def rank(self, query, texts, top_k=3, print_results=False):
encoder_input = [[query, text] for text in texts]
scores = self.cross_encoder.predict(encoder_input)
texts_scores = []
for i in range(len(texts)):
text = texts[i]
score = scores[i]
item = text, score
texts_scores.append(item)

texts_scores = sorted(texts_scores, key=lambda x: x[1], reverse=True)
if len(texts_scores) > top_k:
texts_scores = texts_scores[:top_k]

results = []
for i, text_score in enumerate(texts_scores):
text, score = text_score
results.append(text)
if print_results:
print('=' * 100)
print('Query: {}'.format(query))
print('result {} related score {}'.format(i, score))
print(text)

return results


class QuestionAnswer:
def __init__(self):
model_name = "deepset/roberta-base-squad2"
self.pipeline = pipeline('question-answering', model=model_name, tokenizer=model_name)

def answer(self, question, context):
question_input = {
'question': question,
'context': context,
}
answer_info = self.pipeline(question_input)
answer = answer_info['answer']
answer = answer.replace('\n', ' ')
answer = answer.replace(' ', ' ')
return answer


def main():
print('starting')

book_url = 'https://www.gutenberg.org/cache/epub/56928/pg56928.txt'
text = get_book_text(book_url)
# we don't want answers from the table of contents
text_after_toc = text[text.index('\nCHAPTER I\n'):]

corpus = TextSplitter().get_text_items(text_after_toc)
search = SemanticSearch(corpus)
relevant = RelevantCheck()
question_answer = QuestionAnswer()

queries = [
'who broke in uncontrollable sobbings?',
'what is the weather?',
'what is the useful knowledge?',
'how many men did Fredrick command?',
]

for query in queries:
all_results = search.query(query, print_results=False)
relevant_results = relevant.rank(query, all_results, print_results=False)
print('question: {}'.format(query))
for result_index, result in enumerate(relevant_results):
answer = question_answer.answer(query, result)
print('possible answer #{}: {}'.format(result_index, answer))


main()


No comments:

Post a Comment