Spaces:
Runtime error
Runtime error
import os | |
import streamlit as st | |
from streamlit_chat import message | |
from langchain_openai import OpenAIEmbeddings | |
from pinecone import Pinecone | |
import time | |
from langchain_pinecone.vectorstores import Pinecone as PineconeVectorStore | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnableParallel, RunnablePassthrough | |
from langchain_openai import ChatOpenAI | |
from langchain_groq import ChatGroq | |
from langchain_anthropic import ChatAnthropic | |
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string | |
from langchain.memory import ConversationBufferMemory | |
from langchain_core.runnables import RunnableLambda | |
from operator import itemgetter | |
# Streamlit App Configuration | |
st.set_page_config(page_title="Docu-Help") | |
# Dropdown for namespace selection | |
namespace_name = st.sidebar.selectbox("Select Website:", ('crawlee', ''), key='namespace_name') | |
# Read API keys from environment variables | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
PINE_API_KEY = os.getenv("PINE_API_KEY") | |
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY") | |
LANGCHAIN_TRACING_V2 = 'true' | |
LANGCHAIN_ENDPOINT = "https://api.smith.langchain.com" | |
LANGCHAIN_PROJECT = "docu-help" | |
# Sidebar for model selection and Pinecone index name input | |
st.sidebar.title("Sidebar") | |
model_name = st.sidebar.radio("Choose a model:", ("gpt-3.5-turbo-1106", "gpt-4-0125-preview", "Claude-Sonnet", "mixtral-groq")) | |
openai_api_key2 = st.sidebar.text_input("Enter OpenAI Key: ") | |
groq_api_key = st.sidebar.text_input("Groq API Key: ") | |
anthropic_api_key = st.sidebar.text_input("Claude API Key: ") | |
pinecone_index_name = os.getenv("pinecone_index_name") | |
namespace_name = "crawlee" | |
# Initialize session state variables if they don't exist | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = [] | |
if 'past' not in st.session_state: | |
st.session_state['past'] = [] | |
if 'messages' not in st.session_state: | |
st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}] | |
if 'total_cost' not in st.session_state: | |
st.session_state['total_cost'] = 0.0 | |
def refresh_text(): | |
with response_container: | |
for i in range(len(st.session_state['past'])): | |
try: | |
user_message_content = st.session_state["past"][i] | |
message = st.chat_message("user") | |
message.write(user_message_content) | |
except: | |
print("Past error") | |
try: | |
ai_message_content = st.session_state["generated"][i] | |
message = st.chat_message("assistant") | |
message.write(ai_message_content) | |
except: | |
print("Generated Error") | |
# Function to generate a response using App 2's functionality | |
def generate_response(prompt): | |
st.session_state['messages'].append({"role": "user", "content": prompt}) | |
embed = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=OPENAI_API_KEY) | |
pc = Pinecone(api_key=PINE_API_KEY) | |
index = pc.Index(pinecone_index_name) | |
time.sleep(1) # Ensure index is ready | |
index.describe_index_stats() | |
vectorstore = PineconeVectorStore(index, embed, "text", namespace=namespace_name) | |
retriever = vectorstore.as_retriever() | |
template = """You are an expert software developer who specializes in APIs. Answer the user's question based only on the following context: | |
{context} | |
Chat History: | |
{chat_history} | |
Question: {question} | |
""" | |
prompt_template = ChatPromptTemplate.from_template(template) | |
if model_name == "Claude-Sonnet": | |
chat_model = ChatAnthropic(temperature=0, model="claude-3-sonnet-20240229", anthropic_api_key=anthropic_api_key) | |
elif model_name == "mixtral-groq": | |
chat_model = ChatGroq(temperature=0, groq_api_key=groq_api_key, model_name="mixtral-8x7b-32768") | |
else: | |
chat_model = ChatOpenAI(temperature=0, model=model_name, openai_api_key=openai_api_key2) | |
memory = ConversationBufferMemory( | |
return_messages=True, output_key="answer", input_key="question" | |
) | |
# Loading the previous chat messages into memory | |
for i in range(len(st.session_state['generated'])): | |
# Replaced "Answer: " with "" to stop the model from learning to add "Answer: " to the beginning by itself | |
memory.save_context({"question": st.session_state["past"][i]}, {"answer": st.session_state["generated"][i].replace("Answer: ", "")}) | |
# Prints the memory that the model will be using | |
print(f"Memory: {memory.load_memory_variables({})}") | |
rag_chain = ( | |
RunnablePassthrough.assign(context=(lambda x: x["context"]), chat_history=lambda x: get_buffer_string(x["chat_history"])) | |
| prompt_template | |
| chat_model | |
| StrOutputParser() | |
) | |
rag_chain_with_source = RunnableParallel( | |
{"context": retriever, "question": RunnablePassthrough(), "chat_history": RunnableLambda(memory.load_memory_variables) | itemgetter("history")} | |
).assign(answer=rag_chain) | |
# Function that extracts the individual tokens from the output of the model | |
def make_stream(): | |
sources = [] | |
st.session_state['generated'].append("Answer: ") | |
yield st.session_state['generated'][-1] | |
for chunk in rag_chain_with_source.stream(prompt): | |
if list(chunk.keys())[0] == 'answer': | |
st.session_state['generated'][-1] += chunk['answer'] | |
yield chunk['answer'] | |
elif list(chunk.keys())[0] == 'context': | |
# sources = chunk['context'] | |
sources = [doc.metadata['source'] for doc in chunk['context']] | |
sources_txt = "\n\nSources:\n" + "\n".join(sources) | |
st.session_state['generated'][-1] += sources_txt | |
yield sources_txt | |
# Sending the message as a stream using the function above | |
print("Running the response streamer...") | |
with response_container: | |
message = st.chat_message("assistant") | |
my_generator = make_stream() | |
message.write_stream(my_generator) | |
formatted_response = st.session_state['generated'][-1] | |
#response = rag_chain_with_source.invoke(prompt) | |
#sources = [doc.metadata['source'] for doc in response['context']] | |
#answer = response['answer'] # Extracting the 'answer' part | |
#formatted_response = f"Answer: {answer}\n\nSources:\n" + "\n".join(sources) | |
st.session_state['messages'].append({"role": "assistant", "content": formatted_response}) | |
return formatted_response | |
# Container for chat history and text box | |
response_container = st.container() | |
container = st.container() | |
# Implementing chat input as opposed to a form because chat_input stays locked at the bottom | |
if prompt := st.chat_input("Ask a question..."): | |
# I moved reponse here because, for some reason, I get an error if I only have an if statement for user_input later... | |
st.session_state['past'].append(prompt) | |
refresh_text() | |
response = generate_response(prompt) |