File size: 11,455 Bytes
48a66db
 
 
9544071
48a66db
9544071
48a66db
9544071
48a66db
9544071
48a66db
 
9544071
48a66db
9544071
48a66db
 
 
 
 
9544071
48a66db
 
 
 
 
 
 
 
 
9544071
 
48a66db
9544071
48a66db
9544071
 
 
48a66db
9544071
 
 
 
48a66db
 
 
9544071
48a66db
 
 
 
 
 
 
9544071
 
 
48a66db
 
 
 
 
 
 
9544071
48a66db
 
 
 
 
 
 
 
 
9544071
48a66db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9544071
48a66db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9544071
48a66db
 
 
 
 
 
 
 
 
 
9544071
48a66db
 
 
 
9544071
48a66db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9544071
48a66db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9544071
 
48a66db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9544071
48a66db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9544071
48a66db
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import os
import logging
import re

from dotenv import load_dotenv, find_dotenv

import openai
import pinecone
import chromadb

from langchain_community.vectorstores import Pinecone
from langchain_community.vectorstores import Chroma

from langchain.memory import ConversationBufferMemory

from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain.schema import format_document
from langchain_core.messages import get_buffer_string

from prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT, DEFAULT_DOCUMENT_PROMPT, TEST_QUERY_PROMPT

# Set secrets from environment file
OPENAI_API_KEY=os.getenv('OPENAI_API_KEY')
VOYAGE_API_KEY=os.getenv('VOYAGE_API_KEY')
PINECONE_API_KEY=os.getenv('PINECONE_API_KEY')
HUGGINGFACEHUB_API_TOKEN=os.getenv('HUGGINGFACEHUB_API_TOKEN') 

# Class and functions
class QA_Model:
    def __init__(self, 
                 index_type,
                 index_name,
                 query_model,
                 llm,
                 k=6,
                 search_type='similarity',
                 fetch_k=50,
                 temperature=0,
                 chain_type='stuff',
                 filter_arg=False):
        
        self.index_type=index_type
        self.index_name=index_name
        self.query_model=query_model
        self.llm=llm
        self.k=k
        self.search_type=search_type
        self.fetch_k=fetch_k
        self.temperature=temperature
        self.chain_type=chain_type
        self.filter_arg=filter_arg
        self.sources=[]

        load_dotenv(find_dotenv(),override=True)

        # Define retriever search parameters
        search_kwargs = _process_retriever_args(self.filter_arg,
                                                self.sources,
                                                self.search_type,
                                                self.k,
                                                self.fetch_k)

        # Read in from the vector database
        if index_type=='Pinecone':
            pinecone.init(
                api_key=PINECONE_API_KEY
            )
            logging.info('Chat pinecone index name: '+str(index_name))
            logging.info('Chat query model: '+str(query_model))
            index = pinecone.Index(index_name)
            self.vectorstore = Pinecone(index,query_model,'page_content')
            logging.info('Chat vectorstore: '+str(self.vectorstore))

            # Test query
            test_query = self.vectorstore.similarity_search(TEST_QUERY_PROMPT)
            logging.info('Test query: '+str(test_query))
            if not test_query:
                raise ValueError("Pinecone vector database is not configured properly. Test query failed.")
            else:
                logging.info('Test query succeeded!')
            
            self.retriever=self.vectorstore.as_retriever(search_type=search_type,
                                                         search_kwargs=search_kwargs)
            logging.info('Chat retriever: '+str(self.retriever))
        elif index_type=='ChromaDB':
            logging.info('Chat chroma index name: '+str(index_name))
            logging.info('Chat query model: '+str(query_model))
            persistent_client = chromadb.PersistentClient(path='../db/chromadb')            
            self.vectorstore = Chroma(client=persistent_client,
                                      collection_name=index_name,
                                      embedding_function=query_model)
            logging.info('Chat vectorstore: '+str(self.vectorstore))

            # Test query
            test_query = self.vectorstore.similarity_search(TEST_QUERY_PROMPT)
            logging.info('Test query: '+str(test_query))
            if not test_query:
                raise ValueError("Chroma vector database is not configured properly. Test query failed.")
            else:
                logging.info('Test query succeeded!')
            
            self.retriever=self.vectorstore.as_retriever(search_type=search_type,
                                                         search_kwargs=search_kwargs)
            logging.info('Chat retriever: '+str(self.retriever))
        elif index_type=='RAGatouille':
            # Easy because the index is picked up directly.
            self.vectorstore=query_model
            logging.info('Chat query model:'+str(query_model))

             # Test query
            test_query = self.vectorstore.search(TEST_QUERY_PROMPT)
            logging.info('Test query: '+str(test_query))
            if not test_query:
                raise ValueError("Chroma vector database is not configured properly. Test query failed.")
            else:
                logging.info('Test query succeeded!')
            
            self.retriever=self.vectorstore.as_langchain_retriever()
            logging.info('Chat retriever: '+str(self.retriever))

        # Intialize memory
        self.memory = ConversationBufferMemory(
                        return_messages=True, output_key='answer', input_key='question')
        logging.info('Memory: '+str(self.memory))

        # Assemble main chain
        self.conversational_qa_chain=_define_qa_chain(self.llm,
                                                      self.retriever,
                                                      self.memory,
                                                      self.search_type,
                                                      search_kwargs)
    def query_docs(self,query):        
        self.memory.load_memory_variables({})
        logging.info('Memory content before qa result: '+str(self.memory))

        logging.info('Query: '+str(query))
        self.result = self.conversational_qa_chain.invoke({'question': query})
        logging.info('QA result: '+str(self.result))

        if self.index_type!='RAGatouille':
            self.sources = '\n'.join(str(data.metadata) for data in self.result['references'])
            self.result['answer'].content += '\nSources: \n'+self.sources
            logging.info('Sources: '+str(self.sources))
            logging.info('Response with sources: '+str(self.result['answer'].content))
        else:
            # RAGatouille doesn't have metadata, need to extract from context first.
            extracted_metadata = []
            pattern = r'\{([^}]*)\}(?=[^{}]*$)' # Regular expression pattern to match the last curly braces

            for ref in self.result['references']:
                match = re.search(pattern, ref.page_content)
                if match:
                    extracted_metadata.append("{"+match.group(1)+"}")
            self.sources = '\n'.join(extracted_metadata)
            self.result['answer'].content += '\nSources: \n'+self.sources
            logging.info('Sources: '+str(self.sources))
            logging.info('Response with sources: '+str(self.result['answer'].content))

        self.memory.save_context({'question': query}, {'answer': self.result['answer'].content})
        logging.info('Memory content after qa result: '+str(self.memory))

    def update_model(self,
                     llm,
                     k=6,
                     search_type='similarity',
                     fetch_k=50,
                     filter_arg=False):

        self.llm=llm
        self.k=k
        self.search_type=search_type
        self.fetch_k=fetch_k
        self.filter_arg=filter_arg
        
        # Define retriever search parameters
        search_kwargs = _process_retriever_args(self.filter_arg,
                                                self.sources,
                                                self.search_type,
                                                self.k,
                                                self.fetch_k)
        # Update conversational retrieval chain
        self.conversational_qa_chain=_define_qa_chain(self.llm,
                                                      self.retriever,
                                                      self.memory,
                                                      self.search_type,
                                                      search_kwargs)
        logging.info('Updated qa chain: '+str(self.conversational_qa_chain))

# Internal functions
def _combine_documents(docs, 
                        document_prompt=DEFAULT_DOCUMENT_PROMPT, 
                        document_separator='\n\n'):
    '''
    Combine a list of documents into a single string.
    '''
    # TODO: this would be where stuff, map reduce, etc. would go
    doc_strings = [format_document(doc, document_prompt) for doc in docs]
    return document_separator.join(doc_strings)
def _define_qa_chain(llm,
                     retriever,
                     memory,
                     search_type,
                     search_kwargs):
    '''
    Define the conversational QA chain.
    '''
    # This adds a 'memory' key to the input object
    loaded_memory = RunnablePassthrough.assign(
                        chat_history=RunnableLambda(memory.load_memory_variables) 
                        | itemgetter('history'))  
    logging.info('Loaded memory: '+str(loaded_memory))
    
    # Assemble main chain
    standalone_question = {
        'standalone_question': {
            'question': lambda x: x['question'],
            'chat_history': lambda x: get_buffer_string(x['chat_history'])}
        | CONDENSE_QUESTION_PROMPT
        | llm
        | StrOutputParser()}
    logging.info('Condense inputs as a standalong question: '+str(standalone_question))
    retrieved_documents = {
        'source_documents': itemgetter('standalone_question') 
                            | retriever,
        'question': lambda x: x['standalone_question']}
    logging.info('Retrieved documents: '+str(retrieved_documents))
    # Now we construct the inputs for the final prompt
    final_inputs = {
        'context': lambda x: _combine_documents(x['source_documents']),
        'question': itemgetter('question')}
    logging.info('Combined documents: '+str(final_inputs))
    # And finally, we do the part that returns the answers
    answer = {
        'answer': final_inputs 
                    | QA_PROMPT 
                    | llm,
        'references': itemgetter('source_documents')}
    conversational_qa_chain = loaded_memory | standalone_question | retrieved_documents | answer
    logging.info('Conversational QA chain: '+str(conversational_qa_chain))
    return conversational_qa_chain
def _process_retriever_args(filter_arg,
                            sources,
                            search_type,
                            k,
                            fetch_k):
    '''
    Process arguments for retriever.
    '''
    # Implement filter
    if filter_arg:
        filter_list = list(set(item['source'] for item in sources[-1]))
        filter_items=[]
        for item in filter_list:
            filter_item={'source': item}
            filter_items.append(filter_item)
        filter={'$or':filter_items}
    else:
        filter=None

    # Impement filtering and number of documents to return
    if search_type=='mmr':
        search_kwargs={'k':k,'fetch_k':fetch_k,'filter':filter} # See as_retriever docs for parameters
    else:
        search_kwargs={'k':k,'filter':filter} # See as_retriever docs for parameters
    
    return search_kwargs