File size: 3,749 Bytes
394df7b
 
ca485c1
 
394df7b
 
 
 
 
 
 
 
 
 
 
 
 
 
8eca82c
394df7b
 
 
 
 
 
 
8eca82c
394df7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb35d2b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSeq2SeqLM



max_input_length = 128
max_target_length = 128

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

# 1是中文到满语
model1_checkpoint = "Helsinki-NLP/opus-mt-zh-en"
tokenizer1 = AutoTokenizer.from_pretrained(model1_checkpoint)
model1 = AutoModelForSeq2SeqLM.from_pretrained(model1_checkpoint)
model1 = model1.to(device)

model1.load_state_dict(torch.load('epoch_41_valid_bleu_100.00_model_weights.bin'))
model1.eval()
# 2是满语到中文
model2_checkpoint = "Helsinki-NLP/opus-mt-en-zh"
tokenizer2 = AutoTokenizer.from_pretrained(model2_checkpoint)
model2 = AutoModelForSeq2SeqLM.from_pretrained(model2_checkpoint)
model2 = model2.to(device)

model2.load_state_dict(torch.load('epoch_41_valid_bleu_0.00_model_weights.bin'))
model2.eval()

def chineseToManju(text):
    batch_data = tokenizer1(
        text,
        padding=True,
        max_length=max_input_length,
        truncation=True,
        return_tensors="pt"
    )
    generated_tokens = model1.generate(
        batch_data["input_ids"],
        attention_mask=batch_data["attention_mask"],
        max_length=max_target_length,
    ).cpu().numpy()
    res = tokenizer1.batch_decode(generated_tokens, skip_special_tokens=True)
    return res

def manjuToChinese(text):
    batch_data = tokenizer2(
        text,
        padding=True,
        max_length=max_input_length,
        truncation=True,
        return_tensors="pt"
    )
    generated_tokens = model2.generate(
        batch_data["input_ids"],
        attention_mask=batch_data["attention_mask"],
        max_length=max_target_length,
    ).cpu().numpy()
    res = tokenizer2.batch_decode(generated_tokens, skip_special_tokens=True)
    return res

with gr.Blocks() as demo:
    #用markdown语法编辑输出一段话
    gr.Markdown("## 满语翻译演示")
    # 设置tab选项卡
    with gr.Tab("满to中"):
        #Blocks特有组件,设置所有子组件按垂直排列
        #垂直排列是默认情况,不加也没关系
        with gr.Column():
            text_input1 = gr.Textbox(lines=2, placeholder="请输入满语",label="manju")
            text_button1 = gr.Button("翻译")
            text_output1 = gr.Textbox(lines=2, label="chinese")
    with gr.Tab("中to满"):
        #Blocks特有组件,设置所有子组件按水平排列
        with gr.Column():
            text_input2 = gr.Textbox(lines=2, placeholder="请输入中文",label="chinese")
            text_button2 = gr.Button("翻译")
            text_output2 = gr.Textbox(lines=2, label="manju")
    #设置折叠内容
    with gr.Accordion(""):
        gr.Markdown("## 东北师范大学信息科学与技术学院  满语智能处理实验室")
        gr.Markdown("#### 注意事项")
        gr.Markdown("最长语句不能超过128个词!")
        gr.Markdown("#### 以下是一些例子")
        gr.Markdown('"manju": "sakda amji,be yabume oho.", "chinese": "大爷,我们该走了。"')
        gr.Markdown('"manju": "ume ekxere,jai emu majige teki,majige muke be omiki.", "chinese": "忙什么呀,再坐一会儿,喝点水。"')
        gr.Markdown('"manju": "omirakv oho,ubade emgeri hontoha inenggi tehe,suwembe ambula jobohuha.", "chinese": "不了,在这呆了有半天了,打扰你们了。"')
        gr.Markdown('"manju": "ume yabure,ubade emu erin buda be jejfi jai yabuki.", "chinese": "别走了,在这吃了饭再走吧。"')
    text_button1.click(manjuToChinese, inputs=text_input1, outputs=text_output1)
    text_button2.click(chineseToManju, inputs=text_input2, outputs=text_output2)

if __name__ == "__main__":
    demo.launch()