File size: 1,499 Bytes
dd26848 1e0b033 dd26848 816c9c6 dd26848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import gradio as gr
from typing import Dict, List, Optional, TypedDict
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
from bm25 import BM25Index, BM25Retriever
sciq = load_sciq()
bm25_index = BM25Index.build_from_documents(
documents=iter(sciq.corpus),
ndocs=12160,
show_progress_bar=True,
k1=0.8, # Tuned on dev wrt. MAP@10
b=0.6, # Tuned on dev wrt. MAP@10
)
bm25_index.save("output/bm25_sciq_index")
bm25_retriever = BM25Retriever(index_dir="output/bm25_sciq_index")
class Hit(TypedDict):
cid: str
score: float
text: str
demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
return_type = List[Hit]
## YOUR_CODE_STARTS_HERE
cid2doc = {doc.collection_id: doc.text for doc in sciq.corpus}
def search(query: str) -> List[Hit]:
ranking: Dict[str, float] = bm25_retriever.retrieve(query)
# Sort the ranking by score in descending order
sorted_ranking = sorted(ranking.items(), key=lambda item: item[1], reverse=True)
hits = []
for cid, score in sorted_ranking:
hits.append(Hit(cid=cid, score=score, text=cid2doc[cid]))
return hits
demo = gr.Interface(
fn=search,
inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
outputs="text",
title="BM25 Retriever Search",
description="Search using a BM25 retriever on [SciQ](https://huggingface.co/datasets/allenai/sciq) and return top-10 ranked documents with scores.",
)
## YOUR_CODE_ENDS_HERE
demo.launch()
|