File size: 3,119 Bytes
83870cc
51dabd6
 
 
1fb8ae3
51a31d4
 
 
 
 
 
51dabd6
51a31d4
 
b7158e7
8bbe3aa
51a31d4
 
ab5dfc2
51a31d4
83870cc
51a31d4
83870cc
51a31d4
ab5dfc2
8bbe3aa
 
 
 
1fb8ae3
8bbe3aa
 
 
 
83870cc
 
8bbe3aa
83870cc
 
8bbe3aa
 
 
83870cc
 
8bbe3aa
83870cc
 
8bbe3aa
1fb8ae3
 
 
 
8bbe3aa
1fb8ae3
ab5dfc2
 
1fb8ae3
 
b7158e7
 
1fb8ae3
 
 
 
83870cc
 
 
 
 
 
 
 
 
 
 
1fb8ae3
83870cc
1fb8ae3
8bbe3aa
83870cc
ab5dfc2
1fb8ae3
 
8bbe3aa
1fb8ae3
8bbe3aa
a1746cf
8bbe3aa
 
83870cc
8bbe3aa
 
 
1fb8ae3
83870cc
 
2827202
8bbe3aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import os.path

import torch
from datasets import DatasetDict, load_dataset
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer,
)

from src.retrievers.base_retriever import Retriever
from src.utils.log import get_logger
from src.utils.preprocessing import remove_formulas

# Hacky fix for FAISS error on macOS
# See https://stackoverflow.com/a/63374568/4545692
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"


logger = get_logger()


class FaissRetriever(Retriever):
    """A class used to retrieve relevant documents based on some query.
    based on https://huggingface.co/docs/datasets/faiss_es#faiss.
    """

    def __init__(self, dataset: DatasetDict, embedding_path: str = "./src/models/paragraphs_embedding.faiss") -> None:
        torch.set_grad_enabled(False)

        # Context encoding and tokenization
        self.ctx_encoder = DPRContextEncoder.from_pretrained(
            "facebook/dpr-ctx_encoder-single-nq-base"
        )
        self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
            "facebook/dpr-ctx_encoder-single-nq-base"
        )

        # Question encoding and tokenization
        self.q_encoder = DPRQuestionEncoder.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )

        self.dataset = dataset
        self.embedding_path = embedding_path

        self.index = self._init_index()

    def _init_index(
            self,
            force_new_embedding: bool = False):

        ds = self.dataset["train"]
        ds = ds.map(remove_formulas)


        if not force_new_embedding and os.path.exists(self.embedding_path):
            ds.load_faiss_index(
                'embeddings', self.embedding_path)  # type: ignore
            return ds
        else:
            def embed(row):
                # Inline helper function to perform embedding
                p = row["text"]
                tok = self.ctx_tokenizer(
                    p, return_tensors="pt", truncation=True)
                enc = self.ctx_encoder(**tok)[0][0].numpy()
                return {"embeddings": enc}

            # Add FAISS embeddings
            index = ds.map(embed)  # type: ignore

            index.add_faiss_index(column="embeddings")

            # save dataset w/ embeddings
            os.makedirs("./src/models/", exist_ok=True)
            index.save_faiss_index(
                "embeddings", self.embedding_path)

            return index

    def retrieve(self, query: str, k: int = 50):
        def embed(q):
            # Inline helper function to perform embedding
            tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
            return self.q_encoder(**tok)[0][0].numpy()

        question_embedding = embed(query)
        scores, results = self.index.get_nearest_examples(
            "embeddings", question_embedding, k=k
        )

        return scores, results