multilingual-dokugpt / ggml-try.py
ffreemt
Update gen_doc_chunks
b140cfb
raw
history blame
No virus
3.69 kB
"""Adopted from https://github.com/imartinez/privateGPT/blob/main/privateGPT.py
https://raw.githubusercontent.com/imartinez/privateGPT/main/requirements.txt
from pathlib import Path
Path("models").mkdir(exit_ok=True)
!time wget -c https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin -O models/ggml-gpt4all-j-v1.3-groovy.bin"""
from dotenv import load_dotenv, dotenv_values
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.llms import GPT4All, LlamaCpp
import os
import argparse
import time
from types import SimpleNamespace
from chromadb.config import Settings
# embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
# persist_directory = os.environ.get('PERSIST_DIRECTORY')
# load_dotenv()
# model_type = os.environ.get('MODEL_TYPE')
# model_path = os.environ.get('MODEL_PATH')
# model_n_ctx = os.environ.get('MODEL_N_CTX')
# model_n_batch = int(os.environ.get('MODEL_N_BATCH',8))
# target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
settings = dict([('PERSIST_DIRECTORY', 'db1'),
('MODEL_TYPE', 'GPT4All'),
('MODEL_PATH', 'models/ggml-gpt4all-j-v1.3-groovy.bin'),
('EMBEDDINGS_MODEL_NAME', 'all-MiniLM-L6-v2'),
('MODEL_N_CTX', '1000'),
('MODEL_N_BATCH', '8'),
('TARGET_SOURCE_CHUNKS', '4')])
# models/ggml-gpt4all-j-v1.3-groovy.bin ~5G
persist_directory = settings.get('PERSIST_DIRECTORY')
model_type = settings.get('MODEL_TYPE')
model_path = settings.get('MODEL_PATH')
embeddings_model_name = settings.get("EMBEDDINGS_MODEL_NAME")
# embeddings_model_name = 'all-MiniLM-L6-v2'
# embeddings_model_name = 'paraphrase-multilingual-mpnet-base-v2'
model_n_ctx = settings.get('MODEL_N_CTX')
model_n_batch = int(settings.get('MODEL_N_BATCH',8))
target_source_chunks = int(settings.get('TARGET_SOURCE_CHUNKS',4))
# Define the Chroma settings
CHROMA_SETTINGS = Settings(
chroma_db_impl='duckdb+parquet',
persist_directory=persist_directory,
anonymized_telemetry=False
)
args = SimpleNamespace(hide_source=False, mute_stream=False)
# load chroma database from db1
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# activate/deactivate the streaming StdOut callback for LLMs
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
# Prepare the LLM
match model_type:
case "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, verbose=False)
case "GPT4All":
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False)
case _default:
# raise exception if model_type is not supported
raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All")
# need about 5G RAM
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source)
# Get the answer from the chain
query = "共产党是什么"
start = time.time()
res = qa(query)
answer, docs = res['result'], [] if args.hide_source else res['source_documents']
end = time.time()
# Print the result
print("\n\n> Question:")
print(query)
print(f"\n> Answer (took {round(end - start, 2)} s.):")
print(answer)