big upload
Browse files- Dockerfile +17 -0
- README.md +8 -1
- app/api/__init__.py +0 -0
- app/api/routers/__init__.py +0 -0
- app/engine/__init__.py +0 -0
- app/engine/chunk_embed.py +90 -0
- app/engine/llm.py +0 -0
- app/engine/loaders/__init__.py +0 -0
- app/engine/loaders/file.py +105 -0
- app/engine/logger.py +10 -0
- app/engine/processing.py +48 -0
- app/engine/vectorstore.py +178 -0
- app/engine/weaviate_interface_v4.py +526 -0
- app/main.py +231 -0
- app/notebooks/chunking_indexing.ipynb +0 -0
- app/notebooks/lite_lll.ipynb +158 -0
- app/notebooks/pdf_readers.ipynb +0 -0
- app/notebooks/upload_index.ipynb +0 -0
- app/notebooks/weaviate.ipynb +0 -0
- app/rag/__init__.py +0 -0
- app/rag/llm.py +149 -0
- app/rag/rag.py +56 -0
- app/requirements.txt +20 -0
- app/settings.py +4 -0
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
ENV PYTHONDONTWRITEBYTECODE 1
|
4 |
+
# ^ saves space by not writing .pyc files
|
5 |
+
ENV PYTHONUNBUFFERED 1
|
6 |
+
# ^ ensures that the output from the Python app is sent straight to the terminal without being buffered -> real time monitoring
|
7 |
+
|
8 |
+
ENV ENVIRONMENT=dev
|
9 |
+
|
10 |
+
COPY ./app /app
|
11 |
+
WORKDIR /app
|
12 |
+
RUN mkdir /data
|
13 |
+
|
14 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
15 |
+
# ^ no caching of the packages to save space
|
16 |
+
|
17 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -1,4 +1,11 @@
|
|
1 |
---
|
2 |
-
license: mit
|
3 |
title: FinRAG
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
---
|
|
|
|
|
|
1 |
---
|
|
|
2 |
title: FinRAG
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: blue
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: mit
|
9 |
---
|
10 |
+
|
11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app/api/__init__.py
ADDED
File without changes
|
app/api/routers/__init__.py
ADDED
File without changes
|
app/engine/__init__.py
ADDED
File without changes
|
app/engine/chunk_embed.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from settings import parquet_file
|
8 |
+
|
9 |
+
import tiktoken # tokenizer library for use with OpenAI LLMs
|
10 |
+
from llama_index.legacy.text_splitter import SentenceSplitter
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
|
15 |
+
# create tensors on GPU if available
|
16 |
+
if torch.cuda.is_available():
|
17 |
+
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
18 |
+
|
19 |
+
|
20 |
+
def chunk_vectorize(doc_content: dict = None,
|
21 |
+
chunk_size: int = 256, # limit for 'all-mpnet-base-v2'
|
22 |
+
chunk_overlap: int = 20, # some overlap to link the chunks
|
23 |
+
encoder: str = 'gpt-3.5-turbo-0613',
|
24 |
+
model_name: str = 'sentence-transformers/all-mpnet-base-v2'): # can try all-MiniLM-L6-v2
|
25 |
+
# see tests in chunking_indexing.ipynb for more details
|
26 |
+
|
27 |
+
encoding = tiktoken.encoding_for_model(encoder)
|
28 |
+
|
29 |
+
splitter = SentenceSplitter(chunk_size=chunk_size,
|
30 |
+
tokenizer=encoding.encode,
|
31 |
+
chunk_overlap=chunk_overlap)
|
32 |
+
|
33 |
+
# let's create the splits for every document
|
34 |
+
contents_splits = {}
|
35 |
+
for fname, content in doc_content.items():
|
36 |
+
splits = [splitter.split_text(page) for page in content]
|
37 |
+
contents_splits[fname] = [split for sublist in splits for split in sublist]
|
38 |
+
|
39 |
+
model = SentenceTransformer(model_name)
|
40 |
+
|
41 |
+
content_emb = {}
|
42 |
+
for fname, splits in contents_splits.items():
|
43 |
+
content_emb[fname] = [(split, model.encode(split)) for split in splits]
|
44 |
+
|
45 |
+
# save fname since it carries information, and could be used as a property in Weaviate
|
46 |
+
text_vector_tuples = [(fname, split, emb.tolist()) for fname, splits_emb in content_emb.items() for split, emb in splits_emb]
|
47 |
+
|
48 |
+
new_df = pd.DataFrame(
|
49 |
+
text_vector_tuples,
|
50 |
+
columns=['file', 'content', 'content_embedding']
|
51 |
+
)
|
52 |
+
|
53 |
+
# load the existing parquet file if it exists and update it
|
54 |
+
if os.path.exists(parquet_file):
|
55 |
+
new_df = pd.concat([pd.read_parquet(parquet_file), new_df])
|
56 |
+
|
57 |
+
# no optimization here (zipping etc) since the data is small
|
58 |
+
new_df.to_parquet(parquet_file, index=False)
|
59 |
+
|
60 |
+
return
|
61 |
+
|
62 |
+
# TODO
|
63 |
+
# import unittest
|
64 |
+
# from unitesting_utils import load_impact_theory_data
|
65 |
+
|
66 |
+
# class TestSplitContents(unittest.TestCase):
|
67 |
+
# '''
|
68 |
+
# Unit test to ensure proper functionality of split_contents function
|
69 |
+
# '''
|
70 |
+
|
71 |
+
# def test_split_contents(self):
|
72 |
+
# import tiktoken
|
73 |
+
# from llama_index.text_splitter import SentenceSplitter
|
74 |
+
|
75 |
+
# data = load_impact_theory_data()
|
76 |
+
|
77 |
+
# subset = data[:3]
|
78 |
+
# chunk_size = 256
|
79 |
+
# chunk_overlap = 0
|
80 |
+
# encoding = tiktoken.encoding_for_model('gpt-3.5-turbo-0613')
|
81 |
+
# gpt35_txt_splitter = SentenceSplitter(chunk_size=chunk_size, tokenizer=encoding.encode, chunk_overlap=chunk_overlap)
|
82 |
+
# results = split_contents(subset, gpt35_txt_splitter)
|
83 |
+
# self.assertEqual(len(results), 3)
|
84 |
+
# self.assertEqual(len(results[0]), 83)
|
85 |
+
# self.assertEqual(len(results[1]), 178)
|
86 |
+
# self.assertEqual(len(results[2]), 144)
|
87 |
+
# self.assertTrue(isinstance(results, list))
|
88 |
+
# self.assertTrue(isinstance(results[0], list))
|
89 |
+
# self.assertTrue(isinstance(results[0][0], str))
|
90 |
+
# unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestSplitContents))
|
app/engine/llm.py
ADDED
File without changes
|
app/engine/loaders/__init__.py
ADDED
File without changes
|
app/engine/loaders/file.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# from langchain.document_loaders import PyPDFLoader # deprecated
|
4 |
+
from langchain_community.document_loaders import PyPDFLoader
|
5 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
+
from llama_parse import LlamaParse
|
7 |
+
|
8 |
+
from typing import Union, List, Dict
|
9 |
+
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
class PDFExtractor(ABC):
|
13 |
+
|
14 |
+
def __init__(self, file_or_list: Union[str, List[str]], num_workers: int = 1, verbose: bool = False):
|
15 |
+
""" We can provide a list of files or a single file """
|
16 |
+
if isinstance(file_or_list, str):
|
17 |
+
self.filelist = [file_or_list]
|
18 |
+
else:
|
19 |
+
self.filelist = file_or_list
|
20 |
+
self.num_workers = num_workers
|
21 |
+
self.verbose = verbose
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def extract_text(self) -> Dict[str, List[str]]:
|
26 |
+
""" Extracts text from the PDF, no processing.
|
27 |
+
Return a dictionary, key = filename, value = list of strings, one for each page.
|
28 |
+
"""
|
29 |
+
pass
|
30 |
+
|
31 |
+
@abstractmethod
|
32 |
+
def extract_images(self):
|
33 |
+
"""Extracts images from the PDF, no processing."""
|
34 |
+
pass
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def extract_tables(self):
|
38 |
+
""" Extracts tables from the PDF, no processing.
|
39 |
+
Return in json format
|
40 |
+
"""
|
41 |
+
pass
|
42 |
+
|
43 |
+
class _PyPDFLoader(PDFExtractor):
|
44 |
+
|
45 |
+
def extract_text(self):
|
46 |
+
output_dict = {}
|
47 |
+
for fpath in self.filelist:
|
48 |
+
fname = fpath.split('/')[-1]
|
49 |
+
output_dict[fname] = [p.page_content for p in PyPDFLoader(fpath).load()]
|
50 |
+
return output_dict
|
51 |
+
|
52 |
+
def extract_images(self):
|
53 |
+
raise NotImplementedError("Not implemented or PyPDFLoader does not support image extraction")
|
54 |
+
return
|
55 |
+
|
56 |
+
def extract_tables(self):
|
57 |
+
raise NotImplementedError("Not implemented or PyPDFLoader does not support table extraction")
|
58 |
+
return
|
59 |
+
|
60 |
+
|
61 |
+
class _LlamaParse(PDFExtractor):
|
62 |
+
|
63 |
+
def extract_text(self):
|
64 |
+
# https://github.com/run-llama/llama_parse
|
65 |
+
if os.getenv("LLAMA_PARSE_API_KEY") is None:
|
66 |
+
raise ValueError("LLAMA_PARSE_API_KEY is not set.")
|
67 |
+
|
68 |
+
parser = LlamaParse(
|
69 |
+
api_key = os.getenv("LLAMA_PARSE_API_KEY"),
|
70 |
+
num_workers=self.num_workers,
|
71 |
+
verbose=self.verbose,
|
72 |
+
language="en",
|
73 |
+
result_type="text" # or "markdown"
|
74 |
+
)
|
75 |
+
output_dict = {}
|
76 |
+
for fpath in self.filelist:
|
77 |
+
# https://github.com/run-llama/llama_parse/blob/main/examples/demo_json.ipynb
|
78 |
+
docs = parser.get_json_result(fpath)
|
79 |
+
docs[0]['pages'][0]['text']
|
80 |
+
output_dict[fpath] = None
|
81 |
+
return output_dict
|
82 |
+
|
83 |
+
def extract_images(self):
|
84 |
+
raise NotImplementedError("Not implemented or LlamaParse does not support image extraction")
|
85 |
+
return
|
86 |
+
|
87 |
+
def extract_tables(self):
|
88 |
+
raise NotImplementedError("Not implemented or LlamaParse does not support table extraction")
|
89 |
+
return
|
90 |
+
|
91 |
+
|
92 |
+
def pdf_extractor(extractor_type: str, *args, **kwargs) -> PDFExtractor:
|
93 |
+
""" Factory function to return the appropriate PDF extractor instance, properly initialized """
|
94 |
+
|
95 |
+
if extractor_type == 'PyPDFLoader':
|
96 |
+
return _PyPDFLoader(*args, **kwargs)
|
97 |
+
|
98 |
+
elif extractor_type == 'LlamaParse':
|
99 |
+
return _LlamaParse(*args, **kwargs)
|
100 |
+
else:
|
101 |
+
raise ValueError(f"Unsupported PDF extractor type: {extractor_type}")
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
app/engine/logger.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, logging
|
2 |
+
|
3 |
+
environment = os.getenv("ENVIRONMENT", "dev") # TODO put the logger creation in its own file
|
4 |
+
if environment == "dev":
|
5 |
+
logger = logging.getLogger("uvicorn")
|
6 |
+
else:
|
7 |
+
logger = lambda x: _
|
8 |
+
# we should log also in production TODO
|
9 |
+
# check how it works on HuggingFace, if possible
|
10 |
+
# because we don't have access to the container's file system
|
app/engine/processing.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pickle
|
2 |
+
from typing import List
|
3 |
+
from engine.loaders.file import pdf_extractor
|
4 |
+
from engine.chunk_embed import chunk_vectorize
|
5 |
+
from settings import parquet_file
|
6 |
+
from .logger import logger
|
7 |
+
from .vectorstore import VectorStore
|
8 |
+
# I allow relative imports inside the engine package
|
9 |
+
# I could have created a module but things are still changing
|
10 |
+
|
11 |
+
finrag_vectorstore = VectorStore(model_path='sentence-transformers/all-mpnet-base-v2')
|
12 |
+
|
13 |
+
|
14 |
+
def empty_collection():
|
15 |
+
""" Deletes the Finrag collection if it exists """
|
16 |
+
status = finrag_vectorstore.empty_collection()
|
17 |
+
return status
|
18 |
+
|
19 |
+
|
20 |
+
def index_data():
|
21 |
+
|
22 |
+
if not os.path.exists(parquet_file):
|
23 |
+
logger.info(f"Parquet file {parquet_file} does not exists")
|
24 |
+
return 'no data to index'
|
25 |
+
|
26 |
+
# load the parquet file into the vectorstore
|
27 |
+
finrag_vectorstore.index_data()
|
28 |
+
os.remove(parquet_file)
|
29 |
+
# delete the files so we can load several files and index them when we want
|
30 |
+
# without having to keep track of those that have been indexed already
|
31 |
+
# this is a simple solution for now, but we can do better
|
32 |
+
|
33 |
+
return "Index creation successful"
|
34 |
+
|
35 |
+
|
36 |
+
def process_pdf(filepath:str) -> dict:
|
37 |
+
|
38 |
+
new_content = pdf_extractor('PyPDFLoader', filepath).extract_text()
|
39 |
+
logger.info(f"Successfully extracted text from PDF")
|
40 |
+
|
41 |
+
chunk_vectorize(new_content)
|
42 |
+
logger.info(f"Successfully vectorized PDF content")
|
43 |
+
return new_content
|
44 |
+
|
45 |
+
def process_question(question:str) -> List[str]:
|
46 |
+
|
47 |
+
ans = finrag_vectorstore.hybrid_search(query=question, limit=3, alpha=0.8)
|
48 |
+
return ans
|
app/engine/vectorstore.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, logging
|
2 |
+
from typing import List, Any
|
3 |
+
import pandas as pd
|
4 |
+
from weaviate.classes.config import Property, DataType
|
5 |
+
|
6 |
+
from .weaviate_interface_v4 import WeaviateWCS, WeaviateIndexer
|
7 |
+
from .logger import logger
|
8 |
+
|
9 |
+
from settings import parquet_file
|
10 |
+
|
11 |
+
class VectorStore:
|
12 |
+
def __init__(self, model_path:str = 'sentence-transformers/all-mpnet-base-v2'):
|
13 |
+
# we can create several instances to test various models, especially if we finetune one
|
14 |
+
|
15 |
+
self.finrag_properties = [
|
16 |
+
Property(name='filename',
|
17 |
+
data_type=DataType.TEXT,
|
18 |
+
description='Name of the file',
|
19 |
+
index_filterable=True,
|
20 |
+
index_searchable=True),
|
21 |
+
# Property(name='keywords',
|
22 |
+
# data_type=DataType.TEXT_ARRAY,
|
23 |
+
# description='Keywords associated with the file',
|
24 |
+
# index_filterable=True,
|
25 |
+
# index_searchable=True),
|
26 |
+
Property(name='content',
|
27 |
+
data_type=DataType.TEXT,
|
28 |
+
description='Splits of the article',
|
29 |
+
index_filterable=True,
|
30 |
+
index_searchable=True),
|
31 |
+
]
|
32 |
+
|
33 |
+
self.class_name = "FinRag_all-mpnet-base-v2"
|
34 |
+
|
35 |
+
self.class_config = {'classes': [
|
36 |
+
|
37 |
+
{"class": self.class_name,
|
38 |
+
|
39 |
+
"description": "Financial reports",
|
40 |
+
|
41 |
+
"vectorIndexType": "hnsw",
|
42 |
+
|
43 |
+
# Vector index specific settings for HSNW
|
44 |
+
"vectorIndexConfig": {
|
45 |
+
|
46 |
+
"ef": 64, # higher is better quality vs slower search
|
47 |
+
"efConstruction": 128, # higher = better index but slower build
|
48 |
+
"maxConnections": 32, # max conn per layer - higher = more memory
|
49 |
+
},
|
50 |
+
|
51 |
+
"vectorizer": "none",
|
52 |
+
|
53 |
+
"properties": self.finrag_properties }
|
54 |
+
]
|
55 |
+
}
|
56 |
+
|
57 |
+
self.model_path = model_path
|
58 |
+
try:
|
59 |
+
self.api_key = os.environ['FINRAG_WEAVIATE_API_KEY']
|
60 |
+
self.url = os.environ['FINRAG_WEAVIATE_ENDPOINT']
|
61 |
+
self.client = WeaviateWCS(endpoint=self.url,
|
62 |
+
api_key=self.api_key,
|
63 |
+
model_name_or_path=self.model_path)
|
64 |
+
except Exception as e:
|
65 |
+
# raise Exception(f"Could not create Weaviate client: {e}")
|
66 |
+
pass
|
67 |
+
|
68 |
+
assert self.client._client.is_live(), "Weaviate is not live"
|
69 |
+
assert self.client._client.is_ready(), "Weaviate is not ready"
|
70 |
+
# careful with accessing '_client' since the weaviate helper usually closes the connection every time
|
71 |
+
|
72 |
+
self.indexer = None
|
73 |
+
|
74 |
+
self.create_collection()
|
75 |
+
|
76 |
+
@property
|
77 |
+
def collections(self):
|
78 |
+
|
79 |
+
return self.client.show_all_collections()
|
80 |
+
|
81 |
+
def create_collection(self, collection_name: str='Finrag', description: str='Financial reports'):
|
82 |
+
|
83 |
+
self.collection_name = collection_name
|
84 |
+
if collection_name not in self.collections:
|
85 |
+
self.client.create_collection(collection_name=collection_name,
|
86 |
+
properties=self.finrag_properties,
|
87 |
+
description=description)
|
88 |
+
self.collection_name = collection_name
|
89 |
+
else:
|
90 |
+
logging.warning(f"Collection {collection_name} already exists")
|
91 |
+
|
92 |
+
|
93 |
+
def empty_collection(self, collection_name: str='Finrag') -> bool:
|
94 |
+
|
95 |
+
# not in the library yet, so I simply delete and recreate it
|
96 |
+
if collection_name in self.collections:
|
97 |
+
self.client.delete_collection(collection_name=collection_name)
|
98 |
+
self.create_collection()
|
99 |
+
return True
|
100 |
+
else:
|
101 |
+
logging.warning(f"Collection {collection_name} doesn't exist")
|
102 |
+
return False
|
103 |
+
|
104 |
+
|
105 |
+
def index_data(self, data: List[dict]= None, collection_name: str='Finrag'):
|
106 |
+
|
107 |
+
if self.indexer is None:
|
108 |
+
self.indexer = WeaviateIndexer(self.client)
|
109 |
+
|
110 |
+
if data is None:
|
111 |
+
# use the parquet file, otherwise use the data passed
|
112 |
+
data = pd.read_parquet(parquet_file).to_dict('records')
|
113 |
+
# the parquet file was created/incremented when a new article was uploaded
|
114 |
+
# it is a dataframe with columns: file, content, content_embedding
|
115 |
+
# and reflects exactly the data that we want to index at all times
|
116 |
+
self.status = self.indexer.batch_index_data(data, collection_name, 256)
|
117 |
+
|
118 |
+
self.num_errors, self.error_messages, self.doc_ids = self.status
|
119 |
+
|
120 |
+
# in this case with few articles, we don't tolerate errors
|
121 |
+
# batch_index_data already tests errors against a threshold
|
122 |
+
# assert self.num_errors == 0, f"Errors: {self.num_errors}"
|
123 |
+
|
124 |
+
|
125 |
+
def keyword_search(self,
|
126 |
+
query: str,
|
127 |
+
limit: int=5,
|
128 |
+
return_properties: List[str]=['filename', 'content'],
|
129 |
+
alpha=None # dummy parameter to match the hybrid_search signature
|
130 |
+
) -> List[str]:
|
131 |
+
response = self.client.keyword_search(
|
132 |
+
request=query,
|
133 |
+
collection_name=self.collection_name,
|
134 |
+
query_properties=['content'],
|
135 |
+
limit=limit,
|
136 |
+
filter=None,
|
137 |
+
return_properties=return_properties,
|
138 |
+
return_raw=False)
|
139 |
+
|
140 |
+
return [res['content'] for res in response]
|
141 |
+
|
142 |
+
|
143 |
+
def vector_search(self,
|
144 |
+
query: str,
|
145 |
+
limit: int=5,
|
146 |
+
return_properties: List[str]=['filename', 'content'],
|
147 |
+
alpha=None # dummy parameter to match the hybrid_search signature
|
148 |
+
) -> List[str]:
|
149 |
+
|
150 |
+
response = self.client.vector_search(
|
151 |
+
request=query,
|
152 |
+
collection_name=self.collection_name,
|
153 |
+
limit=limit,
|
154 |
+
filter=None,
|
155 |
+
return_properties=return_properties,
|
156 |
+
return_raw=False)
|
157 |
+
|
158 |
+
return [res['content'] for res in response]
|
159 |
+
|
160 |
+
|
161 |
+
def hybrid_search(self,
|
162 |
+
query: str,
|
163 |
+
limit: int=5,
|
164 |
+
alpha=0.5, # higher = more vector search
|
165 |
+
return_properties: List[str]=['filename', 'content']
|
166 |
+
) -> List[str]:
|
167 |
+
|
168 |
+
response = self.client.hybrid_search(
|
169 |
+
request=query,
|
170 |
+
collection_name=self.collection_name,
|
171 |
+
query_properties=['content'],
|
172 |
+
alpha=alpha,
|
173 |
+
limit=limit,
|
174 |
+
filter=None,
|
175 |
+
return_properties=return_properties,
|
176 |
+
return_raw=False)
|
177 |
+
|
178 |
+
return [res['content'] for res in response]
|
app/engine/weaviate_interface_v4.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Disclaimer: I didn't write this module
|
2 |
+
|
3 |
+
from weaviate.auth import AuthApiKey
|
4 |
+
from weaviate.collections.classes.internal import (MetadataReturn, QueryReturn,
|
5 |
+
MetadataQuery)
|
6 |
+
import weaviate
|
7 |
+
from weaviate.classes.config import Property
|
8 |
+
from weaviate.classes.query import Filter
|
9 |
+
from weaviate.config import ConnectionConfig
|
10 |
+
from openai import OpenAI
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
from typing import Any
|
13 |
+
from torch import cuda
|
14 |
+
from tqdm import tqdm
|
15 |
+
import time
|
16 |
+
import os
|
17 |
+
from dataclasses import dataclass
|
18 |
+
|
19 |
+
class WeaviateWCS:
|
20 |
+
'''
|
21 |
+
A python native Weaviate Client class that encapsulates Weaviate functionalities
|
22 |
+
in one object. Several convenience methods are added for ease of use.
|
23 |
+
|
24 |
+
Args
|
25 |
+
----
|
26 |
+
api_key: str
|
27 |
+
The API key for the Weaviate Cloud Service (WCS) instance.
|
28 |
+
https://console.weaviate.cloud/dashboard
|
29 |
+
|
30 |
+
endpoint: str
|
31 |
+
The url endpoint for the Weaviate Cloud Service instance.
|
32 |
+
|
33 |
+
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2'
|
34 |
+
The name or path of the SentenceTransformer model to use for vector search.
|
35 |
+
Will also support OpenAI text-embedding-ada-002 model. This param enables
|
36 |
+
the use of most leading models on MTEB Leaderboard:
|
37 |
+
https://huggingface.co/spaces/mteb/leaderboard
|
38 |
+
openai_api_key: str=None
|
39 |
+
The API key for the OpenAI API. Only required if using OpenAI text-embedding-ada-002 model.
|
40 |
+
'''
|
41 |
+
def __init__(self,
|
42 |
+
endpoint: str=None,
|
43 |
+
api_key: str=None,
|
44 |
+
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2',
|
45 |
+
embedded: bool=False,
|
46 |
+
openai_api_key: str=None,
|
47 |
+
skip_init_checks: bool=False,
|
48 |
+
**kwargs
|
49 |
+
):
|
50 |
+
|
51 |
+
self.endpoint = endpoint
|
52 |
+
if embedded:
|
53 |
+
self._client = weaviate.connect_to_embedded(**kwargs)
|
54 |
+
else:
|
55 |
+
auth_config = AuthApiKey(api_key=api_key)
|
56 |
+
self._client = weaviate.connect_to_wcs(cluster_url=endpoint,
|
57 |
+
auth_credentials=auth_config,
|
58 |
+
skip_init_checks=skip_init_checks)
|
59 |
+
self.model_name_or_path = model_name_or_path
|
60 |
+
self._openai_model = False
|
61 |
+
if self.model_name_or_path == 'text-embedding-ada-002':
|
62 |
+
if not openai_api_key:
|
63 |
+
raise ValueError(f'OpenAI API key must be provided to use this model: {self.model_name_or_path}')
|
64 |
+
self.model = OpenAI(api_key=openai_api_key)
|
65 |
+
self._openai_model = True
|
66 |
+
else:
|
67 |
+
self.model = SentenceTransformer(self.model_name_or_path) if self.model_name_or_path else None
|
68 |
+
|
69 |
+
self.return_properties = ['guest', 'title', 'summary', 'content', 'video_id', 'doc_id', 'episode_url', 'thumbnail_url']
|
70 |
+
|
71 |
+
def _connect(self) -> None:
|
72 |
+
'''
|
73 |
+
Connects to Weaviate instance.
|
74 |
+
'''
|
75 |
+
if not self._client.is_connected():
|
76 |
+
self._client.connect()
|
77 |
+
|
78 |
+
def create_collection(self,
|
79 |
+
collection_name: str,
|
80 |
+
properties: list[Property],
|
81 |
+
description: str=None,
|
82 |
+
**kwargs
|
83 |
+
) -> None:
|
84 |
+
'''
|
85 |
+
Creates a collection (index) on the Weaviate instance.
|
86 |
+
|
87 |
+
Args
|
88 |
+
----
|
89 |
+
collection_name: str
|
90 |
+
Name of the collection to create.
|
91 |
+
properties: list[Property]
|
92 |
+
List of properties to add to data objects in the collection.
|
93 |
+
description: str=None
|
94 |
+
User-defined description of the collection.
|
95 |
+
'''
|
96 |
+
|
97 |
+
self._connect()
|
98 |
+
if self._client.collections.exists(collection_name):
|
99 |
+
print(f'Collection "{collection_name}" already exists')
|
100 |
+
return
|
101 |
+
else:
|
102 |
+
try:
|
103 |
+
self._client.collections.create(name=collection_name,
|
104 |
+
properties=properties,
|
105 |
+
description=description,
|
106 |
+
**kwargs)
|
107 |
+
print(f'Collection "{collection_name}" created')
|
108 |
+
except Exception as e:
|
109 |
+
print(f'Error creating collection, due to: {e}')
|
110 |
+
self._client.close()
|
111 |
+
return
|
112 |
+
|
113 |
+
def show_all_collections(self,
|
114 |
+
detailed: bool=False,
|
115 |
+
max_details: bool=False
|
116 |
+
) -> list[str] | dict:
|
117 |
+
'''
|
118 |
+
Shows all available collections(indexes) on the Weaviate cluster.
|
119 |
+
By default will only return list of collection names.
|
120 |
+
Otherwise, increasing details about each collection can be returned.
|
121 |
+
'''
|
122 |
+
self._connect()
|
123 |
+
collections = self._client.collections.list_all(simple=not max_details)
|
124 |
+
self._client.close()
|
125 |
+
if not detailed and not max_details:
|
126 |
+
return list(collections.keys())
|
127 |
+
else:
|
128 |
+
if not any(collections):
|
129 |
+
print('No collections found on host')
|
130 |
+
return collections
|
131 |
+
|
132 |
+
def show_collection_config(self, collection_name: str) -> ConnectionConfig:
|
133 |
+
'''
|
134 |
+
Shows all information of a specific collection.
|
135 |
+
'''
|
136 |
+
self._connect()
|
137 |
+
if self._client.collections.exists(collection_name):
|
138 |
+
collection = self.show_all_collections(max_details=True)[collection_name]
|
139 |
+
self._client.close()
|
140 |
+
return collection
|
141 |
+
else:
|
142 |
+
print(f'Collection "{collection_name}" not found on host')
|
143 |
+
|
144 |
+
def show_collection_properties(self, collection_name: str) -> dict | str:
|
145 |
+
'''
|
146 |
+
Shows all properties of a collection (index) on the Weaviate instance.
|
147 |
+
'''
|
148 |
+
self._connect()
|
149 |
+
if self._client.collections.exists(collection_name):
|
150 |
+
collection = self.show_all_collections(max_details=True)[collection_name]
|
151 |
+
self._client.close()
|
152 |
+
return collection.properties
|
153 |
+
else:
|
154 |
+
print(f'Collection "{collection_name}" not found on host')
|
155 |
+
|
156 |
+
def delete_collection(self, collection_name: str) -> str:
|
157 |
+
'''
|
158 |
+
Deletes a collection (index) on the Weaviate instance, if it exists.
|
159 |
+
'''
|
160 |
+
self._connect()
|
161 |
+
if self._client.collections.exists(collection_name):
|
162 |
+
try:
|
163 |
+
self._client.collections.delete(collection_name)
|
164 |
+
self._client.close()
|
165 |
+
print(f'Collection "{collection_name}" deleted')
|
166 |
+
except Exception as e:
|
167 |
+
print(f'Error deleting collection, due to: {e}')
|
168 |
+
else:
|
169 |
+
print(f'Collection "{collection_name}" not found on host')
|
170 |
+
|
171 |
+
def get_doc_count(self, collection_name: str) -> str:
|
172 |
+
'''
|
173 |
+
Returns the number of documents in a collection.
|
174 |
+
'''
|
175 |
+
self._connect()
|
176 |
+
if self._client.collections.exists(collection_name):
|
177 |
+
collection = self._client.collections.get(collection_name)
|
178 |
+
aggregate = collection.aggregate.over_all()
|
179 |
+
total_count = aggregate.total_count
|
180 |
+
print(f'Found {total_count} documents in collection "{collection_name}"')
|
181 |
+
return total_count
|
182 |
+
else:
|
183 |
+
print(f'Collection "{collection_name}" not found on host')
|
184 |
+
|
185 |
+
def format_response(self,
|
186 |
+
response: QueryReturn,
|
187 |
+
) -> list[dict]:
|
188 |
+
'''
|
189 |
+
Formats json response from Weaviate into a list of dictionaries.
|
190 |
+
Expands _additional fields if present into top-level dictionary.
|
191 |
+
'''
|
192 |
+
results = [{**o.properties, **self._get_meta(o.metadata)} for o in response.objects]
|
193 |
+
return results
|
194 |
+
|
195 |
+
def _get_meta(self, metadata: MetadataReturn):
|
196 |
+
'''
|
197 |
+
Extracts metadata from MetadataQuery object if meta exists.
|
198 |
+
'''
|
199 |
+
temp_dict = metadata.__dict__
|
200 |
+
return {k:v for k,v in temp_dict.items() if v}
|
201 |
+
|
202 |
+
def keyword_search(self,
|
203 |
+
request: str,
|
204 |
+
collection_name: str,
|
205 |
+
query_properties: list[str]=['content'],
|
206 |
+
limit: int=10,
|
207 |
+
filter: Filter=None,
|
208 |
+
return_properties: list[str]=None,
|
209 |
+
return_raw: bool=False
|
210 |
+
) -> dict | list[dict]:
|
211 |
+
'''
|
212 |
+
Executes Keyword (BM25) search.
|
213 |
+
|
214 |
+
Args
|
215 |
+
----
|
216 |
+
request: str
|
217 |
+
User query.
|
218 |
+
collection_name: str
|
219 |
+
Collection (index) to search.
|
220 |
+
query_properties: list[str]
|
221 |
+
list of properties to search across.
|
222 |
+
limit: int=10
|
223 |
+
Number of results to return.
|
224 |
+
where_filter: dict=None
|
225 |
+
Property filter to apply to search results.
|
226 |
+
return_properties: list[str]=None
|
227 |
+
list of properties to return in response.
|
228 |
+
If None, returns self.return_properties.
|
229 |
+
return_raw: bool=False
|
230 |
+
If True, returns raw response from Weaviate.
|
231 |
+
'''
|
232 |
+
self._connect()
|
233 |
+
return_properties = return_properties if return_properties else self.return_properties
|
234 |
+
collection = self._client.collections.get(collection_name)
|
235 |
+
response = collection.query.bm25(query=request,
|
236 |
+
query_properties=query_properties,
|
237 |
+
limit=limit,
|
238 |
+
filters=filter,
|
239 |
+
return_metadata=MetadataQuery(score=True),
|
240 |
+
return_properties=return_properties)
|
241 |
+
# response = response.with_where(where_filter).do() if where_filter else response.do()
|
242 |
+
if return_raw:
|
243 |
+
return response
|
244 |
+
else:
|
245 |
+
return self.format_response(response)
|
246 |
+
|
247 |
+
def vector_search(self,
|
248 |
+
request: str,
|
249 |
+
collection_name: str,
|
250 |
+
limit: int=10,
|
251 |
+
return_properties: list[str]=None,
|
252 |
+
filter: Filter=None,
|
253 |
+
return_raw: bool=False,
|
254 |
+
device: str='cuda:0' if cuda.is_available() else 'cpu'
|
255 |
+
) -> dict | list[dict]:
|
256 |
+
'''
|
257 |
+
Executes vector search using embedding model defined on instantiation
|
258 |
+
of WeaviateClient instance.
|
259 |
+
|
260 |
+
Args
|
261 |
+
----
|
262 |
+
request: str
|
263 |
+
User query.
|
264 |
+
collection_name: str
|
265 |
+
Collection (index) to search.
|
266 |
+
limit: int=10
|
267 |
+
Number of results to return.
|
268 |
+
return_properties: list[str]=None
|
269 |
+
list of properties to return in response.
|
270 |
+
If None, returns all properties.
|
271 |
+
return_raw: bool=False
|
272 |
+
If True, returns raw response from Weaviate.
|
273 |
+
device: str
|
274 |
+
Device to use for encoding query.
|
275 |
+
'''
|
276 |
+
self._connect()
|
277 |
+
return_properties = return_properties if return_properties else self.return_properties
|
278 |
+
query_vector = self._create_query_vector(request, device=device)
|
279 |
+
collection = self._client.collections.get(collection_name)
|
280 |
+
response = collection.query.near_vector(near_vector=query_vector,
|
281 |
+
limit=limit,
|
282 |
+
filters=filter,
|
283 |
+
return_metadata=MetadataQuery(distance=True),
|
284 |
+
return_properties=return_properties)
|
285 |
+
# response = response.with_where(where_filter).do() if where_filter else response.do()
|
286 |
+
if return_raw:
|
287 |
+
return response
|
288 |
+
else:
|
289 |
+
return self.format_response(response)
|
290 |
+
|
291 |
+
def _create_query_vector(self, query: str, device: str) -> list[float]:
|
292 |
+
'''
|
293 |
+
Creates embedding vector from text query.
|
294 |
+
'''
|
295 |
+
return self.get_openai_embedding(query) if self._openai_model else self.model.encode(query, device=device).tolist()
|
296 |
+
|
297 |
+
def get_openai_embedding(self, query: str) -> list[float]:
|
298 |
+
'''
|
299 |
+
Gets embedding from OpenAI API for query.
|
300 |
+
'''
|
301 |
+
embedding = self.model.embeddings.create(input=query, model='text-embedding-ada-002').model_dump()
|
302 |
+
if embedding:
|
303 |
+
return embedding['data'][0]['embedding']
|
304 |
+
else:
|
305 |
+
raise ValueError(f'No embedding found for query: {query}')
|
306 |
+
|
307 |
+
def hybrid_search(self,
|
308 |
+
request: str,
|
309 |
+
collection_name: str,
|
310 |
+
query_properties: list[str]=['content'],
|
311 |
+
alpha: float=0.5,
|
312 |
+
limit: int=10,
|
313 |
+
filter: Filter=None,
|
314 |
+
return_properties: list[str]=None,
|
315 |
+
return_raw: bool=False,
|
316 |
+
device: str='cuda:0' if cuda.is_available() else 'cpu'
|
317 |
+
) -> dict | list[dict]:
|
318 |
+
'''
|
319 |
+
Executes Hybrid (Keyword + Vector) search.
|
320 |
+
|
321 |
+
Args
|
322 |
+
----
|
323 |
+
request: str
|
324 |
+
User query.
|
325 |
+
collection_name: str
|
326 |
+
Collection (index) to search.
|
327 |
+
query_properties: list[str]
|
328 |
+
list of properties to search across (using BM25)
|
329 |
+
alpha: float=0.5
|
330 |
+
Weighting factor for BM25 and Vector search.
|
331 |
+
alpha can be any number from 0 to 1, defaulting to 0.5:
|
332 |
+
alpha = 0 executes a pure keyword search method (BM25)
|
333 |
+
alpha = 0.5 weighs the BM25 and vector methods evenly
|
334 |
+
alpha = 1 executes a pure vector search method
|
335 |
+
limit: int=10
|
336 |
+
Number of results to return.
|
337 |
+
filter: Filter=None
|
338 |
+
Property filter to apply to search results.
|
339 |
+
return_properties: list[str]=None
|
340 |
+
list of properties to return in response.
|
341 |
+
If None, returns all properties.
|
342 |
+
return_raw: bool=False
|
343 |
+
If True, returns raw response from Weaviate.
|
344 |
+
'''
|
345 |
+
self._connect()
|
346 |
+
return_properties = return_properties if return_properties else self.return_properties
|
347 |
+
query_vector = self._create_query_vector(request, device=device)
|
348 |
+
collection = self._client.collections.get(collection_name)
|
349 |
+
response = collection.query.hybrid(query=request,
|
350 |
+
query_properties=query_properties,
|
351 |
+
filters=filter,
|
352 |
+
vector=query_vector,
|
353 |
+
alpha=alpha,
|
354 |
+
limit=limit,
|
355 |
+
return_metadata=MetadataQuery(score=True, distance=True),
|
356 |
+
return_properties=return_properties)
|
357 |
+
if return_raw:
|
358 |
+
return response
|
359 |
+
else:
|
360 |
+
return self.format_response(response)
|
361 |
+
|
362 |
+
|
363 |
+
class WeaviateIndexer:
|
364 |
+
|
365 |
+
def __init__(self,
|
366 |
+
client: WeaviateWCS
|
367 |
+
):
|
368 |
+
'''
|
369 |
+
Class designed to batch index documents into Weaviate. Instantiating
|
370 |
+
this class will automatically configure the Weaviate batch client.
|
371 |
+
'''
|
372 |
+
|
373 |
+
self._client = client._client
|
374 |
+
|
375 |
+
def _connect(self):
|
376 |
+
'''
|
377 |
+
Connects to Weaviate instance.
|
378 |
+
'''
|
379 |
+
if not self._client.is_connected():
|
380 |
+
self._client.connect()
|
381 |
+
|
382 |
+
def create_collection(self,
|
383 |
+
collection_name: str,
|
384 |
+
properties: list[Property],
|
385 |
+
description: str=None,
|
386 |
+
**kwargs
|
387 |
+
) -> str:
|
388 |
+
'''
|
389 |
+
Creates a collection (index) on the Weaviate instance.
|
390 |
+
'''
|
391 |
+
if collection_name.find('-') != -1:
|
392 |
+
raise ValueError('Collection name cannot contain hyphens')
|
393 |
+
try:
|
394 |
+
self._connect()
|
395 |
+
self._client.collections.create(name=collection_name,
|
396 |
+
description=description,
|
397 |
+
properties=properties,
|
398 |
+
**kwargs
|
399 |
+
)
|
400 |
+
if self._client.collections.exists(collection_name):
|
401 |
+
print(f'Collection "{collection_name}" created')
|
402 |
+
else:
|
403 |
+
print(f'Collection not found at the moment, try again later')
|
404 |
+
self._client.close()
|
405 |
+
except Exception as e:
|
406 |
+
print(f'Error creating collection, due to: {e}')
|
407 |
+
|
408 |
+
def batch_index_data(self,
|
409 |
+
data: list[dict],
|
410 |
+
collection_name: str,
|
411 |
+
error_threshold: float=0.01,
|
412 |
+
vector_property: str='content_embedding',
|
413 |
+
unique_id_field: str='doc_id',
|
414 |
+
properties: list[Property]=None,
|
415 |
+
collection_description: str=None,
|
416 |
+
**kwargs
|
417 |
+
) -> dict:
|
418 |
+
'''
|
419 |
+
Batch function for fast indexing of data onto Weaviate cluster.
|
420 |
+
|
421 |
+
Args
|
422 |
+
----
|
423 |
+
data: list[dict]
|
424 |
+
List of dictionaries where each dictionary represents a document.
|
425 |
+
collection_name: str
|
426 |
+
Name of the collection to index data into.
|
427 |
+
error_threshold: float=0.01
|
428 |
+
Threshold for error rate during batch upload. This value is a percentage of the total data
|
429 |
+
that the end user is willing to tolerate as errors. If the error rate exceeds this threshold,
|
430 |
+
the batch job will be aborted.
|
431 |
+
vector_property: str='content_embedding'
|
432 |
+
Name of the property that contains the vector representation of the document.
|
433 |
+
unique_id_field: str='doc_id'
|
434 |
+
Name of the unique identifier field in the document.
|
435 |
+
properties: list[Property]=None
|
436 |
+
List of properties to create the collection with. Required if collection does not exist.
|
437 |
+
collection_description: str=None
|
438 |
+
Description of the collection. Optional parameter.
|
439 |
+
|
440 |
+
Returns
|
441 |
+
-------
|
442 |
+
dict
|
443 |
+
Dictionary containing error information if any with the following keys:
|
444 |
+
['num_errors', 'error_messages', 'doc_ids']
|
445 |
+
'''
|
446 |
+
self._connect()
|
447 |
+
if not self._client.collections.exists(collection_name):
|
448 |
+
print(f'Collection "{collection_name}" not found on host, creating Collection first...')
|
449 |
+
if properties is None:
|
450 |
+
raise ValueError(f'Tried to create Collection <{collection_name}> but no properties were provided.')
|
451 |
+
self.create_collection(collection_name=collection_name,
|
452 |
+
properties=properties,
|
453 |
+
description=collection_description,
|
454 |
+
**kwargs)
|
455 |
+
self._client.close()
|
456 |
+
|
457 |
+
self._connect()
|
458 |
+
error_threshold_size = int(len(data) * error_threshold)
|
459 |
+
collection = self._client.collections.get(collection_name)
|
460 |
+
|
461 |
+
start = time.perf_counter()
|
462 |
+
completed_job = True
|
463 |
+
|
464 |
+
with collection.batch.dynamic() as batch:
|
465 |
+
for doc in tqdm(data):
|
466 |
+
batch.add_object(properties={k:v for k,v in doc.items() if k != vector_property},
|
467 |
+
vector=doc[vector_property])
|
468 |
+
if batch.number_errors > error_threshold_size:
|
469 |
+
print('Upload errors exceed error_threshold...')
|
470 |
+
completed_job = False
|
471 |
+
break
|
472 |
+
end = time.perf_counter() - start
|
473 |
+
print(f'Processing finished in {round(end/60, 2)} minutes.')
|
474 |
+
|
475 |
+
failed_objects = collection.batch.failed_objects
|
476 |
+
if any(failed_objects):
|
477 |
+
error_messages = [obj.message for obj in failed_objects]
|
478 |
+
doc_ids = [obj.object_.properties.get(unique_id_field, 'Not Found') for obj in failed_objects]
|
479 |
+
else:
|
480 |
+
error_messages, doc_ids = [], []
|
481 |
+
error_object = {'num_errors':batch.number_errors,
|
482 |
+
'error_messages': error_messages,
|
483 |
+
'doc_ids': doc_ids}
|
484 |
+
if not completed_job:
|
485 |
+
print(f'Batch job failed. Review errors using these keys: {list(error_object.keys())}')
|
486 |
+
return error_object
|
487 |
+
if batch.number_errors > 0:
|
488 |
+
print(f'Batch job completed with {batch.number_errors} errors. Review errors using these keys: {list(error_object.keys())}')
|
489 |
+
else:
|
490 |
+
print('Batch job completed with zero errors.')
|
491 |
+
return error_object
|
492 |
+
|
493 |
+
|
494 |
+
@dataclass
|
495 |
+
class SearchFilter(Filter):
|
496 |
+
|
497 |
+
'''
|
498 |
+
Simplified interface for constructing a Filter object.
|
499 |
+
|
500 |
+
Args
|
501 |
+
----
|
502 |
+
property : str
|
503 |
+
Property to filter on.
|
504 |
+
query_value : str
|
505 |
+
Query value to filter on.
|
506 |
+
'''
|
507 |
+
property: str
|
508 |
+
query_value: str
|
509 |
+
|
510 |
+
def exact_match(self):
|
511 |
+
return self.by_property(self.property).equal(self.query_value)
|
512 |
+
|
513 |
+
def fuzzy_match(self):
|
514 |
+
return self.by_property(self.property).like(f'*{self.query_value}*')
|
515 |
+
|
516 |
+
|
517 |
+
|
518 |
+
def get_weaviate_client(endpoint: str=os.getenv('FINRAG_WEAVIATE_ENDPOINT'),
|
519 |
+
api_key: str=os.getenv('FINRAG_WEAVIATE_API_KEY'),
|
520 |
+
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2',
|
521 |
+
embedded: bool=False,
|
522 |
+
openai_api_key: str=None,
|
523 |
+
skip_init_checks: bool=False,
|
524 |
+
**kwargs
|
525 |
+
) -> WeaviateWCS:
|
526 |
+
return WeaviateWCS(endpoint, api_key, model_name_or_path, embedded, openai_api_key, skip_init_checks, **kwargs)
|
app/main.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os, random, logging, pickle, shutil
|
3 |
+
from dotenv import load_dotenv, find_dotenv
|
4 |
+
from typing import Optional
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
|
7 |
+
from fastapi import FastAPI, HTTPException, File, UploadFile, status
|
8 |
+
from fastapi.responses import HTMLResponse
|
9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
10 |
+
|
11 |
+
from engine.processing import process_pdf, index_data, empty_collection, vector_search
|
12 |
+
from rag.rag import rag_it
|
13 |
+
|
14 |
+
from engine.logger import logger
|
15 |
+
|
16 |
+
from settings import datadir
|
17 |
+
|
18 |
+
os.makedirs(datadir, exist_ok=True)
|
19 |
+
|
20 |
+
app = FastAPI()
|
21 |
+
|
22 |
+
environment = os.getenv("ENVIRONMENT", "dev") # created by dockerfile
|
23 |
+
|
24 |
+
if environment == "dev":
|
25 |
+
logger.warning("Running in development mode - allowing CORS for all origins")
|
26 |
+
app.add_middleware(
|
27 |
+
CORSMiddleware,
|
28 |
+
allow_origins=["*"],
|
29 |
+
allow_credentials=True,
|
30 |
+
allow_methods=["*"],
|
31 |
+
allow_headers=["*"],
|
32 |
+
)
|
33 |
+
|
34 |
+
try:
|
35 |
+
# will not work on HuggingFace
|
36 |
+
# and Liquidity dont' have the env anyway
|
37 |
+
load_dotenv(find_dotenv('env'))
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
pass
|
41 |
+
|
42 |
+
|
43 |
+
@app.get("/", response_class=HTMLResponse)
|
44 |
+
def read_root():
|
45 |
+
logger.info("Title displayed on home page")
|
46 |
+
return """
|
47 |
+
<html>
|
48 |
+
<body>
|
49 |
+
<h1>Welcome to FinExpert, a RAG system designed by JP Bianchi!</h1>
|
50 |
+
</body>
|
51 |
+
</html>
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
@app.get("/ping/")
|
56 |
+
def ping():
|
57 |
+
""" Testing """
|
58 |
+
logger.info("Someone is pinging the server")
|
59 |
+
return {"answer": str(random.random() * 100)}
|
60 |
+
|
61 |
+
|
62 |
+
@app.delete("/erase_data/")
|
63 |
+
def erase_data():
|
64 |
+
""" Erase all files in the data directory, but not the vector store """
|
65 |
+
if len(os.listdir(datadir)) == 0:
|
66 |
+
logger.info("No data to erase")
|
67 |
+
return {"message": "No data to erase"}
|
68 |
+
|
69 |
+
shutil.rmtree(datadir, ignore_errors=True)
|
70 |
+
os.mkdir(datadir)
|
71 |
+
logger.warning("All data has been erased")
|
72 |
+
return {"message": "All data has been erased"}
|
73 |
+
|
74 |
+
|
75 |
+
@app.delete("/empty_collection/")
|
76 |
+
def delete_vectors():
|
77 |
+
""" Empty the collection in the vector store """
|
78 |
+
try:
|
79 |
+
status = empty_collection()
|
80 |
+
return {f"""message": "Collection{'' if status else ' NOT'} erased!"""}
|
81 |
+
except Exception as e:
|
82 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
83 |
+
|
84 |
+
@app.get("/list_files/")
|
85 |
+
def list_files():
|
86 |
+
""" List all files in the data directory """
|
87 |
+
files = os.listdir(datadir)
|
88 |
+
logger.info(f"Files in data directory: {files}")
|
89 |
+
return {"files": files}
|
90 |
+
|
91 |
+
|
92 |
+
@app.post("/upload/")
|
93 |
+
# @limiter.limit("5/minute") see 'slowapi' for rate limiting
|
94 |
+
async def upload_file(file: UploadFile = File(...)):
|
95 |
+
""" Uploads a file in data directory, for later indexing """
|
96 |
+
try:
|
97 |
+
filepath = os.path.join(datadir, file.filename)
|
98 |
+
logger.info(f"Fiename detected: {file.filename}")
|
99 |
+
if os.path.exists(filepath):
|
100 |
+
logger.warning(f"File {file.filename} already exists: no processing done")
|
101 |
+
return {"message": f"File {file.filename} already exists: no processing done"}
|
102 |
+
|
103 |
+
else:
|
104 |
+
logger.info(f"Receiving file: {file.filename}")
|
105 |
+
contents = await file.read()
|
106 |
+
logger.info(f"File reception complete!")
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
logger.error(f"Error during file upload: {str(e)}")
|
110 |
+
return {"message": f"Error during file upload: {str(e)}"}
|
111 |
+
|
112 |
+
if file.filename.endswith('.pdf'):
|
113 |
+
|
114 |
+
# let's save the file in /data even if it's temp storage on HF
|
115 |
+
with open(filepath, 'wb') as f:
|
116 |
+
f.write(contents)
|
117 |
+
|
118 |
+
try:
|
119 |
+
logger.info(f"Starting to process {file.filename}")
|
120 |
+
new_content = process_pdf(filepath)
|
121 |
+
success = {"message": f"Successfully uploaded {file.filename}"}
|
122 |
+
success.update(new_content)
|
123 |
+
return success
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
return {"message": f"Failed to extract text from PDF: {str(e)}"}
|
127 |
+
else:
|
128 |
+
return {"message": "Only PDF files are accepted"}
|
129 |
+
|
130 |
+
|
131 |
+
@app.post("/create_index/")
|
132 |
+
async def create_index():
|
133 |
+
""" Create an index for the uploaded files """
|
134 |
+
|
135 |
+
logger.info("Creating index for uploaded files")
|
136 |
+
try:
|
137 |
+
msg = index_data()
|
138 |
+
return {"message": msg}
|
139 |
+
except Exception as e:
|
140 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
141 |
+
|
142 |
+
|
143 |
+
class Question(BaseModel):
|
144 |
+
question: str
|
145 |
+
|
146 |
+
@app.post("/ask/")
|
147 |
+
async def hybrid_search(question: Question):
|
148 |
+
logger.info(f"Processing question: {question.question}")
|
149 |
+
try:
|
150 |
+
search_results = vector_search(question.question)
|
151 |
+
logger.info(f"Answer: {search_results}")
|
152 |
+
return {"answer": search_results}
|
153 |
+
except Exception as e:
|
154 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
155 |
+
|
156 |
+
|
157 |
+
@app.post("/ragit/")
|
158 |
+
async def ragit(question: Question):
|
159 |
+
logger.info(f"Processing question: {question.question}")
|
160 |
+
try:
|
161 |
+
search_results = vector_search(question.question)
|
162 |
+
logger.info(f"Search results generated: {search_results}")
|
163 |
+
|
164 |
+
answer = rag_it(question.question, search_results)
|
165 |
+
|
166 |
+
logger.info(f"Answer: {answer}")
|
167 |
+
return {"answer": answer}
|
168 |
+
except Exception as e:
|
169 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
170 |
+
|
171 |
+
|
172 |
+
# TODO
|
173 |
+
# rejects searches with a search score below a threshold
|
174 |
+
# scrape the tables (and find a way to reject them from the text search -> LLamaparse)
|
175 |
+
# see why the filename in search results is always empty
|
176 |
+
# -> add it to the search results to avoid confusion Google-Amazon for instance
|
177 |
+
# add python scripts to create index, rag etc
|
178 |
+
|
179 |
+
if __name__ == '__main__':
|
180 |
+
import uvicorn
|
181 |
+
from os import getenv
|
182 |
+
port = int(getenv("PORT", 80))
|
183 |
+
print(f"Starting server on port {port}")
|
184 |
+
reload = True if environment == "dev" else False
|
185 |
+
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=reload)
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
# Examples:
|
190 |
+
# curl -X POST "http://localhost:80/upload" -F "[email protected]"
|
191 |
+
# curl -X DELETE "http://localhost:80/erase_data/"
|
192 |
+
# curl -X GET "http://localhost:80/list_files/"
|
193 |
+
|
194 |
+
# hf space is at https://jpbianchi-finrag.hf.space/
|
195 |
+
# code given by https://jpbianchi-finrag.hf.space/docs
|
196 |
+
# Space must be public
|
197 |
+
# curl -X POST "https://jpbianchi-finrag.hf.space/upload/" -F "[email protected]"
|
198 |
+
|
199 |
+
# curl -X POST http://localhost:80/ask/ -H "Content-Type: application/json" -d '{"question": "what is Amazon loss"}'
|
200 |
+
# curl -X POST http://localhost:80/ragit/ -H "Content-Type: application/json" -d '{"question": "Does ATT have postpaid phone customers?"}'
|
201 |
+
|
202 |
+
|
203 |
+
# TODO
|
204 |
+
# import unittest
|
205 |
+
# from unitesting_utils import load_impact_theory_data
|
206 |
+
|
207 |
+
# class TestSplitContents(unittest.TestCase):
|
208 |
+
# '''
|
209 |
+
# Unit test to ensure proper functionality of split_contents function
|
210 |
+
# '''
|
211 |
+
|
212 |
+
# def test_split_contents(self):
|
213 |
+
# import tiktoken
|
214 |
+
# from llama_index.text_splitter import SentenceSplitter
|
215 |
+
|
216 |
+
# data = load_impact_theory_data()
|
217 |
+
|
218 |
+
# subset = data[:3]
|
219 |
+
# chunk_size = 256
|
220 |
+
# chunk_overlap = 0
|
221 |
+
# encoding = tiktoken.encoding_for_model('gpt-3.5-turbo-0613')
|
222 |
+
# gpt35_txt_splitter = SentenceSplitter(chunk_size=chunk_size, tokenizer=encoding.encode, chunk_overlap=chunk_overlap)
|
223 |
+
# results = split_contents(subset, gpt35_txt_splitter)
|
224 |
+
# self.assertEqual(len(results), 3)
|
225 |
+
# self.assertEqual(len(results[0]), 83)
|
226 |
+
# self.assertEqual(len(results[1]), 178)
|
227 |
+
# self.assertEqual(len(results[2]), 144)
|
228 |
+
# self.assertTrue(isinstance(results, list))
|
229 |
+
# self.assertTrue(isinstance(results[0], list))
|
230 |
+
# self.assertTrue(isinstance(results[0][0], str))
|
231 |
+
# unittest.TextTestRunner().run(unittest.TestLoader().loadTestsFromTestCase(TestSplitContents))
|
app/notebooks/chunking_indexing.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app/notebooks/lite_lll.ipynb
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 4,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"The autoreload extension is already loaded. To reload it, use:\n",
|
13 |
+
" %reload_ext autoreload\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"%load_ext autoreload\n",
|
19 |
+
"%autoreload 2\n",
|
20 |
+
"\n",
|
21 |
+
"import sys\n",
|
22 |
+
"sys.path.append('../')\n",
|
23 |
+
"\n",
|
24 |
+
"from dotenv import load_dotenv, find_dotenv\n",
|
25 |
+
"envs = load_dotenv(find_dotenv('env'), override=True)\n",
|
26 |
+
"\n",
|
27 |
+
"from warnings import filterwarnings\n",
|
28 |
+
"filterwarnings('ignore')\n",
|
29 |
+
"\n",
|
30 |
+
"from llm.llm import LLM\n",
|
31 |
+
"\n",
|
32 |
+
"from litellm import ModelResponse\n",
|
33 |
+
"\n",
|
34 |
+
"from typing import Literal\n",
|
35 |
+
"from rich import print\n",
|
36 |
+
"import os"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": 5,
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"#instantiate the LLM Class\n",
|
46 |
+
"turbo = 'gpt-3.5-turbo-0125'\n",
|
47 |
+
"#the LLM Class will use the OPENAI_API_KEY env var as the default api_key \n",
|
48 |
+
"llm = LLM(turbo)\n",
|
49 |
+
"\n",
|
50 |
+
"# use the gpt3.5 model that is free - recent"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": 11,
|
56 |
+
"metadata": {},
|
57 |
+
"outputs": [],
|
58 |
+
"source": [
|
59 |
+
"vs = [\"Lastly, during the quarter, we increased our reserves for general product and automobile self-\\ninsurance liabilities, driven by changes in our estimates about the cost of asserted and unasserted \\nclaims, resulting in additional expense of $1.3 billion. This impact is primarily recorded in cost of \\nsales on our income statement. As our business has grown quickly over the last several years, \\nparticularly as we've built out our fulfillment and transportation network, and claim amounts have \\nseen industry-wide inflation, we've continued to evaluate and adjust this reserve for both asserted \\nclaims, as well as our estimate for unasserted claims.\\nWe reported overall net income of $278 million in the fourth quarter. While we primarily focus our \\ncomments on operating income, I'd point out that this net income includes a pre-tax valuation loss \\nof $2.3 billion included in non-operating income from our common stock investment in Rivian \\nAutomotive. As we've noted in recent quarters, this activity is not related to Amazon's ongoing \\noperations, but rather the quarter-to-quarter fluctuations in Rivian's stock price. As we head into \\nthe New Year, we remain heads down focused on driving a better customer experience.\",\"tenet of we want to find a way to meaningfully streamline our costs in all of our businesses, not \\njust their existing large businesses, but also in some of the investments we're making, we want to \\nactually do a pretty good thorough look about what we're investing and how much we think we \\nneed to. But doing so, without having to give up our ability to invest in the key long-term strategic \\ninvestments that we think could change broad customer experiences, and change Amazon over \\ntime.\\nAnd you saw that process led to us choosing to pause on incremental headcount, as we tried to \\nassess what was happening in the economy, and we eliminated some programs, Fabric.com, and \\nAmazon Care, and Amazon Glow, and Amazon Explore, and we decided to go slower on some -- \\non the physical store expansion and the grocery space until we had a format that we really \\nbelieved in rolling out, and we went a little bit slower on some devices. Until we made the very \\nhard decision that Brian talked about earlier, which was the hardest decision I think we've all been \\na part of, which was to reduce or eliminate 18,000 roles.\",\"operating income. This operating income was negatively impacted by three large items, which \\nadded approximately $2.7 billion of costs in the quarter. This is related to employee severance, \\nimpairments of property and equipment and operating leases, and changes in estimates related to \\nself-insurance liabilities. These costs primarily impacted our North America segment. If we had not \\nincurred these charges in Q4, our operating income would have been approximately $5.4 billion. \\nWe are encouraged with the progress we continue to make in streamlining the costs in our \\nAmazon Stores business. We entered the quarter with labor more appropriately matched to \\ndemand across our operations network, compared to Q4 of last year, allowing us to have the right \\nlabor, in the right place, at the right time, and drive productivity gains. We also saw continued \\nefficiencies across our transportation network, where process and tech improvements resulted in \\nhigher Amazon Logistics productivity and improved line haul fill rates. While transportation \\noverperformed expectations in the quarter, we also saw productivity improvements across our \\nfulfillment centers, in line with our plan. We also saw good leverage driven by strong holiday \\nvolumes.\"]"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": 10,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [
|
67 |
+
{
|
68 |
+
"data": {
|
69 |
+
"text/plain": [
|
70 |
+
"'Search result 0: The 2022 decline reflects the separation of U.S. videoand lower personnel costs associated with ongoingtransformation initiatives, partially offset by higher baddebt expense, the elimination of Connect America FundPhase II (CAF II) government credits and increasedwholesale network access charges. Wireless equipmentcosts were up slightly, with higher sales volumes and thesale of higher-priced smartphones largely offset by lower3G shutdown costs in 2022. In the first quarter of 2022, weupdated the expected economic lives of customerrelationships, which extended the amortization period ofdeferred acquisition and fulfillment costs and reducedexpenses approximately $395, with $150 recorded toMobility, $115 to Business Wireline and $130 to ConsumerWireline.\\nThe 2021 decline reflects our 2021 business divestitures,\\nlower bad debt expense and lower personnel costsassociated with our transformation initiatives. Declineswere mostly offset by increased domestic wirelessequipment expense from higher volumes.\\nAsset impairments and abandonments and\\nrestructuring increased in 2022 and decreased in 2021.\\nSearch result 1: The 2022 decline reflects the separation of U.S. videoand lower personnel costs associated with ongoingtransformation initiatives, partially offset by higher baddebt expense, the elimination of Connect America FundPhase II (CAF II) government credits and increasedwholesale network access charges. Wireless equipmentcosts were up slightly, with higher sales volumes and thesale of higher-priced smartphones largely offset by lower3G shutdown costs in 2022. In the first quarter of 2022, weupdated the expected economic lives of customerrelationships, which extended the amortization period ofdeferred acquisition and fulfillment costs and reducedexpenses approximately $395, with $150 recorded toMobility, $115 to Business Wireline and $130 to ConsumerWireline.\\nThe 2021 decline reflects our 2021 business divestitures,\\nlower bad debt expense and lower personnel costsassociated with our transformation initiatives. Declineswere mostly offset by increased domestic wirelessequipment expense from higher volumes.\\nAsset impairments and abandonments and\\nrestructuring increased in 2022 and decreased in 2021.\\nSearch result 2: Credit Losses As of January 1, 2020, we adopted,\\nthrough modified retrospective application, ASU No.2016-13, “Financial Instruments—Credit Losses (Topic 326):Measurement of Credit Losses on Financial Instruments,”or Accounting Standards Codification (ASC) 326 (ASC 326),which replaces the incurred loss impairment methodologyunder prior GAAP with an expected credit loss model. ASC326 affects trade receivables, loans, contract assets,certain beneficial interests, off-balance-sheet creditexposures not accounted for as insurance and otherfinancial assets that are not subject to fair value throughnet income, as defined by the standard. Under theexpected credit loss model, we are required to considerfuture economic trends to estimate expected creditlosses over the lifetime of the asset. Upon adoption onJanuary 1, 2020, we recorded a $293 reduction to“Retained earnings,” $395 increase to “Allowances forcredit losses” applicable to our trade and loan receivables,$10 reduction of contract assets, $105 reduction of netdeferred income tax liability and $7 reduction of“Noncontrolling interest.” Our adoption of ASC 326 did nothave a material impact on our financial statements.'"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
"execution_count": 10,
|
74 |
+
"metadata": {},
|
75 |
+
"output_type": "execute_result"
|
76 |
+
}
|
77 |
+
],
|
78 |
+
"source": [
|
79 |
+
"searches"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": 15,
|
85 |
+
"metadata": {},
|
86 |
+
"outputs": [
|
87 |
+
{
|
88 |
+
"data": {
|
89 |
+
"text/plain": [
|
90 |
+
"\"Amazon's loss includes a pre-tax valuation loss of $2.3 billion included in non-operating income from their common stock investment in Rivian Automotive. This loss is not related to Amazon's ongoing operations but rather reflects quarter-to-quarter fluctuations in Rivian's stock price.\""
|
91 |
+
]
|
92 |
+
},
|
93 |
+
"execution_count": 15,
|
94 |
+
"metadata": {},
|
95 |
+
"output_type": "execute_result"
|
96 |
+
}
|
97 |
+
],
|
98 |
+
"source": [
|
99 |
+
"system_message = \"\"\"\n",
|
100 |
+
"You are a financial analyst, with a deep expertise in financial reports.\n",
|
101 |
+
"You are able to quickly understand a series of paragraphs, or quips even, extracted from financial reports by a vector search system. \n",
|
102 |
+
"\"\"\" \n",
|
103 |
+
"searches = \"\\n\".join([f\"Search result {i}: {v}\" for i,v in enumerate(vs,1)])\n",
|
104 |
+
"\n",
|
105 |
+
"question = \"What is Amazon's loss?\"\n",
|
106 |
+
"\n",
|
107 |
+
"user_prompt = f\"\"\"\n",
|
108 |
+
"Use the below context enclosed in triple back ticks to answer the question. \\n\n",
|
109 |
+
"The context is given by a vector search into a vector database of financial reports, so you can assume the context is accurate.\n",
|
110 |
+
"They search results are given in order of relevance (most relevant first). \\n\n",
|
111 |
+
"```\n",
|
112 |
+
"Context:\n",
|
113 |
+
"```\n",
|
114 |
+
"{searches}\n",
|
115 |
+
"```\n",
|
116 |
+
"Question:\\n\n",
|
117 |
+
"{question}\\n\n",
|
118 |
+
"------------------------\n",
|
119 |
+
"1. If the context does not provide enough information to answer the question, then\n",
|
120 |
+
"state that you cannot answer the question with the provided context.\n",
|
121 |
+
"2. Do not use any external knowledge or resources to answer the question.\n",
|
122 |
+
"3. Answer the question directly and with as much detail as possible, within the limits of the context.\n",
|
123 |
+
"------------------------\n",
|
124 |
+
"Answer:\\n\n",
|
125 |
+
"\"\"\".format(searches=searches, question=question)\n",
|
126 |
+
"\n",
|
127 |
+
"\n",
|
128 |
+
"response = llm.chat_completion(system_message=system_message,\n",
|
129 |
+
" user_message=user_prompt,\n",
|
130 |
+
" temperature=0.01,\n",
|
131 |
+
" stream=False,\n",
|
132 |
+
" raw_response=False)\n",
|
133 |
+
"response\n"
|
134 |
+
]
|
135 |
+
}
|
136 |
+
],
|
137 |
+
"metadata": {
|
138 |
+
"kernelspec": {
|
139 |
+
"display_name": "venv",
|
140 |
+
"language": "python",
|
141 |
+
"name": "python3"
|
142 |
+
},
|
143 |
+
"language_info": {
|
144 |
+
"codemirror_mode": {
|
145 |
+
"name": "ipython",
|
146 |
+
"version": 3
|
147 |
+
},
|
148 |
+
"file_extension": ".py",
|
149 |
+
"mimetype": "text/x-python",
|
150 |
+
"name": "python",
|
151 |
+
"nbconvert_exporter": "python",
|
152 |
+
"pygments_lexer": "ipython3",
|
153 |
+
"version": "3.10.14"
|
154 |
+
}
|
155 |
+
},
|
156 |
+
"nbformat": 4,
|
157 |
+
"nbformat_minor": 2
|
158 |
+
}
|
app/notebooks/pdf_readers.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app/notebooks/upload_index.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app/notebooks/weaviate.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app/rag/__init__.py
ADDED
File without changes
|
app/rag/llm.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# I didn't write this code
|
2 |
+
|
3 |
+
from litellm import completion, acompletion
|
4 |
+
from litellm.utils import CustomStreamWrapper, ModelResponse
|
5 |
+
import os
|
6 |
+
|
7 |
+
class LLM:
|
8 |
+
'''
|
9 |
+
Creates primary Class instance for interacting with various LLM model APIs.
|
10 |
+
Primary APIs supported are OpenAI and Anthropic.
|
11 |
+
'''
|
12 |
+
# non-exhaustive list of supported models
|
13 |
+
# these models are known to work
|
14 |
+
valid_models = {'openai': [
|
15 |
+
"gpt-4-turbo-preview",
|
16 |
+
"gpt-4-0125-preview",
|
17 |
+
"gpt-4-1106-preview",
|
18 |
+
"gpt-3.5-turbo",
|
19 |
+
"gpt-3.5-turbo-1106",
|
20 |
+
"gpt-3.5-turbo-0125",
|
21 |
+
],
|
22 |
+
'anthropic': [ 'claude-3-haiku-20240307',
|
23 |
+
'claude-3-sonnet-2024022',
|
24 |
+
'claude-3-opus-20240229'
|
25 |
+
],
|
26 |
+
'cohere': ['command-r',
|
27 |
+
'command-r-plus'
|
28 |
+
]
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
model_name: str='gpt-3.5-turbo-0125',
|
33 |
+
api_key: str=None,
|
34 |
+
api_version: str=None,
|
35 |
+
api_base: str=None
|
36 |
+
):
|
37 |
+
|
38 |
+
self.model_name = model_name
|
39 |
+
if not api_key:
|
40 |
+
try:
|
41 |
+
self._api_key = os.environ['OPENAI_API_KEY']
|
42 |
+
except KeyError:
|
43 |
+
raise ValueError('Default api_key expects OPENAI_API_KEY environment variable. Check that you have this variable or pass in another api_key.')
|
44 |
+
else:
|
45 |
+
self._api_key = api_key
|
46 |
+
self.api_version = api_version
|
47 |
+
self.api_base = api_base
|
48 |
+
|
49 |
+
|
50 |
+
def chat_completion(self,
|
51 |
+
system_message: str,
|
52 |
+
user_message: str='',
|
53 |
+
temperature: int=0,
|
54 |
+
max_tokens: int=500,
|
55 |
+
stream: bool=False,
|
56 |
+
raw_response: bool=False,
|
57 |
+
**kwargs
|
58 |
+
) -> str | CustomStreamWrapper | ModelResponse:
|
59 |
+
'''
|
60 |
+
Generative text completion method.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
-----
|
64 |
+
system_message: str
|
65 |
+
The system message to be sent to the model.
|
66 |
+
user_message: str
|
67 |
+
The user message to be sent to the model.
|
68 |
+
temperature: int
|
69 |
+
The temperature parameter for the model.
|
70 |
+
max_tokens: int
|
71 |
+
The maximum tokens to be generated.
|
72 |
+
stream: bool
|
73 |
+
Whether to stream the response.
|
74 |
+
raw_response: bool
|
75 |
+
If True, returns the raw model response.
|
76 |
+
'''
|
77 |
+
#reformat roles for claude models
|
78 |
+
initial_role = 'user' if self.model_name.startswith('claude') else 'system'
|
79 |
+
secondary_role = 'assistant' if self.model_name.startswith('claude') else 'user'
|
80 |
+
|
81 |
+
#handle temperature for claude models
|
82 |
+
if self.model_name.startswith('claude'):
|
83 |
+
temperature = temperature/2
|
84 |
+
|
85 |
+
messages = [
|
86 |
+
{'role': initial_role, 'content': system_message},
|
87 |
+
{'role': secondary_role, 'content': user_message}
|
88 |
+
]
|
89 |
+
|
90 |
+
response = completion(model=self.model_name,
|
91 |
+
messages=messages,
|
92 |
+
temperature=temperature,
|
93 |
+
max_tokens=max_tokens,
|
94 |
+
stream=stream,
|
95 |
+
api_key=self._api_key,
|
96 |
+
api_base=self.api_base,
|
97 |
+
api_version=self.api_version,
|
98 |
+
**kwargs)
|
99 |
+
|
100 |
+
if raw_response or stream:
|
101 |
+
return response
|
102 |
+
return response.choices[0].message.content
|
103 |
+
|
104 |
+
async def achat_completion(self,
|
105 |
+
system_message: str,
|
106 |
+
user_message: str=None,
|
107 |
+
temperature: int=0,
|
108 |
+
max_tokens: int=500,
|
109 |
+
stream: bool=False,
|
110 |
+
raw_response: bool=False,
|
111 |
+
**kwargs
|
112 |
+
) -> str | CustomStreamWrapper | ModelResponse:
|
113 |
+
'''
|
114 |
+
Asynchronous generative text completion method.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
-----
|
118 |
+
system_message: str
|
119 |
+
The system message to be sent to the model.
|
120 |
+
user_message: str
|
121 |
+
The user message to be sent to the model.
|
122 |
+
temperature: int
|
123 |
+
The temperature parameter for the model.
|
124 |
+
max_tokens: int
|
125 |
+
The maximum tokens to be generated.
|
126 |
+
stream: bool
|
127 |
+
Whether to stream the response.
|
128 |
+
raw_response: bool
|
129 |
+
If True, returns the raw model response.
|
130 |
+
'''
|
131 |
+
initial_role = 'user' if self.model_name.startswith('claude') else 'system'
|
132 |
+
if self.model_name.startswith('claude'):
|
133 |
+
temperature = temperature/2
|
134 |
+
messages = [
|
135 |
+
{'role': initial_role, 'content': system_message},
|
136 |
+
{'role': 'user', 'content': user_message}
|
137 |
+
]
|
138 |
+
response = await acompletion(model=self.model_name,
|
139 |
+
messages=messages,
|
140 |
+
temperature=temperature,
|
141 |
+
max_tokens=max_tokens,
|
142 |
+
stream=stream,
|
143 |
+
api_key=self._api_key,
|
144 |
+
api_base=self.api_base,
|
145 |
+
api_version=self.api_version,
|
146 |
+
**kwargs)
|
147 |
+
if raw_response or stream:
|
148 |
+
return response
|
149 |
+
return response.choices[0].message.content
|
app/rag/rag.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from .llm import LLM
|
5 |
+
#the LLM Class uses the OPENAI_API_KEY env var as the default api_key
|
6 |
+
|
7 |
+
|
8 |
+
def rag_it(question: str,
|
9 |
+
search_results: List[str],
|
10 |
+
model: str = 'gpt-3.5-turbo-0125',
|
11 |
+
) -> str:
|
12 |
+
|
13 |
+
# TODO turn this into a class if time allows
|
14 |
+
llm = LLM(model)
|
15 |
+
|
16 |
+
system_message = """
|
17 |
+
You are a financial analyst, with a deep expertise in financial reports.
|
18 |
+
You are able to quickly understand a series of paragraphs, or quips even, extracted
|
19 |
+
from financial reports by a vector search system.
|
20 |
+
"""
|
21 |
+
|
22 |
+
searches = "\n".join([f"Search result {i}: {v}" for i,v in enumerate(search_results,1)])
|
23 |
+
|
24 |
+
user_prompt = f"""
|
25 |
+
Use the below context enclosed in triple back ticks to answer the question. \n
|
26 |
+
The context is given by a vector search into a vector database of financial reports,
|
27 |
+
so you can assume the context is accurate.
|
28 |
+
They search results are given in order of relevance (most relevant first). \n
|
29 |
+
```
|
30 |
+
Context:
|
31 |
+
```
|
32 |
+
{searches}
|
33 |
+
```
|
34 |
+
Question:\n
|
35 |
+
{question}\n
|
36 |
+
------------------------
|
37 |
+
1. If the context does not provide enough information to answer the question, then
|
38 |
+
state that you cannot answer the question with the provided context.
|
39 |
+
Pay great attention to making sure your answer is relevant to the question
|
40 |
+
(for instance, never answer a question about a topic or company that are not explicitely mentioned in the context)
|
41 |
+
2. Do not use any external knowledge or resources to answer the question.
|
42 |
+
3. Answer the question directly and with as much detail as possible, within the limits of the context.
|
43 |
+
4. Avoid mentioning 'search results' in the answer.
|
44 |
+
Instead, incorporate the information from the search results into the answer.
|
45 |
+
5. Create a clean answer, without backticks, or starting with a new line for instance.
|
46 |
+
------------------------
|
47 |
+
Answer:\n
|
48 |
+
""".format(searches=searches, question=question)
|
49 |
+
|
50 |
+
|
51 |
+
response = llm.chat_completion(system_message=system_message,
|
52 |
+
user_message=user_prompt,
|
53 |
+
temperature=0.01, # let's not allow the model to be creative
|
54 |
+
stream=False,
|
55 |
+
raw_response=False)
|
56 |
+
return response
|
app/requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
requests==2.31.0
|
2 |
+
pydantic==2.7.1
|
3 |
+
pydantic_core==2.18.2
|
4 |
+
fastapi==0.111.0
|
5 |
+
uvicorn[standard]
|
6 |
+
pdfplumber==0.11.0
|
7 |
+
weaviate-client==4.5.4
|
8 |
+
PyPDF2==3.0.1
|
9 |
+
PyMuPDF==1.24.3
|
10 |
+
llama-parse==0.4.2
|
11 |
+
llama-index-readers-file==0.1.22
|
12 |
+
nest_asyncio==1.6.0
|
13 |
+
llama-index==0.10.37
|
14 |
+
sentence-transformers==2.7.0
|
15 |
+
fastparquet==2024.2.0
|
16 |
+
litellm==1.37.12
|
17 |
+
langchain==0.1.20
|
18 |
+
langchain-community==0.0.38
|
19 |
+
langchain-core==0.1.52
|
20 |
+
langchain-text-splitters==0.0.1
|
app/settings.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
datadir = '../data' # will be used in main.py
|
4 |
+
parquet_file = os.path.join(datadir, 'text_vectors.parquet') # used by the files in 'engine'
|