File size: 1,683 Bytes
fa01b79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
from tqdm.auto import tqdm

tqdm.pandas()
from gensim.corpora import Dictionary
from gensim.models import TfidfModel
from gensim.similarities import SparseMatrixSimilarity
from text_utils import preprocess


class BM25Gensim:
    def __init__(self, checkpoint_path, entity_dict, title2idx):
        self.dictionary = Dictionary.load(checkpoint_path + "/dict")
        self.tfidf_model = SparseMatrixSimilarity.load(checkpoint_path + "/tfidf")
        self.bm25_index = TfidfModel.load(checkpoint_path + "/bm25_index")
        self.title2idx = title2idx
        self.entity_dict = entity_dict

    def get_topk_stage1(self, query, topk=100):
        tokenized_query = query.split()
        tfidf_query = self.tfidf_model[self.dictionary.doc2bow(tokenized_query)]
        scores = self.bm25_index[tfidf_query]
        top_n = np.argsort(scores)[::-1][:topk]
        return top_n, scores[top_n]

    def get_topk_stage2(self, x, raw_answer=None, topk=50):
        x = str(x)
        query = preprocess(x, max_length=128).lower().split()
        tfidf_query = self.tfidf_model[self.dictionary.doc2bow(query)]
        scores = self.bm25_index[tfidf_query]
        top_n = list(np.argsort(scores)[::-1][:topk])
        if raw_answer is not None:
            raw_answer = raw_answer.strip()
            if raw_answer in self.entity_dict:
                title = self.entity_dict[raw_answer].replace("wiki/", "").replace("_", " ")
                extra_id = self.title2idx.get(title, -1)
                if extra_id != -1 and extra_id not in top_n:
                    top_n.append(extra_id)
        scores = scores[top_n]
        return np.array(top_n), np.array(scores)