File size: 1,436 Bytes
dd26848
 
 
 
 
 
 
 
 
 
1e0b033
 
dd26848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
from typing import Dict, List, Optional, TypedDict
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
from bm25 import BM25Index, BM25Retriever

sciq = load_sciq()
bm25_index = BM25Index.build_from_documents(
    documents=iter(sciq.corpus),
    ndocs=12160,
    show_progress_bar=True,
    k1=0.8,  # Tuned on dev wrt. MAP@10
    b=0.6,  # Tuned on dev wrt. MAP@10
)
bm25_index.save("output/bm25_sciq_index")
bm25_retriever = BM25Retriever(index_dir="output/bm25_sciq_index")


class Hit(TypedDict):
    cid: str
    score: float
    text: str


demo: Optional[gr.Interface] = None  # Assign your gradio demo to this variable
return_type = List[Hit]

## YOUR_CODE_STARTS_HERE
cid2doc = {doc.collection_id: doc.text for doc in sciq.corpus}


def search(query: str) -> List[Hit]:
    ranking: Dict[str, float] = bm25_retriever.retrieve(query)
    # Sort the ranking by score in descending order
    sorted_ranking = sorted(ranking.items(), key=lambda item: item[1], reverse=True)
    hits = []
    for cid, score in sorted_ranking:
        hits.append(Hit(cid=cid, score=score, text=cid2doc[cid]))
    return hits


demo = gr.Interface(
    fn=search,
    inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
    outputs="text",
    title="BM25 Retriever Search",
    description="Search using a BM25 retriever and return ranked documents with scores.",
)
## YOUR_CODE_ENDS_HERE
demo.launch()