|
import os, logging |
|
from typing import List, Any |
|
import pandas as pd |
|
from weaviate.classes.config import Property, DataType |
|
|
|
from .weaviate_interface_v4 import WeaviateWCS, WeaviateIndexer |
|
from .logger import logger |
|
|
|
from settings import parquet_file |
|
|
|
class VectorStore: |
|
def __init__(self, model_path:str = 'sentence-transformers/all-mpnet-base-v2'): |
|
|
|
|
|
self.finrag_properties = [ |
|
Property(name='filename', |
|
data_type=DataType.TEXT, |
|
description='Name of the file', |
|
index_filterable=True, |
|
index_searchable=True), |
|
|
|
|
|
|
|
|
|
|
|
Property(name='content', |
|
data_type=DataType.TEXT, |
|
description='Splits of the article', |
|
index_filterable=True, |
|
index_searchable=True), |
|
] |
|
|
|
self.class_name = "FinRag_all-mpnet-base-v2" |
|
|
|
self.class_config = {'classes': [ |
|
|
|
{"class": self.class_name, |
|
|
|
"description": "Financial reports", |
|
|
|
"vectorIndexType": "hnsw", |
|
|
|
|
|
"vectorIndexConfig": { |
|
|
|
"ef": 64, |
|
"efConstruction": 128, |
|
"maxConnections": 32, |
|
}, |
|
|
|
"vectorizer": "none", |
|
|
|
"properties": self.finrag_properties } |
|
] |
|
} |
|
|
|
self.model_path = model_path |
|
try: |
|
self.api_key = os.environ['FINRAG_WEAVIATE_API_KEY'] |
|
self.url = os.environ['FINRAG_WEAVIATE_ENDPOINT'] |
|
self.client = WeaviateWCS(endpoint=self.url, |
|
api_key=self.api_key, |
|
model_name_or_path=self.model_path) |
|
except Exception as e: |
|
|
|
pass |
|
|
|
assert self.client._client.is_live(), "Weaviate is not live" |
|
assert self.client._client.is_ready(), "Weaviate is not ready" |
|
|
|
|
|
self.indexer = None |
|
|
|
self.create_collection() |
|
|
|
@property |
|
def collections(self): |
|
|
|
return self.client.show_all_collections() |
|
|
|
def create_collection(self, collection_name: str='Finrag', description: str='Financial reports'): |
|
|
|
self.collection_name = collection_name |
|
if collection_name not in self.collections: |
|
self.client.create_collection(collection_name=collection_name, |
|
properties=self.finrag_properties, |
|
description=description) |
|
self.collection_name = collection_name |
|
else: |
|
logging.warning(f"Collection {collection_name} already exists") |
|
|
|
|
|
def empty_collection(self, collection_name: str='Finrag') -> bool: |
|
|
|
|
|
if collection_name in self.collections: |
|
self.client.delete_collection(collection_name=collection_name) |
|
self.create_collection() |
|
return True |
|
else: |
|
logging.warning(f"Collection {collection_name} doesn't exist") |
|
return False |
|
|
|
|
|
def index_data(self, data: List[dict]= None, collection_name: str='Finrag'): |
|
|
|
if self.indexer is None: |
|
self.indexer = WeaviateIndexer(self.client) |
|
|
|
if data is None: |
|
|
|
data = pd.read_parquet(parquet_file).to_dict('records') |
|
|
|
|
|
|
|
self.status = self.indexer.batch_index_data(data, collection_name, 256) |
|
|
|
self.num_errors, self.error_messages, self.doc_ids = self.status |
|
|
|
|
|
|
|
|
|
|
|
|
|
def keyword_search(self, |
|
query: str, |
|
limit: int=5, |
|
return_properties: List[str]=['filename', 'content'], |
|
alpha=None |
|
) -> List[str]: |
|
response = self.client.keyword_search( |
|
request=query, |
|
collection_name=self.collection_name, |
|
query_properties=['content'], |
|
limit=limit, |
|
filter=None, |
|
return_properties=return_properties, |
|
return_raw=False) |
|
|
|
return [res['content'] for res in response] |
|
|
|
|
|
def vector_search(self, |
|
query: str, |
|
limit: int=5, |
|
return_properties: List[str]=['filename', 'content'], |
|
alpha=None |
|
) -> List[str]: |
|
|
|
response = self.client.vector_search( |
|
request=query, |
|
collection_name=self.collection_name, |
|
limit=limit, |
|
filter=None, |
|
return_properties=return_properties, |
|
return_raw=False) |
|
|
|
return [res['content'] for res in response] |
|
|
|
|
|
def hybrid_search(self, |
|
query: str, |
|
limit: int=5, |
|
alpha=0.5, |
|
return_properties: List[str]=['filename', 'content'] |
|
) -> List[str]: |
|
|
|
response = self.client.hybrid_search( |
|
request=query, |
|
collection_name=self.collection_name, |
|
query_properties=['content'], |
|
alpha=alpha, |
|
limit=limit, |
|
filter=None, |
|
return_properties=return_properties, |
|
return_raw=False) |
|
|
|
return [res['content'] for res in response] |