|
import gradio as gr |
|
import requests |
|
from langchain.embeddings import SentenceTransformerEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain_community.chat_models.huggingface import ChatHuggingFace |
|
from langchain.schema import SystemMessage, HumanMessage, AIMessage |
|
from langchain_community.llms import HuggingFaceEndpoint |
|
|
|
model_name = "sentence-transformers/all-mpnet-base-v2" |
|
embedding_llm = SentenceTransformerEmbeddings(model_name=model_name) |
|
|
|
db = FAISS.load_local("faiss_index", embedding_llm, allow_dangerous_deserialization=True) |
|
|
|
|
|
llm = HuggingFaceEndpoint( |
|
repo_id="HuggingFaceH4/starchat2-15b-v0.1", |
|
task="text-generation", |
|
max_new_tokens=4096, |
|
temperature=0.6, |
|
top_p=0.9, |
|
top_k=40, |
|
repetition_penalty=1.2, |
|
do_sample=True, |
|
) |
|
chat_model = ChatHuggingFace(llm=llm) |
|
|
|
messages = [ |
|
SystemMessage(content="You are a helpful assistant."), |
|
HumanMessage(content="Hi AI, how are you today?"), |
|
AIMessage(content="I'm great thank you. How can I help you?") |
|
] |
|
|
|
def handle_message(message: str, mode: str): |
|
result_text, result_image = "", None |
|
|
|
if not message.strip(): |
|
return "Enter a valid message.", None |
|
|
|
if mode == "Chat-Message": |
|
result_text = chat_message(message) |
|
elif mode == "Web-Search": |
|
result_text = web_search(message) |
|
elif mode == "Chart-Generator": |
|
result_text, result_image = chart_generator(message) |
|
else: |
|
result_text = "Select a valid mode." |
|
|
|
return result_text, result_image |
|
|
|
def chat_message(message: str): |
|
global messages |
|
|
|
prompt = HumanMessage(content=message) |
|
messages.append(prompt) |
|
|
|
response = chat_model.invoke(messages) |
|
messages.append(response.content) |
|
|
|
if len(messages) >= 6: |
|
messages = messages[-6:] |
|
|
|
return f"IT-Assistant: {response.content}" |
|
|
|
def web_search(message: str): |
|
global messages |
|
|
|
similar_docs = db.similarity_search(message, k=3) |
|
|
|
if similar_docs: |
|
source_knowledge = "\n".join([x.page_content for x in similar_docs]) |
|
else: |
|
source_knowledge = "" |
|
|
|
augmented_prompt = f""" |
|
If the answer to the next query is not contained in the Search, say 'No Answer Is Available' and then just give guidance for the query. |
|
Query: {message} |
|
Search: |
|
{source_knowledge} |
|
""" |
|
|
|
prompt = HumanMessage(content=augmented_prompt) |
|
messages.append(prompt) |
|
|
|
response = chat_model.invoke(messages) |
|
messages.append(response.content) |
|
|
|
if len(messages) >= 6: |
|
messages = messages[-6:] |
|
|
|
return f"IT-Assistant: {response.content}" |
|
|
|
def chart_generator(message: str): |
|
global messages |
|
|
|
chart_url = f"https://quickchart.io/natural/{message}" |
|
response = requests.get(chart_url) |
|
|
|
if response.status_code == 200: |
|
message_with_description = f"Describe and analyse the content of this chart: {message}" |
|
|
|
prompt = HumanMessage(content=message_with_description) |
|
messages.append(prompt) |
|
|
|
response = chat_model.invoke(messages) |
|
messages.append(response.content) |
|
|
|
if len(messages) >= 6: |
|
messages = messages[-6:] |
|
|
|
return f"IT-Assistant: {response.content}", chart_url |
|
else: |
|
return f"Can't generate this image. Please provide valid chart details.", None |
|
|
|
demo = gr.Interface( |
|
fn=handle_message, |
|
inputs=["text", gr.Radio(["Chat-Message", "Web-Search", "Chart-Generator"], label="mode", info="Choose a mode and enter your message, then click submit to interact.")], |
|
outputs=[gr.Textbox(label="Response"), gr.Image(label="Chart", type="filepath")], |
|
title="IT Assistant") |
|
|
|
demo.launch() |