File size: 10,363 Bytes
0fac726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import logging
import os
import re

import streamlit as st
from streamlit.logger import get_logger

from knowledgebase import Knowledgebase
from utils.constants import (
    AssistantType,
    OPENAI_KNOWLEDGEBASE_KEY,
    HUGGINGFACEHUB_API_TOKEN_KEY,
    HF_KNOWLEDGEBASE_KEY,
    SOURCES_TAG,
    ANSWER_TAG,
    NONE_TAG,
    EMPTY_TAG,
    MESSAGE_HISTORY_TAG,
    TEXT_TAG,
    USER_TAG,
    ASSISTANT_TAG,
    FROM_TAG,
    IN_PROGRESS_TAG,
    QUERY_INPUT_TAG,
    VALID_TOKEN_TAG,
    StNotificationType,
    API_KEY_TAG,
    ASSISTANT_TYPE_TAG,
    ASSISTANT_AVATAR,
    USER_AVATAR,
    EmbeddingType,
    APIKeyType,
)
from utils.llm import validate_api_token

# initialize a logger
logger = get_logger(__name__)


def retrieve_answer(query: str):
    try:
        assistant_type = st.session_state.selected_assistant_type
        embedding_type = EmbeddingType.HUGGINGFACE
        assistant_api_key = st.session_state.verified_api_key
        embedding_api_key = st.session_state.embedding_api_key
        knowledgebase_name = st.session_state.knowledgebase_name

        knowledgebase = Knowledgebase(
            assistant_type=assistant_type,
            embedding_type=embedding_type,
            assistant_api_key=assistant_api_key,
            embedding_api_key=embedding_api_key,
            knowledgebase_name=knowledgebase_name,
        )
        answer, metadata = knowledgebase.query_knowledgebase(query=query)
        if not metadata:
            metadata = "$0.00"

        final_answer = re.sub(
            r"\bSOURCES:[\n\s]*$", "", str(answer[ANSWER_TAG]).strip()
        ).strip()
        logger.info(f"final answer: {final_answer}")

        if answer.get(SOURCES_TAG, None) not in [None, NONE_TAG, EMPTY_TAG]:
            return f"{final_answer}\n\nSources:\n{answer[SOURCES_TAG]}\n\nCost (USD):\n`{metadata}`"
        else:
            return f"{final_answer}\n\nCost:\n`{metadata}`"
    except Exception as e:
        logger.exception(f"Invalid API key. {e}")
        return (
            f"Could not retrieve the answer. This could be due to "
            f"various reasons such as Invalid API Tokens or hitting "
            f"the Rate limit enforced by LLM vendors."
        )


def show_chat_ui():
    if (
        st.session_state.selected_assistant_type == AssistantType.HUGGINGFACE
        and not st.session_state.get(MESSAGE_HISTORY_TAG, None)
    ):
        show_notification_banner_ui(
            notification_type=StNotificationType.WARNING,
            notification="πŸ€—πŸ€πŸ½ HuggingFace assistant is not always guaranteed "
            "to return a valid response and often exceeds the "
            "maximum token limit. Use the OpenAI assistant for "
            "more reliable responses.",
        )

    if not st.session_state.get(MESSAGE_HISTORY_TAG, None):
        st.subheader("Let's start chatting, shall we?")

    if st.session_state.get(IN_PROGRESS_TAG, False):
        query = st.chat_input(
            "Ask me about ShoutOUT AI stuff", key=QUERY_INPUT_TAG, disabled=True
        )
    else:
        query = st.chat_input("Ask me about ShoutOUT AI stuff", key=QUERY_INPUT_TAG)

    if query:
        st.session_state.in_progress = True
        current_messages = st.session_state.get(MESSAGE_HISTORY_TAG, [])
        current_messages.append({TEXT_TAG: query, FROM_TAG: USER_TAG})
        st.session_state.message_history = current_messages
        answer = retrieve_answer(query=query)
        current_messages.append({TEXT_TAG: answer, FROM_TAG: ASSISTANT_TAG})
        st.session_state.message_history = current_messages
        st.session_state.in_progress = False

    if st.session_state.get(MESSAGE_HISTORY_TAG, None):
        messages = st.session_state.message_history
        for message in messages:
            if message.get(FROM_TAG) == USER_TAG:
                with st.chat_message(USER_TAG, avatar=USER_AVATAR):
                    st.write(message.get(TEXT_TAG))

            if message.get(FROM_TAG) == ASSISTANT_TAG:
                with st.chat_message(ASSISTANT_TAG, avatar=ASSISTANT_AVATAR):
                    st.write(message.get(TEXT_TAG))


def show_hf_chat_ui():
    st.sidebar.info(
        "πŸ€— You are using the Hugging Face Hub models for the QA task and "
        "performance might not be as good as proprietary LLMs."
    )

    verify_token()
    validated_token = st.session_state.get(VALID_TOKEN_TAG, None)
    if validated_token is None:
        st.stop()
    if not validated_token:
        st.sidebar.error("❌ Failed to get connected to the HuggingFace Hub")
        show_notification_banner_ui(
            notification_type=StNotificationType.INFO,
            notification="Failed to get connected to the HuggingFace Hub",
        )
        st.stop()

    st.sidebar.success(f"βœ… Connected to the HF Hub")
    show_chat_ui()


def show_openai_chat_ui():
    st.sidebar.info(
        "πŸš€ To get started, enter your OpenAI API key. Once that's done, "
        "you can ask start asking questions. Oh! one more thing, we take "
        "security seriously and we are NOT storing the API keys in any manner, "
        "so you're safe. Just revoke it after usage to make sure nothing "
        "unexpected happens."
    )
    if st.sidebar.text_input(
        "Enter the OpenAI API Key",
        key=API_KEY_TAG,
        label_visibility="hidden",
        placeholder="OpenAI API Key",
        type="password",
    ):
        verify_token()

    validated_token = st.session_state.get(VALID_TOKEN_TAG, None)
    if validated_token is None:
        st.sidebar.info(f"πŸ—οΈ Provide the API Key")
        st.stop()
    if not validated_token:
        st.sidebar.error("❌ API Key you provided is invalid")
        show_notification_banner_ui(
            notification_type=StNotificationType.INFO,
            notification="Please provide a valid OpenAI API Key",
        )
        st.stop()

    st.sidebar.success(f"βœ… Token Validated!")
    show_chat_ui()


def show_notification_banner_ui(
    notification_type: StNotificationType, notification: str
):
    if notification_type == StNotificationType.INFO:
        st.info(notification)
    elif notification_type == StNotificationType.WARNING:
        st.warning(notification)
    elif notification_type == StNotificationType.ERROR:
        st.error(notification)


def verify_token():
    from dotenv import load_dotenv

    load_dotenv()

    embedding_api_key = os.getenv(HUGGINGFACEHUB_API_TOKEN_KEY, None)
    st_assistant_type = st.session_state.selected_assistant_type
    if st_assistant_type == AssistantType.OPENAI:
        assistant_api_key = st.session_state.get(API_KEY_TAG, None)
        assistant_api_key_type = APIKeyType.OPENAI
        knowledgebase_name = os.environ.get(OPENAI_KNOWLEDGEBASE_KEY, None)
    else:
        assistant_api_key = os.getenv(HUGGINGFACEHUB_API_TOKEN_KEY, None)
        assistant_api_key_type = APIKeyType.HUGGINGFACE
        knowledgebase_name = os.environ.get(HF_KNOWLEDGEBASE_KEY, None)

    logger.info(
        f"The API key for the current st session: {assistant_api_key}\n"
        f"The Knowledgebase for the current st session: {knowledgebase_name}"
    )

    assistant_valid, assistant_err = validate_api_token(
        api_key_type=assistant_api_key_type,
        api_key=assistant_api_key,
    )
    embedding_valid, embedding_err = validate_api_token(
        api_key_type=APIKeyType.HUGGINGFACE,
        api_key=embedding_api_key,
    )

    if assistant_valid and embedding_valid:
        st.session_state.valid_token = True
        st.session_state.verified_api_key = assistant_api_key
        st.session_state.embedding_api_key = embedding_api_key
        st.session_state.knowledgebase_name = knowledgebase_name
    elif not assistant_valid and not embedding_valid:
        st.session_state.valid_token = False
        st.session_state.token_err = f"{assistant_err}\n{embedding_err}"
    elif not assistant_valid:
        st.session_state.valid_token = False
        st.session_state.token_err = assistant_err
    elif not embedding_valid:
        st.session_state.valid_token = False
        st.session_state.token_err = embedding_err
    else:
        st.session_state.valid_token = False
        st.session_state.token_err = (
            "An unknown error occurred while validating the API keys"
        )


def app():
    # sidebar
    st.sidebar.image(
        "https://thisisishara.com/res/images/favicon/android-chrome-192x192.png",
        width=80,
    )
    if st.sidebar.selectbox(
        "Assistant Type",
        ["OpenAI", "Hugging Face"],
        key=ASSISTANT_TYPE_TAG,
        placeholder="Select Assistant Type",
    ):
        if str(st.session_state.assistant_type).lower() == AssistantType.OPENAI.value:
            st.session_state.selected_assistant_type = AssistantType.OPENAI
        else:
            st.session_state.selected_assistant_type = AssistantType.HUGGINGFACE
        st.session_state.valid_token = None
        st.session_state.verified_api_key = None
        st.session_state.knowledgebase_name = None

    st.write(st.session_state.selected_assistant_type)

    # main section
    st.header("LLM Website QA Demo")
    st.caption("⚑ Powered by :blue[LangChain], :green[OpenAI] & :green[Hugging Face]")

    assistant_type = st.session_state.selected_assistant_type
    if assistant_type == AssistantType.OPENAI:
        show_openai_chat_ui()
    elif assistant_type == AssistantType.HUGGINGFACE:
        show_hf_chat_ui()
    else:
        show_notification_banner_ui(
            notification_type=StNotificationType.INFO,
            notification="Please select an assistant type to get started!",
        )


if __name__ == "__main__":
    st.set_page_config(
        page_title="Website QA powered by LangChain & LLMs",
        page_icon="https://thisisishara.com/res/images/favicon/android-chrome-192x192.png",
        layout="wide",
        initial_sidebar_state="expanded",
    )
    hide_streamlit_style = """
                <style>
                # #MainMenu {visibility: hidden;}
                # footer {visibility: hidden;}
                [data-testid="stDecoration"] {background: linear-gradient(to right, #9EE51A, #208BBC) !important;}
                </style>
                """
    st.markdown(hide_streamlit_style, unsafe_allow_html=True)

    # run the app
    app()