"""Adopted from https://github.com/imartinez/privateGPT/blob/main/privateGPT.py https://raw.githubusercontent.com/imartinez/privateGPT/main/requirements.txt from pathlib import Path Path("models").mkdir(exit_ok=True) !time wget -c https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin -O models/ggml-gpt4all-j-v1.3-groovy.bin""" from dotenv import load_dotenv, dotenv_values from langchain.chains import RetrievalQA from langchain.embeddings import HuggingFaceEmbeddings from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.vectorstores import Chroma from langchain.llms import GPT4All, LlamaCpp import os import argparse import time from types import SimpleNamespace from chromadb.config import Settings # embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") # persist_directory = os.environ.get('PERSIST_DIRECTORY') # load_dotenv() # model_type = os.environ.get('MODEL_TYPE') # model_path = os.environ.get('MODEL_PATH') # model_n_ctx = os.environ.get('MODEL_N_CTX') # model_n_batch = int(os.environ.get('MODEL_N_BATCH',8)) # target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) settings = dict([('PERSIST_DIRECTORY', 'db1'), ('MODEL_TYPE', 'GPT4All'), ('MODEL_PATH', 'models/ggml-gpt4all-j-v1.3-groovy.bin'), ('EMBEDDINGS_MODEL_NAME', 'all-MiniLM-L6-v2'), ('MODEL_N_CTX', '1000'), ('MODEL_N_BATCH', '8'), ('TARGET_SOURCE_CHUNKS', '4')]) # models/ggml-gpt4all-j-v1.3-groovy.bin ~5G persist_directory = settings.get('PERSIST_DIRECTORY') model_type = settings.get('MODEL_TYPE') model_path = settings.get('MODEL_PATH') embeddings_model_name = settings.get("EMBEDDINGS_MODEL_NAME") # embeddings_model_name = 'all-MiniLM-L6-v2' # embeddings_model_name = 'paraphrase-multilingual-mpnet-base-v2' model_n_ctx = settings.get('MODEL_N_CTX') model_n_batch = int(settings.get('MODEL_N_BATCH',8)) target_source_chunks = int(settings.get('TARGET_SOURCE_CHUNKS',4)) # Define the Chroma settings CHROMA_SETTINGS = Settings( chroma_db_impl='duckdb+parquet', persist_directory=persist_directory, anonymized_telemetry=False ) args = SimpleNamespace(hide_source=False, mute_stream=False) # load chroma database from db1 embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) # activate/deactivate the streaming StdOut callback for LLMs callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] # Prepare the LLM match model_type: case "LlamaCpp": llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, verbose=False) case "GPT4All": llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False) case _default: # raise exception if model_type is not supported raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All") # need about 5G RAM qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source) # Get the answer from the chain query = "共产党是什么" start = time.time() res = qa(query) answer, docs = res['result'], [] if args.hide_source else res['source_documents'] end = time.time() # Print the result print("\n\n> Question:") print(query) print(f"\n> Answer (took {round(end - start, 2)} s.):") print(answer)