File size: 7,499 Bytes
48a66db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c02837
48a66db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import queries, setup

import os
import time
import logging
import json

import pinecone
import openai

from langchain_community.vectorstores import Pinecone
from langchain_community.vectorstores import Chroma

from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import VoyageEmbeddings

from langchain_openai import OpenAI, ChatOpenAI
from langchain_community.llms import HuggingFaceHub

from ragatouille import RAGPretrainedModel

import streamlit as st

# Set up the page, enable logging
from dotenv import load_dotenv,find_dotenv
load_dotenv(find_dotenv(),override=True)
logging.basicConfig(filename='app_1_chatbot_ams_modular.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s', level=logging.DEBUG)

# Set the page title
st.set_page_config(
    page_title='Aerospace Chatbot: Modular',
)
st.title('Aerospace Mechanisms Chatbot')
with st.expander('''What's under the hood?'''):
    st.markdown('''
    This chatbot will look up from all Aerospace Mechanism Symposia in the following location: https://huggingface.co/spaces/ai-aerospace/aerospace_chatbots/tree/main/data/AMS
    Example questions:
    * What are examples of latch failures which have occurred due to improper fitup?
    * What are examples of lubricants which should be avoided for space mechanism applications?
    ''')
filter_toggle=st.checkbox('Filter response with last received sources?')

sb=setup.load_sidebar(config_file='../config/config.json',
                      index_data_file='../config/index_data.json',
                      vector_databases=True,
                      embeddings=True,
                      rag_type=True,
                      index_name=True,
                      llm=True,
                      model_options=True,
                      secret_keys=True)

secrets=setup.set_secrets(sb) # Take secrets from .env file first, otherwise from sidebar

# Set up chat history
if 'qa_model_obj' not in st.session_state:
    st.session_state.qa_model_obj = []
if 'message_id' not in st.session_state:
    st.session_state.message_id = 0
if 'messages' not in st.session_state:
    st.session_state.messages = []
for message in st.session_state.messages:
    with st.chat_message(message['role']):
        st.markdown(message['content'])

# Define chat
if prompt := st.chat_input('Prompt here'):
    # User prompt
    st.session_state.messages.append({'role': 'user', 'content': prompt})
    with st.chat_message('user'):
        st.markdown(prompt)
    # Assistant response
    with st.chat_message('assistant'):
        message_placeholder = st.empty()

        with st.status('Generating response...') as status:
            t_start=time.time()

            st.session_state.message_id += 1
            st.write('Starting reponse generation for message: '+str(st.session_state.message_id))
            logging.info('Starting reponse generation for message: '+str(st.session_state.message_id))

             # Process some items
            if sb['model_options']['output_level'] == 'Concise':
                out_token = 50
            else:
                out_token = 516
            logging.info('Output tokens: '+str(out_token))
            
            if st.session_state.message_id==1:
                # Define embeddings
                if sb['query_model']=='Openai':
                    query_model=OpenAIEmbeddings(model=sb['embedding_name'],openai_api_key=secrets['OPENAI_API_KEY'])
                elif sb['query_model']=='Voyage':
                    query_model=VoyageEmbeddings(model=sb['embedding_name'],voyage_api_key=secrets['VOYAGE_API_KEY'])
                elif sb['index_type']=='RAGatouille':
                    query_model=RAGPretrainedModel.from_index('../db/.ragatouille/colbert/indexes/'+sb['index_name'])
                logging.info('Query model set: '+str(query_model))

                # Define LLM
                if sb['llm_source']=='OpenAI':
                    llm = ChatOpenAI(model_name=sb['llm_model'],
                                    temperature=sb['model_options']['temperature'],
                                    openai_api_key=secrets['OPENAI_API_KEY'],
                                    max_tokens=out_token)
                elif sb['llm_source']=='Hugging Face':
                    llm = HuggingFaceHub(repo_id=sb['llm_model'],
                                        model_kwargs={"temperature": sb['model_options']['temperature'], "max_length": out_token})
                logging.info('LLM model set: '+str(llm))

                # Initialize QA model object
                if 'search_type' in sb['model_options']: 
                    search_type=sb['model_options']['search_type']
                else:
                    search_type=None
                st.session_state.qa_model_obj=queries.QA_Model(sb['index_type'],
                                                               sb['index_name'],
                                                               query_model,
                                                               llm,
                                                               k=sb['model_options']['k'],
                                                               search_type=search_type,
                                                               filter_arg=False)
                logging.info('QA model object set: '+str(st.session_state.qa_model_obj))
            if st.session_state.message_id>1:
                logging.info('Updating model with sidebar settings...')
                # Update LLM
                if sb['llm_source']=='OpenAI':
                    llm = ChatOpenAI(model_name=sb['llm_model'],
                                    temperature=sb['model_options']['temperature'],
                                    openai_api_key=secrets['OPENAI_API_KEY'],
                                    max_tokens=out_token)
                elif sb['llm_source']=='Hugging Face':
                    llm = HuggingFaceHub(repo_id=sb['llm_model'],
                                        model_kwargs={"temperature": sb['model_options']['temperature'], "max_length": out_token})
                logging.info('LLM model set: '+str(llm))

                st.session_state.qa_model_obj.update_model(llm,
                                                           k=sb['model_options']['k'],
                                                           search_type=sb['model_options']['search_type'],
                                                           filter_arg=filter_toggle)
                logging.info('QA model object updated: '+str(st.session_state.qa_model_obj))
            
            st.write('Searching vector database, generating prompt...')
            logging.info('Searching vector database, generating prompt...')
            st.session_state.qa_model_obj.query_docs(prompt)
            ai_response=st.session_state.qa_model_obj.result['answer'].content
            message_placeholder.markdown(ai_response)
            t_delta=time.time() - t_start
            status.update(label='Prompt generated in '+"{:10.3f}".format(t_delta)+' seconds', state='complete', expanded=False)
            
        st.session_state.messages.append({'role': 'assistant', 'content': ai_response})
        logging.info(f'Messaging complete for {st.session_state.message_id}.')

# Add reset button
if st.button('Restart session'):
    st.session_state.qa_model_obj = []
    st.session_state.message_id = 0
    st.session_state.messages = []