vectorsearch / app.py
JPBianchi's picture
removed some finetuning messages
ef5768a
raw
history blame
35.2 kB
#%%
from tiktoken import get_encoding, encoding_for_model
from weaviate_interface import WeaviateClient, WhereFilter
from sentence_transformers import SentenceTransformer
from prompt_templates import question_answering_prompt_series, question_answering_system
from openai_interface import GPT_Turbo
from app_features import (convert_seconds, generate_prompt_series, search_result,
validate_token_threshold, load_content_cache, load_data,
expand_content)
from retrieval_evaluation import execute_evaluation, calc_hit_rate_scores
from llama_index.finetuning import EmbeddingQAFinetuneDataset
from weaviate_interface import WeaviateClient
from openai import BadRequestError
from reranker import ReRanker
from loguru import logger
import streamlit as st
from streamlit_option_menu import option_menu
import hydralit_components as hc
import sys
import json
import os, time, requests, re
from datetime import timedelta
import pathlib
import gdown
import tempfile
import base64
import shutil
def get_base64_of_bin_file(bin_file):
with open(bin_file, 'rb') as file:
data = file.read()
return base64.b64encode(data).decode()
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv('env'), override=True)
# I use a key that I increment each time I want to change a text_input
if 'key' not in st.session_state:
st.session_state.key = 0
# key = st.session_state['key']
if not pathlib.Path('models').exists():
os.mkdir('models')
# I should cache these things but no time left
# I put a file local.txt in my desktop models folder to find out if it's running online
we_are_online = not pathlib.Path("models/local.txt").exists()
we_are_not_online = not we_are_online
golden_dataset = EmbeddingQAFinetuneDataset.from_json("data/golden_100.json")
# shutil.rmtree("models/models") # remove it - I wanted to clear the space on streamlit online
## PAGE CONFIGURATION
st.set_page_config(page_title="Ask Impact Theory",
page_icon="assets/impact-theory-logo-only.png",
layout="wide",
initial_sidebar_state="collapsed",
menu_items={'Report a bug': "https://www.extremelycoolapp.com/bug"})
image = "https://is2-ssl.mzstatic.com/image/thumb/Music122/v4/bd/34/82/bd348260-314c-5898-26c0-bef2e0388ebe/source/1200x1200bb.png"
def add_bg_from_local(image_file):
bin_str = get_base64_of_bin_file(image_file)
page_bg_img = f'''
<style>
.stApp {{
background-image: url("data:image/png;base64,{bin_str}");
background-size: 100% auto;
background-repeat: no-repeat;
background-attachment: fixed;
}}
</style>
'''
st.markdown(page_bg_img, unsafe_allow_html=True)
# COMMENT: I tried to create a dropdown menu but it's harder than it looks, so I gave up
# https://discuss.streamlit.io/t/streamlit-option-menu-is-a-simple-streamlit-component-that-allows-users-to-select-a-single-item-from-a-list-of-options-in-a-menu/20514
# not great, but it works
# selected = option_menu("About", ["Improvements","This"], #"Main Menu", ["Home", 'Settings'],
# icons=['house', 'gear'],
# menu_icon="cast",
# default_index=1)
# # Custom HTML/CSS for the banner
# base64_img = get_base64_of_bin_file("assets/it_tom_bilyeu.png")
# banner_menu_html = f"""
# <div class="banner">
# <img src= "data:image/png;base64,{base64_img}" alt="Banner Image">
# </div>
# <style>
# .banner {{
# width: 100%;
# height: auto;
# overflow: hidden;
# display: flex;
# justify-content: center;
# }}
# .banner img {{
# width: 130%;
# height: auto;
# object-fit: contain;
# }}
# </style>
# """
# st.components.v1.html(banner_menu_html)
# specify the primary menu definition
# it gives a vertical menu inside a navigation bar !!!
# menu_data = [
# {'icon': "far fa-copy", 'label':"Left End"},
# {'id':'Copy','icon':"🐙",'label':"Copy"},
# {'icon': "far fa-chart-bar", 'label':"Chart"},#no tooltip message
# {'icon': "far fa-address-book", 'label':"Book"},
# {'id':' Crazy return value 💀','icon': "💀", 'label':"Calendar"},
# {'icon': "far fa-clone", 'label':"Component"},
# {'icon': "fas fa-tachometer-alt", 'label':"Dashboard",'ttip':"I'm the Dashboard tooltip!"}, #can add a tooltip message
# {'icon': "far fa-copy", 'label':"Right End"},
# ]
# # we can override any part of the primary colors of the menu
# over_theme = {'txc_inactive': '#FFFFFF','menu_background':'red','txc_active':'yellow','option_active':'blue'}
# # over_theme = {'txc_inactive': '#FFFFFF'}
# menu_id = hc.nav_bar(menu_definition=menu_data,
# home_name='Home',
# override_theme=over_theme)
#get the id of the menu item clicked
# st.info(f"{menu_id=}")
## RERANKER
reranker = ReRanker('cross-encoder/ms-marco-MiniLM-L-6-v2')
## ENCODING --> tiktoken library
model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613']
model_nameGPT = model_ids[1]
encoding = encoding_for_model(model_nameGPT)
# = get_encoding('gpt-3.5-turbo-0613')
##############
data_path = './data/impact_theory_data.json'
cache_path = 'data/impact_theory_cache.parquet'
data = load_data(data_path)
cache = None # load_content_cache(cache_path)
if 'secrets' in st.secrets:
# st.write("Loading secrets from [secrets] section")
# for streamlit online or local, which uses a [secrets] section
Wapi_key = st.secrets['secrets']['WEAVIATE_API_KEY']
url = st.secrets['secrets']['WEAVIATE_ENDPOINT']
openai_api_key = st.secrets['secrets']['OPENAI_API_KEY']
hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN_chris']
hf_endpoint = st.secrets['secrets']['LLAMA2_ENDPOINT_UPLIMIT']
else :
# st.write("Loading secrets for Huggingface")
# for Huggingface (no [secrets] section)
Wapi_key = st.secrets['WEAVIATE_API_KEY']
url = st.secrets['WEAVIATE_ENDPOINT']
openai_api_key = st.secrets['OPENAI_API_KEY']
hf_token = st.secrets['LLAMA2_ENDPOINT_HF_TOKEN_chris']
hf_endpoint = st.secrets['LLAMA2_ENDPOINT_UPLIMIT']
# else:
# # if we want to use env file
# st.write("Loading secrets from environment variables")
# api_key = os.environ['WEAVIATE_API_KEY']
# url = os.environ['WEAVIATE_ENDPOINT']
# openai_api_key = os.environ['OPENAI_API_KEY']
# hf_token = os.environ['LLAMA2_ENDPOINT_HF_TOKEN_chris']
# hf_endpoint = os.environ['LLAMA2_ENDPOINT_UPLIMIT']
#%%
# model_default = 'sentence-transformers/all-mpnet-base-v2'
model_default = 'models/finetuned-all-mpnet-base-v2-300' if we_are_not_online \
else 'sentence-transformers/all-mpnet-base-v2'
available_models = ['sentence-transformers/all-mpnet-base-v2',
'sentence-transformers/all-MiniLM-L6-v2',
'models/finetuned-all-mpnet-base-v2-300',
'sentence-transformers/all-MiniLM-L12-v2']
#%%
models_urls = {'models/finetuned-all-mpnet-base-v2-300': "https://drive.google.com/drive/folders/1asJ37-AUv5nytLtH6hp6_bVV3_cZOXfj"}
def download_model_from_Gdrive(model_name_or_path, model_local_path):
st.write("Downloading model from Google Drive")
assert model_name_or_path in models_urls, f"Model {model_name_or_path} not found in models_urls"
url = models_urls[model_name_or_path]
gdown.download_folder(url, output=model_local_path, quiet=False, use_cookies=False)
print(f"Model downloaded from Gdrive and saved to {model_local_path} folder")
# st.write("Model downloaded")
def download_model(model_name_or_path, model_local_path):
if model_name_or_path.startswith("models/"):
download_model_from_Gdrive(model_name_or_path, model_local_path)
elif model_name_or_path.startswith("sentence-transformers/"):
st.sidebar.write(f"Downloading {model_name_or_path}")
model = SentenceTransformer(model_name_or_path)
st.sidebar.write(f"Model {model_name_or_path} downloaded")
models_urls[model_name_or_path] = model_local_path
model.save(model_local_path)
# st.sidebar.write(f"Model {model_name_or_path} saved to {model_new_path}")
#%%
# for streamlit online, we must download the model from google drive
# because github LFS doesn't work on forked repos
def check_model(model_name_or_path):
model_name = model_name_or_path.split('/')[-1] # remove 'sentence-transformers'
model_local_path = str(pathlib.Path("models") / model_name) # this creates a models folder inside /models
if pathlib.Path(model_local_path).exists():
# let's use the model that's already there
print(f"Model {model_local_path} already exists")
else:
# let's download the model, HF is not limited in space like Streamlit.io
download_model(model_name_or_path, model_local_path)
return model_local_path
#%% instantiate Weaviate client
def get_weaviate_client(api_key, url, model_name_or_path, openai_api_key):
client = WeaviateClient(api_key, url,
model_name_or_path=model_name_or_path,
openai_api_key=openai_api_key)
client.display_properties.append('summary')
available_classes = sorted(client.show_classes())
# st.write(f"Available classes: {available_classes}")
# st.write(f"Available classes type: {type(available_classes)}")
logger.info(available_classes)
return client, available_classes
##############
# data = load_data(data_path)
# guests list for sidebar
guest_list = sorted(list(set([d['guest'] for d in data])))
def main():
with st.sidebar:
# moved it to main area
# guest = st.selectbox('Select Guest',
# options=guest_list,
# index=None,
# placeholder='Select Guest')
_, center, _ = st.columns([3, 5, 3])
with center:
st.text("Search Lab")
_, center, _ = st.columns([2, 5, 3])
with center:
if we_are_online:
st.text("Running ONLINE")
# st.text("(UNSTABLE)")
else:
st.text("Running OFFLINE")
st.write("----------")
hubrid_search = st.toggle('Hybrid Search', True)
if hubrid_search:
alpha_input = st.slider(label='Alpha',min_value=0.00, max_value=1.00, value=0.40, step=0.05)
retrieval_limit = st.slider(label='Hybrid Search Results', min_value=10, max_value=300, value=10, step=10)
hybrid_filter = st.toggle('Filter Search using Guest name', True) # i.e. look only at guests' data
rerank = st.toggle('Rerank', True)
if rerank:
reranker_topk = st.slider(label='Reranker Top K',min_value=1, max_value=5, value=3, step=1)
else:
# needed to not fill the LLM with too many responses (> context size)
# we could make it dependent on the model
reranker_topk = 3
rag_it = st.toggle('RAG it', True)
if rag_it:
st.write(f"Using LLM '{model_nameGPT}'")
llm_temperature = st.slider(label='LLM T˚', min_value=0.0, max_value=2.0, value=0.01, step=0.10 )
model_name_or_path = st.selectbox(label='Model Name:', options=available_models,
index=available_models.index(model_default),
placeholder='Select Model')
st.write("Experimental and time limited 2'")
c1,c2 = st.columns([8,1])
with c1:
finetune_model = st.toggle('Finetune on Modal A100 GPU', False)
if we_are_not_online or we_are_online:
if finetune_model:
from finetune_backend import finetune
if 'finetuned' in model_name_or_path:
st.write("Model already finetuned")
elif "models/" in model_name_or_path:
st.write("sentence-transformers models only!")
else:
try:
if 'finetuned' in model_name_or_path:
st.write("Model already finetuned")
else:
with c2:
with st.spinner(''):
model_path = finetune(model_name_or_path, savemodel=True, outpath='models')
with c1:
# st.write(f"model_path returned = {model_path}")
if model_path is not None:
if model_name_or_path.split('/')[-1] not in model_path:
st.sidebar.write(model_path) # a warning from finetuning in this case
# elif model_path not in available_models:
# finetuning generated a model, let's add it
# no because it's not in Weaviate, so we can't use it
# available_models.append(model_path)
# st.write(f"{model_path.split('/')[-1]} added to list!")
except Exception:
st.write("Model not found on HF or error")
else:
st.write("Finetuning not available on Streamlit online because of space limitations")
model_name_or_path = check_model(model_name_or_path)
try:
client, available_classes = get_weaviate_client(Wapi_key, url, model_name_or_path, openai_api_key)
except Exception as e:
# Weaviate doesn't know this model, maybe we're just finetuning a model
st.sidebar.write(f"Model unknown to Weaviate")
st.stop()
start_class = 'Impact_theory_all_mpnet_base_v2_finetuned'
class_name = st.selectbox(
label='Class Name:',
options=available_classes,
index=available_classes.index(start_class),
placeholder='Select Class Name'
)
st.write("----------")
c1,c2 = st.columns([8,1])
with c1:
show_metrics = st.toggle('Show Metrics on Golden set', False)
if show_metrics:
# _, center, _ = st.columns([3, 5, 3])
# with center:
# st.text("Metrics")
with c2:
with st.spinner(''):
metrics = execute_evaluation(golden_dataset, class_name, client, alpha=alpha_input)
if show_metrics:
kw_hit_rate = metrics['kw_hit_rate']
kw_mrr = metrics['kw_mrr']
hybrid_hit_rate = metrics['hybrid_hit_rate']
vector_hit_rate = metrics['vector_hit_rate']
vector_mrr = metrics['vector_mrr']
total_misses = metrics['total_misses']
st.text(f"KW hit rate: {kw_hit_rate}")
st.text(f"Vector hit rate: {vector_hit_rate}")
st.text(f"Hybrid hit rate: {hybrid_hit_rate}")
st.text(f"Hybrid MRR: {vector_mrr}")
st.text(f"Total misses: {total_misses}")
st.write("----------")
st.title("Chat with the Impact Theory podcasts!")
# st.image('./assets/impact-theory-logo.png', width=400)
st.image('assets/it_tom_bilyeu.png', use_column_width=True)
# st.subheader(f"Chat with the Impact Theory podcast: ")
st.write('\n')
# st.stop()
st.write("\u21D0 Open the sidebar to change Search settings \n ") # https://home.unicode.org also 21E0, 21B0 B2 D0
if not hubrid_search:
st.stop()
col1, _ = st.columns([3,7])
with col1:
guest = st.selectbox('Select A Guest',
options=guest_list,
index=None,
placeholder='Select Guest')
col1, col2 = st.columns([7,3])
with col1:
if guest is None:
msg = f'Select a guest before asking your question:'
else:
msg = f'Enter your question about {guest}:'
textbox = st.empty()
# best solution I found to be able to change the text inside a text_input box afterwards, using a key
query = textbox.text_input(msg,
value="",
placeholder="You can refer to the guest with PRONOUNS",
key=st.session_state.key)
# st.write(f"Guest = {guest}")
# st.write(f"key = {st.session_state.key}")
st.write('\n\n\n\n\n')
reworded_query = {'changed': False, 'status': 'error'} # at start, the query is empty
valid_response = [] # at start, the query is empty, so prevent the search
if query:
if guest is None:
st.session_state.key += 1
query = textbox.text_input(msg,
value="",
placeholder="YOU MUST SELECT A GUEST BEFORE ASKING A QUESTION",
key=st.session_state.key)
# st.write(f"key = {st.session_state.key}")
st.stop()
else:
# st.write(f'It looks like you selected {guest} as a filter (It is ignored for now).')
with col2:
# let's add a nice pulse bar while generating the response
with hc.HyLoader('', hc.Loaders.pulse_bars, primary_color= 'red', height=50): #"#0e404d" for image green
# with st.spinner('Generating Response...'):
with col1:
use_reworded_query = st.toggle('Rewrite query with LLM', True)
if use_reworded_query:
# let's use Llama2, and fall back on GPT3.5 if it fails
reworded_query = reword_query(query, guest,
model_name='llama2-13b-chat')
new_query = reworded_query['rewritten_question']
if reworded_query['status'] != 'error': # or reworded_query['changed']:
guest_lastname = guest.split(' ')[1]
if guest_lastname not in new_query:
# if the guest name is not in the rewritten question, we add it
new_query = f"About {guest}, " + new_query
query = new_query
st.write(f"Rewritten query: {query}")
# we can arrive here only if a guest was selected
where_filter = WhereFilter(path=['guest'], operator='Equal', valueText=guest).todict() \
if hybrid_filter else None
hybrid_response = client.hybrid_search(query,
class_name,
# properties=['content'], #['title', 'summary', 'content'],
alpha=alpha_input,
display_properties=client.display_properties,
where_filter=where_filter,
limit=retrieval_limit)
response = hybrid_response
if rerank:
# rerank results with cross encoder
ranked_response = reranker.rerank(response, query,
apply_sigmoid=True, # score between 0 and 1
top_k=reranker_topk)
logger.info(ranked_response)
expanded_response = expand_content(ranked_response, cache,
content_key='doc_id',
create_new_list=True)
response = expanded_response
# make sure token count < threshold
token_threshold = 8000 if model_nameGPT == model_ids[0] else 3500
valid_response = validate_token_threshold(response,
question_answering_prompt_series,
query=query,
tokenizer= encoding,# variable from ENCODING,
token_threshold=token_threshold,
verbose=True)
# st.write(f"Number of results: {len(valid_response)}")
# I jump out of col1 to get all page width, so need to retest query
if query:
# creates container for LLM response to position it above search results
chat_container, response_box = [], st.empty()
# # RAG time !! execute chat call to LLM
if rag_it:
# st.subheader("Response from Impact Theory (context)")
# will appear under the answer, moved it into the response box
# generate LLM prompt
prompt = generate_prompt_series(query=query, results=valid_response)
GPTllm = GPT_Turbo(model=model_nameGPT,
api_key=openai_api_key)
try:
# inserts chat stream from LLM
for resp in GPTllm.get_chat_completion(prompt=prompt,
temperature=llm_temperature,
max_tokens=350,
show_response=True,
stream=True):
with response_box:
content = resp.choices[0].delta.content
if content:
chat_container.append(content)
result = "".join(chat_container).strip()
response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}")
except BadRequestError as e:
logger.info('Making request with smaller context')
valid_response = validate_token_threshold(response,
question_answering_prompt_series,
query=query,
tokenizer=encoding,
token_threshold=3500,
verbose=True)
# if reranker is off, we may receive a LOT of responses
# so we must reduce the context size manually
if not rerank:
valid_response = valid_response[:reranker_topk]
prompt = generate_prompt_series(query=query, results=valid_response)
for resp in GPTllm.get_chat_completion(prompt=prompt,
temperature=llm_temperature,
max_tokens=350, # expand for more verbose answers
show_response=True,
stream=True):
try:
# inserts chat stream from LLM
with response_box:
content = resp.choice[0].delta.content
if content:
chat_container.append(content)
result = "".join(chat_container).strip()
response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}")
except Exception as e:
print(e)
st.markdown("----")
st.subheader("Search Results")
for i, hit in enumerate(valid_response):
col1, col2 = st.columns([7, 3], gap='large')
image = hit['thumbnail_url'] # get thumbnail_url
episode_url = hit['episode_url'] # get episode_url
title = hit["title"] # get title
show_length = hit["length"] # get length
time_string = str(timedelta(seconds=show_length)) # convert show_length to readable time string
with col1:
st.write(search_result(i=i,
url=episode_url,
guest=hit['guest'],
title=title,
content='',
length=time_string),
unsafe_allow_html=True)
st.write('\n\n')
with col2:
#st.write(f"<a href={episode_url} <img src={image} width='200'></a>",
# unsafe_allow_html=True)
#st.markdown(f"[![{title}]({image})]({episode_url})")
# st.markdown(f'<a href="{episode_url}">'
# f'<img src={image} '
# f'caption={title.split("|")[0]} width=200, use_column_width=False />'
# f'</a>',
# unsafe_allow_html=True)
st.image(image, caption=title.split('|')[0], width=200, use_column_width=False)
# let's use all width for the content
st.write(hit['content'])
def get_answer(query, valid_response, GPTllm):
# generate LLM prompt
prompt = generate_prompt_series(query=query,
results=valid_response)
return GPTllm.get_chat_completion(prompt=prompt,
system_message='answer this question based on the podcast material',
temperature=0,
max_tokens=500,
stream=False,
show_response=False)
def reword_query(query, guest, model_name='llama2-13b-chat', response_processing=True):
""" Asks LLM to rewrite the query when the guest name is missing.
Args:
query (str): user query
guest (str): guest name
model_name (str, optional): name of a LLM model to be used
"""
# tags = {'llama2-13b-chat': {'start': '<s>', 'end': '</s>', 'instruction': '[INST]', 'system': '[SYS]'},
# 'gpt-3.5-turbo-0613': {'start': '<|startoftext|>', 'end': '', 'instruction': "```", 'system': ```}}
prompt_fields = {
"you_are":f"You are an expert in linguistics and semantics, analyzing the question asked by a user to a vector search system, \
and making sure that the question is well formulated and that the system can understand it.",
"your_task":f"Your task is to detect if the name of the guest ({guest}) is mentioned in the user's question, \
and if that is not the case, rewrite the question using the guest name, \
without changing the meaning of the question. \
Most of the time, the user will have used a pronoun to designate the guest, in which case, \
simply replace the pronoun with the guest name.",
"question":f"If the user mentions the guest name, ie {query}, just return his question as is. \
If the user does not mention the guest name, rewrite the question using the guest name.",
"final_instruction":f"Only regerate the requested rewritten question or the original, WITHOUT ANY COMMENT OR REPHRASING. \
Your answer must be as close as possible to the original question, \
and exactly identical, word for word, if the user mentions the guest name, i.e. {guest}.",
}
# prompt created by chatGPT :-)
# and Llama still outputs the original question and precedes the answer with 'rewritten question'
prompt_fields2 = {
"you_are": (
"You are an expert in linguistics and semantics. Your role is to analyze questions asked to a vector search system."
),
"your_task": (
f"Detect if the guest's FULL name, {guest}, is mentioned in the user's question. "
"If not, rewrite the question by replacing pronouns or indirect references with the guest's name." \
"If yes, return the original question as is, without any change at all, not even punctuation,"
"except a question mark that you MUST add if it's missing."
),
"question": (
f"Original question: '{query}'. "
"Rewrite this question to include the guest's FULL name if it's not already mentioned."
"The Only thing you can and MUST add is a question mark if it's missing."
),
"final_instruction": (
"Create a rewritten question or keep the original question as is. "
"Do not include any labels, titles, or additional text before or after the question."
"The Only thing you can and MUST add is a question mark if it's missing."
"Return a json object, with the key 'original_question' for the original question, \
and 'rewritten_question' for the rewritten question \
and 'changed' being True if you changed the answer, otherwise False."
),
}
if model_name == 'llama2-13b-chat':
# special tags are used:
# `<s>` - start prompt tag
# `[INST], [/INST]` - Opening and closing model instruction tags
# `<<<SYS>>>, <</SYS>>` - Opening and closing system prompt tags
llama_prompt = """
<s>[INST] <<SYS>>
{you_are}
<</SYS>>
{your_task}\n
```
\n\n
Question: {question}\n
{final_instruction} [/INST]
Answer:
"""
prompt = llama_prompt.format(**prompt_fields2)
headers = {"Authorization": f"Bearer {hf_token}",
"Content-Type": "application/json",}
json_body = {
"inputs": prompt,
"parameters": {"max_new_tokens":400,
"repetition_penalty": 1.0,
"temperature":0.01}
}
response = requests.request("POST", hf_endpoint, headers=headers, data=json.dumps(json_body))
response = json.loads(response.content.decode("utf-8"))
# ^ will not process the badly formatted generated text, so we do it ourselves
if isinstance(response, dict) and 'error' in response:
print("Found error")
print(response)
# return {'error': response['error'], 'rewritten_question': query, 'changed': False, 'status': 'error'}
# I test this here otherwise it gets in col 2 or 1, which are too
# if reworded_query['status'] == 'error':
# st.write(f"Error in LLM response: 'error':{reworded_query['error']}")
# st.write("The LLM could not connect to the server. Please try again later.")
# st.stop()
return reword_query(query, guest, model_name='gpt-3.5-turbo-0613')
if response_processing:
if isinstance(response, list) and isinstance(response[0], dict) and 'generated_text' in response[0]:
print("Found generated text")
response0 = response[0]['generated_text']
pattern = r'\"(\w+)\":\s*(\".*?\"|\w+)'
matches = re.findall(pattern, response0)
# let's build a dictionary
result = {key: json.loads(value) if value.startswith("\"") else value for key, value in matches}
return result | {'status': 'success'}
else:
print("Found no answer")
return reword_query(query, guest, model_name='gpt-3.5-turbo-0613')
# return {'original_question': query, 'rewritten_question': query, 'changed': False, 'status': 'no properly formatted answer' }
else:
return response
# return response
# assert 'error' not in response, f"Error in LLM response: {response['error']}"
# assert 'generated_text' in response[0], f"Error in LLM response: {response}, no 'generated_text' field"
# # let's extract the rewritten question
# return response[0]['generated_text'] .split("Rewritten question: '")[-1][:-1]
else:
# assume openai
model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613']
model_name = model_ids[1]
GPTllm = GPT_Turbo(model=model_name,
api_key=openai_api_key)
openai_prompt = """
{your_task}\n
```
\n\n
Question: {question}\n
{final_instruction}
Answer:
"""
prompt = openai_prompt.format(**prompt_fields)
try:
resp = GPTllm.get_chat_completion(prompt=openai_prompt,
system_message=prompt_fields['you_are'],
temperature=0.01,
max_tokens=1500, # it's a question...
show_response=True,
stream=False)
return {'rewritten_question': resp.choices[0].delta.content,
'changed': True, 'status': 'success'}
except Exception:
return {'rewritten_question': query, 'changed': False, 'status': 'not success'}
if __name__ == '__main__':
main()
# %%