Upload 4 files
Browse files- Fine-Tuning DialoGPT-Medium on Daily Dialog Dataset.ipynb +439 -0
- Test_saved_model.ipynb +101 -0
- dataset_format.xlsx +0 -0
- requirements.txt +3 -0
Fine-Tuning DialoGPT-Medium on Daily Dialog Dataset.ipynb
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import numpy as np\n",
|
10 |
+
"from datasets import load_dataset\n",
|
11 |
+
"from transformers import GPT2Tokenizer, GPT2LMHeadModel, TrainingArguments, Trainer"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 2,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"# Load the DailyDialog dataset\n",
|
21 |
+
"dataset = load_dataset('daily_dialog')\n",
|
22 |
+
"\n",
|
23 |
+
"# Concatenate all utterances within a dialogue and map to 'dialog' key\n",
|
24 |
+
"def concatenate_utterances(example):\n",
|
25 |
+
" example['dialog'] = \" \".join(example['dialog'])\n",
|
26 |
+
" return example\n",
|
27 |
+
"\n",
|
28 |
+
"# Apply the function to all examples in the dataset\n",
|
29 |
+
"dataset = dataset.map(concatenate_utterances)"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 3,
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"# Load the tokenizer and model\n",
|
39 |
+
"tokenizer = GPT2Tokenizer.from_pretrained('microsoft/DialoGPT-medium')\n",
|
40 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
41 |
+
"model = GPT2LMHeadModel.from_pretrained('microsoft/DialoGPT-medium')"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": 4,
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [
|
49 |
+
{
|
50 |
+
"data": {
|
51 |
+
"application/vnd.jupyter.widget-view+json": {
|
52 |
+
"model_id": "5d576321ac974a118f75b83cd8437256",
|
53 |
+
"version_major": 2,
|
54 |
+
"version_minor": 0
|
55 |
+
},
|
56 |
+
"text/plain": [
|
57 |
+
"Map: 0%| | 0/1000 [00:00<?, ? examples/s]"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
"metadata": {},
|
61 |
+
"output_type": "display_data"
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"data": {
|
65 |
+
"application/vnd.jupyter.widget-view+json": {
|
66 |
+
"model_id": "8e7254605abe41dbad8d6b2321d904c5",
|
67 |
+
"version_major": 2,
|
68 |
+
"version_minor": 0
|
69 |
+
},
|
70 |
+
"text/plain": [
|
71 |
+
"Map: 0%| | 0/1000 [00:00<?, ? examples/s]"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
"metadata": {},
|
75 |
+
"output_type": "display_data"
|
76 |
+
}
|
77 |
+
],
|
78 |
+
"source": [
|
79 |
+
"# Encode the dataset\n",
|
80 |
+
"def encode(examples):\n",
|
81 |
+
" encoded = tokenizer(examples['dialog'], truncation=True, padding='max_length', max_length=128)\n",
|
82 |
+
" encoded['labels'] = encoded['input_ids'][:]\n",
|
83 |
+
" return encoded\n",
|
84 |
+
"\n",
|
85 |
+
"encoded_dataset = dataset.map(encode, batched=True)"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": 5,
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"# Define training arguments\n",
|
95 |
+
"training_args = TrainingArguments(\n",
|
96 |
+
" output_dir='model', # output directory\n",
|
97 |
+
" num_train_epochs=2, # total number of training epochs\n",
|
98 |
+
" per_device_train_batch_size=64, # batch size per device during training\n",
|
99 |
+
" per_device_eval_batch_size=64, # batch size for evaluation\n",
|
100 |
+
" warmup_steps=500, # number of warmup steps for learning rate scheduler\n",
|
101 |
+
" weight_decay=0.01, # strength of weight decay\n",
|
102 |
+
" logging_dir=None, # directory for storing logs\n",
|
103 |
+
")\n",
|
104 |
+
"\n",
|
105 |
+
"# Create Trainer\n",
|
106 |
+
"trainer = Trainer(\n",
|
107 |
+
" model=model,\n",
|
108 |
+
" args=training_args,\n",
|
109 |
+
" train_dataset=encoded_dataset['train'],\n",
|
110 |
+
" eval_dataset=encoded_dataset['validation']\n",
|
111 |
+
")"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "code",
|
116 |
+
"execution_count": 6,
|
117 |
+
"metadata": {},
|
118 |
+
"outputs": [
|
119 |
+
{
|
120 |
+
"data": {
|
121 |
+
"application/vnd.jupyter.widget-view+json": {
|
122 |
+
"model_id": "2a46f1165bba4ef8b597eace733e9eaf",
|
123 |
+
"version_major": 2,
|
124 |
+
"version_minor": 0
|
125 |
+
},
|
126 |
+
"text/plain": [
|
127 |
+
" 0%| | 0/16 [00:00<?, ?it/s]"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
"metadata": {},
|
131 |
+
"output_type": "display_data"
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"data": {
|
135 |
+
"application/vnd.jupyter.widget-view+json": {
|
136 |
+
"model_id": "1214b0347e7149c88d02d31dbc53f523",
|
137 |
+
"version_major": 2,
|
138 |
+
"version_minor": 0
|
139 |
+
},
|
140 |
+
"text/plain": [
|
141 |
+
" 0%| | 0/1 [00:00<?, ?it/s]"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
"metadata": {},
|
145 |
+
"output_type": "display_data"
|
146 |
+
}
|
147 |
+
],
|
148 |
+
"source": [
|
149 |
+
"# Evaluate before fine-tuning\n",
|
150 |
+
"pre_eval_results = trainer.evaluate(encoded_dataset['validation'])\n",
|
151 |
+
"\n",
|
152 |
+
"# Get predictions for validation set before fine tuning for 10 samples\n",
|
153 |
+
"pre_val_predictions = trainer.predict(encoded_dataset['validation'].select(range(10)))"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"execution_count": 7,
|
159 |
+
"metadata": {},
|
160 |
+
"outputs": [
|
161 |
+
{
|
162 |
+
"data": {
|
163 |
+
"application/vnd.jupyter.widget-view+json": {
|
164 |
+
"model_id": "3f2c5fb0a67449a39c76d842dc4e6ed5",
|
165 |
+
"version_major": 2,
|
166 |
+
"version_minor": 0
|
167 |
+
},
|
168 |
+
"text/plain": [
|
169 |
+
" 0%| | 0/348 [00:00<?, ?it/s]"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
"metadata": {},
|
173 |
+
"output_type": "display_data"
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"name": "stdout",
|
177 |
+
"output_type": "stream",
|
178 |
+
"text": [
|
179 |
+
"{'train_runtime': 25354.0984, 'train_samples_per_second': 0.877, 'train_steps_per_second': 0.014, 'train_loss': 2.2603482651984557, 'epoch': 2.0}\n"
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"data": {
|
184 |
+
"text/plain": [
|
185 |
+
"TrainOutput(global_step=348, training_loss=2.2603482651984557, metrics={'train_runtime': 25354.0984, 'train_samples_per_second': 0.877, 'train_steps_per_second': 0.014, 'train_loss': 2.2603482651984557, 'epoch': 2.0})"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
"execution_count": 7,
|
189 |
+
"metadata": {},
|
190 |
+
"output_type": "execute_result"
|
191 |
+
}
|
192 |
+
],
|
193 |
+
"source": [
|
194 |
+
"# Fine-tune the model\n",
|
195 |
+
"trainer.train()"
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": 8,
|
201 |
+
"metadata": {},
|
202 |
+
"outputs": [
|
203 |
+
{
|
204 |
+
"data": {
|
205 |
+
"application/vnd.jupyter.widget-view+json": {
|
206 |
+
"model_id": "6a3612bcb25b4efcbb39d2c15d048b07",
|
207 |
+
"version_major": 2,
|
208 |
+
"version_minor": 0
|
209 |
+
},
|
210 |
+
"text/plain": [
|
211 |
+
" 0%| | 0/1 [00:00<?, ?it/s]"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
"metadata": {},
|
215 |
+
"output_type": "display_data"
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"data": {
|
219 |
+
"application/vnd.jupyter.widget-view+json": {
|
220 |
+
"model_id": "0e147b6f77194fb58a7f1869dbd46be0",
|
221 |
+
"version_major": 2,
|
222 |
+
"version_minor": 0
|
223 |
+
},
|
224 |
+
"text/plain": [
|
225 |
+
" 0%| | 0/16 [00:00<?, ?it/s]"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
"metadata": {},
|
229 |
+
"output_type": "display_data"
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"name": "stdout",
|
233 |
+
"output_type": "stream",
|
234 |
+
"text": [
|
235 |
+
"Evaluation Results before fine-tuning : 4.766543388366699\n",
|
236 |
+
"Evaluation Results after fine-tuning : 1.8690917491912842\n"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"data": {
|
241 |
+
"application/vnd.jupyter.widget-view+json": {
|
242 |
+
"model_id": "fbc53f3e0f2348928c5321f296dfb472",
|
243 |
+
"version_major": 2,
|
244 |
+
"version_minor": 0
|
245 |
+
},
|
246 |
+
"text/plain": [
|
247 |
+
" 0%| | 0/1 [00:00<?, ?it/s]"
|
248 |
+
]
|
249 |
+
},
|
250 |
+
"metadata": {},
|
251 |
+
"output_type": "display_data"
|
252 |
+
}
|
253 |
+
],
|
254 |
+
"source": [
|
255 |
+
"# Get predictions for validation set before fine tuning for 10 samples\n",
|
256 |
+
"pre_val_predictions = trainer.predict(encoded_dataset['validation'].select(range(10)))\n",
|
257 |
+
"\n",
|
258 |
+
"# Evaluate after fine-tuning\n",
|
259 |
+
"post_eval_results = trainer.evaluate(encoded_dataset['validation'])\n",
|
260 |
+
"\n",
|
261 |
+
"# Print the evaluation losses before and after fine-tuning\n",
|
262 |
+
"print('Evaluation Results before fine-tuning :', pre_eval_results['eval_loss'])\n",
|
263 |
+
"print('Evaluation Results after fine-tuning :', post_eval_results['eval_loss'])\n",
|
264 |
+
"\n",
|
265 |
+
"# Get predictions for validation set before fine tuning for 10 samples\n",
|
266 |
+
"post_val_predictions = trainer.predict(encoded_dataset['validation'].select(range(10)))\n",
|
267 |
+
"\n",
|
268 |
+
"# Zip the pre and post tuning predictions\n",
|
269 |
+
"predictions = zip(pre_val_predictions.predictions, post_val_predictions.predictions)"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "code",
|
274 |
+
"execution_count": 9,
|
275 |
+
"metadata": {},
|
276 |
+
"outputs": [
|
277 |
+
{
|
278 |
+
"name": "stdout",
|
279 |
+
"output_type": "stream",
|
280 |
+
"text": [
|
281 |
+
"Ground truth \n",
|
282 |
+
"Good morning , sir . Is there a bank near here ? There is one . 5 blocks away from here ? Well , that's too far.Can you change some money for me ? Surely , of course . What kind of currency have you got ? RIB . How much would you like to change ? 1000 Yuan.Here you are . \n",
|
283 |
+
"\n",
|
284 |
+
"Pre-prediction \n",
|
285 |
+
" and, sir. there anything problem here here? Yes is. in, away. here. Yes, I's a far.How you tell the money for me? Sure. sir course. Here's of money do you got? IIB. Here much is you like to exchange? I R.How you are. \n",
|
286 |
+
"\n",
|
287 |
+
"Post-prediction \n",
|
288 |
+
" and, sir. there anything problem here here? Yes is. in, away. here. Yes, I's a far.How you tell the money for me? Sure. sir course. Here's of money do you got? IIB. Here much is you like to exchange? I R.How you are. \n",
|
289 |
+
"\n",
|
290 |
+
"----------------------------------------------------------------------------------------------------------------------\n",
|
291 |
+
"\n",
|
292 |
+
"Ground truth \n",
|
293 |
+
"Good afternoon . This is Michelle Li speaking , calling on behalf of IBA . Is Mr Meng available at all ? This is Mr Meng speaking , Michelle . Oh , hello ! Sorry about that . I'm just calling to say that we've received your new Corporate Credit Card from HQ . That was quick ! I wasn't expecting it until later this week . Yes , our application procedures have speeded up since we started using the new fast-track system . Shall I come in and collect it ? Or we can send it to you . But if you would like to use it at the ATM , you'll need to wait for your PIN number . Mmmm ... if I come in and collect it this afternoon , is there any way I could use it today ? Petty cash is getting low , so I need to draw some money . As long as you bring your ID , etc , we can serve you over the counter . But you won't be able to use the ATM until your new PIN number arrives . I see . Yes , that's fine . I'll be there at around 2:30 pm . See you later , and thanks . \n",
|
294 |
+
"\n",
|
295 |
+
"Pre-prediction \n",
|
296 |
+
" and, I is Mr... the for you of theKE. there. here? the? Yes is Mr Meng.. calling Li I I, I. I, the. I'm sorry a to see that I have got your letter contract Account Card. China. Oh's fast. Thank'm't expecting that to now. afternoon. I, I new was are beened up. then received working it new card cardtrack system. I we call in now check it? Yes you can go it to you. you you don like to come it,\n",
|
297 |
+
"\n",
|
298 |
+
"Post-prediction \n",
|
299 |
+
" and, I is Mr... the for you of theKE. there. here? the? Yes is Mr Meng.. calling Li I I, I. I, the. I'm sorry a to see that I have got your letter contract Account Card. China. Oh's fast. Thank'm't expecting that to now. afternoon. I, I new was are beened up. then received working it new card cardtrack system. I we call in now check it? Yes you can go it to you. you you don like to come it,\n",
|
300 |
+
"\n",
|
301 |
+
"----------------------------------------------------------------------------------------------------------------------\n",
|
302 |
+
"\n",
|
303 |
+
"Ground truth \n",
|
304 |
+
"What qualifications should a reporter have ? As a reporter , he must have acute insight and language skills . At the same time , he must have good judgment , the respect for his job and tactical cooperation with others . Can you work under pressure ? You know , people working here are all busy everyday since we're daily newspaper . I think I've got used to work under pressure . I will adjust myself to the step of your newspaper quickly . \n",
|
305 |
+
"\n",
|
306 |
+
"Pre-prediction \n",
|
307 |
+
" and do I person be? I a reporter, I should be a knowledge and knowledge skills. least same time, he must be a knowledge. and ability of the colleagues, the ability. the. What you tell for pressure? must, the are under are always very.. the have in news. Yes know so can got the to it under pressure. I'm try my. the new of pressure question.. \n",
|
308 |
+
"\n",
|
309 |
+
"Post-prediction \n",
|
310 |
+
" and do I person be? I a reporter, I should be a knowledge and knowledge skills. least same time, he must be a knowledge. and ability of the colleagues, the ability. the. What you tell for pressure? must, the are under are always very.. the have in news. Yes know so can got the to it under pressure. I'm try my. the new of pressure question.. \n",
|
311 |
+
"\n",
|
312 |
+
"----------------------------------------------------------------------------------------------------------------------\n",
|
313 |
+
"\n",
|
314 |
+
"Ground truth \n",
|
315 |
+
"Hi , good morning , Miss ? what can I help you with ? Good morning I'd like to mail this box of books to Taiwan . OK , please put it on this scale.Airmail or by sea ? How long does it take to reach Taiwan by sea ? Usually about two month . That's too long.How long does it take to reach Taiwan by airmail ? About ten days . Then how much is that by airmail ? Let me see.It ' s 57 dollars , 20 cents , including tax . That's a little bit expensive . Although it's expensive to send it by airmail , it's quicker and safer than by sea . I guess I have to send it by airmail . Do you want to ensure the contents , Miss ? Yes , please . Please fill out this form , also please write the value of the items in this space . OK . \n",
|
316 |
+
"\n",
|
317 |
+
"Pre-prediction \n",
|
318 |
+
" and I morning. I. can I do you with? I morning,'m like to buy some letter to envelop to my. I, I send it in the shelf.How fewail. express air. By about will it take to ship China? sea? About it ten weeks. OK's not long.How much does it take to get China by air?? About three days. That I long is it? seamail? About me see.It's s about yuan. including dollars per and postage. That's fine little \n",
|
319 |
+
"\n",
|
320 |
+
"Post-prediction \n",
|
321 |
+
" and I morning. I. can I do you with? I morning,'m like to buy some letter to envelop to my. I, I send it in the shelf.How fewail. express air. By about will it take to ship China? sea? About it ten weeks. OK's not long.How much does it take to get China by air?? About three days. That I long is it? seamail? About me see.It's s about yuan. including dollars per and postage. That's fine little \n",
|
322 |
+
"\n",
|
323 |
+
"----------------------------------------------------------------------------------------------------------------------\n",
|
324 |
+
"\n",
|
325 |
+
"Ground truth \n",
|
326 |
+
"Excuse me , ma'am . Can you tell me where the nearest postoffice is ? Of course . Go straight ahead . Turn right at the next street . You'll see a tall , yellow building.The post office is on the first floor . Do you mean that I go that way for one block , then turn right ? Yes , you are right . Is it far ? No , It's only about five minutes ' walk . Thank you very much . It's my pleasure . \n",
|
327 |
+
"\n",
|
328 |
+
"Pre-prediction \n",
|
329 |
+
" and me, sir'am. I I tell me where the bus train office is? Yes course. It to to. right at the first intersection. 'll see it post building white building.It post office is on the right floor. Thank you know the the can straight way? the post? or turn right at Yes. that go right. The Thank there a from No, it's not a a blocks walk walk. I you. much. You's a pleasure. \n",
|
330 |
+
"\n",
|
331 |
+
"Post-prediction \n",
|
332 |
+
" and me, sir'am. I I tell me where the bus train office is? Yes course. It to to. right at the first intersection. 'll see it post building white building.It post office is on the right floor. Thank you know the the can straight way? the post? or turn right at Yes. that go right. The Thank there a from No, it's not a a blocks walk walk. I you. much. You's a pleasure. \n",
|
333 |
+
"\n",
|
334 |
+
"----------------------------------------------------------------------------------------------------------------------\n",
|
335 |
+
"\n",
|
336 |
+
"Ground truth \n",
|
337 |
+
"Could you give me some advice on how to bring up my son properly ? He's a bright boy , isn't he ? But he always wimps out of difficulty . Don't worry , he'll make good progress step by step . \n",
|
338 |
+
"\n",
|
339 |
+
"Pre-prediction \n",
|
340 |
+
" and tell me a advice? how to get my my resume?? You's a good young. isn't he? Yes he's getsagsps on of bed. I't worry, he'll get up use. by step. \n",
|
341 |
+
"\n",
|
342 |
+
"Post-prediction \n",
|
343 |
+
" and tell me a advice? how to get my my resume?? You's a good young. isn't he? Yes he's getsagsps on of bed. I't worry, he'll get up use. by step. \n",
|
344 |
+
"\n",
|
345 |
+
"----------------------------------------------------------------------------------------------------------------------\n",
|
346 |
+
"\n",
|
347 |
+
"Ground truth \n",
|
348 |
+
"I'm in 507 . I have a few problems with my room . What is that problem , sir ? There are cockroaches in my room . Are you sure , sir ? Flies I could believe , but cockroaches ? I've counted nine different cockroaches , and I accidentally stepped on another one . Sir , we run a spotless and cockroach-less hotel . You dare to doubt me ? I'm sorry , sir . Let me transfer you to my supervisor . \n",
|
349 |
+
"\n",
|
350 |
+
"Pre-prediction \n",
|
351 |
+
" and sorry the8.'m a friend questions. my car. What are the?? Mr? I's aroaches in my room. What you sure you sir? ies are see see. but cockroaches? Yes'm never them cock cockroaches in sir I've counted on one one. I, I have a bug check room cleanroach freefree room. What mean to step the? I'm not, sir. I me check you to the room. \n",
|
352 |
+
"\n",
|
353 |
+
"Post-prediction \n",
|
354 |
+
" and sorry the8.'m a friend questions. my car. What are the?? Mr? I's aroaches in my room. What you sure you sir? ies are see see. but cockroaches? Yes'm never them cock cockroaches in sir I've counted on one one. I, I have a bug check room cleanroach freefree room. What mean to step the? I'm not, sir. I me check you to the room. \n",
|
355 |
+
"\n",
|
356 |
+
"----------------------------------------------------------------------------------------------------------------------\n",
|
357 |
+
"\n",
|
358 |
+
"Ground truth \n",
|
359 |
+
"Excuse me , sir , I'm afraid you can't park your car here . Why not ? It's my parking space . I'm afraid not , sir . Oh ? That's a surprise . Let me see ... D 0411 Our dog's birthday . Yes , I'm sure this my parking space ! But I saw a red car always parking here before . Oh , we've just repainted our car . It was red . Maybe . But the car of this space has a broken rearview mirror on the left . Yeah . It used to . We got that fixed yesterday too . Could you wait for a minute , sir ? I'd like to have a check . Sure , go ahead . Sorry , sir , my mistake . This is your parking space . That's all right . It's not your fault . \n",
|
360 |
+
"\n",
|
361 |
+
"Pre-prediction \n",
|
362 |
+
" and me, sir. can'm afraid I have't have here car here. I not?'s a car space. I'm sorry I. sir. I, What's a shame. me see. ang., car is name is , it'm afraid you is car space. Oh I'm a car car parked parked there.. I, I have got gotainted it car.'s a before Oh you But it car is the car is a red window window mirror. the left side Oh, I's to be've it\n",
|
363 |
+
"\n",
|
364 |
+
"Post-prediction \n",
|
365 |
+
" and me, sir. can'm afraid I have't have here car here. I not?'s a car space. I'm sorry I. sir. I, What's a shame. me see. ang., car is name is , it'm afraid you is car space. Oh I'm a car car parked parked there.. I, I have got gotainted it car.'s a before Oh you But it car is the car is a red window window mirror. the left side Oh, I's to be've it\n",
|
366 |
+
"\n",
|
367 |
+
"----------------------------------------------------------------------------------------------------------------------\n",
|
368 |
+
"\n",
|
369 |
+
"Ground truth \n",
|
370 |
+
"What can I do for you today ? I need to buy a new refrigerator today . Were you looking at a particular refrigerator ? I like that Kenmore refrigerator . This particular refrigerator is a very good choice . Tell me about it . Not only is it affordable , but it comes with all the appliances . What are the appliances . It has an ice maker , water dispenser , and plenty of room on the inside . I'd like to see it for myself . Go right ahead . I like what I see . \n",
|
371 |
+
"\n",
|
372 |
+
"Pre-prediction \n",
|
373 |
+
" and I do for you?? I'd to get a new car.. What you able at the new brand? Yes was the onemore refrigerator. What one refrigerator is a Ken popular refrigerator. I me more it. It only is it a, it it's with a the appliances. I appliances the appliances? They comes a electric maker, a heaterer, and a of other for the bottom. What like like to buy it. a. It ahead ahead. I'll it I see. \n",
|
374 |
+
"\n",
|
375 |
+
"Post-prediction \n",
|
376 |
+
" and I do for you?? I'd to get a new car.. What you able at the new brand? Yes was the onemore refrigerator. What one refrigerator is a Ken popular refrigerator. I me more it. It only is it a, it it's with a the appliances. I appliances the appliances? They comes a electric maker, a heaterer, and a of other for the bottom. What like like to buy it. a. It ahead ahead. I'll it I see. \n",
|
377 |
+
"\n",
|
378 |
+
"----------------------------------------------------------------------------------------------------------------------\n",
|
379 |
+
"\n",
|
380 |
+
"Ground truth \n",
|
381 |
+
"Oh , well . It was fun to be the winner . But ... it's too big . I must be an extra small in the States . So what about the tennis racket ? Look ! It's amazing . I can't wait to try it out ! How much did that end up costing you ? Oh ... around twenty bucks . A bargain if you ask me . Look at the picture of her playing with it ! Hey , two for one . That's a super deal . And here's her signature ! \n",
|
382 |
+
"\n",
|
383 |
+
"Pre-prediction \n",
|
384 |
+
" and I, I's nice. meet there first. I I was not late. I'm go the old.. the company. You,? the other court? It, It's a! It can't believe to use it.. I much is you cost up costing?? It, it ten dollars. little! you ask me. , the racket. the.. it. Oh, I dollars twenty!'s a bargain bargain! I you I another purse. \n",
|
385 |
+
"\n",
|
386 |
+
"Post-prediction \n",
|
387 |
+
" and I, I's nice. meet there first. I I was not late. I'm go the old.. the company. You,? the other court? It, It's a! It can't believe to use it.. I much is you cost up costing?? It, it ten dollars. little! you ask me. , the racket. the.. it. Oh, I dollars twenty!'s a bargain bargain! I you I another purse. \n",
|
388 |
+
"\n",
|
389 |
+
"----------------------------------------------------------------------------------------------------------------------\n",
|
390 |
+
"\n"
|
391 |
+
]
|
392 |
+
}
|
393 |
+
],
|
394 |
+
"source": [
|
395 |
+
"for idx, (pre, post) in enumerate(predictions):\n",
|
396 |
+
" pre_pred = tokenizer.decode(np.argmax(pre, axis=-1), skip_special_tokens=True)\n",
|
397 |
+
" post_pred = tokenizer.decode(np.argmax(post, axis=-1), skip_special_tokens=True)\n",
|
398 |
+
" ground_truth = encoded_dataset['validation'][idx][\"dialog\"]\n",
|
399 |
+
" \n",
|
400 |
+
" print('Ground truth \\n' + ground_truth + '\\n')\n",
|
401 |
+
" print('Pre-prediction \\n' + pre_pred + '\\n')\n",
|
402 |
+
" print('Post-prediction \\n'+ post_pred + '\\n')\n",
|
403 |
+
" print('----------------------------------------------------------------------------------------------------------------------\\n')"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"cell_type": "code",
|
408 |
+
"execution_count": 13,
|
409 |
+
"metadata": {},
|
410 |
+
"outputs": [],
|
411 |
+
"source": [
|
412 |
+
"tokenizer.save_pretrained(\"saved_model\")\n",
|
413 |
+
"model.save_pretrained(\"saved_model\")"
|
414 |
+
]
|
415 |
+
}
|
416 |
+
],
|
417 |
+
"metadata": {
|
418 |
+
"kernelspec": {
|
419 |
+
"display_name": ".venv",
|
420 |
+
"language": "python",
|
421 |
+
"name": "python3"
|
422 |
+
},
|
423 |
+
"language_info": {
|
424 |
+
"codemirror_mode": {
|
425 |
+
"name": "ipython",
|
426 |
+
"version": 3
|
427 |
+
},
|
428 |
+
"file_extension": ".py",
|
429 |
+
"mimetype": "text/x-python",
|
430 |
+
"name": "python",
|
431 |
+
"nbconvert_exporter": "python",
|
432 |
+
"pygments_lexer": "ipython3",
|
433 |
+
"version": "3.11.5"
|
434 |
+
},
|
435 |
+
"orig_nbformat": 4
|
436 |
+
},
|
437 |
+
"nbformat": 4,
|
438 |
+
"nbformat_minor": 2
|
439 |
+
}
|
Test_saved_model.ipynb
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from transformers import pipeline, Conversation\n",
|
10 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 2,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"# Loading\n",
|
20 |
+
"tok = AutoTokenizer.from_pretrained(\"saved_model\")\n",
|
21 |
+
"mod = AutoModelForCausalLM.from_pretrained(\"saved_model\")\n",
|
22 |
+
"\n",
|
23 |
+
"chatbot = pipeline(\"conversational\", model = mod, tokenizer = tok)"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": 3,
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [
|
31 |
+
{
|
32 |
+
"name": "stderr",
|
33 |
+
"output_type": "stream",
|
34 |
+
"text": [
|
35 |
+
"A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
|
36 |
+
"A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
|
37 |
+
"A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
|
38 |
+
"A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"data": {
|
43 |
+
"text/plain": [
|
44 |
+
"Conversation id: 20e0c3eb-e549-4c61-96d5-831eb3af1933 \n",
|
45 |
+
"user >> Hello \n",
|
46 |
+
"bot >> Hi, I'm here to talk to you. \n",
|
47 |
+
"user >> How are you? \n",
|
48 |
+
"bot >> I'm fine. How are you? \n",
|
49 |
+
"user >> I'm good, do you want to watch a movie today? \n",
|
50 |
+
"bot >> Sure, I'll watch it. What movie? \n",
|
51 |
+
"user >> What about Lalaland? \n",
|
52 |
+
"bot >> That's a good one. I'll watch it. "
|
53 |
+
]
|
54 |
+
},
|
55 |
+
"execution_count": 3,
|
56 |
+
"metadata": {},
|
57 |
+
"output_type": "execute_result"
|
58 |
+
}
|
59 |
+
],
|
60 |
+
"source": [
|
61 |
+
"user_input = \"Hello\"\n",
|
62 |
+
"conversation = Conversation(user_input)\n",
|
63 |
+
"conversation = chatbot(conversation, pad_token_id=chatbot.tokenizer.eos_token_id)\n",
|
64 |
+
"reply = conversation.generated_responses\n",
|
65 |
+
"reply = reply[0].split(\" \")[0]\n",
|
66 |
+
"conversation.generated_responses = [reply]\n",
|
67 |
+
"\n",
|
68 |
+
"conversation.add_user_input(\"How are you?\")\n",
|
69 |
+
"conversation = chatbot(conversation, pad_token_id=chatbot.tokenizer.eos_token_id)\n",
|
70 |
+
"conversation.add_user_input(\"I'm good, do you want to watch a movie today?\")\n",
|
71 |
+
"conversation = chatbot(conversation, pad_token_id=chatbot.tokenizer.eos_token_id)\n",
|
72 |
+
"conversation.add_user_input(\"What about Lalaland?\")\n",
|
73 |
+
"conversation = chatbot(conversation, pad_token_id=chatbot.tokenizer.eos_token_id)\n",
|
74 |
+
"\n",
|
75 |
+
"conversation"
|
76 |
+
]
|
77 |
+
}
|
78 |
+
],
|
79 |
+
"metadata": {
|
80 |
+
"kernelspec": {
|
81 |
+
"display_name": ".venv",
|
82 |
+
"language": "python",
|
83 |
+
"name": "python3"
|
84 |
+
},
|
85 |
+
"language_info": {
|
86 |
+
"codemirror_mode": {
|
87 |
+
"name": "ipython",
|
88 |
+
"version": 3
|
89 |
+
},
|
90 |
+
"file_extension": ".py",
|
91 |
+
"mimetype": "text/x-python",
|
92 |
+
"name": "python",
|
93 |
+
"nbconvert_exporter": "python",
|
94 |
+
"pygments_lexer": "ipython3",
|
95 |
+
"version": "3.11.5"
|
96 |
+
},
|
97 |
+
"orig_nbformat": 4
|
98 |
+
},
|
99 |
+
"nbformat": 4,
|
100 |
+
"nbformat_minor": 2
|
101 |
+
}
|
dataset_format.xlsx
ADDED
Binary file (10.3 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
numpy == 1.25.2
|
2 |
+
datasets == 2.14.5
|
3 |
+
transformers == 4.33.1
|