diff --git "a/Deep learning/.ipynb_checkpoints/Binary classification using deep learning-checkpoint.ipynb" "b/Deep learning/.ipynb_checkpoints/Binary classification using deep learning-checkpoint.ipynb" new file mode 100644--- /dev/null +++ "b/Deep learning/.ipynb_checkpoints/Binary classification using deep learning-checkpoint.ipynb" @@ -0,0 +1,836 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import pickle\n", + "\n", + "from sklearn.model_selection import train_test_split \n", + "\n", + "from keras.preprocessing.text import Tokenizer \n", + "from keras.preprocessing.sequence import pad_sequences" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
VideoIDEffectiveness
0pvuN_WvF1to0.0
1eRLJscAlk1M1.0
2VbiRNT_gWUQ0.0
35scez5dqtAc1.0
4JDcro7dPqpA0.0
.........
131IQpIVsxx0140.0
132JYZpxRy5Mfg1.0
1338DiWzvE52ZY0.0
134OwqIy8Ikv-c0.0
135lPgZfhnCAdI0.0
\n", + "

136 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " VideoID Effectiveness\n", + "0 pvuN_WvF1to 0.0\n", + "1 eRLJscAlk1M 1.0\n", + "2 VbiRNT_gWUQ 0.0\n", + "3 5scez5dqtAc 1.0\n", + "4 JDcro7dPqpA 0.0\n", + ".. ... ...\n", + "131 IQpIVsxx014 0.0\n", + "132 JYZpxRy5Mfg 1.0\n", + "133 8DiWzvE52ZY 0.0\n", + "134 OwqIy8Ikv-c 0.0\n", + "135 lPgZfhnCAdI 0.0\n", + "\n", + "[136 rows x 2 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_excel('../Resources/Climate_change_links_2.xlsx')\n", + "data = df[[\"VideoID\", \"Effectiveness\"]]\n", + "\n", + "# clean data\n", + "data = data.loc[(data[\"Effectiveness\"] == 1) | (data[\"Effectiveness\"] == 2) | (data[\"Effectiveness\"] == 3) | (data[\"Effectiveness\"] == 4) | (data[\"Effectiveness\"] == 5)]\n", + "data = data.reset_index()\n", + "del data[\"index\"]\n", + "\n", + "## Custom encoder\n", + "def custom_encoder(df):\n", + " df.replace(to_replace = 1.0, value = 0, inplace=True)\n", + " df.replace(to_replace = 2.0, value = 0, inplace=True)\n", + " df.replace(to_replace = 4.0, value = 1, inplace=True)\n", + " df.replace(to_replace = 5.0, value = 1, inplace=True)\n", + "\n", + "custom_encoder(df['Effectiveness'])\n", + "\n", + "data = df[[\"VideoID\", \"Effectiveness\"]]\n", + "data = data[data[\"Effectiveness\"] != 3]\n", + "data = data.loc[(data[\"Effectiveness\"] == 0) | (data[\"Effectiveness\"] == 1)]\n", + "data = data.reset_index()\n", + "del data[\"index\"]\n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "## get documnets (pre-processd comments)\n", + "documents = []\n", + "for i in range(len(data)):\n", + " VideoID = data[\"VideoID\"][i]\n", + " comment = pd.read_csv(\"../../NLP Preprocessing/03_Processed_Comments/\"+VideoID+\"/\"+VideoID+\"_all_words.csv\")\n", + " documents.append(list(comment[\"0\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
VideoIDEffectivenesscleanedcleaned_string
0pvuN_WvF1to0.0[clean, version, video, child, love, northeast...clean version video child love northeast india...
1eRLJscAlk1M1.0[step, take, help, fight, climate, change, wel...step take help fight climate change well equal...
2VbiRNT_gWUQ0.0[country, disappear, video, year, old, world, ...country disappear video year old world map did...
35scez5dqtAc1.0[im, watch, trump, biden, ha, already, start, ...im watch trump biden ha already start process ...
4JDcro7dPqpA0.0[fun, fact, cow, belch, fart, adult, version, ...fun fact cow belch fart adult version bill nye...
...............
131IQpIVsxx0140.0[corporate, medium, fear, monger, earth, get, ...corporate medium fear monger earth get flat wr...
132JYZpxRy5Mfg1.0[usually, consumer_NEG, say_NEG, though_NEG, s...usually consumer_NEG say_NEG though_NEG suppor...
1338DiWzvE52ZY0.0[marios, leave, hand, doe, intro, impressive, ...marios leave hand doe intro impressive today p...
134OwqIy8Ikv-c0.0[lie, interseting, isnt, group_NEG, consist_NE...lie interseting isnt group_NEG consist_NEG com...
135lPgZfhnCAdI0.0[miss, man, wa, hero, didnt, cherish_NEG, enou...miss man wa hero didnt cherish_NEG enough_NEG ...
\n", + "

136 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " VideoID Effectiveness \\\n", + "0 pvuN_WvF1to 0.0 \n", + "1 eRLJscAlk1M 1.0 \n", + "2 VbiRNT_gWUQ 0.0 \n", + "3 5scez5dqtAc 1.0 \n", + "4 JDcro7dPqpA 0.0 \n", + ".. ... ... \n", + "131 IQpIVsxx014 0.0 \n", + "132 JYZpxRy5Mfg 1.0 \n", + "133 8DiWzvE52ZY 0.0 \n", + "134 OwqIy8Ikv-c 0.0 \n", + "135 lPgZfhnCAdI 0.0 \n", + "\n", + " cleaned \\\n", + "0 [clean, version, video, child, love, northeast... \n", + "1 [step, take, help, fight, climate, change, wel... \n", + "2 [country, disappear, video, year, old, world, ... \n", + "3 [im, watch, trump, biden, ha, already, start, ... \n", + "4 [fun, fact, cow, belch, fart, adult, version, ... \n", + ".. ... \n", + "131 [corporate, medium, fear, monger, earth, get, ... \n", + "132 [usually, consumer_NEG, say_NEG, though_NEG, s... \n", + "133 [marios, leave, hand, doe, intro, impressive, ... \n", + "134 [lie, interseting, isnt, group_NEG, consist_NE... \n", + "135 [miss, man, wa, hero, didnt, cherish_NEG, enou... \n", + "\n", + " cleaned_string \n", + "0 clean version video child love northeast india... \n", + "1 step take help fight climate change well equal... \n", + "2 country disappear video year old world map did... \n", + "3 im watch trump biden ha already start process ... \n", + "4 fun fact cow belch fart adult version bill nye... \n", + ".. ... \n", + "131 corporate medium fear monger earth get flat wr... \n", + "132 usually consumer_NEG say_NEG though_NEG suppor... \n", + "133 marios leave hand doe intro impressive today p... \n", + "134 lie interseting isnt group_NEG consist_NEG com... \n", + "135 miss man wa hero didnt cherish_NEG enough_NEG ... \n", + "\n", + "[136 rows x 4 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "## create two new columns of the pre-processed data in list and string form\n", + "data['cleaned'] = documents\n", + "data['cleaned_string'] = [' '.join(map(str, l)) for l in data['cleaned']]\n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# split training and test set\n", + "sentences = data['cleaned_string'].values\n", + "y = data['Effectiveness'].values\n", + "\n", + "sentences_train, sentences_test, y_train, y_test = train_test_split(sentences, y, test_size=0.25, random_state=1000)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.style.use('ggplot')\n", + "\n", + "# plot to see overfitting of neural networks\n", + "def plot_history(history, save_fig=False):\n", + " acc = history.history['accuracy']\n", + " val_acc = history.history['val_accuracy']\n", + " loss = history.history['loss']\n", + " val_loss = history.history['val_loss']\n", + " x = range(1, len(acc) + 1)\n", + "\n", + " plt.figure(figsize=(12, 5))\n", + " plt.subplot(1, 2, 1)\n", + " plt.plot(x, acc, 'b', label='Training acc')\n", + " plt.plot(x, val_acc, 'r', label='Validation acc')\n", + " plt.title('Training and validation accuracy')\n", + " plt.xlabel('Epoch')\n", + " plt.ylabel('Accuracy')\n", + " plt.legend()\n", + " plt.subplot(1, 2, 2)\n", + " plt.plot(x, loss, 'b', label='Training loss')\n", + " plt.plot(x, val_loss, 'r', label='Validation loss')\n", + " plt.title('Training and validation loss')\n", + " plt.xlabel('Epoch')\n", + " plt.ylabel('Loss')\n", + " plt.legend()\n", + " if save_fig:\n", + " plt.savefig('History plot.png', bbox_inches = \"tight\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Convolutional Neural Network (CNN)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# create tokenizer\n", + "tokenizer = Tokenizer(num_words=5000)\n", + "tokenizer.fit_on_texts(sentences_train)\n", + "\n", + "X_train = tokenizer.texts_to_sequences(sentences_train)\n", + "X_test = tokenizer.texts_to_sequences(sentences_test)\n", + "\n", + "vocab_size = len(tokenizer.word_index) + 1 # add 1 because of reserved 0 index\n", + "\n", + "# pad data so that it has the same length\n", + "maxlen = 100\n", + "\n", + "X_train = pad_sequences(X_train, padding='post', maxlen=maxlen)\n", + "X_test = pad_sequences(X_test, padding='post', maxlen=maxlen)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "from keras.models import Sequential\n", + "from keras import layers\n", + "from keras.backend import clear_session\n", + "\n", + "from numpy.random import seed\n", + "seed(1)\n", + "from tensorflow.random import set_seed \n", + "set_seed(2)\n", + "\n", + "clear_session()\n", + "\n", + "embedding_dim = 100\n", + "\n", + "## create CNN model\n", + "model = Sequential()\n", + "# add embedding layer\n", + "model.add(layers.Embedding(vocab_size, embedding_dim, input_length=maxlen))\n", + "# add convolutional layer\n", + "model.add(layers.Conv1D(128, 7, activation='relu'))\n", + "# add pooling layer\n", + "model.add(layers.GlobalMaxPooling1D())\n", + "# add dense layer with 10 neurons\n", + "model.add(layers.Dense(10, activation='softmax'))\n", + "# add output layer\n", + "model.add(layers.Dense(1, activation='sigmoid'))\n", + "\n", + "model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy',\n", + " tf.keras.metrics.TruePositives(),\n", + " tf.keras.metrics.TrueNegatives(),\n", + " tf.keras.metrics.FalsePositives(),\n", + " tf.keras.metrics.FalseNegatives(), \n", + " tf.keras.metrics.Precision(class_id=None),\n", + " tf.keras.metrics.Recall()])\n", + "\n", + "history = model.fit(X_train, y_train,\n", + " epochs=120,\n", + " verbose=False,\n", + " validation_data=(X_test, y_test),\n", + " batch_size=28)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"sequential\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "embedding (Embedding) (None, 100, 100) 41079600 \n", + "_________________________________________________________________\n", + "conv1d (Conv1D) (None, 94, 128) 89728 \n", + "_________________________________________________________________\n", + "global_max_pooling1d (Global (None, 128) 0 \n", + "_________________________________________________________________\n", + "dense (Dense) (None, 10) 1290 \n", + "_________________________________________________________________\n", + "dense_1 (Dense) (None, 1) 11 \n", + "=================================================================\n", + "Total params: 41,170,629\n", + "Trainable params: 41,170,629\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "# model summary\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# plot accuracy and loss plots\n", + "plot_history(history)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing Accuracy: 82.3529\n" + ] + } + ], + "source": [ + "# Evaluate\n", + "loss, accuracy, tp, tn, fp, fn, precision, recall = model.evaluate(X_test, y_test, verbose=False)\n", + "print(\"Testing Accuracy: {:.4f}\".format(accuracy*100))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Precision: 0.8888888955116272\n", + "Recall: 0.6153846383094788\n", + "Accuracy: 0.8235294222831726\n", + "F1-Score: 0.7272727454989408\n" + ] + } + ], + "source": [ + "print(\"Precision: \", precision)\n", + "print(\"Recall: \", recall)\n", + "print(\"Accuracy: \", accuracy)\n", + "print(\"F1-Score: \", (2*precision*recall)/(precision+recall))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAAD8CAYAAABAWd66AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjcUlEQVR4nO3de1xUdf7H8dfMoCIiCCPetQQzpTI1zA0tb6OZWpGZrbdSa8209VamqJubSppKoJuslmalW1tbatlum6Gpm2aS6GqaF3atzdQQBhENUJj5/cEvikAZcGDOTO+nj3nIzPmec75H8e2Xz/mec0xOp9OJiIh4nNnTHRARkSIKZBERg1Agi4gYhAJZRMQgFMgiIgahQBYRMQg/T3dARMSIMjIyWLZsGWfPnsVkMmGz2ejXrx/nz58nISGBM2fOEBYWxuTJkwkMDCy1/r59+1i9ejUOh4NevXoRExNT7j5NmocsIlJaVlYWWVlZhIeHk5uby/Tp05k6dSpbt24lMDCQmJgYNmzYwPnz5xk+fHiJdR0OBxMnTmTWrFlYrVZiY2OZOHEizZo1u+I+VbIQESlDSEgI4eHhANSuXZumTZtit9tJSUmhW7duAHTr1o2UlJRS66alpdGoUSMaNmyIn58f0dHRZbb7pSovWbz2xbdVvQvxQg+2b+7pLogB+bshkWp3eMLlthsXxZCcnFz83mazYbPZSrVLT0/n+PHjtGrViuzsbEJCQoCi0D537lyp9na7HavVWvzearVy7NixcvujGrKI+BaT6z/4Xy6Afy4vL4/4+HhGjhxJQECAS9stqxJsMpnKXU8lCxHxLSaT669yFBQUEB8fz+23307nzp0BCA4OJisrCyiqMwcFBZVaz2q1kpmZWfw+MzOzeFR9JQpkEfEtJrPrrytwOp0sX76cpk2bMmDAgOLPo6Ki2LZtGwDbtm2jU6dOpdaNiIjg1KlTpKenU1BQwM6dO4mKiiq36ypZiIhvcWHk64ojR46wfft2WrRowdSpUwEYMmQIMTExJCQksGXLFurXr8+UKVOAorrxihUriI2NxWKxMHr0aOLi4nA4HPTo0YPmzcs/b1Ll0950Uk/KopN6Uha3nNTrPNXltrmfL7r6HbqRRsgi4lsqcFLPaBTIIuJb3FSy8AQFsoj4Fo2QRUQMQiNkERGD0AhZRMQgzBZP96DSFMgi4ls0QhYRMQizasgiIsagEbKIiEFoloWIiEHopJ6IiEGoZCEiYhAqWYiIGIRGyCIiBqERsoiIQWiELCJiEJplISJiEBohi4gYhGrIIiIGoRGyiIhBuHGEnJSURGpqKsHBwcTHxwOQkJDAyZMnAfjhhx8ICAhg0aLSD0sdP348/v7+mM1mLBYLCxYsKHd/CmQR8S1uHCF3796dvn37smzZsuLPJk+eXPz166+/TkBAwGXXnz17NkFBQS7vz3vH9iIiZTCZzS6/yhMZGUlgYGCZy5xOJ5999hldunRxW981QhYRn2KqQMkiOTmZ5OTk4vc2mw2bzebSul999RXBwcE0btz4sm3i4uIA6N27t0vbVSCLiG+pQAm5IgH8Szt27Lji6Hju3LmEhoaSnZ3NvHnzaNKkCZGRkVfcpkoWIuJTTCaTy6/KKiwsZPfu3URHR1+2TWhoKADBwcF06tSJtLS0crerQBYRn1IdgXzgwAGaNGmC1Wotc3leXh65ubnFX+/fv58WLVqUu12VLETEp5hdOFnnqsTERA4dOkROTg5jx45l8ODB9OzZs8xyhd1uZ8WKFcTGxpKdnc3ixYuBotF0165dad++fbn7MzmdTqfbel+G1774tio3L17qwfbNPd0FMSB/NwwRg4eucblt9hsjrn6HbqQRsoj4lKspRXiaAllEfIoCWUTEIBTIIiIGoUAWETEIk1mBLCJiCBohi4gYhAJZRMQovDePFcgi4ls0QhYRMQgFspRr94fvsO+TD8FkokHzlgwYMxW/mjU93S3xoGdmxbJ921ZCQ62se+8DT3fHZ7jzXhbVzXt77kVy7BmkfLSBUfOSGPP8ShyOQg599omnuyUedm/MQP68YqWnu+F7TBV4GYwCuZo4CgspuJhf9Ht+PoEhZd+2T349bonqRFBwsKe74XOq4/abVUUli2pQN7Q+nfs/wIsThuJXsxbhN91CeLsoT3dLxCcZMWhdpRFyNci9kMOxPTsZl7iWCS++xaX8PL78NLn8FUWkwrx5hKxArgZff5lKvbBG1Amqh8XPj+s7deXEsYOe7paITzKZTS6/jEaBXA2CrA34Lu0rLuXn4XQ6+frgXqxNyn+ci4hUnDePkFVDrgZNW7Wlza13sGrm45gtFhpd04oOPft7ulviYdOemsIXKbs5ezaL3j3v4PHxv2fg/Q94ultez4hB6yo9wkk8Qo9wkrK44xFOLSf93eW2xxONNTAq9/C/++47UlJSsNvtmEwmQkJCiIqKolmzZtXRPxGRinHjADkpKYnU1FSCg4OJj48H4O2332bz5s0EBQUBMGTIEDp27Fhq3X379rF69WocDge9evUiJiam3P1dMZA3bNhQ/HTVVq1aAUVPVl2yZAldunS57A6Sk5NJTi6aRdB20PhyOyEi4i7uLFl0796dvn37smzZshKf9+/fn3vuueey6zkcDlatWsWsWbOwWq3Exsa6NJC9YiB/8sknxMfH4+dXstmAAQOYMmXKZQPZZrNhs9kA3yhZnMtM5/0/P8+F7CxMJhPte/bn1r4D2fa31RzdsxOTyUydoHoMGDuVuiH1S62/bOIwavrXxmS2YLZYGD0vCeCy63975Ev+uXoJfjVqcO/4mYQ2akrehfOs/9NcfjttgVfXyHzV6VOnmBn7NJmZGZhMZgY9MJhhIx4u0eZcdjbP/GEGJ779HzVr1uLZec9x3XWti5cXFhYyZPD9NGjYkBeTVgCQEL+IHZ9u5/o2bYmbvxCAje9v4Fx2dqntSxGzG2dPREZGkp6eXuH10tLSaNSoEQ0bNgQgOjqalJSUqwtkk8lEVlYWYWFhJT7Pysr6VYWC2WzBNmwsjVpeR37uD6ye9Tgtb7yF3/QfTLcHRgGQ8s/1fLpuLXc9MqnMbQybFU9A3ZJXZV1u/c//8Q73T5pN9pnvSU3eiG34WD5dv5boe4f+qv7cvYnFz8JTT0+nbeQNXLhwnt8+cD+/ua0LEf//kyXAypeX06ZNWxKXLuP4f//Dc/Pm8PIrrxUv/8ua1wkPj+D8hfMA5OTk8O99e3ln/UZin36SY0eP0LzFNby/YT1JuuT6sqrj38hHH33E9u3bCQ8P56GHHiIwMLDEcrvdjtX609W4VquVY8eOlbvdKwbyyJEjmTNnDo0bNy7eeEZGBqdPn+aRRx6pzHF4pcAQa/GlzrVqB2Bt0oLzWRmENbumuM2l/NwK165qBdQpc32LxULBxYtcys/D4mch6/uT5GRlcE3bm6/6WKRqhIU1ICysAQB16gQSHh5Oevr3JQL5v//5D6MfHQNAy/AITp78jsyMDKz16/P96dP8a/tWHh0zljWvvwoUjfQuXbqE0+kkLz8fPz8/Xn1lJUOHj6BGjRrVfozeoiJ5/PPyKpT86f5y+vTpw6BBgwB46623eP311xk3blyJNmXNlXDlP4orBnL79u1ZsmQJaWlp2O12AEJDQ2nVqpVX31Hpapw9c5rvv0mjSUQbALa+/QoH/vUxtQLqMGzm4rJXMpl4c8E0TJjo0Ks/HXoOKF5U1vrR9wzhHytfwK9mLe55fDqb31hBtwdGVvWhiZt8990JDn/1FTe1K/kfaOvr27A5+WM63hLFgf37OXXyJN9/fxpr/fosXPAck5+cyoULF4rb16kTiK13Hx68P4Zbf3MbgXXrcvDLLxk77onqPiSvUpERsisB/Ev16tUr/rpXr148//zzpdpYrVYyMzOL32dmZhISElLutsudZWE2m2ndunV5zX4VLublsi7xWWwjxhWPbrsPHk33waPZ+d4b7Nn0HncMKl3Xe2h2InVD6nMhO4s3F0zD2rgFLdq2u+z6Da9txcg5LwLwv6/2UzfEitMJ65fOxeznR69hYwkMLv8vV6rfDxcu8OSkCUydPqPUj7GjHx3D8/PjGDzwXlq1bk2bNm2xWPzYtvUTQkNDibzhRlJ2f15inVGP/I5Rj/wOgD8+M5Nxv5/Aunf+xmc7P+W61tczZmzJkZlUbIRcGVlZWcXhunv3bpo3Lz2FMyIiglOnTpGenk5oaCg7d+5kwoQJ5W771znMrYTCggLeTfwjN3TpRZtOt5dafkN0Lw6n/KvMdX880VcnOITWUV04+d/DLq3vdDrZseEvdLlvOJ+ue53bBz3MjV168cVH691wROJuly5dYsqkCfTrfze23n1KLQ8MDGRu3HzeXvcecfMXkpWVRdNmzdi3N5WtW7dwV++eTHtqCimf7yJ22lMl1v3qq0MAXHPNtWx8fwOLXlhCWtoxvvnm6+o4NK9iNptcfpUnMTGRWbNmcfLkScaOHcuWLVtYu3YtTz75JE899RQHDx7k4YeLBmF2u5358+cDRWXH0aNHExcXx+TJk7ntttvKDO5f0pV6LnA6nfz95cXUb3oNnfsNKv7cfvoEoY2KzpoeTd2JtXHpP/CLebk4nU5q1Q7gYl4uxw/soet9w11a/8D2TUR06EztOnW5dDEfk8mMyWTmUn5eVR2qVJLT6eSPz8wsOskzclSZbc6dO0dtf39q1KzJunf+RseoKAIDA5k4+UkmTn4SgJTdn/Paq68w//mS5a9lf1rCM3+cQ0FBAY7CQgDMJjN5ufpe+CV3zrKYNGlSqc969uxZZtvQ0FBiY2OL33fs2LHM+clXokB2wYmjX/Llp8mENW/JytjHAOj+4Gj+vfVDMk+dwGQyEVy/IXeNngRATlYG/3j5BR58+jkunMvi3YQ/AkX3RL4huicRN98KwCd/XVnm+gCX8vPY/69NDJleVJ+69a5BrEv8Ixa/Gtz7xMxqO3Zxzd7UPXzw/ntc17o1gwfeC8DvJ03h1KmTAAx+cAjH//sfZsVOw2wxEx7RimfnxLm07S2bk7nxxpto0KBoClW79h24P+ZuWrduzfVt2lTNAXkxb56IpEunxSN06bSUxR2XTrd7xvVb2+6fU7ETelVNI2QR8SnePFdfgSwiPsWL81iBLCK+xZ0n9aqbAllEfIpKFiIiBuHFeaxAFhHfohGyiIhBeHEeK5BFxLdohCwiYhCaZSEiYhBePEBWIIuIb1HJQkTEILw4jxXIIuJbNEIWETEIBbKIiEFoloWIiEF48QBZgSwivkUlCxERg/DiPFYgi4hvMbsxkZOSkkhNTSU4OJj4+HgA1qxZw549e/Dz86Nhw4aMGzeOOnXqlFp3/Pjx+Pv7YzabsVgsLFiwoNz9KZBFxKe486Re9+7d6du3L8uWLSv+rF27dgwdOhSLxcLatWtZv349w4cPL3P92bNnExQU5PL+zFfdYxERAzGbXH+VJzIyksDAwBKf3XzzzVgsFgBat26N3W53W981QhYRn1KRk3rJyckkJ//0lGqbzYbN5vqTqLds2UJ0dPRll8fFxQHQu3dvl7arQBYRn1KREnJFA/jn1q1bh8Vi4fbbby9z+dy5cwkNDSU7O5t58+bRpEkTIiMjr7hNlSxExKeYKvCrsrZu3cqePXuYMGHCZUfkoaGhAAQHB9OpUyfS0tLK3a4CWUR8ijtryGXZt28f7733HtOmTaNWrVpltsnLyyM3N7f46/3799OiRYtyt62ShYj4FHfOskhMTOTQoUPk5OQwduxYBg8ezPr16ykoKGDu3LkAXHfddYwZMwa73c6KFSuIjY0lOzubxYsXA1BYWEjXrl1p3759ufszOZ1Op9t6X4bXvvi2KjcvXurB9s093QUxIH83DBEHrtrjctt1j9xy9Tt0I42QRcSn6Eo9ERGD0L0sREQMwovzWIEsIr7F4sWJrEAWEZ+ikoWIiEF48QNDFMgi4ls0QhYRMQgvzmMFsoj4Fo2QRUQMwuLFRWQFsoj4FO+NYwWyiPgYdz5Tr7opkEXEp3hxHiuQRcS36KSeiIhBeHEeK5BFxLdoloWIiEGoZHEF7cLqVfUuxAuFdHrC010QA8rd++JVb8ObHxSqEbKI+BSNkEVEDMKLS8gKZBHxLe48qZeUlERqairBwcHEx8cDcP78eRISEjhz5gxhYWFMnjyZwMDAUuvu27eP1atX43A46NWrFzExMeXuz5vLLSIipZhNrr/K0717d2bMmFHisw0bNnDTTTexdOlSbrrpJjZs2FBqPYfDwapVq5gxYwYJCQns2LGDEydOlN93Vw9SRMQbmEyuv8oTGRlZavSbkpJCt27dAOjWrRspKSml1ktLS6NRo0Y0bNgQPz8/oqOjy2z3SypZiIhPqci9LJKTk0lOTi5+b7PZsNlsV1wnOzubkJAQAEJCQjh37lypNna7HavVWvzearVy7NixcvujQBYRn1KRH/tdCeDKcDqdpT5zZfaHShYi4lPcWbIoS3BwMFlZWQBkZWURFBRUqo3VaiUzM7P4fWZmZvGo+koUyCLiUyxmk8uvyoiKimLbtm0AbNu2jU6dOpVqExERwalTp0hPT6egoICdO3cSFRVV7rZVshARn+LOeciJiYkcOnSInJwcxo4dy+DBg4mJiSEhIYEtW7ZQv359pkyZAhTVjVesWEFsbCwWi4XRo0cTFxeHw+GgR48eNG/evNz9mZxlFTvcaO83OVW5efFS0TGxnu6CGJA7Lp2e83Gay22f6d3qqvfnThohi4hP8eIrpxXIIuJbdOm0iIhBmLz4MacKZBHxKX5ePHdMgSwiPkW33xQRMQjVkEVEDMKLB8gKZBHxLRW5uZDRKJBFxKdYdFJPRMQYzJr2JiJiDF5csVAgi4hv0SwLERGD0Ek9ERGD8OI8ViCLiG+p7I3njUCBLCI+xYtnvSmQRcS36F4WIiIG4b1xrEAWER+jWRYiIgbhrjg+efIkCQkJxe/T09MZPHgw/fv3L/7s4MGDLFy4kAYNGgDQuXNnBg0aVOl9KpBFxKeY3TTLokmTJixatAgAh8PBY489xq233lqqXdu2bZk+fbpb9qlAFhGfUhWzLA4cOECjRo0ICwurgq3/RIEsIj6lIrMskpOTSU5OLn5vs9mw2Wyl2u3YsYMuXbqUuY2jR48ydepUQkJCGDFiBM2bN694p/+fAllEfEpFChaXC+CfKygoYM+ePQwdOrTUspYtW5KUlIS/vz+pqaksWrSIpUuXVrDHP/HmOdQiIqWYTCaXX67Yu3cvLVu2pF69eqWWBQQE4O/vD0DHjh0pLCzk3Llzle67Rsgi4lMsbp72dqVyxdmzZwkODsZkMpGWlobD4aBu3bqV3pcCWUR8ijvjOD8/n/379zNmzJjizzZt2gRAnz592LVrF5s2bcJisVCzZk0mTZp0VVcKmpxOp/Oqe30Fe7/JqcrNi5eKjon1dBfEgHL3vnjV23jvwGmX2957U6Or3p87aYQsIj5Fj3ASETEIL75yWoEsIr7FpBGyiIgxuHuWRXVSIFeTJ0bcTe3aAZjNFiwWC88tW+PpLkk1a9awHivnPkRDaxAOp5NX3t3Bsje3EhIUwJrnR3NNk1C+OWln+NOrOJuT6+nuei0vzmMFcnX6w6IVBAXX83Q3xEMKCh1Mf2Ed+w6fIDCgFjvfmMbmzw8z4u7ObN19hMWrP+apUb15alQfZi19z9Pd9VreHMi6Uk+kmpzOOMe+wycAOP9DPoePn6ZJWD0GdG/H2o2fA7B24+fc3aOdJ7vp9UwV+GU0GiFXExMmnosdjwkTvfoPxNZ/oKe7JB7UonEo7a9vRsqXX9PAWpfTGUWX257OOEdYaOWv9BLw4mecKpCry7OJqwi1hpGdZScudjxNm19L23YdPd0t8YA6tWvy5uJHmbr4XXIu5Hm6Oz7Hm58YopJFNQm1Ft1HNTgklE7R3Uk7ctDDPRJP8PMz8+bi3/HWh1/w3pZ/A5CemUOj+kEANKofxBm7rm69Gt5cslAgV4O83Fxyf7hQ/PX+1M9pfm2Eh3slnrB89jCOHD/N0rVbij/7+7YDDL+7MwDD7+7MB1v3e6p7PsFscv1lNCpZVIPss5nEPzsVAEdhIV163En7TtEe7pVUt+j24Qwb0JkDR79j11+LHvkz+8X3Wbz6Y9Y+P5qHY27j21NZDHt6lYd76t2MOPJ1lW4uJB6hmwtJWdxxc6FPj2W53LbrdSFXvT930ghZRHyK946PryKQP/nkE3r06FHmsp8/p+rBx2dWdhciIhX2q7x0+u23375sIP/8OVUqWYhItfLePL5yID/11FNlfu50OsnOzq6SDhnR8vhnSd31KUH1Qlj88tsllm382xr+8vISXvpbcqnLok9++zVL4mYUv08//R0PPPQY/QYOZdf2ZN5Z8xLf/e848/70GhGtIwE4cnAfq5YuwK9GTSbExtGoaXMunM9hSVwssc/96aqeRiDud/jvz5JzIZ9Ch4OCQgddhy2kXeum/Gnmb6lVqwYFhQ4mPfcWXxz8ptS6vx/Wg5H3ReN0OjmYdpIxs9eSf7GA5ybF0O+OG7l4qZDjJzIYM3st2edzue3mcJbMeJCLlwp4KHY1//02g+DA2qx5fjT3jF/mgaM3Jm8+qXfFQM7OzmbmzJnUqVOnxOdOp5M//OEPVdoxI+nW+27uvOdBli18psTnGemnOZD6OfUblP3UgSbNr+X55W8ARbMrHh/aj05din6qaH5tBFOeWcjLS54rsc4H7/yFyc8s5Mzpk3z8wTuMeGwy6/6ykpghoxTGBtV3zBIyz14ofh83KYa4lz5k045D3Nk1krhJMdz5uyUl1mkSFsy4Id3ocH8cefmXWPv8aB648xbWbvyczbsO84c/vU9hoYN5E+5l6uiie1tMHNGTIVNXck1jK2MeuJ3pL6wndkxfFr7yUXUfsqF58z+TK85D7tixI3l5eYSFhZV4NWjQgMjIyOrqo8e1bdeROnWDSn3++vIXGPboBJe+Aw7sTaFh46aENWwMQNMWLWnS/NpS7Sx+flzMz+difh4WPz9OnzyBPSOdyHa3XPVxSPVwOiGoTtGTiIMDa3PqTNk/TfpZLNSuVQOLxUxt/5rF7TbvOkxhoQOA3QeO07RhPQAuFRRSu1YNAmrX4FJBIS2b1adJg3p8uiet6g/Ki5gq8DKaK46QH3/88csumzhxots7402++GwbofUbcE1Ea5faf7btI6J73Fluu5jfjuTlxDhq1qrF+KfnsPalRAaPvPzfg3iW0+lkY9ITOJ1OVr27g1fW7WDq4nfYuGw88yffh9lsosfI+FLrnTyTTeLrmzn64Vxy8y+y+bPDbN51uFS7h+69jXc2pQKw6JVNLJs1hNz8Szwy63XmT7mPZ5M+qPJj9DpGTFoXadpbJeTn5bH+jVeYucC1ul3BpUvs+Ww7vx39RLltr424nnlLXwXgq/2phFjDcDqdJMbF4mfxY/hjk6gXYr2a7osb9RyVwKkz2YSFBPLB8ic48vVpBto68HT8OjZs3sf9vTvw59nD6D+25PzaenVrM6D7TbQdMJuzOT/wxsJH+G2/Tvz1HynFbZ5+5E4KCx3Fn+0/+h3dHi4K9y4dIzh1JhsTJtYsGMWlgkKmv7CedF127dZ7WYwfPx5/f3/MZjMWi4UFCxaUWO50Olm9ejV79+6lVq1ajBs3jvDw8ErvT5dOV8L3p05w5vRJnh47hCdG3I39TDqx44Zx1p5RZvt9KTu4tlWbCgWp0+lk3RurGDjsUd5d8zIPjHiMrr3u4p8b/uquwxA3+LHMcCbrPO9v2U+nG65l2IDObNi8D4B3P95L1A3XlFqvZ+c2fH0yk4ys8xQUONiw5d/85uaWxcuH3d2ZfnfcyMiZr5a53+mP9mX+Sx8y87G7mLv8H7z5jxTGDenu7sPzSu4uWcyePZtFixaVCmOAvXv3cvr0aZYuXcqYMWNYuXLlVfVdgVwJLVq24qW/fcyLazby4pqNhIY1YH7SX6gXWr/M9js++YguLpQrfm7bxx/QoXNXAusGcTE/D5PZhMlkJj9PdwczigD/mgQG1Cr+2nZbGw7+5ySnzmRz+y3XAdD91tak/e9MqXW/PW3n1ptaUtu/BgA9br2eI8e/B6B3dFueHGlj0KQV5OZdKrXu8Ls7889/HeRsTi4B/jVxOJw4HE4C/n9bv3rVWET+4osvuOOOOzCZTLRu3ZoLFy6QleX6lYK/pJKFC5Y+N4ND+/eQk32WcUP7MWjEGHreFVNmW3vmGV56YS7T45YCReWNA6m7+d2kkhfI7P70E15NWsS57CwWzprENRGtmTH/xeJ1tn/8ATPmF5VE+t0/jIQ5T+PnV4Pfz4irugOVCmlgrctbL/wOKDpB99aHX/Dxzq8Y/8MbLJo6CD8/M/n5BTwx700AGocFk/TMUO77/Z9J+fIb1ifv5bM3plFQ6ODfh0+w6t0dACRMG0ytmn588OeiEtfuA18zIa7oJ6Pa/jUYfndnBowr+l5ZunYLby5+lIuXCng49tVq/hMwpopMe/v5RWxQ8hqKH8XFFf2b6927d6lldrud+vV/GohZrVbsdjshIZW7JFv3shCP0L0spCzuuJfFvv+5njntW1z5YQB2u53Q0FCys7OZN28eo0aNKjHDbP78+dx33320adMGgDlz5jB8+PBK15FVshARn2Iyuf4qT2hoKADBwcF06tSJtLSSUwytVisZGT+dO8rMzKz06BgUyCLiY9x1g/q8vDxyc3OLv96/fz8tWrQo0SYqKort27fjdDo5evQoAQEBVxXIqiGLiE9x16y37OxsFi9eDEBhYSFdu3alffv2bNq0CYA+ffrQoUMHUlNTmTBhAjVr1mTcuHFXtU/VkMUjVEOWsrijhvzlifMut72xWeBV78+dNEIWEd+iK/VERIzBZ+/2JiLibYz48FJXKZBFxLcokEVEjEElCxERg/DmG9QrkEXEp3hxHiuQRcTHeHEiK5BFxKe48wb11U2BLCI+xXvjWIEsIr7GixNZgSwiPkXT3kREDMKLS8gKZBHxLQpkERGDUMlCRMQgNEIWETEIL85jBbKI+BaNkEVEDMN7E1mBLCI+xV03qM/IyGDZsmWcPXsWk8mEzWajX79+JdocPHiQhQsX0qBBAwA6d+7MoEGDKr1PBbKI+BR3lSwsFgsjRowgPDyc3Nxcpk+fTrt27WjWrFmJdm3btmX69Olu2afZLVsRETEIUwV+XUlISAjh4eEA1K5dm6ZNm2K326u07xohi4hvqYIScnp6OsePH6dVq1allh09epSpU6cSEhLCiBEjaN68eaX3o0AWEZ9SkTxOTk4mOTm5+L3NZsNms5Vok5eXR3x8PCNHjiQgIKDEspYtW5KUlIS/vz+pqaksWrSIpUuXVrrvCmQR8SkVqSGXFcA/V1BQQHx8PLfffjudO3cutfznAd2xY0dWrVrFuXPnCAoKqlCff6Qasoj4FJPJ5PLrSpxOJ8uXL6dp06YMGDCgzDZnz57F6XQCkJaWhsPhoG7dupXuu0bIIuJT3FVCPnLkCNu3b6dFixZMnToVgCFDhpCRkQFAnz592LVrF5s2bcJisVCzZk0mTZpUbtBficn5Y7xXkb3f5FTl5sVLRcfEeroLYkC5e1+86m1kXihwua21jrHGpMbqjYjIVdLd3kREDEL3shARMQgFsoiIQahkISJiEBohi4gYhBfnsQJZRHyMFyeyAllEfIpqyCIiBuGuG9R7ggJZRHyLAllExBhUshARMQhvnvZW5TcXkp8kJydf8d6r8uuk7wv5ke6HXI1+/mQCkR/p+0J+pEAWETEIBbKIiEEokKuR6oRSFn1fyI90Uk9ExCA0QhYRMQgFsoiIQejCkGqyb98+Vq9ejcPhoFevXsTExHi6S+JhSUlJpKamEhwcTHx8vKe7IwagEXI1cDgcrFq1ihkzZpCQkMCOHTs4ceKEp7slHta9e3dmzJjh6W6IgSiQq0FaWhqNGjWiYcOG+Pn5ER0dTUpKiqe7JR4WGRlJYGCgp7shBqJArgZ2ux2r1Vr83mq1YrfbPdgjETEiBXI1KGtmocmb74AiIlVCgVwNrFYrmZmZxe8zMzMJCQnxYI9ExIgUyNUgIiKCU6dOkZ6eTkFBATt37iQqKsrT3RIRg9GVetUkNTWV1157DYfDQY8ePRg4cKCnuyQelpiYyKFDh8jJySE4OJjBgwfTs2dPT3dLPEiBLCJiECpZiIgYhAJZRMQgFMgiIgahQBYRMQgFsoiIQSiQRUQMQoEsImIQ/weBC2wqwr7NJQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "# function to plot confusion matrix\n", + "def plot_conf_matrix(conf_matrix):\n", + " group_counts = [\"{0:0.0f}\".format(value) for value in\n", + " conf_matrix.flatten()]\n", + " group_percentages = [\"{0:.2%}\".format(value) for value in\n", + " conf_matrix.flatten()/np.sum(conf_matrix)]\n", + " labels = [f\"{v1}\\n{v2}\" for v1, v2 in\n", + " zip(group_counts,group_percentages)]\n", + " labels = np.asarray(labels).reshape(2,2)\n", + " \n", + " sns.heatmap(conf_matrix, annot=labels, fmt='', cmap='Blues')\n", + " \n", + "# Creating confusion matrix\n", + "conf_matrix = np.array( [[tp, fp], \n", + " [fn, tn]] )\n", + "\n", + "plot_conf_matrix(conf_matrix)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using Bag of Words Vectorizer" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer\n", + "\n", + "vectorizer = CountVectorizer(min_df=3, ngram_range=(1, 1))\n", + "vectorizer.fit(sentences_train)\n", + "\n", + "X_train = vectorizer.transform(sentences_train)\n", + "X_test = vectorizer.transform(sentences_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\Jared\\anaconda3\\lib\\site-packages\\tensorflow\\python\\framework\\indexed_slices.py:447: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor(\"gradient_tape/sequential/dense/embedding_lookup_sparse/Reshape_1:0\", shape=(None,), dtype=int32), values=Tensor(\"gradient_tape/sequential/dense/embedding_lookup_sparse/Reshape:0\", shape=(None, 32), dtype=float32), dense_shape=Tensor(\"gradient_tape/sequential/dense/embedding_lookup_sparse/Cast:0\", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training Accuracy: 100.0000\n", + "Testing Accuracy: 73.5294\n" + ] + } + ], + "source": [ + "from keras.models import Sequential\n", + "from keras import layers\n", + "from keras.backend import clear_session\n", + "\n", + "from numpy.random import seed\n", + "seed(1)\n", + "from tensorflow.random import set_seed \n", + "set_seed(2)\n", + "\n", + "clear_session()\n", + "\n", + "# Number of features\n", + "input_dim = X_train.shape[1] \n", + "\n", + "# create model with one hidden layer\n", + "model = Sequential()\n", + "model.add(layers.Dense(32, input_dim=input_dim, activation='relu'))\n", + "model.add(layers.Dense(1, activation='sigmoid'))\n", + "\n", + "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n", + "\n", + "history = model.fit(X_train, y_train,\n", + " epochs=30,\n", + " verbose=False,\n", + " validation_data=(X_test, y_test),\n", + " batch_size=32)\n", + "\n", + "loss, accuracy = model.evaluate(X_train, y_train, verbose=False)\n", + "print(\"Training Accuracy: {:.4f}\".format(accuracy*100))\n", + "loss, accuracy = model.evaluate(X_test, y_test, verbose=False)\n", + "print(\"Testing Accuracy: {:.4f}\".format(accuracy*100))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# plot accuracy and loss plots\n", + "plot_history(history)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hyperparameter Optimization" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def create_model(num_filters, kernel_size, vocab_size, embedding_dim, maxlen):\n", + " model = Sequential()\n", + " model.add(layers.Embedding(vocab_size, embedding_dim, input_length=maxlen))\n", + " model.add(layers.Conv1D(num_filters, kernel_size, activation='relu'))\n", + " model.add(layers.GlobalMaxPooling1D())\n", + " model.add(layers.Dense(10, activation='relu'))\n", + " model.add(layers.Dense(1, activation='sigmoid'))\n", + " model.compile(optimizer='adam',\n", + " loss='binary_crossentropy',\n", + " metrics=['accuracy'])\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 3 folds for each of 5 candidates, totalling 15 fits\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 15 out of 15 | elapsed: 7.8min finished\n" + ] + }, + { + "data": { + "text/plain": [ + "({'vocab_size': 410796,\n", + " 'num_filters': 128,\n", + " 'maxlen': 100,\n", + " 'kernel_size': 7,\n", + " 'embedding_dim': 50},)" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from keras.wrappers.scikit_learn import KerasClassifier\n", + "from sklearn.model_selection import RandomizedSearchCV\n", + "\n", + "# main settings\n", + "epochs = 20\n", + "embedding_dim = 50\n", + "maxlen = 100\n", + "\n", + "sentences = data['cleaned_string'].values\n", + "y = data['Effectiveness'].values\n", + "\n", + "# train-test split\n", + "sentences_train, sentences_test, y_train, y_test = train_test_split(sentences, y, test_size=0.25, random_state=1000)\n", + "\n", + "# tokenize words\n", + "tokenizer = Tokenizer(num_words=5000)\n", + "tokenizer.fit_on_texts(sentences_train)\n", + "X_train = tokenizer.texts_to_sequences(sentences_train)\n", + "X_test = tokenizer.texts_to_sequences(sentences_test)\n", + "\n", + "# add 1 because of reserved 0 index\n", + "vocab_size = len(tokenizer.word_index) + 1\n", + "\n", + "# pad sequences with zeros\n", + "X_train = pad_sequences(X_train, padding='post', maxlen=maxlen)\n", + "X_test = pad_sequences(X_test, padding='post', maxlen=maxlen)\n", + "\n", + "# parameter grid for grid search\n", + "param_grid = dict(num_filters=[32, 64, 128],\n", + " kernel_size=[3, 5, 7],\n", + " vocab_size=[vocab_size],\n", + " embedding_dim=[embedding_dim],\n", + " maxlen=[maxlen])\n", + "\n", + "model = KerasClassifier(build_fn=create_model,\n", + " epochs=epochs, batch_size=10,\n", + " verbose=False)\n", + "\n", + "grid = RandomizedSearchCV(estimator=model, param_distributions=param_grid, cv=3, verbose=1, n_iter=5)\n", + "\n", + "grid_result = grid.fit(X_train, y_train)\n", + "\n", + "# Evaluate testing set\n", + "test_accuracy = grid.score(X_test, y_test)\n", + "grid_result.best_score_\n", + "grid_result.best_params_" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}