Spaces:
Sleeping
Sleeping
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, pipeline | |
from langchain.llms import HuggingFaceHub, HuggingFacePipeline | |
from dotenv import load_dotenv | |
from langchain.embeddings import HuggingFaceBgeEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
import textwrap | |
import torch | |
import os | |
import streamlit as st | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def load_vector_store(): | |
model_name = "BAAI/bge-small-en" | |
model_kwargs = {"device": device} | |
encode_kwargs = {"normalize_embeddings": True} | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs | |
) | |
print('Embeddings loaded!') | |
load_vector_store = Chroma(persist_directory = 'vector stores/textdb', embedding_function = embeddings) | |
print('Vector store loaded!') | |
retriever = load_vector_store.as_retriever( | |
search_kwargs = {"k" : 10}, | |
) | |
return retriever | |
#model | |
def load_model(): | |
repo_id = 'llmware/dragon-mistral-7b-v0' | |
llm = HuggingFaceHub( | |
repo_id = repo_id, | |
model_kwargs = {'max_new_tokens' : 100} | |
) | |
print(llm('HI!')) | |
return llm | |
def qa_chain(): | |
retriever = load_vector_store() | |
llm = load_model() | |
qa = RetrievalQA.from_chain_type( | |
llm = llm, | |
chain_type = 'stuff', | |
retriever = retriever, | |
return_source_documents = True, | |
verbose = True | |
) | |
return qa | |
def wrap_text_preserve_newlines(text, width=110): | |
# Split the input text into lines based on newline characters | |
lines = text.split('\n') | |
# Wrap each line individually | |
wrapped_lines = [textwrap.fill(line, width=width) for line in lines] | |
# Join the wrapped lines back together using newline characters | |
wrapped_text = '\n'.join(wrapped_lines) | |
return wrapped_text | |
def process_llm_response(llm_response): | |
print(wrap_text_preserve_newlines(llm_response['result'])) | |
print('\n\nSources:') | |
for source in llm_response["source_documents"]: | |
print(source.metadata['source']) | |
def main(): | |
qa = qa_chain() | |
st.title('DOCUMENT-GPT') | |
text_query = st.text_area('Ask any question from your documents!') | |
generate_response_btn = st.button('Run RAG') | |
st.subheader('Response') | |
if generate_response_btn and text_query is not None: | |
with st.spinner('Generating Response. Please wait...'): | |
text_response = qa(f"<human>:" + text_query + "\n" + "<bot>:") | |
if text_response: | |
st.write(text_response["result"]) | |
else: | |
st.error('Failed to get response') | |
if __name__ == "__main__": | |
hf_token = st.text_input("Paste Huggingface read api key") | |
if hf_token: | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token | |
main() |