Spaces:
Runtime error
Runtime error
"""Adopted from https://github.com/imartinez/privateGPT/blob/main/privateGPT.py | |
https://raw.githubusercontent.com/imartinez/privateGPT/main/requirements.txt | |
from pathlib import Path | |
Path("models").mkdir(exit_ok=True) | |
!time wget -c https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin -O models/ggml-gpt4all-j-v1.3-groovy.bin""" | |
from dotenv import load_dotenv, dotenv_values | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.vectorstores import Chroma | |
from langchain.llms import GPT4All, LlamaCpp | |
import os | |
import argparse | |
import time | |
from types import SimpleNamespace | |
from chromadb.config import Settings | |
# embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") | |
# persist_directory = os.environ.get('PERSIST_DIRECTORY') | |
# load_dotenv() | |
# model_type = os.environ.get('MODEL_TYPE') | |
# model_path = os.environ.get('MODEL_PATH') | |
# model_n_ctx = os.environ.get('MODEL_N_CTX') | |
# model_n_batch = int(os.environ.get('MODEL_N_BATCH',8)) | |
# target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
settings = dict([('PERSIST_DIRECTORY', 'db1'), | |
('MODEL_TYPE', 'GPT4All'), | |
('MODEL_PATH', 'models/ggml-gpt4all-j-v1.3-groovy.bin'), | |
('EMBEDDINGS_MODEL_NAME', 'all-MiniLM-L6-v2'), | |
('MODEL_N_CTX', '1000'), | |
('MODEL_N_BATCH', '8'), | |
('TARGET_SOURCE_CHUNKS', '4')]) | |
# models/ggml-gpt4all-j-v1.3-groovy.bin ~5G | |
persist_directory = settings.get('PERSIST_DIRECTORY') | |
model_type = settings.get('MODEL_TYPE') | |
model_path = settings.get('MODEL_PATH') | |
embeddings_model_name = settings.get("EMBEDDINGS_MODEL_NAME") | |
# embeddings_model_name = 'all-MiniLM-L6-v2' | |
# embeddings_model_name = 'paraphrase-multilingual-mpnet-base-v2' | |
model_n_ctx = settings.get('MODEL_N_CTX') | |
model_n_batch = int(settings.get('MODEL_N_BATCH',8)) | |
target_source_chunks = int(settings.get('TARGET_SOURCE_CHUNKS',4)) | |
# Define the Chroma settings | |
CHROMA_SETTINGS = Settings( | |
chroma_db_impl='duckdb+parquet', | |
persist_directory=persist_directory, | |
anonymized_telemetry=False | |
) | |
args = SimpleNamespace(hide_source=False, mute_stream=False) | |
# load chroma database from db1 | |
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | |
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) | |
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) | |
# activate/deactivate the streaming StdOut callback for LLMs | |
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] | |
# Prepare the LLM | |
match model_type: | |
case "LlamaCpp": | |
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, verbose=False) | |
case "GPT4All": | |
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False) | |
case _default: | |
# raise exception if model_type is not supported | |
raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All") | |
# need about 5G RAM | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source) | |
# Get the answer from the chain | |
query = "共产党是什么" | |
start = time.time() | |
res = qa(query) | |
answer, docs = res['result'], [] if args.hide_source else res['source_documents'] | |
end = time.time() | |
# Print the result | |
print("\n\n> Question:") | |
print(query) | |
print(f"\n> Answer (took {round(end - start, 2)} s.):") | |
print(answer) | |