Spaces:
Sleeping
Sleeping
File size: 5,531 Bytes
8197b11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import argparse
import json
import logging
import os
from pathlib import Path
import time
from typing import Union
import torch
import tqdm
from relik.retriever import GoldenRetriever
from relik.common.log import get_logger
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.base.datasets import BaseDataset
from relik.retriever.indexers.base import BaseDocumentIndex
from relik.retriever.indexers.faiss import FaissDocumentIndex
logger = get_logger(level=logging.INFO)
def compute_retriever_stats(dataset) -> None:
correct, total = 0, 0
for sample in dataset:
window_candidates = sample["window_candidates"]
window_candidates = [c.replace("_", " ").lower() for c in window_candidates]
for ss, se, label in sample["window_labels"]:
if label == "--NME--":
continue
if label.replace("_", " ").lower() in window_candidates:
correct += 1
total += 1
recall = correct / total
print("Recall:", recall)
@torch.no_grad()
def add_candidates(
retriever_name_or_path: Union[str, os.PathLike],
document_index_name_or_path: Union[str, os.PathLike],
input_path: Union[str, os.PathLike],
batch_size: int = 128,
num_workers: int = 4,
index_type: str = "Flat",
nprobe: int = 1,
device: str = "cpu",
precision: str = "fp32",
topics: bool = False,
):
document_index = BaseDocumentIndex.from_pretrained(
document_index_name_or_path,
# config_kwargs={
# "_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex",
# "index_type": index_type,
# "nprobe": nprobe,
# },
device=device,
precision=precision,
)
retriever = GoldenRetriever(
question_encoder=retriever_name_or_path,
document_index=document_index,
device=device,
precision=precision,
index_device=device,
index_precision=precision,
)
retriever.eval()
logger.info(f"Loading from {input_path}")
with open(input_path) as f:
samples = [json.loads(line) for line in f.readlines()]
topics = topics and "doc_topic" in samples[0]
# get tokenizer
tokenizer = retriever.question_tokenizer
collate_fn = lambda batch: ModelInputs(
tokenizer(
[b["text"] for b in batch],
text_pair=[b["doc_topic"] for b in batch] if topics else None,
padding=True,
return_tensors="pt",
truncation=True,
)
)
logger.info(f"Creating dataloader with batch size {batch_size}")
dataloader = torch.utils.data.DataLoader(
BaseDataset(name="passage", data=samples),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=False,
collate_fn=collate_fn,
)
# we also dump the candidates to a file after a while
retrieved_accumulator = []
with torch.inference_mode():
start = time.time()
num_completed_docs = 0
for documents_batch in tqdm.tqdm(dataloader):
retrieve_kwargs = {
**documents_batch,
"k": 100,
"precision": precision,
}
batch_out = retriever.retrieve(**retrieve_kwargs)
retrieved_accumulator.extend(batch_out)
end = time.time()
output_data = []
# get the correct document from the original dataset
# the dataloader is not shuffled, so we can just count the number of
# documents we have seen so far
for sample, retrieved in zip(
samples[
num_completed_docs : num_completed_docs + len(retrieved_accumulator)
],
retrieved_accumulator,
):
candidate_titles = [c.label.split(" <def>", 1)[0] for c in retrieved]
sample["window_candidates"] = candidate_titles
sample["window_candidates_scores"] = [c.score for c in retrieved]
output_data.append(sample)
# for sample in output_data:
# f_out.write(json.dumps(sample) + "\n")
num_completed_docs += len(retrieved_accumulator)
retrieved_accumulator = []
compute_retriever_stats(output_data)
print(f"Retrieval took {end - start:.2f} seconds")
if __name__ == "__main__":
# arg_parser = argparse.ArgumentParser()
# arg_parser.add_argument("--retriever_name_or_path", type=str, required=True)
# arg_parser.add_argument("--document_index_name_or_path", type=str, required=True)
# arg_parser.add_argument("--input_path", type=str, required=True)
# arg_parser.add_argument("--output_path", type=str, required=True)
# arg_parser.add_argument("--batch_size", type=int, default=128)
# arg_parser.add_argument("--device", type=str, default="cuda")
# arg_parser.add_argument("--index_device", type=str, default="cpu")
# arg_parser.add_argument("--precision", type=str, default="fp32")
# add_candidates(**vars(arg_parser.parse_args()))
add_candidates(
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
"/root/relik-spaces/data/reader/aida/testa_windowed.jsonl",
# index_type="HNSW32",
# index_type="IVF1024,PQ8",
# nprobe=1,
topics=True,
device="cuda",
)
|