Preechanon Chatthai commited on
Commit
2b76585
1 Parent(s): 64fda11

Upload 2 files

Browse files
Files changed (2) hide show
  1. Finetune.ipynb +263 -0
  2. requirements.txt +11 -0
Finetune.ipynb ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "0e7385a4",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "C:\\Users\\preec\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n"
15
+ ]
16
+ },
17
+ {
18
+ "data": {
19
+ "text/plain": [
20
+ "DatasetDict({\n",
21
+ " train: Dataset({\n",
22
+ " features: ['title', 'body', 'summary', 'type', 'tags', 'url'],\n",
23
+ " num_rows: 358868\n",
24
+ " })\n",
25
+ " validation: Dataset({\n",
26
+ " features: ['title', 'body', 'summary', 'type', 'tags', 'url'],\n",
27
+ " num_rows: 11000\n",
28
+ " })\n",
29
+ " test: Dataset({\n",
30
+ " features: ['title', 'body', 'summary', 'type', 'tags', 'url'],\n",
31
+ " num_rows: 11000\n",
32
+ " })\n",
33
+ "})"
34
+ ]
35
+ },
36
+ "execution_count": 1,
37
+ "metadata": {},
38
+ "output_type": "execute_result"
39
+ }
40
+ ],
41
+ "source": [
42
+ "from datasets import load_dataset\n",
43
+ "\n",
44
+ "ds = load_dataset(\"thaisum\")\n",
45
+ "ds"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "337b3bc6",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "from datasets import load_dataset\n",
56
+ "from datasets import DatasetDict \n",
57
+ "\n",
58
+ "dataset = load_dataset('csv', data_files='thaisum.csv')\n",
59
+ "ds_train_devtest = dataset['train'].train_test_split(test_size=0.05, seed=42)\n",
60
+ "ds_devtest = ds_train_devtest['test'].train_test_split(test_size=0.5, seed=42)\n",
61
+ "\n",
62
+ "\n",
63
+ "ds_thai_news = DatasetDict({\n",
64
+ " 'train': ds_train_devtest['train'],\n",
65
+ " 'valid': ds_devtest['train'],\n",
66
+ " 'test': ds_devtest['test']\n",
67
+ "})\n",
68
+ "ds_thai_news"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 2,
74
+ "id": "286cbb13-5fff-4291-bdd7-3e4ddf972228",
75
+ "metadata": {},
76
+ "outputs": [
77
+ {
78
+ "name": "stderr",
79
+ "output_type": "stream",
80
+ "text": [
81
+ "C:\\Users\\preec\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\transformers\\utils\\generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
82
+ " _torch_pytree._register_pytree_node(\n",
83
+ "C:\\Users\\preec\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\transformers\\utils\\generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
84
+ " _torch_pytree._register_pytree_node(\n",
85
+ "C:\\Users\\preec\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\transformers\\utils\\generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
86
+ " _torch_pytree._register_pytree_node(\n"
87
+ ]
88
+ }
89
+ ],
90
+ "source": [
91
+ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig\n",
92
+ "import torch\n",
93
+ "\n",
94
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
95
+ "\n",
96
+ "mt5_config = AutoConfig.from_pretrained(\n",
97
+ " \"../mt5-base-thaisum-text-summarization\",\n",
98
+ " local_files_only=True,\n",
99
+ " max_length=140,\n",
100
+ " min_length=40,\n",
101
+ " length_penalty=1.2,\n",
102
+ " no_repeat_ngram_size=2,\n",
103
+ " num_beams=15,\n",
104
+ ")\n",
105
+ "\n",
106
+ "tokenizer = AutoTokenizer.from_pretrained(\"../mt5-base-thaisum-text-summarization\", local_files_only=True)\n",
107
+ "model = AutoModelForSeq2SeqLM.from_pretrained(\"../mt5-base-thaisum-text-summarization\", local_files_only=True).to(device)"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": 3,
113
+ "id": "ebfdf213",
114
+ "metadata": {},
115
+ "outputs": [
116
+ {
117
+ "name": "stderr",
118
+ "output_type": "stream",
119
+ "text": [
120
+ "Map: 100%|██████████| 11000/11000 [00:17<00:00, 622.18 examples/s]\n"
121
+ ]
122
+ }
123
+ ],
124
+ "source": [
125
+ "from transformers import DataCollatorForSeq2Seq\n",
126
+ "data_collator = DataCollatorForSeq2Seq(\n",
127
+ " tokenizer,\n",
128
+ " model=model,\n",
129
+ " return_tensors=\"pt\")\n",
130
+ "\n",
131
+ "def tokenize_data(data):\n",
132
+ "\n",
133
+ " input_feature = tokenizer(data[\"body\"], truncation=True, max_length=512)\n",
134
+ " label = tokenizer(data[\"summary\"], truncation=True, max_length=140)\n",
135
+ " return {\n",
136
+ " \"input_ids\": input_feature[\"input_ids\"],\n",
137
+ " \"attention_mask\": input_feature[\"attention_mask\"],\n",
138
+ " \"labels\": label[\"input_ids\"],\n",
139
+ " }\n",
140
+ "\n",
141
+ "token_ds_thai_news = ds.map(\n",
142
+ " tokenize_data,\n",
143
+ " remove_columns=['title', 'body', 'summary', 'type', 'tags', 'url'],\n",
144
+ " batched=True,\n",
145
+ " batch_size=64)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 4,
151
+ "id": "a01f4771",
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "import evaluate\n",
156
+ "import numpy as np\n",
157
+ "def tokenize_sentence(arg):\n",
158
+ " encoded_arg = tokenizer(arg)\n",
159
+ " return tokenizer.convert_ids_to_tokens(encoded_arg.input_ids)\n",
160
+ "\n",
161
+ "def metrics_func(eval_arg):\n",
162
+ " preds, labels = eval_arg\n",
163
+ " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
164
+ " text_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
165
+ " text_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
166
+ "\n",
167
+ " return rouge_metric.compute(\n",
168
+ " predictions=text_preds,\n",
169
+ " references=text_labels,\n",
170
+ " tokenizer=tokenize_sentence\n",
171
+ " )\n",
172
+ "rouge_metric = evaluate.load(\"rouge\")"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 5,
178
+ "id": "5d0f286b",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "from transformers import Seq2SeqTrainingArguments\n",
183
+ "\n",
184
+ "training_args = Seq2SeqTrainingArguments(\n",
185
+ " output_dir = \"..\",\n",
186
+ " log_level = \"error\",\n",
187
+ " num_train_epochs = 6,\n",
188
+ " learning_rate = 5e-4,\n",
189
+ " warmup_steps = 5000,\n",
190
+ " weight_decay=0.01,\n",
191
+ " per_device_train_batch_size = 8,\n",
192
+ " per_device_eval_batch_size = 1,\n",
193
+ " gradient_accumulation_steps = 4,\n",
194
+ " evaluation_strategy = \"steps\",\n",
195
+ " eval_steps = 100,\n",
196
+ " predict_with_generate=True,\n",
197
+ " generation_max_length = 140,\n",
198
+ " save_steps = 3000,\n",
199
+ " logging_steps = 10,\n",
200
+ " push_to_hub = False,\n",
201
+ " remove_unused_columns=False\n",
202
+ ")\n"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "33e02416",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "from transformers import Seq2SeqTrainer\n",
213
+ "trainer = Seq2SeqTrainer(\n",
214
+ " model = model,\n",
215
+ " args = training_args,\n",
216
+ " data_collator = data_collator,\n",
217
+ " compute_metrics = metrics_func,\n",
218
+ " train_dataset = token_ds_thai_news[\"train\"],\n",
219
+ " eval_dataset = token_ds_thai_news[\"valid\"].select(range(30)),\n",
220
+ " tokenizer = tokenizer,\n",
221
+ ")"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "id": "1048d26c",
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "import os\n",
232
+ "from transformers import AutoModelForSeq2SeqLM\n",
233
+ "\n",
234
+ "os.makedirs(\"./trained_for_summarization\", exist_ok=True)\n",
235
+ "if hasattr(trainer.model, \"module\"):\n",
236
+ " trainer.model.module.save_pretrained(\"./trained_for_summarization\")\n",
237
+ "else:\n",
238
+ " trainer.model.save_pretrained(\"./trained_for_summarization\")"
239
+ ]
240
+ }
241
+ ],
242
+ "metadata": {
243
+ "kernelspec": {
244
+ "display_name": "Python 3 (ipykernel)",
245
+ "language": "python",
246
+ "name": "python3"
247
+ },
248
+ "language_info": {
249
+ "codemirror_mode": {
250
+ "name": "ipython",
251
+ "version": 3
252
+ },
253
+ "file_extension": ".py",
254
+ "mimetype": "text/x-python",
255
+ "name": "python",
256
+ "nbconvert_exporter": "python",
257
+ "pygments_lexer": "ipython3",
258
+ "version": "3.12.2"
259
+ }
260
+ },
261
+ "nbformat": 4,
262
+ "nbformat_minor": 5
263
+ }
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.36.1
2
+ numpy
3
+ datasets
4
+ nltk
5
+ pythainlp
6
+ rouge_score
7
+ evaluate
8
+ --index-url https://download.pytorch.org/whl/cu118
9
+ torch
10
+ torchvision
11
+ torchaudio --index-url https://download.pytorch.org/whl/cu118