|
from typing import ( |
|
Any, |
|
Iterable, |
|
List, |
|
Optional, |
|
Tuple, |
|
cast, |
|
Generator |
|
) |
|
|
|
import torch |
|
from langchain_community.retrievers import QdrantSparseVectorRetriever |
|
from langchain_community.vectorstores.qdrant import Qdrant |
|
from langchain_core.pydantic_v1 import Field |
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun |
|
from langchain.schema import Document |
|
try: |
|
from qdrant_client import QdrantClient, models |
|
except ImportError: |
|
pass |
|
|
|
|
|
def batchify(_list: List, batch_size: int) -> Generator[List, None, None]: |
|
for i in range(0, len(_list), batch_size): |
|
yield _list[i:i + batch_size] |
|
|
|
|
|
class MyQdrantSparseVectorRetriever(QdrantSparseVectorRetriever): |
|
splade_doc_tokenizer: Any = Field(repr=False) |
|
splade_doc_model: Any = Field(repr=False) |
|
splade_query_tokenizer: Any = Field(repr=False) |
|
splade_query_model: Any = Field(repr=False) |
|
device: Any = Field(repr=False) |
|
batch_size: int = Field(repr=False) |
|
sparse_encoder: Any or None = Field(repr=False) |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
arbitrary_types_allowed = True |
|
|
|
def compute_document_vectors(self, texts: List[str], batch_size: int) -> Tuple[List[List[int]], List[List[float]]]: |
|
indices = [] |
|
values = [] |
|
for text_batch in batchify(texts, batch_size): |
|
with torch.no_grad(): |
|
tokens = self.splade_doc_tokenizer(text_batch, truncation=True, padding=True, |
|
return_tensors="pt").to(self.device) |
|
output = self.splade_doc_model(**tokens) |
|
logits, attention_mask = output.logits, tokens.attention_mask |
|
relu_log = torch.log(1 + torch.relu(logits)) |
|
weighted_log = relu_log * attention_mask.unsqueeze(-1) |
|
tvecs, _ = torch.max(weighted_log, dim=1) |
|
|
|
|
|
for batch in tvecs.cpu(): |
|
indices.append(batch.nonzero(as_tuple=True)[0].numpy()) |
|
values.append(batch[indices[-1]].numpy()) |
|
|
|
return indices, values |
|
|
|
def compute_query_vector(self, text: str): |
|
""" |
|
Computes a vector from logits and attention mask using ReLU, log, and max operations. |
|
""" |
|
with torch.no_grad(): |
|
tokens = self.splade_query_tokenizer(text, return_tensors="pt").to(self.device) |
|
output = self.splade_query_model(**tokens) |
|
logits, attention_mask = output.logits, tokens.attention_mask |
|
relu_log = torch.log(1 + torch.relu(logits)) |
|
weighted_log = relu_log * attention_mask.unsqueeze(-1) |
|
max_val, _ = torch.max(weighted_log, dim=1) |
|
query_vec = max_val.squeeze().cpu() |
|
|
|
query_indices = query_vec.nonzero().numpy().flatten() |
|
query_values = query_vec.detach().numpy()[query_indices] |
|
|
|
return query_indices, query_values |
|
|
|
def add_texts( |
|
self, |
|
texts: Iterable[str], |
|
metadatas: Optional[List[dict]] = None, |
|
**kwargs: Any, |
|
): |
|
client = cast(QdrantClient, self.client) |
|
|
|
indices, values = self.compute_document_vectors(texts, self.batch_size) |
|
|
|
points = [ |
|
models.PointStruct( |
|
id=i + 1, |
|
vector={ |
|
self.sparse_vector_name: models.SparseVector( |
|
indices=indices[i], |
|
values=values[i], |
|
) |
|
}, |
|
payload={ |
|
self.content_payload_key: texts[i], |
|
self.metadata_payload_key: metadatas[i] if metadatas else None, |
|
}, |
|
) |
|
for i in range(len(texts)) |
|
] |
|
client.upsert(self.collection_name, points=points, **kwargs) |
|
if self.device == "cuda": |
|
torch.cuda.empty_cache() |
|
|
|
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]: |
|
client = cast(QdrantClient, self.client) |
|
query_indices, query_values = self.compute_query_vector(query) |
|
|
|
results = client.search( |
|
self.collection_name, |
|
query_filter=self.filter, |
|
query_vector=models.NamedSparseVector( |
|
name=self.sparse_vector_name, |
|
vector=models.SparseVector( |
|
indices=query_indices, |
|
values=query_values, |
|
), |
|
), |
|
limit=self.k, |
|
with_vectors=False, |
|
**self.search_options, |
|
) |
|
|
|
return [ |
|
Qdrant._document_from_scored_point( |
|
point, |
|
self.collection_name, |
|
self.content_payload_key, |
|
self.metadata_payload_key, |
|
) |
|
for point in results |
|
] |
|
|