Spaces:
Running
Running
import os | |
from app.service.api import baseURL | |
from qdrant_client import QdrantClient, models | |
api_key = os.environ.get("QDRANT_API_KEY") | |
class QdrantConnectionDb: | |
client = None | |
_instance = None | |
_collection_name = "docuRAG" | |
_vector_size = 384 | |
dense_model = "sentence-transformers/all-MiniLM-L6-v2" | |
sparse_model = "prithivida/Splade_PP_en_v1" | |
def __new__(cls, *args, **kwargs): | |
""" | |
Create a new instance of QdrantConnectionDb if it does not exist and initialize the collection and models. | |
""" | |
if cls._instance is None: | |
cls._instance = super(QdrantConnectionDb, cls).__new__(cls) | |
cls.client = QdrantClient(url=baseURL, api_key=api_key) | |
cls._initialize_collection( | |
cls.client, | |
cls._collection_name, | |
cls._vector_size, | |
) | |
cls._set_models(cls.dense_model, cls.sparse_model) | |
return cls._instance | |
def _initialize_collection( | |
cls, client: QdrantClient, collection_name: str, _vector_size: int | |
): | |
""" | |
Initialize collection if it does not exist | |
:param client: QdrantClient | |
:param collection_name: str | |
:param _vector_size: int | |
:return: None | |
""" | |
try: | |
collections = client.get_collections().collections | |
if collection_name not in [c.name for c in collections]: | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config={ | |
"text-dense": models.VectorParams( | |
size=_vector_size, | |
distance=models.Distance.COSINE, | |
) | |
}, | |
sparse_vectors_config={ | |
"text-sparse": models.SparseVectorParams( | |
index=models.SparseIndexParams( | |
on_disk=False, | |
) | |
) | |
}, | |
) | |
print(f"Collection {collection_name} initialized successfully") | |
except Exception as e: | |
print(f"Error while initializing collection: {e}") | |
def get_client(self) -> QdrantClient: | |
""" | |
Get the QdrantClient instance | |
""" | |
return self.client | |
def _set_models(self, model_name: str, sparse_model_name: str): | |
""" | |
Set the model and sparse model for the client | |
""" | |
self.client.set_model(model_name) | |
self.client.set_sparse_model(sparse_model_name) | |
def get_collection_name(cls) -> str: | |
""" | |
Get the current collection name | |
""" | |
return cls._collection_name | |