Spaces:
Runtime error
Runtime error
File size: 10,363 Bytes
0fac726 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
import logging
import os
import re
import streamlit as st
from streamlit.logger import get_logger
from knowledgebase import Knowledgebase
from utils.constants import (
AssistantType,
OPENAI_KNOWLEDGEBASE_KEY,
HUGGINGFACEHUB_API_TOKEN_KEY,
HF_KNOWLEDGEBASE_KEY,
SOURCES_TAG,
ANSWER_TAG,
NONE_TAG,
EMPTY_TAG,
MESSAGE_HISTORY_TAG,
TEXT_TAG,
USER_TAG,
ASSISTANT_TAG,
FROM_TAG,
IN_PROGRESS_TAG,
QUERY_INPUT_TAG,
VALID_TOKEN_TAG,
StNotificationType,
API_KEY_TAG,
ASSISTANT_TYPE_TAG,
ASSISTANT_AVATAR,
USER_AVATAR,
EmbeddingType,
APIKeyType,
)
from utils.llm import validate_api_token
# initialize a logger
logger = get_logger(__name__)
def retrieve_answer(query: str):
try:
assistant_type = st.session_state.selected_assistant_type
embedding_type = EmbeddingType.HUGGINGFACE
assistant_api_key = st.session_state.verified_api_key
embedding_api_key = st.session_state.embedding_api_key
knowledgebase_name = st.session_state.knowledgebase_name
knowledgebase = Knowledgebase(
assistant_type=assistant_type,
embedding_type=embedding_type,
assistant_api_key=assistant_api_key,
embedding_api_key=embedding_api_key,
knowledgebase_name=knowledgebase_name,
)
answer, metadata = knowledgebase.query_knowledgebase(query=query)
if not metadata:
metadata = "$0.00"
final_answer = re.sub(
r"\bSOURCES:[\n\s]*$", "", str(answer[ANSWER_TAG]).strip()
).strip()
logger.info(f"final answer: {final_answer}")
if answer.get(SOURCES_TAG, None) not in [None, NONE_TAG, EMPTY_TAG]:
return f"{final_answer}\n\nSources:\n{answer[SOURCES_TAG]}\n\nCost (USD):\n`{metadata}`"
else:
return f"{final_answer}\n\nCost:\n`{metadata}`"
except Exception as e:
logger.exception(f"Invalid API key. {e}")
return (
f"Could not retrieve the answer. This could be due to "
f"various reasons such as Invalid API Tokens or hitting "
f"the Rate limit enforced by LLM vendors."
)
def show_chat_ui():
if (
st.session_state.selected_assistant_type == AssistantType.HUGGINGFACE
and not st.session_state.get(MESSAGE_HISTORY_TAG, None)
):
show_notification_banner_ui(
notification_type=StNotificationType.WARNING,
notification="π€π€π½ HuggingFace assistant is not always guaranteed "
"to return a valid response and often exceeds the "
"maximum token limit. Use the OpenAI assistant for "
"more reliable responses.",
)
if not st.session_state.get(MESSAGE_HISTORY_TAG, None):
st.subheader("Let's start chatting, shall we?")
if st.session_state.get(IN_PROGRESS_TAG, False):
query = st.chat_input(
"Ask me about ShoutOUT AI stuff", key=QUERY_INPUT_TAG, disabled=True
)
else:
query = st.chat_input("Ask me about ShoutOUT AI stuff", key=QUERY_INPUT_TAG)
if query:
st.session_state.in_progress = True
current_messages = st.session_state.get(MESSAGE_HISTORY_TAG, [])
current_messages.append({TEXT_TAG: query, FROM_TAG: USER_TAG})
st.session_state.message_history = current_messages
answer = retrieve_answer(query=query)
current_messages.append({TEXT_TAG: answer, FROM_TAG: ASSISTANT_TAG})
st.session_state.message_history = current_messages
st.session_state.in_progress = False
if st.session_state.get(MESSAGE_HISTORY_TAG, None):
messages = st.session_state.message_history
for message in messages:
if message.get(FROM_TAG) == USER_TAG:
with st.chat_message(USER_TAG, avatar=USER_AVATAR):
st.write(message.get(TEXT_TAG))
if message.get(FROM_TAG) == ASSISTANT_TAG:
with st.chat_message(ASSISTANT_TAG, avatar=ASSISTANT_AVATAR):
st.write(message.get(TEXT_TAG))
def show_hf_chat_ui():
st.sidebar.info(
"π€ You are using the Hugging Face Hub models for the QA task and "
"performance might not be as good as proprietary LLMs."
)
verify_token()
validated_token = st.session_state.get(VALID_TOKEN_TAG, None)
if validated_token is None:
st.stop()
if not validated_token:
st.sidebar.error("β Failed to get connected to the HuggingFace Hub")
show_notification_banner_ui(
notification_type=StNotificationType.INFO,
notification="Failed to get connected to the HuggingFace Hub",
)
st.stop()
st.sidebar.success(f"β
Connected to the HF Hub")
show_chat_ui()
def show_openai_chat_ui():
st.sidebar.info(
"π To get started, enter your OpenAI API key. Once that's done, "
"you can ask start asking questions. Oh! one more thing, we take "
"security seriously and we are NOT storing the API keys in any manner, "
"so you're safe. Just revoke it after usage to make sure nothing "
"unexpected happens."
)
if st.sidebar.text_input(
"Enter the OpenAI API Key",
key=API_KEY_TAG,
label_visibility="hidden",
placeholder="OpenAI API Key",
type="password",
):
verify_token()
validated_token = st.session_state.get(VALID_TOKEN_TAG, None)
if validated_token is None:
st.sidebar.info(f"ποΈ Provide the API Key")
st.stop()
if not validated_token:
st.sidebar.error("β API Key you provided is invalid")
show_notification_banner_ui(
notification_type=StNotificationType.INFO,
notification="Please provide a valid OpenAI API Key",
)
st.stop()
st.sidebar.success(f"β
Token Validated!")
show_chat_ui()
def show_notification_banner_ui(
notification_type: StNotificationType, notification: str
):
if notification_type == StNotificationType.INFO:
st.info(notification)
elif notification_type == StNotificationType.WARNING:
st.warning(notification)
elif notification_type == StNotificationType.ERROR:
st.error(notification)
def verify_token():
from dotenv import load_dotenv
load_dotenv()
embedding_api_key = os.getenv(HUGGINGFACEHUB_API_TOKEN_KEY, None)
st_assistant_type = st.session_state.selected_assistant_type
if st_assistant_type == AssistantType.OPENAI:
assistant_api_key = st.session_state.get(API_KEY_TAG, None)
assistant_api_key_type = APIKeyType.OPENAI
knowledgebase_name = os.environ.get(OPENAI_KNOWLEDGEBASE_KEY, None)
else:
assistant_api_key = os.getenv(HUGGINGFACEHUB_API_TOKEN_KEY, None)
assistant_api_key_type = APIKeyType.HUGGINGFACE
knowledgebase_name = os.environ.get(HF_KNOWLEDGEBASE_KEY, None)
logger.info(
f"The API key for the current st session: {assistant_api_key}\n"
f"The Knowledgebase for the current st session: {knowledgebase_name}"
)
assistant_valid, assistant_err = validate_api_token(
api_key_type=assistant_api_key_type,
api_key=assistant_api_key,
)
embedding_valid, embedding_err = validate_api_token(
api_key_type=APIKeyType.HUGGINGFACE,
api_key=embedding_api_key,
)
if assistant_valid and embedding_valid:
st.session_state.valid_token = True
st.session_state.verified_api_key = assistant_api_key
st.session_state.embedding_api_key = embedding_api_key
st.session_state.knowledgebase_name = knowledgebase_name
elif not assistant_valid and not embedding_valid:
st.session_state.valid_token = False
st.session_state.token_err = f"{assistant_err}\n{embedding_err}"
elif not assistant_valid:
st.session_state.valid_token = False
st.session_state.token_err = assistant_err
elif not embedding_valid:
st.session_state.valid_token = False
st.session_state.token_err = embedding_err
else:
st.session_state.valid_token = False
st.session_state.token_err = (
"An unknown error occurred while validating the API keys"
)
def app():
# sidebar
st.sidebar.image(
"https://thisisishara.com/res/images/favicon/android-chrome-192x192.png",
width=80,
)
if st.sidebar.selectbox(
"Assistant Type",
["OpenAI", "Hugging Face"],
key=ASSISTANT_TYPE_TAG,
placeholder="Select Assistant Type",
):
if str(st.session_state.assistant_type).lower() == AssistantType.OPENAI.value:
st.session_state.selected_assistant_type = AssistantType.OPENAI
else:
st.session_state.selected_assistant_type = AssistantType.HUGGINGFACE
st.session_state.valid_token = None
st.session_state.verified_api_key = None
st.session_state.knowledgebase_name = None
st.write(st.session_state.selected_assistant_type)
# main section
st.header("LLM Website QA Demo")
st.caption("β‘ Powered by :blue[LangChain], :green[OpenAI] & :green[Hugging Face]")
assistant_type = st.session_state.selected_assistant_type
if assistant_type == AssistantType.OPENAI:
show_openai_chat_ui()
elif assistant_type == AssistantType.HUGGINGFACE:
show_hf_chat_ui()
else:
show_notification_banner_ui(
notification_type=StNotificationType.INFO,
notification="Please select an assistant type to get started!",
)
if __name__ == "__main__":
st.set_page_config(
page_title="Website QA powered by LangChain & LLMs",
page_icon="https://thisisishara.com/res/images/favicon/android-chrome-192x192.png",
layout="wide",
initial_sidebar_state="expanded",
)
hide_streamlit_style = """
<style>
# #MainMenu {visibility: hidden;}
# footer {visibility: hidden;}
[data-testid="stDecoration"] {background: linear-gradient(to right, #9EE51A, #208BBC) !important;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
# run the app
app()
|