manju_translate / app.py
taotao577's picture
Add application file
394df7b
raw
history blame
3.82 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSeq2SeqLM
import torch
max_input_length = 128
max_target_length = 128
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
# 1是中文到满语
model1_checkpoint = "Helsinki-NLP/opus-mt-zh-en"
tokenizer1 = AutoTokenizer.from_pretrained(model1_checkpoint)
model1 = AutoModelForSeq2SeqLM.from_pretrained(model1_checkpoint)
model1 = model1.to(device)
model1.load_state_dict(torch.load('../model/chineseToManju/epoch_41_valid_bleu_100.00_model_weights.bin'))
model1.eval()
# 2是满语到中文
model2_checkpoint = "Helsinki-NLP/opus-mt-en-zh"
tokenizer2 = AutoTokenizer.from_pretrained(model2_checkpoint)
model2 = AutoModelForSeq2SeqLM.from_pretrained(model2_checkpoint)
model2 = model2.to(device)
model2.load_state_dict(torch.load('../model/manjuToChinese/epoch_41_valid_bleu_0.00_model_weights.bin'))
model2.eval()
def chineseToManju(text):
batch_data = tokenizer1(
text,
padding=True,
max_length=max_input_length,
truncation=True,
return_tensors="pt"
)
generated_tokens = model1.generate(
batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
max_length=max_target_length,
).cpu().numpy()
res = tokenizer1.batch_decode(generated_tokens, skip_special_tokens=True)
return res
def manjuToChinese(text):
batch_data = tokenizer2(
text,
padding=True,
max_length=max_input_length,
truncation=True,
return_tensors="pt"
)
generated_tokens = model2.generate(
batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
max_length=max_target_length,
).cpu().numpy()
res = tokenizer2.batch_decode(generated_tokens, skip_special_tokens=True)
return res
with gr.Blocks() as demo:
#用markdown语法编辑输出一段话
gr.Markdown("## 满语翻译演示")
# 设置tab选项卡
with gr.Tab("满to中"):
#Blocks特有组件,设置所有子组件按垂直排列
#垂直排列是默认情况,不加也没关系
with gr.Column():
text_input1 = gr.Textbox(lines=2, placeholder="请输入满语",label="manju")
text_button1 = gr.Button("翻译")
text_output1 = gr.Textbox(lines=2, label="chinese")
with gr.Tab("中to满"):
#Blocks特有组件,设置所有子组件按水平排列
with gr.Column():
text_input2 = gr.Textbox(lines=2, placeholder="请输入中文",label="chinese")
text_button2 = gr.Button("翻译")
text_output2 = gr.Textbox(lines=2, label="manju")
#设置折叠内容
with gr.Accordion(""):
gr.Markdown("## 东北师范大学信息科学与技术学院 满语智能处理实验室")
gr.Markdown("#### 注意事项")
gr.Markdown("最长语句不能超过128个词!")
gr.Markdown("#### 以下是一些例子")
gr.Markdown('"manju": "sakda amji,be yabume oho.", "chinese": "大爷,我们该走了。"')
gr.Markdown('"manju": "ume ekxere,jai emu majige teki,majige muke be omiki.", "chinese": "忙什么呀,再坐一会儿,喝点水。"')
gr.Markdown('"manju": "omirakv oho,ubade emgeri hontoha inenggi tehe,suwembe ambula jobohuha.", "chinese": "不了,在这呆了有半天了,打扰你们了。"')
gr.Markdown('"manju": "ume yabure,ubade emu erin buda be jejfi jai yabuki.", "chinese": "别走了,在这吃了饭再走吧。"')
text_button1.click(manjuToChinese, inputs=text_input1, outputs=text_output1)
text_button2.click(chineseToManju, inputs=text_input2, outputs=text_output2)
demo.launch(server_name="0.0.0.0", server_port=1234, share=True)