import copy import os from pathlib import Path from typing import Union, Any from grobid_client.grobid_client import GrobidClient from langchain.chains import create_extraction_chain from langchain.chains.question_answering import load_qa_chain from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.retrievers import MultiQueryRetriever from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from tqdm import tqdm from grobid_processors import GrobidProcessor class DocumentQAEngine: llm = None qa_chain_type = None embedding_function = None embeddings_dict = {} embeddings_map_from_md5 = {} embeddings_map_to_md5 = {} def __init__(self, llm, embedding_function, qa_chain_type="stuff", embeddings_root_path=None, grobid_url=None): self.embedding_function = embedding_function self.llm = llm self.chain = load_qa_chain(llm, chain_type=qa_chain_type) if embeddings_root_path is not None: self.embeddings_root_path = embeddings_root_path if not os.path.exists(embeddings_root_path): os.makedirs(embeddings_root_path) else: self.load_embeddings(self.embeddings_root_path) if grobid_url: self.grobid_url = grobid_url grobid_client = GrobidClient( grobid_server=self.grobid_url, batch_size=1000, coordinates=["p"], sleep_time=5, timeout=60, check_server=True ) self.grobid_processor = GrobidProcessor(grobid_client) def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None: """ Load the embeddings assuming they are all persisted and stored in a single directory. The root path of the embeddings containing one data store for each document in each subdirectory """ embeddings_directories = [f for f in os.scandir(embeddings_root_path) if f.is_dir()] if len(embeddings_directories) == 0: print("No available embeddings") return for embedding_document_dir in embeddings_directories: self.embeddings_dict[embedding_document_dir.name] = Chroma(persist_directory=embedding_document_dir.path, embedding_function=self.embedding_function) filename_list = list(Path(embedding_document_dir).glob('*.storage_filename')) if filename_list: filenam = filename_list[0].name.replace(".storage_filename", "") self.embeddings_map_from_md5[embedding_document_dir.name] = filenam self.embeddings_map_to_md5[filenam] = embedding_document_dir.name print("Embedding loaded: ", len(self.embeddings_dict.keys())) def get_loaded_embeddings_ids(self): return list(self.embeddings_dict.keys()) def get_md5_from_filename(self, filename): return self.embeddings_map_to_md5[filename] def get_filename_from_md5(self, md5): return self.embeddings_map_from_md5[md5] def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None, verbose=False) -> ( Any, str): # self.load_embeddings(self.embeddings_root_path) if verbose: print(query) response = self._run_query(doc_id, query, context_size=context_size) response = response['output_text'] if 'output_text' in response else response if verbose: print(doc_id, "->", response) if output_parser: try: return self._parse_json(response, output_parser), response except Exception as oe: print("Failing to parse the response", oe) return None, response elif extraction_schema: try: chain = create_extraction_chain(extraction_schema, self.llm) parsed = chain.run(response) return parsed, response except Exception as oe: print("Failing to parse the response", oe) return None, response else: return None, response def query_storage(self, query: str, doc_id, context_size=4): documents = self._get_context(doc_id, query, context_size) context_as_text = [doc.page_content for doc in documents] return context_as_text def _parse_json(self, response, output_parser): system_message = "You are an useful assistant expert in materials science, physics, and chemistry " \ "that can process text and transform it to JSON." human_message = """Transform the text between three double quotes in JSON.\n\n\n\n {format_instructions}\n\nText: \"\"\"{text}\"\"\"""" system_message_prompt = SystemMessagePromptTemplate.from_template(system_message) human_message_prompt = HumanMessagePromptTemplate.from_template(human_message) prompt_template = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) results = self.llm( prompt_template.format_prompt( text=response, format_instructions=output_parser.get_format_instructions() ).to_messages() ) parsed_output = output_parser.parse(results.content) return parsed_output def _run_query(self, doc_id, query, context_size=4): relevant_documents = self._get_context(doc_id, query, context_size) return self.chain.run(input_documents=relevant_documents, question=query) # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True) def _get_context(self, doc_id, query, context_size=4): db = self.embeddings_dict[doc_id] retriever = db.as_retriever(search_kwargs={"k": context_size}) relevant_documents = retriever.get_relevant_documents(query) return relevant_documents def get_all_context_by_document(self, doc_id): db = self.embeddings_dict[doc_id] docs = db.get() return docs['documents'] def _get_context_multiquery(self, doc_id, query, context_size=4): db = self.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size}) multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm) relevant_documents = multi_query_retriever.get_relevant_documents(query) return relevant_documents def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False): if verbose: print("File", pdf_file_path) filename = Path(pdf_file_path).stem structure = self.grobid_processor.process_structure(pdf_file_path) biblio = structure['biblio'] biblio['filename'] = filename.replace(" ", "_") if verbose: print("Generating embeddings for:", hash, ", filename: ", filename) texts = [] metadatas = [] ids = [] if chunk_size < 0: for passage in structure['passages']: biblio_copy = copy.copy(biblio) if len(str.strip(passage['text'])) > 0: texts.append(passage['text']) biblio_copy['type'] = passage['type'] biblio_copy['section'] = passage['section'] biblio_copy['subSection'] = passage['subSection'] metadatas.append(biblio_copy) ids.append(passage['passage_id']) else: document_text = " ".join([passage['text'] for passage in structure['passages']]) # text_splitter = CharacterTextSplitter.from_tiktoken_encoder( text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=chunk_size, chunk_overlap=chunk_size * perc_overlap ) texts = text_splitter.split_text(document_text) metadatas = [biblio for _ in range(len(texts))] ids = [id for id, t in enumerate(texts)] return texts, metadatas, ids def create_memory_embeddings(self, pdf_path, doc_id=None): texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=500, perc_overlap=0.1) if doc_id: hash = doc_id else: hash = metadata[0]['hash'] self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata) self.embeddings_root_path = None return hash def create_embeddings(self, pdfs_dir_path: Path): input_files = [] for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False): for file_ in files: if not (file_.lower().endswith(".pdf")): continue input_files.append(os.path.join(root, file_)) for input_file in tqdm(input_files, total=len(input_files), unit='document', desc="Grobid + embeddings processing"): md5 = self.calculate_md5(input_file) data_path = os.path.join(self.embeddings_root_path, md5) if os.path.exists(data_path): print(data_path, "exists. Skipping it ") continue texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=500, perc_overlap=0.1) filename = metadata[0]['filename'] vector_db_document = Chroma.from_texts(texts, metadatas=metadata, embedding=self.embedding_function, persist_directory=data_path) vector_db_document.persist() with open(os.path.join(data_path, filename + ".storage_filename"), 'w') as fo: fo.write("") @staticmethod def calculate_md5(input_file: Union[Path, str]): import hashlib md5_hash = hashlib.md5() with open(input_file, 'rb') as fi: md5_hash.update(fi.read()) return md5_hash.hexdigest().upper()