mca183's picture
add everything
3dcda4d
raw
history blame contribute delete
No virus
2.04 kB
from langchain_community.document_loaders import HuggingFaceDatasetLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import pipeline
import gradio as gr
import os
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file
hf_api_key = os.environ['HF_TOKEN']
# Load the data
loader = HuggingFaceDatasetLoader(path="HuggingFaceH4/CodeAlpaca_20K", page_content_column="completion", use_auth_token=hf_api_key)
data = loader.load()
# Document Transformers
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
docs = text_splitter.split_documents(data)
# Text Embedding
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-l6-v2",
model_kwargs={'device':'cpu'},
encode_kwargs={'normalize_embeddings': False}
)
# Set up Vector Stores
db = FAISS.from_documents(docs, embeddings)
# Set up retrievers
retriever = db.as_retriever(search_kwargs={"k": 4})
txt2txt_gen = pipeline("text2text-generation", model="Salesforce/codet5-base")
def generate(question):
docs = retriever.get_relevant_documents(question)
context = docs[0].page_content
input = f"question: {question} context: {context}"
output = txt2txt_gen(input)
return output[0]['generated_text']
def respond(message, chat_history):
bot_message = generate(message)
chat_history.append((message, bot_message))
return "", chat_history
# Set up the chat interface
with gr.Blocks() as demo:
chatbot = gr.Chatbot(height=240) #just to fit the notebook
msg = gr.Textbox(label="Ask away")
btn = gr.Button("Submit")
clear = gr.ClearButton(components=[msg, chatbot], value="Clear console")
btn.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot]) #Press enter to submit
demo.queue().launch()