Spaces:
Runtime error
Runtime error
from src import streamlit_utils | |
from src.prompts import AGENT_SYSTEM_PROMPT, AGENT_USER_PROMPT, RAG_USER_PROMPT, TRAVERSIALAI_USER_PROMPT | |
from src.retriever import Retriever | |
import streamlit as st | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain.memory import ChatMessageHistory | |
import re | |
import requests | |
import os | |
from qdrant_client import QdrantClient | |
collection_name = 'hotels' | |
st.set_page_config(page_title="Hotels search chatbot", page_icon="⭐") | |
st.header('Hotels search chatbot') | |
st.write('[![view source code and description](https://img.shields.io/badge/view_source_code-gray?logo=github)](https://github.com/Maksimov-Dmitry/traversaal-ai-hackathon)') | |
st.write('Developed by [Dmitry Maksimov](https://www.linkedin.com/in/maksimov-dmitry/), [email protected] and [Ilya Dudnik](https://www.linkedin.com/in/ilia-dudnik-5b8018271/), [email protected]') | |
st.sidebar.header('Choose your preferences') | |
n_hotels = st.sidebar.number_input('Number of hotels', min_value=1, max_value=10, value=3) | |
def get_db_client(path='data/db'): | |
client = QdrantClient(path=path) | |
return client | |
def add_new_info(chat_history, queries): | |
"""After the user has changed any parameters (city, price, rating), we notify the Agent about it. | |
The information is added to the chat history. | |
Args: | |
chat_history: history of the chat | |
queries (list): list of queries that the user has changed | |
""" | |
for query in queries: | |
chat_history.add_user_message(query) | |
chat_history.add_ai_message('Ok, got it!') | |
def check_params(params): | |
"""Check if the user has changed the parameters (city, price, rating). | |
If the user has changed the parameters, the corresponding queries are created. | |
Args: | |
params (dict): dictionary with the parameters | |
Returns: | |
list: list of queries that the user has changed | |
""" | |
changed_params = [] | |
if 'prev_params' not in st.session_state: | |
st.session_state.prev_params = {'city': '<BLANK>', 'price': '<BLANK>', 'rating': '<BLANK>'} | |
if st.session_state.prev_params['city'] != params['city']: | |
changed_params.append(f'I want to find hotels in {params["city"]}' if params['city'] else 'I want to find hotels in any city') | |
if st.session_state.prev_params['price'] != params['price']: | |
changed_params.append(f'I want to find hotels in price range {params["price"]}' if params['price'] else 'I want to find hotels in any price range') | |
if st.session_state.prev_params['rating'] != params['rating']: | |
changed_params.append(f'I want to find hotels with rating greater than {params["rating"]}') | |
st.session_state.prev_params = params | |
return changed_params | |
def get_parameters(db_client): | |
"""Get the parameters from the user (city, price, rating), | |
The provided metadata (in case it was provided by the user) is used in the MixedRetrieval from Qdrant vector DB | |
""" | |
points, _ = db_client.scroll( | |
collection_name=collection_name, | |
limit=1e9, | |
with_payload=True, | |
with_vectors=False, | |
) | |
cities = ['Doest not matter'] + list(set([point.payload['city'] for point in points])) | |
city = st.sidebar.selectbox('City', list(cities), index=0) | |
if city == 'Doest not matter': | |
city = None | |
prices = ['Doest not matter'] + list(set([point.payload['price'] for point in points])) | |
price = st.sidebar.selectbox('Price', list(prices), index=0) | |
if price == 'Doest not matter': | |
price = None | |
rating = st.sidebar.slider('Min hotel rating', min_value=.0, max_value=5.0, value=4.5, step=.5) | |
return dict(city=city, price=price, rating=rating) | |
class HotelsSearchChatbot: | |
""" | |
This is the Agent class. It is responsible for the decision-making during conversation with the user. | |
Based on the user's query, the Agent decides which action to take and how to present result to the user. | |
""" | |
def __init__(self, db_client): | |
streamlit_utils.configure_api_keys() | |
self.llm_model = "gpt-4-1106-preview" | |
self.temperature = 0.6 | |
self.embeedings_model = "text-embedding-3-large" | |
self.rerank_model = 'rerank-multilingual-v2.0' | |
self.ares_api_key = os.environ.get("ARES_API_KEY") | |
self.db_client = db_client | |
def _traversialai(self, query): | |
"""Acquiring information from the internet using the Traversaal.ai. | |
Args: | |
query (str): search query | |
Returns: | |
str: information from the internet based on the query | |
""" | |
url = "https://api-ares.traversaal.ai/live/predict" | |
payload = {"query": [query]} | |
headers = { | |
"x-api-key": self.ares_api_key, | |
"content-type": "application/json" | |
} | |
response = requests.post(url, json=payload, headers=headers) | |
try: | |
return response.json()['data']['response_text'] | |
except: | |
return None | |
def _get_action(self, text): | |
"""Parse (read) the action and the action input from the response of the Agent | |
(after he made a decision what to do). | |
'action' and 'action_input' indicate whether we need to query additional tools | |
(vector DB, Traversaal AI) and how. | |
Args: | |
text (str): response of the Agent, which contains the action and the action input | |
Returns: | |
tuple: action, action input | |
""" | |
action_pattern = r"Action:\s*(.*)\n" | |
action_input_pattern = r"Action Input:\s*(.*)" | |
action_match = re.search(action_pattern, text) | |
action_input_match = re.search(action_input_pattern, text) | |
action = action_match.group(1) if action_match else None | |
action_input = action_input_match.group(1) if action_input_match else None | |
return action, action_input | |
def _make_action(self, action, action_input, retriever, chain, chat_history, config, retriever_params): | |
"""Take the action corresponding to 'action' and 'action input'. The 'action' can be one of the following: | |
'nothing' - Agent is capable of dealing on its own without use of additional tools, | |
'hotels_data_base' - Agent decides to get the information from the hotels vector DB, | |
'ares_api' - Agent requires additional information from the internet using the Traversaal.ai. | |
Args: | |
action (str): action to make | |
action_input (str): action input (formulated by Agent search query) | |
retriever (Retriever): Retriever object | |
chain (Chain): Chain object | |
chat_history (ChatMessageHistory): history of the chat | |
config (dict): handlers for a LangChain invoke method | |
retriever_params (dict): parameters for the Retriever | |
""" | |
if action == 'nothing': | |
st.markdown(action_input) | |
return action_input | |
if action == 'hotels_data_base': | |
context = retriever(action_input, top_k=n_hotels, **retriever_params) | |
chat_history.add_user_message(RAG_USER_PROMPT.format(context=context, query=action_input)) | |
response = chain.invoke({"messages": chat_history.messages}, config) | |
chat_history.messages.pop() | |
return response.content | |
if action == 'ares_api': | |
context = self._traversialai(action_input) | |
chat_history.add_user_message(TRAVERSIALAI_USER_PROMPT.format(context=context, query=action_input)) | |
response = chain.invoke({"messages": chat_history.messages}, config) | |
chat_history.messages.pop() | |
return response.content | |
return None | |
def setup_chain(_self): | |
retriever = Retriever(embedding_model=_self.embeedings_model, llm_model=_self.llm_model, | |
rerank_model=_self.rerank_model, db_client=_self.db_client, db_collection=collection_name) | |
chat_history = ChatMessageHistory() | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
AGENT_SYSTEM_PROMPT, | |
), | |
MessagesPlaceholder(variable_name="messages"), | |
] | |
) | |
chat = ChatOpenAI(model=_self.llm_model, temperature=_self.temperature, streaming=True) | |
chain = prompt | chat | |
return chain, chat_history, retriever | |
def main(self, params): | |
chain, chat_history, retriever = self.setup_chain() | |
user_query = st.chat_input(placeholder="Ask me anything!") | |
if user_query: | |
streamlit_utils.display_msg(user_query, 'user') | |
# add new info to the chat history | |
queries = check_params(params) | |
add_new_info(chat_history, queries) | |
# get the action and the action input based on the user's query | |
chat_history.add_user_message(AGENT_USER_PROMPT.format(input=user_query)) | |
action_response = chain.invoke({"messages": chat_history.messages}) | |
chat_history.messages.pop() | |
action, action_input = self._get_action(action_response.content) | |
with st.chat_message("assistant"): | |
st_cb = streamlit_utils.StreamHandler(st.empty()) | |
# create response on the user's query | |
response = self._make_action(action, action_input, | |
retriever, chain, chat_history, {"callbacks": [st_cb]}, params) | |
chat_history.add_user_message(user_query) | |
if response is None: | |
response = 'Sorry, I cannot help you with it. Could you rephrase your question?' | |
st.markdown(response) | |
chat_history.add_ai_message(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
if __name__ == "__main__": | |
db_client = get_db_client() | |
params = get_parameters(db_client) | |
obj = HotelsSearchChatbot(db_client) | |
obj.main(params) | |