mr / app /engine /weaviate_interface_v4.py
JPBianchi's picture
big upload
10d6a86
raw
history blame
22.1 kB
# Disclaimer: I didn't write this module
from weaviate.auth import AuthApiKey
from weaviate.collections.classes.internal import (MetadataReturn, QueryReturn,
MetadataQuery)
import weaviate
from weaviate.classes.config import Property
from weaviate.classes.query import Filter
from weaviate.config import ConnectionConfig
from openai import OpenAI
from sentence_transformers import SentenceTransformer
from typing import Any
from torch import cuda
from tqdm import tqdm
import time
import os
from dataclasses import dataclass
class WeaviateWCS:
'''
A python native Weaviate Client class that encapsulates Weaviate functionalities
in one object. Several convenience methods are added for ease of use.
Args
----
api_key: str
The API key for the Weaviate Cloud Service (WCS) instance.
https://console.weaviate.cloud/dashboard
endpoint: str
The url endpoint for the Weaviate Cloud Service instance.
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2'
The name or path of the SentenceTransformer model to use for vector search.
Will also support OpenAI text-embedding-ada-002 model. This param enables
the use of most leading models on MTEB Leaderboard:
https://huggingface.co/spaces/mteb/leaderboard
openai_api_key: str=None
The API key for the OpenAI API. Only required if using OpenAI text-embedding-ada-002 model.
'''
def __init__(self,
endpoint: str=None,
api_key: str=None,
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2',
embedded: bool=False,
openai_api_key: str=None,
skip_init_checks: bool=False,
**kwargs
):
self.endpoint = endpoint
if embedded:
self._client = weaviate.connect_to_embedded(**kwargs)
else:
auth_config = AuthApiKey(api_key=api_key)
self._client = weaviate.connect_to_wcs(cluster_url=endpoint,
auth_credentials=auth_config,
skip_init_checks=skip_init_checks)
self.model_name_or_path = model_name_or_path
self._openai_model = False
if self.model_name_or_path == 'text-embedding-ada-002':
if not openai_api_key:
raise ValueError(f'OpenAI API key must be provided to use this model: {self.model_name_or_path}')
self.model = OpenAI(api_key=openai_api_key)
self._openai_model = True
else:
self.model = SentenceTransformer(self.model_name_or_path) if self.model_name_or_path else None
self.return_properties = ['guest', 'title', 'summary', 'content', 'video_id', 'doc_id', 'episode_url', 'thumbnail_url']
def _connect(self) -> None:
'''
Connects to Weaviate instance.
'''
if not self._client.is_connected():
self._client.connect()
def create_collection(self,
collection_name: str,
properties: list[Property],
description: str=None,
**kwargs
) -> None:
'''
Creates a collection (index) on the Weaviate instance.
Args
----
collection_name: str
Name of the collection to create.
properties: list[Property]
List of properties to add to data objects in the collection.
description: str=None
User-defined description of the collection.
'''
self._connect()
if self._client.collections.exists(collection_name):
print(f'Collection "{collection_name}" already exists')
return
else:
try:
self._client.collections.create(name=collection_name,
properties=properties,
description=description,
**kwargs)
print(f'Collection "{collection_name}" created')
except Exception as e:
print(f'Error creating collection, due to: {e}')
self._client.close()
return
def show_all_collections(self,
detailed: bool=False,
max_details: bool=False
) -> list[str] | dict:
'''
Shows all available collections(indexes) on the Weaviate cluster.
By default will only return list of collection names.
Otherwise, increasing details about each collection can be returned.
'''
self._connect()
collections = self._client.collections.list_all(simple=not max_details)
self._client.close()
if not detailed and not max_details:
return list(collections.keys())
else:
if not any(collections):
print('No collections found on host')
return collections
def show_collection_config(self, collection_name: str) -> ConnectionConfig:
'''
Shows all information of a specific collection.
'''
self._connect()
if self._client.collections.exists(collection_name):
collection = self.show_all_collections(max_details=True)[collection_name]
self._client.close()
return collection
else:
print(f'Collection "{collection_name}" not found on host')
def show_collection_properties(self, collection_name: str) -> dict | str:
'''
Shows all properties of a collection (index) on the Weaviate instance.
'''
self._connect()
if self._client.collections.exists(collection_name):
collection = self.show_all_collections(max_details=True)[collection_name]
self._client.close()
return collection.properties
else:
print(f'Collection "{collection_name}" not found on host')
def delete_collection(self, collection_name: str) -> str:
'''
Deletes a collection (index) on the Weaviate instance, if it exists.
'''
self._connect()
if self._client.collections.exists(collection_name):
try:
self._client.collections.delete(collection_name)
self._client.close()
print(f'Collection "{collection_name}" deleted')
except Exception as e:
print(f'Error deleting collection, due to: {e}')
else:
print(f'Collection "{collection_name}" not found on host')
def get_doc_count(self, collection_name: str) -> str:
'''
Returns the number of documents in a collection.
'''
self._connect()
if self._client.collections.exists(collection_name):
collection = self._client.collections.get(collection_name)
aggregate = collection.aggregate.over_all()
total_count = aggregate.total_count
print(f'Found {total_count} documents in collection "{collection_name}"')
return total_count
else:
print(f'Collection "{collection_name}" not found on host')
def format_response(self,
response: QueryReturn,
) -> list[dict]:
'''
Formats json response from Weaviate into a list of dictionaries.
Expands _additional fields if present into top-level dictionary.
'''
results = [{**o.properties, **self._get_meta(o.metadata)} for o in response.objects]
return results
def _get_meta(self, metadata: MetadataReturn):
'''
Extracts metadata from MetadataQuery object if meta exists.
'''
temp_dict = metadata.__dict__
return {k:v for k,v in temp_dict.items() if v}
def keyword_search(self,
request: str,
collection_name: str,
query_properties: list[str]=['content'],
limit: int=10,
filter: Filter=None,
return_properties: list[str]=None,
return_raw: bool=False
) -> dict | list[dict]:
'''
Executes Keyword (BM25) search.
Args
----
request: str
User query.
collection_name: str
Collection (index) to search.
query_properties: list[str]
list of properties to search across.
limit: int=10
Number of results to return.
where_filter: dict=None
Property filter to apply to search results.
return_properties: list[str]=None
list of properties to return in response.
If None, returns self.return_properties.
return_raw: bool=False
If True, returns raw response from Weaviate.
'''
self._connect()
return_properties = return_properties if return_properties else self.return_properties
collection = self._client.collections.get(collection_name)
response = collection.query.bm25(query=request,
query_properties=query_properties,
limit=limit,
filters=filter,
return_metadata=MetadataQuery(score=True),
return_properties=return_properties)
# response = response.with_where(where_filter).do() if where_filter else response.do()
if return_raw:
return response
else:
return self.format_response(response)
def vector_search(self,
request: str,
collection_name: str,
limit: int=10,
return_properties: list[str]=None,
filter: Filter=None,
return_raw: bool=False,
device: str='cuda:0' if cuda.is_available() else 'cpu'
) -> dict | list[dict]:
'''
Executes vector search using embedding model defined on instantiation
of WeaviateClient instance.
Args
----
request: str
User query.
collection_name: str
Collection (index) to search.
limit: int=10
Number of results to return.
return_properties: list[str]=None
list of properties to return in response.
If None, returns all properties.
return_raw: bool=False
If True, returns raw response from Weaviate.
device: str
Device to use for encoding query.
'''
self._connect()
return_properties = return_properties if return_properties else self.return_properties
query_vector = self._create_query_vector(request, device=device)
collection = self._client.collections.get(collection_name)
response = collection.query.near_vector(near_vector=query_vector,
limit=limit,
filters=filter,
return_metadata=MetadataQuery(distance=True),
return_properties=return_properties)
# response = response.with_where(where_filter).do() if where_filter else response.do()
if return_raw:
return response
else:
return self.format_response(response)
def _create_query_vector(self, query: str, device: str) -> list[float]:
'''
Creates embedding vector from text query.
'''
return self.get_openai_embedding(query) if self._openai_model else self.model.encode(query, device=device).tolist()
def get_openai_embedding(self, query: str) -> list[float]:
'''
Gets embedding from OpenAI API for query.
'''
embedding = self.model.embeddings.create(input=query, model='text-embedding-ada-002').model_dump()
if embedding:
return embedding['data'][0]['embedding']
else:
raise ValueError(f'No embedding found for query: {query}')
def hybrid_search(self,
request: str,
collection_name: str,
query_properties: list[str]=['content'],
alpha: float=0.5,
limit: int=10,
filter: Filter=None,
return_properties: list[str]=None,
return_raw: bool=False,
device: str='cuda:0' if cuda.is_available() else 'cpu'
) -> dict | list[dict]:
'''
Executes Hybrid (Keyword + Vector) search.
Args
----
request: str
User query.
collection_name: str
Collection (index) to search.
query_properties: list[str]
list of properties to search across (using BM25)
alpha: float=0.5
Weighting factor for BM25 and Vector search.
alpha can be any number from 0 to 1, defaulting to 0.5:
alpha = 0 executes a pure keyword search method (BM25)
alpha = 0.5 weighs the BM25 and vector methods evenly
alpha = 1 executes a pure vector search method
limit: int=10
Number of results to return.
filter: Filter=None
Property filter to apply to search results.
return_properties: list[str]=None
list of properties to return in response.
If None, returns all properties.
return_raw: bool=False
If True, returns raw response from Weaviate.
'''
self._connect()
return_properties = return_properties if return_properties else self.return_properties
query_vector = self._create_query_vector(request, device=device)
collection = self._client.collections.get(collection_name)
response = collection.query.hybrid(query=request,
query_properties=query_properties,
filters=filter,
vector=query_vector,
alpha=alpha,
limit=limit,
return_metadata=MetadataQuery(score=True, distance=True),
return_properties=return_properties)
if return_raw:
return response
else:
return self.format_response(response)
class WeaviateIndexer:
def __init__(self,
client: WeaviateWCS
):
'''
Class designed to batch index documents into Weaviate. Instantiating
this class will automatically configure the Weaviate batch client.
'''
self._client = client._client
def _connect(self):
'''
Connects to Weaviate instance.
'''
if not self._client.is_connected():
self._client.connect()
def create_collection(self,
collection_name: str,
properties: list[Property],
description: str=None,
**kwargs
) -> str:
'''
Creates a collection (index) on the Weaviate instance.
'''
if collection_name.find('-') != -1:
raise ValueError('Collection name cannot contain hyphens')
try:
self._connect()
self._client.collections.create(name=collection_name,
description=description,
properties=properties,
**kwargs
)
if self._client.collections.exists(collection_name):
print(f'Collection "{collection_name}" created')
else:
print(f'Collection not found at the moment, try again later')
self._client.close()
except Exception as e:
print(f'Error creating collection, due to: {e}')
def batch_index_data(self,
data: list[dict],
collection_name: str,
error_threshold: float=0.01,
vector_property: str='content_embedding',
unique_id_field: str='doc_id',
properties: list[Property]=None,
collection_description: str=None,
**kwargs
) -> dict:
'''
Batch function for fast indexing of data onto Weaviate cluster.
Args
----
data: list[dict]
List of dictionaries where each dictionary represents a document.
collection_name: str
Name of the collection to index data into.
error_threshold: float=0.01
Threshold for error rate during batch upload. This value is a percentage of the total data
that the end user is willing to tolerate as errors. If the error rate exceeds this threshold,
the batch job will be aborted.
vector_property: str='content_embedding'
Name of the property that contains the vector representation of the document.
unique_id_field: str='doc_id'
Name of the unique identifier field in the document.
properties: list[Property]=None
List of properties to create the collection with. Required if collection does not exist.
collection_description: str=None
Description of the collection. Optional parameter.
Returns
-------
dict
Dictionary containing error information if any with the following keys:
['num_errors', 'error_messages', 'doc_ids']
'''
self._connect()
if not self._client.collections.exists(collection_name):
print(f'Collection "{collection_name}" not found on host, creating Collection first...')
if properties is None:
raise ValueError(f'Tried to create Collection <{collection_name}> but no properties were provided.')
self.create_collection(collection_name=collection_name,
properties=properties,
description=collection_description,
**kwargs)
self._client.close()
self._connect()
error_threshold_size = int(len(data) * error_threshold)
collection = self._client.collections.get(collection_name)
start = time.perf_counter()
completed_job = True
with collection.batch.dynamic() as batch:
for doc in tqdm(data):
batch.add_object(properties={k:v for k,v in doc.items() if k != vector_property},
vector=doc[vector_property])
if batch.number_errors > error_threshold_size:
print('Upload errors exceed error_threshold...')
completed_job = False
break
end = time.perf_counter() - start
print(f'Processing finished in {round(end/60, 2)} minutes.')
failed_objects = collection.batch.failed_objects
if any(failed_objects):
error_messages = [obj.message for obj in failed_objects]
doc_ids = [obj.object_.properties.get(unique_id_field, 'Not Found') for obj in failed_objects]
else:
error_messages, doc_ids = [], []
error_object = {'num_errors':batch.number_errors,
'error_messages': error_messages,
'doc_ids': doc_ids}
if not completed_job:
print(f'Batch job failed. Review errors using these keys: {list(error_object.keys())}')
return error_object
if batch.number_errors > 0:
print(f'Batch job completed with {batch.number_errors} errors. Review errors using these keys: {list(error_object.keys())}')
else:
print('Batch job completed with zero errors.')
return error_object
@dataclass
class SearchFilter(Filter):
'''
Simplified interface for constructing a Filter object.
Args
----
property : str
Property to filter on.
query_value : str
Query value to filter on.
'''
property: str
query_value: str
def exact_match(self):
return self.by_property(self.property).equal(self.query_value)
def fuzzy_match(self):
return self.by_property(self.property).like(f'*{self.query_value}*')
def get_weaviate_client(endpoint: str=os.getenv('FINRAG_WEAVIATE_ENDPOINT'),
api_key: str=os.getenv('FINRAG_WEAVIATE_API_KEY'),
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2',
embedded: bool=False,
openai_api_key: str=None,
skip_init_checks: bool=False,
**kwargs
) -> WeaviateWCS:
return WeaviateWCS(endpoint, api_key, model_name_or_path, embedded, openai_api_key, skip_init_checks, **kwargs)