{ "cells": [ { "cell_type": "code", "source": [], "metadata": { "id": "b3_hlnrYh30E" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Acknowledgement:\n", "Thanks to @RajKKapadia
\n", "Link: https://github.com/RajKKapadia/Transformers-Text-Classification-BERT-Blog" ], "metadata": { "id": "zK4VsKufh4gZ" } }, { "cell_type": "markdown", "metadata": { "id": "XLhB2j_Hemio" }, "source": [ "## Read the dataset csv file" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hgYEtrYgemir", "outputId": "d3ddedc7-8bd7-4ba9-c82e-68e4eb1309c3" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0Texttarget
00.0polis tangkapNonCyberbully
11.0kenapa lokasi kebakaran terlalu spesifikNonCyberbully
22.0menyesal tanya nak for birthdayNonCyberbully
33.0meriah tahNonCyberbully
44.0asal bs kelar kerja jam sik kl baru diajak mee...NonCyberbully
\n", "
" ], "text/plain": [ " Unnamed: 0 Text \\\n", "0 0.0 polis tangkap \n", "1 1.0 kenapa lokasi kebakaran terlalu spesifik \n", "2 2.0 menyesal tanya nak for birthday \n", "3 3.0 meriah tah \n", "4 4.0 asal bs kelar kerja jam sik kl baru diajak mee... \n", "\n", " target \n", "0 NonCyberbully \n", "1 NonCyberbully \n", "2 NonCyberbully \n", "3 NonCyberbully \n", "4 NonCyberbully " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "df = pd.read_csv('C:/Users/user/Documents/PSM/BERT_Ver2/Transformers-Text-Classification-BERT-Blog-main/input/Tagged_MixedNew.csv')\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "fGUtFkVfemit" }, "source": [ "## Process the data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7C3uRWECemiu", "outputId": "8e764d84-010d-4e42-987a-af7162627f6e", "colab": { "referenced_widgets": [ "042c8b0b8dcf42eb84660c93778d8ea7", "4ab6074437a849f79be038b043025283", "9aed4d88c18e4e28a1efbbed94331228" ] } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "042c8b0b8dcf42eb84660c93778d8ea7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)okenizer_config.json: 0%| | 0.00/380 [00:00\n", " \n", " \n", " [417/417 56:36, Epoch 3/3]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyPrecisionRecallF1
1No log0.4938760.7797830.6573430.8867920.755020
2No log0.5423670.8700360.8500000.8018870.825243
3No log0.7256690.8483750.8200000.7735850.796117

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=417, training_loss=0.2771467213436282, metrics={'train_runtime': 3405.0836, 'train_samples_per_second': 0.974, 'train_steps_per_second': 0.122, 'total_flos': 218053287129600.0, 'train_loss': 0.2771467213436282, 'epoch': 3.0})" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fZYGhNyremi4", "outputId": "5119c379-d7e9-48f7-9137-d788f99a3731" }, "outputs": [ { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [35/35 00:43]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/plain": [ "{'eval_loss': 0.7256694436073303,\n", " 'eval_accuracy': 0.8483754512635379,\n", " 'eval_precision': 0.82,\n", " 'eval_recall': 0.7735849056603774,\n", " 'eval_f1': 0.796116504854369,\n", " 'eval_runtime': 44.9419,\n", " 'eval_samples_per_second': 6.164,\n", " 'eval_steps_per_second': 0.779,\n", " 'epoch': 3.0}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate()" ] }, { "cell_type": "markdown", "metadata": { "id": "tlw24Ccdemi5" }, "source": [ "## Save the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "69n4eVBHemi6" }, "outputs": [], "source": [ "model.save_pretrained('./model/')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gC9qDoERemi6", "outputId": "a5514df7-d322-48b9-df27-c799dca6d884" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://download.pytorch.org/whl/cu117\n", "Requirement already satisfied: torch in c:\\users\\user\\anaconda3\\lib\\site-packages (2.0.1+cu118)\n", "Requirement already satisfied: torchvision in c:\\users\\user\\anaconda3\\lib\\site-packages (0.15.2+cu117)\n", "Requirement already satisfied: torchaudio in c:\\users\\user\\anaconda3\\lib\\site-packages (2.0.2+cu117)\n", "Requirement already satisfied: sympy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (1.11.1)\n", "Requirement already satisfied: jinja2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (3.1.2)\n", "Requirement already satisfied: filelock in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (3.9.0)\n", "Requirement already satisfied: networkx in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (2.5.1)\n", "Requirement already satisfied: typing-extensions in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (4.4.0)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (9.4.0)\n", "Requirement already satisfied: numpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.23.5)\n", "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (2.28.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from jinja2->torch) (2.1.1)\n", "Requirement already satisfied: decorator<5,>=4.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from networkx->torch) (4.4.2)\n", "Requirement already satisfied: charset-normalizer<3,>=2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.0.4)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (1.26.14)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.10)\n", "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2022.12.7)\n", "Requirement already satisfied: mpmath>=0.19 in c:\\users\\user\\anaconda3\\lib\\site-packages (from sympy->torch) (1.2.1)\n" ] } ], "source": [ "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3NBugUKAemi7" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-W3_K_Kjemi7" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "yMiT54Ddemi7" }, "source": [ "## Load the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mEFnUaM3emi7" }, "outputs": [], "source": [ "import torch\n", "from transformers import AutoModelForSequenceClassification\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "new_model = AutoModelForSequenceClassification.from_pretrained('./model/').to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zkDeulcTemi8", "outputId": "2500b324-398b-471b-9c08-48fa79ea9de3" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ERROR: torch-1.0.1-cp36-cp36m-win_amd64.whl is not a supported wheel on this platform.\n", "\n", "[notice] A new release of pip is available: 23.0.1 -> 23.1.2\n", "[notice] To update, run: python.exe -m pip install --upgrade pip\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: torchvision in c:\\users\\user\\anaconda3\\lib\\site-packages (0.14.0)\n", "Requirement already satisfied: typing-extensions in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (4.1.1)\n", "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (2.27.1)\n", "Requirement already satisfied: torch==1.13.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.13.0)\n", "Requirement already satisfied: numpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.24.2)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (9.0.1)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (3.3)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.0.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2022.9.24)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (1.26.9)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "[notice] A new release of pip is available: 23.0.1 -> 23.1.2\n", "[notice] To update, run: python.exe -m pip install --upgrade pip\n" ] } ], "source": [ "!pip install https://download.pytorch.org/whl/cpu/torch-1.0.1-cp36-cp36m-win_amd64.whl\n", "!pip install torchvision" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WtI-WDBhemi8" }, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "new_tokenizer = AutoTokenizer.from_pretrained('mesolitica/bert-base-standard-bahasa-cased')" ] }, { "cell_type": "markdown", "metadata": { "id": "S2X_uPYJemi9" }, "source": [ "## Get predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qXKQEiWxemi9" }, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "\n", "def get_prediction(text):\n", " encoding = new_tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128)\n", " encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}\n", "\n", " outputs = new_model(**encoding)\n", "\n", " logits = outputs.logits\n", " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", " sigmoid = torch.nn.Sigmoid()\n", " print(sigmoid)\n", " probs = sigmoid(logits.squeeze().cpu())\n", " probs = probs.detach().numpy()\n", " label = np.argmax(probs, axis=-1)\n", "\n", " if label == 1:\n", " return {\n", " 'Target': 'Cyberbully',\n", " 'probability': probs[1]\n", " }\n", " else:\n", " return {\n", " 'Target': 'Not Cyberbully',\n", " 'probability': probs[0]\n", " }" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NcYq4vmVemi9" }, "outputs": [], "source": [ "# dir()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CS_2FfAeemi_", "outputId": "106776a5-fced-4329-aa1f-5970a4a71386" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sigmoid()\n" ] }, { "data": { "text/plain": [ "{'Target': 'Cyberbully', 'probability': 0.9651532}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_prediction('Aku malas kerja dengan orang macam ni menyusahkan orang je')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "vscode": { "interpreter": { "hash": "173fe52379437b78f95c8980b8ee9f2930fd7b56889ab31a72735475ddc10c81" } }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 0 }