File size: 7,162 Bytes
e486ecf
 
 
 
 
 
 
 
 
 
 
 
c3d603e
e486ecf
 
 
 
 
 
84bbe9a
e486ecf
4c7452b
da262da
4c7452b
e486ecf
 
 
 
 
 
 
 
 
 
7b9bd98
e486ecf
 
7b9bd98
760bd47
 
e486ecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b9bd98
 
e486ecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import streamlit as st
from streamlit_chat import message
from langchain_openai import OpenAIEmbeddings
from pinecone import Pinecone
import time
from langchain_pinecone.vectorstores import Pinecone as PineconeVectorStore
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
from langchain.memory import ConversationBufferMemory
from langchain_core.runnables import RunnableLambda
from operator import itemgetter

# Streamlit App Configuration
st.set_page_config(page_title="Docu-Help")

# Dropdown for namespace selection
namespace_name = st.sidebar.selectbox("Select Website:", ('crawlee', ''), key='namespace_name')

# Read API keys from environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PINE_API_KEY = os.getenv("PINE_API_KEY")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
LANGCHAIN_TRACING_V2 = 'true'
LANGCHAIN_ENDPOINT = "https://api.smith.langchain.com"
LANGCHAIN_PROJECT = "docu-help"

# Sidebar for model selection and Pinecone index name input
st.sidebar.title("Sidebar")
model_name = st.sidebar.radio("Choose a model:", ("gpt-3.5-turbo-1106", "gpt-4-0125-preview", "Claude-Sonnet", "mixtral-groq"))
openai_api_key2 = st.sidebar.text_input("Enter OpenAI Key: ")
groq_api_key = st.sidebar.text_input("Groq API Key: ")
anthropic_api_key = st.sidebar.text_input("Claude API Key: ")
pinecone_index_name = os.getenv("pinecone_index_name")
namespace_name = "crawlee"

# Initialize session state variables if they don't exist
if 'generated' not in st.session_state:
    st.session_state['generated'] = []

if 'past' not in st.session_state:
    st.session_state['past'] = []

if 'messages' not in st.session_state:
    st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]

if 'total_cost' not in st.session_state:
    st.session_state['total_cost'] = 0.0
        
def refresh_text():
    with response_container:
        for i in range(len(st.session_state['past'])):
            try:
                user_message_content = st.session_state["past"][i]
                message = st.chat_message("user")
                message.write(user_message_content)
            except:
                print("Past error")
            
            try:
                ai_message_content = st.session_state["generated"][i]
                message = st.chat_message("assistant")
                message.write(ai_message_content)
            except:
                print("Generated Error")

# Function to generate a response using App 2's functionality
def generate_response(prompt):
    st.session_state['messages'].append({"role": "user", "content": prompt})
    embed = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=OPENAI_API_KEY)

    pc = Pinecone(api_key=PINE_API_KEY)
    index = pc.Index(pinecone_index_name)
    time.sleep(1)  # Ensure index is ready
    index.describe_index_stats()

    vectorstore = PineconeVectorStore(index, embed, "text", namespace=namespace_name)
    retriever = vectorstore.as_retriever()

    template = """You are an expert software developer who specializes in APIs. Answer the user's question based only on the following context:
                {context}

                Chat History:
                {chat_history}

                Question: {question}
                """
    prompt_template = ChatPromptTemplate.from_template(template)

    if model_name == "Claude-Sonnet":
        chat_model = ChatAnthropic(temperature=0, model="claude-3-sonnet-20240229", anthropic_api_key=anthropic_api_key)
    elif model_name == "mixtral-groq":
        chat_model = ChatGroq(temperature=0, groq_api_key=groq_api_key, model_name="mixtral-8x7b-32768")
    else:
        chat_model = ChatOpenAI(temperature=0, model=model_name, openai_api_key=openai_api_key2)

    memory = ConversationBufferMemory(
        return_messages=True, output_key="answer", input_key="question"
    )

    # Loading the previous chat messages into memory
    for i in range(len(st.session_state['generated'])):
        # Replaced "Answer: " with "" to stop the model from learning to add "Answer: " to the beginning by itself
        memory.save_context({"question": st.session_state["past"][i]}, {"answer": st.session_state["generated"][i].replace("Answer: ", "")})

    # Prints the memory that the model will be using
    print(f"Memory: {memory.load_memory_variables({})}")

    rag_chain = (
        RunnablePassthrough.assign(context=(lambda x: x["context"]), chat_history=lambda x: get_buffer_string(x["chat_history"]))
        | prompt_template
        | chat_model
        | StrOutputParser()
    )
    
    rag_chain_with_source = RunnableParallel(
        {"context": retriever, "question": RunnablePassthrough(), "chat_history": RunnableLambda(memory.load_memory_variables) | itemgetter("history")}
    ).assign(answer=rag_chain)

    # Function that extracts the individual tokens from the output of the model
    def make_stream():
        sources = []
        st.session_state['generated'].append("Answer: ")
        yield st.session_state['generated'][-1]

        for chunk in rag_chain_with_source.stream(prompt):

            if list(chunk.keys())[0] == 'answer':
                st.session_state['generated'][-1] += chunk['answer']
                yield chunk['answer']

            elif list(chunk.keys())[0] == 'context':
                # sources = chunk['context']
                sources = [doc.metadata['source'] for doc in chunk['context']]

        sources_txt = "\n\nSources:\n" + "\n".join(sources)
        st.session_state['generated'][-1] += sources_txt
        yield sources_txt

    # Sending the message as a stream using the function above
    print("Running the response streamer...")
    with response_container:
        message = st.chat_message("assistant")
        my_generator = make_stream()
        message.write_stream(my_generator)

    

    formatted_response = st.session_state['generated'][-1]

    #response = rag_chain_with_source.invoke(prompt)

    #sources = [doc.metadata['source'] for doc in response['context']]

    #answer = response['answer']  # Extracting the 'answer' part

    #formatted_response = f"Answer: {answer}\n\nSources:\n" + "\n".join(sources)

    st.session_state['messages'].append({"role": "assistant", "content": formatted_response})

    return formatted_response

# Container for chat history and text box
response_container = st.container()
container = st.container()

# Implementing chat input as opposed to a form because chat_input stays locked at the bottom
if prompt := st.chat_input("Ask a question..."):
        # I moved reponse here because, for some reason, I get an error if I only have an if statement for user_input later...
        st.session_state['past'].append(prompt)
        refresh_text()

        response = generate_response(prompt)