ChatHaruhi / app.py
BlairLeng's picture
change key:
f723f88
raw
history blame
15.7 kB
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 指定要使用的GPU设备编号
from transformers import pipeline
import argparse
import openai
import tiktoken
import torch
from scipy.spatial.distance import cosine
from transformers import AutoModel, AutoTokenizer
from argparse import Namespace
from langchain.chat_models import ChatOpenAI
import gradio as gr
import random
import time
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage
)
from text import Text
def download_models():
# Import our models. The package will take care of downloading the models automatically
model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False,
init_embeddings_model=None)
model = AutoModel.from_pretrained("silk-road/luotuo-bert", trust_remote_code=True, model_args=model_args)
return model
# OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY2")
key_1 = "sk-dsal7Uv2ThjeagImllYC"
key_2 = "T3BlbkFJesI7mCOMgW3jC9PnhHN1"
openai.api_key = key_1+key_2 # 在这里输入你的OpenAI API Token
os.environ["OPENAI_API_KEY"] = openai.api_key
folder_name = "Suzumiya"
current_directory = os.getcwd()
new_directory = os.path.join(current_directory, folder_name)
pkl_path = './pkl/texts.pkl'
text_image_pkl_path='./pkl/text_image.pkl'
dict_path = "characters/haruhi/text_image_dict.txt"
dict_text_pkl_path = './pkl/dict_text.pkl'
image_path = "characters/haruhi/images"
model = download_models()
text = Text("characters/haruhi/texts", text_image_pkl_path=text_image_pkl_path,
dict_text_pkl_path=dict_text_pkl_path, model=model, num_steps=50, pkl_path=pkl_path,
dict_path=dict_path, image_path=image_path)
if not os.path.exists(new_directory):
os.makedirs(new_directory)
print(f"文件夹 '{folder_name}' 创建成功!")
else:
print(f"文件夹 '{folder_name}' 已经存在。")
enc = tiktoken.get_encoding("cl100k_base")
class Run:
def __init__(self, **params):
"""
* 命令行参数的接入
* 台词folder,记录台词
* system prompt存成txt文件,支持切换
* 支持设定max_len_story 和max_len_history
* 支持设定save_path
* 实现一个colab脚本,可以clone转换后的项目并运行,方便其他用户体验
"""
self.folder = params['folder']
# self.system_prompt = params['system_prompt']
with open(params['system_prompt'], 'r') as f:
self.system_prompt = f.read()
self.max_len_story = params['max_len_story']
self.max_len_history = params['max_len_history']
self.save_path = params['save_path']
self.titles, self.title_to_text = self.read_prompt_data()
self.embeddings, self.embed_to_title = self.title_text_embedding(self.titles, self.title_to_text)
# self.embeddings, self.embed_to_title = [], []
# 一个封装 OpenAI 接口的函数,参数为 Prompt,返回对应结果
def get_completion_from_messages(self, messages, model="gpt-3.5-turbo", temperature=0):
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=temperature, # 控制模型输出的随机程度
)
# print(str(response.choices[0].message))
return response.choices[0].message["content"]
def read_prompt_data(self):
"""
read prompt-data for in-context-learning
"""
titles = []
title_to_text = {}
for file in os.listdir(self.folder):
if file.endswith('.txt'):
title_name = file[:-4]
titles.append(title_name)
with open(os.path.join(self.folder, file), 'r') as f:
title_to_text[title_name] = f.read()
return titles, title_to_text
def get_embedding(self, text):
tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert")
model = download_models()
if len(text) > 512:
text = text[:512]
texts = [text]
# Tokenize the text
inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
# Extract the embeddings
# Get the embeddings
with torch.no_grad():
embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
return embeddings[0]
def title_text_embedding(self, titles, title_to_text):
"""titles-text-embeddings"""
embeddings = []
embed_to_title = []
for title in titles:
text = title_to_text[title]
# divide text with \n\n
divided_texts = text.split('\n\n')
for divided_text in divided_texts:
embed = self.get_embedding(divided_text)
embeddings.append(embed)
embed_to_title.append(title)
return embeddings, embed_to_title
def get_cosine_similarity(self, embed1, embed2):
return torch.nn.functional.cosine_similarity(embed1, embed2, dim=0)
def retrieve_title(self, query_embed, embeddings, embed_to_title, k):
# compute cosine similarity between query_embed and embeddings
cosine_similarities = []
for embed in embeddings:
cosine_similarities.append(self.get_cosine_similarity(query_embed, embed))
# sort cosine similarity
sorted_cosine_similarities = sorted(cosine_similarities, reverse=True)
top_k_index = []
top_k_title = []
for i in range(len(sorted_cosine_similarities)):
current_title = embed_to_title[cosine_similarities.index(sorted_cosine_similarities[i])]
if current_title not in top_k_title:
top_k_title.append(current_title)
top_k_index.append(cosine_similarities.index(sorted_cosine_similarities[i]))
if len(top_k_title) == k:
break
return top_k_title
def organize_story_with_maxlen(self, selected_sample):
maxlen = self.max_len_story
# title_to_text, _ = self.read_prompt_data()
story = "凉宫春日的经典桥段如下:\n"
count = 0
final_selected = []
print(selected_sample)
for sample_topic in selected_sample:
# find sample_answer in dictionary
sample_story = self.title_to_text[sample_topic]
sample_len = len(enc.encode(sample_story))
# print(sample_topic, ' ' , sample_len)
if sample_len + count > maxlen:
break
story += sample_story
story += '\n'
count += sample_len
final_selected.append(sample_topic)
return story, final_selected
def organize_message(self, story, history_chat, history_response, new_query):
messages = [{'role': 'system', 'content': self.system_prompt}, {'role': 'user', 'content': story}]
n = len(history_chat)
if n != len(history_response):
print('warning, unmatched history_char length, clean and start new chat')
# clean all
history_chat = []
history_response = []
n = 0
for i in range(n):
messages.append({'role': 'user', 'content': history_chat[i]})
messages.append({'role': 'user', 'content': history_response[i]})
messages.append({'role': 'user', 'content': new_query})
return messages
def keep_tail(self, history_chat, history_response):
max_len = self.max_len_history
n = len(history_chat)
if n == 0:
return [], []
if n != len(history_response):
print('warning, unmatched history_char length, clean and start new chat')
return [], []
token_len = []
for i in range(n):
chat_len = len(enc.encode(history_chat[i]))
res_len = len(enc.encode(history_response[i]))
token_len.append(chat_len + res_len)
keep_k = 1
count = token_len[n - 1]
for i in range(1, n):
count += token_len[n - 1 - i]
if count > max_len:
break
keep_k += 1
return history_chat[-keep_k:], history_response[-keep_k:]
def organize_message_langchain(self, story, history_chat, history_response, new_query):
# messages = [{'role':'system', 'content':SYSTEM_PROMPT}, {'role':'user', 'content':story}]
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=story)
]
n = len(history_chat)
if n != len(history_response):
print('warning, unmatched history_char length, clean and start new chat')
# clean all
history_chat = []
history_response = []
n = 0
for i in range(n):
messages.append(HumanMessage(content=history_chat[i]))
messages.append(AIMessage(content=history_response[i]))
# messages.append( {'role':'user', 'content':new_query })
messages.append(HumanMessage(content=new_query))
return messages
def get_response(self, user_message, chat_history_tuple):
history_chat = []
history_response = []
if len(chat_history_tuple) > 0:
for cha, res in chat_history_tuple:
history_chat.append(cha)
history_response.append(res)
history_chat, history_response = self.keep_tail(history_chat, history_response)
print('history done')
new_query = user_message
query_embed = self.get_embedding(new_query)
# print("1")
# embeddings, embed_to_title = self.title_text_embedding(self.titles, self.title_to_text)
print("2")
selected_sample = self.retrieve_title(query_embed, self.embeddings, self.embed_to_title, 7)
print("3")
story, selected_sample = self.organize_story_with_maxlen(selected_sample)
## TODO: visualize seletected sample later
print('当前辅助sample:', selected_sample)
messages = self.organize_message_langchain(story, history_chat, history_response, new_query)
chat = ChatOpenAI(temperature=0)
return_msg = chat(messages)
response = return_msg.content
return response
def save_response(self, chat_history_tuple):
with open(f"{self.save_path}/conversation_{time.time()}.txt", "w") as file:
for cha, res in chat_history_tuple:
file.write(cha)
file.write("\n---\n")
file.write(res)
file.write("\n---\n")
def create_gradio(self):
# from google.colab import drive
# drive.mount(drive_path)
with gr.Blocks() as demo:
gr.Markdown(
"""
## Chat凉宫春日 ChatHaruhi
项目地址 [https://github.com/LC1332/Chat-Haruhi-Suzumiya](https:// github.com/LC1332/Chat-Haruhi-Suzumiya)
骆驼项目地址 [https://github.com/LC1332/Luotuo-Chinese-LLM](https:// github.com/LC1332/Luotuo-Chinese-LLM)
此版本为图文版本,完整功能(+语音)的demo见项目
角色名建议输入 阿虚 或者影视剧中有的人物。或者也可以是新学生或者老师。
"""
)
image_input = gr.Textbox(visible=False)
# japanese_input = gr.Textbox(visible=False)
with gr.Row():
chatbot = gr.Chatbot()
image_output = gr.Image()
role_name = gr.Textbox(label="角色名", placeholde="输入角色名")
msg = gr.Textbox(label="输入")
with gr.Row():
clear = gr.Button("Clear")
sub = gr.Button("Submit")
image_button = gr.Button("给我一个图")
# japanese_output = gr.Textbox(interactive=False)
def respond(role_name, user_message, chat_history):
input_message = role_name + ':「' + user_message + '」'
bot_message = self.get_response(input_message, chat_history)
chat_history.append((input_message, bot_message))
self.save_response(chat_history)
# time.sleep(1)
# jp_text = pipe(f'<-zh2ja-> {bot_message}')[0]['translation_text']
return "" , chat_history, bot_message
clear.click(lambda: None, None, chatbot, queue=False)
msg.submit(respond, [role_name, msg, chatbot], [msg, chatbot, image_input])
sub.click(fn=respond, inputs=[role_name, msg, chatbot], outputs=[msg, chatbot, image_input])
# with gr.Tab("text_to_text"):
# text_input = gr.Textbox()
# text_output = gr.Textbox()
# text_button = gr.Button('begin')
# text_button.click(text.text_to_text, inputs=text_input, outputs=text_output)
# with gr.Tab("text_to_iamge"):
# with gr.Row():
# image_input = gr.Textbox()
# image_output = gr.Image()
# image_button = gr.Button("给我一个图")
image_button.click(text.text_to_image, inputs=image_input, outputs=image_output)
demo.launch(debug=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="-----[Chat凉宫春日]-----")
parser.add_argument("--folder", default="characters/haruhi/texts", help="text folder")
parser.add_argument("--system_prompt", default="characters/haruhi/system_prompt.txt", help="store system_prompt")
parser.add_argument("--max_len_story", default=1500, type=int)
parser.add_argument("--max_len_history", default=1200, type=int)
# parser.add_argument("--save_path", default="/content/drive/MyDrive/GPTData/Haruhi-Lulu/")
parser.add_argument("--save_path", default=os.getcwd()+"/Suzumiya")
options = parser.parse_args()
params = {
"folder": options.folder,
"system_prompt": options.system_prompt,
"max_len_story": options.max_len_story,
"max_len_history": options.max_len_history,
"save_path": options.save_path
}
# pipe = pipeline(model="engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1", device=0,max_length=120)
run = Run(**params)
run.create_gradio()
# history_chat = []
# history_response = []
# chat_timer = 5
# new_query = '鲁鲁:你好我是新同学鲁鲁'
# query_embed = run.get_embedding(new_query)
# titles, title_to_text = run.read_prompt_data()
# embeddings, embed_to_title = run.title_text_embedding(titles, title_to_text)
# selected_sample = run.retrieve_title(query_embed, embeddings, embed_to_title, 7)
# print('限制长度之前:', selected_sample)
# story, selected_sample = run.organize_story_with_maxlen(selected_sample)
# print('当前辅助sample:', selected_sample)
# messages = run.organize_message(story, history_chat, history_response, new_query)
# response = run.get_completion_from_messages(messages)
# print(response)
# history_chat.append(new_query)
# history_response.append(response)
# history_chat, history_response = run.keep_tail(history_chat, history_response)
# print(history_chat, history_response)