|
import gradio as gr |
|
from typing import Dict, List, Optional, TypedDict |
|
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq |
|
from bm25 import BM25Index, BM25Retriever |
|
|
|
sciq = load_sciq() |
|
bm25_index = BM25Index.build_from_documents( |
|
documents=iter(sciq.corpus), |
|
ndocs=12160, |
|
show_progress_bar=True, |
|
k1=0.8, |
|
b=0.6, |
|
) |
|
bm25_index.save("output/bm25_sciq_index") |
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_sciq_index") |
|
|
|
|
|
class Hit(TypedDict): |
|
cid: str |
|
score: float |
|
text: str |
|
|
|
|
|
demo: Optional[gr.Interface] = None |
|
return_type = List[Hit] |
|
|
|
|
|
cid2doc = {doc.collection_id: doc.text for doc in sciq.corpus} |
|
|
|
|
|
def search(query: str) -> List[Hit]: |
|
ranking: Dict[str, float] = bm25_retriever.retrieve(query) |
|
|
|
sorted_ranking = sorted(ranking.items(), key=lambda item: item[1], reverse=True) |
|
hits = [] |
|
for cid, score in sorted_ranking: |
|
hits.append(Hit(cid=cid, score=score, text=cid2doc[cid])) |
|
return hits |
|
|
|
|
|
demo = gr.Interface( |
|
fn=search, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."), |
|
outputs="text", |
|
title="BM25 Retriever Search", |
|
description="Search using a BM25 retriever on [SciQ](https://huggingface.co/datasets/allenai/sciq) and return top-10 ranked documents with scores.", |
|
) |
|
|
|
demo.launch() |
|
|