ChatHaruhi / app_with_text_preload.py
BlairLeng's picture
changes
2e33d6a
raw
history blame
24.1 kB
import json
import os
import numpy as np
# os.environ['http_proxy'] = "http://127.0.0.1:1450"
# os.environ['https_proxy'] = "http://127.0.0.1:1450"
import argparse
import openai
import tiktoken
import torch
from scipy.spatial.distance import cosine
from langchain.chat_models import ChatOpenAI
import gradio as gr
import random
import time
import collections
import pickle
from argparse import Namespace
import torch
from PIL import Image
from torch import cosine_similarity
from transformers import AutoTokenizer, AutoModel
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage
)
# OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY2")
openai.proxy = "http://127.0.0.1:7890"
openai.api_key = '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' # 在这里输入你的OpenAI API Token
os.environ["OPENAI_API_KEY"] = openai.api_key
folder_name = "Suzumiya"
current_directory = os.getcwd()
new_directory = os.path.join(current_directory, folder_name)
device = torch.device("cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.exists(new_directory):
os.makedirs(new_directory)
print(f"文件夹 '{folder_name}' 创建成功!")
else:
print(f"文件夹 '{folder_name}' 已经存在。")
enc = tiktoken.get_encoding("cl100k_base")
class Run:
def __init__(self, **params):
"""
* 命令行参数的接入
* 台词folder,记录台词
* system prompt存成txt文件,支持切换
* 支持设定max_len_story 和max_len_history
* 支持设定save_path
* 实现一个colab脚本,可以clone转换后的项目并运行,方便其他用户体验
"""
self.title_to_text_pkl_path = params['title_to_text_pkl_path']
self.text_image_pkl_path = params['text_image_pkl_path']
self.dict_text_pkl_path = params['dict_text_pkl_path']
self.num_steps = params['num_steps']
self.texts_pkl_path = params['texts_pkl_path']
self.embeds_path = params['embeds_path']
self.embeds2_path = params['embeds2_path']
self.dict_path = params['dict_path']
self.image_path = params['image_path']
self.maps_pkl_path = params['maps_pkl_path']
self.folder = params['folder']
self.system_prompt = params['system_prompt']
self.max_len_story = params['max_len_story']
self.max_len_history = params['max_len_history']
self.save_path = params['save_path']
def read_text(self):
"""抽取、预存"""
text_embeddings = []
title_to_text = {}
texts = []
data = []
id = 0
for file in os.listdir(self.folder):
if file.endswith('.txt'):
title_name = file[:-4]
with open(os.path.join(self.folder, file), 'r') as fr:
title_to_text[title_name] = fr.read()
for line in title_to_text[title_name].strip().split('\n'):
line = line.strip()
category = {}
ch = ':' if ':' in line else ':'
if '旁白' in line:
text = line.split(ch)[1].strip()
else:
text = ''.join(list(line.split(ch)[1])[1:-1]) # 提取「」内的文本
if title_name + "_" + text in texts: # 避免重复的text,导致embeds 和 maps形状不一致
continue
texts.append(title_name+"_"+text)
category["titles"] = file.split('.')[0]
category["id"] = str(id)
category["text"] = text
id = id + 1
data.append(dict(category))
embeddings = self.get_embedding(texts)
with open(self.texts_pkl_path, 'w+', encoding='utf-8') as f1:
i = 0
for text in texts:
item = {}
item[text] = i
json.dump(item, f1, ensure_ascii=False)
f1.write('\n')
i+=1
with open(self.embeds_path, 'w+', encoding='utf-8') as f2, open(self.embeds2_path, 'w+', encoding='utf-8') as f3:
i = 0
for embed in embeddings:
item = {}
embed = embed.numpy().tolist()
item[i] = embed
if i < len(embeddings)/2:
json.dump(item, f2, ensure_ascii=False)
f2.write('\n')
else:
json.dump(item, f3, ensure_ascii=False)
f3.write('\n')
i += 1
# self.store(self.texts_pkl_path, text_embeddings)
self.store(self.title_to_text_pkl_path, title_to_text)
# self.store(self.embeds_pkl_path, embeddings)
self.store(self.maps_pkl_path, data)
return text_embeddings, data
def store(self, path, data):
with open(path, 'wb+') as f:
pickle.dump(data, f)
def load(self, load_texts=False, load_maps=False, load_dict_text=False,
load_text_image=False, load_title_to_text=False):
if load_texts:
if self.texts_pkl_path:
text_embeddings = {}
texts = []
embeds1 = []
embeds2 = []
with open(self.texts_pkl_path, 'r') as f:
for line in f:
data = json.loads(line)
texts.append(list(data.keys())[0])
with open(self.embeds_path, 'r') as f:
for line in f:
data = json.loads(line)
embeds1.append(list(data.values()))
with open(self.embeds2_path, 'r') as f:
for line in f:
data = json.loads(line)
embeds2.append(list(data.values()))
embeds = embeds1 + embeds2
for text, embed in zip(texts, embeds):
text_embeddings[text] = embed
return text_embeddings
else:
print("No texts_pkl_path")
elif load_maps:
if self.maps_pkl_path:
with open(self.maps_pkl_path, 'rb') as f:
return pickle.load(f)
else:
print("No maps_pkl_path")
elif load_dict_text:
if self.dict_text_pkl_path:
with open(self.dict_text_pkl_path, 'rb') as f:
return pickle.load(f)
else:
print("No dict_text_pkl_path")
elif load_text_image:
if self.text_image_pkl_path:
with open(self.text_image_pkl_path, 'rb') as f:
return pickle.load(f)
else:
print("No text_image_pkl_path")
elif load_title_to_text:
if self.title_to_text_pkl_path:
with open(self.title_to_text_pkl_path, 'rb') as f:
return pickle.load(f)
else:
print("No title_to_text_pkl_path")
else:
print("Please specify the loading file!")
def text_to_image(self, text, save_dict_text=False):
"""
给定文本出图片
计算query 和 texts 的相似度,取最高的作为new_query 查询image
到text_image_dict 读取图片名
然后到images里面加载该图片然后返回
"""
if save_dict_text:
text_image = collections.defaultdict()
with open(self.dict_path, 'r') as f:
data = f.readlines()
for sub_text, image in zip(data[::2], data[1::2]):
text_image[sub_text.strip()] = image.strip()
self.store(self.text_image_pkl_path, text_image)
keys_embeddings = collections.defaultdict(str)
for key in text_image.keys():
keys_embeddings[key] = self.get_embedding(key)
self.store(self.dict_text_pkl_path, keys_embeddings)
if self.dict_path and self.image_path:
# 加载 text-imageName
text_image = self.load(load_text_image=True)
keys = list(text_image.keys())
keys.insert(0, text)
query_similarity = self.get_cosine_similarity(keys, get_image=True)
key_index = query_similarity.argmax(dim=0)
text = list(text_image.keys())[key_index]
image = text_image[text] + '.jpg'
if image in os.listdir(self.image_path):
res = Image.open(self.image_path + '/' + image)
# res.show()
return res
else:
print("Image doesn't exist")
else:
print("No path")
def text_to_text(self, text):
pkl = self.load(load_texts=True)
texts = [title_text.split('_')[1] for title_text in list(pkl.keys())]
texts.insert(0, text)
texts_similarity = self.get_cosine_similarity(texts, get_texts=True)
key_index = texts_similarity.argmax(dim=0)
value = list(pkl.keys())[key_index]
return value
# 一个封装 OpenAI 接口的函数,参数为 Prompt,返回对应结果
def get_completion_from_messages(self, messages, model="gpt-3.5-turbo", temperature=0):
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=temperature, # 控制模型输出的随机程度
)
# print(str(response.choices[0].message))
return response.choices[0].message["content"]
def download_models(self):
# Import our models. The package will take care of downloading the models automatically
model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False,
init_embeddings_model=None)
model = AutoModel.from_pretrained("silk-road/luotuo-bert", trust_remote_code=True, model_args=model_args).to(device)
return model
def get_embedding(self, texts):
tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert")
model = self.download_models()
# str or strList
texts = texts if isinstance(texts, list) else [texts]
# 截断
for i in range(len(texts)):
if len(texts[i]) > self.num_steps:
texts[i] = texts[i][:self.num_steps]
# Tokenize the texts
inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
# Extract the embeddings
# Get the embeddings
inputs = inputs.to(device)
with torch.no_grad():
embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
return embeddings[0] if len(texts) == 1 else embeddings
def get_cosine_similarity(self, texts, get_image=False, get_texts=False):
"""
计算文本列表的相似度避免重复计算query_similarity
texts[0] = query
"""
if get_image:
pkl = self.load(load_dict_text=True)
elif get_texts:
pkl = self.load(load_texts=True)
else:
# 计算query_embed
pkl = {}
embeddings = self.get_embedding(texts[1:]).reshape(-1, 1536)
for text, embed in zip(texts, embeddings):
pkl[text] = embed
query_embedding = self.get_embedding(texts[0]).reshape(1, -1)
texts_embeddings = np.array([np.array(value).reshape(-1, 1536) for value in pkl.values()]).squeeze(1)
return cosine_similarity(query_embedding, torch.from_numpy(texts_embeddings))
def retrieve_title(self, query_text, k):
# compute cosine similarity between query_embed and embeddings
embed_to_title = []
texts = [query_text]
texts_pkl = self.load(load_texts=True)
for title_text in texts_pkl.keys():
res = title_text.split('_')
embed_to_title.append(res[0])
cosine_similarities = self.get_cosine_similarity(texts, get_texts=True).numpy().tolist()
# sort cosine similarity
sorted_cosine_similarities = sorted(cosine_similarities, reverse=True)
top_k_index = []
top_k_title = []
for i in range(len(sorted_cosine_similarities)):
current_title = embed_to_title[cosine_similarities.index(sorted_cosine_similarities[i])]
if current_title not in top_k_title:
top_k_title.append(current_title)
top_k_index.append(cosine_similarities.index(sorted_cosine_similarities[i]))
if len(top_k_title) == k:
break
return top_k_title
def organize_story_with_maxlen(self, selected_sample):
maxlen = self.max_len_story
title_to_text = self.load(load_title_to_text=True)
story = "凉宫春日的经典桥段如下:\n"
count = 0
final_selected = []
print(selected_sample)
for sample_topic in selected_sample:
# find sample_answer in dictionary
sample_story = title_to_text[sample_topic]
sample_len = len(enc.encode(sample_story))
# print(sample_topic, ' ' , sample_len)
if sample_len + count > maxlen:
break
story += sample_story
story += '\n'
count += sample_len
final_selected.append(sample_topic)
return story, final_selected
def organize_message(self, story, history_chat, history_response, new_query):
messages = [{'role': 'system', 'content': self.system_prompt},
{'role': 'user', 'content': story}]
n = len(history_chat)
if n != len(history_response):
print('warning, unmatched history_char length, clean and start new chat')
# clean all
history_chat = []
history_response = []
n = 0
for i in range(n):
messages.append({'role': 'user', 'content': history_chat[i]})
messages.append({'role': 'user', 'content': history_response[i]})
messages.append({'role': 'user', 'content': new_query})
return messages
def keep_tail(self, history_chat, history_response):
max_len = self.max_len_history
n = len(history_chat)
if n == 0:
return [], []
if n != len(history_response):
print('warning, unmatched history_char length, clean and start new chat')
return [], []
token_len = []
for i in range(n):
chat_len = len(enc.encode(history_chat[i]))
res_len = len(enc.encode(history_response[i]))
token_len.append(chat_len + res_len)
keep_k = 1
count = token_len[n - 1]
for i in range(1, n):
count += token_len[n - 1 - i]
if count > max_len:
break
keep_k += 1
return history_chat[-keep_k:], history_response[-keep_k:]
def organize_message_langchain(self, story, history_chat, history_response, new_query):
# messages = [{'role':'system', 'content':SYSTEM_PROMPT}, {'role':'user', 'content':story}]
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=story)
]
n = len(history_chat)
if n != len(history_response):
print('warning, unmatched history_char length, clean and start new chat')
# clean all
history_chat = []
history_response = []
n = 0
for i in range(n):
messages.append(HumanMessage(content=history_chat[i]))
messages.append(AIMessage(content=history_response[i]))
# messages.append( {'role':'user', 'content':new_query })
messages.append(HumanMessage(content=new_query))
return messages
def get_response(self, user_message, chat_history_tuple):
history_chat = []
history_response = []
if len(chat_history_tuple) > 0:
for cha, res in chat_history_tuple:
history_chat.append(cha)
history_response.append(res)
history_chat, history_response = self.keep_tail(history_chat, history_response)
print('history done')
new_query = user_message
selected_sample = self.retrieve_title(new_query, 7)
print("备选辅助:", selected_sample)
story, selected_sample = self.organize_story_with_maxlen(selected_sample)
## TODO: visualize seletected sample later
print('当前辅助sample:', selected_sample)
messages = self.organize_message_langchain(story, history_chat, history_response, new_query)
print(f"messages:{messages}")
chat = ChatOpenAI(temperature=0)
return_msg = chat(messages)
response = return_msg.content
return response
def save_response(self, chat_history_tuple):
with open(f"{self.save_path}/conversation_{time.time()}.txt", "w") as file:
for cha, res in chat_history_tuple:
file.write(cha)
file.write("\n---\n")
file.write(res)
file.write("\n---\n")
def create_gradio(self):
# from google.colab import drive
# drive.mount(drive_path)
with gr.Blocks() as demo:
gr.Markdown(
"""
## Chat凉宫春日 ChatHaruhi
项目地址 [https://github.com/LC1332/Chat-Haruhi-Suzumiya](https://github.com/LC1332/Chat-Haruhi-Suzumiya)
骆驼项目地址 [https://github.com/LC1332/Luotuo-Chinese-LLM](https://github.com/LC1332/Luotuo-Chinese-LLM)
此版本为图文版本,非最终版本,将上线更多功能,敬请期待
"""
)
image_input = gr.Textbox(visible=False)
with gr.Row():
chatbot = gr.Chatbot()
image_output = gr.Image()
role_name = gr.Textbox(label="角色名", placeholde="输入角色名")
msg = gr.Textbox(label="输入")
with gr.Row():
clear = gr.Button("Clear")
sub = gr.Button("Submit")
image_button = gr.Button("给我一个图")
def respond(role_name, user_message, chat_history):
role_name = "阿虚" if role_name in ['', ' '] else role_name
role_name = role_name[:10] if len(role_name) > 10 else role_name
user_message = user_message[:200] if len(user_message) > 200 else user_message
special_chars = [':', ':', '「', '」', '\n']
for char in special_chars:
role_name = role_name.replace(char, 'x')
user_message = user_message.replace(char, ' ')
replacement_rules = {'凉': '马', '宫': '宝', '春': '国', '日': '啊'}
# for char, replacement in replacement_rules.items():
# role_name = role_name.replace(char, replacement)
# user_message = user_message.replace(char, replacement)
input_message = role_name + ':「' + user_message + '」'
print(f"chat_history:{chat_history}")
bot_message = self.get_response(input_message, chat_history)
chat_history.append((input_message, bot_message))
self.save_response(chat_history)
# time.sleep(1)
return "", chat_history, bot_message
msg.submit(respond, [role_name, msg, chatbot], [msg, chatbot, image_input])
clear.click(lambda: None, None, chatbot, queue=False)
sub.click(fn=respond, inputs=[role_name, msg, chatbot], outputs=[msg, chatbot, image_input])
image_button.click(self.text_to_image, inputs=image_input, outputs=image_output)
demo.launch(debug=True, share=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="-----[Chat凉宫春日]-----")
parser.add_argument("--folder", default="../characters/haruhi/texts", help="text folder")
parser.add_argument("--system_prompt", default="../characters/haruhi/system_prompt.txt", help="store system_prompt")
parser.add_argument("--max_len_story", default=1500, type=int)
parser.add_argument("--max_len_history", default=1200, type=int)
# parser.add_argument("--save_path", default="/content/drive/MyDrive/GPTData/Haruhi-Lulu/")
parser.add_argument("--save_path", default=os.getcwd() + "/Suzumiya")
parser.add_argument("--texts_pkl_path", default="./pkl/texts.jsonl")
parser.add_argument("--embeds_path", default="./pkl/embeds.jsonl")
parser.add_argument("--embeds2_path", default="./pkl/embeds2.jsonl")
parser.add_argument("--maps_pkl_path", default="./pkl/maps.pkl")
parser.add_argument("--title_to_text_pkl_path", default='./pkl/title_to_text.pkl')
parser.add_argument("--dict_text_pkl_path", default="./pkl/dict_text.pkl")
parser.add_argument("--text_image_pkl_path", default="./pkl/text_image.pkl")
parser.add_argument("--dict_path", default="../characters/haruhi/text_image_dict.txt")
parser.add_argument("--image_path", default="../characters/haruhi/images")
parser.add_argument("--num_steps", default=510, type=int)
options = parser.parse_args()
params = {
"folder": options.folder,
"system_prompt": options.system_prompt,
"max_len_story": options.max_len_story,
"max_len_history": options.max_len_history,
"save_path": options.save_path,
"texts_pkl_path": options.texts_pkl_path,
"embeds_path": options.embeds_path,
"embeds2_path": options.embeds2_path,
"title_to_text_pkl_path": options.title_to_text_pkl_path,
"maps_pkl_path": options.maps_pkl_path,
"dict_text_pkl_path": options.dict_text_pkl_path,
"text_image_pkl_path": options.text_image_pkl_path,
"dict_path": options.dict_path,
"image_path": options.image_path,
"num_steps": options.num_steps,
}
run = Run(**params)
# selected_samples = run.retrieve_title("hello", 7)
# story, selected_samples = run.organize_story_with_maxlen(selected_samples)
# print(story, selected_samples)
run.read_text()
# run.text_to_image("hello", save_dict_text=True)
run.create_gradio()
# a = run.load(load_texts=True)
# print(len(a))
# for item in a:
# print(item)
# print(len(a))
# a = run.load(load_dict_text=True)
# print(a)
# print(len(a))
# a = run.load(load_text_image=True)
# print(a)
# print(len(a))
# a = run.load(load_title_to_text=True)
# print(a)
# print(len(a))
# b = run.load(load_maps=True)
# print(len(b))
# print(run.load(load_title_to_text)
# history_chat = []
# history_response = []
# chat_timer = 5
# new_query = '鲁鲁:你好我是新同学鲁鲁'
#
#
# selected_sample = run.retrieve_title(new_query, 7)
#
# print('限制长度之前:', selected_sample)
#
# story, selected_sample = run.organize_story_with_maxlen(selected_sample)
#
# print('当前辅助sample:', selected_sample)
#
# messages = run.organize_message(story, history_chat, history_response, new_query)
#
# response = run.get_completion_from_messages(messages)
#
# print(response)
#
# history_chat.append(new_query)
# history_response.append(response)
#
# history_chat, history_response = run.keep_tail(history_chat, history_response)
# print(history_chat, history_response)