|
from typing import Any, Callable, Dict, List, Optional |
|
|
|
from air_benchmark import AIRBench, Retriever |
|
from llama_index.core import VectorStoreIndex |
|
from llama_index.core.node_parser import SentenceSplitter |
|
from llama_index.embeddings.openai import OpenAIEmbedding |
|
from llama_index.llms.openai import OpenAI |
|
from llama_index.retrievers.bm25 import BM25Retriever |
|
from llama_index.core.retrievers import QueryFusionRetriever |
|
from llama_index.core.schema import Document, NodeWithScore |
|
|
|
|
|
def create_retriever_fn(documents: List[Document], top_k: int) -> Callable[[str], List[NodeWithScore]]: |
|
|
|
|
|
|
|
nodes = SentenceSplitter(chunk_size=1024, chunk_overlap=128)(documents) |
|
|
|
vector_index = VectorStoreIndex( |
|
nodes=nodes, |
|
embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002") |
|
) |
|
vector_retriever = vector_index.as_retriever(similarity_top_k=top_k) |
|
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k) |
|
|
|
retriever = QueryFusionRetriever( |
|
[vector_retriever, bm25_retriever], |
|
similarity_top_k=top_k, |
|
num_queries=3, |
|
mode="dist_based_score", |
|
llm=OpenAI(model="gpt-3.5-turbo", temperature=0.1) |
|
) |
|
|
|
def _retriever(query: str) -> List[NodeWithScore]: |
|
return retriever.retrieve(query) |
|
|
|
return _retriever |
|
|
|
|
|
|
|
class LlamaRetriever(Retriever): |
|
def __init__( |
|
self, |
|
name: str, |
|
create_retriever_fn: Callable[[List[Document], int], Callable[[str], List[NodeWithScore]]], |
|
search_top_k: int = 1000, |
|
) -> None: |
|
self.name = name |
|
self.search_top_k |
|
self.create_retriever_fn = create_retriever_fn |
|
|
|
def __str__(self): |
|
return self.name |
|
|
|
def __call__( |
|
self, |
|
corpus: Dict[str, Dict[str, Any]], |
|
queries: Dict[str, str], |
|
**kwargs, |
|
) -> Dict[str, Dict[str, float]]: |
|
""" |
|
Retrieve relevant documents for each query |
|
""" |
|
|
|
documents = [] |
|
for doc_id, doc in corpus.items(): |
|
text = doc.pop("text") |
|
assert text is not None |
|
documents.append(Document(id_=doc_id, text=text, metadata={**doc})) |
|
|
|
retriever = self.create_retriever_fn(documents) |
|
|
|
query_ids = list(queries.keys()) |
|
results = {qid: {} for qid in query_ids} |
|
for qid in query_ids: |
|
query = queries[qid] |
|
if isinstance(query, list): |
|
|
|
|
|
query = "; ".join(query) |
|
|
|
nodes = retriever(query) |
|
for node in nodes: |
|
|
|
results[qid][node.node.ref_doc_id] = node.score |
|
|
|
return results |
|
|
|
|
|
|
|
retriever = LlamaRetriever("vector_bm25_fusion", create_retriever_fn) |
|
|
|
evaluation = AIRBench( |
|
benchmark_version="AIR-Bench_24.04", |
|
task_types=["long-doc"], |
|
domains=["arxiv"], |
|
languages=["en"], |
|
|
|
) |
|
|
|
evaluation.run( |
|
retriever, |
|
output_dir="./llama_results", |
|
overwrite=True |
|
) |
|
|