ChatHaruhi / text.py
BlairLeng's picture
changes
2e33d6a
import collections
import os
import pickle
from argparse import Namespace
import numpy as np
import torch
from PIL import Image
from torch import cosine_similarity
from transformers import AutoTokenizer, AutoModel
def download_models():
# Import our models. The package will take care of downloading the models automatically
model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False,
init_embeddings_model=None)
model = AutoModel.from_pretrained("silk-road/luotuo-bert", trust_remote_code=True, model_args=model_args)
return model
class Text:
def __init__(self, text_dir, model, num_steps, text_image_pkl_path=None, dict_text_pkl_path=None, pkl_path=None, dict_path=None, image_path=None, maps_path=None):
self.dict_text_pkl_path = dict_text_pkl_path
self.text_image_pkl_path = text_image_pkl_path
self.text_dir = text_dir
self.model = model
self.num_steps = num_steps
self.pkl_path = pkl_path
self.dict_path = dict_path
self.image_path = image_path
self.maps_path = maps_path
def get_embedding(self, texts):
tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert")
model = download_models()
# 截断
# str or strList
texts = texts if isinstance(texts, list) else [texts]
for i in range(len(texts)):
if len(texts[i]) > self.num_steps:
texts[i] = texts[i][:self.num_steps]
# Tokenize the texts
inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
# Extract the embeddings
# Get the embeddings
with torch.no_grad():
embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
return embeddings[0] if len(texts) == 1 else embeddings
def read_text(self, save_embeddings=False, save_maps=False):
"""抽取、预存"""
text_embeddings = collections.defaultdict()
text_keys = []
dirs = os.listdir(self.text_dir)
data = []
texts = []
id = 0
for dir in dirs:
with open(self.text_dir + '/' + dir, 'r') as fr:
for line in fr.readlines():
category = collections.defaultdict(str)
ch = ':' if ':' in line else ':'
if '旁白' in line:
text = line.strip().split(ch)[1].strip()
else:
text = ''.join(list(line.strip().split(ch)[1])[1:-1]) # 提取「」内的文本
if text in text_keys: # 避免重复的text,导致embeds 和 maps形状不一致
continue
text_keys.append(text)
if save_maps:
category["titles"] = dir.split('.')[0]
category["id"] = str(id)
category["text"] = text
id = id + 1
data.append(dict(category))
texts.append(text)
embeddings = self.get_embedding(texts)
if save_embeddings:
for text, embed in zip(texts, embeddings):
text_embeddings[text] = self.get_embedding(text)
if save_embeddings:
self.store(self.pkl_path, text_embeddings)
if save_maps:
self.store(self.maps_path, data)
return text_embeddings, data
def load(self, load_pkl=False, load_maps=False, load_dict_text=False, load_text_image=False):
if self.pkl_path and load_pkl:
with open(self.pkl_path, 'rb') as f:
return pickle.load(f)
elif self.maps_path and load_maps:
with open(self.maps_path, 'rb') as f:
return pickle.load(f)
elif self.dict_text_pkl_path and load_dict_text:
with open(self.dict_text_pkl_path, 'rb') as f:
return pickle.load(f)
elif self.text_image_pkl_path and load_text_image:
with open(self.text_image_pkl_path, 'rb') as f:
return pickle.load(f)
else:
print("No pkl_path")
def get_cosine_similarity(self, texts, get_image=False, get_texts=False):
"""
计算文本列表的相似度避免重复计算query_similarity
texts[0] = query
"""
if get_image:
pkl = self.load(load_dict_text=True)
elif get_texts:
pkl = self.load(load_pkl=True)
else:
pkl = {}
embeddings = self.get_embedding(texts[1:]).reshape(-1, 1536)
for text, embed in zip(texts, embeddings):
pkl[text] = embed
query_embedding = self.get_embedding(texts[0]).reshape(1, -1)
texts_embeddings = np.array([value.numpy().reshape(-1, 1536) for value in pkl.values()]).squeeze(1)
return cosine_similarity(query_embedding, torch.from_numpy(texts_embeddings))
def store(self, path, data):
with open(path, 'wb+') as f:
pickle.dump(data, f)
def text_to_image(self, text, save_dict_text=False):
"""
给定文本出图片
计算query 和 texts 的相似度,取最高的作为new_query 查询image
到text_image_dict 读取图片名
然后到images里面加载该图片然后返回
"""
if save_dict_text:
text_image = {}
with open(self.dict_path, 'r') as f:
data = f.readlines()
for sub_text, image in zip(data[::2], data[1::2]):
text_image[sub_text.strip()] = image.strip()
self.store(self.text_image_pkl_path, text_image)
keys_embeddings = {}
embeddings = self.get_embedding(list(text_image.keys()))
for key, embed in zip(text_image.keys(), embeddings):
keys_embeddings[key] = embed
self.store(self.dict_text_pkl_path, keys_embeddings)
if self.dict_path and self.image_path:
# 加载 text-imageName
text_image = self.load(load_text_image=True)
keys = list(text_image.keys())
keys.insert(0, text)
query_similarity = self.get_cosine_similarity(keys, get_image=True)
key_index = query_similarity.argmax(dim=0)
text = list(text_image.keys())[key_index]
image = text_image[text] + '.jpg'
if image in os.listdir(self.image_path):
res = Image.open(self.image_path + '/' + image)
# res.show()
return res
else:
print("Image doesn't exist")
else:
print("No path")
def text_to_text(self, text):
pkl = self.load(load_pkl=True)
texts = list(pkl.keys())
texts.insert(0, text)
texts_similarity = self.get_cosine_similarity(texts, get_texts=True)
key_index = texts_similarity.argmax(dim=0).item()
value = list(pkl.keys())[key_index]
return value
# if __name__ == '__main__':
# pkl_path = './pkl/texts.pkl'
# maps_path = './pkl/maps.pkl'
# text_image_pkl_path='./pkl/text_image.pkl'
# dict_path = "../characters/haruhi/text_image_dict.txt"
# dict_text_pkl_path = './pkl/dict_text.pkl'
# image_path = "../characters/haruhi/images"
# text_dir = "../characters/haruhi/texts"
# model = download_models()
# text = Text(text_dir, text_image_pkl_path=text_image_pkl_path, maps_path=maps_path,
# dict_text_pkl_path=dict_text_pkl_path, model=model, num_steps=50, pkl_path=pkl_path,
# dict_path=dict_path, image_path=image_path)
# text.read_text(save_maps=True, save_embeddings=True)
# data = text.load(load_pkl=True)
# sub_text = "你好!"
# image = text.text_to_image(sub_text)
# print(image)
# sub_texts = ["hello", "你好"]
# print(text.get_cosine_similarity(sub_texts))
# value = text.text_to_text(sub_text)
# print(value)