Spaces:
Sleeping
Sleeping
import re | |
from ast import literal_eval | |
from nltk.stem import PorterStemmer, WordNetLemmatizer | |
# Entity Extraction | |
def generate_ner_docs_prompt(query): | |
prompt = """USER: Extract the company names and time duration mentioned in the question. The entities should be extracted in the following format: {"companies": list of companies mentioned in the question,"start-duration": ("start-quarter", "start-year"), "end-duration": ("end-quarter", "end-year")}. Return {"companies": None, "start-duration": (None, None), "end-duration": (None, None)} if the entities are not found. | |
Examples: | |
What is Intel's update on the server chip roadmap and strategy for Q1 2019? | |
{"companies": ["Intel"], "start-duration": ("Q1", "2019"), "end-duration": ("Q1", "2019")} | |
What are the opportunities and challenges in the Indian market for Amazon in 2016? | |
{"companies": ["Amazon"], "start-duration": ("Q1", "2016"), "end-duration": ("Q4", "2016")} | |
What did analysts ask about the Cisco's Webex? | |
{"companies": ["Cisco"], "start-duration": (None, None), "end-duration": (None, None)} | |
What is the comparative performance analysis between Intel and AMD in key overlapping segments such as PC, Gaming, and Data Centers in Q2 to Q3 2018? | |
{"companies": ["Intel", "AMD"], "start-duration": ("Q2", "2018"), "end-duration": ("Q3", "2018")} | |
How did Microsoft and Amazon perform in terms of reliability and scalability of cloud for the years 2016 and 2017? | |
{"companies": ["Microsoft", "Amazon"], "start-duration": ("Q1", "2016"), "end-duration": ("Q4", "2017")}""" | |
input_prompt = f"""###Input: {query} | |
ASSISTANT:""" | |
final_prompt = prompt + "\n\n" + input_prompt | |
return final_prompt | |
def extract_entities_docs(query, model): | |
""" | |
Takes input a string which contains a dictionary of entities of the format: | |
{"companies": list of companies mentioned in the question,"start-duration": ("start-quarter", "start-year"), "end-duration": ("end-quarter", "end-year")} | |
""" | |
prompt = generate_ner_docs_prompt(query) | |
string_of_dict = model.predict(prompt, api_name="/predict") | |
entities_dict = literal_eval(string_of_dict) | |
start_quarter, start_year = entities_dict["start-duration"] | |
end_quarter, end_year = entities_dict["end-duration"] | |
companies = entities_dict["companies"] | |
print((companies, start_quarter, start_year, end_quarter, end_year)) | |
return companies, start_quarter, start_year, end_quarter, end_year | |
def year_quarter_range(start_quarter, start_year, end_quarter, end_year): | |
""" | |
Creates a list of all (year, quarter) pairs that lie in the range including the start and end quarters. | |
Example: | |
year_quarter_range("Q2", "2020", "Q3", "2021") | |
[('Q2', '2020'), ('Q3', '2020'), ('Q4', '2020'), ('Q1', '2021'), ('Q2', '2021'), ('Q3', '2021')] | |
""" | |
if ( | |
start_quarter is None | |
or start_year is None | |
or end_quarter is None | |
or end_year is None | |
): | |
return [] | |
else: | |
quarters = ["Q1", "Q2", "Q3", "Q4"] | |
start_index = quarters.index(start_quarter) | |
end_index = quarters.index(end_quarter) | |
years = range(int(start_year), int(end_year) + 1) | |
year_quarter_range_list = [] | |
for year in years: | |
if year == int(start_year): | |
start = start_index | |
else: | |
start = 0 | |
if year == int(end_year): | |
end = end_index + 1 | |
else: | |
end = len(quarters) | |
for quarter_index in range(start, end): | |
year_quarter_range_list.append( | |
(quarters[quarter_index], str(year)) | |
) | |
return year_quarter_range_list | |
def clean_companies(company_list): | |
"""Returns list of Tickers from list of companies""" | |
company_ticker_map = { | |
"apple": "AAPL", | |
"amd": "AMD", | |
"amazon": "AMZN", | |
"cisco": "CSCO", | |
"google": "GOOGL", | |
"microsoft": "MSFT", | |
"nvidia": "NVDA", | |
"asml": "ASML", | |
"intel": "INTC", | |
"micron": "MU", | |
} | |
tickers = [ | |
"AAPL", | |
"CSCO", | |
"MSFT", | |
"ASML", | |
"NVDA", | |
"GOOGL", | |
"MU", | |
"INTC", | |
"AMZN", | |
"AMD", | |
"aapl", | |
"csco", | |
"msft", | |
"asml", | |
"nvda", | |
"googl", | |
"mu", | |
"intc", | |
"amzn", | |
"amd", | |
] | |
ticker_list = [] | |
for company in company_list: | |
if company.lower() in company_ticker_map.keys(): | |
ticker = company_ticker_map[company.lower()] | |
ticker_list.append(ticker) | |
elif company.lower() in tickers: | |
ticker_list.append(company.upper()) | |
return ticker_list | |
def ticker_year_quarter_tuples_creator(ticker_list, year_quarter_range_list): | |
ticker_year_quarter_tuples_list = [] | |
for ticker in ticker_list: | |
if year_quarter_range_list == []: | |
return [] | |
else: | |
for quarter, year in year_quarter_range_list: | |
ticker_year_quarter_tuples_list.append((ticker, quarter, year)) | |
return ticker_year_quarter_tuples_list | |
# Keyword Extraction | |
def generate_ner_keywords_prompt(query): | |
prompt = """USER: Extract the entities which describe the key theme and topics being asked in the question. Extract the entities in the following format: {"entities":["keywords"]}. | |
Examples: | |
What is Intel's update on the server chip roadmap and strategy for Q1 2019? | |
{"entities":["server"]} | |
What are the opportunities and challenges in the Indian market for Amazon from Q1 to Q3 in 2016? | |
{"entities":["indian"]} | |
What is the comparative performance analysis between Intel and AMD in key overlapping segments such as PC, Gaming, and Data Centers in Q1 2016? | |
{"entities":["PC","Gaming","Data Centers"]} | |
What was Google's and Microsoft's capex spend for the last 2 years? | |
{"entities":["capex"]} | |
What did analysts ask about the cloud during Microsoft's earnings call in Q1 2018? | |
{"entities":["cloud"]} | |
What was the growth in Apple services revenue for 2017 Q3? | |
{"entities":["services"]}""" | |
input_prompt = f"""###Input: {query} | |
ASSISTANT:""" | |
final_prompt = prompt + "\n" + input_prompt | |
return final_prompt | |
def extract_entities_keywords(query, model): | |
""" | |
Takes input a string which contains a dictionary of entities of the format: | |
{"entities":["keywords"]} | |
""" | |
prompt = generate_ner_keywords_prompt(query) | |
string_of_dict = model.predict(prompt, api_name="/predict") | |
entities_dict = literal_eval(string_of_dict) | |
keywords_list = entities_dict["entities"] | |
return keywords_list | |
def expand_list_of_lists(list_of_lists): | |
""" | |
Expands a list of lists of strings to a list of strings. | |
Args: | |
list_of_lists: A list of lists of strings. | |
Returns: | |
A list of strings. | |
""" | |
expanded_list = [] | |
for inner_list in list_of_lists: | |
for string in inner_list: | |
expanded_list.append(string) | |
return expanded_list | |
def all_keywords_combs(list_of_cleaned_keywords): | |
# Convert all strings to lowercase. | |
lower_texts = [text.lower() for text in list_of_cleaned_keywords] | |
# Stem the words in each string. | |
stemmer = PorterStemmer() | |
stem_texts = [stemmer.stem(text) for text in list_of_cleaned_keywords] | |
# Lemmatize the words in each string. | |
lemmatizer = WordNetLemmatizer() | |
lemm_texts = [ | |
lemmatizer.lemmatize(text) for text in list_of_cleaned_keywords | |
] | |
list_of_cleaned_keywords.extend(lower_texts) | |
list_of_cleaned_keywords.extend(stem_texts) | |
list_of_cleaned_keywords.extend(lemm_texts) | |
list_of_cleaned_keywords = list(set(list_of_cleaned_keywords)) | |
return list_of_cleaned_keywords | |
def create_incorrect_entities_list(): | |
words_to_remove = [ | |
"q1", | |
"q2", | |
"q3", | |
"q4", | |
"2016", | |
"2017", | |
"2018", | |
"2019", | |
"2020", | |
"apple", | |
"amd", | |
"amazon", | |
"cisco", | |
"google", | |
"microsoft", | |
"nvidia", | |
"asml", | |
"intel", | |
"micron", | |
"strategy", | |
"roadmap", | |
"impact", | |
"opportunities", | |
"challenges", | |
"growth", | |
"performance", | |
"analysis", | |
"segments", | |
"comparative", | |
"overlapping", | |
"acquisition", | |
"revenue", | |
] | |
words_to_remove = all_keywords_combs(words_to_remove) | |
return words_to_remove | |
def clean_keywords_all_combs(keywords_list): | |
words_to_remove = create_incorrect_entities_list() | |
texts = [text.split(" ") for text in keywords_list] | |
texts = expand_list_of_lists(texts) | |
# Convert all strings to lowercase. | |
lower_texts = [text.lower() for text in texts] | |
cleaned_keywords = [ | |
text for text in lower_texts if text not in words_to_remove | |
] | |
all_cleaned_keywords = all_keywords_combs(cleaned_keywords) | |
return all_cleaned_keywords | |