Spaces:
Runtime error
Runtime error
LOUIS SANNA
commited on
Commit
•
cc2ce8c
1
Parent(s):
6e28a81
feat(loader)
Browse files- .DS_Store +0 -0
- .gitattributes +3 -0
- README.md +10 -2
- app.py +14 -14
- chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/data_level0.bin +3 -0
- chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/header.bin +3 -0
- chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/length.bin +3 -0
- chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/link_lists.bin +0 -0
- chroma_db/chroma.sqlite3 +3 -0
- climateqa/build_index.py +54 -0
- climateqa/chains.py +2 -1
- climateqa/chat.py +0 -42
- climateqa/custom_retrieval_chain.py +2 -18
- climateqa/embeddings.py +1 -0
- climateqa/llm.py +39 -6
- climateqa/{logging.py → qa_logging.py} +2 -3
- climateqa/retriever.py +48 -107
- climateqa/vectorstore.py +26 -4
- constitution.pdf +0 -0
- data/daoism/tao-te-ching.pdf +0 -0
- data/us-founding/constitution.pdf +0 -0
- data/us-founding/declaration-of-independance.pdf +0 -0
- declaration-of-independance.pdf +0 -0
- requirements.txt +5 -3
- utils.py +0 -3
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -44,3 +44,6 @@ documents/climate_gpt_v2_only_giec.faiss filter=lfs diff=lfs merge=lfs -text
|
|
44 |
documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
|
45 |
climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
|
46 |
climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
44 |
documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
|
45 |
climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
|
46 |
climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
|
47 |
+
data filter=lfs diff=lfs merge=lfs -text
|
48 |
+
chroma_db filter=lfs diff=lfs merge=lfs -text
|
49 |
+
chroma_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🌍
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
@@ -9,4 +9,12 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: AnythingQ&A
|
3 |
emoji: 🌍
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Anything Q&A
|
13 |
+
|
14 |
+
A clone of the amazing https://huggingface.co/spaces/Ekimetrics/climate-question-answering.
|
15 |
+
|
16 |
+
## Build vector index
|
17 |
+
|
18 |
+
```bash
|
19 |
+
python -m climateqa.build_index
|
20 |
+
```
|
app.py
CHANGED
@@ -2,18 +2,18 @@ import gradio as gr
|
|
2 |
|
3 |
from utils import create_user_id
|
4 |
|
5 |
-
|
6 |
# Langchain
|
7 |
from langchain.embeddings import HuggingFaceEmbeddings
|
8 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
9 |
|
10 |
# ClimateQ&A imports
|
|
|
11 |
from climateqa.llm import get_llm
|
12 |
-
from climateqa.
|
13 |
from climateqa.chains import load_qa_chain_with_text
|
14 |
from climateqa.chains import load_reformulation_chain
|
15 |
-
from climateqa.vectorstore import
|
16 |
-
from climateqa.retriever import
|
17 |
from climateqa.prompts import audience_prompts
|
18 |
|
19 |
# Load environment variables in local mode
|
@@ -113,13 +113,10 @@ class StreamingGradioCallbackHandler(BaseCallbackHandler):
|
|
113 |
|
114 |
|
115 |
# Create embeddings function and LLM
|
116 |
-
embeddings_function = HuggingFaceEmbeddings(
|
117 |
-
model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
118 |
-
)
|
119 |
-
|
120 |
|
121 |
# Create vectorstore and retriever
|
122 |
-
vectorstore =
|
123 |
|
124 |
# ---------------------------------------------------------------------------
|
125 |
# ClimateQ&A Streaming
|
@@ -148,8 +145,8 @@ def fetch_sources(query, sources):
|
|
148 |
llm_reformulation = get_llm(
|
149 |
max_tokens=512, temperature=0.0, verbose=True, streaming=False
|
150 |
)
|
151 |
-
retriever =
|
152 |
-
vectorstore=vectorstore, sources=
|
153 |
)
|
154 |
reformulation_chain = load_reformulation_chain(llm_reformulation)
|
155 |
|
@@ -265,6 +262,11 @@ def answer_bot(query, history, docs, question, language, audience):
|
|
265 |
def make_html_source(source, i):
|
266 |
meta = source.metadata
|
267 |
content = source.page_content.split(":", 1)[1].strip()
|
|
|
|
|
|
|
|
|
|
|
268 |
return f"""
|
269 |
<div class="card">
|
270 |
<div class="card-content">
|
@@ -273,9 +275,7 @@ def make_html_source(source, i):
|
|
273 |
</div>
|
274 |
<div class="card-footer">
|
275 |
<span>{meta['name']}</span>
|
276 |
-
|
277 |
-
<span role="img" aria-label="Open PDF">🔗</span>
|
278 |
-
</a>
|
279 |
</div>
|
280 |
</div>
|
281 |
"""
|
|
|
2 |
|
3 |
from utils import create_user_id
|
4 |
|
|
|
5 |
# Langchain
|
6 |
from langchain.embeddings import HuggingFaceEmbeddings
|
7 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
8 |
|
9 |
# ClimateQ&A imports
|
10 |
+
from climateqa.embeddings import EMBEDDING_MODEL_NAME
|
11 |
from climateqa.llm import get_llm
|
12 |
+
from climateqa.qa_logging import log
|
13 |
from climateqa.chains import load_qa_chain_with_text
|
14 |
from climateqa.chains import load_reformulation_chain
|
15 |
+
from climateqa.vectorstore import get_vectorstore
|
16 |
+
from climateqa.retriever import QARetriever
|
17 |
from climateqa.prompts import audience_prompts
|
18 |
|
19 |
# Load environment variables in local mode
|
|
|
113 |
|
114 |
|
115 |
# Create embeddings function and LLM
|
116 |
+
embeddings_function = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
|
|
|
|
|
|
|
117 |
|
118 |
# Create vectorstore and retriever
|
119 |
+
vectorstore = get_vectorstore(embeddings_function)
|
120 |
|
121 |
# ---------------------------------------------------------------------------
|
122 |
# ClimateQ&A Streaming
|
|
|
145 |
llm_reformulation = get_llm(
|
146 |
max_tokens=512, temperature=0.0, verbose=True, streaming=False
|
147 |
)
|
148 |
+
retriever = QARetriever(
|
149 |
+
vectorstore=vectorstore, sources=[], k_summary=0, k_total=10
|
150 |
)
|
151 |
reformulation_chain = load_reformulation_chain(llm_reformulation)
|
152 |
|
|
|
262 |
def make_html_source(source, i):
|
263 |
meta = source.metadata
|
264 |
content = source.page_content.split(":", 1)[1].strip()
|
265 |
+
link = (
|
266 |
+
f'<a href="{meta["url"]}#page={int(meta["page_number"])}" target="_blank" class="pdf-link"><span role="img" aria-label="Open PDF">🔗</span></a>'
|
267 |
+
if "url" in meta
|
268 |
+
else ""
|
269 |
+
)
|
270 |
return f"""
|
271 |
<div class="card">
|
272 |
<div class="card-content">
|
|
|
275 |
</div>
|
276 |
<div class="card-footer">
|
277 |
<span>{meta['name']}</span>
|
278 |
+
{link}
|
|
|
|
|
279 |
</div>
|
280 |
</div>
|
281 |
"""
|
chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/data_level0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a13e72541800c513c73dccea69f79e39cf4baef4fa23f7e117c0d6b0f5f99670
|
3 |
+
size 3212000
|
chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/header.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ec6df10978b056a10062ed99efeef2702fa4a1301fad702b53dd2517103c746
|
3 |
+
size 100
|
chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/length.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc19b1997119425765295aeab72d76faa6927d4f83985d328c26f20468d6cc76
|
3 |
+
size 4000
|
chroma_db/5fa47764-2449-49fb-ae2f-0fd1886dfa2d/link_lists.bin
ADDED
File without changes
|
chroma_db/chroma.sqlite3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db081ece29301d223a01bac97e8b2905fada2e7c376cec96bf44fee0f5c95069
|
3 |
+
size 1843200
|
climateqa/build_index.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import
|
2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
from langchain.vectorstores import Chroma
|
4 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
5 |
+
from langchain.document_loaders import PyPDFLoader
|
6 |
+
|
7 |
+
from .embeddings import EMBEDDING_MODEL_NAME
|
8 |
+
from .vectorstore import get_vectorstore
|
9 |
+
|
10 |
+
|
11 |
+
def load_data():
|
12 |
+
docs = parse_data()
|
13 |
+
embedding_function = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
|
14 |
+
vectorstore = get_vectorstore(embedding_function)
|
15 |
+
|
16 |
+
assert isinstance(vectorstore, Chroma)
|
17 |
+
vectorstore.from_documents(
|
18 |
+
docs, embedding_function, persist_directory="./chroma_db"
|
19 |
+
)
|
20 |
+
return vectorstore
|
21 |
+
|
22 |
+
|
23 |
+
def parse_data():
|
24 |
+
loader = PyPDFLoader("data/daoism/tao-te-ching.pdf")
|
25 |
+
pages = loader.load_and_split()
|
26 |
+
|
27 |
+
# split it into chunks
|
28 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0)
|
29 |
+
docs = text_splitter.split_documents(pages)
|
30 |
+
print(docs)
|
31 |
+
for doc in docs:
|
32 |
+
doc.metadata["name"] = parse_name(doc.metadata["source"])
|
33 |
+
doc.metadata["domain"] = parse_domain(doc.metadata["source"])
|
34 |
+
doc.metadata["page_number"] = doc.metadata["page"]
|
35 |
+
doc.metadata["short_name"] = doc.metadata["name"]
|
36 |
+
return docs
|
37 |
+
|
38 |
+
|
39 |
+
def parse_name(source: str) -> str:
|
40 |
+
return source.split("/")[-1].split(".")[0]
|
41 |
+
|
42 |
+
|
43 |
+
def parse_domain(source: str) -> str:
|
44 |
+
return source.split("/")[2]
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
db = load_data()
|
49 |
+
# query it
|
50 |
+
query = (
|
51 |
+
"He who can bear the misfortune of a nation is called the ruler of the world."
|
52 |
+
)
|
53 |
+
docs = db.similarity_search(query)
|
54 |
+
print(docs)
|
climateqa/chains.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
import json
|
4 |
|
5 |
from langchain import PromptTemplate, LLMChain
|
6 |
-
from langchain.chains import
|
7 |
from langchain.chains import TransformChain, SequentialChain
|
8 |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
9 |
|
@@ -21,6 +21,7 @@ def load_reformulation_chain(llm):
|
|
21 |
# Parse the output
|
22 |
def parse_output(output):
|
23 |
query = output["query"]
|
|
|
24 |
json_output = json.loads(output["json"])
|
25 |
question = json_output.get("question", query)
|
26 |
language = json_output.get("language", "English")
|
|
|
3 |
import json
|
4 |
|
5 |
from langchain import PromptTemplate, LLMChain
|
6 |
+
from langchain.chains import QAWithSourcesChain
|
7 |
from langchain.chains import TransformChain, SequentialChain
|
8 |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
9 |
|
|
|
21 |
# Parse the output
|
22 |
def parse_output(output):
|
23 |
query = output["query"]
|
24 |
+
print("output", output)
|
25 |
json_output = json.loads(output["json"])
|
26 |
question = json_output.get("question", query)
|
27 |
language = json_output.get("language", "English")
|
climateqa/chat.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
# LANGCHAIN IMPORTS
|
2 |
-
from langchain import PromptTemplate, LLMChain
|
3 |
-
from langchain.embeddings import HuggingFaceEmbeddings
|
4 |
-
from langchain.chains import RetrievalQAWithSourcesChain
|
5 |
-
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
6 |
-
|
7 |
-
|
8 |
-
# CLIMATEQA
|
9 |
-
from climateqa.retriever import ClimateQARetriever
|
10 |
-
from climateqa.vectorstore import get_pinecone_vectorstore
|
11 |
-
from climateqa.chains import load_climateqa_chain
|
12 |
-
|
13 |
-
|
14 |
-
class ClimateQA:
|
15 |
-
def __init__(
|
16 |
-
self,
|
17 |
-
hf_embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
18 |
-
show_progress_bar=False,
|
19 |
-
batch_size=1,
|
20 |
-
max_tokens=1024,
|
21 |
-
**kwargs
|
22 |
-
):
|
23 |
-
self.llm = self.get_llm(max_tokens=max_tokens, **kwargs)
|
24 |
-
self.embeddings_function = HuggingFaceEmbeddings(
|
25 |
-
model_name=hf_embedding_model,
|
26 |
-
encode_kwargs={
|
27 |
-
"show_progress_bar": show_progress_bar,
|
28 |
-
"batch_size": batch_size,
|
29 |
-
},
|
30 |
-
)
|
31 |
-
|
32 |
-
def get_vectorstore(self):
|
33 |
-
pass
|
34 |
-
|
35 |
-
def reformulate(self):
|
36 |
-
pass
|
37 |
-
|
38 |
-
def retrieve(self):
|
39 |
-
pass
|
40 |
-
|
41 |
-
def ask(self):
|
42 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/custom_retrieval_chain.py
CHANGED
@@ -1,35 +1,19 @@
|
|
1 |
from __future__ import annotations
|
2 |
import inspect
|
3 |
-
from typing import Any, Dict
|
4 |
|
5 |
-
from pydantic import Extra
|
6 |
-
|
7 |
-
from langchain.schema.language_model import BaseLanguageModel
|
8 |
from langchain.callbacks.manager import (
|
9 |
-
AsyncCallbackManagerForChainRun,
|
10 |
CallbackManagerForChainRun,
|
11 |
)
|
12 |
-
from langchain.chains.base import Chain
|
13 |
-
from langchain.prompts.base import BasePromptTemplate
|
14 |
|
15 |
-
from typing import Any, Dict
|
16 |
|
17 |
from langchain.callbacks.manager import (
|
18 |
-
AsyncCallbackManagerForChainRun,
|
19 |
CallbackManagerForChainRun,
|
20 |
)
|
21 |
-
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
22 |
-
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
23 |
-
from langchain.docstore.document import Document
|
24 |
-
from langchain.pydantic_v1 import Field
|
25 |
-
from langchain.schema import BaseRetriever
|
26 |
-
|
27 |
from langchain.chains import RetrievalQAWithSourcesChain
|
28 |
|
29 |
|
30 |
-
from langchain.chains.router.llm_router import LLMRouterChain
|
31 |
-
|
32 |
-
|
33 |
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
|
34 |
fallback_answer: str = "No sources available to answer this question."
|
35 |
|
|
|
1 |
from __future__ import annotations
|
2 |
import inspect
|
3 |
+
from typing import Any, Dict
|
4 |
|
|
|
|
|
|
|
5 |
from langchain.callbacks.manager import (
|
|
|
6 |
CallbackManagerForChainRun,
|
7 |
)
|
|
|
|
|
8 |
|
9 |
+
from typing import Any, Dict
|
10 |
|
11 |
from langchain.callbacks.manager import (
|
|
|
12 |
CallbackManagerForChainRun,
|
13 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from langchain.chains import RetrievalQAWithSourcesChain
|
15 |
|
16 |
|
|
|
|
|
|
|
17 |
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
|
18 |
fallback_answer: str = "No sources available to answer this question."
|
19 |
|
climateqa/embeddings.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
EMBEDDING_MODEL_NAME = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
climateqa/llm.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
-
from langchain.chat_models import AzureChatOpenAI
|
2 |
import os
|
3 |
|
4 |
-
# LOAD ENVIRONMENT VARIABLES
|
5 |
try:
|
6 |
from dotenv import load_dotenv
|
7 |
|
@@ -11,16 +10,50 @@ except:
|
|
11 |
|
12 |
|
13 |
def get_llm(max_tokens=1024, temperature=0.0, verbose=True, streaming=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
llm = AzureChatOpenAI(
|
15 |
openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
|
16 |
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
17 |
deployment_name=os.environ["AZURE_OPENAI_API_DEPLOYMENT_NAME"],
|
18 |
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
19 |
openai_api_type="azure",
|
20 |
-
max_tokens=max_tokens,
|
21 |
-
temperature=temperature,
|
22 |
-
verbose=verbose,
|
23 |
-
streaming=streaming,
|
24 |
**kwargs,
|
25 |
)
|
26 |
return llm
|
|
|
1 |
+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
2 |
import os
|
3 |
|
|
|
4 |
try:
|
5 |
from dotenv import load_dotenv
|
6 |
|
|
|
10 |
|
11 |
|
12 |
def get_llm(max_tokens=1024, temperature=0.0, verbose=True, streaming=False, **kwargs):
|
13 |
+
if has_azure_openai_config():
|
14 |
+
return get_azure_llm(
|
15 |
+
max_tokens=max_tokens,
|
16 |
+
temperature=temperature,
|
17 |
+
verbose=verbose,
|
18 |
+
streaming=streaming,
|
19 |
+
**kwargs,
|
20 |
+
)
|
21 |
+
return get_open_ai_llm(
|
22 |
+
max_tokens=max_tokens,
|
23 |
+
temperature=temperature,
|
24 |
+
verbose=verbose,
|
25 |
+
streaming=streaming,
|
26 |
+
**kwargs,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def has_azure_openai_config():
|
31 |
+
"""
|
32 |
+
Checks if the necessary environment variables for Azure Blob Storage are set.
|
33 |
+
Returns True if they are set, False otherwise.
|
34 |
+
"""
|
35 |
+
return all(
|
36 |
+
key in os.environ
|
37 |
+
for key in [
|
38 |
+
"AZURE_OPENAI_API_BASE_URL",
|
39 |
+
"AZURE_OPENAI_API_VERSION",
|
40 |
+
"AZURE_OPENAI_API_DEPLOYMENT_NAME",
|
41 |
+
"AZURE_OPENAI_API_KEY",
|
42 |
+
]
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def get_open_ai_llm(**kwargs):
|
47 |
+
return ChatOpenAI(**kwargs)
|
48 |
+
|
49 |
+
|
50 |
+
def get_azure_llm(**kwargs):
|
51 |
llm = AzureChatOpenAI(
|
52 |
openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
|
53 |
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
54 |
deployment_name=os.environ["AZURE_OPENAI_API_DEPLOYMENT_NAME"],
|
55 |
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
56 |
openai_api_type="azure",
|
|
|
|
|
|
|
|
|
57 |
**kwargs,
|
58 |
)
|
59 |
return llm
|
climateqa/{logging.py → qa_logging.py}
RENAMED
@@ -2,9 +2,6 @@ import datetime
|
|
2 |
import json
|
3 |
import os
|
4 |
|
5 |
-
from azure.storage.fileshare import ShareServiceClient
|
6 |
-
|
7 |
-
|
8 |
def log(question, history, docs, user_id):
|
9 |
if has_blob_config():
|
10 |
log_in_azure(question, history, docs, user_id)
|
@@ -49,6 +46,8 @@ def get_azure_blob_client():
|
|
49 |
}
|
50 |
account_url = os.environ["BLOB_ACCOUNT_URL"]
|
51 |
file_share_name = "climategpt"
|
|
|
|
|
52 |
service = ShareServiceClient(account_url=account_url, credential=credential)
|
53 |
share_client = service.get_share_client(file_share_name)
|
54 |
return share_client
|
|
|
2 |
import json
|
3 |
import os
|
4 |
|
|
|
|
|
|
|
5 |
def log(question, history, docs, user_id):
|
6 |
if has_blob_config():
|
7 |
log_in_azure(question, history, docs, user_id)
|
|
|
46 |
}
|
47 |
account_url = os.environ["BLOB_ACCOUNT_URL"]
|
48 |
file_share_name = "climategpt"
|
49 |
+
# I don't know why this is necessary, but it cause an error otherwise when running build_index.py
|
50 |
+
from azure.storage.fileshare import ShareServiceClient
|
51 |
service = ShareServiceClient(account_url=account_url, credential=credential)
|
52 |
share_client = service.get_share_client(file_share_name)
|
53 |
return share_client
|
climateqa/retriever.py
CHANGED
@@ -1,56 +1,65 @@
|
|
1 |
# https://github.com/langchain-ai/langchain/issues/8623
|
2 |
|
3 |
-
import pandas as pd
|
4 |
|
5 |
from langchain.schema.retriever import BaseRetriever, Document
|
6 |
-
from langchain.vectorstores.base import VectorStoreRetriever
|
7 |
from langchain.vectorstores import VectorStore
|
8 |
-
from langchain.
|
9 |
from typing import List
|
10 |
-
from pydantic import Field
|
11 |
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
vectorstore: VectorStore
|
15 |
-
sources: list = [
|
16 |
threshold: float = 22
|
17 |
-
k_summary: int =
|
18 |
k_total: int = 10
|
19 |
namespace: str = "vectors"
|
20 |
|
21 |
def get_relevant_documents(self, query: str) -> List[Document]:
|
22 |
# Check if all elements in the list are either IPCC or IPBES
|
23 |
assert isinstance(self.sources, list)
|
24 |
-
assert all([x in ["IPCC", "IPBES"] for x in self.sources])
|
25 |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
26 |
|
|
|
27 |
# Prepare base search kwargs
|
28 |
-
filters = {
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# Search for k_total - k_summary documents in the full reports dataset
|
46 |
-
filters_full = {
|
47 |
-
|
48 |
-
"report_type": {"$nin":
|
49 |
-
|
50 |
k_full = self.k_total - len(docs_summaries)
|
51 |
docs_full = self.vectorstore.similarity_search_with_score(
|
52 |
-
query=query,
|
|
|
|
|
|
|
53 |
)
|
|
|
54 |
|
55 |
# Concatenate documents
|
56 |
docs = docs_summaries + docs_full
|
@@ -71,81 +80,13 @@ class ClimateQARetriever(BaseRetriever):
|
|
71 |
|
72 |
return results
|
73 |
|
74 |
-
|
75 |
-
#
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
# # pass
|
85 |
-
|
86 |
-
# # Separate summaries and full reports
|
87 |
-
# df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
|
88 |
-
# df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
|
89 |
-
|
90 |
-
# # Find passages from summaries dataset
|
91 |
-
# passages_summaries = df_summaries.head(k_summary)
|
92 |
-
|
93 |
-
# # Find passages from full reports dataset
|
94 |
-
# passages_fullreports = df_full.head(k_total - len(passages_summaries))
|
95 |
-
|
96 |
-
# # Concatenate passages
|
97 |
-
# passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
|
98 |
-
# return passages
|
99 |
-
|
100 |
-
|
101 |
-
# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
|
102 |
-
# assert max_k > k_total
|
103 |
-
|
104 |
-
# validated_sources = ["IPCC","IPBES"]
|
105 |
-
# sources = [x for x in sources if x in validated_sources]
|
106 |
-
# filters = {
|
107 |
-
# "source": { "$in": sources },
|
108 |
-
# }
|
109 |
-
# print(filters)
|
110 |
-
|
111 |
-
# # Retrieve documents
|
112 |
-
# docs = retriever.retrieve(query,top_k = max_k,filters = filters)
|
113 |
-
|
114 |
-
# # Filter by score
|
115 |
-
# docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
|
116 |
-
|
117 |
-
# if len(docs) == 0:
|
118 |
-
# return []
|
119 |
-
# res = pd.DataFrame(docs)
|
120 |
-
# passages_df = filter_summaries(res,k_summary,k_total)
|
121 |
-
# if as_dict:
|
122 |
-
# contents = passages_df["content"].tolist()
|
123 |
-
# meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
|
124 |
-
# passages = []
|
125 |
-
# for i in range(len(contents)):
|
126 |
-
# passages.append({"content":contents[i],"meta":meta[i]})
|
127 |
-
# return passages
|
128 |
-
# else:
|
129 |
-
# return passages_df
|
130 |
-
|
131 |
-
|
132 |
-
# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
|
133 |
-
|
134 |
-
|
135 |
-
# print("hellooooo")
|
136 |
-
|
137 |
-
# # Reformulate queries
|
138 |
-
# reformulated_query,language = reformulate(query)
|
139 |
-
|
140 |
-
# print(reformulated_query)
|
141 |
-
|
142 |
-
# # Retrieve documents
|
143 |
-
# passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
|
144 |
-
# response = {
|
145 |
-
# "query":query,
|
146 |
-
# "reformulated_query":reformulated_query,
|
147 |
-
# "language":language,
|
148 |
-
# "sources":passages,
|
149 |
-
# "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
|
150 |
-
# }
|
151 |
-
# return response
|
|
|
1 |
# https://github.com/langchain-ai/langchain/issues/8623
|
2 |
|
|
|
3 |
|
4 |
from langchain.schema.retriever import BaseRetriever, Document
|
|
|
5 |
from langchain.vectorstores import VectorStore
|
6 |
+
from langchain.vectorstores import Chroma
|
7 |
from typing import List
|
|
|
8 |
|
9 |
|
10 |
+
## The idea that some documents are summaries so easier to exploit
|
11 |
+
SUMMARY_TYPES = []
|
12 |
+
|
13 |
+
|
14 |
+
class QARetriever(BaseRetriever):
|
15 |
vectorstore: VectorStore
|
16 |
+
sources: list = []
|
17 |
threshold: float = 22
|
18 |
+
k_summary: int = 0
|
19 |
k_total: int = 10
|
20 |
namespace: str = "vectors"
|
21 |
|
22 |
def get_relevant_documents(self, query: str) -> List[Document]:
|
23 |
# Check if all elements in the list are either IPCC or IPBES
|
24 |
assert isinstance(self.sources, list)
|
|
|
25 |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
26 |
|
27 |
+
query = "He who can bear the misfortune of a nation is called the ruler of the world."
|
28 |
# Prepare base search kwargs
|
29 |
+
filters = {}
|
30 |
+
if len(self.sources):
|
31 |
+
filters["source"] = {"$in": self.sources}
|
32 |
+
|
33 |
+
if self.k_summary > 0:
|
34 |
+
# Search for k_summary documents in the summaries dataset
|
35 |
+
if len(SUMMARY_TYPES):
|
36 |
+
filters_summaries = {
|
37 |
+
**filters_summaries,
|
38 |
+
"report_type": {"$in": SUMMARY_TYPES},
|
39 |
+
}
|
40 |
+
docs_summaries = self.vectorstore.similarity_search_with_score(
|
41 |
+
query=query,
|
42 |
+
# namespace=self.namespace,
|
43 |
+
filter=self.format_filter(filters_summaries),
|
44 |
+
k=self.k_summary,
|
45 |
+
)
|
46 |
+
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
|
47 |
+
else:
|
48 |
+
docs_summaries = []
|
49 |
|
50 |
# Search for k_total - k_summary documents in the full reports dataset
|
51 |
+
filters_full = {}
|
52 |
+
if len(SUMMARY_TYPES):
|
53 |
+
filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
|
54 |
+
|
55 |
k_full = self.k_total - len(docs_summaries)
|
56 |
docs_full = self.vectorstore.similarity_search_with_score(
|
57 |
+
query=query,
|
58 |
+
# namespace=self.namespace,
|
59 |
+
filter=self.format_filter(filters_full),
|
60 |
+
k=k_full,
|
61 |
)
|
62 |
+
print("docs_full", docs_full)
|
63 |
|
64 |
# Concatenate documents
|
65 |
docs = docs_summaries + docs_full
|
|
|
80 |
|
81 |
return results
|
82 |
|
83 |
+
def format_filter(self, filters):
|
84 |
+
# https://docs.trychroma.com/usage-guide#using-logical-operators
|
85 |
+
if isinstance(self.vectorstore, Chroma):
|
86 |
+
if len(filters) <= 1:
|
87 |
+
return filters
|
88 |
+
and_filters = []
|
89 |
+
for field, condition in filters.items():
|
90 |
+
and_filters.append({field: condition})
|
91 |
+
return {"$and": and_filters}
|
92 |
+
return filters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/vectorstore.py
CHANGED
@@ -3,9 +3,8 @@
|
|
3 |
# And https://python.langchain.com/docs/integrations/vectorstores/pinecone
|
4 |
import os
|
5 |
import pinecone
|
6 |
-
from langchain.vectorstores import Pinecone
|
7 |
|
8 |
-
# LOAD ENVIRONMENT VARIABLES
|
9 |
try:
|
10 |
from dotenv import load_dotenv
|
11 |
|
@@ -14,7 +13,30 @@ except:
|
|
14 |
pass
|
15 |
|
16 |
|
17 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# initialize pinecone
|
19 |
pinecone.init(
|
20 |
api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
|
@@ -23,6 +45,6 @@ def get_pinecone_vectorstore(embeddings, text_key="content"):
|
|
23 |
|
24 |
index_name = os.getenv("PINECONE_API_INDEX")
|
25 |
vectorstore = Pinecone.from_existing_index(
|
26 |
-
index_name,
|
27 |
)
|
28 |
return vectorstore
|
|
|
3 |
# And https://python.langchain.com/docs/integrations/vectorstores/pinecone
|
4 |
import os
|
5 |
import pinecone
|
6 |
+
from langchain.vectorstores import Chroma, Pinecone
|
7 |
|
|
|
8 |
try:
|
9 |
from dotenv import load_dotenv
|
10 |
|
|
|
13 |
pass
|
14 |
|
15 |
|
16 |
+
def get_vectorstore(embeddings_function):
|
17 |
+
if has_pinecone_config():
|
18 |
+
return get_pinecone_vectorstore(embeddings_function)
|
19 |
+
return get_chroma_vectore_store(embeddings_function)
|
20 |
+
|
21 |
+
|
22 |
+
def get_chroma_vectore_store(embedding_function):
|
23 |
+
return Chroma(
|
24 |
+
persist_directory="./chroma_db", embedding_function=embedding_function
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def has_pinecone_config():
|
29 |
+
return all(
|
30 |
+
key in os.environ
|
31 |
+
for key in [
|
32 |
+
"PINECONE_API_KEY",
|
33 |
+
"PINECONE_API_ENVIRONMENT",
|
34 |
+
"PINECONE_API_INDEX",
|
35 |
+
]
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def get_pinecone_vectorstore(embeddings_function, text_key="content"):
|
40 |
# initialize pinecone
|
41 |
pinecone.init(
|
42 |
api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
|
|
|
45 |
|
46 |
index_name = os.getenv("PINECONE_API_INDEX")
|
47 |
vectorstore = Pinecone.from_existing_index(
|
48 |
+
index_name, embeddings_function, text_key=text_key
|
49 |
)
|
50 |
return vectorstore
|
constitution.pdf
ADDED
Binary file (414 kB). View file
|
|
data/daoism/tao-te-ching.pdf
ADDED
Binary file (174 kB). View file
|
|
data/us-founding/constitution.pdf
ADDED
Binary file (414 kB). View file
|
|
data/us-founding/declaration-of-independance.pdf
ADDED
Binary file (742 kB). View file
|
|
declaration-of-independance.pdf
ADDED
Binary file (742 kB). View file
|
|
requirements.txt
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
-
gradio==3.47.1
|
2 |
-
openai==0.27.0
|
3 |
azure-storage-file-share==12.11.1
|
4 |
-
|
|
|
5 |
langchain==0.0.295
|
|
|
6 |
pinecone-client==2.2.1
|
|
|
|
|
7 |
sentence-transformers==2.2.2
|
|
|
|
|
|
|
1 |
azure-storage-file-share==12.11.1
|
2 |
+
chromadb==0.4.14
|
3 |
+
gradio==3.47.1
|
4 |
langchain==0.0.295
|
5 |
+
openai==0.27.0
|
6 |
pinecone-client==2.2.1
|
7 |
+
pypdf==3.16.4
|
8 |
+
python-dotenv==1.0.0
|
9 |
sentence-transformers==2.2.2
|
utils.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import random
|
3 |
-
import string
|
4 |
import uuid
|
5 |
|
6 |
|
|
|
|
|
|
|
|
|
1 |
import uuid
|
2 |
|
3 |
|