{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "machine_shape": "hm", "gpuType": "A100" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "7c57de542a934cb0b91fb2b0e8c88e7b": { "model_module": "@jupyter-widgets/controls", "model_name": "VBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "VBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "VBoxView", "box_style": "", "children": [ "IPY_MODEL_50a020332d8e4afbbaf5c86c5a316a8e", "IPY_MODEL_f388c680bd57423297e98246a4bd4eb0", "IPY_MODEL_8acc093a26e34ccc9e4e7c0252f4716b", "IPY_MODEL_07018e8f53474ca390b9922cd7f71661" ], "layout": "IPY_MODEL_c913025b0e76407e961e89c335ab1753" } }, "e09a6eac62cf4aa99168d0e50fec5314": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_56eef714b9c5492885577f21e9d62733", "placeholder": "", "style": "IPY_MODEL_1c5a50faaf8a49f49be022b5c2876d22", "value": "
Step | \n", "Training Loss | \n", "
---|---|
1 | \n", "2.035200 | \n", "
2 | \n", "2.440300 | \n", "
3 | \n", "1.987100 | \n", "
4 | \n", "2.004400 | \n", "
5 | \n", "2.109200 | \n", "
6 | \n", "1.612300 | \n", "
7 | \n", "1.610300 | \n", "
8 | \n", "2.008700 | \n", "
9 | \n", "1.333500 | \n", "
10 | \n", "1.556800 | \n", "
11 | \n", "1.365500 | \n", "
12 | \n", "1.463800 | \n", "
13 | \n", "1.172800 | \n", "
14 | \n", "1.170800 | \n", "
15 | \n", "1.415200 | \n", "
16 | \n", "1.206500 | \n", "
17 | \n", "1.114500 | \n", "
18 | \n", "1.058100 | \n", "
19 | \n", "1.007300 | \n", "
20 | \n", "1.226100 | \n", "
21 | \n", "0.810900 | \n", "
22 | \n", "1.011300 | \n", "
23 | \n", "1.131600 | \n", "
24 | \n", "0.953600 | \n", "
25 | \n", "0.862700 | \n", "
26 | \n", "0.854000 | \n", "
27 | \n", "1.255600 | \n", "
28 | \n", "0.990600 | \n", "
29 | \n", "1.103300 | \n", "
30 | \n", "1.091000 | \n", "
31 | \n", "1.018300 | \n", "
32 | \n", "0.840600 | \n", "
33 | \n", "1.081000 | \n", "
34 | \n", "1.113600 | \n", "
35 | \n", "1.003300 | \n", "
36 | \n", "1.325900 | \n", "
37 | \n", "0.866900 | \n", "
38 | \n", "0.912600 | \n", "
39 | \n", "1.007300 | \n", "
40 | \n", "0.761800 | \n", "
41 | \n", "1.147900 | \n", "
42 | \n", "0.762900 | \n", "
43 | \n", "0.962800 | \n", "
44 | \n", "1.122300 | \n", "
45 | \n", "0.941600 | \n", "
46 | \n", "0.985300 | \n", "
47 | \n", "0.903500 | \n", "
48 | \n", "0.889100 | \n", "
49 | \n", "0.983100 | \n", "
50 | \n", "0.814300 | \n", "
51 | \n", "1.043200 | \n", "
52 | \n", "0.753100 | \n", "
53 | \n", "0.761000 | \n", "
54 | \n", "0.817300 | \n", "
55 | \n", "1.039500 | \n", "
56 | \n", "0.811700 | \n", "
57 | \n", "0.842200 | \n", "
58 | \n", "0.892900 | \n", "
59 | \n", "0.863500 | \n", "
60 | \n", "0.874400 | \n", "
61 | \n", "0.670500 | \n", "
62 | \n", "1.125400 | \n", "
63 | \n", "1.007000 | \n", "
64 | \n", "0.959700 | \n", "
65 | \n", "0.860100 | \n", "
66 | \n", "0.868600 | \n", "
67 | \n", "0.687900 | \n", "
68 | \n", "0.855600 | \n", "
69 | \n", "0.996800 | \n", "
70 | \n", "1.227800 | \n", "
71 | \n", "0.788800 | \n", "
72 | \n", "1.131100 | \n", "
73 | \n", "0.939900 | \n", "
74 | \n", "0.848600 | \n", "
75 | \n", "1.160700 | \n", "
76 | \n", "0.847100 | \n", "
77 | \n", "0.987500 | \n", "
78 | \n", "0.857900 | \n", "
79 | \n", "0.818400 | \n", "
80 | \n", "0.981400 | \n", "
81 | \n", "1.127600 | \n", "
82 | \n", "0.990600 | \n", "
83 | \n", "0.886300 | \n", "
84 | \n", "0.772400 | \n", "
85 | \n", "1.013000 | \n", "
86 | \n", "1.049300 | \n", "
87 | \n", "1.035500 | \n", "
88 | \n", "0.812300 | \n", "
89 | \n", "0.888700 | \n", "
90 | \n", "0.808700 | \n", "
91 | \n", "1.126400 | \n", "
92 | \n", "0.720200 | \n", "
93 | \n", "0.835700 | \n", "
94 | \n", "0.985800 | \n", "
95 | \n", "0.938100 | \n", "
96 | \n", "0.824300 | \n", "
97 | \n", "0.872600 | \n", "
98 | \n", "1.139100 | \n", "
99 | \n", "0.944100 | \n", "
100 | \n", "0.819500 | \n", "
101 | \n", "0.664200 | \n", "
102 | \n", "0.694100 | \n", "
103 | \n", "0.850700 | \n", "
104 | \n", "0.677200 | \n", "
105 | \n", "1.015500 | \n", "
106 | \n", "0.979900 | \n", "
107 | \n", "0.680900 | \n", "
108 | \n", "0.778400 | \n", "
109 | \n", "0.862600 | \n", "
110 | \n", "0.802300 | \n", "
111 | \n", "0.677100 | \n", "
112 | \n", "0.982300 | \n", "
113 | \n", "1.114900 | \n", "
114 | \n", "0.908700 | \n", "
115 | \n", "0.741500 | \n", "
116 | \n", "0.653900 | \n", "
117 | \n", "0.755000 | \n", "
118 | \n", "1.240100 | \n", "
119 | \n", "0.914800 | \n", "
120 | \n", "0.885100 | \n", "
121 | \n", "0.847100 | \n", "
122 | \n", "0.726500 | \n", "
123 | \n", "0.991600 | \n", "
124 | \n", "0.718400 | \n", "
125 | \n", "0.754300 | \n", "
126 | \n", "0.768900 | \n", "
127 | \n", "0.882800 | \n", "
128 | \n", "0.836900 | \n", "
129 | \n", "1.157700 | \n", "
"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Model Training on Dataset 2"
],
"metadata": {
"id": "veNHf3ltLuRM"
}
},
{
"cell_type": "code",
"source": [
"# Define trainer arguments\n",
"trainer_2_args = TrainingArguments(\n",
" per_device_train_batch_size=2,\n",
" gradient_accumulation_steps=2,\n",
" num_train_epochs=3,\n",
" learning_rate=2e-4,\n",
" warmup_ratio=0.03,\n",
" fp16=True,\n",
" logging_steps=5,\n",
" output_dir=\"outputs\")\n",
"\n",
"\n",
"# Define trainer\n",
"trainer_2 = Trainer(\n",
" model=model,\n",
" args=trainer_2_args,\n",
" train_dataset=tokenized_dataset_2[\"text\"],\n",
" data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
")\n",
"\n",
"\n",
"# Train model\n",
"model.config.use_cache = False # Supress Warnings, re-enable for inference later\n",
"trainer_2.train()\n",
"\n",
"\n",
"# Save the fine-tuned model\n",
"trainer_2.save_model(\"finetuned_model_2\")"
],
"metadata": {
"id": "w6OeNB7Vzf70",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "12609eb1-af81-40a4-ad68-32fe0749d2ac"
},
"execution_count": 10,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
" "
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Upload To HuggingFace Hub"
],
"metadata": {
"id": "ANKGOAzX7adc"
}
},
{
"cell_type": "code",
"source": [
"model.push_to_hub(\"John4Blues/Llama-3-8B-Therapy\", use_auth_token=True, commit_message=\"Just A Basic Trained Model\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 84,
"referenced_widgets": [
"1a741f5ef12546a693be8960e60674cb",
"32416b54fe03455bb8312d9885665a17",
"5920aec36c114a78ba1ec41c8755fa2b",
"c8c21ce1b1364a9f87be2cb78d428ecf",
"ca5130da04d24ed4ae95fe18158a5a62",
"9d18488e5bae46e29f73dd4d7fdddcf2",
"9d073b0f97094a1f93ea891814da353b",
"2a1f7704665b4c2a84ca643e940f2ae2",
"2a5a98922beb47abaca3274f0250aadd",
"c8996659a61c45c389ba6d251276d61e",
"654206dbf9f14fc294f51ddd5c2b7cf6"
]
},
"id": "6g9s3XK97xtY",
"outputId": "0bff6491-e420-4491-93f1-6edaf2bac8a0"
},
"execution_count": 12,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"adapter_model.safetensors: 0%| | 0.00/27.3M [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "1a741f5ef12546a693be8960e60674cb"
}
},
"metadata": {}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"CommitInfo(commit_url='https://huggingface.co/John4Blues/Llama-3-8B-Therapy/commit/af765571aaebac3fae1dea710f8f306651de60f9', commit_message='Just A Basic Trained Model', commit_description='', oid='af765571aaebac3fae1dea710f8f306651de60f9', pr_url=None, pr_revision=None, pr_num=None)"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "markdown",
"source": [
"# Inferencing\n",
"\n",
"https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"
],
"metadata": {
"id": "mN0SQ4y_jC29"
}
}
]
}\n",
" \n",
"
\n",
" \n",
" \n",
" \n",
" Step \n",
" Training Loss \n",
" \n",
" \n",
" 5 \n",
" 1.937900 \n",
" \n",
" \n",
" 10 \n",
" 1.928900 \n",
" \n",
" \n",
" 15 \n",
" 1.850000 \n",
" \n",
" \n",
" 20 \n",
" 1.922100 \n",
" \n",
" \n",
" 25 \n",
" 1.673800 \n",
" \n",
" \n",
" 30 \n",
" 1.782600 \n",
" \n",
" \n",
" 35 \n",
" 1.923600 \n",
" \n",
" \n",
" 40 \n",
" 1.805500 \n",
" \n",
" \n",
" 45 \n",
" 1.884300 \n",
" \n",
" \n",
" 50 \n",
" 1.886500 \n",
" \n",
" \n",
" 55 \n",
" 1.782400 \n",
" \n",
" \n",
" 60 \n",
" 1.816200 \n",
" \n",
" \n",
" 65 \n",
" 1.919600 \n",
" \n",
" \n",
" 70 \n",
" 1.710600 \n",
" \n",
" \n",
" 75 \n",
" 1.896300 \n",
" \n",
" \n",
" 80 \n",
" 1.755300 \n",
" \n",
" \n",
" 85 \n",
" 1.819700 \n",
" \n",
" \n",
" 90 \n",
" 1.793800 \n",
" \n",
" \n",
" 95 \n",
" 1.757700 \n",
" \n",
" \n",
" 100 \n",
" 1.777400 \n",
" \n",
" \n",
" 105 \n",
" 1.692000 \n",
" \n",
" \n",
" 110 \n",
" 1.876600 \n",
" \n",
" \n",
" 115 \n",
" 1.770800 \n",
" \n",
" \n",
" 120 \n",
" 1.846000 \n",
" \n",
" \n",
" 125 \n",
" 1.906200 \n",
" \n",
" \n",
" 130 \n",
" 1.767400 \n",
" \n",
" \n",
" 135 \n",
" 1.749600 \n",
" \n",
" \n",
" 140 \n",
" 1.797900 \n",
" \n",
" \n",
" 145 \n",
" 1.762100 \n",
" \n",
" \n",
" 150 \n",
" 1.796600 \n",
" \n",
" \n",
" 155 \n",
" 1.796800 \n",
" \n",
" \n",
" 160 \n",
" 1.762500 \n",
" \n",
" \n",
" 165 \n",
" 1.832900 \n",
" \n",
" \n",
" 170 \n",
" 1.823500 \n",
" \n",
" \n",
" 175 \n",
" 1.885300 \n",
" \n",
" \n",
" 180 \n",
" 1.826200 \n",
" \n",
" \n",
" 185 \n",
" 1.799100 \n",
" \n",
" \n",
" 190 \n",
" 1.739100 \n",
" \n",
" \n",
" 195 \n",
" 1.867600 \n",
" \n",
" \n",
" 200 \n",
" 1.809800 \n",
" \n",
" \n",
" 205 \n",
" 1.800100 \n",
" \n",
" \n",
" 210 \n",
" 1.798900 \n",
" \n",
" \n",
" 215 \n",
" 1.835800 \n",
" \n",
" \n",
" 220 \n",
" 1.751300 \n",
" \n",
" \n",
" 225 \n",
" 1.710000 \n",
" \n",
" \n",
" 230 \n",
" 1.881700 \n",
" \n",
" \n",
" 235 \n",
" 1.793300 \n",
" \n",
" \n",
" 240 \n",
" 1.806900 \n",
" \n",
" \n",
" 245 \n",
" 1.770700 \n",
" \n",
" \n",
" 250 \n",
" 1.796700 \n",
" \n",
" \n",
" 255 \n",
" 1.769900 \n",
" \n",
" \n",
" 260 \n",
" 1.784300 \n",
" \n",
" \n",
" 265 \n",
" 1.811600 \n",
" \n",
" \n",
" 270 \n",
" 1.732000 \n",
" \n",
" \n",
" 275 \n",
" 1.666400 \n",
" \n",
" \n",
" 280 \n",
" 1.677400 \n",
" \n",
" \n",
" 285 \n",
" 1.820700 \n",
" \n",
" \n",
" 290 \n",
" 1.659500 \n",
" \n",
" \n",
" 295 \n",
" 1.667800 \n",
" \n",
" \n",
" 300 \n",
" 1.765100 \n",
" \n",
" \n",
" 305 \n",
" 1.719200 \n",
" \n",
" \n",
" 310 \n",
" 1.828000 \n",
" \n",
" \n",
" 315 \n",
" 1.805600 \n",
" \n",
" \n",
" 320 \n",
" 1.781000 \n",
" \n",
" \n",
" 325 \n",
" 1.662300 \n",
" \n",
" \n",
" 330 \n",
" 1.742200 \n",
" \n",
" \n",
" 335 \n",
" 1.714500 \n",
" \n",
" \n",
" 340 \n",
" 1.693700 \n",
" \n",
" \n",
" 345 \n",
" 1.608100 \n",
" \n",
" \n",
" 350 \n",
" 1.780700 \n",
" \n",
" \n",
" 355 \n",
" 1.694400 \n",
" \n",
" \n",
" 360 \n",
" 1.559900 \n",
" \n",
" \n",
" 365 \n",
" 1.641600 \n",
" \n",
" \n",
" 370 \n",
" 1.655600 \n",
" \n",
" \n",
" 375 \n",
" 1.719200 \n",
" \n",
" \n",
" 380 \n",
" 1.747800 \n",
" \n",
" \n",
" 385 \n",
" 1.653700 \n",
" \n",
" \n",
" 390 \n",
" 1.739900 \n",
" \n",
" \n",
" 395 \n",
" 1.651900 \n",
" \n",
" \n",
" 400 \n",
" 1.826100 \n",
" \n",
" \n",
" 405 \n",
" 1.788700 \n",
" \n",
" \n",
" 410 \n",
" 1.623900 \n",
" \n",
" \n",
" 415 \n",
" 1.672400 \n",
" \n",
" \n",
" 420 \n",
" 1.672100 \n",
" \n",
" \n",
" 425 \n",
" 1.791100 \n",
" \n",
" \n",
" 430 \n",
" 1.687000 \n",
" \n",
" \n",
" 435 \n",
" 1.698900 \n",
" \n",
" \n",
" 440 \n",
" 1.616600 \n",
" \n",
" \n",
" 445 \n",
" 1.539200 \n",
" \n",
" \n",
" 450 \n",
" 1.643000 \n",
" \n",
" \n",
" 455 \n",
" 1.748800 \n",
" \n",
" \n",
" 460 \n",
" 1.870800 \n",
" \n",
" \n",
" 465 \n",
" 1.726900 \n",
" \n",
" \n",
" 470 \n",
" 1.741500 \n",
" \n",
" \n",
" 475 \n",
" 1.761000 \n",
" \n",
" \n",
" 480 \n",
" 1.647400 \n",
" \n",
" \n",
" 485 \n",
" 1.606400 \n",
" \n",
" \n",
" 490 \n",
" 1.589900 \n",
" \n",
" \n",
" 495 \n",
" 1.634000 \n",
" \n",
" \n",
" 500 \n",
" 1.655500 \n",
" \n",
" \n",
" 505 \n",
" 1.813400 \n",
" \n",
" \n",
" 510 \n",
" 1.580000 \n",
" \n",
" \n",
" 515 \n",
" 1.584000 \n",
" \n",
" \n",
" 520 \n",
" 1.540400 \n",
" \n",
" \n",
" 525 \n",
" 1.585400 \n",
" \n",
" \n",
" 530 \n",
" 1.706400 \n",
" \n",
" \n",
" 535 \n",
" 1.712000 \n",
" \n",
" \n",
" 540 \n",
" 1.627300 \n",
" \n",
" \n",
" 545 \n",
" 1.625000 \n",
" \n",
" \n",
" 550 \n",
" 1.693900 \n",
" \n",
" \n",
" 555 \n",
" 1.672300 \n",
" \n",
" \n",
" 560 \n",
" 1.662200 \n",
" \n",
" \n",
" 565 \n",
" 1.644700 \n",
" \n",
" \n",
" 570 \n",
" 1.647400 \n",
" \n",
" \n",
" 575 \n",
" 1.651400 \n",
" \n",
" \n",
" 580 \n",
" 1.624000 \n",
" \n",
" \n",
" 585 \n",
" 1.666000 \n",
" \n",
" \n",
" 590 \n",
" 1.493200 \n",
" \n",
" \n",
" 595 \n",
" 1.655900 \n",
" \n",
" \n",
" 600 \n",
" 1.695700 \n",
" \n",
" \n",
" 605 \n",
" 1.711100 \n",
" \n",
" \n",
" 610 \n",
" 1.691600 \n",
" \n",
" \n",
" 615 \n",
" 1.628000 \n",
" \n",
" \n",
" 620 \n",
" 1.612300 \n",
" \n",
" \n",
" 625 \n",
" 1.544400 \n",
" \n",
" \n",
" 630 \n",
" 1.629700 \n",
" \n",
" \n",
" 635 \n",
" 1.757900 \n",
" \n",
" \n",
" 640 \n",
" 1.642900 \n",
" \n",
" \n",
" 645 \n",
" 1.578700 \n",
" \n",
" \n",
" 650 \n",
" 1.623900 \n",
" \n",
" \n",
" 655 \n",
" 1.693600 \n",
" \n",
" \n",
" 660 \n",
" 1.648000 \n",
" \n",
" \n",
" 665 \n",
" 1.645900 \n",
" \n",
" \n",
" 670 \n",
" 1.769000 \n",
" \n",
" \n",
" 675 \n",
" 1.613400 \n",
" \n",
" \n",
" 680 \n",
" 1.569900 \n",
" \n",
" \n",
" 685 \n",
" 1.792000 \n",
" \n",
" \n",
" 690 \n",
" 1.600800 \n",
" \n",
" \n",
" 695 \n",
" 1.557700 \n",
" \n",
" \n",
" 700 \n",
" 1.594300 \n",
" \n",
" \n",
" \n",
"705 \n",
" 1.680800 \n",
"