cheesyFishes's picture
Upload run_airbench.py
81bb72e verified
from typing import Any, Callable, Dict, List, Optional
from air_benchmark import AIRBench, Retriever
from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.schema import Document, NodeWithScore
def create_retriever_fn(documents: List[Document], top_k: int) -> Callable[[str], List[NodeWithScore]]:
# IMPORTANT: if you don't use a llama-index node parser/splitter, you need to ensure
# that node.ref_doc_id points to the correct parent document id.
# This is used to line up the corpus document id for evaluation
nodes = SentenceSplitter(chunk_size=1024, chunk_overlap=128)(documents)
vector_index = VectorStoreIndex(
nodes=nodes,
embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002")
)
vector_retriever = vector_index.as_retriever(similarity_top_k=top_k)
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k)
retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=top_k,
num_queries=3,
mode="dist_based_score",
llm=OpenAI(model="gpt-3.5-turbo", temperature=0.1)
)
def _retriever(query: str) -> List[NodeWithScore]:
return retriever.retrieve(query)
return _retriever
class LlamaRetriever(Retriever):
def __init__(
self,
name: str,
create_retriever_fn: Callable[[List[Document], int], Callable[[str], List[NodeWithScore]]],
search_top_k: int = 1000,
) -> None:
self.name = name
self.search_top_k
self.create_retriever_fn = create_retriever_fn
def __str__(self):
return self.name
def __call__(
self,
corpus: Dict[str, Dict[str, Any]],
queries: Dict[str, str],
**kwargs,
) -> Dict[str, Dict[str, float]]:
"""
Retrieve relevant documents for each query
"""
documents = []
for doc_id, doc in corpus.items():
text = doc.pop("text")
assert text is not None
documents.append(Document(id_=doc_id, text=text, metadata={**doc}))
retriever = self.create_retriever_fn(documents)
query_ids = list(queries.keys())
results = {qid: {} for qid in query_ids}
for qid in query_ids:
query = queries[qid]
if isinstance(query, list):
# take from mteb:
# https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/RetrievalEvaluator.py#L403
query = "; ".join(query)
nodes = retriever(query)
for node in nodes:
# ref_doc_id should point to corpus document id
results[qid][node.node.ref_doc_id] = node.score
return results
retriever = LlamaRetriever("vector_bm25_fusion", create_retriever_fn)
evaluation = AIRBench(
benchmark_version="AIR-Bench_24.04",
task_types=["long-doc"], # remove this line if you want to evaluate on all tasks
domains=["arxiv"], # remove this line if you want to evaluate on all domains
languages=["en"], # remove this line if you want to evaluate on all languages
# cache_dir="~/.air_bench/" # path to the cache directory (**NEED ~52GB FOR FULL BENCHMARK**)
)
evaluation.run(
retriever,
output_dir="./llama_results", # path to the output directory, default is "./search_results"
overwrite=True # set to True if you want to overwrite the existing results
)