mr / app /engine /vectorstore.py
JPBianchi's picture
secrets
d811058
raw
history blame
7.71 kB
import os, logging
from typing import List, Any
import pandas as pd
from weaviate.classes.config import Property, DataType
from .weaviate_interface_v4 import WeaviateWCS, WeaviateIndexer
from .logger import logger
from settings import parquet_file
class VectorStore:
def __init__(self, model_path:str = 'sentence-transformers/all-mpnet-base-v2'):
# we can create several instances to test various models, especially if we finetune one
self.finrag_properties = [
Property(name='filename',
data_type=DataType.TEXT,
description='Name of the file',
index_filterable=True,
index_searchable=True),
# Property(name='keywords',
# data_type=DataType.TEXT_ARRAY,
# description='Keywords associated with the file',
# index_filterable=True,
# index_searchable=True),
Property(name='content',
data_type=DataType.TEXT,
description='Splits of the article',
index_filterable=True,
index_searchable=True),
]
self.class_name = "FinRag_all-mpnet-base-v2"
self.class_config = {'classes': [
{"class": self.class_name,
"description": "Financial reports",
"vectorIndexType": "hnsw",
# Vector index specific settings for HSNW
"vectorIndexConfig": {
"ef": 64, # higher is better quality vs slower search
"efConstruction": 128, # higher = better index but slower build
"maxConnections": 32, # max conn per layer - higher = more memory
},
"vectorizer": "none",
"properties": self.finrag_properties }
]
}
self.model_path = model_path
try:
self.api_key = os.environ.get('FINRAG_WEAVIATE_API_KEY')
self.url = os.environ.get('FINRAG_WEAVIATE_ENDPOINT')
self.client = WeaviateWCS(endpoint=self.url,
api_key=self.api_key,
model_name_or_path=self.model_path)
except Exception as e:
# raise Exception(f"Could not create Weaviate client: {e}")
pass
assert self.client._client.is_live(), "Weaviate is not live"
assert self.client._client.is_ready(), "Weaviate is not ready"
# careful with accessing '_client' since the weaviate helper usually closes the connection every time
self.indexer = None
self.create_collection()
@property
def collections(self):
return self.client.show_all_collections()
def create_collection(self, collection_name: str='Finrag', description: str='Financial reports'):
self.collection_name = collection_name
if collection_name not in self.collections:
self.client.create_collection(collection_name=collection_name,
properties=self.finrag_properties,
description=description)
self.collection_name = collection_name
else:
logging.warning(f"Collection {collection_name} already exists")
def empty_collection(self, collection_name: str='Finrag') -> bool:
# not in the library yet, so I simply delete and recreate it
if collection_name in self.collections:
self.client.delete_collection(collection_name=collection_name)
self.create_collection()
return True
else:
logging.warning(f"Collection {collection_name} doesn't exist")
return False
def index_data(self, data: List[dict]= None, collection_name: str='Finrag'):
if self.indexer is None:
self.indexer = WeaviateIndexer(self.client)
if data is None:
# use the parquet file, otherwise use the data passed
data = pd.read_parquet(parquet_file).to_dict('records')
# the parquet file was created/incremented when a new article was uploaded
# it is a dataframe with columns: file, content, content_embedding
# and reflects exactly the data that we want to index at all times
self.status = self.indexer.batch_index_data(data, collection_name, 256)
self.num_errors, self.error_messages, self.doc_ids = self.status
# in this case with few articles, we don't tolerate errors
# batch_index_data already tests errors against a threshold
# assert self.num_errors == 0, f"Errors: {self.num_errors}"
def keyword_search(self,
query: str,
limit: int=5,
return_properties: List[str]=['filename', 'content'],
alpha=None # dummy parameter to match the hybrid_search signature
) -> List[str]:
response = self.client.keyword_search(
request=query,
collection_name=self.collection_name,
query_properties=['content'],
limit=limit,
filter=None,
return_properties=return_properties,
return_raw=False)
return [res['content'] for res in response]
def vector_search(self,
query: str,
limit: int=5,
return_properties: List[str]=['filename', 'content'],
alpha=None # dummy parameter to match the hybrid_search signature
) -> List[str]:
response = self.client.vector_search(
request=query,
collection_name=self.collection_name,
limit=limit,
filter=None,
return_properties=return_properties,
return_raw=False)
return [res['content'] for res in response]
def hybrid_search(self,
query: str,
limit: int=5,
alpha=0.5, # higher = more vector search
return_properties: List[str]=['filename', 'content']
) -> List[str]:
response = self.client.hybrid_search(
request=query,
collection_name=self.collection_name,
query_properties=['content'],
alpha=alpha,
limit=limit,
filter=None,
return_properties=return_properties,
return_raw=False)
return [res['content'] for res in response]