File size: 2,585 Bytes
9e02690
 
 
 
9bd55b7
9e02690
cdbedd8
 
7a75ce5
9e02690
 
 
 
 
 
 
 
 
 
e2a77ad
cdbedd8
e2a77ad
 
 
9e02690
 
 
 
 
 
 
 
 
 
 
 
7e3a12b
7b1042c
9e02690
 
7b1042c
 
 
 
7e3a12b
9e02690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2a77ad
9e02690
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from typing import List, Optional
from transformers import BertTokenizer, BartForConditionalGeneration

title = "HIT-TMG/dialogue-bart-large-chinese-DuSinc"
description = """
This is a fine-tuned version of HIT-TMG/dialogue-bart-large-chinese on the DuSinc dataset.
But it only has chit-chat ability without knowledge since we haven't introduced knowledge retrieval interface yet.\n
See some details of model card at https://huggingface.co/HIT-TMG/dialogue-bart-large-chinese-DuSinc . \n\n
Besides starting the conversation from scratch, you can also input the whole dialogue history utterance by utterance seperated by '[SEP]'. \n
"""


tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese-DuSinc")
model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-large-chinese-DuSinc")

tokenizer.truncation_side = 'left'
max_length = 512

examples = [
    ["你有什么爱好吗"],
    ["你好。[SEP]嘿嘿你好,请问你最近在忙什么呢?[SEP]我最近养了一只狗狗,我在训练它呢。"]
]


def chat_func(input_utterance: str, history: Optional[List[str]] = None):
    if history is not None:
        history.extend(input_utterance.split(tokenizer.sep_token))
    else:
        history = input_utterance.split(tokenizer.sep_token)

    history_str = "[history] " + tokenizer.sep_token.join(history)

    input_ids = tokenizer(history_str,
                          return_tensors='pt',
                          truncation=True,
                          max_length=max_length,
                          ).input_ids

    output_ids = model.generate(input_ids,
                                max_new_tokens=30,
                                top_p=0.95,
                                do_sample=True,
                                num_beams=4)[0]

    response = tokenizer.decode(output_ids, skip_special_tokens=True)

    history.append(response)


    if len(history) % 2 == 0:
        display_utterances = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
    else:
        display_utterances = [("", history[0])] + [(history[i], history[i + 1]) for i in range(1, len(history) - 1, 2)]

    return display_utterances, history


demo = gr.Interface(fn=chat_func,
                    title=title,
                    description=description,
                    inputs=[gr.Textbox(lines=1, placeholder="Input current utterance"), "state"],
                    examples=examples,
                    outputs=["chatbot", "state"])


if __name__ == "__main__":
    demo.launch()