gkrthk commited on
Commit
7f409ac
1 Parent(s): fb5c9b9
Files changed (1) hide show
  1. confluence_qa.py +9 -12
confluence_qa.py CHANGED
@@ -1,6 +1,6 @@
1
  from langchain.document_loaders import ConfluenceLoader
2
  from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter,RecursiveCharacterTextSplitter,SentenceTransformersTokenTextSplitter
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,pipeline,DistilBertTokenizer,DistilBertForQuestionAnswering
4
  from langchain.llms.huggingface_pipeline import HuggingFacePipeline
5
  from langchain.prompts import PromptTemplate
6
  from langchain.chains import RetrievalQA
@@ -12,13 +12,10 @@ class ConfluenceQA:
12
  self.embeddings = HuggingFaceEmbeddings(model_name="multi-qa-MiniLM-L6-cos-v1")
13
 
14
  def define_model(self) -> None:
15
- tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
16
- model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad')
17
-
18
- # tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
19
- # model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
20
- pipe = pipeline("question-answering", model=model, tokenizer=tokenizer)
21
- self.llm = HuggingFacePipeline(pipeline = pipe,model_kwargs={"temperature": 0})
22
 
23
  def store_in_vector_db(self) -> None:
24
  persist_directory = self.config.get("persist_directory",None)
@@ -31,12 +28,11 @@ class ConfluenceQA:
31
  url=confluence_url, username=username, api_key=api_key
32
  )
33
  documents = loader.load(include_attachments=include_attachment, limit=100, space_key=space_key)
34
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
35
  documents = text_splitter.split_documents(documents)
36
  self.db = Chroma.from_documents(documents, self.embeddings)
37
- question = "How do I make a space public?"
38
- searchDocs = self.db.similarity_search(question)
39
- print(searchDocs[0].page_content)
40
 
41
  def retrieve_qa_chain(self) -> None:
42
  template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
@@ -58,4 +54,5 @@ class ConfluenceQA:
58
 
59
  def qa_bot(self, query:str):
60
  result = self.qa.run(query)
 
61
  return result
 
1
  from langchain.document_loaders import ConfluenceLoader
2
  from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter,RecursiveCharacterTextSplitter,SentenceTransformersTokenTextSplitter
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,pipeline,T5Tokenizer,T5ForConditionalGeneration
4
  from langchain.llms.huggingface_pipeline import HuggingFacePipeline
5
  from langchain.prompts import PromptTemplate
6
  from langchain.chains import RetrievalQA
 
12
  self.embeddings = HuggingFaceEmbeddings(model_name="multi-qa-MiniLM-L6-cos-v1")
13
 
14
  def define_model(self) -> None:
15
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
16
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
17
+ pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
18
+ self.llm = HuggingFacePipeline(pipeline = pipe,model_kwargs={"temperature": 0.5})
 
 
 
19
 
20
  def store_in_vector_db(self) -> None:
21
  persist_directory = self.config.get("persist_directory",None)
 
28
  url=confluence_url, username=username, api_key=api_key
29
  )
30
  documents = loader.load(include_attachments=include_attachment, limit=100, space_key=space_key)
31
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=10)
32
  documents = text_splitter.split_documents(documents)
33
  self.db = Chroma.from_documents(documents, self.embeddings)
34
+ # question = "How do I make a space public?"
35
+ # searchDocs = self.db.similarity_search(question)
 
36
 
37
  def retrieve_qa_chain(self) -> None:
38
  template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
 
54
 
55
  def qa_bot(self, query:str):
56
  result = self.qa.run(query)
57
+ print(result)
58
  return result