LOUIS SANNA commited on
Commit
cc2ce8c
1 Parent(s): 6e28a81

feat(loader)

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -44,3 +44,6 @@ documents/climate_gpt_v2_only_giec.faiss filter=lfs diff=lfs merge=lfs -text
44
  documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
45
  climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
46
  climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
 
 
 
 
44
  documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
45
  climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
46
  climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
47
+ data filter=lfs diff=lfs merge=lfs -text
48
+ chroma_db filter=lfs diff=lfs merge=lfs -text
49
+ chroma_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: ClimateQ&A
3
  emoji: 🌍
4
  colorFrom: blue
5
  colorTo: red
@@ -9,4 +9,12 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # Climate Q&A
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: AnythingQ&A
3
  emoji: 🌍
4
  colorFrom: blue
5
  colorTo: red
 
9
  pinned: false
10
  ---
11
 
12
+ # Anything Q&A
13
+
14
+ A clone of the amazing https://huggingface.co/spaces/Ekimetrics/climate-question-answering.
15
+
16
+ ## Build vector index
17
+
18
+ ```bash
19
+ python -m climateqa.build_index
20
+ ```
app.py CHANGED
@@ -2,18 +2,18 @@ import gradio as gr
2
 
3
  from utils import create_user_id
4
 
5
-
6
  # Langchain
7
  from langchain.embeddings import HuggingFaceEmbeddings
8
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
9
 
10
  # ClimateQ&A imports
 
11
  from climateqa.llm import get_llm
12
- from climateqa.logging import log
13
  from climateqa.chains import load_qa_chain_with_text
14
  from climateqa.chains import load_reformulation_chain
15
- from climateqa.vectorstore import get_pinecone_vectorstore
16
- from climateqa.retriever import ClimateQARetriever
17
  from climateqa.prompts import audience_prompts
18
 
19
  # Load environment variables in local mode
@@ -113,13 +113,10 @@ class StreamingGradioCallbackHandler(BaseCallbackHandler):
113
 
114
 
115
  # Create embeddings function and LLM
116
- embeddings_function = HuggingFaceEmbeddings(
117
- model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1"
118
- )
119
-
120
 
121
  # Create vectorstore and retriever
122
- vectorstore = get_pinecone_vectorstore(embeddings_function)
123
 
124
  # ---------------------------------------------------------------------------
125
  # ClimateQ&A Streaming
@@ -148,8 +145,8 @@ def fetch_sources(query, sources):
148
  llm_reformulation = get_llm(
149
  max_tokens=512, temperature=0.0, verbose=True, streaming=False
150
  )
151
- retriever = ClimateQARetriever(
152
- vectorstore=vectorstore, sources=sources, k_summary=3, k_total=10
153
  )
154
  reformulation_chain = load_reformulation_chain(llm_reformulation)
155
 
@@ -265,6 +262,11 @@ def answer_bot(query, history, docs, question, language, audience):
265
  def make_html_source(source, i):
266
  meta = source.metadata
267
  content = source.page_content.split(":", 1)[1].strip()
 
 
 
 
 
268
  return f"""
269
  <div class="card">
270
  <div class="card-content">
@@ -273,9 +275,7 @@ def make_html_source(source, i):
273
  </div>
274
  <div class="card-footer">
275
  <span>{meta['name']}</span>
276
- <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
277
- <span role="img" aria-label="Open PDF">🔗</span>
278
- </a>
279
  </div>
280
  </div>
281
  """
 
2
 
3
  from utils import create_user_id
4
 
 
5
  # Langchain
6
  from langchain.embeddings import HuggingFaceEmbeddings
7
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
8
 
9
  # ClimateQ&A imports
10
+ from climateqa.embeddings import EMBEDDING_MODEL_NAME
11
  from climateqa.llm import get_llm
12
+ from climateqa.qa_logging import log
13
  from climateqa.chains import load_qa_chain_with_text
14
  from climateqa.chains import load_reformulation_chain
15
+ from climateqa.vectorstore import get_vectorstore
16
+ from climateqa.retriever import QARetriever
17
  from climateqa.prompts import audience_prompts
18
 
19
  # Load environment variables in local mode
 
113
 
114
 
115
  # Create embeddings function and LLM
116
+ embeddings_function = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
 
 
 
117
 
118
  # Create vectorstore and retriever
119
+ vectorstore = get_vectorstore(embeddings_function)
120
 
121
  # ---------------------------------------------------------------------------
122
  # ClimateQ&A Streaming
 
145
  llm_reformulation = get_llm(
146
  max_tokens=512, temperature=0.0, verbose=True, streaming=False
147
  )
148
+ retriever = QARetriever(
149
+ vectorstore=vectorstore, sources=[], k_summary=0, k_total=10
150
  )
151
  reformulation_chain = load_reformulation_chain(llm_reformulation)
152
 
 
262
  def make_html_source(source, i):
263
  meta = source.metadata
264
  content = source.page_content.split(":", 1)[1].strip()
265
+ link = (
266
+ f'<a href="{meta["url"]}#page={int(meta["page_number"])}" target="_blank" class="pdf-link"><span role="img" aria-label="Open PDF">🔗</span></a>'
267
+ if "url" in meta
268
+ else ""
269
+ )
270
  return f"""
271
  <div class="card">
272
  <div class="card-content">
 
275
  </div>
276
  <div class="card-footer">
277
  <span>{meta['name']}</span>
278
+ {link}
 
 
279
  </div>
280
  </div>
281
  """
chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a13e72541800c513c73dccea69f79e39cf4baef4fa23f7e117c0d6b0f5f99670
3
+ size 3212000
chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ec6df10978b056a10062ed99efeef2702fa4a1301fad702b53dd2517103c746
3
+ size 100
chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc19b1997119425765295aeab72d76faa6927d4f83985d328c26f20468d6cc76
3
+ size 4000
chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/link_lists.bin ADDED
File without changes
chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db081ece29301d223a01bac97e8b2905fada2e7c376cec96bf44fee0f5c95069
3
+ size 1843200
climateqa/build_index.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain.vectorstores import Chroma
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.document_loaders import PyPDFLoader
6
+
7
+ from .embeddings import EMBEDDING_MODEL_NAME
8
+ from .vectorstore import get_vectorstore
9
+
10
+
11
+ def load_data():
12
+ docs = parse_data()
13
+ embedding_function = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
14
+ vectorstore = get_vectorstore(embedding_function)
15
+
16
+ assert isinstance(vectorstore, Chroma)
17
+ vectorstore.from_documents(
18
+ docs, embedding_function, persist_directory="./chroma_db"
19
+ )
20
+ return vectorstore
21
+
22
+
23
+ def parse_data():
24
+ loader = PyPDFLoader("data/daoism/tao-te-ching.pdf")
25
+ pages = loader.load_and_split()
26
+
27
+ # split it into chunks
28
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0)
29
+ docs = text_splitter.split_documents(pages)
30
+ print(docs)
31
+ for doc in docs:
32
+ doc.metadata["name"] = parse_name(doc.metadata["source"])
33
+ doc.metadata["domain"] = parse_domain(doc.metadata["source"])
34
+ doc.metadata["page_number"] = doc.metadata["page"]
35
+ doc.metadata["short_name"] = doc.metadata["name"]
36
+ return docs
37
+
38
+
39
+ def parse_name(source: str) -> str:
40
+ return source.split("/")[-1].split(".")[0]
41
+
42
+
43
+ def parse_domain(source: str) -> str:
44
+ return source.split("/")[2]
45
+
46
+
47
+ if __name__ == "__main__":
48
+ db = load_data()
49
+ # query it
50
+ query = (
51
+ "He who can bear the misfortune of a nation is called the ruler of the world."
52
+ )
53
+ docs = db.similarity_search(query)
54
+ print(docs)
climateqa/chains.py CHANGED
@@ -3,7 +3,7 @@
3
  import json
4
 
5
  from langchain import PromptTemplate, LLMChain
6
- from langchain.chains import RetrievalQAWithSourcesChain, QAWithSourcesChain
7
  from langchain.chains import TransformChain, SequentialChain
8
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
 
@@ -21,6 +21,7 @@ def load_reformulation_chain(llm):
21
  # Parse the output
22
  def parse_output(output):
23
  query = output["query"]
 
24
  json_output = json.loads(output["json"])
25
  question = json_output.get("question", query)
26
  language = json_output.get("language", "English")
 
3
  import json
4
 
5
  from langchain import PromptTemplate, LLMChain
6
+ from langchain.chains import QAWithSourcesChain
7
  from langchain.chains import TransformChain, SequentialChain
8
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
 
 
21
  # Parse the output
22
  def parse_output(output):
23
  query = output["query"]
24
+ print("output", output)
25
  json_output = json.loads(output["json"])
26
  question = json_output.get("question", query)
27
  language = json_output.get("language", "English")
climateqa/chat.py DELETED
@@ -1,42 +0,0 @@
1
- # LANGCHAIN IMPORTS
2
- from langchain import PromptTemplate, LLMChain
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from langchain.chains import RetrievalQAWithSourcesChain
5
- from langchain.chains.qa_with_sources import load_qa_with_sources_chain
6
-
7
-
8
- # CLIMATEQA
9
- from climateqa.retriever import ClimateQARetriever
10
- from climateqa.vectorstore import get_pinecone_vectorstore
11
- from climateqa.chains import load_climateqa_chain
12
-
13
-
14
- class ClimateQA:
15
- def __init__(
16
- self,
17
- hf_embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
18
- show_progress_bar=False,
19
- batch_size=1,
20
- max_tokens=1024,
21
- **kwargs
22
- ):
23
- self.llm = self.get_llm(max_tokens=max_tokens, **kwargs)
24
- self.embeddings_function = HuggingFaceEmbeddings(
25
- model_name=hf_embedding_model,
26
- encode_kwargs={
27
- "show_progress_bar": show_progress_bar,
28
- "batch_size": batch_size,
29
- },
30
- )
31
-
32
- def get_vectorstore(self):
33
- pass
34
-
35
- def reformulate(self):
36
- pass
37
-
38
- def retrieve(self):
39
- pass
40
-
41
- def ask(self):
42
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/custom_retrieval_chain.py CHANGED
@@ -1,35 +1,19 @@
1
  from __future__ import annotations
2
  import inspect
3
- from typing import Any, Dict, List, Optional
4
 
5
- from pydantic import Extra
6
-
7
- from langchain.schema.language_model import BaseLanguageModel
8
  from langchain.callbacks.manager import (
9
- AsyncCallbackManagerForChainRun,
10
  CallbackManagerForChainRun,
11
  )
12
- from langchain.chains.base import Chain
13
- from langchain.prompts.base import BasePromptTemplate
14
 
15
- from typing import Any, Dict, List
16
 
17
  from langchain.callbacks.manager import (
18
- AsyncCallbackManagerForChainRun,
19
  CallbackManagerForChainRun,
20
  )
21
- from langchain.chains.combine_documents.stuff import StuffDocumentsChain
22
- from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
23
- from langchain.docstore.document import Document
24
- from langchain.pydantic_v1 import Field
25
- from langchain.schema import BaseRetriever
26
-
27
  from langchain.chains import RetrievalQAWithSourcesChain
28
 
29
 
30
- from langchain.chains.router.llm_router import LLMRouterChain
31
-
32
-
33
  class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
34
  fallback_answer: str = "No sources available to answer this question."
35
 
 
1
  from __future__ import annotations
2
  import inspect
3
+ from typing import Any, Dict
4
 
 
 
 
5
  from langchain.callbacks.manager import (
 
6
  CallbackManagerForChainRun,
7
  )
 
 
8
 
9
+ from typing import Any, Dict
10
 
11
  from langchain.callbacks.manager import (
 
12
  CallbackManagerForChainRun,
13
  )
 
 
 
 
 
 
14
  from langchain.chains import RetrievalQAWithSourcesChain
15
 
16
 
 
 
 
17
  class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
18
  fallback_answer: str = "No sources available to answer this question."
19
 
climateqa/embeddings.py ADDED
@@ -0,0 +1 @@
 
 
1
+ EMBEDDING_MODEL_NAME = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
climateqa/llm.py CHANGED
@@ -1,7 +1,6 @@
1
- from langchain.chat_models import AzureChatOpenAI
2
  import os
3
 
4
- # LOAD ENVIRONMENT VARIABLES
5
  try:
6
  from dotenv import load_dotenv
7
 
@@ -11,16 +10,50 @@ except:
11
 
12
 
13
  def get_llm(max_tokens=1024, temperature=0.0, verbose=True, streaming=False, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  llm = AzureChatOpenAI(
15
  openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
16
  openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
17
  deployment_name=os.environ["AZURE_OPENAI_API_DEPLOYMENT_NAME"],
18
  openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
19
  openai_api_type="azure",
20
- max_tokens=max_tokens,
21
- temperature=temperature,
22
- verbose=verbose,
23
- streaming=streaming,
24
  **kwargs,
25
  )
26
  return llm
 
1
+ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
2
  import os
3
 
 
4
  try:
5
  from dotenv import load_dotenv
6
 
 
10
 
11
 
12
  def get_llm(max_tokens=1024, temperature=0.0, verbose=True, streaming=False, **kwargs):
13
+ if has_azure_openai_config():
14
+ return get_azure_llm(
15
+ max_tokens=max_tokens,
16
+ temperature=temperature,
17
+ verbose=verbose,
18
+ streaming=streaming,
19
+ **kwargs,
20
+ )
21
+ return get_open_ai_llm(
22
+ max_tokens=max_tokens,
23
+ temperature=temperature,
24
+ verbose=verbose,
25
+ streaming=streaming,
26
+ **kwargs,
27
+ )
28
+
29
+
30
+ def has_azure_openai_config():
31
+ """
32
+ Checks if the necessary environment variables for Azure Blob Storage are set.
33
+ Returns True if they are set, False otherwise.
34
+ """
35
+ return all(
36
+ key in os.environ
37
+ for key in [
38
+ "AZURE_OPENAI_API_BASE_URL",
39
+ "AZURE_OPENAI_API_VERSION",
40
+ "AZURE_OPENAI_API_DEPLOYMENT_NAME",
41
+ "AZURE_OPENAI_API_KEY",
42
+ ]
43
+ )
44
+
45
+
46
+ def get_open_ai_llm(**kwargs):
47
+ return ChatOpenAI(**kwargs)
48
+
49
+
50
+ def get_azure_llm(**kwargs):
51
  llm = AzureChatOpenAI(
52
  openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
53
  openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
54
  deployment_name=os.environ["AZURE_OPENAI_API_DEPLOYMENT_NAME"],
55
  openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
56
  openai_api_type="azure",
 
 
 
 
57
  **kwargs,
58
  )
59
  return llm
climateqa/{logging.py → qa_logging.py} RENAMED
@@ -2,9 +2,6 @@ import datetime
2
  import json
3
  import os
4
 
5
- from azure.storage.fileshare import ShareServiceClient
6
-
7
-
8
  def log(question, history, docs, user_id):
9
  if has_blob_config():
10
  log_in_azure(question, history, docs, user_id)
@@ -49,6 +46,8 @@ def get_azure_blob_client():
49
  }
50
  account_url = os.environ["BLOB_ACCOUNT_URL"]
51
  file_share_name = "climategpt"
 
 
52
  service = ShareServiceClient(account_url=account_url, credential=credential)
53
  share_client = service.get_share_client(file_share_name)
54
  return share_client
 
2
  import json
3
  import os
4
 
 
 
 
5
  def log(question, history, docs, user_id):
6
  if has_blob_config():
7
  log_in_azure(question, history, docs, user_id)
 
46
  }
47
  account_url = os.environ["BLOB_ACCOUNT_URL"]
48
  file_share_name = "climategpt"
49
+ # I don't know why this is necessary, but it cause an error otherwise when running build_index.py
50
+ from azure.storage.fileshare import ShareServiceClient
51
  service = ShareServiceClient(account_url=account_url, credential=credential)
52
  share_client = service.get_share_client(file_share_name)
53
  return share_client
climateqa/retriever.py CHANGED
@@ -1,56 +1,65 @@
1
  # https://github.com/langchain-ai/langchain/issues/8623
2
 
3
- import pandas as pd
4
 
5
  from langchain.schema.retriever import BaseRetriever, Document
6
- from langchain.vectorstores.base import VectorStoreRetriever
7
  from langchain.vectorstores import VectorStore
8
- from langchain.callbacks.manager import CallbackManagerForRetrieverRun
9
  from typing import List
10
- from pydantic import Field
11
 
12
 
13
- class ClimateQARetriever(BaseRetriever):
 
 
 
 
14
  vectorstore: VectorStore
15
- sources: list = ["IPCC", "IPBES"]
16
  threshold: float = 22
17
- k_summary: int = 3
18
  k_total: int = 10
19
  namespace: str = "vectors"
20
 
21
  def get_relevant_documents(self, query: str) -> List[Document]:
22
  # Check if all elements in the list are either IPCC or IPBES
23
  assert isinstance(self.sources, list)
24
- assert all([x in ["IPCC", "IPBES"] for x in self.sources])
25
  assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
26
 
 
27
  # Prepare base search kwargs
28
- filters = {
29
- "source": {"$in": self.sources},
30
- }
31
-
32
- # Search for k_summary documents in the summaries dataset
33
- filters_summaries = {
34
- **filters,
35
- "report_type": {"$in": ["SPM", "TS"]},
36
- }
37
- docs_summaries = self.vectorstore.similarity_search_with_score(
38
- query=query,
39
- namespace=self.namespace,
40
- filter=filters_summaries,
41
- k=self.k_summary,
42
- )
43
- docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
 
 
 
 
44
 
45
  # Search for k_total - k_summary documents in the full reports dataset
46
- filters_full = {
47
- **filters,
48
- "report_type": {"$nin": ["SPM", "TS"]},
49
- }
50
  k_full = self.k_total - len(docs_summaries)
51
  docs_full = self.vectorstore.similarity_search_with_score(
52
- query=query, namespace=self.namespace, filter=filters_full, k=k_full
 
 
 
53
  )
 
54
 
55
  # Concatenate documents
56
  docs = docs_summaries + docs_full
@@ -71,81 +80,13 @@ class ClimateQARetriever(BaseRetriever):
71
 
72
  return results
73
 
74
-
75
- # def filter_summaries(df,k_summary = 3,k_total = 10):
76
- # # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
77
-
78
- # # # Filter by source
79
- # # if source == "IPCC":
80
- # # df = df.loc[df["source"]=="IPCC"]
81
- # # elif source == "IPBES":
82
- # # df = df.loc[df["source"]=="IPBES"]
83
- # # else:
84
- # # pass
85
-
86
- # # Separate summaries and full reports
87
- # df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
88
- # df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
89
-
90
- # # Find passages from summaries dataset
91
- # passages_summaries = df_summaries.head(k_summary)
92
-
93
- # # Find passages from full reports dataset
94
- # passages_fullreports = df_full.head(k_total - len(passages_summaries))
95
-
96
- # # Concatenate passages
97
- # passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
98
- # return passages
99
-
100
-
101
- # def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
102
- # assert max_k > k_total
103
-
104
- # validated_sources = ["IPCC","IPBES"]
105
- # sources = [x for x in sources if x in validated_sources]
106
- # filters = {
107
- # "source": { "$in": sources },
108
- # }
109
- # print(filters)
110
-
111
- # # Retrieve documents
112
- # docs = retriever.retrieve(query,top_k = max_k,filters = filters)
113
-
114
- # # Filter by score
115
- # docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
116
-
117
- # if len(docs) == 0:
118
- # return []
119
- # res = pd.DataFrame(docs)
120
- # passages_df = filter_summaries(res,k_summary,k_total)
121
- # if as_dict:
122
- # contents = passages_df["content"].tolist()
123
- # meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
124
- # passages = []
125
- # for i in range(len(contents)):
126
- # passages.append({"content":contents[i],"meta":meta[i]})
127
- # return passages
128
- # else:
129
- # return passages_df
130
-
131
-
132
- # def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
133
-
134
-
135
- # print("hellooooo")
136
-
137
- # # Reformulate queries
138
- # reformulated_query,language = reformulate(query)
139
-
140
- # print(reformulated_query)
141
-
142
- # # Retrieve documents
143
- # passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
144
- # response = {
145
- # "query":query,
146
- # "reformulated_query":reformulated_query,
147
- # "language":language,
148
- # "sources":passages,
149
- # "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
150
- # }
151
- # return response
 
1
  # https://github.com/langchain-ai/langchain/issues/8623
2
 
 
3
 
4
  from langchain.schema.retriever import BaseRetriever, Document
 
5
  from langchain.vectorstores import VectorStore
6
+ from langchain.vectorstores import Chroma
7
  from typing import List
 
8
 
9
 
10
+ ## The idea that some documents are summaries so easier to exploit
11
+ SUMMARY_TYPES = []
12
+
13
+
14
+ class QARetriever(BaseRetriever):
15
  vectorstore: VectorStore
16
+ sources: list = []
17
  threshold: float = 22
18
+ k_summary: int = 0
19
  k_total: int = 10
20
  namespace: str = "vectors"
21
 
22
  def get_relevant_documents(self, query: str) -> List[Document]:
23
  # Check if all elements in the list are either IPCC or IPBES
24
  assert isinstance(self.sources, list)
 
25
  assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
26
 
27
+ query = "He who can bear the misfortune of a nation is called the ruler of the world."
28
  # Prepare base search kwargs
29
+ filters = {}
30
+ if len(self.sources):
31
+ filters["source"] = {"$in": self.sources}
32
+
33
+ if self.k_summary > 0:
34
+ # Search for k_summary documents in the summaries dataset
35
+ if len(SUMMARY_TYPES):
36
+ filters_summaries = {
37
+ **filters_summaries,
38
+ "report_type": {"$in": SUMMARY_TYPES},
39
+ }
40
+ docs_summaries = self.vectorstore.similarity_search_with_score(
41
+ query=query,
42
+ # namespace=self.namespace,
43
+ filter=self.format_filter(filters_summaries),
44
+ k=self.k_summary,
45
+ )
46
+ docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
47
+ else:
48
+ docs_summaries = []
49
 
50
  # Search for k_total - k_summary documents in the full reports dataset
51
+ filters_full = {}
52
+ if len(SUMMARY_TYPES):
53
+ filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
54
+
55
  k_full = self.k_total - len(docs_summaries)
56
  docs_full = self.vectorstore.similarity_search_with_score(
57
+ query=query,
58
+ # namespace=self.namespace,
59
+ filter=self.format_filter(filters_full),
60
+ k=k_full,
61
  )
62
+ print("docs_full", docs_full)
63
 
64
  # Concatenate documents
65
  docs = docs_summaries + docs_full
 
80
 
81
  return results
82
 
83
+ def format_filter(self, filters):
84
+ # https://docs.trychroma.com/usage-guide#using-logical-operators
85
+ if isinstance(self.vectorstore, Chroma):
86
+ if len(filters) <= 1:
87
+ return filters
88
+ and_filters = []
89
+ for field, condition in filters.items():
90
+ and_filters.append({field: condition})
91
+ return {"$and": and_filters}
92
+ return filters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/vectorstore.py CHANGED
@@ -3,9 +3,8 @@
3
  # And https://python.langchain.com/docs/integrations/vectorstores/pinecone
4
  import os
5
  import pinecone
6
- from langchain.vectorstores import Pinecone
7
 
8
- # LOAD ENVIRONMENT VARIABLES
9
  try:
10
  from dotenv import load_dotenv
11
 
@@ -14,7 +13,30 @@ except:
14
  pass
15
 
16
 
17
- def get_pinecone_vectorstore(embeddings, text_key="content"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # initialize pinecone
19
  pinecone.init(
20
  api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
@@ -23,6 +45,6 @@ def get_pinecone_vectorstore(embeddings, text_key="content"):
23
 
24
  index_name = os.getenv("PINECONE_API_INDEX")
25
  vectorstore = Pinecone.from_existing_index(
26
- index_name, embeddings, text_key=text_key
27
  )
28
  return vectorstore
 
3
  # And https://python.langchain.com/docs/integrations/vectorstores/pinecone
4
  import os
5
  import pinecone
6
+ from langchain.vectorstores import Chroma, Pinecone
7
 
 
8
  try:
9
  from dotenv import load_dotenv
10
 
 
13
  pass
14
 
15
 
16
+ def get_vectorstore(embeddings_function):
17
+ if has_pinecone_config():
18
+ return get_pinecone_vectorstore(embeddings_function)
19
+ return get_chroma_vectore_store(embeddings_function)
20
+
21
+
22
+ def get_chroma_vectore_store(embedding_function):
23
+ return Chroma(
24
+ persist_directory="./chroma_db", embedding_function=embedding_function
25
+ )
26
+
27
+
28
+ def has_pinecone_config():
29
+ return all(
30
+ key in os.environ
31
+ for key in [
32
+ "PINECONE_API_KEY",
33
+ "PINECONE_API_ENVIRONMENT",
34
+ "PINECONE_API_INDEX",
35
+ ]
36
+ )
37
+
38
+
39
+ def get_pinecone_vectorstore(embeddings_function, text_key="content"):
40
  # initialize pinecone
41
  pinecone.init(
42
  api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
 
45
 
46
  index_name = os.getenv("PINECONE_API_INDEX")
47
  vectorstore = Pinecone.from_existing_index(
48
+ index_name, embeddings_function, text_key=text_key
49
  )
50
  return vectorstore
constitution.pdf ADDED
Binary file (414 kB). View file
 
data/daoism/tao-te-ching.pdf ADDED
Binary file (174 kB). View file
 
data/us-founding/constitution.pdf ADDED
Binary file (414 kB). View file
 
data/us-founding/declaration-of-independance.pdf ADDED
Binary file (742 kB). View file
 
declaration-of-independance.pdf ADDED
Binary file (742 kB). View file
 
requirements.txt CHANGED
@@ -1,7 +1,9 @@
1
- gradio==3.47.1
2
- openai==0.27.0
3
  azure-storage-file-share==12.11.1
4
- python-dotenv==1.0.0
 
5
  langchain==0.0.295
 
6
  pinecone-client==2.2.1
 
 
7
  sentence-transformers==2.2.2
 
 
 
1
  azure-storage-file-share==12.11.1
2
+ chromadb==0.4.14
3
+ gradio==3.47.1
4
  langchain==0.0.295
5
+ openai==0.27.0
6
  pinecone-client==2.2.1
7
+ pypdf==3.16.4
8
+ python-dotenv==1.0.0
9
  sentence-transformers==2.2.2
utils.py CHANGED
@@ -1,6 +1,3 @@
1
- import numpy as np
2
- import random
3
- import string
4
  import uuid
5
 
6
 
 
 
 
 
1
  import uuid
2
 
3