ypatel / run_localGPT.py
root
Add application file
80be8a0
raw
history blame
3.34 kB
from langchain.chains import RetrievalQA
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms import HuggingFacePipeline
from constants import CHROMA_SETTINGS, PERSIST_DIRECTORY
from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline
import click
from constants import CHROMA_SETTINGS
def load_model():
'''
Select a model on huggingface.
If you are running this for the first time, it will download a model for you.
subsequent runs will use the model from the disk.
'''
model_id = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = LlamaTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id,
# load_in_8bit=True, # set these options if your GPU supports them!
# device_map=1#'auto',
# torch_dtype=torch.float16,
# low_cpu_mem_usage=True
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=2048,
temperature=0,
top_p=0.95,
repetition_penalty=1.15
)
local_llm = HuggingFacePipeline(pipeline=pipe)
return local_llm
@click.command()
@click.option('--device_type', default='gpu', help='device to run on, select gpu or cpu')
def main(device_type, ):
# load the instructorEmbeddings
if device_type in ['cpu', 'CPU']:
device='cpu'
else:
device='cuda'
print(f"Running on: {device}")
embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl",
model_kwargs={"device": device})
# load the vectorstore
db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever()
# Prepare the LLM
# callbacks = [StreamingStdOutCallbackHandler()]
# load the LLM for generating Natural Language responses.
llm = load_model()
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
# Interactive questions and answers
while True:
query = input("\nEnter a query: ")
if query == "exit":
break
# Get the answer from the chain
res = qa(query)
answer, docs = res['result'], res['source_documents']
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)
# # Print the relevant sources used for the answer
print("----------------------------------SOURCE DOCUMENTS---------------------------")
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
print("----------------------------------SOURCE DOCUMENTS---------------------------")
if __name__ == "__main__":
main()