taotao577 commited on
Commit
394df7b
1 Parent(s): dd13a45

Add application file

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSeq2SeqLM
3
+ import torch
4
+
5
+
6
+ max_input_length = 128
7
+ max_target_length = 128
8
+
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ print(f'Using {device} device')
11
+
12
+ # 1是中文到满语
13
+ model1_checkpoint = "Helsinki-NLP/opus-mt-zh-en"
14
+ tokenizer1 = AutoTokenizer.from_pretrained(model1_checkpoint)
15
+ model1 = AutoModelForSeq2SeqLM.from_pretrained(model1_checkpoint)
16
+ model1 = model1.to(device)
17
+
18
+ model1.load_state_dict(torch.load('../model/chineseToManju/epoch_41_valid_bleu_100.00_model_weights.bin'))
19
+ model1.eval()
20
+ # 2是满语到中文
21
+ model2_checkpoint = "Helsinki-NLP/opus-mt-en-zh"
22
+ tokenizer2 = AutoTokenizer.from_pretrained(model2_checkpoint)
23
+ model2 = AutoModelForSeq2SeqLM.from_pretrained(model2_checkpoint)
24
+ model2 = model2.to(device)
25
+
26
+ model2.load_state_dict(torch.load('../model/manjuToChinese/epoch_41_valid_bleu_0.00_model_weights.bin'))
27
+ model2.eval()
28
+
29
+ def chineseToManju(text):
30
+ batch_data = tokenizer1(
31
+ text,
32
+ padding=True,
33
+ max_length=max_input_length,
34
+ truncation=True,
35
+ return_tensors="pt"
36
+ )
37
+ generated_tokens = model1.generate(
38
+ batch_data["input_ids"],
39
+ attention_mask=batch_data["attention_mask"],
40
+ max_length=max_target_length,
41
+ ).cpu().numpy()
42
+ res = tokenizer1.batch_decode(generated_tokens, skip_special_tokens=True)
43
+ return res
44
+
45
+ def manjuToChinese(text):
46
+ batch_data = tokenizer2(
47
+ text,
48
+ padding=True,
49
+ max_length=max_input_length,
50
+ truncation=True,
51
+ return_tensors="pt"
52
+ )
53
+ generated_tokens = model2.generate(
54
+ batch_data["input_ids"],
55
+ attention_mask=batch_data["attention_mask"],
56
+ max_length=max_target_length,
57
+ ).cpu().numpy()
58
+ res = tokenizer2.batch_decode(generated_tokens, skip_special_tokens=True)
59
+ return res
60
+
61
+ with gr.Blocks() as demo:
62
+ #用markdown语法编辑输出一段话
63
+ gr.Markdown("## 满语翻译演示")
64
+ # 设置tab选项卡
65
+ with gr.Tab("满to中"):
66
+ #Blocks特有组件,设置所有子组件按垂直排列
67
+ #垂直排列是默认情况,不加也没关系
68
+ with gr.Column():
69
+ text_input1 = gr.Textbox(lines=2, placeholder="请输入满语",label="manju")
70
+ text_button1 = gr.Button("翻译")
71
+ text_output1 = gr.Textbox(lines=2, label="chinese")
72
+ with gr.Tab("中to满"):
73
+ #Blocks特有组件,设置所有子组件按水平排列
74
+ with gr.Column():
75
+ text_input2 = gr.Textbox(lines=2, placeholder="请输入中文",label="chinese")
76
+ text_button2 = gr.Button("翻译")
77
+ text_output2 = gr.Textbox(lines=2, label="manju")
78
+ #设置折叠内容
79
+ with gr.Accordion(""):
80
+ gr.Markdown("## 东北师范大学信息科学与技术学院 满语智能处理实验室")
81
+ gr.Markdown("#### 注意事项")
82
+ gr.Markdown("最长语句不能超过128个词!")
83
+ gr.Markdown("#### 以下是一些例子")
84
+ gr.Markdown('"manju": "sakda amji,be yabume oho.", "chinese": "大爷,我们该走了。"')
85
+ gr.Markdown('"manju": "ume ekxere,jai emu majige teki,majige muke be omiki.", "chinese": "忙什么呀,再坐一会儿,喝点水。"')
86
+ gr.Markdown('"manju": "omirakv oho,ubade emgeri hontoha inenggi tehe,suwembe ambula jobohuha.", "chinese": "不了,在这呆了有半天了,打扰你们了。"')
87
+ gr.Markdown('"manju": "ume yabure,ubade emu erin buda be jejfi jai yabuki.", "chinese": "别走了,在这吃了饭再走吧。"')
88
+ text_button1.click(manjuToChinese, inputs=text_input1, outputs=text_output1)
89
+ text_button2.click(chineseToManju, inputs=text_input2, outputs=text_output2)
90
+ demo.launch(server_name="0.0.0.0", server_port=1234, share=True)