nicholasKluge
commited on
Commit
•
8d1df0b
1
Parent(s):
4237769
Update app.py
Browse files
app.py
CHANGED
@@ -1,22 +1,31 @@
|
|
1 |
import time
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
|
6 |
model_id = "nicholasKluge/Aira-Instruct-PT-560M"
|
|
|
|
|
7 |
token = "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"
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
|
|
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
else:
|
16 |
-
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token)
|
17 |
|
18 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
|
19 |
model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
intro = """
|
22 |
## O que é `Aira`?
|
@@ -25,11 +34,19 @@ intro = """
|
|
25 |
|
26 |
Desenvolvemos os nossos chatbots de conversação de domínio aberto através da geração de texto condicional/ajuste fino por instruções. Esta abordagem tem muitas limitações. Apesar de podermos criar um chatbot capaz de responder a perguntas sobre qualquer assunto, é difícil forçar o modelo a produzir respostas de boa qualidade. E por boa, queremos dizer texto **factual** e **não tóxico**. Isto leva-nos a dois dos problemas mais comuns quando lidando com modelos generativos utilizados em aplicações de conversação:
|
27 |
|
|
|
|
|
28 |
🤥 Modelos generativos podem perpetuar a geração de conteúdo pseudo-informativo, ou seja, informações falsas que podem parecer verdadeiras.
|
29 |
|
30 |
🤬 Em certos tipos de tarefas, modelos generativos podem produzir conteúdo prejudicial e discriminatório inspirado em estereótipos históricos.
|
31 |
-
|
|
|
|
|
32 |
`Aira` destina-se apenas à investigação académica. Para mais informações, visite o nosso [HuggingFace models](https://huggingface.co/nicholasKluge) para ver como desenvolvemos `Aira`.
|
|
|
|
|
|
|
|
|
33 |
"""
|
34 |
|
35 |
disclaimer = """
|
@@ -39,19 +56,20 @@ Se desejar apresentar uma reclamação sobre qualquer mensagem produzida por `Ai
|
|
39 |
"""
|
40 |
|
41 |
with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
|
42 |
-
|
43 |
-
gr.Markdown("""<h1><center>Aira Demo
|
44 |
gr.Markdown(intro)
|
45 |
-
|
46 |
chatbot = gr.Chatbot(label="Aira").style(height=500)
|
|
|
47 |
|
48 |
-
with gr.Accordion(label="Parâmetros ⚙️", open=
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
|
56 |
clear = gr.Button("Limpar Conversa 🧹")
|
57 |
gr.Markdown(disclaimer)
|
@@ -59,23 +77,66 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
|
|
59 |
def user(user_message, chat_history):
|
60 |
return gr.update(value=user_message, interactive=True), chat_history + [["👤 " + user_message, None]]
|
61 |
|
62 |
-
def generate_response(user_msg, top_p, temperature, top_k, max_length, chat_history):
|
63 |
|
64 |
-
inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.eos_token, return_tensors="pt").to(device)
|
65 |
|
66 |
generated_response = model.generate(**inputs,
|
67 |
bos_token_id=tokenizer.bos_token_id,
|
68 |
pad_token_id=tokenizer.pad_token_id,
|
69 |
eos_token_id=tokenizer.eos_token_id,
|
|
|
70 |
do_sample=True,
|
71 |
-
early_stopping=True,
|
72 |
-
top_k=top_k,
|
73 |
max_length=max_length,
|
74 |
top_p=top_p,
|
75 |
-
temperature=temperature,
|
76 |
-
num_return_sequences=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
|
|
|
79 |
|
80 |
chat_history[-1][1] = "🤖 "
|
81 |
for character in bot_message:
|
@@ -84,10 +145,10 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
|
|
84 |
yield chat_history
|
85 |
|
86 |
response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
87 |
-
generate_response, [msg, top_p, temperature, top_k, max_length, chatbot], chatbot
|
88 |
)
|
89 |
response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
|
90 |
-
msg.submit(lambda x: gr.update(value=''),
|
91 |
clear.click(lambda: None, None, chatbot, queue=False)
|
92 |
|
93 |
demo.queue()
|
|
|
1 |
import time
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
|
5 |
|
6 |
model_id = "nicholasKluge/Aira-Instruct-PT-560M"
|
7 |
+
rewardmodel_id = "nicholasKluge/RewardModelPT"
|
8 |
+
toxicitymodel_id = "nicholasKluge/ToxicityModelPT"
|
9 |
token = "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"
|
10 |
|
11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
|
13 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token)
|
14 |
+
rewardModel = AutoModelForSequenceClassification.from_pretrained(rewardmodel_id, use_auth_token=token)
|
15 |
+
toxicityModel = AutoModelForSequenceClassification.from_pretrained(toxicitymodel_id, use_auth_token=token)
|
16 |
|
17 |
+
model.eval()
|
18 |
+
rewardModel.eval()
|
19 |
+
toxicityModel.eval()
|
|
|
|
|
20 |
|
|
|
21 |
model.to(device)
|
22 |
+
rewardModel.to(device)
|
23 |
+
toxicityModel.to(device)
|
24 |
+
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
|
26 |
+
rewardTokenizer = AutoTokenizer.from_pretrained(rewardmodel_id, use_auth_token=token)
|
27 |
+
toxiciyTokenizer = AutoTokenizer.from_pretrained(toxicitymodel_id, use_auth_token=token)
|
28 |
+
|
29 |
|
30 |
intro = """
|
31 |
## O que é `Aira`?
|
|
|
34 |
|
35 |
Desenvolvemos os nossos chatbots de conversação de domínio aberto através da geração de texto condicional/ajuste fino por instruções. Esta abordagem tem muitas limitações. Apesar de podermos criar um chatbot capaz de responder a perguntas sobre qualquer assunto, é difícil forçar o modelo a produzir respostas de boa qualidade. E por boa, queremos dizer texto **factual** e **não tóxico**. Isto leva-nos a dois dos problemas mais comuns quando lidando com modelos generativos utilizados em aplicações de conversação:
|
36 |
|
37 |
+
## Limitações
|
38 |
+
|
39 |
🤥 Modelos generativos podem perpetuar a geração de conteúdo pseudo-informativo, ou seja, informações falsas que podem parecer verdadeiras.
|
40 |
|
41 |
🤬 Em certos tipos de tarefas, modelos generativos podem produzir conteúdo prejudicial e discriminatório inspirado em estereótipos históricos.
|
42 |
+
|
43 |
+
## Uso Intendido
|
44 |
+
|
45 |
`Aira` destina-se apenas à investigação académica. Para mais informações, visite o nosso [HuggingFace models](https://huggingface.co/nicholasKluge) para ver como desenvolvemos `Aira`.
|
46 |
+
|
47 |
+
## Como essa demo funciona?
|
48 |
+
|
49 |
+
Esta demonstração utiliza um [`modelo de recompensa`](https://huggingface.co/nicholasKluge/RewardModel) e um [`modelo de toxicidade`](https://huggingface.co/nicholasKluge/ToxicityModel) para avaliar a pontuação de cada resposta candidata, considerando o seu alinhamento com a mensagem do utilizador e o seu nível de toxicidade. A função de geração organiza as respostas candidatas por ordem da sua pontuação de recompensa e elimina as respostas consideradas tóxicas ou nocivas. Posteriormente, a função de geração devolve a resposta candidata com a pontuação mais elevada que ultrapassa o limiar de segurança, ou uma mensagem pré-estabelecida se não forem identificados candidatos seguros.
|
50 |
"""
|
51 |
|
52 |
disclaimer = """
|
|
|
56 |
"""
|
57 |
|
58 |
with gr.Blocks(theme='freddyaboulton/dracula_revamped') as demo:
|
59 |
+
|
60 |
+
gr.Markdown("""<h1><center>Aira Demo 🤓💬</h1></center>""")
|
61 |
gr.Markdown(intro)
|
62 |
+
|
63 |
chatbot = gr.Chatbot(label="Aira").style(height=500)
|
64 |
+
msg = gr.Textbox(label="Write a question or comment to Aira ...", placeholder="Hi Aira, how are you?")
|
65 |
|
66 |
+
with gr.Accordion(label="Parâmetros ⚙️", open=True):
|
67 |
+
safety = gr.Radio(["On", "Off"], label="Proteção 🛡️", value="On", info="Ajuda a prevenir o modelo de gerar conteúdo tóxico.")
|
68 |
+
top_k = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True, label="Top-k", info="Controla o número de tokens de maior probabilidade a considerar em cada passo.")
|
69 |
+
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.50, step=0.05, interactive=True, label="Top-p", info="Controla a probabilidade cumulativa dos tokens gerados.")
|
70 |
+
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.1, step=0.1, interactive=True, label="Temperatura", info="Controla a aleatoriedade dos tokens gerados.")
|
71 |
+
max_length = gr.Slider(minimum=10, maximum=500, value=100, step=10, interactive=True, label="Comprimento Máximo", info="Controla o comprimento máximo do texto gerado.")
|
72 |
+
smaple_from = gr.Slider(minimum=2, maximum=10, value=2, step=1, interactive=True, label="Amostragem por Rejeição", info="Controla o número de gerações a partir das quais o modelo de recompensa irá selecionar.")
|
73 |
|
74 |
clear = gr.Button("Limpar Conversa 🧹")
|
75 |
gr.Markdown(disclaimer)
|
|
|
77 |
def user(user_message, chat_history):
|
78 |
return gr.update(value=user_message, interactive=True), chat_history + [["👤 " + user_message, None]]
|
79 |
|
80 |
+
def generate_response(user_msg, top_p, temperature, top_k, max_length, smaple_from, safety, chat_history):
|
81 |
|
82 |
+
inputs = tokenizer(tokenizer.bos_token + user_msg + tokenizer.eos_token, return_tensors="pt").to(model.device)
|
83 |
|
84 |
generated_response = model.generate(**inputs,
|
85 |
bos_token_id=tokenizer.bos_token_id,
|
86 |
pad_token_id=tokenizer.pad_token_id,
|
87 |
eos_token_id=tokenizer.eos_token_id,
|
88 |
+
repetition_penalty=1.8,
|
89 |
do_sample=True,
|
90 |
+
early_stopping=True,
|
91 |
+
top_k=top_k,
|
92 |
max_length=max_length,
|
93 |
top_p=top_p,
|
94 |
+
temperature=temperature,
|
95 |
+
num_return_sequences=smaple_from)
|
96 |
+
|
97 |
+
decoded_text = [tokenizer.decode(tokens, skip_special_tokens=True).replace(user_msg, "") for tokens in generated_response]
|
98 |
+
|
99 |
+
rewards = list()
|
100 |
+
toxicities = list()
|
101 |
+
|
102 |
+
for text in decoded_text:
|
103 |
+
reward_tokens = rewardTokenizer(user_msg, text,
|
104 |
+
truncation=True,
|
105 |
+
max_length=512,
|
106 |
+
return_token_type_ids=False,
|
107 |
+
return_tensors="pt",
|
108 |
+
return_attention_mask=True)
|
109 |
+
|
110 |
+
reward_tokens.to(rewardModel.device)
|
111 |
+
|
112 |
+
reward = rewardModel(**reward_tokens)[0].item()
|
113 |
+
|
114 |
+
toxicity_tokens = toxiciyTokenizer(user_msg + " " + text,
|
115 |
+
truncation=True,
|
116 |
+
max_length=512,
|
117 |
+
return_token_type_ids=False,
|
118 |
+
return_tensors="pt",
|
119 |
+
return_attention_mask=True)
|
120 |
+
|
121 |
+
toxicity_tokens.to(toxicityModel.device)
|
122 |
+
|
123 |
+
toxicity = toxicityModel(**toxicity_tokens)[0].item()
|
124 |
+
|
125 |
+
rewards.append(reward)
|
126 |
+
toxicities.append(toxicity)
|
127 |
+
|
128 |
+
toxicity_threshold = 5
|
129 |
+
|
130 |
+
ordered_generations = sorted(zip(decoded_text, rewards, toxicities), key=lambda x: x[1], reverse=True)
|
131 |
+
|
132 |
+
if safety == "On":
|
133 |
+
ordered_generations = [(x, y, z) for (x, y, z) in ordered_generations if z >= toxicity_threshold]
|
134 |
+
|
135 |
+
if len(ordered_generations) == 0:
|
136 |
+
bot_message = """Peço desculpa pelo incómodo, mas parece que não foi possível identificar respostas adequadas que cumpram as nossas normas de segurança. Infelizmente, isto indica que o conteúdo gerado pode conter elementos de toxicidade ou pode não ajudar a responder à sua mensagem. A sua opinião é valiosa para nós e esforçamo-nos por garantir uma conversa segura e construtiva. Não hesite em fornecer mais pormenores ou colocar quaisquer outras questões, e farei o meu melhor para o ajudar."""
|
137 |
|
138 |
+
else:
|
139 |
+
bot_message = ordered_generations[0][0]
|
140 |
|
141 |
chat_history[-1][1] = "🤖 "
|
142 |
for character in bot_message:
|
|
|
145 |
yield chat_history
|
146 |
|
147 |
response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
148 |
+
generate_response, [msg, top_p, temperature, top_k, max_length, smaple_from, safety, chatbot], chatbot
|
149 |
)
|
150 |
response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
|
151 |
+
msg.submit(lambda x: gr.update(value=''), None,[msg])
|
152 |
clear.click(lambda: None, None, chatbot, queue=False)
|
153 |
|
154 |
demo.queue()
|