Full Blog TOC

Full Blog Table Of Content with Keywords Available HERE

Monday, November 18, 2024

Go ReadLine for Long Lines


 


In this post we present the correct method of reading lines from a long text in GO.


The naive method of reading lines in go is using bufio.Scanner:

scanner := bufio.NewScanner(file)
for scanner.Scan() {
fmt.Println(scanner.Text())
}


This does not work in case of long lines. Instead, we should use bufio.Reader. However, the compiling of long lines is cumbersome, and a KISS wrapper is missing in the GO standard library, hence I've created it myself.


import (
"bufio"
"io"
"strings"
)

type LineReader struct {
reader *bufio.Reader
isEof bool
}

func ProduceLineReader(
text string,
) *LineReader {
reader := bufio.NewReader(strings.NewReader(text))

return &LineReader{
reader: reader,
isEof: false,
}
}

func (r *LineReader) GetLine() string {
var longLineBuffer *strings.Builder
multiLines := false
for {
line, isPrefix, err := r.reader.ReadLine()
if err == io.EOF {
r.isEof = true
return ""
}

if isPrefix {
multiLines = true
}

if !multiLines {
// simple single line
return string(line)
}

if longLineBuffer == nil {
// create only if needed - better performance
longLineBuffer = &strings.Builder{}
}

longLineBuffer.Write(line)
if !isPrefix {
// end of long line
return longLineBuffer.String()
}
}
}

func (r *LineReader) IsEof() bool {
return r.isEof
}


An example of usage is:

reader := kitstring.ProduceLineReader(text)
for {
line := reader.GetLine()
if reader.IsEof() {
break
}
t.Log("read line: %v", line)
}



Monday, November 11, 2024

OpenAPI Schema In Go



 

In this post we present a utilty to produce an OpenAPI file using GO commands. This is useful in case the application detects the schema, and is required to supply an OpenAPI file of the detected schema.



package openapi

import (
"gopkg.in/yaml.v3"
"strings"
)

const applicationJson = "application/json"

type InLocation int

const (
InLocationQuery InLocation = iota + 1
InLocationHeader
InLocationCookie
InLocationPath
)

type SchemaType int

const (
SchemaTypeString SchemaType = iota + 1
SchemaTypeInt
SchemaTypeObject
SchemaTypeArray
)

type Info struct {
Title string `yaml:"title,omitempty"`
Description string `yaml:"description,omitempty"`
Version string `yaml:"version,omitempty"`
}

type Schema struct {
Type string `yaml:"type,omitempty"`
Enum []string `yaml:"enum,omitempty"`
Properties map[string]*Schema `yaml:"properties,omitempty"`
Items *Schema `yaml:"items,omitempty"`
}

type Parameter struct {
In string `yaml:"in,omitempty"`
Name string `yaml:"name,omitempty"`
Description string `yaml:"description,omitempty"`
Required bool `yaml:"required,omitempty"`
Schema *Schema `yaml:"schema,omitempty"`
}

type Content map[string]*Schema

type RequestBody struct {
Description string `yaml:"description,omitempty"`
Required bool `yaml:"required,omitempty"`
Content *Content `yaml:"content,omitempty"`
}

type Response struct {
Description string `yaml:"description,omitempty"`
Content *Content `yaml:"content,omitempty"`
}

type Method struct {
Summary string `yaml:"summary,omitempty"`
Description string `yaml:"description,omitempty"`
Deprecated string `yaml:"deprecated,omitempty"`
Parameters []*Parameter `yaml:"parameters,omitempty"`
RequestBody *RequestBody `yaml:"requestBody,omitempty"`
Responses map[string]*Response `yaml:"responses,omitempty"`
}

type Path map[string]*Method

type OpenApi struct {
OpenApi string `yaml:"openapi,omitempty"`
Info *Info `yaml:"info,omitempty"`
Paths map[string]*Path `yaml:"paths,omitempty"`
}

func produceSchema() *Schema {
return &Schema{
Properties: make(map[string]*Schema),
}
}

func ProduceOpenApi() *OpenApi {
return &OpenApi{
OpenApi: "3.0.0",
Paths: make(map[string]*Path),
}
}

func (o *OpenApi) CreateYamlBytes() []byte {
bytes, err := yaml.Marshal(o)
kiterr.RaiseIfError(err)
return bytes
}

func (o *OpenApi) CreateYamlString() string {
return string(o.CreateYamlBytes())
}

func (o *OpenApi) SetPath(
path string,
) *Path {
for pathUrl, pathObject := range o.Paths {
if pathUrl == path {
return pathObject
}
}
pathObject := make(Path)
o.Paths[path] = &pathObject
return &pathObject
}

func (p *Path) SetMethod(
method string,
) *Method {
method = strings.ToLower(method)

pathObject := *p
existingMethod := pathObject[method]
if existingMethod != nil {
return existingMethod
}
methodObject := Method{
Responses: make(map[string]*Response),
}
pathObject[method] = &methodObject
return &methodObject
}

func (m *Method) SetParameter(
name string,
) *Parameter {
for _, parameter := range m.Parameters {
if parameter.Name == name {
return parameter
}
}

parameter := Parameter{
Name: name,
}
m.Parameters = append(m.Parameters, &parameter)
return &parameter
}

func (p *Parameter) SetInLocation(
in InLocation,
) *Parameter {
switch in {
case InLocationQuery:
p.In = "query"
case InLocationCookie:
p.In = "cookie"
case InLocationHeader:
p.In = "header"
case InLocationPath:
p.In = "path"
}
return p
}

func (p *Parameter) SetSchema(
schemaType SchemaType,
) *Parameter {
schema := p.Schema
if schema == nil {
schema = produceSchema()
p.Schema = schema
}

schema.SetType(schemaType)
return p
}

func (s *Schema) SetType(
schemaType SchemaType,
) *Schema {
switch schemaType {
case SchemaTypeString:
s.Type = "string"
case SchemaTypeInt:
s.Type = "integer"
case SchemaTypeObject:
s.Type = "object"
case SchemaTypeArray:
s.Type = "array"
}
return s
}

func (s *Schema) SetProperty(
name string,
schemaType SchemaType,
) *Schema {
property := s.Properties[name]
if property == nil {
property = produceSchema()
s.Properties[name] = property
}

property.SetType(schemaType)
return property
}

func (s *Schema) SetPropertyArray(
name string,
) *Schema {
array := s.SetProperty(name, SchemaTypeArray)
array.Items = produceSchema()
return array.Items
}

func (m *Method) SetRequestContent(
contentType string,
) *Schema {
body := m.RequestBody
if body == nil {
body = &RequestBody{}
m.RequestBody = body
}

content := body.Content
if content == nil {
content = &Content{}
body.Content = content
}

contentObject := *content
schema := contentObject[contentType]
if schema == nil {
schema = produceSchema()
contentObject[contentType] = schema
}

return schema
}

func (m *Method) SetContentApplicationJson() *Schema {
return m.SetRequestContent(applicationJson)
}

func (m *Method) SetResponseContent(
responseCode string,
contentType string,
) *Schema {
response := m.Responses[responseCode]
if response == nil {
response = &Response{}
m.Responses[responseCode] = response
}

content := response.Content
if content == nil {
content = &Content{}
response.Content = content
}

contentObject := *content
schema := contentObject[contentType]
if schema == nil {
schema = produceSchema()
contentObject[contentType] = schema
}

return schema
}

func (m *Method) SetResponseSuccessContentApplicationJson() *Schema {
return m.SetResponseContent("200", applicationJson)
}



The sample usage below creates 2 endpoints of list-items and add-item.


api := openapi.ProduceOpenApi()

method := api.SetPath("/api/list-items").SetMethod("GET")
method.Description = "list all store items"

method.SetParameter("store-id").
SetInLocation(openapi.InLocationQuery).
SetSchema(openapi.SchemaTypeInt)

listStoreSchema := method.SetResponseSuccessContentApplicationJson()
existingItemSchema := listStoreSchema.SetPropertyArray("items")
existingItemSchema.SetProperty("id", openapi.SchemaTypeInt)
existingItemSchema.SetProperty("name", openapi.SchemaTypeString)
existingItemSchema.SetProperty("price", openapi.SchemaTypeInt)

method = api.SetPath("/api/add-item").SetMethod("POST")
method.Description = "add item to store"
addItemSchema := method.SetContentApplicationJson()
addItemSchema.SetType(openapi.SchemaTypeObject)
newItemSchema := addItemSchema.SetProperty("item", openapi.SchemaTypeObject)
newItemSchema.SetProperty("name", openapi.SchemaTypeString)
newItemSchema.SetProperty("price", openapi.SchemaTypeInt)

t.Log("Schema is:\n\n%v", api.CreateYamlString())


The result openAPI file is:


openapi: 3.0.0
paths:
/api/add-item:
post:
description: add item to store
requestBody:
content:
application/json:
type: object
properties:
item:
type: object
properties:
name:
type: string
price:
type: integer
/api/list-items:
get:
description: list all store items
parameters:
- in: query
name: store-id
schema:
type: integer
responses:
"200":
content:
application/json:
properties:
items:
type: array
items:
properties:
id:
type: integer
name:
type: string
price:
type: integer



Tuesday, November 5, 2024

Streamlit


 


In this post we present a simple example of a streamlit based application. Streamlit is a framework providing simple and fast GUI development for internal use. The nice thing is that it is very simple, the coding is pure python without any Javascript and without multiple processes, Notice that the GUI is not for the end user, but mostly for internal developers use, since due to its simplicity it is also limited.



The following code handles configuration for a long running process, and rerunning the process if and only if the configuration changes. Some screenshots are below:








The simple python code is below.


import random
import time
from datetime import datetime

import pandas as pd
import streamlit as st


def display_bars(amount):
data = []
for i in range(amount):
data.append(['person {}'.format(i), random.randint(10, 50)])
df = pd.DataFrame(data, columns=['Name', 'Age'])
st.bar_chart(df, x="Name", y="Age")


@st.cache_data
def run_long_processing(compute_time):
progress = st.empty()

with progress.container():
st.text('starting with {} times'.format(compute_time))
for i in range(compute_time):
time.sleep(1)
st.text('iteration {}'.format(i))
time.sleep(1)
progress.empty()

st.text('task with {} times is complete'.format(compute_time))
display_bars(compute_time)


@st.fragment(run_every='1s')
def display_time():
st.write(datetime.now())


def main():
st.title('Demo Application')
display_time()

if st.button('clear cache'):
st.cache_data.clear()

tab1, tab2 = st.tabs(['config', 'results'])

with tab1:
st.session_state['compute_time'] = st.slider('select compute time', 1, 10, 3)
with tab2:
compute_time = st.session_state['compute_time']
run_long_processing(compute_time)


main()




Tuesday, October 29, 2024

Grafana K6


 

In this post we review using Grafana K6, a tool for testing applications, using a java script like scripts.

We will use the following test file as our test:

test.js

import http from 'k6/http';
import {check, sleep} from 'k6';

export const options = {
vus: 2,
duration: '5s',
};
export default function () {
const url = 'http://test.k6.io/login';
const payload = JSON.stringify({
email: 'aaa',
password: 'bbb',
});

const params = {
headers: {
'Content-Type': 'application/json',
},
};

const result = http.post(url, payload, params);
check(result, {
'is status 200': (r) => r.status === 200,
});
sleep(1);
}



And we will review 2 methods of running the tool: docker and kubernetes.


Running in docker

To run in docker, we use the following dockerfile

Dockerfile

FROM grafana/k6
WORKDIR /usr/src/app
COPY . .
CMD [ "k6", "run", "test.js" ]


Next we build and run the docker

#!/usr/bin/env bash

set -e
cd "$(dirname "$0")"

docker build . -t my-k6
docker run --rm my-k6 run test.js


The output is:

[+] Building 1.1s (8/8) FINISHED                                                                                                                                  docker:default
=> [internal] load build definition from Dockerfile 0.0s
=> => transferring dockerfile: 114B 0.0s
=> [internal] load metadata for docker.io/grafana/k6:latest 0.7s
=> [internal] load .dockerignore 0.0s
=> => transferring context: 2B 0.0s
=> [1/3] FROM docker.io/grafana/k6:latest@sha256:d39047ea6c5981ac0abacec2ea32389f22a7aa68bc8902c08b356cc5dd74aac9 0.0s
=> [internal] load build context 0.0s
=> => transferring context: 904B 0.0s
=> CACHED [2/3] WORKDIR /usr/src/app 0.0s
=> [3/3] COPY . . 0.1s
=> exporting to image 0.1s
=> => exporting layers 0.0s
=> => writing image sha256:dd14bff333a7714cf09ff4c60a6ae820174bb575404ed4c63acf871a27148878 0.0s
=> => naming to docker.io/library/my-k6 0.0s

/\ Grafana /‾‾/
/\ / \ |\ __ / /
/ \/ \ | |/ / / ‾‾\
/ \ | ( | (‾) |
/ __________ \ |_|\_\ \_____/

execution: local
script: test.js
output: -

scenarios: (100.00%) 1 scenario, 2 max VUs, 35s max duration (incl. graceful stop):
* default: 2 looping VUs for 5s (gracefulStop: 30s)


running (01.0s), 2/2 VUs, 0 complete and 0 interrupted iterations
default [ 20% ] 2 VUs 1.0s/5s

running (02.0s), 2/2 VUs, 0 complete and 0 interrupted iterations
default [ 40% ] 2 VUs 2.0s/5s

running (03.0s), 2/2 VUs, 2 complete and 0 interrupted iterations
default [ 60% ] 2 VUs 3.0s/5s

running (04.0s), 2/2 VUs, 4 complete and 0 interrupted iterations
default [ 80% ] 2 VUs 4.0s/5s

running (05.0s), 2/2 VUs, 6 complete and 0 interrupted iterations
default [ 100% ] 2 VUs 5s

running (06.0s), 2/2 VUs, 6 complete and 0 interrupted iterations
default ↓ [ 100% ] 2 VUs 5s

✗ is status 200
0% — ✓ 0 / ✗ 8

checks.........................: 0.00% 0 out of 8
data_received..................: 16 kB 2.6 kB/s
data_sent......................: 3.8 kB 605 B/s
http_req_blocked...............: avg=75.92ms min=4.21µs med=17.25µs max=384.4ms p(90)=302.09ms p(95)=379.63ms
http_req_connecting............: avg=176.26µs min=0s med=0s max=717.99µs p(90)=703.73µs p(95)=707.57µs
http_req_duration..............: avg=191.36ms min=150.99ms med=156.8ms max=354.78ms p(90)=252.73ms p(95)=278.68ms
{ expected_response:true }...: avg=202.91ms min=153.28ms med=167.17ms max=354.78ms p(90)=283.75ms p(95)=319.27ms
http_req_failed................: 50.00% 8 out of 16
http_req_receiving.............: avg=155.79µs min=77.92µs med=154.46µs max=292.35µs p(90)=237.78µs p(95)=269.18µs
http_req_sending...............: avg=147.67µs min=27.02µs med=81.32µs max=533.87µs p(90)=296.64µs p(95)=377.56µs
http_req_tls_handshaking.......: avg=47.53ms min=0s med=0s max=383.46ms p(90)=188.55ms p(95)=378.7ms
http_req_waiting...............: avg=191.06ms min=150.59ms med=156.55ms max=353.98ms p(90)=252.5ms p(95)=278.36ms
http_reqs......................: 16 2.55197/s
iteration_duration.............: avg=1.53s min=1.3s med=1.33s max=2.21s p(90)=2.1s p(95)=2.16s
iterations.....................: 8 1.275985/s
vus............................: 2 min=2 max=2
vus_max........................: 2 min=2 max=2


running (06.3s), 0/2 VUs, 8 complete and 0 interrupted iterations
default ✓ [ 100% ] 2 VUs 5s


Running in kubernetes

To run in kubernetes, we deploy a k6 operator which will create jobs for test runs.


curl https://raw.githubusercontent.com/grafana/k6-operator/main/bundle.yaml | kubectl apply -f -


The outputs from the testrun in this case are sent to prometheus, so we create a dedicated image that handles the write to promethues.

Dockerfile

FROM golang:1.20 AS builder

RUN go install go.k6.io/xk6/cmd/xk6@latest

RUN xk6 build \
--with github.com/grafana/xk6-output-prometheus-remote@latest \
--output /k6

FROM grafana/k6:latest
COPY --from=builder /k6 /usr/bin/k6


Build the test runner image:

docker build -t k6-extended:local .
kind load docker-image k6-extended:local


We use the TestRun CRD to configure the test run.


k8s_testrun.yml

apiVersion: k6.io/v1alpha1
kind: TestRun
metadata:
name: my-testrun
spec:
cleanup: post
parallelism: 2
script:
configMap:
name: my-test
file: test.js
runner:
image: k6-extended:local
env:
- name: K6_PROMETHEUS_RW_SERVER_URL
value: http://prometheus/api/v1/write


And we run the test:


kubectl delete configmap my-test --ignore-not-found=true
kubectl create configmap my-test --from-file ./test.js
kubectl apply -f ./k8s_testrun.yml



The results can be viewed in grafana using a predefined dashboard.




Monday, October 14, 2024

Split Load Using NATS Partitions



 

In this post we will review NATS partitioning and how to use it to split load among multiple pods.

We've already reviewed the steps to setup a NATS cluster in kubernetes in this post. As part of the NATS statefulset template we have a config map, which is mounted into the NATS server container.



- name: nats
args:
- --config
- /etc/nats-config/nats.conf


volumeMounts:
- mountPath: /etc/nats-config
name: config


volumes:
- configMap:
name: nats-config



The ConfigMap is as follows:


apiVersion: v1
kind: ConfigMap
metadata:
name: nats-config
data:
nats.conf: |
{
"cluster": {
"name": "nats",
"no_advertise": true,
"port": 6222,
"routes": [
"nats://nats-0.nats-headless:6222",
"nats://nats-1.nats-headless:6222"
]
},
"http_port": 8222,
"lame_duck_duration": "30s",
"lame_duck_grace_period": "10s",
"pid_file": "/var/run/nats/nats.pid",
"port": 4222,
"server_name": $SERVER_NAME,
"mappings": {
"application.*": "application.{{partition(3,1)}}.{{wildcard(1)}}"
}
}



The NATS partitioning is configured by the mapping section in the config. In this case the producer publish messages to the NATS queue "application-<APPLICATION_ID>".

We want to split the load among 3 pods, and hence we configure the following:

"application.*": "application.{{partition(3,1)}}.{{wildcard(1)}}"


This provides the following instruction to the NATS:

Take the value used in the first wild card "application.*" , and use it to split into 3 partitions 0,1,2. Then publish the message in the following queue "application.<PARTITION_ID>.<APPLICATION ID>".


Now we can have 3 pods, each subsribing to the prefix: "application.<PARTITION_ID>.*", and the messages are split accordingly. Notice that a specific APPLICATION ID is always assigned to the same PARTITION_ID.

Also notice that this is a static assignment, regardless of the load on each application. In case we have hightly unbalanced applications load, this might not be a suitable method.






Wednesday, September 18, 2024

Visualize and Cluster Embedding Vectors

 


In this post we will present how to cluster embedding vectors, and how to visualize the results.

We start with a simple random tensors that represents the embeddings, and plot them using t-SNE which reduces the 768 features to a 2 dimentional representation.

To make the t-SNE work nice, we will insert predetermined variance into the data.


import random

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE

random.seed(42)
torch.manual_seed(42)

embeddings = None
for i in range(100):
embeddings_part = torch.rand(20, 768) + random.uniform(0, 0.5)
if embeddings is None:
embeddings = embeddings_part
else:
embeddings = torch.cat((embeddings, embeddings_part))

tsne = TSNE(2)
tsne_result = tsne.fit_transform(embeddings)
x, y = tsne_result[:, 0], tsne_result[:, 1]
plt.clf()

plt.scatter(x, y, s=2)
plt.legend()
plt.savefig('embeddings.pdf')


The plotted graph is:





Next, we use k-means to cluster the vectors to 10 clusters, and display the result clustering.


cluster_amount = 10
kmeans_model = KMeans(n_clusters=cluster_amount, random_state=0)
corpus_clusters = kmeans_model.fit_predict(embeddings).tolist()

plt.clf()

for cluster in range(cluster_amount):
cluster_points = None
for sample_index, sample_cluster in enumerate(corpus_clusters):
if sample_cluster == cluster:
tsne_point = tsne_result[sample_index]
tsne_point = np.expand_dims(tsne_point, axis=0)
if cluster_points is None:
cluster_points = tsne_point
else:
cluster_points = np.concatenate((cluster_points, tsne_point))

if cluster_points is not None:
x = cluster_points[:, 0]
y = cluster_points[:, 1]
plt.scatter(x, y, s=2, label=cluster)

plt.legend()
plt.savefig("clustered.pdf")



And the plotted clustering is:






Thursday, September 12, 2024

Question Answered


 


In the previuos post Search On My Own, we've located a relevant section in a book to match a question. In this post we extend this to actually answer the question using the relecant extracted text from the book.

This is done using a pre-trained model. No additional fine tune is required. The actual change is the addition of QuestionAnswer class that uses the model.


Output example:


question: who broke in uncontrollable sobbings?
possible answer #0: Madame, the Marchioness of Schwedt
possible answer #1: princess
possible answer #2: Madam Sonsfeld
question: what is the weather?
possible answer #0: cold, gloomy, December
possible answer #1: long, cold, and dreary
possible answer #2: snow-tempests, sleet, frost
question: what is the useful knowledge?
possible answer #0: the study of philosophy, history, and languages
possible answer #1: most points
possible answer #2: Useful discourse
question: how many men did Fredrick command?
possible answer #0: twenty-four thousand
possible answer #1: thirty-five thousand
possible answer #2: four hundred



The full code is below.


import os.path
import pickle
import time
import urllib.request

import spacy
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from transformers import pipeline


def get_book_text(book_url):
print('loading book')
book_file_path = 'book.txt'
if not os.path.isfile(book_file_path):
with urllib.request.urlopen(book_url) as response:
data = response.read()
with open(book_file_path, 'wb') as file:
file.write(data)

with open(book_file_path, 'r') as file:
return file.read()


class TimedPrinter:
def __init__(self):
self.last_print = time.time()

def print(self, message):
passed = time.time() - self.last_print
if passed > 5:
self.last_print = time.time()
print(message)


class TextSplitter:
def __init__(self):
self.items = []
self.printer = TimedPrinter()
self.nlp = spacy.load('en_core_web_sm')
self.current_bulk = []
self.current_bulk_spaces = 0
self.min_words_in_section = 5

def flush_bulk(self):
if len(self.current_bulk) > 0:
self.items.append('\n'.join(self.current_bulk))

self.current_bulk = []
self.current_bulk_spaces = 0

def get_text_items(self, full_text, print_results=False):
items_file_path = 'items.pkl'
if not os.path.isfile(items_file_path):
self.split_text(full_text)
with open(items_file_path, 'wb') as file:
pickle.dump(self.items, file)

with open(items_file_path, 'rb') as file:
print('loading embedding')
self.items = pickle.load(file)

if print_results:
print('\n===\n'.join(self.items))

print('final split size is {}'.format(len(self.items)))
return self.items

def split_text(self, full_text):
print('breaking text')
sections = full_text.split('\n\n')
print('text length {}'.format(len(full_text)))
print('text split to {} sections'.format(len(sections)))

for section_index, section_text in enumerate(sections):
self.printer.print('section {}/{}'.format(section_index + 1, len(sections)))
self.scan_section(section_text)
self.flush_bulk()

def scan_section(self, section_text):
section_text = section_text.strip()
if section_text.count(' ') < self.min_words_in_section:
return
doc = self.nlp(section_text)

for sentence in doc.sents:
self.scan_sentence(sentence.text)

def scan_sentence(self, sentence_text):
sentence_text = sentence_text.strip()
if len(sentence_text) == 0:
return

spaces = sentence_text.count(' ')
if spaces + self.current_bulk_spaces > 128:
self.flush_bulk()

self.current_bulk.append(sentence_text)
self.current_bulk_spaces += spaces


class SemanticSearch:
def __init__(self, corpus):
self.embedder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
self.embedder.max_seq_length = 256
self.corpus = corpus

embeddings_file_path = 'embeddings.pkl'
if not os.path.isfile(embeddings_file_path):
print('embedding corpus')

corpus_embeddings = self.embedder.encode(corpus, convert_to_tensor=True, show_progress_bar=True)

with open(embeddings_file_path, 'wb') as file:
pickle.dump(corpus_embeddings, file)

with open(embeddings_file_path, 'rb') as file:
print('loading embedding')
self.embeddings = pickle.load(file)

def query(self, query_text, print_results=False, top_k=100):
query_embedding = self.embedder.encode(query_text, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, self.embeddings, top_k=top_k)
hits = hits[0]
results = []
for hit_index, hit in enumerate(hits):
result_text = self.corpus[hit['corpus_id']]
results.append(result_text)
if print_results:
print('=' * 100)
print('Query: {}'.format(query_text))
print('Search hit {} score {}'.format(hit_index, hit['score']))
print(result_text)

return results


class RelevantCheck:
def __init__(self):
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

def rank(self, query, texts, top_k=3, print_results=False):
encoder_input = [[query, text] for text in texts]
scores = self.cross_encoder.predict(encoder_input)
texts_scores = []
for i in range(len(texts)):
text = texts[i]
score = scores[i]
item = text, score
texts_scores.append(item)

texts_scores = sorted(texts_scores, key=lambda x: x[1], reverse=True)
if len(texts_scores) > top_k:
texts_scores = texts_scores[:top_k]

results = []
for i, text_score in enumerate(texts_scores):
text, score = text_score
results.append(text)
if print_results:
print('=' * 100)
print('Query: {}'.format(query))
print('result {} related score {}'.format(i, score))
print(text)

return results


class QuestionAnswer:
def __init__(self):
model_name = "deepset/roberta-base-squad2"
self.pipeline = pipeline('question-answering', model=model_name, tokenizer=model_name)

def answer(self, question, context):
question_input = {
'question': question,
'context': context,
}
answer_info = self.pipeline(question_input)
answer = answer_info['answer']
answer = answer.replace('\n', ' ')
answer = answer.replace(' ', ' ')
return answer


def main():
print('starting')

book_url = 'https://www.gutenberg.org/cache/epub/56928/pg56928.txt'
text = get_book_text(book_url)
# we don't want answers from the table of contents
text_after_toc = text[text.index('\nCHAPTER I\n'):]

corpus = TextSplitter().get_text_items(text_after_toc)
search = SemanticSearch(corpus)
relevant = RelevantCheck()
question_answer = QuestionAnswer()

queries = [
'who broke in uncontrollable sobbings?',
'what is the weather?',
'what is the useful knowledge?',
'how many men did Fredrick command?',
]

for query in queries:
all_results = search.query(query, print_results=False)
relevant_results = relevant.rank(query, all_results, print_results=False)
print('question: {}'.format(query))
for result_index, result in enumerate(relevant_results):
answer = question_answer.answer(query, result)
print('possible answer #{}: {}'.format(result_index, answer))


main()