File size: 3,104 Bytes
8bbe3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, \
                         DPRQuestionEncoder, DPRQuestionEncoderTokenizer
from datasets import load_dataset
import torch


class Retriever():
    """A class used to retrieve relevant documents based on some query.
    based on https://huggingface.co/docs/datasets/faiss_es#faiss.
    """

    def __init__(self, dataset: str = "GroNLP/ik-nlp-22_slp") -> None:
        """Initialize the retriever

        Args:
            dataset (str, optional): The dataset to train on. Assumes the
            information is stored in a column named 'text'. Defaults to
            "GroNLP/ik-nlp-22_slp".
        """
        torch.set_grad_enabled(False)

        # Context encoding and tokenization
        self.ctx_encoder = DPRContextEncoder.from_pretrained(
            "facebook/dpr-ctx_encoder-single-nq-base")
        self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
            "facebook/dpr-ctx_encoder-single-nq-base")

        # Question encoding and tokenization
        self.q_encoder = DPRQuestionEncoder.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")

        # Dataset building
        self.dataset = self.__init_dataset(dataset)

    def __init_dataset(self, dataset: str):
        """Loads the dataset and adds FAISS embeddings.

        Args:
            dataset (str): A HuggingFace dataset name.

        Returns:
            Dataset: A dataset with a new column 'embeddings' containing FAISS
            embeddings.
        """
        # TODO: save ds w/ embeddings to disk and retrieve it if it already exists

        # Load dataset
        ds = load_dataset(dataset, name='paragraphs')['train']

        def embed(row):
            # Inline helper function to perform embedding
            p = row['text']
            tok = self.ctx_tokenizer(p, return_tensors='pt', truncation=True)
            enc = self.ctx_encoder(**tok)[0][0].numpy()
            return {'embeddings': enc}

        # Add FAISS embeddings
        ds_with_embeddings = ds.map(embed)

        # Todo: this throws a weird error.
        ds_with_embeddings.add_faiss_index(column='embeddings')
        return ds_with_embeddings

    def retrieve(self, query: str, k: int = 5):
        """Retrieve the top k matches for a search query.

        Args:
            query (str): A search query
            k (int, optional): The number of documents to retrieve. Defaults to
            5.

        Returns:
            tuple: A tuple of lists of scores and results.
        """

        def embed(q):
            # Inline helper function to perform embedding
            tok = self.q_tokenizer(q, return_tensors='pt', truncation=True)
            return self.q_encoder(**tok)[0][0].numpy()

        question_embedding = embed(query)
        scores, results = self.dataset.get_nearest_examples(
            'embeddings', question_embedding, k=k)
        return scores, results