File size: 3,227 Bytes
1540bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16cede9
1540bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from langchain.embeddings import HuggingFaceEmbeddings,HuggingFaceInstructEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chat_models.openai import ChatOpenAI
from langchain import VectorDBQA
import pandas as pd

from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage
)

from datetime import datetime as dt

st.set_page_config(page_title="Tweets Question/Answering with Langchain and OpenAI", page_icon="πŸ”Ž")

system_template="""Use the following pieces of context to answer the users question. 
If you don't know the answer, just say that you don't know, don't try to make up an answer.
ALWAYS return a "SOURCES" part in your answer.
The "SOURCES" part should be a reference to the source of the document from which you got your answer.

Example of your response should be:

```
The answer is foo
SOURCES: xyz
```

Begin!
----------------
{context}"""
messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)

current_time = dt.strftime(dt.today(),'%d_%m_%Y_%H_%M')

st.markdown("## Financial Tweets GPT Search")

twitter_link = """
[![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi)
"""

st.markdown(twitter_link)

bi_enc_dict = {'mpnet-base-v2':"sentence-transformers/all-mpnet-base-v2",
              'instructor-base': 'hkunlp/instructor-base'}

search_input = st.text_input(
        label='Enter Your Search Query',value= "What are the most topical risks?", key='search')
        
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')

with open('tweets.txt') as f:
    tweets = f.read()

@st.experimental_singleton(suppress_st_warning=True)
def process_tweets(file,embed_model,query):
    '''Process file with latest tweets'''

    # Split tweets int chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    texts = text_splitter.split_text(file)

    model = bi_enc_dict[embed_model]

    if model == "hkunlp/instructor-large":
        emb = HuggingFaceInstructEmbeddings(model_name=model,
                                            query_instruction='Represent the Financial question for retrieving supporting documents: ',
                                            embed_instruction='Represent the Financial document for retrieval: ')
        
    elif model == "sentence-transformers/all-mpnet-base-v2":
        emb = HuggingFaceEmbeddings(model_name=model)

    docsearch = FAISS.from_texts(texts, emb)

    chain_type_kwargs = {"prompt": prompt}
    chain = VectorDBQA.from_chain_type(
    ChatOpenAI(temperature=0), 
    chain_type="stuff", 
    vectorstore=docsearch,
    chain_type_kwargs=chain_type_kwargs
    )

    result = chain({"query": query}, return_only_outputs=True)

    return result