Spaces:
Running
Running
File size: 34,458 Bytes
30ffb9e 677bca6 30ffb9e bb52c90 30ffb9e 9531574 30ffb9e 685ba52 bbf4302 30ffb9e 685ba52 3bec1a5 685ba52 73e0fbb 30ffb9e fc26027 30ffb9e 30eb437 30ffb9e 30eb437 30ffb9e 30eb437 30ffb9e 30eb437 30ffb9e 30eb437 2e4b5f4 30ffb9e 677bca6 30ffb9e 30eb437 30ffb9e 30eb437 30ffb9e 30eb437 30ffb9e 677bca6 30ffb9e d884b0a 30ffb9e 0a29650 476fd04 fc26027 30ffb9e 0a29650 5313c77 0a29650 5313c77 0a29650 30ffb9e 2f7a3b7 82e8c15 30eb437 fc26027 30eb437 fc26027 30eb437 ef5768a 660dcf1 30eb437 82e8c15 30ffb9e 55a0f00 677bca6 97f211e 677bca6 fc26027 677bca6 30ffb9e 677bca6 0a29650 fc26027 685ba52 30ffb9e 083cd31 30ffb9e 88b4a61 5ec9190 88b4a61 87dd32d 88b4a61 5ec9190 083cd31 30ffb9e bbf4302 88b4a61 30ffb9e 671b5bb 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d 30ffb9e 87dd32d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 |
#%%
from tiktoken import get_encoding, encoding_for_model
from weaviate_interface import WeaviateClient, WhereFilter
from sentence_transformers import SentenceTransformer
from prompt_templates import question_answering_prompt_series, question_answering_system
from openai_interface import GPT_Turbo
from app_features import (convert_seconds, generate_prompt_series, search_result,
validate_token_threshold, load_content_cache, load_data,
expand_content)
from retrieval_evaluation import execute_evaluation, calc_hit_rate_scores
from llama_index.finetuning import EmbeddingQAFinetuneDataset
from openai import BadRequestError
from reranker import ReRanker
from loguru import logger
import streamlit as st
from streamlit_option_menu import option_menu
import hydralit_components as hc
import sys
import json
import os, time, requests, re
from datetime import timedelta
import pathlib
import gdown
import tempfile
import base64
import shutil
def get_base64_of_bin_file(bin_file):
with open(bin_file, 'rb') as file:
data = file.read()
return base64.b64encode(data).decode()
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv('env'), override=True)
# I use a key that I increment each time I want to change a text_input
if 'key' not in st.session_state:
st.session_state.key = 0
# key = st.session_state['key']
if not pathlib.Path('models').exists():
os.mkdir('models')
# TODO cache these things but no time left
# I put a file local.txt in my desktop models folder to find out if it's running online
we_are_online = not pathlib.Path("models/local.txt").exists()
we_are_not_online = not we_are_online
golden_dataset = EmbeddingQAFinetuneDataset.from_json("data/golden_100.json")
## PAGE CONFIGURATION
st.set_page_config(page_title="Ask Impact Theory",
page_icon="assets/impact-theory-logo-only.png",
layout="wide",
initial_sidebar_state="collapsed",
menu_items={'Report a bug': "https://www.extremelycoolapp.com/bug"})
image = "https://is2-ssl.mzstatic.com/image/thumb/Music122/v4/bd/34/82/bd348260-314c-5898-26c0-bef2e0388ebe/source/1200x1200bb.png"
def add_bg_from_local(image_file):
bin_str = get_base64_of_bin_file(image_file)
page_bg_img = f'''
<style>
.stApp {{
background-image: url("data:image/png;base64,{bin_str}");
background-size: 100% auto;
background-repeat: no-repeat;
background-attachment: fixed;
}}
</style>
'''
st.markdown(page_bg_img, unsafe_allow_html=True)
## RERANKER
reranker = ReRanker('cross-encoder/ms-marco-MiniLM-L-6-v2')
## ENCODING --> tiktoken library
model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613']
model_nameGPT = model_ids[1]
encoding = encoding_for_model(model_nameGPT)
## DATA
data_path = './data/impact_theory_data.json'
cache_path = 'data/impact_theory_cache.parquet'
data = load_data(data_path)
cache = None # load_content_cache(cache_path)
guest_list = sorted(list(set([d['guest'] for d in data])))
if 'secrets' in st.secrets:
# st.write("Loading secrets from [secrets] section")
# for streamlit online or local, which uses a [secrets] section
Wapi_key = st.secrets['secrets']['WEAVIATE_API_KEY']
url = st.secrets['secrets']['WEAVIATE_ENDPOINT']
openai_api_key = st.secrets['secrets']['OPENAI_API_KEY']
hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN_chris']
hf_endpoint = st.secrets['secrets']['LLAMA2_ENDPOINT_UPLIMIT']
else :
# st.write("Loading secrets for Huggingface")
# for Huggingface (no [secrets] section)
Wapi_key = st.secrets['WEAVIATE_API_KEY']
url = st.secrets['WEAVIATE_ENDPOINT']
openai_api_key = st.secrets['OPENAI_API_KEY']
hf_token = st.secrets['LLAMA2_ENDPOINT_HF_TOKEN_chris']
hf_endpoint = st.secrets['LLAMA2_ENDPOINT_UPLIMIT']
#%%
# model_default = 'sentence-transformers/all-mpnet-base-v2'
model_default = 'models/finetuned-all-mpnet-base-v2-300' if we_are_not_online \
else 'sentence-transformers/all-mpnet-base-v2'
available_models = ['sentence-transformers/all-mpnet-base-v2',
'sentence-transformers/all-MiniLM-L6-v2',
'models/finetuned-all-mpnet-base-v2-300',
'sentence-transformers/all-MiniLM-L12-v2']
#%%
models_urls = {'models/finetuned-all-mpnet-base-v2-300': "https://drive.google.com/drive/folders/1asJ37-AUv5nytLtH6hp6_bVV3_cZOXfj"}
def download_model_from_Gdrive(model_name_or_path, model_local_path):
st.write("Downloading model from Google Drive")
assert model_name_or_path in models_urls, f"Model {model_name_or_path} not found in models_urls"
url = models_urls[model_name_or_path]
gdown.download_folder(url, output=model_local_path, quiet=False, use_cookies=False)
print(f"Model downloaded from Gdrive and saved to {model_local_path} folder")
# st.write("Model downloaded")
def download_model(model_name_or_path, model_local_path):
if model_name_or_path.startswith("models/"):
download_model_from_Gdrive(model_name_or_path, model_local_path)
elif model_name_or_path.startswith("sentence-transformers/"):
st.sidebar.write(f"Downloading {model_name_or_path}")
model = SentenceTransformer(model_name_or_path)
st.sidebar.write(f"Model {model_name_or_path} downloaded")
models_urls[model_name_or_path] = model_local_path
model.save(model_local_path)
# st.sidebar.write(f"Model {model_name_or_path} saved to {model_new_path}")
#%%
# for streamlit online, we must download the model from google drive
# because github LFS doesn't work on forked repos
def check_model(model_name_or_path):
model_name = model_name_or_path.split('/')[-1] # remove 'sentence-transformers'
model_local_path = str(pathlib.Path("models") / model_name) # this creates a models folder inside /models
if pathlib.Path(model_local_path).exists():
# let's use the model that's already there
print(f"Model {model_local_path} already exists")
else:
# let's download the model, HF is not limited in space like Streamlit.io
download_model(model_name_or_path, model_local_path)
#%% instantiate Weaviate client
def get_weaviate_client(api_key, url, model_name_or_path, openai_api_key):
try:
client = WeaviateClient(api_key, url,
model_name_or_path=model_name_or_path,
openai_api_key=openai_api_key)
except Exception:
# client not available, wrong key, expired free sandbox etc
return None, None
try:
client.display_properties.append('summary')
# available_classes = sorted(client.show_classes()) # doesn't work anymore
# print(available_classes)
available_classes = sorted([c['class'] for c in client.schema.get()['classes']])
# print(available_classes)
# st.write(f"Available classes: {available_classes}")
# st.write(f"Available classes type: {type(available_classes)}")
logger.info(available_classes)
return client, available_classes
except Exception:
return client, []
##############
def main():
with st.sidebar:
_, center, _ = st.columns([3, 5, 3])
with center:
st.text("Search Lab")
_, center, _ = st.columns([2, 5, 3])
with center:
if we_are_online:
st.text("Running ONLINE")
# st.text("(UNSTABLE)")
else:
st.text("Running OFFLINE")
st.write("----------")
hybrid_search = st.toggle('Hybrid Search', True)
if hybrid_search:
alpha_input = st.slider(label='Alpha',min_value=0.00, max_value=1.00, value=0.40, step=0.05, key=1)
retrieval_limit = st.slider(label='Hybrid Search Results', min_value=10, max_value=300, value=10, step=10)
hybrid_filter = st.toggle('Filter Search using Guest name', True) # i.e. look only at guests' data
rerank = st.toggle('Rerank', True)
if rerank:
reranker_topk = st.slider(label='Reranker Top K',min_value=1, max_value=5, value=3, step=1)
else:
# needed to not fill the LLM with too many responses (> context size)
# we could make it dependent on the model
reranker_topk = 3
rag_it = st.toggle(f"RAG it with '{model_nameGPT}'", True)
if rag_it:
# st.write(f"Using LLM '{model_nameGPT}'")
llm_temperature = st.slider(label='LLM T˚', min_value=0.0, max_value=2.0, value=0.01, step=0.10 )
model_name_or_path = st.selectbox(label='Model Name:', options=available_models,
index=available_models.index(model_default),
placeholder='Select Model')
delete_models = st.button('Delete models')
if delete_models:
# model_path = os.path.join("models", model_name_or_path.split('/')[-1])
# if os.path.isdir(model_path):
# shutil.rmtree(model_path)
for model in os.listdir("models"):
model_path = os.path.join("models", model)
if os.path.isdir(model_path) and 'finetuned-all-mpnet-base-v2-300' not in model_path:
shutil.rmtree(model_path)
st.write("Models deleted")
if we_are_not_online:
st.write("Experimental and time limited 2'")
c1,c2 = st.columns([8,1])
with c1:
finetune_model = st.button('Finetune on Modal A100 GPU')
if finetune_model:
from finetune_backend import finetune
if 'finetuned' in model_name_or_path:
st.write("Model already finetuned")
elif "models/" in model_name_or_path:
st.write("sentence-transformers models only!")
else:
try:
if 'finetuned' in model_name_or_path:
st.write("Model already finetuned")
else:
with c2:
with st.spinner(''):
model_path = finetune(model_name_or_path, savemodel=True, outpath='models')
with c1:
if model_path is not None:
if model_name_or_path.split('/')[-1] not in model_path:
st.sidebar.write(model_path) # a warning from finetuning in this case
# TODO: add model to Weaviate and to model list
except Exception:
st.write("Model not found on HF or error")
else:
st.write("Finetuning not available on Streamlit online because of space limitations")
check_model(model_name_or_path)
client, available_classes = get_weaviate_client(Wapi_key, url, model_name_or_path, openai_api_key)
print("Available classes:", available_classes)
if client is None:
# maybe the free sandbox has expired, or the api key is wrong
st.sidebar.write(f"Weaviate sandbox not accessible or expired")
# st.stop()
elif available_classes:
start_class = 'Impact_theory_all_mpnet_base_v2_finetuned'
class_name = st.selectbox(
label='Class Name:',
options=available_classes,
index=available_classes.index(start_class),
placeholder='Select Class Name'
)
st.write("----------")
if we_are_not_online:
c1,c2 = st.columns([8,1])
with c1:
show_metrics = st.button('Show Metrics on Golden set')
if show_metrics:
# we must add it because the hybrid search toggle could hide it
alpha_input2 = st.slider(label='Alpha',min_value=0.00, max_value=1.00, value=0.40, step=0.05, key=2)
# _, center, _ = st.columns([3, 5, 3])
# with center:
# st.text("Metrics")
with c2:
with st.spinner(''):
metrics = execute_evaluation(golden_dataset, class_name, client, alpha=alpha_input2)
with c1:
kw_hit_rate = metrics['kw_hit_rate']
kw_mrr = metrics['kw_mrr']
hybrid_hit_rate = metrics['hybrid_hit_rate']
vector_hit_rate = metrics['vector_hit_rate']
vector_mrr = metrics['vector_mrr']
total_misses = metrics['total_misses']
st.text(f"KW hit rate: {kw_hit_rate}")
st.text(f"Vector hit rate: {vector_hit_rate}")
st.text(f"Hybrid hit rate: {hybrid_hit_rate}")
st.text(f"Hybrid MRR: {vector_mrr}")
st.text(f"Total misses: {total_misses}")
st.write("----------")
else:
# Weaviate doesn't know this model, maybe we're just finetuning a model
st.sidebar.write(f"Model Unknown to Weaviate")
st.title("Chat with the Impact Theory podcasts!")
# st.image('./assets/impact-theory-logo.png', width=400)
st.image('assets/it_tom_bilyeu.png', use_column_width=True)
# st.subheader(f"Chat with the Impact Theory podcast: ")
st.write('\n')
# st.stop()
st.write("\u21D0 Open the sidebar to change Search settings \n ") # https://home.unicode.org also 21E0, 21B0 B2 D0
if client is None:
st.write("Weaviate sandbox not accessible or expired!!! Stopping execution!")
st.stop()
elif not available_classes:
# we have to stop here, to exit the 'with st.sidebar' block and display the banner at least
st.stop()
if not hybrid_search:
st.stop()
col1, _ = st.columns([3,7])
with col1:
guest = st.selectbox('Select A Guest',
options=guest_list,
index=None,
placeholder='Select Guest')
col1, col2 = st.columns([7,3])
with col1:
if guest is None:
msg = f'Select a guest before asking your question:'
else:
msg = f'Enter your question about {guest}:'
textbox = st.empty()
# best solution I found to be able to change the text inside a text_input box afterwards, using a key
query = textbox.text_input(msg,
value="",
placeholder="You can refer to the guest with PRONOUNS",
key=st.session_state.key)
# st.write(f"Guest = {guest}")
# st.write(f"key = {st.session_state.key}")
st.write('\n\n\n\n\n')
reworded_query = {'changed': False, 'status': 'error'} # at start, the query is empty
valid_response = [] # at start, the query is empty, so prevent the search
if query:
if guest is None:
st.session_state.key += 1
query = textbox.text_input(msg,
value="",
placeholder="YOU MUST SELECT A GUEST BEFORE ASKING A QUESTION",
key=st.session_state.key)
# st.write(f"key = {st.session_state.key}")
st.stop()
else:
# st.write(f'It looks like you selected {guest} as a filter (It is ignored for now).')
with col2:
# let's add a nice pulse bar while generating the response
with hc.HyLoader('', hc.Loaders.pulse_bars, primary_color= 'red', height=50): #"#0e404d" for image green
with col1:
if st.toggle('Rewrite query with LLM', True):
# let's use Llama2, and fall back on GPT3.5 if it fails
reworded_query = reword_query(query, guest,
model_name='gpt-3.5-turbo-0125')
new_query = reworded_query['rewritten_question']
if reworded_query['status'] != 'error': # or reworded_query['changed']:
guest_lastname = guest.split(' ')[1]
if guest_lastname not in new_query:
# if the guest name is not in the rewritten question, we add it
new_query = f"About {guest}, " + new_query
query = new_query
st.write(f"New query: {query}")
# we can arrive here only if a guest was selected
where_filter = WhereFilter(path=['guest'], operator='Equal', valueText=guest).todict() \
if hybrid_filter else None
hybrid_response = client.hybrid_search(query,
class_name,
# properties=['content'], #['title', 'summary', 'content'],
alpha=alpha_input,
display_properties=client.display_properties,
where_filter=where_filter,
limit=retrieval_limit)
response = hybrid_response
if rerank:
# rerank results with cross encoder
ranked_response = reranker.rerank(response, query,
apply_sigmoid=True, # score between 0 and 1
top_k=reranker_topk)
logger.info(ranked_response)
expanded_response = expand_content(ranked_response, cache,
content_key='doc_id',
create_new_list=True)
response = expanded_response
# make sure token count < threshold
token_threshold = 8000 if model_nameGPT == model_ids[0] else 3500
valid_response = validate_token_threshold(response,
question_answering_prompt_series,
query=query,
tokenizer= encoding,# variable from ENCODING,
token_threshold=token_threshold,
verbose=True)
# st.write(f"Number of results: {len(valid_response)}")
# I jumped out of col1 to get all page width, so need to retest query
if query:
# creates container for LLM response to position it above search results
chat_container, response_box = [], st.empty()
# # RAG time !! execute chat call to LLM
if rag_it:
# st.subheader("Response from Impact Theory (context)")
# will appear under the answer, moved it into the response box
# generate LLM prompt
prompt = generate_prompt_series(query=query, results=valid_response)
GPTllm = GPT_Turbo(model=model_nameGPT,
api_key=openai_api_key)
try:
# inserts chat stream from LLM
for resp in GPTllm.get_chat_completion(prompt=prompt,
temperature=llm_temperature,
max_tokens=350,
show_response=True,
stream=True):
with response_box:
content = resp.choices[0].delta.content
if content:
chat_container.append(content)
result = "".join(chat_container).strip()
response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}")
except BadRequestError as e:
logger.info('Making request with smaller context')
valid_response = validate_token_threshold(response,
question_answering_prompt_series,
query=query,
tokenizer=encoding,
token_threshold=3500,
verbose=True)
# if reranker is off, we may receive a LOT of responses
# so we must reduce the context size manually
if not rerank:
valid_response = valid_response[:reranker_topk]
prompt = generate_prompt_series(query=query, results=valid_response)
for resp in GPTllm.get_chat_completion(prompt=prompt,
temperature=llm_temperature,
max_tokens=350, # expand for more verbose answers
show_response=True,
stream=True):
try:
# inserts chat stream from LLM
with response_box:
content = resp.choice[0].delta.content
if content:
chat_container.append(content)
result = "".join(chat_container).strip()
response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}")
except Exception as e:
print(e)
st.markdown("----")
st.subheader("Search Results")
for i, hit in enumerate(valid_response):
col1, col2 = st.columns([7, 3], gap='large')
image = hit['thumbnail_url'] # get thumbnail_url
episode_url = hit['episode_url'] # get episode_url
title = hit["title"] # get title
show_length = hit["length"] # get length
time_string = str(timedelta(seconds=show_length)) # convert show_length to readable time string
with col1:
st.write(search_result(i=i,
url=episode_url,
guest=hit['guest'],
title=title,
content='',
length=time_string),
unsafe_allow_html=True)
st.write('\n\n')
with col2:
#st.write(f"<a href={episode_url} <img src={image} width='200'></a>",
# unsafe_allow_html=True)
#st.markdown(f"[![{title}]({image})]({episode_url})")
# st.markdown(f'<a href="{episode_url}">'
# f'<img src={image} '
# f'caption={title.split("|")[0]} width=200, use_column_width=False />'
# f'</a>',
# unsafe_allow_html=True)
st.image(image, caption=title.split('|')[0], width=200, use_column_width=False)
# let's use all width for the content
st.write(hit['content'])
def get_answer(query, valid_response, GPTllm):
# generate LLM prompt
prompt = generate_prompt_series(query=query,
results=valid_response)
return GPTllm.get_chat_completion(prompt=prompt,
system_message='answer this question based on the podcast material',
temperature=0,
max_tokens=500,
stream=False,
show_response=False)
def reword_query(query, guest, model_name='llama2-13b-chat', response_processing=True):
""" Asks LLM to rewrite the query when the guest name is missing.
Args:
query (str): user query
guest (str): guest name
model_name (str, optional): name of a LLM model to be used
"""
# tags = {'llama2-13b-chat': {'start': '<s>', 'end': '</s>', 'instruction': '[INST]', 'system': '[SYS]'},
# 'gpt-3.5-turbo-0613': {'start': '<|startoftext|>', 'end': '', 'instruction': "```", 'system': ```}}
prompt_fields = {
"you_are":f"You are an expert in linguistics and semantics, analyzing the question asked by a user to a vector search system, \
and making sure that the question is well formulated and that the system can understand it.",
"your_task":f"Your task is to detect if the name of the guest ({guest}) is mentioned in the user's question, \
and if that is not the case, rewrite the question using the guest name, \
without changing the meaning of the question. \
Most of the time, the user will have used a pronoun to designate the guest, in which case, \
simply replace the pronoun with the guest name.",
"question":f"If the user mentions the guest name ({guest}) in the following question '{query}', just return his question as is. \
If the user does not mention the guest name, rewrite the question using the guest name.",
"final_instruction":f"Only regenerate the requested rewritten question or the original, WITHOUT ANY COMMENT OR REPHRASING. \
Your answer must be as close as possible to the original question, \
and exactly identical, word for word, if the user mentions the guest name, i.e. {guest}.",
}
# prompt created by chatGPT :-)
# and Llama still outputs the original question and precedes the answer with 'rewritten question'
prompt_fields2 = {
"you_are": (
"You are an expert in linguistics and semantics. Your role is to analyze questions asked to a vector search system."
),
"your_task": (
f"Detect if the guest's FULL name, {guest}, is mentioned in the user's question. "
"If not, rewrite the question by replacing pronouns or indirect references with the guest's name." \
"If yes, return the original question as is, without any change at all, not even punctuation,"
"except a question mark that you MUST add if it's missing."
),
"question": (
f"Original question: '{query}'. "
"Rewrite this question to include the guest's FULL name if it's not already mentioned."
"Add a question mark if it's missing, nothing else."
),
"final_instruction": (
"Create a rewritten question or keep the original question as is. "
"Do not include any labels, titles, or additional text before or after the question."
"The Only thing you can and MUST add is a question mark if it's missing."
"Return a json object, with the key 'original_question' for the original question, \
and 'rewritten_question' for the rewritten question \
and 'changed' being True if you changed the answer, otherwise False."
),
}
if model_name == 'llama2-13b-chat':
# special tags are used:
# `<s>` - start prompt tag
# `[INST], [/INST]` - Opening and closing model instruction tags
# `<<<SYS>>>, <</SYS>>` - Opening and closing system prompt tags
llama_prompt = """
<s>[INST] <<SYS>>
{you_are}
<</SYS>>
{your_task}\n
```
\n\n
Question: {question}\n
{final_instruction} [/INST]
Answer:
"""
prompt = llama_prompt.format(**prompt_fields2)
headers = {"Authorization": f"Bearer {hf_token}",
"Content-Type": "application/json",}
json_body = {
"inputs": prompt,
"parameters": {"max_new_tokens":400,
"repetition_penalty": 1.0,
"temperature":0.01}
}
response = requests.request("POST", hf_endpoint, headers=headers, data=json.dumps(json_body))
response = json.loads(response.content.decode("utf-8"))
# ^ will not process the badly formatted generated text, so we do it ourselves
if isinstance(response, dict) and 'error' in response:
print("Found error")
print(response)
# return {'error': response['error'], 'rewritten_question': query, 'changed': False, 'status': 'error'}
# I test this here otherwise it gets in col 2 or 1, which are too
# if reworded_query['status'] == 'error':
# st.write(f"Error in LLM response: 'error':{reworded_query['error']}")
# st.write("The LLM could not connect to the server. Please try again later.")
# st.stop()
return reword_query(query, guest, model_name='gpt-3.5-turbo-0125')
if response_processing:
if isinstance(response, list) and isinstance(response[0], dict) and 'generated_text' in response[0]:
print("Found generated text")
response0 = response[0]['generated_text']
pattern = r'\"(\w+)\":\s*(\".*?\"|\w+)'
matches = re.findall(pattern, response0)
# let's build a dictionary
result = {key: json.loads(value) if value.startswith("\"") else value for key, value in matches}
return result | {'status': 'success'}
else:
print("Found no answer")
return reword_query(query, guest, model_name='gpt-3.5-turbo-0125')
# return {'original_question': query, 'rewritten_question': query, 'changed': False, 'status': 'no properly formatted answer' }
else:
return response
# return response
# assert 'error' not in response, f"Error in LLM response: {response['error']}"
# assert 'generated_text' in response[0], f"Error in LLM response: {response}, no 'generated_text' field"
# # let's extract the rewritten question
# return response[0]['generated_text'] .split("Rewritten question: '")[-1][:-1]
else:
# we assume / force openai
model_ids = ['gpt-3.5-turbo-0125', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613']
if model_name not in model_ids:
model_name = model_ids[0]
GPTllm = GPT_Turbo(model=model_name, api_key=openai_api_key)
openai_prompt = """
{your_task} \n
{final_instruction} /n
```
\n\n
Question: {question}\n
Answer:
"""
prompt = openai_prompt.format(**prompt_fields)
openai_prompt2 = """
{your_task}\n
```
\n\n
{final_instruction}
"""
prompt2 = openai_prompt2.format(**{'your_task':prompt_fields['your_task'],
'final_instruction':prompt_fields['final_instruction']})
try:
# https://platform.openai.com/docs/guides/text-generation/chat-completions-api
resp = GPTllm.get_chat_completion(prompt=prompt,
system_message=prompt_fields['you_are'],
user_message = None, #prompt_fields['question'],
temperature=0.01,
max_tokens=1500, # it's a long question...
show_response=True,
stream=False)
if resp.choices[0].finish_reason == 'stop':
return {'rewritten_question': resp.choices[0].message.content,
'changed': True, 'status': 'success'}
else:
raise Exception("LLM did not stop") # to go to the except block
except Exception:
return {'rewritten_question': query, 'changed': False, 'status': 'not success'}
if __name__ == '__main__':
main()
# streamlit run app.py --server.allowRunOnSave True |