DeFamy commited on
Commit
e3650fc
1 Parent(s): d2cc008

Upload train_model.ipynb

Browse files
Files changed (1) hide show
  1. train_model.ipynb +928 -0
train_model.ipynb ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "source": [],
6
+ "metadata": {
7
+ "id": "b3_hlnrYh30E"
8
+ },
9
+ "execution_count": null,
10
+ "outputs": []
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "source": [
15
+ "Acknowledgement:\n",
16
+ "Thanks to @RajKKapadia <br>\n",
17
+ "Link: https://github.com/RajKKapadia/Transformers-Text-Classification-BERT-Blog"
18
+ ],
19
+ "metadata": {
20
+ "id": "zK4VsKufh4gZ"
21
+ }
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "metadata": {
26
+ "id": "XLhB2j_Hemio"
27
+ },
28
+ "source": [
29
+ "## Read the dataset csv file"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {
36
+ "id": "hgYEtrYgemir",
37
+ "outputId": "d3ddedc7-8bd7-4ba9-c82e-68e4eb1309c3"
38
+ },
39
+ "outputs": [
40
+ {
41
+ "data": {
42
+ "text/html": [
43
+ "<div>\n",
44
+ "<style scoped>\n",
45
+ " .dataframe tbody tr th:only-of-type {\n",
46
+ " vertical-align: middle;\n",
47
+ " }\n",
48
+ "\n",
49
+ " .dataframe tbody tr th {\n",
50
+ " vertical-align: top;\n",
51
+ " }\n",
52
+ "\n",
53
+ " .dataframe thead th {\n",
54
+ " text-align: right;\n",
55
+ " }\n",
56
+ "</style>\n",
57
+ "<table border=\"1\" class=\"dataframe\">\n",
58
+ " <thead>\n",
59
+ " <tr style=\"text-align: right;\">\n",
60
+ " <th></th>\n",
61
+ " <th>Unnamed: 0</th>\n",
62
+ " <th>Text</th>\n",
63
+ " <th>target</th>\n",
64
+ " </tr>\n",
65
+ " </thead>\n",
66
+ " <tbody>\n",
67
+ " <tr>\n",
68
+ " <th>0</th>\n",
69
+ " <td>0.0</td>\n",
70
+ " <td>polis tangkap</td>\n",
71
+ " <td>NonCyberbully</td>\n",
72
+ " </tr>\n",
73
+ " <tr>\n",
74
+ " <th>1</th>\n",
75
+ " <td>1.0</td>\n",
76
+ " <td>kenapa lokasi kebakaran terlalu spesifik</td>\n",
77
+ " <td>NonCyberbully</td>\n",
78
+ " </tr>\n",
79
+ " <tr>\n",
80
+ " <th>2</th>\n",
81
+ " <td>2.0</td>\n",
82
+ " <td>menyesal tanya nak for birthday</td>\n",
83
+ " <td>NonCyberbully</td>\n",
84
+ " </tr>\n",
85
+ " <tr>\n",
86
+ " <th>3</th>\n",
87
+ " <td>3.0</td>\n",
88
+ " <td>meriah tah</td>\n",
89
+ " <td>NonCyberbully</td>\n",
90
+ " </tr>\n",
91
+ " <tr>\n",
92
+ " <th>4</th>\n",
93
+ " <td>4.0</td>\n",
94
+ " <td>asal bs kelar kerja jam sik kl baru diajak mee...</td>\n",
95
+ " <td>NonCyberbully</td>\n",
96
+ " </tr>\n",
97
+ " </tbody>\n",
98
+ "</table>\n",
99
+ "</div>"
100
+ ],
101
+ "text/plain": [
102
+ " Unnamed: 0 Text \\\n",
103
+ "0 0.0 polis tangkap \n",
104
+ "1 1.0 kenapa lokasi kebakaran terlalu spesifik \n",
105
+ "2 2.0 menyesal tanya nak for birthday \n",
106
+ "3 3.0 meriah tah \n",
107
+ "4 4.0 asal bs kelar kerja jam sik kl baru diajak mee... \n",
108
+ "\n",
109
+ " target \n",
110
+ "0 NonCyberbully \n",
111
+ "1 NonCyberbully \n",
112
+ "2 NonCyberbully \n",
113
+ "3 NonCyberbully \n",
114
+ "4 NonCyberbully "
115
+ ]
116
+ },
117
+ "execution_count": 3,
118
+ "metadata": {},
119
+ "output_type": "execute_result"
120
+ }
121
+ ],
122
+ "source": [
123
+ "import pandas as pd\n",
124
+ "df = pd.read_csv('C:/Users/user/Documents/PSM/BERT_Ver2/Transformers-Text-Classification-BERT-Blog-main/input/Tagged_MixedNew.csv')\n",
125
+ "df.head()"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "metadata": {
131
+ "id": "fGUtFkVfemit"
132
+ },
133
+ "source": [
134
+ "## Process the data"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {
141
+ "id": "7C3uRWECemiu",
142
+ "outputId": "8e764d84-010d-4e42-987a-af7162627f6e",
143
+ "colab": {
144
+ "referenced_widgets": [
145
+ "042c8b0b8dcf42eb84660c93778d8ea7",
146
+ "4ab6074437a849f79be038b043025283",
147
+ "9aed4d88c18e4e28a1efbbed94331228"
148
+ ]
149
+ }
150
+ },
151
+ "outputs": [
152
+ {
153
+ "data": {
154
+ "application/vnd.jupyter.widget-view+json": {
155
+ "model_id": "042c8b0b8dcf42eb84660c93778d8ea7",
156
+ "version_major": 2,
157
+ "version_minor": 0
158
+ },
159
+ "text/plain": [
160
+ "Downloading (…)okenizer_config.json: 0%| | 0.00/380 [00:00<?, ?B/s]"
161
+ ]
162
+ },
163
+ "metadata": {},
164
+ "output_type": "display_data"
165
+ },
166
+ {
167
+ "name": "stderr",
168
+ "output_type": "stream",
169
+ "text": [
170
+ "C:\\Users\\user\\anaconda3\\lib\\site-packages\\huggingface_hub\\file_download.py:133: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\user\\.cache\\huggingface\\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
171
+ "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
172
+ " warnings.warn(message)\n"
173
+ ]
174
+ },
175
+ {
176
+ "data": {
177
+ "application/vnd.jupyter.widget-view+json": {
178
+ "model_id": "4ab6074437a849f79be038b043025283",
179
+ "version_major": 2,
180
+ "version_minor": 0
181
+ },
182
+ "text/plain": [
183
+ "Downloading (…)solve/main/vocab.txt: 0%| | 0.00/233k [00:00<?, ?B/s]"
184
+ ]
185
+ },
186
+ "metadata": {},
187
+ "output_type": "display_data"
188
+ },
189
+ {
190
+ "data": {
191
+ "application/vnd.jupyter.widget-view+json": {
192
+ "model_id": "9aed4d88c18e4e28a1efbbed94331228",
193
+ "version_major": 2,
194
+ "version_minor": 0
195
+ },
196
+ "text/plain": [
197
+ "Downloading (…)cial_tokens_map.json: 0%| | 0.00/125 [00:00<?, ?B/s]"
198
+ ]
199
+ },
200
+ "metadata": {},
201
+ "output_type": "display_data"
202
+ }
203
+ ],
204
+ "source": [
205
+ "#from transformers import BertTokenizer\n",
206
+ "#tokenizer = BertTokenizer.from_pretrained('malay-huggingface/bert-tiny-bahasa-cased')\n",
207
+ "\n",
208
+ "from transformers import AutoTokenizer\n",
209
+ "tokenizer = AutoTokenizer.from_pretrained('mesolitica/bert-base-standard-bahasa-cased')"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "metadata": {
216
+ "id": "Ks3XobW0emiu"
217
+ },
218
+ "outputs": [],
219
+ "source": [
220
+ "import numpy as np\n",
221
+ "from sklearn.model_selection import train_test_split\n",
222
+ "from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score\n",
223
+ "import torch\n",
224
+ "from transformers import TrainingArguments, Trainer\n",
225
+ "from transformers import BertTokenizer, BertForSequenceClassification"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "metadata": {
232
+ "id": "0ZZx6mUdemiv"
233
+ },
234
+ "outputs": [],
235
+ "source": [
236
+ "def process_data(row):\n",
237
+ "\n",
238
+ " text = row['Text']\n",
239
+ " text = str(text)\n",
240
+ " text = ' '.join(text.split())\n",
241
+ "\n",
242
+ " encodings = tokenizer(text, padding=\"max_length\", truncation=True, max_length=128)\n",
243
+ "\n",
244
+ " label = 0\n",
245
+ " if row['target'] == 'Cyberbully':\n",
246
+ " label += 1\n",
247
+ "\n",
248
+ " encodings['label'] = label\n",
249
+ " encodings['Text'] = text\n",
250
+ "\n",
251
+ " return encodings"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "metadata": {
258
+ "id": "MaFmqSc-emiv",
259
+ "outputId": "03eb6491-b646-45dd-ef3d-318c81313430"
260
+ },
261
+ "outputs": [
262
+ {
263
+ "name": "stdout",
264
+ "output_type": "stream",
265
+ "text": [
266
+ "{'input_ids': [2, 2039, 3058, 9857, 1606, 1164, 2161, 8062, 1219, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'label': 0, 'Text': 'Saya suka masakan beliau dan cara penyampaiannya'}\n"
267
+ ]
268
+ }
269
+ ],
270
+ "source": [
271
+ "print(process_data({\n",
272
+ " 'Text': 'Saya suka masakan beliau dan cara penyampaiannya',\n",
273
+ " 'target': 'NonCyberbully'\n",
274
+ "}))"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "metadata": {
281
+ "id": "Lel-2lqKemiw"
282
+ },
283
+ "outputs": [],
284
+ "source": [
285
+ "processed_data = []\n",
286
+ "\n",
287
+ "for i in range(len(df[:1383])):\n",
288
+ " processed_data.append(process_data(df.iloc[i]))"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "markdown",
293
+ "metadata": {
294
+ "id": "x_DGsKzHemiw"
295
+ },
296
+ "source": [
297
+ "## Generate the dataset"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": null,
303
+ "metadata": {
304
+ "id": "oc_NsbnXemiw"
305
+ },
306
+ "outputs": [],
307
+ "source": [
308
+ "from sklearn.model_selection import train_test_split\n",
309
+ "\n",
310
+ "new_df = pd.DataFrame(processed_data)\n",
311
+ "\n",
312
+ "train_df, valid_df = train_test_split(\n",
313
+ " new_df,\n",
314
+ " test_size=0.2,\n",
315
+ " random_state=2022\n",
316
+ ")"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": null,
322
+ "metadata": {
323
+ "id": "4qSci5CRemix"
324
+ },
325
+ "outputs": [],
326
+ "source": [
327
+ "import pyarrow as pa\n",
328
+ "from datasets import Dataset\n",
329
+ "\n",
330
+ "train_hg = Dataset(pa.Table.from_pandas(train_df))\n",
331
+ "valid_hg = Dataset(pa.Table.from_pandas(valid_df))"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "metadata": {
338
+ "id": "xDgnim7iemix",
339
+ "outputId": "59858161-59a4-4731-fbfc-7e30a1246eed"
340
+ },
341
+ "outputs": [
342
+ {
343
+ "data": {
344
+ "text/plain": [
345
+ "Dataset({\n",
346
+ " features: ['Text', 'attention_mask', 'input_ids', 'label', 'token_type_ids', '__index_level_0__'],\n",
347
+ " num_rows: 277\n",
348
+ "})"
349
+ ]
350
+ },
351
+ "execution_count": 12,
352
+ "metadata": {},
353
+ "output_type": "execute_result"
354
+ }
355
+ ],
356
+ "source": [
357
+ "valid_hg"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "markdown",
362
+ "metadata": {
363
+ "id": "8Uqq0cKKemiy"
364
+ },
365
+ "source": [
366
+ "## Create a model"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": null,
372
+ "metadata": {
373
+ "id": "QQkDAXmRemiz",
374
+ "outputId": "e00faff0-c7d7-456d-dab2-73d9839c0274",
375
+ "colab": {
376
+ "referenced_widgets": [
377
+ "b9faad28a43547029c8b13ab639f8d05",
378
+ "6175ea4206304020823d86e0bbc23298"
379
+ ]
380
+ }
381
+ },
382
+ "outputs": [
383
+ {
384
+ "data": {
385
+ "application/vnd.jupyter.widget-view+json": {
386
+ "model_id": "b9faad28a43547029c8b13ab639f8d05",
387
+ "version_major": 2,
388
+ "version_minor": 0
389
+ },
390
+ "text/plain": [
391
+ "Downloading (…)lve/main/config.json: 0%| | 0.00/697 [00:00<?, ?B/s]"
392
+ ]
393
+ },
394
+ "metadata": {},
395
+ "output_type": "display_data"
396
+ },
397
+ {
398
+ "data": {
399
+ "application/vnd.jupyter.widget-view+json": {
400
+ "model_id": "6175ea4206304020823d86e0bbc23298",
401
+ "version_major": 2,
402
+ "version_minor": 0
403
+ },
404
+ "text/plain": [
405
+ "Downloading pytorch_model.bin: 0%| | 0.00/443M [00:00<?, ?B/s]"
406
+ ]
407
+ },
408
+ "metadata": {},
409
+ "output_type": "display_data"
410
+ },
411
+ {
412
+ "name": "stderr",
413
+ "output_type": "stream",
414
+ "text": [
415
+ "Some weights of the model checkpoint at mesolitica/bert-base-standard-bahasa-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']\n",
416
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
417
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
418
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at mesolitica/bert-base-standard-bahasa-cased and are newly initialized: ['classifier.bias', 'bert.pooler.dense.bias', 'classifier.weight', 'bert.pooler.dense.weight']\n",
419
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
420
+ ]
421
+ }
422
+ ],
423
+ "source": [
424
+ "#from transformers import BertForSequenceClassification\n",
425
+ "\n",
426
+ "#model = BertForSequenceClassification.from_pretrained(\n",
427
+ "# 'malay-huggingface/bert-tiny-bahasa-cased',\n",
428
+ "# num_labels=2\n",
429
+ "#)\n",
430
+ "\n",
431
+ "\n",
432
+ "from transformers import AutoModelForSequenceClassification\n",
433
+ "\n",
434
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
435
+ " 'mesolitica/bert-base-standard-bahasa-cased',\n",
436
+ " num_labels=2\n",
437
+ ")"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "code",
442
+ "execution_count": null,
443
+ "metadata": {
444
+ "id": "ifvtnwBMemi1"
445
+ },
446
+ "outputs": [],
447
+ "source": [
448
+ "def compute_metrics(p):\n",
449
+ " print(type(p))\n",
450
+ " pred, labels = p\n",
451
+ " pred = np.argmax(pred, axis=1)\n",
452
+ "\n",
453
+ " accuracy = accuracy_score(y_true=labels, y_pred=pred)\n",
454
+ " recall = recall_score(y_true=labels, y_pred=pred)\n",
455
+ " precision = precision_score(y_true=labels, y_pred=pred)\n",
456
+ " f1 = f1_score(y_true=labels, y_pred=pred)\n",
457
+ "\n",
458
+ " return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}\n"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": null,
464
+ "metadata": {
465
+ "id": "50Xy9P7Remi2"
466
+ },
467
+ "outputs": [],
468
+ "source": [
469
+ "from transformers import TrainingArguments, Trainer\n",
470
+ "\n",
471
+ "training_args = TrainingArguments(output_dir=\"./result\", evaluation_strategy=\"epoch\")\n",
472
+ "\n",
473
+ "trainer = Trainer(\n",
474
+ " model=model,\n",
475
+ " args=training_args,\n",
476
+ " train_dataset=train_hg,\n",
477
+ " eval_dataset=valid_hg,\n",
478
+ " tokenizer=tokenizer,\n",
479
+ " compute_metrics=compute_metrics\n",
480
+ ")"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "markdown",
485
+ "metadata": {
486
+ "id": "myIstfgJemi3"
487
+ },
488
+ "source": [
489
+ "## Train and Evaluate the model"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "metadata": {
496
+ "id": "-UtAkNHUemi4",
497
+ "outputId": "5af038f3-a77c-41eb-e48d-747a8e776e38"
498
+ },
499
+ "outputs": [
500
+ {
501
+ "name": "stderr",
502
+ "output_type": "stream",
503
+ "text": [
504
+ "C:\\Users\\user\\anaconda3\\lib\\site-packages\\transformers\\optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
505
+ " warnings.warn(\n",
506
+ "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
507
+ ]
508
+ },
509
+ {
510
+ "data": {
511
+ "text/html": [
512
+ "\n",
513
+ " <div>\n",
514
+ " \n",
515
+ " <progress value='417' max='417' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
516
+ " [417/417 56:36, Epoch 3/3]\n",
517
+ " </div>\n",
518
+ " <table border=\"1\" class=\"dataframe\">\n",
519
+ " <thead>\n",
520
+ " <tr style=\"text-align: left;\">\n",
521
+ " <th>Epoch</th>\n",
522
+ " <th>Training Loss</th>\n",
523
+ " <th>Validation Loss</th>\n",
524
+ " <th>Accuracy</th>\n",
525
+ " <th>Precision</th>\n",
526
+ " <th>Recall</th>\n",
527
+ " <th>F1</th>\n",
528
+ " </tr>\n",
529
+ " </thead>\n",
530
+ " <tbody>\n",
531
+ " <tr>\n",
532
+ " <td>1</td>\n",
533
+ " <td>No log</td>\n",
534
+ " <td>0.493876</td>\n",
535
+ " <td>0.779783</td>\n",
536
+ " <td>0.657343</td>\n",
537
+ " <td>0.886792</td>\n",
538
+ " <td>0.755020</td>\n",
539
+ " </tr>\n",
540
+ " <tr>\n",
541
+ " <td>2</td>\n",
542
+ " <td>No log</td>\n",
543
+ " <td>0.542367</td>\n",
544
+ " <td>0.870036</td>\n",
545
+ " <td>0.850000</td>\n",
546
+ " <td>0.801887</td>\n",
547
+ " <td>0.825243</td>\n",
548
+ " </tr>\n",
549
+ " <tr>\n",
550
+ " <td>3</td>\n",
551
+ " <td>No log</td>\n",
552
+ " <td>0.725669</td>\n",
553
+ " <td>0.848375</td>\n",
554
+ " <td>0.820000</td>\n",
555
+ " <td>0.773585</td>\n",
556
+ " <td>0.796117</td>\n",
557
+ " </tr>\n",
558
+ " </tbody>\n",
559
+ "</table><p>"
560
+ ],
561
+ "text/plain": [
562
+ "<IPython.core.display.HTML object>"
563
+ ]
564
+ },
565
+ "metadata": {},
566
+ "output_type": "display_data"
567
+ },
568
+ {
569
+ "name": "stdout",
570
+ "output_type": "stream",
571
+ "text": [
572
+ "<class 'transformers.trainer_utils.EvalPrediction'>\n",
573
+ "<class 'transformers.trainer_utils.EvalPrediction'>\n",
574
+ "<class 'transformers.trainer_utils.EvalPrediction'>\n"
575
+ ]
576
+ },
577
+ {
578
+ "data": {
579
+ "text/plain": [
580
+ "TrainOutput(global_step=417, training_loss=0.2771467213436282, metrics={'train_runtime': 3405.0836, 'train_samples_per_second': 0.974, 'train_steps_per_second': 0.122, 'total_flos': 218053287129600.0, 'train_loss': 0.2771467213436282, 'epoch': 3.0})"
581
+ ]
582
+ },
583
+ "execution_count": 16,
584
+ "metadata": {},
585
+ "output_type": "execute_result"
586
+ }
587
+ ],
588
+ "source": [
589
+ "trainer.train()"
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": null,
595
+ "metadata": {
596
+ "id": "fZYGhNyremi4",
597
+ "outputId": "5119c379-d7e9-48f7-9137-d788f99a3731"
598
+ },
599
+ "outputs": [
600
+ {
601
+ "data": {
602
+ "text/html": [
603
+ "\n",
604
+ " <div>\n",
605
+ " \n",
606
+ " <progress value='35' max='35' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
607
+ " [35/35 00:43]\n",
608
+ " </div>\n",
609
+ " "
610
+ ],
611
+ "text/plain": [
612
+ "<IPython.core.display.HTML object>"
613
+ ]
614
+ },
615
+ "metadata": {},
616
+ "output_type": "display_data"
617
+ },
618
+ {
619
+ "name": "stdout",
620
+ "output_type": "stream",
621
+ "text": [
622
+ "<class 'transformers.trainer_utils.EvalPrediction'>\n"
623
+ ]
624
+ },
625
+ {
626
+ "data": {
627
+ "text/plain": [
628
+ "{'eval_loss': 0.7256694436073303,\n",
629
+ " 'eval_accuracy': 0.8483754512635379,\n",
630
+ " 'eval_precision': 0.82,\n",
631
+ " 'eval_recall': 0.7735849056603774,\n",
632
+ " 'eval_f1': 0.796116504854369,\n",
633
+ " 'eval_runtime': 44.9419,\n",
634
+ " 'eval_samples_per_second': 6.164,\n",
635
+ " 'eval_steps_per_second': 0.779,\n",
636
+ " 'epoch': 3.0}"
637
+ ]
638
+ },
639
+ "execution_count": 17,
640
+ "metadata": {},
641
+ "output_type": "execute_result"
642
+ }
643
+ ],
644
+ "source": [
645
+ "trainer.evaluate()"
646
+ ]
647
+ },
648
+ {
649
+ "cell_type": "markdown",
650
+ "metadata": {
651
+ "id": "tlw24Ccdemi5"
652
+ },
653
+ "source": [
654
+ "## Save the model"
655
+ ]
656
+ },
657
+ {
658
+ "cell_type": "code",
659
+ "execution_count": null,
660
+ "metadata": {
661
+ "id": "69n4eVBHemi6"
662
+ },
663
+ "outputs": [],
664
+ "source": [
665
+ "model.save_pretrained('./model/')"
666
+ ]
667
+ },
668
+ {
669
+ "cell_type": "code",
670
+ "execution_count": null,
671
+ "metadata": {
672
+ "id": "gC9qDoERemi6",
673
+ "outputId": "a5514df7-d322-48b9-df27-c799dca6d884"
674
+ },
675
+ "outputs": [
676
+ {
677
+ "name": "stdout",
678
+ "output_type": "stream",
679
+ "text": [
680
+ "Looking in indexes: https://download.pytorch.org/whl/cu117\n",
681
+ "Requirement already satisfied: torch in c:\\users\\user\\anaconda3\\lib\\site-packages (2.0.1+cu118)\n",
682
+ "Requirement already satisfied: torchvision in c:\\users\\user\\anaconda3\\lib\\site-packages (0.15.2+cu117)\n",
683
+ "Requirement already satisfied: torchaudio in c:\\users\\user\\anaconda3\\lib\\site-packages (2.0.2+cu117)\n",
684
+ "Requirement already satisfied: sympy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (1.11.1)\n",
685
+ "Requirement already satisfied: jinja2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (3.1.2)\n",
686
+ "Requirement already satisfied: filelock in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (3.9.0)\n",
687
+ "Requirement already satisfied: networkx in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (2.5.1)\n",
688
+ "Requirement already satisfied: typing-extensions in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (4.4.0)\n",
689
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (9.4.0)\n",
690
+ "Requirement already satisfied: numpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.23.5)\n",
691
+ "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (2.28.1)\n",
692
+ "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from jinja2->torch) (2.1.1)\n",
693
+ "Requirement already satisfied: decorator<5,>=4.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from networkx->torch) (4.4.2)\n",
694
+ "Requirement already satisfied: charset-normalizer<3,>=2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.0.4)\n",
695
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (1.26.14)\n",
696
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.10)\n",
697
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2022.12.7)\n",
698
+ "Requirement already satisfied: mpmath>=0.19 in c:\\users\\user\\anaconda3\\lib\\site-packages (from sympy->torch) (1.2.1)\n"
699
+ ]
700
+ }
701
+ ],
702
+ "source": [
703
+ "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117"
704
+ ]
705
+ },
706
+ {
707
+ "cell_type": "code",
708
+ "execution_count": null,
709
+ "metadata": {
710
+ "id": "3NBugUKAemi7"
711
+ },
712
+ "outputs": [],
713
+ "source": []
714
+ },
715
+ {
716
+ "cell_type": "code",
717
+ "execution_count": null,
718
+ "metadata": {
719
+ "id": "-W3_K_Kjemi7"
720
+ },
721
+ "outputs": [],
722
+ "source": []
723
+ },
724
+ {
725
+ "cell_type": "markdown",
726
+ "metadata": {
727
+ "id": "yMiT54Ddemi7"
728
+ },
729
+ "source": [
730
+ "## Load the model"
731
+ ]
732
+ },
733
+ {
734
+ "cell_type": "code",
735
+ "execution_count": null,
736
+ "metadata": {
737
+ "id": "mEFnUaM3emi7"
738
+ },
739
+ "outputs": [],
740
+ "source": [
741
+ "import torch\n",
742
+ "from transformers import AutoModelForSequenceClassification\n",
743
+ "\n",
744
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
745
+ "\n",
746
+ "new_model = AutoModelForSequenceClassification.from_pretrained('./model/').to(device)"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "code",
751
+ "execution_count": null,
752
+ "metadata": {
753
+ "id": "zkDeulcTemi8",
754
+ "outputId": "2500b324-398b-471b-9c08-48fa79ea9de3"
755
+ },
756
+ "outputs": [
757
+ {
758
+ "name": "stderr",
759
+ "output_type": "stream",
760
+ "text": [
761
+ "ERROR: torch-1.0.1-cp36-cp36m-win_amd64.whl is not a supported wheel on this platform.\n",
762
+ "\n",
763
+ "[notice] A new release of pip is available: 23.0.1 -> 23.1.2\n",
764
+ "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
765
+ ]
766
+ },
767
+ {
768
+ "name": "stdout",
769
+ "output_type": "stream",
770
+ "text": [
771
+ "Requirement already satisfied: torchvision in c:\\users\\user\\anaconda3\\lib\\site-packages (0.14.0)\n",
772
+ "Requirement already satisfied: typing-extensions in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (4.1.1)\n",
773
+ "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (2.27.1)\n",
774
+ "Requirement already satisfied: torch==1.13.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.13.0)\n",
775
+ "Requirement already satisfied: numpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.24.2)\n",
776
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (9.0.1)\n",
777
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (3.3)\n",
778
+ "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.0.4)\n",
779
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2022.9.24)\n",
780
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (1.26.9)\n"
781
+ ]
782
+ },
783
+ {
784
+ "name": "stderr",
785
+ "output_type": "stream",
786
+ "text": [
787
+ "\n",
788
+ "[notice] A new release of pip is available: 23.0.1 -> 23.1.2\n",
789
+ "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
790
+ ]
791
+ }
792
+ ],
793
+ "source": [
794
+ "!pip install https://download.pytorch.org/whl/cpu/torch-1.0.1-cp36-cp36m-win_amd64.whl\n",
795
+ "!pip install torchvision"
796
+ ]
797
+ },
798
+ {
799
+ "cell_type": "code",
800
+ "execution_count": null,
801
+ "metadata": {
802
+ "id": "WtI-WDBhemi8"
803
+ },
804
+ "outputs": [],
805
+ "source": [
806
+ "from transformers import AutoTokenizer\n",
807
+ "\n",
808
+ "new_tokenizer = AutoTokenizer.from_pretrained('mesolitica/bert-base-standard-bahasa-cased')"
809
+ ]
810
+ },
811
+ {
812
+ "cell_type": "markdown",
813
+ "metadata": {
814
+ "id": "S2X_uPYJemi9"
815
+ },
816
+ "source": [
817
+ "## Get predictions"
818
+ ]
819
+ },
820
+ {
821
+ "cell_type": "code",
822
+ "execution_count": null,
823
+ "metadata": {
824
+ "id": "qXKQEiWxemi9"
825
+ },
826
+ "outputs": [],
827
+ "source": [
828
+ "import torch\n",
829
+ "import numpy as np\n",
830
+ "\n",
831
+ "def get_prediction(text):\n",
832
+ " encoding = new_tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128)\n",
833
+ " encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}\n",
834
+ "\n",
835
+ " outputs = new_model(**encoding)\n",
836
+ "\n",
837
+ " logits = outputs.logits\n",
838
+ " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
839
+ " sigmoid = torch.nn.Sigmoid()\n",
840
+ " print(sigmoid)\n",
841
+ " probs = sigmoid(logits.squeeze().cpu())\n",
842
+ " probs = probs.detach().numpy()\n",
843
+ " label = np.argmax(probs, axis=-1)\n",
844
+ "\n",
845
+ " if label == 1:\n",
846
+ " return {\n",
847
+ " 'Target': 'Cyberbully',\n",
848
+ " 'probability': probs[1]\n",
849
+ " }\n",
850
+ " else:\n",
851
+ " return {\n",
852
+ " 'Target': 'Not Cyberbully',\n",
853
+ " 'probability': probs[0]\n",
854
+ " }"
855
+ ]
856
+ },
857
+ {
858
+ "cell_type": "code",
859
+ "execution_count": null,
860
+ "metadata": {
861
+ "id": "NcYq4vmVemi9"
862
+ },
863
+ "outputs": [],
864
+ "source": [
865
+ "# dir()"
866
+ ]
867
+ },
868
+ {
869
+ "cell_type": "code",
870
+ "execution_count": null,
871
+ "metadata": {
872
+ "id": "CS_2FfAeemi_",
873
+ "outputId": "106776a5-fced-4329-aa1f-5970a4a71386"
874
+ },
875
+ "outputs": [
876
+ {
877
+ "name": "stdout",
878
+ "output_type": "stream",
879
+ "text": [
880
+ "Sigmoid()\n"
881
+ ]
882
+ },
883
+ {
884
+ "data": {
885
+ "text/plain": [
886
+ "{'Target': 'Cyberbully', 'probability': 0.9651532}"
887
+ ]
888
+ },
889
+ "execution_count": 24,
890
+ "metadata": {},
891
+ "output_type": "execute_result"
892
+ }
893
+ ],
894
+ "source": [
895
+ "get_prediction('Aku malas kerja dengan orang macam ni menyusahkan orang je')"
896
+ ]
897
+ }
898
+ ],
899
+ "metadata": {
900
+ "kernelspec": {
901
+ "display_name": "Python 3 (ipykernel)",
902
+ "language": "python",
903
+ "name": "python3"
904
+ },
905
+ "language_info": {
906
+ "codemirror_mode": {
907
+ "name": "ipython",
908
+ "version": 3
909
+ },
910
+ "file_extension": ".py",
911
+ "mimetype": "text/x-python",
912
+ "name": "python",
913
+ "nbconvert_exporter": "python",
914
+ "pygments_lexer": "ipython3",
915
+ "version": "3.10.9"
916
+ },
917
+ "vscode": {
918
+ "interpreter": {
919
+ "hash": "173fe52379437b78f95c8980b8ee9f2930fd7b56889ab31a72735475ddc10c81"
920
+ }
921
+ },
922
+ "colab": {
923
+ "provenance": []
924
+ }
925
+ },
926
+ "nbformat": 4,
927
+ "nbformat_minor": 0
928
+ }