{ "cells": [ { "cell_type": "markdown", "id": "a9fd592d-0ad0-4f1b-94f4-0e5ae9dbeec7", "metadata": { "id": "a9fd592d-0ad0-4f1b-94f4-0e5ae9dbeec7" }, "source": [ "# Fine-Tune a Causal Language Model for Dialogue Summarization" ] }, { "cell_type": "markdown", "id": "976f3255-166f-4e6f-aaf0-ef6c4703b9f7", "metadata": { "id": "976f3255-166f-4e6f-aaf0-ef6c4703b9f7" }, "source": [ "Fine-tune Meta's Llama 2 base version for enhanced topic summarization creation of mutlitple choice question (MCQ). Llama 2 is a large language model (LLM) free for research and commercial use. It is one of the top-performing open-source LLM comparable to GPT-3.5 on several benchmarks.\n", "\n", "We will explore the use of Parameter Efficient Fine-Tuning (PEFT - lora) for fine-tuning, and evaluate the resulting model using ROUGE metrics." ] }, { "cell_type": "markdown", "id": "a3b6ce61-1b18-412e-8ea9-b1aa4b6ea33d", "metadata": { "id": "a3b6ce61-1b18-412e-8ea9-b1aa4b6ea33d" }, "source": [ "## Install the pre-requisites\n", "\n", "Uncomment the following if these python packages have not been installed" ] }, { "cell_type": "code", "execution_count": null, "id": "4da7e887-227f-4cce-983a-83ab3405de65", "metadata": { "tags": [], "id": "4da7e887-227f-4cce-983a-83ab3405de65", "outputId": "0e0d4037-fecc-497f-e670-c8ec02e4ec32", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.2)\n", "Collecting datasets\n", " Downloading datasets-2.17.0-py3-none-any.whl (536 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m536.6/536.6 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting accelerate\n", " Downloading accelerate-0.27.2-py3-none-any.whl (279 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m280.0/280.0 kB\u001b[0m \u001b[31m22.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n", "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.11.4)\n", "Collecting peft\n", " Downloading peft-0.8.2-py3-none-any.whl (183 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m183.4/183.4 kB\u001b[0m \u001b[31m24.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting bitsandbytes\n", " Downloading bitsandbytes-0.42.0-py3-none-any.whl (105.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m105.0/105.0 MB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting evaluate\n", " Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m13.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting rouge_score\n", " Downloading rouge_score-0.1.2.tar.gz (17 kB)\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.20.3)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.12.25)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.2)\n", "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.2)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.2)\n", "Collecting pyarrow>=12.0.0 (from datasets)\n", " Downloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.3/38.3 MB\u001b[0m \u001b[31m38.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n", "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.8-py3-none-any.whl (116 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", "Collecting multiprocess (from datasets)\n", " Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.3)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.1.0+cu121)\n", "Collecting responses<0.19 (from evaluate)\n", " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n", "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.4.0)\n", "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge_score) (3.8.1)\n", "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.16.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.9.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.2.2)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.3)\n", "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.1.0)\n", "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (8.1.7)\n", "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (1.3.2)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.4)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n", "Building wheels for collected packages: rouge_score\n", " Building wheel for rouge_score (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24933 sha256=97db1dbbde83d29517ebbf637d3e9665143ece23701de0adadeb3b10706c43fe\n", " Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4\n", "Successfully built rouge_score\n", "Installing collected packages: pyarrow, dill, rouge_score, responses, multiprocess, bitsandbytes, accelerate, datasets, peft, evaluate\n", " Attempting uninstall: pyarrow\n", " Found existing installation: pyarrow 10.0.1\n", " Uninstalling pyarrow-10.0.1:\n", " Successfully uninstalled pyarrow-10.0.1\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "ibis-framework 7.1.0 requires pyarrow<15,>=2, but you have pyarrow 15.0.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed accelerate-0.27.2 bitsandbytes-0.42.0 datasets-2.17.0 dill-0.3.8 evaluate-0.4.1 multiprocess-0.70.16 peft-0.8.2 pyarrow-15.0.0 responses-0.18.0 rouge_score-0.1.2\n" ] } ], "source": [ "!pip install transformers datasets accelerate sentencepiece scipy peft bitsandbytes evaluate rouge_score" ] }, { "cell_type": "markdown", "id": "a39feac3-3c0f-4057-83e2-898a0329cbee", "metadata": { "id": "a39feac3-3c0f-4057-83e2-898a0329cbee" }, "source": [ "## Request access to Llama-2 weights\n", "\n", "You need to request for access to download the Llama 2 weights. You can either do so through this [link at Meta](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) or through your huggingface account at this [link](https://huggingface.co/meta-llama/Llama-2-7b). Once your request is approved, you will receive an email from Meta with instruction to download the Llama 2 weights, or email from Hugging Face informing you access has been granted.\n", "\n", "If you download the weights from Meta directly, you need to run a conversion script to convert the weights to huggingface format for use with huggingface transformer library." ] }, { "cell_type": "code", "execution_count": null, "id": "4178ec51-ccf8-4d01-8224-e57c8b0683f1", "metadata": { "tags": [], "id": "4178ec51-ccf8-4d01-8224-e57c8b0683f1" }, "outputs": [], "source": [ "# %%bash\n", "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n", "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B" ] }, { "cell_type": "code", "execution_count": null, "id": "00149073-36b6-4317-818b-c8e0485cd8d2", "metadata": { "tags": [], "id": "00149073-36b6-4317-818b-c8e0485cd8d2", "outputId": "367401cb-334c-430e-efdf-d2f795f4d5ba", "colab": { "base_uri": "https://localhost:8080/", "height": 145, "referenced_widgets": [ "c438af202cdf45519dc13ffe6f22037d", "99b5ebef5fe54a82b20e1f9e196111d0", "e463243dece7453da77c8775a4fe5f94", "44344f5882fd4e3aadc05cc18b15d571", "eb26fc1a3e6047078528f3f92ccd27b9", "53927a52e58544b6bdbb71382a3f0d45", "1e74458632d2490685769a7a1f46dce9", "e4686d7a955b4ad9a65b323285430838", "dc3b4d57e6544e1082ac417062306808", "de3d1e34487d4763a2b19c8f7567c8f3", "07bc154a336b402f9156c30cba0ed696", "ac978db16d5f442d8930984fa4498717", "5e0f629e3e094748b2f03675e4370989", "498c2dac23ff4355b3475328c7d77deb", "96f915a83bb843b5913508c7cf7debe4", "315c11cdbd3b4fcf93881deae18b8a4b", "8e1c860e75bd42a2ae54047ba436ea01", "97cc7d3fd82646cf8a64985c4257ec47", "84b644b814bd4638803555c93102b329", "2c7386d8b79846d8a56821837d2faa82", "bb5c7c3283de41f385d4a394da65c3d6", "84bc91a6c29a4963996a93ecf3b9bcd4", "6ebbe9df3902460686665a4a9a10bd5f", "38905af71dc2456aab7ab6298a3de610", "2e579d901ca440c38ecd38e99b0fdd5d", "6672ac14178549349bfb740c1e5fe754", "8dd8fda1fa314f048022dd17e432c59b", "45331f2eed8f45ea95f25fea8faf2446", "bd6c352f84364c14a0f9c7a81d2f2ea6", "e287c06077904002925102bdfcb9e223", "0f6031eaf46b4d24b66d94ba5e8a49c0", "43a1107155bf475cb47376268954c2ec" ] } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='
Step | \n", "Training Loss | \n", "Validation Loss | \n", "
---|---|---|
10 | \n", "1.681300 | \n", "1.631856 | \n", "
20 | \n", "1.583000 | \n", "1.498535 | \n", "
30 | \n", "1.413800 | \n", "1.393412 | \n", "
40 | \n", "1.360300 | \n", "1.308431 | \n", "
50 | \n", "1.271100 | \n", "1.237813 | \n", "
60 | \n", "1.225500 | \n", "1.179109 | \n", "
70 | \n", "1.168600 | \n", "1.139814 | \n", "
80 | \n", "1.082900 | \n", "1.112842 | \n", "
90 | \n", "1.114000 | \n", "1.087292 | \n", "
100 | \n", "1.077200 | \n", "1.062716 | \n", "
110 | \n", "1.020000 | \n", "1.039254 | \n", "
120 | \n", "0.994900 | \n", "1.012544 | \n", "
130 | \n", "0.990900 | \n", "0.976476 | \n", "
140 | \n", "0.910700 | \n", "0.923615 | \n", "
150 | \n", "0.871000 | \n", "0.896943 | \n", "
160 | \n", "0.897800 | \n", "0.883769 | \n", "
170 | \n", "0.869700 | \n", "0.875816 | \n", "
180 | \n", "0.867600 | \n", "0.870564 | \n", "
190 | \n", "0.797800 | \n", "0.867839 | \n", "
200 | \n", "0.804900 | \n", "0.867121 | \n", "
"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TrainOutput(global_step=200, training_loss=1.100144896507263, metrics={'train_runtime': 3605.5794, 'train_samples_per_second': 0.222, 'train_steps_per_second': 0.055, 'total_flos': 1.6248515592192e+16, 'train_loss': 1.100144896507263, 'epoch': 3.25})"
]
},
"metadata": {},
"execution_count": 33
}
],
"source": [
"# Start training\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93524441-36a1-4f7e-8307-b855e68c5c14",
"metadata": {
"tags": [],
"id": "93524441-36a1-4f7e-8307-b855e68c5c14",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 178
},
"outputId": "2bba0ba1-3427-4b6e-8cf9-258363851fe5"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.