|
import logging |
|
|
|
import click |
|
import torch |
|
from auto_gptq import AutoGPTQForCausalLM |
|
from huggingface_hub import hf_hub_download |
|
from langchain.chains import RetrievalQA |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
from langchain.llms import HuggingFacePipeline, LlamaCpp |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.prompts import PromptTemplate |
|
|
|
|
|
|
|
from langchain.vectorstores import Chroma |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
GenerationConfig, |
|
LlamaForCausalLM, |
|
LlamaTokenizer, |
|
pipeline, |
|
) |
|
|
|
from constants import EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME |
|
|
|
|
|
def load_model(device_type, model_id, model_basename=None): |
|
""" |
|
Select a model for text generation using the HuggingFace library. |
|
If you are running this for the first time, it will download a model for you. |
|
subsequent runs will use the model from the disk. |
|
|
|
Args: |
|
device_type (str): Type of device to use, e.g., "cuda" for GPU or "cpu" for CPU. |
|
model_id (str): Identifier of the model to load from HuggingFace's model hub. |
|
model_basename (str, optional): Basename of the model if using quantized models. |
|
Defaults to None. |
|
|
|
Returns: |
|
HuggingFacePipeline: A pipeline object for text generation using the loaded model. |
|
|
|
Raises: |
|
ValueError: If an unsupported model or device type is provided. |
|
""" |
|
logging.info(f"Loading Model: {model_id}, on: {device_type}") |
|
logging.info("This action can take a few minutes!") |
|
|
|
if model_basename is not None: |
|
if ".ggml" in model_basename: |
|
logging.info("Using Llamacpp for GGML quantized models") |
|
model_path = hf_hub_download(repo_id=model_id, filename=model_basename, resume_download=True) |
|
max_ctx_size = 2048 |
|
kwargs = { |
|
"model_path": model_path, |
|
"n_ctx": max_ctx_size, |
|
"max_tokens": max_ctx_size, |
|
} |
|
if device_type.lower() == "mps": |
|
kwargs["n_gpu_layers"] = 1000 |
|
if device_type.lower() == "cuda": |
|
kwargs["n_gpu_layers"] = 1000 |
|
kwargs["n_batch"] = max_ctx_size |
|
return LlamaCpp(**kwargs) |
|
|
|
else: |
|
|
|
|
|
logging.info("Using AutoGPTQForCausalLM for quantized models") |
|
|
|
if ".safetensors" in model_basename: |
|
|
|
model_basename = model_basename.replace(".safetensors", "") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) |
|
logging.info("Tokenizer loaded") |
|
|
|
model = AutoGPTQForCausalLM.from_quantized( |
|
model_id, |
|
model_basename=model_basename, |
|
use_safetensors=True, |
|
trust_remote_code=True, |
|
device="cuda:0", |
|
use_triton=False, |
|
quantize_config=None, |
|
) |
|
elif ( |
|
device_type.lower() == "cuda" |
|
): |
|
|
|
logging.info("Using AutoModelForCausalLM for full models") |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
logging.info("Tokenizer loaded") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
|
|
) |
|
model.tie_weights() |
|
else: |
|
logging.info("Using LlamaTokenizer") |
|
tokenizer = LlamaTokenizer.from_pretrained(model_id) |
|
model = LlamaForCausalLM.from_pretrained(model_id) |
|
|
|
|
|
generation_config = GenerationConfig.from_pretrained(model_id) |
|
|
|
|
|
|
|
|
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_length=2048, |
|
temperature=0, |
|
top_p=0.95, |
|
repetition_penalty=1.15, |
|
generation_config=generation_config, |
|
) |
|
|
|
local_llm = HuggingFacePipeline(pipeline=pipe) |
|
logging.info("Local LLM Loaded") |
|
|
|
return local_llm |
|
|
|
|
|
|
|
@click.command() |
|
@click.option( |
|
"--device_type", |
|
default="cuda" if torch.cuda.is_available() else "cpu", |
|
type=click.Choice( |
|
[ |
|
"cpu", |
|
"cuda", |
|
"ipu", |
|
"xpu", |
|
"mkldnn", |
|
"opengl", |
|
"opencl", |
|
"ideep", |
|
"hip", |
|
"ve", |
|
"fpga", |
|
"ort", |
|
"xla", |
|
"lazy", |
|
"vulkan", |
|
"mps", |
|
"meta", |
|
"hpu", |
|
"mtia", |
|
], |
|
), |
|
help="Device to run on. (Default is cuda)", |
|
) |
|
@click.option( |
|
"--show_sources", |
|
"-s", |
|
is_flag=True, |
|
help="Show sources along with answers (Default is False)", |
|
) |
|
def main(device_type, show_sources): |
|
""" |
|
This function implements the information retrieval task. |
|
|
|
|
|
1. Loads an embedding model, can be HuggingFaceInstructEmbeddings or HuggingFaceEmbeddings |
|
2. Loads the existing vectorestore that was created by inget.py |
|
3. Loads the local LLM using load_model function - You can now set different LLMs. |
|
4. Setup the Question Answer retreival chain. |
|
5. Question answers. |
|
""" |
|
|
|
logging.info(f"Running on: {device_type}") |
|
logging.info(f"Display Source Documents set to: {show_sources}") |
|
|
|
embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type}) |
|
|
|
|
|
|
|
|
|
|
|
db = Chroma( |
|
persist_directory=PERSIST_DIRECTORY, |
|
embedding_function=embeddings, |
|
|
|
) |
|
retriever = db.as_retriever() |
|
|
|
|
|
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\ |
|
just say that you don't know, don't try to make up an answer. |
|
|
|
{context} |
|
|
|
{history} |
|
Question: {question} |
|
Helpful Answer:""" |
|
|
|
prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template) |
|
memory = ConversationBufferMemory(input_key="question", memory_key="history") |
|
|
|
llm = load_model(device_type, model_id=MODEL_ID, model_basename=MODEL_BASENAME) |
|
|
|
qa = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt": prompt, "memory": memory}, |
|
) |
|
|
|
while True: |
|
query = input("\nEnter a query: ") |
|
if query == "exit": |
|
break |
|
|
|
res = qa(query) |
|
answer, docs = res["result"], res["source_documents"] |
|
|
|
|
|
print("\n\n> Question:") |
|
print(query) |
|
print("\n> Answer:") |
|
print(answer) |
|
|
|
if show_sources: |
|
|
|
print("----------------------------------SOURCE DOCUMENTS---------------------------") |
|
for document in docs: |
|
print("\n> " + document.metadata["source"] + ":") |
|
print(document.page_content) |
|
print("----------------------------------SOURCE DOCUMENTS---------------------------") |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO |
|
) |
|
main() |
|
|