Spaces:
Runtime error
Runtime error
gkrthk
commited on
Commit
•
33fe60d
1
Parent(s):
52f0eee
initial commit
Browse files- app.py +74 -0
- confluence_qa.py +52 -0
- requirements.txt +12 -0
app.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
|
6 |
+
# Import the ConfluenceQA class
|
7 |
+
from confluence_qa import ConfluenceQA
|
8 |
+
|
9 |
+
st.set_page_config(
|
10 |
+
page_title='Q&A Bot for Confluence Page',
|
11 |
+
page_icon='⚡',
|
12 |
+
layout='wide',
|
13 |
+
initial_sidebar_state='auto',
|
14 |
+
)
|
15 |
+
if "config" not in st.session_state:
|
16 |
+
st.session_state["config"] = {}
|
17 |
+
if "confluence_qa" not in st.session_state:
|
18 |
+
st.session_state["confluence_qa"] = None
|
19 |
+
|
20 |
+
@st.cache_resource
|
21 |
+
def load_confluence(config):
|
22 |
+
# st.write("loading the confluence page")
|
23 |
+
confluence_qa = ConfluenceQA(config=config)
|
24 |
+
confluence_qa.init_embeddings()
|
25 |
+
confluence_qa.define_model()
|
26 |
+
confluence_qa.store_in_vector_db()
|
27 |
+
confluence_qa.retrieve_qa_chain()
|
28 |
+
return confluence_qa
|
29 |
+
|
30 |
+
with st.sidebar.form(key ='Form1'):
|
31 |
+
st.markdown('## Add your configs')
|
32 |
+
confluence_url = st.text_input("paste the confluence URL", "https://templates.atlassian.net/wiki/")
|
33 |
+
username = st.text_input(label="confluence username",
|
34 |
+
help="leave blank if confluence page is public",
|
35 |
+
type="password")
|
36 |
+
space_key = st.text_input(label="confluence space",
|
37 |
+
help="Space of Confluence",
|
38 |
+
value="RD")
|
39 |
+
api_key = st.text_input(label="confluence api key",
|
40 |
+
help="leave blank if confluence page is public",
|
41 |
+
type="password")
|
42 |
+
submitted1 = st.form_submit_button(label='Submit')
|
43 |
+
|
44 |
+
# if submitted1 and confluence_url and space_key:
|
45 |
+
# st.session_state["config"] = {
|
46 |
+
# "persist_directory": None,
|
47 |
+
# "confluence_url": confluence_url,
|
48 |
+
# "username": username if username != "" else None,
|
49 |
+
# "api_key": api_key if api_key != "" else None,
|
50 |
+
# "space_key": space_key,
|
51 |
+
# }
|
52 |
+
# with st.spinner(text="Ingesting Confluence..."):
|
53 |
+
# ### Hardcoding for https://templates.atlassian.net/wiki/ and space RD to avoid multiple OpenAI calls.
|
54 |
+
# config = st.session_state["config"]
|
55 |
+
# if config["confluence_url"] == "https://templates.atlassian.net/wiki/" and config["space_key"] =="RD":
|
56 |
+
# config["persist_directory"] = "chroma_db"
|
57 |
+
# st.session_state["config"] = config
|
58 |
+
|
59 |
+
# st.session_state["confluence_qa"] = load_confluence(st.session_state["config"])
|
60 |
+
# st.write("Confluence Space Ingested")
|
61 |
+
|
62 |
+
|
63 |
+
st.title("Confluence Q&A Demo")
|
64 |
+
|
65 |
+
question = st.text_input('Ask a question', "How do I make a space public?")
|
66 |
+
|
67 |
+
if st.button('Get Answer', key='button2'):
|
68 |
+
with st.spinner(text="Asking LLM..."):
|
69 |
+
confluence_qa = st.session_state.get("confluence_qa")
|
70 |
+
if confluence_qa is not None:
|
71 |
+
result = confluence_qa.answer_confluence(question)
|
72 |
+
st.write(result)
|
73 |
+
else:
|
74 |
+
st.write("Please load Confluence page first.")
|
confluence_qa.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.document_loaders import ConfluenceLoader
|
2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter,TokenTextSplitter
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,pipeline
|
4 |
+
from langchain import HuggingFacePipeline
|
5 |
+
from langchain.prompts import PromptTemplate
|
6 |
+
from langchain.chains import RetrievalQA
|
7 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
8 |
+
from langchain.vectorstores import FAISS
|
9 |
+
|
10 |
+
class ConfluenceQA:
|
11 |
+
def init_embeddings(self) -> None:
|
12 |
+
self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
13 |
+
|
14 |
+
def define_model(self) -> None:
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
|
16 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
|
17 |
+
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024)
|
18 |
+
self.llm = HuggingFacePipeline(pipeline = pipe,model_kwargs={"temperature": 0, "max_length": 1024},)
|
19 |
+
|
20 |
+
def store_in_vector_db(self) -> None:
|
21 |
+
config = self.config
|
22 |
+
loader = ConfluenceLoader(
|
23 |
+
url=config.url, username=config.username, api_key=config.apiKey
|
24 |
+
)
|
25 |
+
documents = loader.load(include_attachments=config.includeAttachements, limit=50, page_ids=config.page_ids)
|
26 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
|
27 |
+
documents = text_splitter.split_documents(documents)
|
28 |
+
# text_splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=10) # This the encoding for text-embedding-ada-002
|
29 |
+
# texts = text_splitter.split_documents(texts)
|
30 |
+
self.db = FAISS.from_documents(documents, self.embeddings)
|
31 |
+
|
32 |
+
def retrieve_qa_chain(self) -> None:
|
33 |
+
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible.
|
34 |
+
{context}
|
35 |
+
Question: {question}
|
36 |
+
Helpful Answer:"""
|
37 |
+
QA_CHAIN_PROMPT = PromptTemplate(
|
38 |
+
template=template, input_variables=["context", "question"]
|
39 |
+
)
|
40 |
+
chain_type_kwargs = {"prompt": QA_CHAIN_PROMPT}
|
41 |
+
self.qa = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff", retriever=self.db.as_retriever(), chain_type_kwargs=chain_type_kwargs)
|
42 |
+
|
43 |
+
def __init__(self,config) -> None:
|
44 |
+
self.db=None
|
45 |
+
self.embeddings=None
|
46 |
+
self.llm=None
|
47 |
+
self.config=config
|
48 |
+
self.qa=None
|
49 |
+
|
50 |
+
def qa_bot(self, query:str):
|
51 |
+
result = self.qa.run(query)
|
52 |
+
return result
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
faiss-cpu
|
5 |
+
pypdf
|
6 |
+
sentence-transformers
|
7 |
+
atlassian-python-api
|
8 |
+
pytesseract
|
9 |
+
reportlab
|
10 |
+
Pillow
|
11 |
+
svglib
|
12 |
+
tiktoken
|