DocuRAG / Api /app /modules /denseEmbeddings /denseEmbeddings.py
abadesalex's picture
Update to Qdrant db
47b5f0c
raw
history blame
1.99 kB
import torch
from qdrant_client import models
from qdrant_client.models import NamedVector
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer
class DenseEmbeddings:
def __init__(
self,
dense_model: AutoModel,
dense_tokenizer: AutoTokenizer,
sparse_model: AutoModelForMaskedLM,
sparse_tokenizer: AutoTokenizer,
):
self.dense_model = dense_model
self.dense_tokenizer = dense_tokenizer
self.sparse_model = sparse_model
self.sparse_tokenizer = sparse_tokenizer
def get_dense_vector(self, text: str) -> NamedVector:
"""
Get dense vector from the dense model
:param text: str
:return: NamedVector
"""
inputs = self.dense_tokenizer(
text, return_tensors="pt", padding=True, truncation=True
)
with torch.no_grad():
outputs = self.dense_model(**inputs)
dense_vector = NamedVector(
name="text-dense",
vector=torch.mean(outputs.last_hidden_state, dim=1).squeeze().numpy(),
)
return dense_vector
def get_sparse_vector(self, text: str) -> models.SparseVector:
"""
Get sparse vector from the sparse model
:param text: str
:return: SparseVector
"""
inputs = self.sparse_tokenizer(
text, return_tensors="pt", padding=True, truncation=True
)
with torch.no_grad():
outputs = self.sparse_model(**inputs)
token_scores = outputs.logits.squeeze().max(dim=0)[0]
token_ids = inputs["input_ids"].squeeze()
sparse_vector = {
int(token_id): float(score)
for token_id, score in zip(token_ids, token_scores)
if score > -5.0
}
sparse_vector = models.SparseVector(
indices=list(sparse_vector.keys()),
values=list(sparse_vector.values()),
)
return sparse_vector