awinml's picture
Upload 2 files
a64e1a1
raw
history blame
No virus
4.39 kB
import streamlit as st
import pandas as pd
import pandas as pd
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
)
import streamlit_scrollable_textbox as stx
@st.experimental_singleton
def get_data():
data = pd.read_csv("earnings_calls_sentencewise.csv")
return data
# Initialize models from HuggingFace
@st.experimental_singleton
def get_t5_model():
return pipeline("summarization", model="t5-small", tokenizer="t5-small")
@st.experimental_singleton
def get_flan_t5_model():
return pipeline(
"summarization", model="google/flan-t5-small", tokenizer="google/flan-t5-small"
)
@st.experimental_singleton
def get_mpnet_embedding_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer(
"sentence-transformers/all-mpnet-base-v2", device=device
)
model.max_seq_length = 512
return model
@st.experimental_singleton
def get_sgpt_embedding_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer(
"Muennighoff/SGPT-125M-weightedmean-nli-bitfit", device=device
)
model.max_seq_length = 512
return model
@st.experimental_memo
def save_key(api_key):
return api_key
def query_pinecone(query, top_k, model, index, year, quarter, ticker, threshold=0.5):
# generate embeddings for the query
xq = model.encode([query]).tolist()
# search pinecone index for context passage with the answer
xc = index.query(
xq,
top_k=top_k,
filter={
"Year": int(year),
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
},
include_metadata=True,
)
# filter the context passages based on the score threshold
filtered_matches = []
for match in xc["matches"]:
if match["score"] >= threshold:
filtered_matches.append(match)
xc["matches"] = filtered_matches
return xc
def format_query(query_results):
# extract passage_text from Pinecone search result
context = [result["metadata"]["Text"] for result in query_results["matches"]]
return context
def sentence_id_combine(data, query_results, lag=2):
# Extract sentence IDs from query results
ids = [result["metadata"]["Sentence_id"] for result in query_results["matches"]]
# Generate new IDs by adding a lag value to the original IDs
new_ids = [id + i for id in ids for i in range(-lag, lag + 1)]
# Remove duplicates and sort the new IDs
new_ids = sorted(set(new_ids))
# Create a list of lookup IDs by grouping the new IDs in groups of lag*2+1
lookup_ids = [
new_ids[i : i + (lag * 2 + 1)] for i in range(0, len(new_ids), lag * 2 + 1)
]
# Create a list of context sentences by joining the sentences corresponding to the lookup IDs
context_list = [
". ".join(data.Text.iloc[lookup_id].to_list()) for lookup_id in lookup_ids
]
return context_list
def text_lookup(data, sentence_ids):
context = ". ".join(data.iloc[sentence_ids].to_list())
return context
def gpt3(query, result):
response = openai.Completion.create(
model="text-davinci-003",
prompt=f"""Context information is below. \n"
"---------------------\n"
"{result}"
"\n---------------------\n"
"Given the context information and prior knowledge, answer this question: {query}. \n"
"Try to include as many key details as possible and format the answer in points. \n" """,
temperature=0.1,
max_tokens=512,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=1,
)
return response.choices[0].text
# Transcript Retrieval
def retrieve_transcript(data, year, quarter, ticker):
row = (
data.loc[
(data.Year == int(year))
& (data.Quarter == quarter)
& (data.Ticker == ticker),
["Year", "Month", "Date", "Ticker"],
]
.drop_duplicates()
.iloc[0]
)
# convert row to a string and join values with "-"
row_str = "-".join(row.astype(str)) + ".txt"
open_file = open(
f"Transcripts/{ticker}/{row_str}",
"r",
)
file_text = open_file.read()
return file_text