Spaces:
Sleeping
Sleeping
import argparse | |
import json | |
import logging | |
import os | |
from pathlib import Path | |
import time | |
from typing import Union | |
import torch | |
import tqdm | |
from relik.retriever import GoldenRetriever | |
from relik.common.log import get_logger | |
from relik.retriever.common.model_inputs import ModelInputs | |
from relik.retriever.data.base.datasets import BaseDataset | |
from relik.retriever.indexers.base import BaseDocumentIndex | |
from relik.retriever.indexers.faiss import FaissDocumentIndex | |
logger = get_logger(level=logging.INFO) | |
def compute_retriever_stats(dataset) -> None: | |
correct, total = 0, 0 | |
for sample in dataset: | |
window_candidates = sample["window_candidates"] | |
window_candidates = [c.replace("_", " ").lower() for c in window_candidates] | |
for ss, se, label in sample["window_labels"]: | |
if label == "--NME--": | |
continue | |
if label.replace("_", " ").lower() in window_candidates: | |
correct += 1 | |
total += 1 | |
recall = correct / total | |
print("Recall:", recall) | |
def add_candidates( | |
retriever_name_or_path: Union[str, os.PathLike], | |
document_index_name_or_path: Union[str, os.PathLike], | |
input_path: Union[str, os.PathLike], | |
batch_size: int = 128, | |
num_workers: int = 4, | |
index_type: str = "Flat", | |
nprobe: int = 1, | |
device: str = "cpu", | |
precision: str = "fp32", | |
topics: bool = False, | |
): | |
document_index = BaseDocumentIndex.from_pretrained( | |
document_index_name_or_path, | |
# config_kwargs={ | |
# "_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex", | |
# "index_type": index_type, | |
# "nprobe": nprobe, | |
# }, | |
device=device, | |
precision=precision, | |
) | |
retriever = GoldenRetriever( | |
question_encoder=retriever_name_or_path, | |
document_index=document_index, | |
device=device, | |
precision=precision, | |
index_device=device, | |
index_precision=precision, | |
) | |
retriever.eval() | |
logger.info(f"Loading from {input_path}") | |
with open(input_path) as f: | |
samples = [json.loads(line) for line in f.readlines()] | |
topics = topics and "doc_topic" in samples[0] | |
# get tokenizer | |
tokenizer = retriever.question_tokenizer | |
collate_fn = lambda batch: ModelInputs( | |
tokenizer( | |
[b["text"] for b in batch], | |
text_pair=[b["doc_topic"] for b in batch] if topics else None, | |
padding=True, | |
return_tensors="pt", | |
truncation=True, | |
) | |
) | |
logger.info(f"Creating dataloader with batch size {batch_size}") | |
dataloader = torch.utils.data.DataLoader( | |
BaseDataset(name="passage", data=samples), | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
pin_memory=False, | |
collate_fn=collate_fn, | |
) | |
# we also dump the candidates to a file after a while | |
retrieved_accumulator = [] | |
with torch.inference_mode(): | |
start = time.time() | |
num_completed_docs = 0 | |
for documents_batch in tqdm.tqdm(dataloader): | |
retrieve_kwargs = { | |
**documents_batch, | |
"k": 100, | |
"precision": precision, | |
} | |
batch_out = retriever.retrieve(**retrieve_kwargs) | |
retrieved_accumulator.extend(batch_out) | |
end = time.time() | |
output_data = [] | |
# get the correct document from the original dataset | |
# the dataloader is not shuffled, so we can just count the number of | |
# documents we have seen so far | |
for sample, retrieved in zip( | |
samples[ | |
num_completed_docs : num_completed_docs + len(retrieved_accumulator) | |
], | |
retrieved_accumulator, | |
): | |
candidate_titles = [c.label.split(" <def>", 1)[0] for c in retrieved] | |
sample["window_candidates"] = candidate_titles | |
sample["window_candidates_scores"] = [c.score for c in retrieved] | |
output_data.append(sample) | |
# for sample in output_data: | |
# f_out.write(json.dumps(sample) + "\n") | |
num_completed_docs += len(retrieved_accumulator) | |
retrieved_accumulator = [] | |
compute_retriever_stats(output_data) | |
print(f"Retrieval took {end - start:.2f} seconds") | |
if __name__ == "__main__": | |
# arg_parser = argparse.ArgumentParser() | |
# arg_parser.add_argument("--retriever_name_or_path", type=str, required=True) | |
# arg_parser.add_argument("--document_index_name_or_path", type=str, required=True) | |
# arg_parser.add_argument("--input_path", type=str, required=True) | |
# arg_parser.add_argument("--output_path", type=str, required=True) | |
# arg_parser.add_argument("--batch_size", type=int, default=128) | |
# arg_parser.add_argument("--device", type=str, default="cuda") | |
# arg_parser.add_argument("--index_device", type=str, default="cpu") | |
# arg_parser.add_argument("--precision", type=str, default="fp32") | |
# add_candidates(**vars(arg_parser.parse_args())) | |
add_candidates( | |
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder", | |
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered", | |
"/root/relik-spaces/data/reader/aida/testa_windowed.jsonl", | |
# index_type="HNSW32", | |
# index_type="IVF1024,PQ8", | |
# nprobe=1, | |
topics=True, | |
device="cuda", | |
) | |