Spaces:
Sleeping
Sleeping
import os | |
os.environ["HF_HOME"] = "weights" | |
os.environ["TORCH_HOME"] = "weights" | |
from typing import List, Optional, Union | |
from langchain.callbacks import FileCallbackHandler | |
from langchain.retrievers import ContextualCompressionRetriever, ParentDocumentRetriever | |
from langchain.retrievers.document_compressors import EmbeddingsFilter | |
from langchain.storage import InMemoryStore | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import UnstructuredFileLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS, Chroma | |
from langchain_core.documents import Document | |
from loguru import logger | |
from rich import print | |
from sentence_transformers import CrossEncoder | |
from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs | |
logfile = "log/output.log" | |
logger.add(logfile, colorize=True, enqueue=True) | |
handler = FileCallbackHandler(logfile) | |
persist_directory = None | |
class RAGException(Exception): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def rerank_docs(reranker_model, query, retrieved_docs): | |
query_and_docs = [(query, r.page_content) for r in retrieved_docs] | |
scores = reranker_model.predict(query_and_docs) | |
return sorted(list(zip(retrieved_docs, scores)), key=lambda x: x[1], reverse=True) | |
def load_pdf( | |
files: Union[str, List[str]] = "2401.08406v3.pdf" | |
) -> List[Document]: | |
if isinstance(files, str): | |
loader = UnstructuredFileLoader( | |
files, | |
post_processors=[clean_extra_whitespace, group_broken_paragraphs], | |
) | |
return loader.load() | |
loaders = [ | |
UnstructuredFileLoader( | |
file, | |
post_processors=[clean_extra_whitespace, group_broken_paragraphs], | |
) | |
for file in files | |
] | |
docs = [] | |
for loader in loaders: | |
docs.extend( | |
loader.load(), | |
) | |
return docs | |
def create_parent_retriever( | |
docs: List[Document], embeddings_model: HuggingFaceEmbeddings() | |
): | |
parent_splitter = RecursiveCharacterTextSplitter( | |
separators=["\n\n\n", "\n\n"], | |
chunk_size=2000, | |
length_function=len, | |
is_separator_regex=False, | |
) | |
# This text splitter is used to create the child documents | |
child_splitter = RecursiveCharacterTextSplitter( | |
separators=["\n\n\n", "\n\n"], | |
chunk_size=1000, | |
chunk_overlap=300, | |
length_function=len, | |
is_separator_regex=False, | |
) | |
# The vectorstore to use to index the child chunks | |
vectorstore = Chroma( | |
collection_name="split_documents", | |
embedding_function=embeddings_model, | |
persist_directory=persist_directory, | |
) | |
# The storage layer for the parent documents | |
store = InMemoryStore() | |
retriever = ParentDocumentRetriever( | |
vectorstore=vectorstore, | |
docstore=store, | |
child_splitter=child_splitter, | |
parent_splitter=parent_splitter, | |
k=10, | |
) | |
retriever.add_documents(docs) | |
return retriever | |
def retrieve_context(query, retriever, reranker_model): | |
retrieved_docs = retriever.get_relevant_documents(query) | |
if len(retrieved_docs) == 0: | |
raise RAGException( | |
f"Couldn't retrieve any relevant document with the query `{query}`. Try modifying your question!" | |
) | |
reranked_docs = rerank_docs( | |
query=query, retrieved_docs=retrieved_docs, reranker_model=reranker_model | |
) | |
return reranked_docs | |
def load_embedding_model( | |
model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cuda" | |
) -> HuggingFaceEmbeddings: | |
model_kwargs = {"device": device} | |
encode_kwargs = { | |
"normalize_embeddings": True | |
} # set True to compute cosine similarity | |
embedding_model = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs, | |
) | |
return embedding_model | |
def load_reranker_model( | |
reranker_model_name: str = "BAAI/bge-reranker-large", device: str = "cuda" | |
) -> CrossEncoder: | |
reranker_model = CrossEncoder( | |
model_name=reranker_model_name, max_length=1024, device=device | |
) | |
return reranker_model | |
def main( | |
file: str = "2401.08406v3.pdf", | |
query: Optional[str] = None, | |
llm_name="mistral", | |
): | |
docs = load_pdf(files=file) | |
embedding_model = load_embedding_model() | |
retriever = create_parent_retriever(docs, embedding_model) | |
reranker_model = load_reranker_model() | |
context = retrieve_context( | |
query, retriever=retriever, reranker_model=reranker_model | |
)[0] | |
print("context:\n", context, "\n", "=" * 50, "\n") | |
if __name__ == "__main__": | |
from jsonargparse import CLI | |
CLI(main) |