gkrthk commited on
Commit
b4493c7
1 Parent(s): f22b8d0
Files changed (1) hide show
  1. confluence_qa.py +12 -14
confluence_qa.py CHANGED
@@ -1,6 +1,6 @@
1
  from langchain.document_loaders import ConfluenceLoader
2
  from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter,RecursiveCharacterTextSplitter,SentenceTransformersTokenTextSplitter
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,pipeline
4
  from langchain.llms.huggingface_pipeline import HuggingFacePipeline
5
  from langchain.prompts import PromptTemplate
6
  from langchain.chains import RetrievalQA
@@ -12,10 +12,13 @@ class ConfluenceQA:
12
  self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
13
 
14
  def define_model(self) -> None:
15
- tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
16
- model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
17
- pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, truncation=True)
18
- self.llm = HuggingFacePipeline(pipeline = pipe,model_kwargs={"temperature": 0, "max_length": 1024})
 
 
 
19
 
20
  def store_in_vector_db(self) -> None:
21
  persist_directory = self.config.get("persist_directory",None)
@@ -28,17 +31,12 @@ class ConfluenceQA:
28
  url=confluence_url, username=username, api_key=api_key
29
  )
30
  documents = loader.load(include_attachments=include_attachment, limit=100, space_key=space_key)
31
- # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
32
- # documents = text_splitter.split_documents(documents)
33
- # print(documents)
34
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
35
- documents = text_splitter.split_documents(documents)
36
- # text_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=10)
37
- # documents = text_splitter.split_documents(documents)
38
- # print(documents)
39
- text_splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=10, encoding_name="cl100k_base") # This the encoding for text-embedding-ada-002
40
  documents = text_splitter.split_documents(documents)
41
  self.db = Chroma.from_documents(documents, self.embeddings)
 
 
 
42
 
43
 
44
  def retrieve_qa_chain(self) -> None:
 
1
  from langchain.document_loaders import ConfluenceLoader
2
  from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter,RecursiveCharacterTextSplitter,SentenceTransformersTokenTextSplitter
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,pipeline,DistilBertTokenizer,DistilBertForQuestionAnswering
4
  from langchain.llms.huggingface_pipeline import HuggingFacePipeline
5
  from langchain.prompts import PromptTemplate
6
  from langchain.chains import RetrievalQA
 
12
  self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
13
 
14
  def define_model(self) -> None:
15
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
16
+ model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad')
17
+
18
+ # tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
19
+ # model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
20
+ pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
21
+ self.llm = HuggingFacePipeline(pipeline = pipe,model_kwargs={"temperature": 0})
22
 
23
  def store_in_vector_db(self) -> None:
24
  persist_directory = self.config.get("persist_directory",None)
 
31
  url=confluence_url, username=username, api_key=api_key
32
  )
33
  documents = loader.load(include_attachments=include_attachment, limit=100, space_key=space_key)
34
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
 
 
 
 
 
 
 
 
35
  documents = text_splitter.split_documents(documents)
36
  self.db = Chroma.from_documents(documents, self.embeddings)
37
+ question = "How do I make a space public?"
38
+ searchDocs = self.db.similarity_search(question)
39
+ print(searchDocs[0].page_content)
40
 
41
 
42
  def retrieve_qa_chain(self) -> None: