Cpuritan's picture
Upload 2 files
795cb79
raw
history blame
5.18 kB
import os
import io
import PIL
import torch
import librosa
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import BertConfig, BertTokenizer, XLMRobertaForSequenceClassification
from keras.models import load_model
def text_clf(text):
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vocab_file = "vocab.txt" # 词汇表
tokenizer = BertTokenizer(vocab_file)
# 加载模型
config = BertConfig.from_pretrained("nanaaaa/emotion_chinese_english")
model = XLMRobertaForSequenceClassification.from_pretrained("nanaaaa/emotion_chinese_english", config=config)
# model.to(device)
inputs = tokenizer(text, return_tensors="pt")
# inputs.to(device)
# 模型推断
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
# 创建标签和概率列表
labels = ["害怕", "高兴喵", "惊喜", "伤心", "生气"]
probabilities = probs.detach().cpu().numpy()[0].tolist()
# 返回标签和概率列表
return {labels[i]: float(probabilities[i]) for i in range(len(labels))}
def audio_clf(aud):
my_model = load_model('speech_mfcc_model.h5')
def normalizeVoiceLen(y, normalizedLen):
nframes = len(y)
y = np.reshape(y, [nframes, 1]).T
# 归一化音频长度为2s,32000数据点
if (nframes < normalizedLen):
res = normalizedLen - nframes
res_data = np.zeros([1, res], dtype=np.float32)
y = np.reshape(y, [nframes, 1]).T
y = np.c_[y, res_data]
else:
y = y[:, 0:normalizedLen]
return y[0]
def getNearestLen(framelength, sr):
framesize = framelength * sr
# 找到与当前framesize最接近的2的正整数次方
nfftdict = {}
lists = [32, 64, 128, 256, 512, 1024]
for i in lists:
nfftdict[i] = abs(framesize - i)
print(nfftdict)
sortlist = sorted(nfftdict.items(), key=lambda x: x[1])
print(sortlist)
framesize = int(sortlist[0][0]) # 取最接近当前framesize的那个2的正整数次方值为新的framesize
return framesize
VOICE_LEN = 35000
sr, y = aud
N_FFT = getNearestLen(0.5, sr)
y = normalizeVoiceLen(y, VOICE_LEN) # 归一化长度
mfcc_data = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13, n_fft=N_FFT, hop_length=int(N_FFT / 4))
feature = np.mean(mfcc_data, axis=0)
# 数据标准化
data = feature.tolist()
DATA_MEAN = np.mean(feature.tolist(), axis=0)
DATA_STD = np.std(feature.tolist(), axis=0)
data -= DATA_MEAN
data /= DATA_STD
data = np.array(data)
data = data.reshape((1, data.shape[0], 1))
pred = my_model.predict(data)
labels1 = ["angry", "fear", "joy", "neutral", "sadness", "surprise"]
probabilities1 = pred[0].tolist()
return {labels1[i]: float(probabilities1[i]) for i in range(len(labels1))}
def cir_clf(L, R):
df_4 = pd.read_csv(r'./df_4.csv', encoding="gbk")
fig, ax = plt.subplots()
r = df_4["R_nor"][int(L):int(R)]
theta = (2 * np.pi * df_4["Theta_nor"])[int(L):int(R)]
def clf_col(x):
if -1.5 * np.pi > x > -2 * np.pi:
return 5
if -1.5 * np.pi < x < -1.1 * np.pi:
return 2
if -1.1 * np.pi < x < -1 * np.pi:
return 3
if 1.04 * np.pi > x > 1 * np.pi:
return 3
if 1.1 * np.pi < x < 1.375 * np.pi:
return 4
if 1.625 * np.pi > x > 1.375 * np.pi:
return 1
if 1.625 * np.pi < x < 2 * np.pi:
return 0
theta1 = theta.copy()
colors = theta1.apply(lambda x: clf_col(x))
ax = plt.subplot(111, projection="polar")
c = ax.scatter(theta, r, c=colors, cmap="hsv", alpha=0.6)
fig.set_size_inches(10, 10)
def fig2data(fig):
import PIL.Image as Image
fig.canvas.draw()
w, h = fig.canvas.get_width_height()
buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
buf.shape = (w, h, 4)
buf = np.roll(buf, 3, axis=2)
image = Image.frombytes("RGBA", (w, h), buf.tostring())
image = np.asarray(image)
return image
return fig2data(fig)
with gr.Blocks() as demo:
with gr.Tab("Flip Text"):
text = gr.Textbox(label="文本哟")
text_output = gr.outputs.Label(label="情感呢")
text_button = gr.Button("确认")
with gr.Tab("Flip Audio"):
audio = gr.Audio(label="音频捏")
audio_output = gr.outputs.Label(label="情感哟")
audio_button = gr.Button("确认")
with gr.Tab("Flip Circle"):
cir_l = gr.Slider(0, 30000, step=1)
cir_r = gr.Slider(0, 30000, step=1)
cir_output = gr.outputs.Image(type='numpy', label="情感圈")
cir_button = gr.Button("确认")
text_button.click(fn=text_clf, inputs=text, outputs=text_output)
audio_button.click(fn=audio_clf, inputs=audio, outputs=audio_output)
cir_button.click(fn=cir_clf, inputs=[cir_l, cir_r], outputs=cir_output)
demo.launch(share=True)