|
import os, logging |
|
from app.engine.logger import logger |
|
|
|
from typing import List, Any |
|
import pandas as pd |
|
from weaviate.classes.config import Property, DataType |
|
|
|
from .weaviate_interface_v4 import WeaviateWCS, WeaviateIndexer |
|
|
|
from ..settings import parquet_file |
|
from weaviate.classes.query import Filter |
|
from torch import cuda |
|
|
|
if os.path.exists('.we_are_local'): |
|
COLLECTION = 'MultiRAG_local_mr' |
|
else: |
|
COLLECTION = 'MultiRAG' |
|
|
|
class dummyWeaviate: |
|
""" Created to pass on HF since I had again the client creation issue |
|
Temporary solution |
|
""" |
|
def __init__(self, |
|
endpoint: str=None, |
|
api_key: str=None, |
|
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2', |
|
embedded: bool=False, |
|
openai_api_key: str=None, |
|
skip_init_checks: bool=False, |
|
**kwargs |
|
): |
|
return |
|
|
|
def _connect(self) -> None: |
|
return |
|
|
|
def _client(self): |
|
return |
|
|
|
def create_collection(self, |
|
collection_name: str, |
|
properties: list[Property], |
|
description: str=None, |
|
**kwargs |
|
) -> None: |
|
return |
|
|
|
def show_all_collections(self, |
|
detailed: bool=False, |
|
max_details: bool=False |
|
) -> list[str] | dict: |
|
return ['abc', 'def'] |
|
|
|
def show_collection_config(self, collection_name: str): |
|
return |
|
|
|
def show_collection_properties(self, collection_name: str): |
|
return |
|
|
|
def delete_collection(self, collection_name: str): |
|
return |
|
|
|
def get_doc_count(self, collection_name: str): |
|
return |
|
|
|
def keyword_search(self, |
|
request: str, |
|
collection_name: str, |
|
query_properties: list[str]=['content'], |
|
limit: int=10, |
|
filter: Filter=None, |
|
return_properties: list[str]=None, |
|
return_raw: bool=False |
|
): |
|
return |
|
|
|
def vector_search(self, |
|
request: str, |
|
collection_name: str, |
|
limit: int=10, |
|
return_properties: list[str]=None, |
|
filter: Filter=None, |
|
return_raw: bool=False, |
|
device: str='cuda:0' if cuda.is_available() else 'cpu' |
|
): |
|
return |
|
|
|
def hybrid_search(self, |
|
request: str, |
|
collection_name: str, |
|
query_properties: list[str]=['content'], |
|
alpha: float=0.5, |
|
limit: int=10, |
|
filter: Filter=None, |
|
return_properties: list[str]=None, |
|
return_raw: bool=False, |
|
device: str='cuda:0' if cuda.is_available() else 'cpu' |
|
): |
|
return |
|
|
|
class VectorStore: |
|
def __init__(self, model_path: str = 'sentence-transformers/all-mpnet-base-v2'): |
|
|
|
|
|
self.MultiRAG_properties = [ |
|
Property(name='file', |
|
data_type=DataType.TEXT, |
|
description='Name of the file', |
|
index_filterable=True, |
|
index_searchable=True), |
|
|
|
|
|
|
|
|
|
|
|
Property(name='content', |
|
data_type=DataType.TEXT, |
|
description='Splits of the article', |
|
index_filterable=True, |
|
index_searchable=True), |
|
] |
|
|
|
self.class_name = "MultiRAG_all-mpnet-base-v2" |
|
|
|
self.class_config = {'classes': [ |
|
|
|
{"class": self.class_name, |
|
|
|
"description": "multiple types of docs", |
|
|
|
"vectorIndexType": "hnsw", |
|
|
|
|
|
"vectorIndexConfig": { |
|
|
|
"ef": 64, |
|
"efConstruction": 128, |
|
"maxConnections": 32, |
|
}, |
|
|
|
"vectorizer": "none", |
|
|
|
"properties": self.MultiRAG_properties} |
|
] |
|
} |
|
|
|
self.model_path = model_path |
|
|
|
try: |
|
self.api_key = os.environ.get('FINRAG_WEAVIATE_API_KEY') |
|
logger(f"API key: {self.api_key[:5]}") |
|
self.url = os.environ.get('FINRAG_WEAVIATE_ENDPOINT') |
|
logger(f"URL: {self.url[8:15]}") |
|
self.client = WeaviateWCS( |
|
endpoint=self.url, |
|
api_key=self.api_key, |
|
model_name_or_path=self.model_path, |
|
) |
|
assert self.client._client.is_live(), "Weaviate is not live" |
|
assert self.client._client.is_ready(), "Weaviate is not ready" |
|
logger(f"Weaviate client created") |
|
except Exception as e: |
|
|
|
self.client = dummyWeaviate() |
|
logger(f"Could not create Weaviate client: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.indexer = None |
|
|
|
self.create_collection() |
|
|
|
@property |
|
def collections(self): |
|
|
|
return self.client.show_all_collections() |
|
|
|
def create_collection(self, |
|
collection_name: str=COLLECTION, |
|
description: str='Documents'): |
|
|
|
self.collection_name = collection_name |
|
if collection_name not in self.collections: |
|
self.client.create_collection(collection_name=collection_name, |
|
properties=self.MultiRAG_properties, |
|
description=description) |
|
|
|
else: |
|
logger(f"Collection {collection_name} already exists") |
|
|
|
|
|
def empty_collection(self, collection_name: str=COLLECTION) -> bool: |
|
|
|
|
|
if collection_name in self.collections: |
|
self.client.delete_collection(collection_name=collection_name) |
|
self.create_collection() |
|
return True |
|
else: |
|
logger(f"Collection {collection_name} doesn't exist") |
|
return False |
|
|
|
|
|
def index_data(self, data: List[dict]= None, collection_name: str=COLLECTION): |
|
|
|
if self.indexer is None: |
|
self.indexer = WeaviateIndexer(self.client) |
|
|
|
if data is None: |
|
|
|
data = pd.read_parquet(parquet_file).to_dict('records') |
|
|
|
|
|
|
|
self.status = self.indexer.batch_index_data(data, collection_name, 256) |
|
|
|
self.num_errors, self.error_messages, self.doc_ids = self.status |
|
|
|
|
|
|
|
|
|
|
|
|
|
def keyword_search(self, |
|
query: str, |
|
limit: int=5, |
|
return_properties: List[str]=['file', 'content'], |
|
alpha=None |
|
) -> List[str]: |
|
response = self.client.keyword_search( |
|
request=query, |
|
collection_name=self.collection_name, |
|
query_properties=['file', 'content'], |
|
limit=limit, |
|
filter=None, |
|
return_properties=return_properties, |
|
return_raw=False) |
|
|
|
return [(res['file'], res['content'], res['score']) for res in response] |
|
|
|
|
|
def vector_search(self, |
|
query: str, |
|
limit: int=5, |
|
return_properties: List[str]=['file', 'content'], |
|
alpha=None |
|
) -> List[str]: |
|
|
|
response = self.client.vector_search( |
|
request=query, |
|
collection_name=self.collection_name, |
|
limit=limit, |
|
filter=None, |
|
return_properties=return_properties, |
|
return_raw=False) |
|
|
|
return [(res['file'], res['content'], res['score']) for res in response] |
|
|
|
|
|
def hybrid_search(self, |
|
query: str, |
|
limit: int=10, |
|
alpha=0.5, |
|
return_properties: List[str]=['file', 'content'] |
|
) -> List[str]: |
|
|
|
response = self.client.hybrid_search( |
|
request=query, |
|
collection_name=self.collection_name, |
|
query_properties=['file', 'content'], |
|
alpha=alpha, |
|
limit=limit, |
|
filter=None, |
|
return_properties=return_properties, |
|
return_raw=False) |
|
|
|
return [(res['file'], res['content'], res['score']) for res in response] |