Spaces:
Running
Running
File size: 36,423 Bytes
30ffb9e 685ba52 30ffb9e 685ba52 3bec1a5 685ba52 73e0fbb 685ba52 30ffb9e fc26027 30ffb9e d884b0a 30ffb9e fc26027 30ffb9e fc26027 30ffb9e 8d089a5 083cd31 fc26027 083cd31 fc26027 083cd31 30ffb9e fc26027 30ffb9e fc26027 685ba52 30ffb9e 083cd31 30ffb9e 083cd31 685ba52 083cd31 30ffb9e 671b5bb 30ffb9e 92c2974 30ffb9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 |
#%%
from tiktoken import get_encoding, encoding_for_model
from weaviate_interface import WeaviateClient, WhereFilter
from sentence_transformers import SentenceTransformer
from prompt_templates import question_answering_prompt_series, question_answering_system
from openai_interface import GPT_Turbo
from app_features import (convert_seconds, generate_prompt_series, search_result,
validate_token_threshold, load_content_cache, load_data,
expand_content)
from retrieval_evaluation import execute_evaluation, calc_hit_rate_scores
from llama_index.finetuning import EmbeddingQAFinetuneDataset
from weaviate_interface import WeaviateClient
from openai import BadRequestError
from reranker import ReRanker
from loguru import logger
import streamlit as st
from streamlit_option_menu import option_menu
import hydralit_components as hc
import sys
import json
import os, time, requests, re
from datetime import timedelta
import pathlib
import gdown
import tempfile
import base64
import shutil
def get_base64_of_bin_file(bin_file):
with open(bin_file, 'rb') as file:
data = file.read()
return base64.b64encode(data).decode()
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv('env'), override=True)
# I use a key that I increment each time I want to change a text_input
if 'key' not in st.session_state:
st.session_state.key = 0
# key = st.session_state['key']
if not pathlib.Path('models').exists():
os.mkdir('models')
# I should cache these things but no time left
# I put a file local.txt in my desktop models folder to find out if it's running online
we_are_online = not pathlib.Path("models/local.txt").exists()
we_are_not_online = not we_are_online
golden_dataset = EmbeddingQAFinetuneDataset.from_json("data/golden_100.json")
# shutil.rmtree("models/models") # remove it - I wanted to clear the space on streamlit online
## PAGE CONFIGURATION
st.set_page_config(page_title="Ask Impact Theory",
page_icon="assets/impact-theory-logo-only.png",
layout="wide",
initial_sidebar_state="collapsed",
menu_items={'Report a bug': "https://www.extremelycoolapp.com/bug"})
image = "https://is2-ssl.mzstatic.com/image/thumb/Music122/v4/bd/34/82/bd348260-314c-5898-26c0-bef2e0388ebe/source/1200x1200bb.png"
def add_bg_from_local(image_file):
bin_str = get_base64_of_bin_file(image_file)
page_bg_img = f'''
<style>
.stApp {{
background-image: url("data:image/png;base64,{bin_str}");
background-size: 100% auto;
background-repeat: no-repeat;
background-attachment: fixed;
}}
</style>
'''
st.markdown(page_bg_img, unsafe_allow_html=True)
# COMMENT: I tried to create a dropdown menu but it's harder than it looks, so I gave up
# https://discuss.streamlit.io/t/streamlit-option-menu-is-a-simple-streamlit-component-that-allows-users-to-select-a-single-item-from-a-list-of-options-in-a-menu/20514
# not great, but it works
# selected = option_menu("About", ["Improvements","This"], #"Main Menu", ["Home", 'Settings'],
# icons=['house', 'gear'],
# menu_icon="cast",
# default_index=1)
# # Custom HTML/CSS for the banner
# base64_img = get_base64_of_bin_file("assets/it_tom_bilyeu.png")
# banner_menu_html = f"""
# <div class="banner">
# <img src= "data:image/png;base64,{base64_img}" alt="Banner Image">
# </div>
# <style>
# .banner {{
# width: 100%;
# height: auto;
# overflow: hidden;
# display: flex;
# justify-content: center;
# }}
# .banner img {{
# width: 130%;
# height: auto;
# object-fit: contain;
# }}
# </style>
# """
# st.components.v1.html(banner_menu_html)
# specify the primary menu definition
# it gives a vertical menu inside a navigation bar !!!
# menu_data = [
# {'icon': "far fa-copy", 'label':"Left End"},
# {'id':'Copy','icon':"🐙",'label':"Copy"},
# {'icon': "far fa-chart-bar", 'label':"Chart"},#no tooltip message
# {'icon': "far fa-address-book", 'label':"Book"},
# {'id':' Crazy return value 💀','icon': "💀", 'label':"Calendar"},
# {'icon': "far fa-clone", 'label':"Component"},
# {'icon': "fas fa-tachometer-alt", 'label':"Dashboard",'ttip':"I'm the Dashboard tooltip!"}, #can add a tooltip message
# {'icon': "far fa-copy", 'label':"Right End"},
# ]
# # we can override any part of the primary colors of the menu
# over_theme = {'txc_inactive': '#FFFFFF','menu_background':'red','txc_active':'yellow','option_active':'blue'}
# # over_theme = {'txc_inactive': '#FFFFFF'}
# menu_id = hc.nav_bar(menu_definition=menu_data,
# home_name='Home',
# override_theme=over_theme)
#get the id of the menu item clicked
# st.info(f"{menu_id=}")
## RERANKER
reranker = ReRanker('cross-encoder/ms-marco-MiniLM-L-6-v2')
## ENCODING --> tiktoken library
model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613']
model_nameGPT = model_ids[1]
encoding = encoding_for_model(model_nameGPT)
# = get_encoding('gpt-3.5-turbo-0613')
##############
data_path = './data/impact_theory_data.json'
cache_path = 'data/impact_theory_cache.parquet'
data = load_data(data_path)
cache = None # load_content_cache(cache_path)
if 'secrets' in st.secrets:
# st.write("Loading secrets from [secrets] section")
# for streamlit online or local, which uses a [secrets] section
Wapi_key = st.secrets['secrets']['WEAVIATE_API_KEY']
url = st.secrets['secrets']['WEAVIATE_ENDPOINT']
openai_api_key = st.secrets['secrets']['OPENAI_API_KEY']
hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN_chris']
hf_endpoint = st.secrets['secrets']['LLAMA2_ENDPOINT_UPLIMIT']
else :
# st.write("Loading secrets for Huggingface")
# for Huggingface (no [secrets] section)
Wapi_key = st.secrets['WEAVIATE_API_KEY']
url = st.secrets['WEAVIATE_ENDPOINT']
openai_api_key = st.secrets['OPENAI_API_KEY']
hf_token = st.secrets['LLAMA2_ENDPOINT_HF_TOKEN_chris']
hf_endpoint = st.secrets['LLAMA2_ENDPOINT_UPLIMIT']
# else:
# # if we want to use env file
# st.write("Loading secrets from environment variables")
# api_key = os.environ['WEAVIATE_API_KEY']
# url = os.environ['WEAVIATE_ENDPOINT']
# openai_api_key = os.environ['OPENAI_API_KEY']
# hf_token = os.environ['LLAMA2_ENDPOINT_HF_TOKEN_chris']
# hf_endpoint = os.environ['LLAMA2_ENDPOINT_UPLIMIT']
#%%
# model_default = 'sentence-transformers/all-mpnet-base-v2'
model_default = 'models/finetuned-all-mpnet-base-v2-300' if we_are_not_online \
else 'sentence-transformers/all-mpnet-base-v2'
available_models = ['sentence-transformers/all-mpnet-base-v2',
'sentence-transformers/all-MiniLM-L6-v2',
'models/finetuned-all-mpnet-base-v2-300',
'sentence-transformers/all-MiniLM-L12-v2']
#%%
models_urls = {'models/finetuned-all-mpnet-base-v2-300': "https://drive.google.com/drive/folders/1asJ37-AUv5nytLtH6hp6_bVV3_cZOXfj"}
def download_model_from_Gdrive(model_name_or_path, model_full_path):
print("Downloading model from Google Drive")
st.write("Downloading model from Google Drive")
assert model_name_or_path in models_urls, f"Model {model_name_or_path} not found in models_urls"
url = models_urls[model_name_or_path]
gdown.download_folder(url, output=model_full_path, quiet=False, use_cookies=False)
print("Model downloaded and saved to models folder")
# st.write("Model downloaded")
def download_model(model_name_or_path, model_full_path):
if model_name_or_path.startswith("models/"):
download_model_from_Gdrive(model_name_or_path, model_full_path)
print(f"Model {model_full_path} downloaded")
models_urls[model_name_or_path] = model_full_path
# st.sidebar.write(f"Model {model_full_path} downloaded")
elif model_name_or_path.startswith("sentence-transformers/"):
st.sidebar.write(f"Downloading Sentence Transformer model {model_name_or_path}")
model = SentenceTransformer(model_name_or_path) # HF looks into its own models folder/path
models_urls[model_name_or_path] = model_full_path
# st.sidebar.write(f"Model {model_name_or_path} downloaded")
model.save(model_full_path)
# st.sidebar.write(f"Model {model_name_or_path} saved to {model_full_path}")
# if 'modelspath' not in st.session_state:
# st.session_state['modelspath'] = None
# if st.session_state.modelspath is None:
# # let's create a temp folder on the first run
# persistent_dir = pathlib.Path("path/to/persistent_dir")
# persistent_dir.mkdir(parents=True, exist_ok=True)
# with tempfile.TemporaryDirectory() as temp_dir:
# st.session_state.modelspath = temp_dir
# print(f"Temporary directory created at {temp_dir}")
# # the temp folder disappears with the context, but not the one we've created manually
# else:
# temp_dir = st.session_state.modelspath
# print(f"Temporary directory already exists at {temp_dir}")
# # st.write(os.listdir(temp_dir))
#%%
# for streamlit online, we must download the model from google drive
# because github LFS doesn't work on forked repos
def check_model(model_name_or_path):
model_path = pathlib.Path(model_name_or_path)
model_full_path = str(pathlib.Path("models") / model_path) # this creates a models folder inside /models
model_full_path = model_full_path.replace("sentence-transformers/", "models/") # all are saved in models folder
if pathlib.Path(model_full_path).exists():
# let's use the model that's already there
print(f"Model {model_full_path} already exists")
# but delete everything else in we are online because
# streamlit online has limited space (and will shut down the app if it's full)
if we_are_online:
# st.sidebar.write(f"Model {model_full_path} already exists")
# st.sidebar.write(f"Deleting other models")
dirs = os.listdir("models/models")
# we get only the folder name, not the full path
dirs.remove(model_full_path.split('/')[-1])
for p in dirs:
dirpath = pathlib.Path("models/models") / p
if dirpath.is_dir():
shutil.rmtree(dirpath)
else:
if we_are_online:
# space issues on streamlit online, let's not leave anything behind
# and redownload the model eveery time
print("Deleting models/models folder")
if pathlib.Path('models/models').exists():
shutil.rmtree("models/models") # make room, if other models are there
# st.sidebar.write(f"models/models folder deleted")
download_model(model_name_or_path, model_full_path)
return model_full_path
#%% instantiate Weaviate client
def get_weaviate_client(api_key, url, model_name_or_path, openai_api_key):
client = WeaviateClient(api_key, url,
model_name_or_path=model_name_or_path,
openai_api_key=openai_api_key)
client.display_properties.append('summary')
available_classes = sorted(client.show_classes())
# st.write(f"Available classes: {available_classes}")
# st.write(f"Available classes type: {type(available_classes)}")
logger.info(available_classes)
return client, available_classes
##############
# data = load_data(data_path)
# guests list for sidebar
guest_list = sorted(list(set([d['guest'] for d in data])))
def main():
with st.sidebar:
# moved it to main area
# guest = st.selectbox('Select Guest',
# options=guest_list,
# index=None,
# placeholder='Select Guest')
_, center, _ = st.columns([3, 5, 3])
with center:
st.text("Search Lab")
_, center, _ = st.columns([2, 5, 3])
with center:
if we_are_online:
st.text("Running ONLINE")
# st.text("(UNSTABLE)")
else:
st.text("Running OFFLINE")
st.write("----------")
hubrid_search = st.toggle('Hybrid Search', True)
if hubrid_search:
alpha_input = st.slider(label='Alpha',min_value=0.00, max_value=1.00, value=0.40, step=0.05)
retrieval_limit = st.slider(label='Hybrid Search Results', min_value=10, max_value=300, value=10, step=10)
hybrid_filter = st.toggle('Filter Guest', True) # i.e. look only at guests' data
rerank = st.toggle('Use Reranker', True)
if rerank:
reranker_topk = st.slider(label='Reranker Top K',min_value=1, max_value=5, value=3, step=1)
else:
# needed to not fill the LLM with too many responses (> context size)
# we could make it dependent on the model
reranker_topk = 3
rag_it = st.toggle('RAG it', True)
if rag_it:
st.write(f"Using LLM '{model_nameGPT}'")
llm_temperature = st.slider(label='LLM T˚', min_value=0.0, max_value=2.0, value=0.01, step=0.10 )
model_name_or_path = st.selectbox(label='Model Name:', options=available_models,
index=available_models.index(model_default),
placeholder='Select Model')
st.write("Experimental and time limited 2'")
finetune_model = st.toggle('Finetune on Modal A100 GPU', False)
if we_are_not_online or we_are_online:
if finetune_model:
from finetune_backend import finetune
if 'finetuned' in model_name_or_path:
st.write("Model already finetuned")
elif "models/" in model_name_or_path:
st.write("sentence-transformers models only!")
else:
# try:
if 'finetuned' in model_name_or_path:
st.write("Model already finetuned")
else:
model_path = finetune(model_name_or_path, savemodel=True, outpath='models')
st.write(f"model_path returned = {model_path}")
if model_path is not None:
if model_name_or_path.split('/')[-1] not in model_path:
st.write(model_path) # a warning from finetuning in this case
elif model_path not in available_models:
# finetuning generated a model, let's add it
available_models.append(model_path)
st.write("Model saved!")
# except Exception:
# st.write("Model not found on HF or error")
else:
st.write("Finetuning not available on Streamlit online because of space limitations")
model_name_or_path = check_model(model_name_or_path)
try:
client, available_classes = get_weaviate_client(Wapi_key, url, model_name_or_path, openai_api_key)
except Exception as e:
# Weaviate doesn't know this model, maybe we're just finetuning a model
st.sidebar.write(f"Model unknown to Weaviate")
st.stop()
start_class = 'Impact_theory_all_mpnet_base_v2_finetuned'
class_name = st.selectbox(
label='Class Name:',
options=available_classes,
index=available_classes.index(start_class),
placeholder='Select Class Name'
)
st.write("----------")
c1,c2 = st.columns([8,1])
with c1:
show_metrics = st.toggle('Show Metrics on Golden set', False)
if show_metrics:
# _, center, _ = st.columns([3, 5, 3])
# with center:
# st.text("Metrics")
with c2:
with st.spinner(''):
metrics = execute_evaluation(golden_dataset, class_name, client, alpha=alpha_input)
if show_metrics:
kw_hit_rate = metrics['kw_hit_rate']
kw_mrr = metrics['kw_mrr']
hybrid_hit_rate = metrics['hybrid_hit_rate']
vector_hit_rate = metrics['vector_hit_rate']
vector_mrr = metrics['vector_mrr']
total_misses = metrics['total_misses']
st.text(f"KW hit rate: {kw_hit_rate}")
st.text(f"Vector hit rate: {vector_hit_rate}")
st.text(f"Hybrid hit rate: {hybrid_hit_rate}")
st.text(f"Hybrid MRR: {vector_mrr}")
st.text(f"Total misses: {total_misses}")
st.write("----------")
st.title("Chat with the Impact Theory podcasts!")
# st.image('./assets/impact-theory-logo.png', width=400)
st.image('assets/it_tom_bilyeu.png', use_column_width=True)
# st.subheader(f"Chat with the Impact Theory podcast: ")
st.write('\n')
# st.stop()
st.write("\u21D0 Open the sidebar to change Search settings \n ") # https://home.unicode.org also 21E0, 21B0 B2 D0
if not hubrid_search:
st.stop()
col1, _ = st.columns([3,7])
with col1:
guest = st.selectbox('Select A Guest',
options=guest_list,
index=None,
placeholder='Select Guest')
col1, col2 = st.columns([7,3])
with col1:
if guest is None:
msg = f'Select a guest before asking your question:'
else:
msg = f'Enter your question about {guest}:'
textbox = st.empty()
# best solution I found to be able to change the text inside a text_input box afterwards, using a key
query = textbox.text_input(msg,
value="",
placeholder="You can refer to the guest with PRONOUNS",
key=st.session_state.key)
# st.write(f"Guest = {guest}")
# st.write(f"key = {st.session_state.key}")
st.write('\n\n\n\n\n')
reworded_query = {'changed': False, 'status': 'error'} # at start, the query is empty
valid_response = [] # at start, the query is empty, so prevent the search
if query:
if guest is None:
st.session_state.key += 1
query = textbox.text_input(msg,
value="",
placeholder="YOU MUST SELECT A GUEST BEFORE ASKING A QUESTION",
key=st.session_state.key)
# st.write(f"key = {st.session_state.key}")
st.stop()
else:
# st.write(f'It looks like you selected {guest} as a filter (It is ignored for now).')
with col2:
# let's add a nice pulse bar while generating the response
with hc.HyLoader('', hc.Loaders.pulse_bars, primary_color= 'red', height=50): #"#0e404d" for image green
# with st.spinner('Generating Response...'):
with col1:
# let's use Llama2 here
reworded_query = reword_query(query, guest,
model_name='llama2-13b-chat')
new_query = reworded_query['rewritten_question']
if guest.split(' ')[1] not in new_query and guest.split(' ')[0] not in new_query:
# if the guest name is not in the rewritten question, we add it
new_query = f"About {guest}, " + new_query
query = new_query
# we can arrive here only if a guest was selected
where_filter = WhereFilter(path=['guest'], operator='Equal', valueText=guest).todict() \
if hybrid_filter else None
hybrid_response = client.hybrid_search(query,
class_name,
# properties=['content'], #['title', 'summary', 'content'],
alpha=alpha_input,
display_properties=client.display_properties,
where_filter=where_filter,
limit=retrieval_limit)
response = hybrid_response
if rerank:
# rerank results with cross encoder
ranked_response = reranker.rerank(response, query,
apply_sigmoid=True, # score between 0 and 1
top_k=reranker_topk)
logger.info(ranked_response)
expanded_response = expand_content(ranked_response, cache,
content_key='doc_id',
create_new_list=True)
response = expanded_response
# make sure token count < threshold
token_threshold = 8000 if model_nameGPT == model_ids[0] else 3500
valid_response = validate_token_threshold(response,
question_answering_prompt_series,
query=query,
tokenizer= encoding,# variable from ENCODING,
token_threshold=token_threshold,
verbose=True)
# st.write(f"Number of results: {len(valid_response)}")
# I jump out of col1 to get all page width, so need to retest query
if query is not None and reworded_query['status'] != 'error':
show_query = st.toggle('Show rewritten query', False)
if show_query: # or reworded_query['changed']:
st.write(f"Rewritten query: {query}")
# creates container for LLM response to position it above search results
chat_container, response_box = [], st.empty()
# # RAG time !! execute chat call to LLM
if rag_it:
# st.subheader("Response from Impact Theory (context)")
# will appear under the answer, moved it into the response box
# generate LLM prompt
prompt = generate_prompt_series(query=query, results=valid_response)
GPTllm = GPT_Turbo(model=model_nameGPT,
api_key=openai_api_key)
try:
# inserts chat stream from LLM
for resp in GPTllm.get_chat_completion(prompt=prompt,
temperature=llm_temperature,
max_tokens=350,
show_response=True,
stream=True):
with response_box:
content = resp.choices[0].delta.content
if content:
chat_container.append(content)
result = "".join(chat_container).strip()
response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}")
except BadRequestError as e:
logger.info('Making request with smaller context')
valid_response = validate_token_threshold(response,
question_answering_prompt_series,
query=query,
tokenizer=encoding,
token_threshold=3500,
verbose=True)
# if reranker is off, we may receive a LOT of responses
# so we must reduce the context size manually
if not rerank:
valid_response = valid_response[:reranker_topk]
prompt = generate_prompt_series(query=query, results=valid_response)
for resp in GPTllm.get_chat_completion(prompt=prompt,
temperature=llm_temperature,
max_tokens=350, # expand for more verbose answers
show_response=True,
stream=True):
try:
# inserts chat stream from LLM
with response_box:
content = resp.choice[0].delta.content
if content:
chat_container.append(content)
result = "".join(chat_container).strip()
response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}")
except Exception as e:
print(e)
st.markdown("----")
st.subheader("Search Results")
for i, hit in enumerate(valid_response):
col1, col2 = st.columns([7, 3], gap='large')
image = hit['thumbnail_url'] # get thumbnail_url
episode_url = hit['episode_url'] # get episode_url
title = hit["title"] # get title
show_length = hit["length"] # get length
time_string = str(timedelta(seconds=show_length)) # convert show_length to readable time string
with col1:
st.write(search_result(i=i,
url=episode_url,
guest=hit['guest'],
title=title,
content='',
length=time_string),
unsafe_allow_html=True)
st.write('\n\n')
with col2:
#st.write(f"<a href={episode_url} <img src={image} width='200'></a>",
# unsafe_allow_html=True)
#st.markdown(f"[![{title}]({image})]({episode_url})")
# st.markdown(f'<a href="{episode_url}">'
# f'<img src={image} '
# f'caption={title.split("|")[0]} width=200, use_column_width=False />'
# f'</a>',
# unsafe_allow_html=True)
st.image(image, caption=title.split('|')[0], width=200, use_column_width=False)
# let's use all width for the content
st.write(hit['content'])
def get_answer(query, valid_response, GPTllm):
# generate LLM prompt
prompt = generate_prompt_series(query=query,
results=valid_response)
return GPTllm.get_chat_completion(prompt=prompt,
system_message='answer this question based on the podcast material',
temperature=0,
max_tokens=500,
stream=False,
show_response=False)
def reword_query(query, guest, model_name='llama2-13b-chat', response_processing=True):
""" Asks LLM to rewrite the query when the guest name is missing.
Args:
query (str): user query
guest (str): guest name
model_name (str, optional): name of a LLM model to be used
"""
# tags = {'llama2-13b-chat': {'start': '<s>', 'end': '</s>', 'instruction': '[INST]', 'system': '[SYS]'},
# 'gpt-3.5-turbo-0613': {'start': '<|startoftext|>', 'end': '', 'instruction': "```", 'system': ```}}
prompt_fields = {
"you_are":f"You are an expert in linguistics and semantics, analyzing the question asked by a user to a vector search system, \
and making sure that the question is well formulated and that the system can understand it.",
"your_task":f"Your task is to detect if the name of the guest ({guest}) is mentioned in the user's question, \
and if that is not the case, rewrite the question using the guest name, \
without changing the meaning of the question. \
Most of the time, the user will have used a pronoun to designate the guest, in which case, \
simply replace the pronoun with the guest name.",
"question":f"If the user mentions the guest name, ie {query}, just return his question as is. \
If the user does not mention the guest name, rewrite the question using the guest name.",
"final_instruction":f"Only regerate the requested rewritten question or the original, WITHOUT ANY COMMENT OR REPHRASING. \
Your answer must be as close as possible to the original question, \
and exactly identical, word for word, if the user mentions the guest name, i.e. {guest}.",
}
# prompt created by chatGPT :-)
# and Llama still outputs the original question and precedes the answer with 'rewritten question'
prompt_fields2 = {
"you_are": (
"You are an expert in linguistics and semantics. Your role is to analyze questions asked to a vector search system."
),
"your_task": (
f"Detect if the guest's FULL name, {guest}, is mentioned in the user's question. "
"If not, rewrite the question by replacing pronouns or indirect references with the guest's name." \
"If yes, return the original question as is, without any change at all, not even punctuation,"
"except a question mark that you MUST add if it's missing."
),
"question": (
f"Original question: '{query}'. "
"Rewrite this question to include the guest's FULL name if it's not already mentioned."
"The Only thing you can and MUST add is a question mark if it's missing."
),
"final_instruction": (
"Create a rewritten question or keep the original question as is. "
"Do not include any labels, titles, or additional text before or after the question."
"The Only thing you can and MUST add is a question mark if it's missing."
"Return a json object, with the key 'original_question' for the original question, \
and 'rewritten_question' for the rewritten question \
and 'changed' being True if you changed the answer, otherwise False."
),
}
if model_name == 'llama2-13b-chat':
# special tags are used:
# `<s>` - start prompt tag
# `[INST], [/INST]` - Opening and closing model instruction tags
# `<<<SYS>>>, <</SYS>>` - Opening and closing system prompt tags
llama_prompt = """
<s>[INST] <<SYS>>
{you_are}
<</SYS>>
{your_task}\n
```
\n\n
Question: {question}\n
{final_instruction} [/INST]
Answer:
"""
prompt = llama_prompt.format(**prompt_fields2)
headers = {"Authorization": f"Bearer {hf_token}",
"Content-Type": "application/json",}
json_body = {
"inputs": prompt,
"parameters": {"max_new_tokens":400,
"repetition_penalty": 1.0,
"temperature":0.01}
}
response = requests.request("POST", hf_endpoint, headers=headers, data=json.dumps(json_body))
response = json.loads(response.content.decode("utf-8"))
# ^ will not process the badly formatted generated text, so we do it ourselves
if isinstance(response, dict) and 'error' in response:
print("Found error")
print(response)
# return {'error': response['error'], 'rewritten_question': query, 'changed': False, 'status': 'error'}
# I test this here otherwise it gets in col 2 or 1, which are too
# if reworded_query['status'] == 'error':
# st.write(f"Error in LLM response: 'error':{reworded_query['error']}")
# st.write("The LLM could not connect to the server. Please try again later.")
# st.stop()
return reword_query(query, guest, model_name='gpt-3.5-turbo-0613')
if response_processing:
if isinstance(response, list) and isinstance(response[0], dict) and 'generated_text' in response[0]:
print("Found generated text")
response0 = response[0]['generated_text']
pattern = r'\"(\w+)\":\s*(\".*?\"|\w+)'
matches = re.findall(pattern, response0)
# let's build a dictionary
result = {key: json.loads(value) if value.startswith("\"") else value for key, value in matches}
return result | {'status': 'success'}
else:
print("Found no answer")
return reword_query(query, guest, model_name='gpt-3.5-turbo-0613')
# return {'original_question': query, 'rewritten_question': query, 'changed': False, 'status': 'no properly formatted answer' }
else:
return response
# return response
# assert 'error' not in response, f"Error in LLM response: {response['error']}"
# assert 'generated_text' in response[0], f"Error in LLM response: {response}, no 'generated_text' field"
# # let's extract the rewritten question
# return response[0]['generated_text'] .split("Rewritten question: '")[-1][:-1]
else:
# assume openai
model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613']
model_name = model_ids[1]
GPTllm = GPT_Turbo(model=model_name,
api_key=openai_api_key)
openai_prompt = """
{your_task}\n
```
\n\n
Question: {question}\n
{final_instruction}
Answer:
"""
prompt = openai_prompt.format(**prompt_fields)
try:
resp = GPTllm.get_chat_completion(prompt=openai_prompt,
system_message=prompt_fields['you_are'],
temperature=0.01,
max_tokens=1500, # it's a question...
show_response=True,
stream=False)
return {'rewritten_question': resp.choices[0].delta.content,
'changed': True, 'status': 'success'}
except Exception:
return {'rewritten_question': query, 'changed': False, 'status': 'not success'}
if __name__ == '__main__':
main()
# %%
|