|
|
|
|
|
from weaviate.auth import AuthApiKey |
|
from weaviate.collections.classes.internal import (MetadataReturn, QueryReturn, |
|
MetadataQuery) |
|
import weaviate |
|
from weaviate.classes.config import Property |
|
from weaviate.classes.query import Filter |
|
from weaviate.config import ConnectionConfig |
|
from openai import OpenAI |
|
from sentence_transformers import SentenceTransformer |
|
from typing import Any |
|
from torch import cuda |
|
from tqdm import tqdm |
|
import time |
|
import os |
|
from dataclasses import dataclass |
|
|
|
class WeaviateWCS: |
|
''' |
|
A python native Weaviate Client class that encapsulates Weaviate functionalities |
|
in one object. Several convenience methods are added for ease of use. |
|
|
|
Args |
|
---- |
|
api_key: str |
|
The API key for the Weaviate Cloud Service (WCS) instance. |
|
https://console.weaviate.cloud/dashboard |
|
|
|
endpoint: str |
|
The url endpoint for the Weaviate Cloud Service instance. |
|
|
|
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2' |
|
The name or path of the SentenceTransformer model to use for vector search. |
|
Will also support OpenAI text-embedding-ada-002 model. This param enables |
|
the use of most leading models on MTEB Leaderboard: |
|
https://huggingface.co/spaces/mteb/leaderboard |
|
openai_api_key: str=None |
|
The API key for the OpenAI API. Only required if using OpenAI text-embedding-ada-002 model. |
|
''' |
|
def __init__(self, |
|
endpoint: str=None, |
|
api_key: str=None, |
|
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2', |
|
embedded: bool=False, |
|
openai_api_key: str=None, |
|
skip_init_checks: bool=False, |
|
**kwargs |
|
): |
|
|
|
self.endpoint = endpoint |
|
if embedded: |
|
self._client = weaviate.connect_to_embedded(**kwargs) |
|
else: |
|
auth_config = AuthApiKey(api_key=api_key) |
|
self._client = weaviate.connect_to_wcs(cluster_url=endpoint, |
|
auth_credentials=auth_config, |
|
skip_init_checks=skip_init_checks) |
|
self.model_name_or_path = model_name_or_path |
|
self._openai_model = False |
|
if self.model_name_or_path == 'text-embedding-ada-002': |
|
if not openai_api_key: |
|
raise ValueError(f'OpenAI API key must be provided to use this model: {self.model_name_or_path}') |
|
self.model = OpenAI(api_key=openai_api_key) |
|
self._openai_model = True |
|
else: |
|
self.model = SentenceTransformer(self.model_name_or_path) if self.model_name_or_path else None |
|
|
|
self.return_properties = ['guest', 'title', 'summary', 'content', 'video_id', 'doc_id', 'episode_url', 'thumbnail_url'] |
|
|
|
def _connect(self) -> None: |
|
''' |
|
Connects to Weaviate instance. |
|
''' |
|
if not self._client.is_connected(): |
|
self._client.connect() |
|
|
|
def create_collection(self, |
|
collection_name: str, |
|
properties: list[Property], |
|
description: str=None, |
|
**kwargs |
|
) -> None: |
|
''' |
|
Creates a collection (index) on the Weaviate instance. |
|
|
|
Args |
|
---- |
|
collection_name: str |
|
Name of the collection to create. |
|
properties: list[Property] |
|
List of properties to add to data objects in the collection. |
|
description: str=None |
|
User-defined description of the collection. |
|
''' |
|
|
|
self._connect() |
|
if self._client.collections.exists(collection_name): |
|
print(f'Collection "{collection_name}" already exists') |
|
return |
|
else: |
|
try: |
|
self._client.collections.create(name=collection_name, |
|
properties=properties, |
|
description=description, |
|
**kwargs) |
|
print(f'Collection "{collection_name}" created') |
|
except Exception as e: |
|
print(f'Error creating collection, due to: {e}') |
|
self._client.close() |
|
return |
|
|
|
def show_all_collections(self, |
|
detailed: bool=False, |
|
max_details: bool=False |
|
) -> list[str] | dict: |
|
''' |
|
Shows all available collections(indexes) on the Weaviate cluster. |
|
By default will only return list of collection names. |
|
Otherwise, increasing details about each collection can be returned. |
|
''' |
|
self._connect() |
|
collections = self._client.collections.list_all(simple=not max_details) |
|
self._client.close() |
|
if not detailed and not max_details: |
|
return list(collections.keys()) |
|
else: |
|
if not any(collections): |
|
print('No collections found on host') |
|
return collections |
|
|
|
def show_collection_config(self, collection_name: str) -> ConnectionConfig: |
|
''' |
|
Shows all information of a specific collection. |
|
''' |
|
self._connect() |
|
if self._client.collections.exists(collection_name): |
|
collection = self.show_all_collections(max_details=True)[collection_name] |
|
self._client.close() |
|
return collection |
|
else: |
|
print(f'Collection "{collection_name}" not found on host') |
|
|
|
def show_collection_properties(self, collection_name: str) -> dict | str: |
|
''' |
|
Shows all properties of a collection (index) on the Weaviate instance. |
|
''' |
|
self._connect() |
|
if self._client.collections.exists(collection_name): |
|
collection = self.show_all_collections(max_details=True)[collection_name] |
|
self._client.close() |
|
return collection.properties |
|
else: |
|
print(f'Collection "{collection_name}" not found on host') |
|
|
|
def delete_collection(self, collection_name: str) -> str: |
|
''' |
|
Deletes a collection (index) on the Weaviate instance, if it exists. |
|
''' |
|
self._connect() |
|
if self._client.collections.exists(collection_name): |
|
try: |
|
self._client.collections.delete(collection_name) |
|
self._client.close() |
|
print(f'Collection "{collection_name}" deleted') |
|
except Exception as e: |
|
print(f'Error deleting collection, due to: {e}') |
|
else: |
|
print(f'Collection "{collection_name}" not found on host') |
|
|
|
def get_doc_count(self, collection_name: str) -> str: |
|
''' |
|
Returns the number of documents in a collection. |
|
''' |
|
self._connect() |
|
if self._client.collections.exists(collection_name): |
|
collection = self._client.collections.get(collection_name) |
|
aggregate = collection.aggregate.over_all() |
|
total_count = aggregate.total_count |
|
print(f'Found {total_count} documents in collection "{collection_name}"') |
|
return total_count |
|
else: |
|
print(f'Collection "{collection_name}" not found on host') |
|
|
|
def format_response(self, |
|
response: QueryReturn, |
|
) -> list[dict]: |
|
''' |
|
Formats json response from Weaviate into a list of dictionaries. |
|
Expands _additional fields if present into top-level dictionary. |
|
''' |
|
results = [{**o.properties, **self._get_meta(o.metadata)} for o in response.objects] |
|
return results |
|
|
|
def _get_meta(self, metadata: MetadataReturn): |
|
''' |
|
Extracts metadata from MetadataQuery object if meta exists. |
|
''' |
|
temp_dict = metadata.__dict__ |
|
return {k:v for k,v in temp_dict.items() if v} |
|
|
|
def keyword_search(self, |
|
request: str, |
|
collection_name: str, |
|
query_properties: list[str]=['content'], |
|
limit: int=10, |
|
filter: Filter=None, |
|
return_properties: list[str]=None, |
|
return_raw: bool=False |
|
) -> dict | list[dict]: |
|
''' |
|
Executes Keyword (BM25) search. |
|
|
|
Args |
|
---- |
|
request: str |
|
User query. |
|
collection_name: str |
|
Collection (index) to search. |
|
query_properties: list[str] |
|
list of properties to search across. |
|
limit: int=10 |
|
Number of results to return. |
|
where_filter: dict=None |
|
Property filter to apply to search results. |
|
return_properties: list[str]=None |
|
list of properties to return in response. |
|
If None, returns self.return_properties. |
|
return_raw: bool=False |
|
If True, returns raw response from Weaviate. |
|
''' |
|
self._connect() |
|
return_properties = return_properties if return_properties else self.return_properties |
|
collection = self._client.collections.get(collection_name) |
|
response = collection.query.bm25(query=request, |
|
query_properties=query_properties, |
|
limit=limit, |
|
filters=filter, |
|
return_metadata=MetadataQuery(score=True), |
|
return_properties=return_properties) |
|
|
|
if return_raw: |
|
return response |
|
else: |
|
return self.format_response(response) |
|
|
|
def vector_search(self, |
|
request: str, |
|
collection_name: str, |
|
limit: int=10, |
|
return_properties: list[str]=None, |
|
filter: Filter=None, |
|
return_raw: bool=False, |
|
device: str='cuda:0' if cuda.is_available() else 'cpu' |
|
) -> dict | list[dict]: |
|
''' |
|
Executes vector search using embedding model defined on instantiation |
|
of WeaviateClient instance. |
|
|
|
Args |
|
---- |
|
request: str |
|
User query. |
|
collection_name: str |
|
Collection (index) to search. |
|
limit: int=10 |
|
Number of results to return. |
|
return_properties: list[str]=None |
|
list of properties to return in response. |
|
If None, returns all properties. |
|
return_raw: bool=False |
|
If True, returns raw response from Weaviate. |
|
device: str |
|
Device to use for encoding query. |
|
''' |
|
self._connect() |
|
return_properties = return_properties if return_properties else self.return_properties |
|
query_vector = self._create_query_vector(request, device=device) |
|
collection = self._client.collections.get(collection_name) |
|
response = collection.query.near_vector(near_vector=query_vector, |
|
limit=limit, |
|
filters=filter, |
|
return_metadata=MetadataQuery(distance=True), |
|
return_properties=return_properties) |
|
|
|
if return_raw: |
|
return response |
|
else: |
|
return self.format_response(response) |
|
|
|
def _create_query_vector(self, query: str, device: str) -> list[float]: |
|
''' |
|
Creates embedding vector from text query. |
|
''' |
|
return self.get_openai_embedding(query) if self._openai_model else self.model.encode(query, device=device).tolist() |
|
|
|
def get_openai_embedding(self, query: str) -> list[float]: |
|
''' |
|
Gets embedding from OpenAI API for query. |
|
''' |
|
embedding = self.model.embeddings.create(input=query, model='text-embedding-ada-002').model_dump() |
|
if embedding: |
|
return embedding['data'][0]['embedding'] |
|
else: |
|
raise ValueError(f'No embedding found for query: {query}') |
|
|
|
def hybrid_search(self, |
|
request: str, |
|
collection_name: str, |
|
query_properties: list[str]=['content'], |
|
alpha: float=0.5, |
|
limit: int=10, |
|
filter: Filter=None, |
|
return_properties: list[str]=None, |
|
return_raw: bool=False, |
|
device: str='cuda:0' if cuda.is_available() else 'cpu' |
|
) -> dict | list[dict]: |
|
''' |
|
Executes Hybrid (Keyword + Vector) search. |
|
|
|
Args |
|
---- |
|
request: str |
|
User query. |
|
collection_name: str |
|
Collection (index) to search. |
|
query_properties: list[str] |
|
list of properties to search across (using BM25) |
|
alpha: float=0.5 |
|
Weighting factor for BM25 and Vector search. |
|
alpha can be any number from 0 to 1, defaulting to 0.5: |
|
alpha = 0 executes a pure keyword search method (BM25) |
|
alpha = 0.5 weighs the BM25 and vector methods evenly |
|
alpha = 1 executes a pure vector search method |
|
limit: int=10 |
|
Number of results to return. |
|
filter: Filter=None |
|
Property filter to apply to search results. |
|
return_properties: list[str]=None |
|
list of properties to return in response. |
|
If None, returns all properties. |
|
return_raw: bool=False |
|
If True, returns raw response from Weaviate. |
|
''' |
|
self._connect() |
|
return_properties = return_properties if return_properties else self.return_properties |
|
query_vector = self._create_query_vector(request, device=device) |
|
collection = self._client.collections.get(collection_name) |
|
response = collection.query.hybrid(query=request, |
|
query_properties=query_properties, |
|
filters=filter, |
|
vector=query_vector, |
|
alpha=alpha, |
|
limit=limit, |
|
return_metadata=MetadataQuery(score=True, distance=True), |
|
return_properties=return_properties) |
|
if return_raw: |
|
return response |
|
else: |
|
return self.format_response(response) |
|
|
|
|
|
class WeaviateIndexer: |
|
|
|
def __init__(self, |
|
client: WeaviateWCS |
|
): |
|
''' |
|
Class designed to batch index documents into Weaviate. Instantiating |
|
this class will automatically configure the Weaviate batch client. |
|
''' |
|
|
|
self._client = client._client |
|
|
|
def _connect(self): |
|
''' |
|
Connects to Weaviate instance. |
|
''' |
|
if not self._client.is_connected(): |
|
self._client.connect() |
|
|
|
def create_collection(self, |
|
collection_name: str, |
|
properties: list[Property], |
|
description: str=None, |
|
**kwargs |
|
) -> str: |
|
''' |
|
Creates a collection (index) on the Weaviate instance. |
|
''' |
|
if collection_name.find('-') != -1: |
|
raise ValueError('Collection name cannot contain hyphens') |
|
try: |
|
self._connect() |
|
self._client.collections.create(name=collection_name, |
|
description=description, |
|
properties=properties, |
|
**kwargs |
|
) |
|
if self._client.collections.exists(collection_name): |
|
print(f'Collection "{collection_name}" created') |
|
else: |
|
print(f'Collection not found at the moment, try again later') |
|
self._client.close() |
|
except Exception as e: |
|
print(f'Error creating collection, due to: {e}') |
|
|
|
def batch_index_data(self, |
|
data: list[dict], |
|
collection_name: str, |
|
error_threshold: float=0.01, |
|
vector_property: str='content_embedding', |
|
unique_id_field: str='doc_id', |
|
properties: list[Property]=None, |
|
collection_description: str=None, |
|
**kwargs |
|
) -> dict: |
|
''' |
|
Batch function for fast indexing of data onto Weaviate cluster. |
|
|
|
Args |
|
---- |
|
data: list[dict] |
|
List of dictionaries where each dictionary represents a document. |
|
collection_name: str |
|
Name of the collection to index data into. |
|
error_threshold: float=0.01 |
|
Threshold for error rate during batch upload. This value is a percentage of the total data |
|
that the end user is willing to tolerate as errors. If the error rate exceeds this threshold, |
|
the batch job will be aborted. |
|
vector_property: str='content_embedding' |
|
Name of the property that contains the vector representation of the document. |
|
unique_id_field: str='doc_id' |
|
Name of the unique identifier field in the document. |
|
properties: list[Property]=None |
|
List of properties to create the collection with. Required if collection does not exist. |
|
collection_description: str=None |
|
Description of the collection. Optional parameter. |
|
|
|
Returns |
|
------- |
|
dict |
|
Dictionary containing error information if any with the following keys: |
|
['num_errors', 'error_messages', 'doc_ids'] |
|
''' |
|
self._connect() |
|
if not self._client.collections.exists(collection_name): |
|
print(f'Collection "{collection_name}" not found on host, creating Collection first...') |
|
if properties is None: |
|
raise ValueError(f'Tried to create Collection <{collection_name}> but no properties were provided.') |
|
self.create_collection(collection_name=collection_name, |
|
properties=properties, |
|
description=collection_description, |
|
**kwargs) |
|
self._client.close() |
|
|
|
self._connect() |
|
error_threshold_size = int(len(data) * error_threshold) |
|
collection = self._client.collections.get(collection_name) |
|
|
|
start = time.perf_counter() |
|
completed_job = True |
|
|
|
with collection.batch.dynamic() as batch: |
|
for doc in tqdm(data): |
|
batch.add_object(properties={k:v for k,v in doc.items() if k != vector_property}, |
|
vector=doc[vector_property]) |
|
if batch.number_errors > error_threshold_size: |
|
print('Upload errors exceed error_threshold...') |
|
completed_job = False |
|
break |
|
end = time.perf_counter() - start |
|
print(f'Processing finished in {round(end/60, 2)} minutes.') |
|
|
|
failed_objects = collection.batch.failed_objects |
|
if any(failed_objects): |
|
error_messages = [obj.message for obj in failed_objects] |
|
doc_ids = [obj.object_.properties.get(unique_id_field, 'Not Found') for obj in failed_objects] |
|
else: |
|
error_messages, doc_ids = [], [] |
|
error_object = {'num_errors':batch.number_errors, |
|
'error_messages': error_messages, |
|
'doc_ids': doc_ids} |
|
if not completed_job: |
|
print(f'Batch job failed. Review errors using these keys: {list(error_object.keys())}') |
|
return error_object |
|
if batch.number_errors > 0: |
|
print(f'Batch job completed with {batch.number_errors} errors. Review errors using these keys: {list(error_object.keys())}') |
|
else: |
|
print('Batch job completed with zero errors.') |
|
return error_object |
|
|
|
|
|
@dataclass |
|
class SearchFilter(Filter): |
|
|
|
''' |
|
Simplified interface for constructing a Filter object. |
|
|
|
Args |
|
---- |
|
property : str |
|
Property to filter on. |
|
query_value : str |
|
Query value to filter on. |
|
''' |
|
property: str |
|
query_value: str |
|
|
|
def exact_match(self): |
|
return self.by_property(self.property).equal(self.query_value) |
|
|
|
def fuzzy_match(self): |
|
return self.by_property(self.property).like(f'*{self.query_value}*') |
|
|
|
|
|
|
|
def get_weaviate_client(endpoint: str=os.getenv('FINRAG_WEAVIATE_ENDPOINT'), |
|
api_key: str=os.getenv('FINRAG_WEAVIATE_API_KEY'), |
|
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2', |
|
embedded: bool=False, |
|
openai_api_key: str=None, |
|
skip_init_checks: bool=False, |
|
**kwargs |
|
) -> WeaviateWCS: |
|
return WeaviateWCS(endpoint, api_key, model_name_or_path, embedded, openai_api_key, skip_init_checks, **kwargs) |