Delete lit-llama
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- lit-llama/.github/CODEOWNERS +0 -1
- lit-llama/.github/azure-gpu-tests.yml +0 -57
- lit-llama/.github/workflows/cpu-tests.yml +0 -52
- lit-llama/.gitignore +0 -16
- lit-llama/LICENSE +0 -201
- lit-llama/README.md +0 -193
- lit-llama/evaluate/adapter.py +0 -164
- lit-llama/evaluate/adapter_v2.py +0 -161
- lit-llama/evaluate/full.py +0 -147
- lit-llama/evaluate/lora.py +0 -172
- lit-llama/finetune/adapter.py +0 -262
- lit-llama/finetune/adapter_v2.py +0 -266
- lit-llama/finetune/full.py +0 -224
- lit-llama/finetune/lora.py +0 -218
- lit-llama/generate.py +0 -170
- lit-llama/generate/adapter.py +0 -106
- lit-llama/generate/adapter_v2.py +0 -108
- lit-llama/generate/full.py +0 -103
- lit-llama/generate/lora.py +0 -118
- lit-llama/howto/convert_lora_weights.md +0 -19
- lit-llama/howto/customize_paths.md +0 -33
- lit-llama/howto/download_weights.md +0 -130
- lit-llama/howto/finetune_adapter.md +0 -109
- lit-llama/howto/finetune_adapter_v2.md +0 -114
- lit-llama/howto/finetune_full.md +0 -106
- lit-llama/howto/finetune_lora.md +0 -90
- lit-llama/howto/inference.md +0 -43
- lit-llama/howto/tpus.md +0 -51
- lit-llama/howto/train_redpajama.md +0 -133
- lit-llama/howto/unstructured_dataset.md +0 -18
- lit-llama/lit_llama/__init__.py +0 -2
- lit-llama/lit_llama/adapter.py +0 -313
- lit-llama/lit_llama/adapter_v2.py +0 -45
- lit-llama/lit_llama/lora.py +0 -476
- lit-llama/lit_llama/model.py +0 -321
- lit-llama/lit_llama/packed_dataset.py +0 -260
- lit-llama/lit_llama/quantization.py +0 -614
- lit-llama/lit_llama/tokenizer.py +0 -49
- lit-llama/lit_llama/utils.py +0 -496
- lit-llama/pretrain/redpajama.py +0 -321
- lit-llama/pretrain/shakespeare.py +0 -166
- lit-llama/quantize/gptq.py +0 -238
- lit-llama/requirements.txt +0 -9
- lit-llama/scripts/convert_checkpoint.py +0 -141
- lit-llama/scripts/convert_hf_checkpoint.py +0 -167
- lit-llama/scripts/convert_lora_weights.py +0 -95
- lit-llama/scripts/download.py +0 -34
- lit-llama/scripts/prepare_alpaca.py +0 -131
- lit-llama/scripts/prepare_any_text.py +0 -97
- lit-llama/scripts/prepare_dolly.py +0 -133
lit-llama/.github/CODEOWNERS
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
* @awaelchli @carmocca @lantiga
|
|
|
|
lit-llama/.github/azure-gpu-tests.yml
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
# Python package
|
2 |
-
# Create and test a Python package on multiple Python versions.
|
3 |
-
# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more:
|
4 |
-
# https://docs.microsoft.com/azure/devops/pipelines/languages/python
|
5 |
-
|
6 |
-
trigger:
|
7 |
-
tags:
|
8 |
-
include:
|
9 |
-
- '*'
|
10 |
-
branches:
|
11 |
-
include:
|
12 |
-
- "main"
|
13 |
-
- "refs/tags/*"
|
14 |
-
|
15 |
-
pr:
|
16 |
-
branches:
|
17 |
-
include:
|
18 |
-
- "main"
|
19 |
-
|
20 |
-
jobs:
|
21 |
-
- job: testing
|
22 |
-
# how long to run the job before automatically cancelling
|
23 |
-
timeoutInMinutes: "20"
|
24 |
-
# how much time to give 'run always even if cancelled tasks' before stopping them
|
25 |
-
cancelTimeoutInMinutes: "2"
|
26 |
-
pool: "lit-rtx-3090"
|
27 |
-
variables:
|
28 |
-
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
|
29 |
-
container:
|
30 |
-
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.0-cuda11.7.1"
|
31 |
-
options: "--gpus=all --shm-size=8gb"
|
32 |
-
workspace:
|
33 |
-
clean: all
|
34 |
-
steps:
|
35 |
-
|
36 |
-
- bash: |
|
37 |
-
echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)"
|
38 |
-
cuda_ver=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))")
|
39 |
-
echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$cuda_ver"
|
40 |
-
displayName: 'set env. vars'
|
41 |
-
|
42 |
-
- bash: |
|
43 |
-
echo $CUDA_VISIBLE_DEVICES
|
44 |
-
echo $CUDA_VERSION_MM
|
45 |
-
lspci | egrep 'VGA|3D'
|
46 |
-
whereis nvidia
|
47 |
-
nvidia-smi
|
48 |
-
which python && which pip
|
49 |
-
python --version && pip --version && pip list
|
50 |
-
python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'"
|
51 |
-
displayName: 'Image info & NVIDIA'
|
52 |
-
|
53 |
-
- script: pip install pytest -r requirements.txt
|
54 |
-
displayName: 'Install dependencies'
|
55 |
-
|
56 |
-
- bash: pytest -v --durations=10 --disable-pytest-warnings --strict-markers --color=yes
|
57 |
-
displayName: 'Testing'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/.github/workflows/cpu-tests.yml
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
name: CPU tests
|
2 |
-
|
3 |
-
on:
|
4 |
-
push:
|
5 |
-
branches: [main]
|
6 |
-
pull_request:
|
7 |
-
branches: [main]
|
8 |
-
|
9 |
-
concurrency:
|
10 |
-
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
|
11 |
-
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
12 |
-
|
13 |
-
defaults:
|
14 |
-
run:
|
15 |
-
shell: bash
|
16 |
-
|
17 |
-
jobs:
|
18 |
-
cpu-tests:
|
19 |
-
runs-on: ${{ matrix.os }}
|
20 |
-
strategy:
|
21 |
-
fail-fast: false
|
22 |
-
matrix:
|
23 |
-
include:
|
24 |
-
- {os: "macOS-11", python-version: "3.10"}
|
25 |
-
- {os: "ubuntu-20.04", python-version: "3.10"}
|
26 |
-
- {os: "windows-2022", python-version: "3.10"}
|
27 |
-
timeout-minutes: 15
|
28 |
-
|
29 |
-
steps:
|
30 |
-
- uses: actions/checkout@v3
|
31 |
-
|
32 |
-
- name: Set up Python ${{ matrix.python-version }}
|
33 |
-
uses: actions/setup-python@v4
|
34 |
-
with:
|
35 |
-
python-version: ${{ matrix.python-version }}
|
36 |
-
cache: 'pip'
|
37 |
-
cache-dependency-path: |
|
38 |
-
requirements.txt
|
39 |
-
setup.py
|
40 |
-
|
41 |
-
- name: Run tests without the package installed
|
42 |
-
run: |
|
43 |
-
pip install pytest -r requirements.txt
|
44 |
-
pip list
|
45 |
-
|
46 |
-
pytest --disable-pytest-warnings --strict-markers --color=yes
|
47 |
-
|
48 |
-
- name: Run tests
|
49 |
-
run: |
|
50 |
-
pip install . --no-deps
|
51 |
-
|
52 |
-
pytest -v --durations=10 --disable-pytest-warnings --strict-markers --color=yes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/.gitignore
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
__pycache__
|
2 |
-
.idea
|
3 |
-
.DS_Store
|
4 |
-
*.egg-info
|
5 |
-
build
|
6 |
-
|
7 |
-
# data
|
8 |
-
data
|
9 |
-
checkpoints
|
10 |
-
out
|
11 |
-
!data/shakespeare/prepare.py
|
12 |
-
wandb
|
13 |
-
|
14 |
-
# downloaded by our tests
|
15 |
-
original_model.py
|
16 |
-
original_adapter.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/LICENSE
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright [2023] Lightning AI
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/README.md
DELETED
@@ -1,193 +0,0 @@
|
|
1 |
-
<div align="center">
|
2 |
-
<img src="https://pl-public-data.s3.amazonaws.com/assets_lightning/Lit_LLaMA_Badge3x.png" alt="Lit-LLaMA" width="128"/>
|
3 |
-
|
4 |
-
# ⚡ Lit-LLaMA ️
|
5 |
-
|
6 |
-
<!--
|
7 |
-
<p align="center">
|
8 |
-
<a href="https://www.lightning.ai/">Lightning.ai</a> •
|
9 |
-
<a href="https://lightning.ai/docs/pytorch/stable/">PyTorch Lightning</a> •
|
10 |
-
<a href="https://lightning.ai/docs/fabric/stable/">Fabric</a>
|
11 |
-
</p>
|
12 |
-
-->
|
13 |
-
|
14 |
-
![cpu-tests](https://github.com/lightning-AI/lit-llama/actions/workflows/cpu-tests.yml/badge.svg) [![Build Status](https://dev.azure.com/Lightning-AI/lit%20Models/_apis/build/status%2FLightning-AI.lit-LLaMA?branchName=main)](https://dev.azure.com/Lightning-AI/lit%20Models/_build/latest?definitionId=49&branchName=main) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lit-llama/blob/master/LICENSE) [![Discord](https://img.shields.io/discord/1077906959069626439?style=plastic)](https://discord.gg/VptPCZkGNa)
|
15 |
-
|
16 |
-
<img src="https://pl-public-data.s3.amazonaws.com/assets_lightning/Llama_pineapple.gif" alt="Lit-LLaMA and pineapple pizza" width="500px"/>
|
17 |
-
|
18 |
-
</div>
|
19 |
-
|
20 |
-
# ⚡ Lit-LLaMA ️
|
21 |
-
Independent implementation of [LLaMA](<https://github.com/facebookresearch/llama>) pretraining, finetuning, and inference code that is fully open source under the **Apache 2.0 license.**
|
22 |
-
|
23 |
-
This implementation builds on [nanoGPT](<https://github.com/karpathy/nanoGPT>).
|
24 |
-
|
25 |
-
The open-source code in this repository works with the original LLaMA weights that are distributed by Meta under a [research-only license](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md#model-details).
|
26 |
-
|
27 |
-
New Apache 2.0 licensed weights are being released as part of the [Open LLaMA project](https://github.com/openlm-research/open_llama). Both the original research-only weights by Meta and the Open LLaMA weights can be [loaded in Lit-LLaMA](howto/download_weights.md).
|
28 |
-
|
29 |
-
## Why?
|
30 |
-
|
31 |
-
We believe that AI should be fully open source and part of the collective knowledge.
|
32 |
-
|
33 |
-
The original [LLaMA code](https://github.com/facebookresearch/llama) is [GPL licensed](https://github.com/facebookresearch/llama/blob/main/LICENSE) which means any project using it must also be released under GPL.
|
34 |
-
|
35 |
-
This "taints" any other code and prevents integration with the rest of the ecosystem.
|
36 |
-
|
37 |
-
**Lit-LLaMA solves that for good.**
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
## Design principles
|
42 |
-
**Lit-LLaMA** is:
|
43 |
-
|
44 |
-
- **Simple:** Single-file implementation without boilerplate.
|
45 |
-
- **Correct:** Numerically equivalent to the original model.
|
46 |
-
- **Optimized:** Runs on consumer hardware or at scale.
|
47 |
-
- **Open-source:** No strings attached.
|
48 |
-
|
49 |
-
## Get involved!
|
50 |
-
[Join our Discord](https://discord.gg/VptPCZkGNa) to build high-performance, truly open-source models for the common benefit of the community.
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
## Setup
|
55 |
-
|
56 |
-
Clone the repo
|
57 |
-
|
58 |
-
```bash
|
59 |
-
git clone https://github.com/Lightning-AI/lit-llama
|
60 |
-
cd lit-llama
|
61 |
-
```
|
62 |
-
|
63 |
-
install dependencies
|
64 |
-
|
65 |
-
```bash
|
66 |
-
pip install -r requirements.txt
|
67 |
-
```
|
68 |
-
|
69 |
-
You are all set! 🎉
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
## Use the model
|
74 |
-
|
75 |
-
To generate text predictions, you need to download the model weights. **If you don't have them, check out our [guide](howto/download_weights.md).**
|
76 |
-
|
77 |
-
Run inference:
|
78 |
-
|
79 |
-
```bash
|
80 |
-
python generate.py --prompt "Hello, my name is"
|
81 |
-
```
|
82 |
-
|
83 |
-
This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
|
84 |
-
|
85 |
-
[Full guide for generating samples from the model](howto/inference.md).
|
86 |
-
|
87 |
-
### Run Lit-LLaMA on consumer devices
|
88 |
-
|
89 |
-
On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
|
90 |
-
For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):
|
91 |
-
|
92 |
-
```bash
|
93 |
-
python generate.py --quantize llm.int8 --prompt "Hello, my name is"
|
94 |
-
```
|
95 |
-
|
96 |
-
See `python generate.py --help` for more options.
|
97 |
-
|
98 |
-
You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:
|
99 |
-
|
100 |
-
```bash
|
101 |
-
python quantize/gptq.py --output_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth --dtype bfloat16 --quantize gptq.int4
|
102 |
-
```
|
103 |
-
|
104 |
-
GPTQ-style int4 quantization brings GPU usage down to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to also use `--dtype bfloat16` even with the quantization enabled.
|
105 |
-
|
106 |
-
With the generated quantized checkpoint generation quantization then works as usual with `--quantize gptq.int4` and the newly generated checkpoint file:
|
107 |
-
|
108 |
-
```bash
|
109 |
-
python generate.py --quantize gptq.int4 --checkpoint_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth
|
110 |
-
```
|
111 |
-
|
112 |
-
[Full guide for generating samples from the model](howto/inference.md).
|
113 |
-
|
114 |
-
## Finetune the model
|
115 |
-
|
116 |
-
We provide a simple training scripts in `finetune/lora.py` and `finetune/adapter.py` that instruction-tunes a pretrained model on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset using the techniques of [LoRA](https://arxiv.org/abs/2106.09685) and [Adapter](https://arxiv.org/abs/2303.16199).
|
117 |
-
|
118 |
-
1. Download the data and generate a instruction tuning dataset:
|
119 |
-
|
120 |
-
```bash
|
121 |
-
python scripts/prepare_alpaca.py
|
122 |
-
```
|
123 |
-
|
124 |
-
2. Run the finetuning script
|
125 |
-
|
126 |
-
```bash
|
127 |
-
python finetune/lora.py
|
128 |
-
```
|
129 |
-
or
|
130 |
-
```bash
|
131 |
-
python finetune/adapter.py
|
132 |
-
```
|
133 |
-
|
134 |
-
It is expected that you have downloaded the pretrained weights as described above.
|
135 |
-
The finetuning requires at least one GPU with ~24 GB memory (RTX 3090). Follow the instructions in the script to efficiently fit your GPU memory.
|
136 |
-
Note: For some GPU models you might need to set `torch.backends.cuda.enable_flash_sdp(False)` (see comments at the top of the script).
|
137 |
-
|
138 |
-
More details about each finetuning method and how you can apply it to your own data can be found in our technical how-to guides.
|
139 |
-
|
140 |
-
### Finetuning How-To Guides
|
141 |
-
|
142 |
-
These technical tutorials illustrate how to run the finetuning code.
|
143 |
-
|
144 |
-
- [Finetune with LoRA](howto/finetune_lora.md)
|
145 |
-
- [Finetune with Adapters](howto/finetune_adapter.md)
|
146 |
-
|
147 |
-
### Understanding Finetuning -- Conceptual Tutorials
|
148 |
-
|
149 |
-
Looking for conceptual tutorials and explanations? We have some additional articles below:
|
150 |
-
|
151 |
-
- [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/)
|
152 |
-
|
153 |
-
## Pre-training
|
154 |
-
|
155 |
-
We provide a simple training script based on Fabric if you want to venture into pre-training on RedPajama, a reproduction of the original LLaMA dataset.
|
156 |
-
Conversion scripts for our optimized streaming `PackedDataset` are included.
|
157 |
-
|
158 |
-
Follow this guide to start pre-training on the RedPajama dataset:
|
159 |
-
|
160 |
-
- [Pretrain on RedPajama](howto/train_redpajama.md)
|
161 |
-
|
162 |
-
## Get involved!
|
163 |
-
|
164 |
-
We are on a quest towards fully open source AI.
|
165 |
-
|
166 |
-
<img align="right" src="https://pl-public-data.s3.amazonaws.com/assets_lightning/Lit_LLaMA_Illustration3x.png" alt="Lit-LLaMA" width="128"/>
|
167 |
-
|
168 |
-
Join us and start contributing, especially on the following areas:
|
169 |
-
|
170 |
-
- [ ] [Pre-training](https://github.com/Lightning-AI/lit-llama/labels/pre-training)
|
171 |
-
- [ ] [Fine-tuning (full and LoRA)](https://github.com/Lightning-AI/lit-llama/labels/fine-tuning)
|
172 |
-
- [ ] [Quantization](https://github.com/Lightning-AI/lit-llama/labels/quantization)
|
173 |
-
- [ ] [Sparsification](https://github.com/Lightning-AI/lit-llama/labels/sparsification)
|
174 |
-
|
175 |
-
Look at `train.py` for a starting point towards pre-training / fine-tuning using [Lightning Fabric](https://lightning.ai/docs/fabric/stable/).
|
176 |
-
|
177 |
-
We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment.
|
178 |
-
|
179 |
-
Unsure about contributing? Check out our [Contributing to Lit-LLaMA: A Hitchhiker’s Guide to the Quest for Fully Open-Source AI](https://lightning.ai/pages/community/tutorial/contributing-to-lit-llama-a-hitchhikers-guide-to-the-quest-for-fully-open-source-ai/) guide.
|
180 |
-
|
181 |
-
Don't forget to [join our Discord](https://discord.gg/VptPCZkGNa)!
|
182 |
-
|
183 |
-
## Acknowledgements
|
184 |
-
|
185 |
-
- [@karpathy](https://github.com/karpathy) for [nanoGPT](https://github.com/karpathy/nanoGPT)
|
186 |
-
- [@FacebookResearch](https://github.com/facebookresearch) for the original [LLaMA implementation](https://github.com/facebookresearch/llama)
|
187 |
-
- [@TimDettmers](https://github.com/TimDettmers) for [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
188 |
-
- [@Microsoft](https://github.com/microsoft) for [LoRA](https://github.com/microsoft/LoRA)
|
189 |
-
- [@IST-DASLab](https://github.com/IST-DASLab) for [GPTQ](https://github.com/IST-DASLab/gptq)
|
190 |
-
|
191 |
-
## License
|
192 |
-
|
193 |
-
Lit-LLaMA is released under the [Apache 2.0](https://github.com/Lightning-AI/lightning-llama/blob/main/LICENSE) license.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/evaluate/adapter.py
DELETED
@@ -1,164 +0,0 @@
|
|
1 |
-
# This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
|
2 |
-
# Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
|
3 |
-
import math
|
4 |
-
import sys
|
5 |
-
import time
|
6 |
-
from pathlib import Path
|
7 |
-
from typing import Optional
|
8 |
-
|
9 |
-
import lightning as L
|
10 |
-
import torch
|
11 |
-
import tqdm
|
12 |
-
|
13 |
-
# support running without installing as a package
|
14 |
-
wd = Path(__file__).parent.parent.resolve()
|
15 |
-
sys.path.append(str(wd))
|
16 |
-
|
17 |
-
from lit_llama import Tokenizer
|
18 |
-
from lit_llama.adapter import LLaMA
|
19 |
-
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
|
20 |
-
from scripts.prepare_alpaca import generate_prompt
|
21 |
-
|
22 |
-
from datasets import load_dataset
|
23 |
-
|
24 |
-
instruction_tuning = True
|
25 |
-
|
26 |
-
|
27 |
-
def load_eval_data(dataset_name: str) -> str:
|
28 |
-
# this mimics gptq datautils
|
29 |
-
if dataset_name == "wikitext":
|
30 |
-
# traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
|
31 |
-
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
32 |
-
testdata = "\n\n".join(testdata["text"])
|
33 |
-
elif dataset_name == "ptb":
|
34 |
-
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
|
35 |
-
testdata = "\n\n".join(testdata["sentence"])
|
36 |
-
elif dataset_name == "c4":
|
37 |
-
testdata = load_dataset(
|
38 |
-
"allenai/c4",
|
39 |
-
"allenai--c4",
|
40 |
-
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
41 |
-
split="validation",
|
42 |
-
)
|
43 |
-
testdata = " ".join(testdata[:1100]["text"])
|
44 |
-
|
45 |
-
else:
|
46 |
-
raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
|
47 |
-
return testdata
|
48 |
-
|
49 |
-
|
50 |
-
@torch.inference_mode()
|
51 |
-
def main(
|
52 |
-
datasets: str = "wikitext,ptb,c4",
|
53 |
-
*,
|
54 |
-
# compilation fails as it does not support torch.complex64 for RoPE
|
55 |
-
# compile: bool = False,
|
56 |
-
accelerator: str = "auto",
|
57 |
-
adapter_path: Path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth"),
|
58 |
-
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
|
59 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
60 |
-
dtype: str = "float32",
|
61 |
-
quantize: Optional[str] = None,
|
62 |
-
) -> None:
|
63 |
-
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
|
64 |
-
|
65 |
-
Args:
|
66 |
-
datasets: The datasets to use as a comma separated string
|
67 |
-
# compile: Whether to compile the model.
|
68 |
-
accelerator: The hardware to run on. Possible choices are:
|
69 |
-
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
|
70 |
-
adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
|
71 |
-
`finetune_adapter.py`.
|
72 |
-
checkpoint_path: The checkpoint path to load.
|
73 |
-
tokenizer_path: The tokenizer path to load.
|
74 |
-
dtype: The tensor dtype for choosing the floating-point precision
|
75 |
-
quantize: Whether to quantize the model and using which method:
|
76 |
-
``"llm.int8"``: LLM.int8() mode,
|
77 |
-
``"gptq.int4"``: GPTQ 4-bit mode.
|
78 |
-
"""
|
79 |
-
assert adapter_path.is_file()
|
80 |
-
assert checkpoint_path.is_file()
|
81 |
-
assert tokenizer_path.is_file()
|
82 |
-
|
83 |
-
fabric = L.Fabric(accelerator=accelerator, devices=1)
|
84 |
-
|
85 |
-
dt = getattr(torch, dtype, None)
|
86 |
-
if not isinstance(dt, torch.dtype):
|
87 |
-
raise ValueError(f"{dtype} is not a valid dtype.")
|
88 |
-
dtype = dt
|
89 |
-
|
90 |
-
print("Loading model ...", file=sys.stderr)
|
91 |
-
t0 = time.time()
|
92 |
-
with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
|
93 |
-
name = llama_model_lookup(pretrained_checkpoint)
|
94 |
-
|
95 |
-
with EmptyInitOnDevice(
|
96 |
-
device=fabric.device, dtype=dtype, quantization_mode=quantize
|
97 |
-
):
|
98 |
-
model = LLaMA.from_name(name)
|
99 |
-
|
100 |
-
# 1. Load the pretrained weights
|
101 |
-
model.load_state_dict(pretrained_checkpoint, strict=False)
|
102 |
-
# 2. Load the fine-tuned adapter weights
|
103 |
-
model.load_state_dict(adapter_checkpoint, strict=False)
|
104 |
-
|
105 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
106 |
-
|
107 |
-
model.eval()
|
108 |
-
|
109 |
-
# if compile:
|
110 |
-
# model = torch.compile(model)
|
111 |
-
|
112 |
-
total_toks = 0
|
113 |
-
model = fabric.setup_module(model)
|
114 |
-
|
115 |
-
tokenizer = Tokenizer(tokenizer_path)
|
116 |
-
|
117 |
-
for dsname in datasets.split(","):
|
118 |
-
test_string = load_eval_data(dsname)
|
119 |
-
|
120 |
-
if instruction_tuning:
|
121 |
-
sample = {"instruction": test_string, "input": input}
|
122 |
-
test_string = generate_prompt(sample)
|
123 |
-
|
124 |
-
encoded_text = tokenizer.encode(
|
125 |
-
test_string, bos=True, eos=False, device=fabric.device
|
126 |
-
)
|
127 |
-
encoded_text = encoded_text[
|
128 |
-
None, : 256 * model.config.block_size
|
129 |
-
] # add batch dimension, trim like gptq implementation
|
130 |
-
t0 = time.perf_counter()
|
131 |
-
|
132 |
-
nlls = 0
|
133 |
-
toks = 0
|
134 |
-
block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
|
135 |
-
for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
|
136 |
-
inp = encoded_text[:, i : i + block_size]
|
137 |
-
logits = model(inp)[0]
|
138 |
-
nll = torch.nn.functional.cross_entropy(
|
139 |
-
logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
|
140 |
-
)
|
141 |
-
toks += inp.size(1) - 1
|
142 |
-
nlls += nll.item()
|
143 |
-
|
144 |
-
print(encoded_text.shape, logits.shape)
|
145 |
-
ppl = math.exp(nlls / toks)
|
146 |
-
print(f"Perplexity on {dsname}: {ppl:.2f}")
|
147 |
-
total_toks += toks
|
148 |
-
|
149 |
-
t = time.perf_counter() - t0
|
150 |
-
print(
|
151 |
-
f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
|
152 |
-
file=sys.stderr,
|
153 |
-
)
|
154 |
-
print(
|
155 |
-
f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
|
156 |
-
file=sys.stderr,
|
157 |
-
)
|
158 |
-
|
159 |
-
|
160 |
-
if __name__ == "__main__":
|
161 |
-
from jsonargparse import CLI
|
162 |
-
|
163 |
-
torch.set_float32_matmul_precision("high")
|
164 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/evaluate/adapter_v2.py
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
# This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
|
2 |
-
# Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
|
3 |
-
import math
|
4 |
-
import sys
|
5 |
-
import time
|
6 |
-
from pathlib import Path
|
7 |
-
from typing import Optional
|
8 |
-
|
9 |
-
import lightning as L
|
10 |
-
import torch
|
11 |
-
import tqdm
|
12 |
-
|
13 |
-
# support running without installing as a package
|
14 |
-
wd = Path(__file__).parent.parent.resolve()
|
15 |
-
sys.path.append(str(wd))
|
16 |
-
|
17 |
-
from lit_llama import Tokenizer
|
18 |
-
from lit_llama.adapter import LLaMA
|
19 |
-
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
|
20 |
-
from lit_llama.adapter_v2 import add_adapter_v2_parameters_to_linear_layers
|
21 |
-
from scripts.prepare_alpaca import generate_prompt
|
22 |
-
|
23 |
-
from datasets import load_dataset
|
24 |
-
|
25 |
-
|
26 |
-
def load_eval_data(dataset_name: str) -> str:
|
27 |
-
# this mimics gptq datautils
|
28 |
-
if dataset_name == "wikitext":
|
29 |
-
# traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
|
30 |
-
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
31 |
-
testdata = "\n\n".join(testdata["text"])
|
32 |
-
elif dataset_name == "ptb":
|
33 |
-
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
|
34 |
-
testdata = "\n\n".join(testdata["sentence"])
|
35 |
-
elif dataset_name == "c4":
|
36 |
-
testdata = load_dataset(
|
37 |
-
"allenai/c4",
|
38 |
-
"allenai--c4",
|
39 |
-
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
40 |
-
split="validation",
|
41 |
-
)
|
42 |
-
testdata = " ".join(testdata[:1100]["text"])
|
43 |
-
|
44 |
-
else:
|
45 |
-
raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
|
46 |
-
return testdata
|
47 |
-
|
48 |
-
|
49 |
-
@torch.inference_mode()
|
50 |
-
def main(
|
51 |
-
datasets: str = "wikitext,ptb,c4",
|
52 |
-
*,
|
53 |
-
accelerator: str = "auto",
|
54 |
-
adapter_path: Path = Path("out/adapter_v2/alpaca/lit-llama-adapter-finetuned.pth"),
|
55 |
-
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
|
56 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
57 |
-
dtype: str = "float32",
|
58 |
-
quantize: Optional[str] = None,
|
59 |
-
) -> None:
|
60 |
-
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
datasets: The datasets to use as a comma separated string
|
64 |
-
accelerator: The hardware to run on. Possible choices are:
|
65 |
-
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
|
66 |
-
adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
|
67 |
-
`finetune_adapter_v2.py`.
|
68 |
-
checkpoint_path: The checkpoint path to load.
|
69 |
-
tokenizer_path: The tokenizer path to load.
|
70 |
-
dtype: The tensor dtype for choosing the floating-point precision
|
71 |
-
quantize: Whether to quantize the model and using which method:
|
72 |
-
``"llm.int8"``: LLM.int8() mode,
|
73 |
-
``"gptq.int4"``: GPTQ 4-bit mode.
|
74 |
-
"""
|
75 |
-
assert adapter_path.is_file()
|
76 |
-
assert checkpoint_path.is_file()
|
77 |
-
assert tokenizer_path.is_file()
|
78 |
-
|
79 |
-
fabric = L.Fabric(accelerator=accelerator, devices=1)
|
80 |
-
|
81 |
-
dt = getattr(torch, dtype, None)
|
82 |
-
if not isinstance(dt, torch.dtype):
|
83 |
-
raise ValueError(f"{dtype} is not a valid dtype.")
|
84 |
-
dtype = dt
|
85 |
-
|
86 |
-
print("Loading model ...", file=sys.stderr)
|
87 |
-
t0 = time.time()
|
88 |
-
with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
|
89 |
-
name = llama_model_lookup(pretrained_checkpoint)
|
90 |
-
|
91 |
-
with EmptyInitOnDevice(
|
92 |
-
device=fabric.device, dtype=dtype, quantization_mode=quantize
|
93 |
-
):
|
94 |
-
model = LLaMA.from_name(name)
|
95 |
-
add_adapter_v2_parameters_to_linear_layers(model)
|
96 |
-
|
97 |
-
# 1. Load the pretrained weights
|
98 |
-
model.load_state_dict(pretrained_checkpoint, strict=False)
|
99 |
-
# 2. Load the fine-tuned adapter weights
|
100 |
-
model.load_state_dict(adapter_checkpoint, strict=False)
|
101 |
-
|
102 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
103 |
-
|
104 |
-
model.eval()
|
105 |
-
|
106 |
-
# if compile:
|
107 |
-
# model = torch.compile(model)
|
108 |
-
|
109 |
-
total_toks = 0
|
110 |
-
model = fabric.setup_module(model)
|
111 |
-
|
112 |
-
tokenizer = Tokenizer(tokenizer_path)
|
113 |
-
|
114 |
-
for dsname in datasets.split(","):
|
115 |
-
test_string = load_eval_data(dsname)
|
116 |
-
|
117 |
-
sample = {"instruction": test_string, "input": input}
|
118 |
-
test_string = generate_prompt(sample)
|
119 |
-
|
120 |
-
encoded_text = tokenizer.encode(
|
121 |
-
test_string, bos=True, eos=False, device=fabric.device
|
122 |
-
)
|
123 |
-
encoded_text = encoded_text[
|
124 |
-
None, : 256 * model.config.block_size
|
125 |
-
] # add batch dimension, trim like gptq implementation
|
126 |
-
t0 = time.perf_counter()
|
127 |
-
|
128 |
-
nlls = 0
|
129 |
-
toks = 0
|
130 |
-
|
131 |
-
block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
|
132 |
-
for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
|
133 |
-
inp = encoded_text[:, i : i + block_size]
|
134 |
-
logits = model(inp)[0]
|
135 |
-
nll = torch.nn.functional.cross_entropy(
|
136 |
-
logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
|
137 |
-
)
|
138 |
-
toks += inp.size(1) - 1
|
139 |
-
nlls += nll.item()
|
140 |
-
|
141 |
-
print(encoded_text.shape, logits.shape)
|
142 |
-
ppl = math.exp(nlls / toks)
|
143 |
-
print(f"Perplexity on {dsname}: {ppl:.2f}")
|
144 |
-
total_toks += toks
|
145 |
-
|
146 |
-
t = time.perf_counter() - t0
|
147 |
-
print(
|
148 |
-
f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
|
149 |
-
file=sys.stderr,
|
150 |
-
)
|
151 |
-
print(
|
152 |
-
f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
|
153 |
-
file=sys.stderr,
|
154 |
-
)
|
155 |
-
|
156 |
-
|
157 |
-
if __name__ == "__main__":
|
158 |
-
from jsonargparse import CLI
|
159 |
-
|
160 |
-
torch.set_float32_matmul_precision("high")
|
161 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/evaluate/full.py
DELETED
@@ -1,147 +0,0 @@
|
|
1 |
-
# This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
|
2 |
-
# Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
|
3 |
-
import math
|
4 |
-
import sys
|
5 |
-
import time
|
6 |
-
from pathlib import Path
|
7 |
-
from typing import Optional
|
8 |
-
|
9 |
-
import lightning as L
|
10 |
-
import torch
|
11 |
-
import tqdm
|
12 |
-
|
13 |
-
# support running without installing as a package
|
14 |
-
wd = Path(__file__).parent.parent.resolve()
|
15 |
-
sys.path.append(str(wd))
|
16 |
-
|
17 |
-
from lit_llama import LLaMA, Tokenizer
|
18 |
-
from lit_llama.utils import EmptyInitOnDevice
|
19 |
-
|
20 |
-
from datasets import load_dataset
|
21 |
-
|
22 |
-
|
23 |
-
def load_eval_data(dataset_name: str) -> str:
|
24 |
-
# this mimics gptq datautils
|
25 |
-
if dataset_name == "wikitext":
|
26 |
-
# traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
|
27 |
-
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
28 |
-
testdata = "\n\n".join(testdata["text"])
|
29 |
-
elif dataset_name == "ptb":
|
30 |
-
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
|
31 |
-
testdata = "\n\n".join(testdata["sentence"])
|
32 |
-
elif dataset_name == "c4":
|
33 |
-
testdata = load_dataset(
|
34 |
-
"allenai/c4",
|
35 |
-
"allenai--c4",
|
36 |
-
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
37 |
-
split="validation",
|
38 |
-
)
|
39 |
-
testdata = " ".join(testdata[:1100]["text"])
|
40 |
-
|
41 |
-
else:
|
42 |
-
raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
|
43 |
-
return testdata
|
44 |
-
|
45 |
-
|
46 |
-
def main(
|
47 |
-
datasets: str = "wikitext,ptb,c4",
|
48 |
-
*,
|
49 |
-
# compilation fails as it does not support torch.complex64 for RoPE
|
50 |
-
# compile: bool = False,
|
51 |
-
accelerator: str = "auto",
|
52 |
-
checkpoint_path: Optional[Path] = None,
|
53 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
54 |
-
model_size: str = "7B",
|
55 |
-
dtype: str = "float32",
|
56 |
-
quantize: Optional[str] = None,
|
57 |
-
) -> None:
|
58 |
-
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
|
59 |
-
|
60 |
-
Args:
|
61 |
-
datasets: The datasets to use as a comma separated string
|
62 |
-
# compile: Whether to compile the model.
|
63 |
-
accelerator: The hardware to run on. Possible choices are:
|
64 |
-
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
|
65 |
-
checkpoint_path: The checkpoint path to load.
|
66 |
-
tokenizer_path: The tokenizer path to load.
|
67 |
-
dtype: The tensor dtype for choosing the floating-point precision
|
68 |
-
quantize: Whether to quantize the model and using which method:
|
69 |
-
``"llm.int8"``: LLM.int8() mode,
|
70 |
-
``"gptq.int4"``: GPTQ 4-bit mode.
|
71 |
-
"""
|
72 |
-
if not checkpoint_path:
|
73 |
-
checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth")
|
74 |
-
assert checkpoint_path.is_file()
|
75 |
-
assert tokenizer_path.is_file()
|
76 |
-
|
77 |
-
fabric = L.Fabric(accelerator=accelerator, devices=1)
|
78 |
-
|
79 |
-
dt = getattr(torch, dtype, None)
|
80 |
-
if not isinstance(dt, torch.dtype):
|
81 |
-
raise ValueError(f"{dtype} is not a valid dtype.")
|
82 |
-
dtype = dt
|
83 |
-
|
84 |
-
with EmptyInitOnDevice(
|
85 |
-
device=fabric.device, dtype=dtype, quantization_mode=quantize
|
86 |
-
):
|
87 |
-
print("Loading model ...", file=sys.stderr)
|
88 |
-
t0 = time.time()
|
89 |
-
model = LLaMA.from_name(model_size)
|
90 |
-
checkpoint = torch.load(checkpoint_path)
|
91 |
-
model.load_state_dict(checkpoint)
|
92 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
93 |
-
|
94 |
-
model.eval()
|
95 |
-
|
96 |
-
# if compile:
|
97 |
-
# model = torch.compile(model)
|
98 |
-
|
99 |
-
total_toks = 0
|
100 |
-
model = fabric.setup_module(model)
|
101 |
-
|
102 |
-
tokenizer = Tokenizer(tokenizer_path)
|
103 |
-
|
104 |
-
for dsname in datasets.split(","):
|
105 |
-
test_string = load_eval_data(dsname)
|
106 |
-
encoded_text = tokenizer.encode(
|
107 |
-
test_string, bos=True, eos=False, device=fabric.device
|
108 |
-
)
|
109 |
-
encoded_text = encoded_text[
|
110 |
-
None, : 256 * model.config.block_size
|
111 |
-
] # add batch dimension, trim like gptq implementation
|
112 |
-
t0 = time.perf_counter()
|
113 |
-
|
114 |
-
nlls = 0
|
115 |
-
toks = 0
|
116 |
-
with torch.inference_mode():
|
117 |
-
block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
|
118 |
-
for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
|
119 |
-
inp = encoded_text[:, i : i + block_size]
|
120 |
-
logits = model(inp)[0]
|
121 |
-
nll = torch.nn.functional.cross_entropy(
|
122 |
-
logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
|
123 |
-
)
|
124 |
-
toks += inp.size(1) - 1
|
125 |
-
nlls += nll.item()
|
126 |
-
|
127 |
-
print(encoded_text.shape, logits.shape)
|
128 |
-
ppl = math.exp(nlls / toks)
|
129 |
-
print(f"Perplexity on {dsname}: {ppl:.2f}")
|
130 |
-
total_toks += toks
|
131 |
-
|
132 |
-
t = time.perf_counter() - t0
|
133 |
-
print(
|
134 |
-
f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
|
135 |
-
file=sys.stderr,
|
136 |
-
)
|
137 |
-
print(
|
138 |
-
f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
|
139 |
-
file=sys.stderr,
|
140 |
-
)
|
141 |
-
|
142 |
-
|
143 |
-
if __name__ == "__main__":
|
144 |
-
from jsonargparse import CLI
|
145 |
-
|
146 |
-
torch.set_float32_matmul_precision("high")
|
147 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/evaluate/lora.py
DELETED
@@ -1,172 +0,0 @@
|
|
1 |
-
# This mimics GPTQ's evaluation metrics: https://github.com/IST-DASLab/gptq/
|
2 |
-
# Thanks to E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
|
3 |
-
import math
|
4 |
-
import sys
|
5 |
-
import time
|
6 |
-
from pathlib import Path
|
7 |
-
from typing import Optional
|
8 |
-
|
9 |
-
import lightning as L
|
10 |
-
import torch
|
11 |
-
import tqdm
|
12 |
-
|
13 |
-
# support running without installing as a package
|
14 |
-
wd = Path(__file__).parent.parent.resolve()
|
15 |
-
sys.path.append(str(wd))
|
16 |
-
|
17 |
-
from lit_llama import LLaMA, Tokenizer
|
18 |
-
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
|
19 |
-
from lit_llama.lora import lora
|
20 |
-
from scripts.prepare_alpaca import generate_prompt
|
21 |
-
|
22 |
-
from datasets import load_dataset
|
23 |
-
|
24 |
-
instruction_tuning = True
|
25 |
-
lora_r = 8
|
26 |
-
lora_alpha = 16
|
27 |
-
lora_dropout = 0.05
|
28 |
-
|
29 |
-
|
30 |
-
def load_eval_data(dataset_name: str) -> str:
|
31 |
-
# this mimics gptq datautils
|
32 |
-
if dataset_name == "wikitext":
|
33 |
-
# traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
|
34 |
-
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
35 |
-
testdata = "\n\n".join(testdata["text"])
|
36 |
-
elif dataset_name == "ptb":
|
37 |
-
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
|
38 |
-
testdata = "\n\n".join(testdata["sentence"])
|
39 |
-
elif dataset_name == "c4":
|
40 |
-
testdata = load_dataset(
|
41 |
-
"allenai/c4",
|
42 |
-
"allenai--c4",
|
43 |
-
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
44 |
-
split="validation",
|
45 |
-
)
|
46 |
-
testdata = " ".join(testdata[:1100]["text"])
|
47 |
-
|
48 |
-
else:
|
49 |
-
raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)")
|
50 |
-
return testdata
|
51 |
-
|
52 |
-
|
53 |
-
def main(
|
54 |
-
datasets: str = "wikitext,ptb,c4",
|
55 |
-
*,
|
56 |
-
# compilation fails as it does not support torch.complex64 for RoPE
|
57 |
-
# compile: bool = False,
|
58 |
-
accelerator: str = "auto",
|
59 |
-
lora_path: Path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth"),
|
60 |
-
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
|
61 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
62 |
-
dtype: str = "float32",
|
63 |
-
quantize: Optional[str] = None,
|
64 |
-
) -> None:
|
65 |
-
"""Generates text samples based on a pre-trained LLaMA model and tokenizer
|
66 |
-
finetuned with LoRA.
|
67 |
-
|
68 |
-
Args:
|
69 |
-
datasets: The datasets to use as a comma separated string
|
70 |
-
# compile: Whether to compile the model.
|
71 |
-
accelerator: The hardware to run on. Possible choices are:
|
72 |
-
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
|
73 |
-
lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
|
74 |
-
`finetune_lora.py`.
|
75 |
-
checkpoint_path: The checkpoint path to load.
|
76 |
-
tokenizer_path: The tokenizer path to load.
|
77 |
-
dtype: The tensor dtype for choosing the floating-point precision
|
78 |
-
quantize: Whether to quantize the model and using which method:
|
79 |
-
``"llm.int8"``: LLM.int8() mode,
|
80 |
-
``"gptq.int4"``: GPTQ 4-bit mode.
|
81 |
-
"""
|
82 |
-
assert lora_path.is_file()
|
83 |
-
assert checkpoint_path.is_file()
|
84 |
-
assert tokenizer_path.is_file()
|
85 |
-
|
86 |
-
if quantize is not None:
|
87 |
-
raise NotImplementedError("Quantization in LoRA is not supported yet")
|
88 |
-
|
89 |
-
fabric = L.Fabric(accelerator=accelerator, devices=1)
|
90 |
-
|
91 |
-
dt = getattr(torch, dtype, None)
|
92 |
-
if not isinstance(dt, torch.dtype):
|
93 |
-
raise ValueError(f"{dtype} is not a valid dtype.")
|
94 |
-
dtype = dt
|
95 |
-
|
96 |
-
print("Loading model ...", file=sys.stderr)
|
97 |
-
t0 = time.time()
|
98 |
-
|
99 |
-
with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(lora_path) as lora_checkpoint:
|
100 |
-
name = llama_model_lookup(pretrained_checkpoint)
|
101 |
-
|
102 |
-
with EmptyInitOnDevice(
|
103 |
-
device=fabric.device, dtype=dtype, quantization_mode=quantize
|
104 |
-
), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
|
105 |
-
model = LLaMA.from_name(name)
|
106 |
-
|
107 |
-
# 1. Load the pretrained weights
|
108 |
-
model.load_state_dict(pretrained_checkpoint, strict=False)
|
109 |
-
# 2. Load the fine-tuned lora weights
|
110 |
-
model.load_state_dict(lora_checkpoint, strict=False)
|
111 |
-
|
112 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
113 |
-
|
114 |
-
model.eval()
|
115 |
-
|
116 |
-
# if compile:
|
117 |
-
# model = torch.compile(model)
|
118 |
-
|
119 |
-
total_toks = 0
|
120 |
-
model = fabric.setup_module(model)
|
121 |
-
|
122 |
-
tokenizer = Tokenizer(tokenizer_path)
|
123 |
-
|
124 |
-
for dsname in datasets.split(","):
|
125 |
-
test_string = load_eval_data(dsname)
|
126 |
-
|
127 |
-
if instruction_tuning:
|
128 |
-
sample = {"instruction": test_string, "input": input}
|
129 |
-
test_string = generate_prompt(sample)
|
130 |
-
|
131 |
-
encoded_text = tokenizer.encode(
|
132 |
-
test_string, bos=True, eos=False, device=fabric.device
|
133 |
-
)
|
134 |
-
encoded_text = encoded_text[
|
135 |
-
None, : 256 * model.config.block_size
|
136 |
-
] # add batch dimension, trim like gptq implementation
|
137 |
-
t0 = time.perf_counter()
|
138 |
-
|
139 |
-
nlls = 0
|
140 |
-
toks = 0
|
141 |
-
with torch.inference_mode():
|
142 |
-
block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
|
143 |
-
for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)):
|
144 |
-
inp = encoded_text[:, i : i + block_size]
|
145 |
-
logits = model(inp)[0]
|
146 |
-
nll = torch.nn.functional.cross_entropy(
|
147 |
-
logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum"
|
148 |
-
)
|
149 |
-
toks += inp.size(1) - 1
|
150 |
-
nlls += nll.item()
|
151 |
-
|
152 |
-
print(encoded_text.shape, logits.shape)
|
153 |
-
ppl = math.exp(nlls / toks)
|
154 |
-
print(f"Perplexity on {dsname}: {ppl:.2f}")
|
155 |
-
total_toks += toks
|
156 |
-
|
157 |
-
t = time.perf_counter() - t0
|
158 |
-
print(
|
159 |
-
f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec",
|
160 |
-
file=sys.stderr,
|
161 |
-
)
|
162 |
-
print(
|
163 |
-
f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
|
164 |
-
file=sys.stderr,
|
165 |
-
)
|
166 |
-
|
167 |
-
|
168 |
-
if __name__ == "__main__":
|
169 |
-
from jsonargparse import CLI
|
170 |
-
|
171 |
-
torch.set_float32_matmul_precision("high")
|
172 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/finetune/adapter.py
DELETED
@@ -1,262 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Instruction-tuning with LLaMA-Adapter on the Alpaca dataset following the paper
|
3 |
-
|
4 |
-
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
|
5 |
-
https://arxiv.org/abs/2303.16199
|
6 |
-
|
7 |
-
This script runs on a single GPU by default. You can adjust the `micro_batch_size` to fit your GPU memory.
|
8 |
-
You can finetune within 1 hour as done in the original paper using DeepSpeed Zero-2 on 8 A100 GPUs by setting the
|
9 |
-
devices variable to `devices = 8` and `micro_batch_size = 8` (or higher).
|
10 |
-
|
11 |
-
Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
|
12 |
-
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
|
13 |
-
"""
|
14 |
-
import os
|
15 |
-
import sys
|
16 |
-
import time
|
17 |
-
from pathlib import Path
|
18 |
-
import shutil
|
19 |
-
|
20 |
-
import lightning as L
|
21 |
-
import numpy as np
|
22 |
-
import torch
|
23 |
-
|
24 |
-
# support running without installing as a package
|
25 |
-
wd = Path(__file__).parent.parent.resolve()
|
26 |
-
sys.path.append(str(wd))
|
27 |
-
|
28 |
-
from generate import generate
|
29 |
-
from lit_llama.adapter import LLaMA, LLaMAConfig, mark_only_adapter_as_trainable, adapter_state_from_state_dict
|
30 |
-
from lit_llama.tokenizer import Tokenizer
|
31 |
-
from scripts.prepare_alpaca import generate_prompt
|
32 |
-
from lightning.fabric.strategies import DeepSpeedStrategy
|
33 |
-
|
34 |
-
|
35 |
-
instruction_tuning = True
|
36 |
-
eval_interval = 600
|
37 |
-
save_interval = 1000
|
38 |
-
eval_iters = 100
|
39 |
-
log_interval = 1
|
40 |
-
devices = 1
|
41 |
-
|
42 |
-
# Hyperparameters
|
43 |
-
learning_rate = 9e-3
|
44 |
-
batch_size = 64 / devices
|
45 |
-
micro_batch_size = 4
|
46 |
-
gradient_accumulation_iters = batch_size // micro_batch_size
|
47 |
-
assert gradient_accumulation_iters > 0
|
48 |
-
epoch_size = 50000 # train dataset size
|
49 |
-
num_epochs = 5
|
50 |
-
max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
|
51 |
-
weight_decay = 0.02
|
52 |
-
max_seq_length = 256 # see scripts/prepare_alpaca.py
|
53 |
-
warmup_iters = 2 * (epoch_size // micro_batch_size) // devices # 2 epochs
|
54 |
-
|
55 |
-
ds_config = {
|
56 |
-
"train_micro_batch_size_per_gpu": micro_batch_size,
|
57 |
-
"gradient_accumulation_steps": gradient_accumulation_iters,
|
58 |
-
"zero_optimization": {"stage": 2},
|
59 |
-
}
|
60 |
-
|
61 |
-
|
62 |
-
def main(
|
63 |
-
data_dir: str = "data/alpaca",
|
64 |
-
pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
|
65 |
-
out_dir: str = "out/adapter/alpaca",
|
66 |
-
):
|
67 |
-
|
68 |
-
fabric = L.Fabric(
|
69 |
-
accelerator="cuda",
|
70 |
-
devices=devices,
|
71 |
-
strategy=(DeepSpeedStrategy(config=ds_config) if devices > 1 else "auto"),
|
72 |
-
precision="bf16-true",
|
73 |
-
)
|
74 |
-
fabric.launch()
|
75 |
-
fabric.seed_everything(1337 + fabric.global_rank)
|
76 |
-
|
77 |
-
if fabric.global_rank == 0:
|
78 |
-
os.makedirs(out_dir, exist_ok=True)
|
79 |
-
|
80 |
-
train_data, val_data = load_datasets(data_dir=data_dir)
|
81 |
-
|
82 |
-
config = LLaMAConfig(block_size=max_seq_length)
|
83 |
-
|
84 |
-
if not os.path.isfile(pretrained_path):
|
85 |
-
raise FileNotFoundError(
|
86 |
-
f"Can't find the pretrained weights at {pretrained_path}."
|
87 |
-
" Please follow the instructions in the README to download them."
|
88 |
-
)
|
89 |
-
checkpoint = torch.load(pretrained_path)
|
90 |
-
|
91 |
-
with fabric.init_module():
|
92 |
-
model = LLaMA(config)
|
93 |
-
# strict=False because missing keys due to adapter weights not containted in state dict
|
94 |
-
model.load_state_dict(checkpoint, strict=False)
|
95 |
-
|
96 |
-
mark_only_adapter_as_trainable(model)
|
97 |
-
|
98 |
-
num_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
|
99 |
-
print(f"Number of trainable parameters: {num_params}")
|
100 |
-
|
101 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
102 |
-
model, optimizer = fabric.setup(model, optimizer)
|
103 |
-
train(fabric, model, optimizer, train_data, val_data, out_dir)
|
104 |
-
|
105 |
-
# Save the final checkpoint at the end of training
|
106 |
-
save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-adapter-finetuned.pth"))
|
107 |
-
|
108 |
-
|
109 |
-
def train(
|
110 |
-
fabric: L.Fabric,
|
111 |
-
model: torch.nn.Module,
|
112 |
-
optimizer: torch.optim.Optimizer,
|
113 |
-
train_data: np.ndarray,
|
114 |
-
val_data: np.ndarray,
|
115 |
-
out_dir: str,
|
116 |
-
) -> None:
|
117 |
-
"""The training loop.
|
118 |
-
|
119 |
-
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
|
120 |
-
"""
|
121 |
-
step_count = 0
|
122 |
-
|
123 |
-
for iter_num in range(max_iters):
|
124 |
-
|
125 |
-
if step_count <= warmup_iters:
|
126 |
-
# linear warmup
|
127 |
-
lr = learning_rate * step_count / warmup_iters
|
128 |
-
for param_group in optimizer.param_groups:
|
129 |
-
param_group['lr'] = lr
|
130 |
-
|
131 |
-
t0 = time.time()
|
132 |
-
|
133 |
-
input_ids, targets = get_batch(fabric, train_data)
|
134 |
-
with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
|
135 |
-
logits = model(input_ids)
|
136 |
-
loss = loss_fn(logits, targets)
|
137 |
-
fabric.backward(loss / gradient_accumulation_iters)
|
138 |
-
|
139 |
-
if (iter_num + 1) % gradient_accumulation_iters == 0:
|
140 |
-
optimizer.step()
|
141 |
-
optimizer.zero_grad()
|
142 |
-
step_count += 1
|
143 |
-
|
144 |
-
if step_count % eval_interval == 0:
|
145 |
-
val_loss = validate(fabric, model, val_data)
|
146 |
-
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
|
147 |
-
fabric.barrier()
|
148 |
-
|
149 |
-
if step_count % save_interval == 0:
|
150 |
-
print(f"Saving adapter weights to {out_dir}")
|
151 |
-
# TODO: Provide a function/script to merge the adapter weights with pretrained weights
|
152 |
-
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}.pth"))
|
153 |
-
|
154 |
-
dt = time.time() - t0
|
155 |
-
if iter_num % log_interval == 0:
|
156 |
-
fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
|
157 |
-
|
158 |
-
|
159 |
-
def generate_response(model, instruction, input=""):
|
160 |
-
tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
|
161 |
-
sample = {"instruction": instruction, "input": input}
|
162 |
-
prompt = instruction
|
163 |
-
if instruction_tuning:
|
164 |
-
prompt = generate_prompt(sample)
|
165 |
-
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
|
166 |
-
|
167 |
-
output = generate(
|
168 |
-
model,
|
169 |
-
idx=encoded,
|
170 |
-
max_seq_length=max_seq_length,
|
171 |
-
max_new_tokens=100,
|
172 |
-
temperature=0.8,
|
173 |
-
)
|
174 |
-
output = tokenizer.decode(output)
|
175 |
-
return output # output.split("### Response:")[1].strip()
|
176 |
-
|
177 |
-
|
178 |
-
@torch.no_grad()
|
179 |
-
def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
|
180 |
-
fabric.print("Validating ...")
|
181 |
-
model.eval()
|
182 |
-
losses = torch.zeros(eval_iters)
|
183 |
-
for k in range(eval_iters):
|
184 |
-
input_ids, targets = get_batch(fabric, val_data)
|
185 |
-
logits = model(input_ids)
|
186 |
-
loss = loss_fn(logits, targets)
|
187 |
-
losses[k] = loss.item()
|
188 |
-
val_loss = losses.mean()
|
189 |
-
|
190 |
-
# produce an example:
|
191 |
-
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
|
192 |
-
output = generate_response(model, instruction)
|
193 |
-
fabric.print(instruction)
|
194 |
-
fabric.print(output)
|
195 |
-
|
196 |
-
model.train()
|
197 |
-
return val_loss.item()
|
198 |
-
|
199 |
-
def loss_fn(logits, targets):
|
200 |
-
# shift the targets such that output n predicts token n+1
|
201 |
-
logits = logits[..., :-1, :].contiguous()
|
202 |
-
targets = targets[..., 1:].contiguous()
|
203 |
-
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
204 |
-
return loss
|
205 |
-
|
206 |
-
|
207 |
-
def get_batch(fabric: L.Fabric, data: list):
|
208 |
-
ix = torch.randint(len(data), (micro_batch_size,))
|
209 |
-
|
210 |
-
input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
|
211 |
-
labels = [data[i]["labels"].type(torch.int64) for i in ix]
|
212 |
-
|
213 |
-
max_len = max(len(s) for s in input_ids)
|
214 |
-
|
215 |
-
def pad_right(x, pad_id):
|
216 |
-
# pad right based on the longest sequence
|
217 |
-
n = max_len - len(x)
|
218 |
-
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
|
219 |
-
|
220 |
-
x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
|
221 |
-
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
|
222 |
-
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
|
223 |
-
return x, y
|
224 |
-
|
225 |
-
|
226 |
-
def load_datasets(data_dir):
|
227 |
-
train_data = torch.load(os.path.join(data_dir, "train.pt"))
|
228 |
-
val_data = torch.load(os.path.join(data_dir, "test.pt"))
|
229 |
-
return train_data, val_data
|
230 |
-
|
231 |
-
|
232 |
-
def save_model_checkpoint(fabric, model, file_path):
|
233 |
-
file_path = Path(file_path)
|
234 |
-
|
235 |
-
if isinstance(fabric.strategy, DeepSpeedStrategy):
|
236 |
-
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
237 |
-
|
238 |
-
tmp_path = file_path.with_suffix(".tmp")
|
239 |
-
fabric.save(tmp_path, {"model": model})
|
240 |
-
fabric.barrier()
|
241 |
-
if fabric.global_rank == 0:
|
242 |
-
# Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
|
243 |
-
# and only keep the adapter weights
|
244 |
-
state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_path)
|
245 |
-
state_dict = adapter_state_from_state_dict(state_dict)
|
246 |
-
torch.save(state_dict, file_path)
|
247 |
-
shutil.rmtree(tmp_path)
|
248 |
-
else:
|
249 |
-
state_dict = adapter_state_from_state_dict(model.state_dict())
|
250 |
-
if fabric.global_rank == 0:
|
251 |
-
torch.save(state_dict, file_path)
|
252 |
-
fabric.barrier()
|
253 |
-
|
254 |
-
|
255 |
-
if __name__ == "__main__":
|
256 |
-
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
|
257 |
-
# torch.backends.cuda.enable_flash_sdp(False)
|
258 |
-
torch.set_float32_matmul_precision("high")
|
259 |
-
|
260 |
-
from jsonargparse.cli import CLI
|
261 |
-
|
262 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/finetune/adapter_v2.py
DELETED
@@ -1,266 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Instruction-tuning with LLaMA-Adapter v2 on the Alpaca dataset following the paper
|
3 |
-
|
4 |
-
LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
|
5 |
-
https://arxiv.org/abs/2304.15010
|
6 |
-
|
7 |
-
This script runs on a single GPU by default. You can adjust the `micro_batch_size` to fit your GPU memory.
|
8 |
-
You can finetune within 1 hour as done in the original paper using DeepSpeed Zero-2 on 8 A100 GPUs by setting the
|
9 |
-
devices variable to `devices = 8` and `micro_batch_size = 8` (or higher).
|
10 |
-
|
11 |
-
Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
|
12 |
-
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
|
13 |
-
"""
|
14 |
-
import os
|
15 |
-
import sys
|
16 |
-
import time
|
17 |
-
from pathlib import Path
|
18 |
-
import shutil
|
19 |
-
|
20 |
-
import lightning as L
|
21 |
-
import numpy as np
|
22 |
-
import torch
|
23 |
-
import torch.nn as nn
|
24 |
-
|
25 |
-
# support running without installing as a package
|
26 |
-
wd = Path(__file__).parent.parent.resolve()
|
27 |
-
sys.path.append(str(wd))
|
28 |
-
|
29 |
-
from generate import generate
|
30 |
-
from lit_llama.adapter import LLaMA, LLaMAConfig
|
31 |
-
from lit_llama.adapter_v2 import (
|
32 |
-
mark_only_adapter_v2_as_trainable,
|
33 |
-
add_adapter_v2_parameters_to_linear_layers,
|
34 |
-
adapter_v2_state_from_state_dict
|
35 |
-
)
|
36 |
-
from lit_llama.tokenizer import Tokenizer
|
37 |
-
from scripts.prepare_alpaca import generate_prompt
|
38 |
-
from lightning.fabric.strategies import DeepSpeedStrategy
|
39 |
-
|
40 |
-
|
41 |
-
eval_interval = 600
|
42 |
-
save_interval = 1000
|
43 |
-
eval_iters = 100
|
44 |
-
log_interval = 1
|
45 |
-
devices = 1
|
46 |
-
|
47 |
-
# Hyperparameters
|
48 |
-
learning_rate = 9e-3
|
49 |
-
batch_size = 64 / devices
|
50 |
-
micro_batch_size = 4
|
51 |
-
gradient_accumulation_iters = batch_size // micro_batch_size
|
52 |
-
assert gradient_accumulation_iters > 0
|
53 |
-
epoch_size = 50000 # train dataset size
|
54 |
-
num_epochs = 5
|
55 |
-
max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
|
56 |
-
weight_decay = 0.02
|
57 |
-
max_seq_length = 256 # see scripts/prepare_alpaca.py
|
58 |
-
warmup_iters = 2 * (epoch_size // micro_batch_size) // devices # 2 epoch
|
59 |
-
|
60 |
-
ds_config = {
|
61 |
-
"train_micro_batch_size_per_gpu": micro_batch_size,
|
62 |
-
"gradient_accumulation_steps": gradient_accumulation_iters,
|
63 |
-
"zero_optimization": {"stage": 2},
|
64 |
-
}
|
65 |
-
|
66 |
-
|
67 |
-
def main(
|
68 |
-
data_dir: str = "data/alpaca",
|
69 |
-
pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
|
70 |
-
out_dir: str = "out/adapter_v2/alpaca",
|
71 |
-
):
|
72 |
-
|
73 |
-
fabric = L.Fabric(
|
74 |
-
accelerator="cuda",
|
75 |
-
devices=1,
|
76 |
-
strategy=(DeepSpeedStrategy(config=ds_config) if devices > 1 else "auto"),
|
77 |
-
precision="bf16-true",
|
78 |
-
)
|
79 |
-
fabric.launch()
|
80 |
-
fabric.seed_everything(1337 + fabric.global_rank)
|
81 |
-
|
82 |
-
if fabric.global_rank == 0:
|
83 |
-
os.makedirs(out_dir, exist_ok=True)
|
84 |
-
|
85 |
-
train_data, val_data = load_datasets(data_dir=data_dir)
|
86 |
-
|
87 |
-
config = LLaMAConfig(block_size=max_seq_length)
|
88 |
-
|
89 |
-
if not os.path.isfile(pretrained_path):
|
90 |
-
raise FileNotFoundError(
|
91 |
-
f"Can't find the pretrained weights at {pretrained_path}."
|
92 |
-
" Please follow the instructions in the README to download them."
|
93 |
-
)
|
94 |
-
checkpoint = torch.load(pretrained_path)
|
95 |
-
|
96 |
-
with fabric.init_module():
|
97 |
-
model = LLaMA(config)
|
98 |
-
# strict=False because missing keys due to adapter weights not contained in state dict
|
99 |
-
model.load_state_dict(checkpoint, strict=False)
|
100 |
-
|
101 |
-
add_adapter_v2_parameters_to_linear_layers(model)
|
102 |
-
mark_only_adapter_v2_as_trainable(model)
|
103 |
-
|
104 |
-
num_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
|
105 |
-
print(f"Number of trainable parameters: {num_params}")
|
106 |
-
|
107 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
108 |
-
model, optimizer = fabric.setup(model, optimizer)
|
109 |
-
train(fabric, model, optimizer, train_data, val_data, out_dir)
|
110 |
-
|
111 |
-
# Save the final checkpoint at the end of training
|
112 |
-
save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-adapter-finetuned.pth"))
|
113 |
-
|
114 |
-
|
115 |
-
def train(
|
116 |
-
fabric: L.Fabric,
|
117 |
-
model: torch.nn.Module,
|
118 |
-
optimizer: torch.optim.Optimizer,
|
119 |
-
train_data: np.ndarray,
|
120 |
-
val_data: np.ndarray,
|
121 |
-
out_dir: str,
|
122 |
-
) -> None:
|
123 |
-
"""The training loop.
|
124 |
-
|
125 |
-
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
|
126 |
-
"""
|
127 |
-
step_count = 0
|
128 |
-
|
129 |
-
for iter_num in range(max_iters):
|
130 |
-
|
131 |
-
if step_count <= warmup_iters:
|
132 |
-
# linear warmup
|
133 |
-
lr = learning_rate * step_count / warmup_iters
|
134 |
-
for param_group in optimizer.param_groups:
|
135 |
-
param_group['lr'] = lr
|
136 |
-
|
137 |
-
t0 = time.time()
|
138 |
-
|
139 |
-
input_ids, targets = get_batch(fabric, train_data)
|
140 |
-
with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
|
141 |
-
logits = model(input_ids)
|
142 |
-
loss = loss_fn(logits, targets)
|
143 |
-
fabric.backward(loss / gradient_accumulation_iters)
|
144 |
-
|
145 |
-
if (iter_num + 1) % gradient_accumulation_iters == 0:
|
146 |
-
optimizer.step()
|
147 |
-
optimizer.zero_grad()
|
148 |
-
step_count += 1
|
149 |
-
|
150 |
-
if step_count % eval_interval == 0:
|
151 |
-
val_loss = validate(fabric, model, val_data)
|
152 |
-
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
|
153 |
-
fabric.barrier()
|
154 |
-
|
155 |
-
if step_count % save_interval == 0:
|
156 |
-
print(f"Saving adapter weights to {out_dir}")
|
157 |
-
# TODO: Provide a function/script to merge the adapter weights with pretrained weights
|
158 |
-
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}.pth"))
|
159 |
-
|
160 |
-
dt = time.time() - t0
|
161 |
-
if iter_num % log_interval == 0:
|
162 |
-
fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
|
163 |
-
|
164 |
-
|
165 |
-
def generate_response(model, instruction, input=""):
|
166 |
-
tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
|
167 |
-
sample = {"instruction": instruction, "input": input}
|
168 |
-
prompt = generate_prompt(sample)
|
169 |
-
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
|
170 |
-
|
171 |
-
output = generate(
|
172 |
-
model,
|
173 |
-
idx=encoded,
|
174 |
-
max_seq_length=max_seq_length,
|
175 |
-
max_new_tokens=100,
|
176 |
-
temperature=0.8,
|
177 |
-
)
|
178 |
-
output = tokenizer.decode(output)
|
179 |
-
return output # output.split("### Response:")[1].strip()
|
180 |
-
|
181 |
-
|
182 |
-
@torch.no_grad()
|
183 |
-
def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
|
184 |
-
fabric.print("Validating ...")
|
185 |
-
model.eval()
|
186 |
-
losses = torch.zeros(eval_iters)
|
187 |
-
for k in range(eval_iters):
|
188 |
-
input_ids, targets = get_batch(fabric, val_data)
|
189 |
-
logits = model(input_ids)
|
190 |
-
loss = loss_fn(logits, targets)
|
191 |
-
losses[k] = loss.item()
|
192 |
-
val_loss = losses.mean()
|
193 |
-
|
194 |
-
# produce an example:
|
195 |
-
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
|
196 |
-
output = generate_response(model, instruction)
|
197 |
-
fabric.print(instruction)
|
198 |
-
fabric.print(output)
|
199 |
-
|
200 |
-
model.train()
|
201 |
-
return val_loss.item()
|
202 |
-
|
203 |
-
def loss_fn(logits, targets):
|
204 |
-
# shift the targets such that output n predicts token n+1
|
205 |
-
logits = logits[..., :-1, :].contiguous()
|
206 |
-
targets = targets[..., 1:].contiguous()
|
207 |
-
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
208 |
-
return loss
|
209 |
-
|
210 |
-
|
211 |
-
def get_batch(fabric: L.Fabric, data: list):
|
212 |
-
ix = torch.randint(len(data), (micro_batch_size,))
|
213 |
-
|
214 |
-
input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
|
215 |
-
labels = [data[i]["labels"].type(torch.int64) for i in ix]
|
216 |
-
|
217 |
-
max_len = max(len(s) for s in input_ids)
|
218 |
-
|
219 |
-
def pad_right(x, pad_id):
|
220 |
-
# pad right based on the longest sequence
|
221 |
-
n = max_len - len(x)
|
222 |
-
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
|
223 |
-
|
224 |
-
x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
|
225 |
-
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
|
226 |
-
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
|
227 |
-
return x, y
|
228 |
-
|
229 |
-
|
230 |
-
def load_datasets(data_dir):
|
231 |
-
train_data = torch.load(os.path.join(data_dir, "train.pt"))
|
232 |
-
val_data = torch.load(os.path.join(data_dir, "test.pt"))
|
233 |
-
return train_data, val_data
|
234 |
-
|
235 |
-
|
236 |
-
def save_model_checkpoint(fabric, model, file_path):
|
237 |
-
file_path = Path(file_path)
|
238 |
-
|
239 |
-
if isinstance(fabric.strategy, DeepSpeedStrategy):
|
240 |
-
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
241 |
-
|
242 |
-
tmp_path = file_path.with_suffix(".tmp")
|
243 |
-
fabric.save(tmp_path, {"model": model})
|
244 |
-
fabric.barrier()
|
245 |
-
if fabric.global_rank == 0:
|
246 |
-
# Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
|
247 |
-
# and only keep the adapter weights
|
248 |
-
state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_path)
|
249 |
-
state_dict = adapter_v2_state_from_state_dict(state_dict)
|
250 |
-
torch.save(state_dict, file_path)
|
251 |
-
shutil.rmtree(tmp_path)
|
252 |
-
else:
|
253 |
-
state_dict = adapter_v2_state_from_state_dict(model.state_dict())
|
254 |
-
if fabric.global_rank == 0:
|
255 |
-
torch.save(state_dict, file_path)
|
256 |
-
fabric.barrier()
|
257 |
-
|
258 |
-
|
259 |
-
if __name__ == "__main__":
|
260 |
-
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
|
261 |
-
# torch.backends.cuda.enable_flash_sdp(False)
|
262 |
-
torch.set_float32_matmul_precision("high")
|
263 |
-
|
264 |
-
from jsonargparse.cli import CLI
|
265 |
-
|
266 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/finetune/full.py
DELETED
@@ -1,224 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Instruction-tuning on the Alpaca dataset using a regular finetuning procedure (updating all layers).
|
3 |
-
|
4 |
-
Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
|
5 |
-
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
|
6 |
-
"""
|
7 |
-
import sys
|
8 |
-
from pathlib import Path
|
9 |
-
import os
|
10 |
-
import time
|
11 |
-
from functools import partial
|
12 |
-
|
13 |
-
import lightning as L
|
14 |
-
from lightning.fabric.strategies import FSDPStrategy
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
18 |
-
|
19 |
-
# support running without installing as a package
|
20 |
-
wd = Path(__file__).parent.parent.resolve()
|
21 |
-
sys.path.append(str(wd))
|
22 |
-
|
23 |
-
from generate import generate
|
24 |
-
from lit_llama.model import Block, LLaMA, LLaMAConfig
|
25 |
-
from lit_llama.tokenizer import Tokenizer
|
26 |
-
from lit_llama.utils import save_model_checkpoint
|
27 |
-
from scripts.prepare_alpaca import generate_prompt
|
28 |
-
|
29 |
-
|
30 |
-
instruction_tuning = True
|
31 |
-
eval_interval = 1000
|
32 |
-
save_interval = 1000
|
33 |
-
eval_iters = 100
|
34 |
-
log_interval = 100
|
35 |
-
devices = 4
|
36 |
-
|
37 |
-
# Hyperparameters
|
38 |
-
learning_rate = 3e-5
|
39 |
-
batch_size = 128 / devices
|
40 |
-
micro_batch_size = 4
|
41 |
-
gradient_accumulation_iters = batch_size // micro_batch_size
|
42 |
-
assert gradient_accumulation_iters > 0
|
43 |
-
epoch_size = 50000 # train dataset size
|
44 |
-
num_epochs = 5
|
45 |
-
max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
|
46 |
-
weight_decay = 0.0
|
47 |
-
block_size = 512
|
48 |
-
warmup_iters = 100
|
49 |
-
|
50 |
-
|
51 |
-
def main(
|
52 |
-
data_dir: str = "data/alpaca",
|
53 |
-
pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
|
54 |
-
out_dir: str = "out/full/alpaca",
|
55 |
-
):
|
56 |
-
|
57 |
-
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
|
58 |
-
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)
|
59 |
-
|
60 |
-
fabric = L.Fabric(accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy)
|
61 |
-
fabric.launch()
|
62 |
-
fabric.seed_everything(1337 + fabric.global_rank)
|
63 |
-
|
64 |
-
if fabric.global_rank == 0:
|
65 |
-
os.makedirs(out_dir, exist_ok=True)
|
66 |
-
|
67 |
-
train_data, val_data = load_datasets(data_dir=data_dir)
|
68 |
-
|
69 |
-
config = LLaMAConfig.from_name("7B")
|
70 |
-
config.block_size = block_size
|
71 |
-
|
72 |
-
checkpoint = torch.load(pretrained_path)
|
73 |
-
|
74 |
-
with fabric.device:
|
75 |
-
torch.set_default_tensor_type(torch.HalfTensor)
|
76 |
-
model = LLaMA(config).bfloat16()
|
77 |
-
torch.set_default_tensor_type(torch.FloatTensor)
|
78 |
-
model.load_state_dict(checkpoint, strict=False)
|
79 |
-
|
80 |
-
model = fabric.setup_module(model)
|
81 |
-
|
82 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, foreach=False)
|
83 |
-
optimizer = fabric.setup_optimizers(optimizer)
|
84 |
-
|
85 |
-
train(fabric, model, optimizer, train_data, val_data, out_dir)
|
86 |
-
|
87 |
-
# Save the final checkpoint at the end of training
|
88 |
-
save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-full-finetuned.pth"))
|
89 |
-
|
90 |
-
|
91 |
-
def train(
|
92 |
-
fabric: L.Fabric,
|
93 |
-
model: torch.nn.Module,
|
94 |
-
optimizer: torch.optim.Optimizer,
|
95 |
-
train_data: np.ndarray,
|
96 |
-
val_data: np.ndarray,
|
97 |
-
out_dir: str,
|
98 |
-
) -> None:
|
99 |
-
"""The training loop.
|
100 |
-
|
101 |
-
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
|
102 |
-
"""
|
103 |
-
step_count = 0
|
104 |
-
model.train()
|
105 |
-
|
106 |
-
for iter_num in range(max_iters):
|
107 |
-
|
108 |
-
is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0
|
109 |
-
|
110 |
-
if step_count <= warmup_iters:
|
111 |
-
# linear warmup
|
112 |
-
lr = learning_rate * step_count / warmup_iters
|
113 |
-
for param_group in optimizer.param_groups:
|
114 |
-
param_group['lr'] = lr
|
115 |
-
|
116 |
-
t0 = time.time()
|
117 |
-
|
118 |
-
input_ids, targets = get_batch(fabric, train_data)
|
119 |
-
with fabric.no_backward_sync(model, enabled=is_accumulating):
|
120 |
-
logits = model(input_ids)
|
121 |
-
loss = loss_fn(logits, targets)
|
122 |
-
fabric.backward(loss / gradient_accumulation_iters)
|
123 |
-
|
124 |
-
if not is_accumulating:
|
125 |
-
optimizer.step()
|
126 |
-
optimizer.zero_grad()
|
127 |
-
step_count += 1
|
128 |
-
|
129 |
-
if step_count % eval_interval == 0:
|
130 |
-
val_loss = validate(fabric, model, val_data)
|
131 |
-
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
|
132 |
-
fabric.barrier()
|
133 |
-
|
134 |
-
if step_count % save_interval == 0:
|
135 |
-
print(f"Saving weights to {out_dir}")
|
136 |
-
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))
|
137 |
-
|
138 |
-
dt = time.time() - t0
|
139 |
-
if iter_num % log_interval == 0:
|
140 |
-
fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
|
141 |
-
|
142 |
-
|
143 |
-
def generate_response(model, instruction):
|
144 |
-
tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
|
145 |
-
sample = {"instruction": instruction, "input": ""}
|
146 |
-
prompt = instruction
|
147 |
-
if instruction_tuning:
|
148 |
-
prompt = generate_prompt(sample)
|
149 |
-
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
|
150 |
-
|
151 |
-
output = generate(
|
152 |
-
model,
|
153 |
-
idx=encoded,
|
154 |
-
max_seq_length=block_size,
|
155 |
-
max_new_tokens=100,
|
156 |
-
)
|
157 |
-
output = tokenizer.decode(output)
|
158 |
-
return output # output.split("### Response:")[1].strip()
|
159 |
-
|
160 |
-
|
161 |
-
@torch.no_grad()
|
162 |
-
def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
|
163 |
-
fabric.print("Validating ...")
|
164 |
-
model.eval()
|
165 |
-
losses = torch.zeros(eval_iters)
|
166 |
-
for k in range(eval_iters):
|
167 |
-
input_ids, targets = get_batch(fabric, val_data)
|
168 |
-
logits = model(input_ids)
|
169 |
-
loss = loss_fn(logits, targets)
|
170 |
-
losses[k] = loss.item()
|
171 |
-
out = losses.mean()
|
172 |
-
|
173 |
-
# produce an example:
|
174 |
-
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
|
175 |
-
|
176 |
-
output = generate_response(model, instruction)
|
177 |
-
fabric.print(instruction)
|
178 |
-
fabric.print(output)
|
179 |
-
|
180 |
-
model.train()
|
181 |
-
return out.item()
|
182 |
-
|
183 |
-
|
184 |
-
def loss_fn(logits, targets):
|
185 |
-
# shift the targets such that output n predicts token n+1
|
186 |
-
logits = logits[..., :-1, :].contiguous()
|
187 |
-
targets = targets[..., 1:].contiguous()
|
188 |
-
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
189 |
-
return loss
|
190 |
-
|
191 |
-
|
192 |
-
def get_batch(fabric: L.Fabric, data: list):
|
193 |
-
ix = torch.randint(len(data), (micro_batch_size,))
|
194 |
-
|
195 |
-
input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
|
196 |
-
labels = [data[i]["labels"].type(torch.int64) for i in ix]
|
197 |
-
|
198 |
-
max_len = max(len(s) for s in input_ids)
|
199 |
-
|
200 |
-
def pad_right(x, pad_id):
|
201 |
-
# pad right based on the longest sequence
|
202 |
-
n = max_len - len(x)
|
203 |
-
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
|
204 |
-
|
205 |
-
x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
|
206 |
-
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
|
207 |
-
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
|
208 |
-
return x, y
|
209 |
-
|
210 |
-
|
211 |
-
def load_datasets(data_dir):
|
212 |
-
train_data = torch.load(os.path.join(data_dir, "train.pt"))
|
213 |
-
val_data = torch.load(os.path.join(data_dir, "test.pt"))
|
214 |
-
return train_data, val_data
|
215 |
-
|
216 |
-
|
217 |
-
if __name__ == "__main__":
|
218 |
-
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
|
219 |
-
# torch.backends.cuda.enable_flash_sdp(False)
|
220 |
-
torch.set_float32_matmul_precision("high")
|
221 |
-
|
222 |
-
from jsonargparse.cli import CLI
|
223 |
-
|
224 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/finetune/lora.py
DELETED
@@ -1,218 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Instruction-tuning with LoRA on the Alpaca dataset.
|
3 |
-
|
4 |
-
Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
|
5 |
-
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
|
6 |
-
"""
|
7 |
-
import sys
|
8 |
-
from pathlib import Path
|
9 |
-
import os
|
10 |
-
import time
|
11 |
-
|
12 |
-
import lightning as L
|
13 |
-
import numpy as np
|
14 |
-
import torch
|
15 |
-
|
16 |
-
# support running without installing as a package
|
17 |
-
wd = Path(__file__).parent.parent.resolve()
|
18 |
-
sys.path.append(str(wd))
|
19 |
-
|
20 |
-
from generate import generate
|
21 |
-
from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
|
22 |
-
from lit_llama.model import LLaMA, LLaMAConfig
|
23 |
-
from lit_llama.tokenizer import Tokenizer
|
24 |
-
from scripts.prepare_alpaca import generate_prompt
|
25 |
-
|
26 |
-
|
27 |
-
instruction_tuning = True
|
28 |
-
eval_interval = 100
|
29 |
-
save_interval = 100
|
30 |
-
eval_iters = 100
|
31 |
-
log_interval = 1
|
32 |
-
|
33 |
-
# Hyperparameters
|
34 |
-
learning_rate = 3e-4
|
35 |
-
batch_size = 128
|
36 |
-
micro_batch_size = 4
|
37 |
-
gradient_accumulation_iters = batch_size // micro_batch_size
|
38 |
-
assert gradient_accumulation_iters > 0
|
39 |
-
max_iters = 50000 * 3 // micro_batch_size
|
40 |
-
weight_decay = 0.0
|
41 |
-
max_seq_length = 256 # see scripts/prepare_alpaca.py
|
42 |
-
lora_r = 8
|
43 |
-
lora_alpha = 16
|
44 |
-
lora_dropout = 0.05
|
45 |
-
warmup_iters = 100
|
46 |
-
|
47 |
-
|
48 |
-
def main(
|
49 |
-
data_dir: str = "data/alpaca",
|
50 |
-
pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
|
51 |
-
tokenizer_path: str = "checkpoints/lit-llama/tokenizer.model",
|
52 |
-
out_dir: str = "out/lora/alpaca",
|
53 |
-
):
|
54 |
-
|
55 |
-
fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-true")
|
56 |
-
fabric.launch()
|
57 |
-
fabric.seed_everything(1337 + fabric.global_rank)
|
58 |
-
|
59 |
-
if fabric.global_rank == 0:
|
60 |
-
os.makedirs(out_dir, exist_ok=True)
|
61 |
-
|
62 |
-
train_data, val_data = load_datasets(data_dir=data_dir)
|
63 |
-
|
64 |
-
config = LLaMAConfig.from_name("7B")
|
65 |
-
config.block_size = max_seq_length
|
66 |
-
|
67 |
-
checkpoint = torch.load(pretrained_path)
|
68 |
-
|
69 |
-
with fabric.init_module(), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
|
70 |
-
model = LLaMA(config)
|
71 |
-
# strict=False because missing keys due to LoRA weights not contained in checkpoint state
|
72 |
-
model.load_state_dict(checkpoint, strict=False)
|
73 |
-
|
74 |
-
mark_only_lora_as_trainable(model)
|
75 |
-
|
76 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
77 |
-
model, optimizer = fabric.setup(model, optimizer)
|
78 |
-
train(fabric, model, optimizer, train_data, val_data, tokenizer_path, out_dir)
|
79 |
-
|
80 |
-
# Save the final LoRA checkpoint at the end of training
|
81 |
-
checkpoint = lora_state_dict(model)
|
82 |
-
fabric.save(os.path.join(out_dir, "lit-llama-lora-finetuned.pth"), checkpoint)
|
83 |
-
|
84 |
-
|
85 |
-
def train(
|
86 |
-
fabric: L.Fabric,
|
87 |
-
model: torch.nn.Module,
|
88 |
-
optimizer: torch.optim.Optimizer,
|
89 |
-
train_data: np.ndarray,
|
90 |
-
val_data: np.ndarray,
|
91 |
-
tokenizer_path: str,
|
92 |
-
out_dir: str,
|
93 |
-
) -> None:
|
94 |
-
"""The training loop.
|
95 |
-
|
96 |
-
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
|
97 |
-
"""
|
98 |
-
step_count = 0
|
99 |
-
|
100 |
-
for iter_num in range(max_iters):
|
101 |
-
|
102 |
-
if step_count <= warmup_iters:
|
103 |
-
# linear warmup
|
104 |
-
lr = learning_rate * step_count / warmup_iters
|
105 |
-
for param_group in optimizer.param_groups:
|
106 |
-
param_group['lr'] = lr
|
107 |
-
|
108 |
-
t0 = time.time()
|
109 |
-
|
110 |
-
input_ids, targets = get_batch(fabric, train_data)
|
111 |
-
with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
|
112 |
-
logits = model(input_ids)
|
113 |
-
loss = loss_fn(logits, targets)
|
114 |
-
fabric.backward(loss / gradient_accumulation_iters)
|
115 |
-
|
116 |
-
if (iter_num + 1) % gradient_accumulation_iters == 0:
|
117 |
-
optimizer.step()
|
118 |
-
optimizer.zero_grad()
|
119 |
-
step_count += 1
|
120 |
-
|
121 |
-
if step_count % eval_interval == 0:
|
122 |
-
val_loss = validate(fabric, model, val_data, tokenizer_path)
|
123 |
-
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
|
124 |
-
fabric.barrier()
|
125 |
-
|
126 |
-
if step_count % save_interval == 0:
|
127 |
-
print(f"Saving LoRA weights to {out_dir}")
|
128 |
-
# We are only saving the LoRA weights
|
129 |
-
# TODO: Provide a function/script to merge the LoRA weights with pretrained weights
|
130 |
-
checkpoint = lora_state_dict(model)
|
131 |
-
fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint)
|
132 |
-
|
133 |
-
dt = time.time() - t0
|
134 |
-
if iter_num % log_interval == 0:
|
135 |
-
fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
|
136 |
-
|
137 |
-
|
138 |
-
def generate_response(model, instruction, tokenizer_path):
|
139 |
-
tokenizer = Tokenizer(tokenizer_path)
|
140 |
-
sample = {"instruction": instruction, "input": ""}
|
141 |
-
prompt = instruction
|
142 |
-
if instruction_tuning:
|
143 |
-
prompt = generate_prompt(sample)
|
144 |
-
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
|
145 |
-
|
146 |
-
output = generate(
|
147 |
-
model,
|
148 |
-
idx=encoded,
|
149 |
-
max_seq_length=max_seq_length,
|
150 |
-
max_new_tokens=100,
|
151 |
-
)
|
152 |
-
output = tokenizer.decode(output)
|
153 |
-
return output # output.split("### Response:")[1].strip()
|
154 |
-
|
155 |
-
|
156 |
-
@torch.no_grad()
|
157 |
-
def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray, tokenizer_path: str) -> torch.Tensor:
|
158 |
-
fabric.print("Validating ...")
|
159 |
-
model.eval()
|
160 |
-
losses = torch.zeros(eval_iters)
|
161 |
-
for k in range(eval_iters):
|
162 |
-
input_ids, targets = get_batch(fabric, val_data)
|
163 |
-
logits = model(input_ids)
|
164 |
-
loss = loss_fn(logits, targets)
|
165 |
-
losses[k] = loss.item()
|
166 |
-
out = losses.mean()
|
167 |
-
|
168 |
-
# produce an example:
|
169 |
-
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
|
170 |
-
|
171 |
-
output = generate_response(model, instruction, tokenizer_path)
|
172 |
-
fabric.print(instruction)
|
173 |
-
fabric.print(output)
|
174 |
-
|
175 |
-
model.train()
|
176 |
-
return out.item()
|
177 |
-
|
178 |
-
def loss_fn(logits, targets):
|
179 |
-
# shift the targets such that output n predicts token n+1
|
180 |
-
logits = logits[..., :-1, :].contiguous()
|
181 |
-
targets = targets[..., 1:].contiguous()
|
182 |
-
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
183 |
-
return loss
|
184 |
-
|
185 |
-
|
186 |
-
def get_batch(fabric: L.Fabric, data: list):
|
187 |
-
ix = torch.randint(len(data), (micro_batch_size,))
|
188 |
-
|
189 |
-
input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
|
190 |
-
labels = [data[i]["labels"].type(torch.int64) for i in ix]
|
191 |
-
|
192 |
-
max_len = max(len(s) for s in input_ids)
|
193 |
-
|
194 |
-
def pad_right(x, pad_id):
|
195 |
-
# pad right based on the longest sequence
|
196 |
-
n = max_len - len(x)
|
197 |
-
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
|
198 |
-
|
199 |
-
x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
|
200 |
-
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
|
201 |
-
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
|
202 |
-
return x, y
|
203 |
-
|
204 |
-
|
205 |
-
def load_datasets(data_dir):
|
206 |
-
train_data = torch.load(os.path.join(data_dir, "train.pt"))
|
207 |
-
val_data = torch.load(os.path.join(data_dir, "test.pt"))
|
208 |
-
return train_data, val_data
|
209 |
-
|
210 |
-
|
211 |
-
if __name__ == "__main__":
|
212 |
-
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
|
213 |
-
# torch.backends.cuda.enable_flash_sdp(False)
|
214 |
-
torch.set_float32_matmul_precision("high")
|
215 |
-
|
216 |
-
from jsonargparse.cli import CLI
|
217 |
-
|
218 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/generate.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import time
|
3 |
-
import warnings
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
-
import lightning as L
|
8 |
-
import torch
|
9 |
-
|
10 |
-
# support running without installing as a package
|
11 |
-
wd = Path(__file__).parent.parent.resolve()
|
12 |
-
sys.path.append(str(wd))
|
13 |
-
|
14 |
-
from lit_llama import LLaMA, Tokenizer
|
15 |
-
from lit_llama.utils import lazy_load, llama_model_lookup, quantization
|
16 |
-
|
17 |
-
|
18 |
-
@torch.no_grad()
|
19 |
-
def generate(
|
20 |
-
model: LLaMA,
|
21 |
-
idx: torch.Tensor,
|
22 |
-
max_new_tokens: int,
|
23 |
-
*,
|
24 |
-
max_seq_length: Optional[int] = None,
|
25 |
-
temperature: float = 1.0,
|
26 |
-
top_k: Optional[int] = None,
|
27 |
-
eos_id: Optional[int] = None,
|
28 |
-
) -> torch.Tensor:
|
29 |
-
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
30 |
-
|
31 |
-
The implementation of this function is modified from A. Karpathy's nanoGPT.
|
32 |
-
|
33 |
-
Args:
|
34 |
-
model: The model to use.
|
35 |
-
idx: Tensor of shape (T) with indices of the prompt sequence.
|
36 |
-
max_new_tokens: The number of new tokens to generate.
|
37 |
-
max_seq_length: The maximum sequence length allowed.
|
38 |
-
temperature: Scales the predicted logits by 1 / temperature
|
39 |
-
top_k: If specified, only sample among the tokens with the k highest probabilities
|
40 |
-
eos_id: If specified, stop generating any more token once the <eos> token is triggered
|
41 |
-
"""
|
42 |
-
# create an empty tensor of the expected final shape and fill in the current tokens
|
43 |
-
T = idx.size(0)
|
44 |
-
T_new = T + max_new_tokens
|
45 |
-
if max_seq_length is None:
|
46 |
-
max_seq_length = min(T_new, model.config.block_size)
|
47 |
-
|
48 |
-
device, dtype = idx.device, idx.dtype
|
49 |
-
# create an empty tensor of the expected final shape and fill in the current tokens
|
50 |
-
empty = torch.empty(T_new, dtype=dtype, device=device)
|
51 |
-
empty[:T] = idx
|
52 |
-
idx = empty
|
53 |
-
input_pos = torch.arange(0, T, device=device)
|
54 |
-
|
55 |
-
if idx.device.type == "xla":
|
56 |
-
import torch_xla.core.xla_model as xm
|
57 |
-
|
58 |
-
xm.mark_step()
|
59 |
-
|
60 |
-
# generate max_new_tokens tokens
|
61 |
-
for _ in range(max_new_tokens):
|
62 |
-
x = idx.index_select(0, input_pos).view(1, -1)
|
63 |
-
|
64 |
-
# forward
|
65 |
-
logits = model(x, max_seq_length, input_pos)
|
66 |
-
logits = logits[0, -1] / temperature
|
67 |
-
|
68 |
-
# optionally crop the logits to only the top k options
|
69 |
-
if top_k is not None:
|
70 |
-
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
71 |
-
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
|
72 |
-
|
73 |
-
probs = torch.nn.functional.softmax(logits, dim=-1)
|
74 |
-
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
|
75 |
-
|
76 |
-
# advance
|
77 |
-
input_pos = input_pos[-1:] + 1
|
78 |
-
|
79 |
-
if idx.device.type == "xla":
|
80 |
-
xm.mark_step()
|
81 |
-
|
82 |
-
# concatenate the new generation
|
83 |
-
idx = idx.index_copy(0, input_pos, idx_next)
|
84 |
-
|
85 |
-
# if <eos> token is triggered, return the output (stop generation)
|
86 |
-
if idx_next == eos_id:
|
87 |
-
return idx[:input_pos] # include the EOS token
|
88 |
-
|
89 |
-
return idx
|
90 |
-
|
91 |
-
|
92 |
-
def main(
|
93 |
-
prompt: str = "Hello, my name is",
|
94 |
-
*,
|
95 |
-
num_samples: int = 1,
|
96 |
-
max_new_tokens: int = 50,
|
97 |
-
top_k: int = 200,
|
98 |
-
temperature: float = 0.8,
|
99 |
-
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
|
100 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
101 |
-
quantize: Optional[str] = None,
|
102 |
-
) -> None:
|
103 |
-
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
|
104 |
-
|
105 |
-
Args:
|
106 |
-
prompt: The prompt string to use for generating the samples.
|
107 |
-
num_samples: The number of text samples to generate.
|
108 |
-
max_new_tokens: The number of generation steps to take.
|
109 |
-
top_k: The number of top most probable tokens to consider in the sampling process.
|
110 |
-
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
|
111 |
-
samples.
|
112 |
-
checkpoint_path: The checkpoint path to load.
|
113 |
-
tokenizer_path: The tokenizer path to load.
|
114 |
-
quantize: Whether to quantize the model and using which method:
|
115 |
-
``"llm.int8"``: LLM.int8() mode,
|
116 |
-
``"gptq.int4"``: GPTQ 4-bit mode.
|
117 |
-
"""
|
118 |
-
assert checkpoint_path.is_file(), checkpoint_path
|
119 |
-
assert tokenizer_path.is_file(), tokenizer_path
|
120 |
-
|
121 |
-
precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
|
122 |
-
fabric = L.Fabric(devices=1, precision=precision)
|
123 |
-
|
124 |
-
print("Loading model ...", file=sys.stderr)
|
125 |
-
t0 = time.time()
|
126 |
-
with lazy_load(checkpoint_path) as checkpoint:
|
127 |
-
name = llama_model_lookup(checkpoint)
|
128 |
-
|
129 |
-
with fabric.init_module(empty_init=True), quantization(mode=quantize):
|
130 |
-
model = LLaMA.from_name(name)
|
131 |
-
|
132 |
-
model.load_state_dict(checkpoint)
|
133 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
134 |
-
|
135 |
-
model.eval()
|
136 |
-
model = fabric.setup(model)
|
137 |
-
|
138 |
-
tokenizer = Tokenizer(tokenizer_path)
|
139 |
-
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
|
140 |
-
prompt_length = encoded.size(0)
|
141 |
-
|
142 |
-
L.seed_everything(1234)
|
143 |
-
for i in range(num_samples):
|
144 |
-
t0 = time.perf_counter()
|
145 |
-
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
|
146 |
-
t = time.perf_counter() - t0
|
147 |
-
|
148 |
-
model.reset_cache()
|
149 |
-
print(tokenizer.decode(y))
|
150 |
-
tokens_generated = y.size(0) - prompt_length
|
151 |
-
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
|
152 |
-
if fabric.device.type == "cuda":
|
153 |
-
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
|
154 |
-
|
155 |
-
|
156 |
-
if __name__ == "__main__":
|
157 |
-
from jsonargparse import CLI
|
158 |
-
|
159 |
-
torch.set_float32_matmul_precision("high")
|
160 |
-
warnings.filterwarnings(
|
161 |
-
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
|
162 |
-
"ignore",
|
163 |
-
message="ComplexHalf support is experimental and many operators don't support it yet"
|
164 |
-
)
|
165 |
-
warnings.filterwarnings(
|
166 |
-
# Triggered in bitsandbytes/autograd/_functions.py:298
|
167 |
-
"ignore",
|
168 |
-
message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
|
169 |
-
)
|
170 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/generate/adapter.py
DELETED
@@ -1,106 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import time
|
3 |
-
import warnings
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
-
import lightning as L
|
8 |
-
import torch
|
9 |
-
|
10 |
-
# support running without installing as a package
|
11 |
-
wd = Path(__file__).parent.parent.resolve()
|
12 |
-
sys.path.append(str(wd))
|
13 |
-
|
14 |
-
from generate import generate
|
15 |
-
from lit_llama import Tokenizer
|
16 |
-
from lit_llama.adapter import LLaMA
|
17 |
-
from lit_llama.utils import lazy_load, llama_model_lookup, quantization
|
18 |
-
from scripts.prepare_alpaca import generate_prompt
|
19 |
-
|
20 |
-
|
21 |
-
def main(
|
22 |
-
prompt: str = "What food do lamas eat?",
|
23 |
-
input: str = "",
|
24 |
-
adapter_path: Path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth"),
|
25 |
-
pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
|
26 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
27 |
-
quantize: Optional[str] = None,
|
28 |
-
max_new_tokens: int = 100,
|
29 |
-
top_k: int = 200,
|
30 |
-
temperature: float = 0.8,
|
31 |
-
) -> None:
|
32 |
-
"""Generates a response based on a given instruction and an optional input.
|
33 |
-
This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model.
|
34 |
-
See `finetune_adapter.py`.
|
35 |
-
|
36 |
-
Args:
|
37 |
-
prompt: The prompt/instruction (Alpaca style).
|
38 |
-
adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
|
39 |
-
`finetune_adapter.py`.
|
40 |
-
input: Optional input (Alpaca style).
|
41 |
-
pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
|
42 |
-
tokenizer_path: The tokenizer path to load.
|
43 |
-
quantize: Whether to quantize the model and using which method:
|
44 |
-
``"llm.int8"``: LLM.int8() mode,
|
45 |
-
``"gptq.int4"``: GPTQ 4-bit mode.
|
46 |
-
max_new_tokens: The number of generation steps to take.
|
47 |
-
top_k: The number of top most probable tokens to consider in the sampling process.
|
48 |
-
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
|
49 |
-
samples.
|
50 |
-
"""
|
51 |
-
assert adapter_path.is_file()
|
52 |
-
assert pretrained_path.is_file()
|
53 |
-
assert tokenizer_path.is_file()
|
54 |
-
|
55 |
-
precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
|
56 |
-
fabric = L.Fabric(devices=1, precision=precision)
|
57 |
-
|
58 |
-
print("Loading model ...", file=sys.stderr)
|
59 |
-
t0 = time.time()
|
60 |
-
with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
|
61 |
-
name = llama_model_lookup(pretrained_checkpoint)
|
62 |
-
|
63 |
-
with fabric.init_module(empty_init=True), quantization(mode=quantize):
|
64 |
-
model = LLaMA.from_name(name)
|
65 |
-
|
66 |
-
# 1. Load the pretrained weights
|
67 |
-
model.load_state_dict(pretrained_checkpoint, strict=False)
|
68 |
-
# 2. Load the fine-tuned adapter weights
|
69 |
-
model.load_state_dict(adapter_checkpoint, strict=False)
|
70 |
-
|
71 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
72 |
-
|
73 |
-
model.eval()
|
74 |
-
model = fabric.setup(model)
|
75 |
-
|
76 |
-
tokenizer = Tokenizer(tokenizer_path)
|
77 |
-
sample = {"instruction": prompt, "input": input}
|
78 |
-
prompt = generate_prompt(sample)
|
79 |
-
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
|
80 |
-
prompt_length = encoded.size(0)
|
81 |
-
|
82 |
-
t0 = time.perf_counter()
|
83 |
-
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
|
84 |
-
t = time.perf_counter() - t0
|
85 |
-
|
86 |
-
model.reset_cache()
|
87 |
-
output = tokenizer.decode(y)
|
88 |
-
output = output.split("### Response:")[1].strip()
|
89 |
-
print(output)
|
90 |
-
|
91 |
-
tokens_generated = y.size(0) - prompt_length
|
92 |
-
print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
|
93 |
-
if fabric.device.type == "cuda":
|
94 |
-
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
|
95 |
-
|
96 |
-
|
97 |
-
if __name__ == "__main__":
|
98 |
-
from jsonargparse import CLI
|
99 |
-
|
100 |
-
torch.set_float32_matmul_precision("high")
|
101 |
-
warnings.filterwarnings(
|
102 |
-
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
|
103 |
-
"ignore",
|
104 |
-
message="ComplexHalf support is experimental and many operators don't support it yet"
|
105 |
-
)
|
106 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/generate/adapter_v2.py
DELETED
@@ -1,108 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import time
|
3 |
-
import warnings
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
-
import lightning as L
|
8 |
-
import torch
|
9 |
-
|
10 |
-
# support running without installing as a package
|
11 |
-
wd = Path(__file__).parent.parent.resolve()
|
12 |
-
sys.path.append(str(wd))
|
13 |
-
|
14 |
-
from generate import generate
|
15 |
-
from lit_llama import Tokenizer
|
16 |
-
from lit_llama.adapter import LLaMA
|
17 |
-
from lit_llama.utils import lazy_load, llama_model_lookup, quantization
|
18 |
-
from lit_llama.adapter_v2 import add_adapter_v2_parameters_to_linear_layers
|
19 |
-
from scripts.prepare_alpaca import generate_prompt
|
20 |
-
|
21 |
-
|
22 |
-
def main(
|
23 |
-
prompt: str = "What food do lamas eat?",
|
24 |
-
input: str = "",
|
25 |
-
adapter_path: Path = Path("out/adapter_v2/alpaca/lit-llama-adapter-finetuned.pth"),
|
26 |
-
pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
|
27 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
28 |
-
quantize: Optional[str] = None,
|
29 |
-
max_new_tokens: int = 100,
|
30 |
-
top_k: int = 200,
|
31 |
-
temperature: float = 0.8,
|
32 |
-
) -> None:
|
33 |
-
"""Generates a response based on a given instruction and an optional input.
|
34 |
-
This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model.
|
35 |
-
See `finetune_adapter_v2.py`.
|
36 |
-
|
37 |
-
Args:
|
38 |
-
prompt: The prompt/instruction (Alpaca style).
|
39 |
-
adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
|
40 |
-
`finetune_adapter_v2.py`.
|
41 |
-
input: Optional input (Alpaca style).
|
42 |
-
pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
|
43 |
-
tokenizer_path: The tokenizer path to load.
|
44 |
-
quantize: Whether to quantize the model and using which method:
|
45 |
-
``"llm.int8"``: LLM.int8() mode,
|
46 |
-
``"gptq.int4"``: GPTQ 4-bit mode.
|
47 |
-
max_new_tokens: The number of generation steps to take.
|
48 |
-
top_k: The number of top most probable tokens to consider in the sampling process.
|
49 |
-
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
|
50 |
-
samples.
|
51 |
-
"""
|
52 |
-
assert adapter_path.is_file()
|
53 |
-
assert pretrained_path.is_file()
|
54 |
-
assert tokenizer_path.is_file()
|
55 |
-
|
56 |
-
precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
|
57 |
-
fabric = L.Fabric(devices=1, precision=precision)
|
58 |
-
|
59 |
-
print("Loading model ...", file=sys.stderr)
|
60 |
-
t0 = time.time()
|
61 |
-
with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
|
62 |
-
name = llama_model_lookup(pretrained_checkpoint)
|
63 |
-
|
64 |
-
with fabric.init_module(empty_init=True), quantization(mode=quantize):
|
65 |
-
model = LLaMA.from_name(name)
|
66 |
-
add_adapter_v2_parameters_to_linear_layers(model)
|
67 |
-
|
68 |
-
# 1. Load the pretrained weights
|
69 |
-
model.load_state_dict(pretrained_checkpoint, strict=False)
|
70 |
-
# 2. Load the fine-tuned adapter weights
|
71 |
-
model.load_state_dict(adapter_checkpoint, strict=False)
|
72 |
-
|
73 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
74 |
-
|
75 |
-
model.eval()
|
76 |
-
model = fabric.setup(model)
|
77 |
-
|
78 |
-
tokenizer = Tokenizer(tokenizer_path)
|
79 |
-
sample = {"instruction": prompt, "input": input}
|
80 |
-
prompt = generate_prompt(sample)
|
81 |
-
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
|
82 |
-
prompt_length = encoded.size(0)
|
83 |
-
|
84 |
-
t0 = time.perf_counter()
|
85 |
-
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
|
86 |
-
t = time.perf_counter() - t0
|
87 |
-
|
88 |
-
model.reset_cache()
|
89 |
-
output = tokenizer.decode(y)
|
90 |
-
output = output.split("### Response:")[1].strip()
|
91 |
-
print(output)
|
92 |
-
|
93 |
-
tokens_generated = y.size(0) - prompt_length
|
94 |
-
print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
|
95 |
-
if fabric.device.type == "cuda":
|
96 |
-
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
|
97 |
-
|
98 |
-
|
99 |
-
if __name__ == "__main__":
|
100 |
-
from jsonargparse import CLI
|
101 |
-
|
102 |
-
torch.set_float32_matmul_precision("high")
|
103 |
-
warnings.filterwarnings(
|
104 |
-
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
|
105 |
-
"ignore",
|
106 |
-
message="ComplexHalf support is experimental and many operators don't support it yet"
|
107 |
-
)
|
108 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/generate/full.py
DELETED
@@ -1,103 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import time
|
3 |
-
import warnings
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
-
import lightning as L
|
8 |
-
import torch
|
9 |
-
|
10 |
-
# support running without installing as a package
|
11 |
-
wd = Path(__file__).absolute().parent.parent
|
12 |
-
sys.path.append(str(wd))
|
13 |
-
|
14 |
-
from lit_llama import LLaMA, Tokenizer
|
15 |
-
from lit_llama.utils import quantization
|
16 |
-
from scripts.prepare_alpaca import generate_prompt
|
17 |
-
from generate import generate
|
18 |
-
|
19 |
-
|
20 |
-
def main(
|
21 |
-
prompt: str = "Hello, my name is",
|
22 |
-
*,
|
23 |
-
num_samples: int = 1,
|
24 |
-
max_new_tokens: int = 50,
|
25 |
-
top_k: int = 200,
|
26 |
-
temperature: float = 0.8,
|
27 |
-
checkpoint_path: Optional[Path] = None,
|
28 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
29 |
-
model_size: str = "7B",
|
30 |
-
quantize: Optional[str] = None,
|
31 |
-
) -> None:
|
32 |
-
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
|
33 |
-
|
34 |
-
Args:
|
35 |
-
prompt: The prompt string to use for generating the samples.
|
36 |
-
num_samples: The number of text samples to generate.
|
37 |
-
max_new_tokens: The number of generation steps to take.
|
38 |
-
top_k: The number of top most probable tokens to consider in the sampling process.
|
39 |
-
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
|
40 |
-
samples.
|
41 |
-
checkpoint_path: The checkpoint path to load.
|
42 |
-
tokenizer_path: The tokenizer path to load.
|
43 |
-
model_size: The model size to load.
|
44 |
-
quantize: Whether to quantize the model and using which method:
|
45 |
-
``"llm.int8"``: LLM.int8() mode,
|
46 |
-
``"gptq.int4"``: GPTQ 4-bit mode.
|
47 |
-
"""
|
48 |
-
if not checkpoint_path:
|
49 |
-
checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth")
|
50 |
-
assert checkpoint_path.is_file(), checkpoint_path
|
51 |
-
assert tokenizer_path.is_file(), tokenizer_path
|
52 |
-
|
53 |
-
precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
|
54 |
-
fabric = L.Fabric(devices=1, precision=precision)
|
55 |
-
|
56 |
-
print("Loading model ...", file=sys.stderr)
|
57 |
-
t0 = time.time()
|
58 |
-
|
59 |
-
with fabric.init_module(empty_init=True), quantization(mode=quantize):
|
60 |
-
model = LLaMA.from_name(model_size)
|
61 |
-
|
62 |
-
checkpoint = torch.load(checkpoint_path)
|
63 |
-
model.load_state_dict(checkpoint)
|
64 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
65 |
-
|
66 |
-
model.eval()
|
67 |
-
model = fabric.setup(model)
|
68 |
-
|
69 |
-
tokenizer = Tokenizer(tokenizer_path)
|
70 |
-
sample = {"instruction": prompt, "input": input}
|
71 |
-
prompt = generate_prompt(sample)
|
72 |
-
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
|
73 |
-
prompt_length = encoded.size(0)
|
74 |
-
|
75 |
-
L.seed_everything(1234)
|
76 |
-
for i in range(num_samples):
|
77 |
-
t0 = time.perf_counter()
|
78 |
-
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
|
79 |
-
t = time.perf_counter() - t0
|
80 |
-
|
81 |
-
model.reset_cache()
|
82 |
-
print(tokenizer.decode(y))
|
83 |
-
tokens_generated = y.size(0) - prompt_length
|
84 |
-
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
|
85 |
-
if fabric.device.type == "cuda":
|
86 |
-
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
|
87 |
-
|
88 |
-
|
89 |
-
if __name__ == "__main__":
|
90 |
-
from jsonargparse import CLI
|
91 |
-
|
92 |
-
torch.set_float32_matmul_precision("high")
|
93 |
-
warnings.filterwarnings(
|
94 |
-
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
|
95 |
-
"ignore",
|
96 |
-
message="ComplexHalf support is experimental and many operators don't support it yet"
|
97 |
-
)
|
98 |
-
warnings.filterwarnings(
|
99 |
-
# Triggered in bitsandbytes/autograd/_functions.py:298
|
100 |
-
"ignore",
|
101 |
-
message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
|
102 |
-
)
|
103 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/generate/lora.py
DELETED
@@ -1,118 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import time
|
3 |
-
import warnings
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
-
import lightning as L
|
8 |
-
import torch
|
9 |
-
|
10 |
-
# support running without installing as a package
|
11 |
-
wd = Path(__file__).parent.parent.resolve()
|
12 |
-
sys.path.append(str(wd))
|
13 |
-
|
14 |
-
from generate import generate
|
15 |
-
from lit_llama import Tokenizer, LLaMA
|
16 |
-
from lit_llama.lora import lora
|
17 |
-
from lit_llama.utils import lazy_load, llama_model_lookup
|
18 |
-
from scripts.prepare_alpaca import generate_prompt
|
19 |
-
|
20 |
-
lora_r = 8
|
21 |
-
lora_alpha = 16
|
22 |
-
lora_dropout = 0.05
|
23 |
-
|
24 |
-
|
25 |
-
def main(
|
26 |
-
prompt: str = "What food do lamas eat?",
|
27 |
-
input: str = "",
|
28 |
-
lora_path: Path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth"),
|
29 |
-
pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
|
30 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
31 |
-
quantize: Optional[str] = None,
|
32 |
-
max_new_tokens: int = 100,
|
33 |
-
top_k: int = 200,
|
34 |
-
temperature: float = 0.8,
|
35 |
-
) -> None:
|
36 |
-
"""Generates a response based on a given instruction and an optional input.
|
37 |
-
This script will only work with checkpoints from the instruction-tuned LoRA model.
|
38 |
-
See `finetune_lora.py`.
|
39 |
-
|
40 |
-
Args:
|
41 |
-
prompt: The prompt/instruction (Alpaca style).
|
42 |
-
lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
|
43 |
-
`finetune_lora.py`.
|
44 |
-
input: Optional input (Alpaca style).
|
45 |
-
pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
|
46 |
-
tokenizer_path: The tokenizer path to load.
|
47 |
-
quantize: Whether to quantize the model and using which method:
|
48 |
-
``"llm.int8"``: LLM.int8() mode,
|
49 |
-
``"gptq.int4"``: GPTQ 4-bit mode.
|
50 |
-
max_new_tokens: The number of generation steps to take.
|
51 |
-
top_k: The number of top most probable tokens to consider in the sampling process.
|
52 |
-
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
|
53 |
-
samples.
|
54 |
-
"""
|
55 |
-
assert lora_path.is_file()
|
56 |
-
assert pretrained_path.is_file()
|
57 |
-
assert tokenizer_path.is_file()
|
58 |
-
|
59 |
-
if quantize is not None:
|
60 |
-
raise NotImplementedError("Quantization in LoRA is not supported yet")
|
61 |
-
|
62 |
-
precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
|
63 |
-
fabric = L.Fabric(devices=1, precision=precision)
|
64 |
-
|
65 |
-
print("Loading model ...", file=sys.stderr)
|
66 |
-
t0 = time.time()
|
67 |
-
|
68 |
-
with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(lora_path) as lora_checkpoint:
|
69 |
-
name = llama_model_lookup(pretrained_checkpoint)
|
70 |
-
|
71 |
-
with fabric.init_module(empty_init=True), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
|
72 |
-
model = LLaMA.from_name(name)
|
73 |
-
|
74 |
-
# 1. Load the pretrained weights
|
75 |
-
model.load_state_dict(pretrained_checkpoint, strict=False)
|
76 |
-
# 2. Load the fine-tuned lora weights
|
77 |
-
model.load_state_dict(lora_checkpoint, strict=False)
|
78 |
-
|
79 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
80 |
-
|
81 |
-
model.eval()
|
82 |
-
model = fabric.setup(model)
|
83 |
-
|
84 |
-
tokenizer = Tokenizer(tokenizer_path)
|
85 |
-
sample = {"instruction": prompt, "input": input}
|
86 |
-
prompt = generate_prompt(sample)
|
87 |
-
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
|
88 |
-
|
89 |
-
t0 = time.perf_counter()
|
90 |
-
output = generate(
|
91 |
-
model,
|
92 |
-
idx=encoded,
|
93 |
-
max_new_tokens=max_new_tokens,
|
94 |
-
temperature=temperature,
|
95 |
-
top_k=top_k,
|
96 |
-
eos_id=tokenizer.eos_id
|
97 |
-
)
|
98 |
-
t = time.perf_counter() - t0
|
99 |
-
|
100 |
-
output = tokenizer.decode(output)
|
101 |
-
output = output.split("### Response:")[1].strip()
|
102 |
-
print(output)
|
103 |
-
|
104 |
-
print(f"\n\nTime for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
|
105 |
-
if fabric.device.type == "cuda":
|
106 |
-
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
|
107 |
-
|
108 |
-
|
109 |
-
if __name__ == "__main__":
|
110 |
-
from jsonargparse import CLI
|
111 |
-
|
112 |
-
torch.set_float32_matmul_precision("high")
|
113 |
-
warnings.filterwarnings(
|
114 |
-
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
|
115 |
-
"ignore",
|
116 |
-
message="ComplexHalf support is experimental and many operators don't support it yet"
|
117 |
-
)
|
118 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/convert_lora_weights.md
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
# Merging LoRA weights into base model weights
|
2 |
-
|
3 |
-
Purpose: By merging our selected LoRA weights into the base model weights, we can benefit from all base model optimisation such as quantisation (available in this repo), pruning, caching, etc.
|
4 |
-
|
5 |
-
|
6 |
-
## How to run?
|
7 |
-
|
8 |
-
After you have finish finetuning using LoRA, select your weight and run the converter script:
|
9 |
-
|
10 |
-
```bash
|
11 |
-
python scripts/convert_lora_weights.py --lora_path out/lora/your-folder/your-weight-name.pth
|
12 |
-
```
|
13 |
-
|
14 |
-
The converted base weight file will be saved into the same folder with the name `{your-weight-name}-lora-merged-weights.pth`. Now you can run `generate.py` with the merged weights and apply quantisation:
|
15 |
-
|
16 |
-
```bash
|
17 |
-
python generate.py --checkpoint_path out/lora/your-folder/your-weight-name-lora-merged-weights.pth --quantize llm.int8
|
18 |
-
```
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/customize_paths.md
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
## Customize paths
|
2 |
-
|
3 |
-
The project is setup to use specific paths to read the original weights and save checkpoints etc.
|
4 |
-
|
5 |
-
For all scripts, you can run
|
6 |
-
|
7 |
-
```shell
|
8 |
-
python script.py -h
|
9 |
-
```
|
10 |
-
|
11 |
-
to get a list of available options. For instance, here's how you would modify the checkpoint dir:
|
12 |
-
|
13 |
-
```shell
|
14 |
-
python scripts/convert_checkpoint.py --checkpoint_dir "data/checkpoints/foo"
|
15 |
-
```
|
16 |
-
|
17 |
-
Note that this change will need to be passed along to subsequent steps, for example:
|
18 |
-
|
19 |
-
```shell
|
20 |
-
python generate.py \
|
21 |
-
--checkpoint_path "data/checkpoints/foo/7B/lit-llama.pth" \
|
22 |
-
--tokenizer_path "data/checkpoints/foo/tokenizer.model"
|
23 |
-
```
|
24 |
-
|
25 |
-
and
|
26 |
-
|
27 |
-
```shell
|
28 |
-
python quantize/gptq.py \
|
29 |
-
--checkpoint_path "data/checkpoints/foo/7B/lit-llama.pth" \
|
30 |
-
--tokenizer_path "data/checkpoints/foo/tokenizer.model"
|
31 |
-
```
|
32 |
-
|
33 |
-
To avoid this, you can use symbolic links to create shortcuts and avoid passing different paths.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/download_weights.md
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
## Downloading pretrained weights
|
2 |
-
|
3 |
-
Except for when you are training from scratch, you will need the pretrained weights from Meta.
|
4 |
-
|
5 |
-
### Original Meta weights
|
6 |
-
|
7 |
-
Download the model weights following the instructions on the official [LLaMA repository](https://github.com/facebookresearch/llama).
|
8 |
-
|
9 |
-
Once downloaded, you should have a folder like this:
|
10 |
-
|
11 |
-
```text
|
12 |
-
checkpoints/llama
|
13 |
-
├── 7B
|
14 |
-
│ ├── ...
|
15 |
-
│ └── consolidated.00.pth
|
16 |
-
├── 13B
|
17 |
-
│ ...
|
18 |
-
└── tokenizer.model
|
19 |
-
```
|
20 |
-
|
21 |
-
Convert the weights to the Lit-LLaMA format:
|
22 |
-
|
23 |
-
```bash
|
24 |
-
python scripts/convert_checkpoint.py --model_size 7B
|
25 |
-
```
|
26 |
-
|
27 |
-
> **Note**
|
28 |
-
> All scripts support argument [customization](customize_paths.md)
|
29 |
-
|
30 |
-
### OpenLLaMA
|
31 |
-
|
32 |
-
OpenLM Research has released **Apache 2.0 licensed** weights obtained by training LLaMA on the 1.2 trillion token open-source [RedPajama](https://github.com/togethercomputer/RedPajama-Data) dataset.
|
33 |
-
|
34 |
-
Weights were released in preview on intermediate number of tokens (1T at the time of writing). In order to get them do:
|
35 |
-
|
36 |
-
```bash
|
37 |
-
# Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
|
38 |
-
git clone https://huggingface.co/openlm-research/open_llama_7b checkpoints/open-llama/7B
|
39 |
-
```
|
40 |
-
|
41 |
-
Or if you don't have `git-lfs` installed:
|
42 |
-
|
43 |
-
```bash
|
44 |
-
python scripts/download.py --repo_id openlm-research/open_llama_7b --local_dir checkpoints/open-llama/7B
|
45 |
-
```
|
46 |
-
|
47 |
-
Once downloaded, you should have a folder like this:
|
48 |
-
|
49 |
-
```text
|
50 |
-
checkpoints/open-llama/
|
51 |
-
└── 7B
|
52 |
-
├── ...
|
53 |
-
├── pytorch_model-00001-of-00002.bin
|
54 |
-
├── pytorch_model-00002-of-00002.bin
|
55 |
-
├── pytorch_model.bin.index.json
|
56 |
-
└── tokenizer.model
|
57 |
-
```
|
58 |
-
|
59 |
-
Convert the weights to the Lit-LLaMA format:
|
60 |
-
|
61 |
-
```bash
|
62 |
-
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/open-llama/7B --model_size 7B
|
63 |
-
```
|
64 |
-
|
65 |
-
> **Note**
|
66 |
-
> All scripts support argument [customization](customize_paths.md)
|
67 |
-
|
68 |
-
Once converted, you should have a folder like this:
|
69 |
-
|
70 |
-
```text
|
71 |
-
checkpoints/lit-llama/
|
72 |
-
├── 7B
|
73 |
-
│ └── lit-llama.pth
|
74 |
-
└── tokenizer.model
|
75 |
-
```
|
76 |
-
|
77 |
-
You are all set. Now you can continue with inference or finetuning.
|
78 |
-
|
79 |
-
Try running [`generate.py` to test the imported weights](inference.md).
|
80 |
-
|
81 |
-
|
82 |
-
### Alternative sources
|
83 |
-
|
84 |
-
You might find LLaMA weights hosted online in the HuggingFace hub. Beware that this infringes the original weight's license.
|
85 |
-
You could try downloading them by running the following command with a specific repo id:
|
86 |
-
|
87 |
-
```bash
|
88 |
-
# Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
|
89 |
-
git clone REPO_ID checkpoints/hf-llama/7B
|
90 |
-
```
|
91 |
-
|
92 |
-
Or if you don't have `git-lfs` installed:
|
93 |
-
|
94 |
-
```bash
|
95 |
-
python scripts/download.py --repo_id REPO_ID --local_dir checkpoints/hf-llama/7B
|
96 |
-
```
|
97 |
-
|
98 |
-
Once downloaded, you should have a folder like this:
|
99 |
-
|
100 |
-
```text
|
101 |
-
checkpoints/hf-llama/
|
102 |
-
└── 7B
|
103 |
-
├── ...
|
104 |
-
├── pytorch_model-00001-of-00002.bin
|
105 |
-
├── pytorch_model-00002-of-00002.bin
|
106 |
-
├── pytorch_model.bin.index.json
|
107 |
-
└── tokenizer.model
|
108 |
-
```
|
109 |
-
|
110 |
-
Convert the weights to the Lit-LLaMA format:
|
111 |
-
|
112 |
-
```bash
|
113 |
-
python scripts/convert_hf_checkpoint.py --model_size 7B
|
114 |
-
```
|
115 |
-
|
116 |
-
> **Note**
|
117 |
-
> All scripts support argument [customization](customize_paths.md)
|
118 |
-
|
119 |
-
Once converted, you should have a folder like this:
|
120 |
-
|
121 |
-
```text
|
122 |
-
checkpoints/lit-llama/
|
123 |
-
├── 7B
|
124 |
-
│ └── lit-llama.pth
|
125 |
-
└── tokenizer.model
|
126 |
-
```
|
127 |
-
|
128 |
-
You are all set. Now you can continue with inference or finetuning.
|
129 |
-
|
130 |
-
Try running [`generate.py` to test the imported weights](inference.md).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/finetune_adapter.md
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
# Finetuning with Adapter
|
2 |
-
|
3 |
-
[LLaMA-Adapter](https://arxiv.org/abs/2303.16199) is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in LLaMA. In total, there are only 1.2M parameters to update during finetuning, which significantly reduces the memory footprint and speeds up training.
|
4 |
-
|
5 |
-
We are able to demonstrate instruction-finetuning Lit-LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.
|
6 |
-
|
7 |
-
If you are new to LLaMA-Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful.
|
8 |
-
|
9 |
-
## LLaMA-Adapter v2
|
10 |
-
|
11 |
-
The LLaMA-Adapter authors developed a newer adapter method called LLaMA-Adapter v2, which is related to this LLaMA-Adapter method but includes more trainable parameters. LLaMA-Adapter v2 is also available via Lit-LLaMA; you can read more about it in [the related how-to doc here](./finetune_adapter_v2.md).
|
12 |
-
|
13 |
-
## Preparation
|
14 |
-
|
15 |
-
The steps here only need to be done once:
|
16 |
-
|
17 |
-
1. Follow the instructions in the [README](README.md) to install the dependencies.
|
18 |
-
2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
|
19 |
-
3. If you want to utilize more than one GPU, you should `pip install deepspeed`.
|
20 |
-
4. Download the data and generate the Alpaca instruction tuning dataset:
|
21 |
-
|
22 |
-
```bash
|
23 |
-
python scripts/prepare_alpaca.py
|
24 |
-
```
|
25 |
-
|
26 |
-
or [prepare your own dataset](#tune-on-your-dataset).
|
27 |
-
|
28 |
-
See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
|
29 |
-
|
30 |
-
## Running the finetuning
|
31 |
-
|
32 |
-
```bash
|
33 |
-
python finetune/adapter.py
|
34 |
-
```
|
35 |
-
|
36 |
-
The finetuning requires at least one GPU with ~24 GB memory (RTX 3090).
|
37 |
-
You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available.
|
38 |
-
Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
|
39 |
-
|
40 |
-
For example, the following settings will let you finetune the model in under 1 hour using DeepSpeed Zero-2:
|
41 |
-
|
42 |
-
```python
|
43 |
-
devices = 8
|
44 |
-
micro_batch_size = 8
|
45 |
-
```
|
46 |
-
|
47 |
-
This script will save checkpoints periodically to the folder `out/`.
|
48 |
-
|
49 |
-
> **Note**
|
50 |
-
> All scripts support argument [customization](customize_paths.md)
|
51 |
-
|
52 |
-
## Test the model
|
53 |
-
|
54 |
-
You can test the finetuned model with your own instructions by running:
|
55 |
-
|
56 |
-
```bash
|
57 |
-
python generate/adapter.py \
|
58 |
-
--prompt "Recommend a movie to watch on the weekend." \
|
59 |
-
--quantize llm.int8
|
60 |
-
```
|
61 |
-
Output:
|
62 |
-
```
|
63 |
-
A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
|
64 |
-
```
|
65 |
-
If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
|
66 |
-
|
67 |
-
## Tune on your dataset
|
68 |
-
|
69 |
-
With only a few modifications, you can prepare and train on your own instruction dataset.
|
70 |
-
|
71 |
-
1. Create a json file in which each row holds one instruction-response pair.
|
72 |
-
A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
|
73 |
-
the empty string if the instruction doesn't require a context. Below is an example json file:
|
74 |
-
|
75 |
-
```
|
76 |
-
[
|
77 |
-
{
|
78 |
-
"instruction": "Arrange the given numbers in ascending order.",
|
79 |
-
"input": "2, 4, 0, 8, 3",
|
80 |
-
"output": "0, 2, 3, 4, 8"
|
81 |
-
},
|
82 |
-
...
|
83 |
-
]
|
84 |
-
```
|
85 |
-
|
86 |
-
2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
|
87 |
-
|
88 |
-
```bash
|
89 |
-
cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
|
90 |
-
```
|
91 |
-
|
92 |
-
3. Modify `scripts/prepare_mydata.py` to read the json data file.
|
93 |
-
4. Run the script to generate the preprocessed, tokenized train-val split:
|
94 |
-
|
95 |
-
```bash
|
96 |
-
python scripts/prepare_mydata.py --destination_path data/mydata/
|
97 |
-
```
|
98 |
-
|
99 |
-
5. Run `finetune/adapter.py` by passing in the location of your data (and optionally other parameters):
|
100 |
-
|
101 |
-
```bash
|
102 |
-
python finetune/adapter.py --data_dir data/mydata/ --out_dir out/myexperiment
|
103 |
-
```
|
104 |
-
|
105 |
-
|
106 |
-
## Troubleshooting
|
107 |
-
|
108 |
-
If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
|
109 |
-
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/finetune_adapter_v2.md
DELETED
@@ -1,114 +0,0 @@
|
|
1 |
-
# Finetuning with Adapter v2
|
2 |
-
|
3 |
-
[LLaMA-Adapter v2](https://arxiv.org/abs/2304.15010) is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in LLaMA. In total, there are only ~4 M parameters to update during finetuning, which significantly reduces the memory footprint and speeds up training.
|
4 |
-
|
5 |
-
We are able to demonstrate instruction-finetuning Lit-LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.
|
6 |
-
|
7 |
-
If you are new to LLaMA-Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful.
|
8 |
-
|
9 |
-
## LLaMA-Adapter v1 versus LLaMA-Adapter v2
|
10 |
-
|
11 |
-
LLaMA-Adapter v2 extends the original LLaMA-Adapter idea by adding trainable bias and scale parameters to each linear layer in the transformer. Furthermore, LLaMA-Adapter v2 makes the normalization layers trainable. Where the 7B LLaMA model has 1.2M trainable parameters with LLaMA v1, LLaMA-Adapter v2 adds 2.8 M trainable parameters for the bias and scale parameters and ~300k trainable parameters for the normalization layers. So, adapter v2 has ~4.3 M trainable parameters in total.
|
12 |
-
|
13 |
-
If you are interested in using the more lightweight LLaMA-Adapter v1 approach, see [the related LLaMA Adapter how-to doc here](./finetune_adapter.md).
|
14 |
-
|
15 |
-
While LLaMA-Adapter v2 increases the number of trainable parameters from 1.2 M (from LLaMA-Apdapter v1) to 4.3 M, the inference cost is not significantly impacted. This is because the additional bias and scale parameters are cheap to compute in the forward pass, and the RMSNorm parameters are already included in the base model. In LLaMA-Adapter v1, the RMSNorm parameters are not trainable.
|
16 |
-
|
17 |
-
|
18 |
-
## Preparation
|
19 |
-
|
20 |
-
The steps here only need to be done once:
|
21 |
-
|
22 |
-
1. Follow the instructions in the [README](README.md) to install the dependencies.
|
23 |
-
2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
|
24 |
-
3. If you want to utilize more than one GPU, you should `pip install deepspeed`.
|
25 |
-
4. Download the data and generate the Alpaca instruction tuning dataset:
|
26 |
-
|
27 |
-
```bash
|
28 |
-
python scripts/prepare_alpaca.py
|
29 |
-
```
|
30 |
-
|
31 |
-
or [prepare your own dataset](#tune-on-your-dataset).
|
32 |
-
|
33 |
-
See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
|
34 |
-
|
35 |
-
## Running the finetuning
|
36 |
-
|
37 |
-
```bash
|
38 |
-
python finetune/adapter_v2.py
|
39 |
-
```
|
40 |
-
|
41 |
-
The finetuning requires at least one GPU with ~24 GB memory (RTX 3090).
|
42 |
-
You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available.
|
43 |
-
Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
|
44 |
-
|
45 |
-
For example, the following settings will let you finetune the model in under 1 hour using DeepSpeed Zero-2:
|
46 |
-
|
47 |
-
```python
|
48 |
-
devices = 8
|
49 |
-
micro_batch_size = 8
|
50 |
-
```
|
51 |
-
|
52 |
-
This script will save checkpoints periodically to the folder `out/`.
|
53 |
-
|
54 |
-
> **Note**
|
55 |
-
> All scripts support argument [customization](customize_paths.md)
|
56 |
-
|
57 |
-
## Test the model
|
58 |
-
|
59 |
-
You can test the finetuned model with your own instructions by running:
|
60 |
-
|
61 |
-
```bash
|
62 |
-
python generate/adapter_v2.py \
|
63 |
-
--prompt "Recommend a movie to watch on the weekend." \
|
64 |
-
--quantize llm.int8
|
65 |
-
```
|
66 |
-
Output:
|
67 |
-
```
|
68 |
-
A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
|
69 |
-
```
|
70 |
-
If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
|
71 |
-
|
72 |
-
## Tune on your dataset
|
73 |
-
|
74 |
-
With only a few modifications, you can prepare and train on your own instruction dataset.
|
75 |
-
|
76 |
-
1. Create a json file in which each row holds one instruction-response pair.
|
77 |
-
A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
|
78 |
-
the empty string if the instruction doesn't require a context. Below is an example json file:
|
79 |
-
|
80 |
-
```
|
81 |
-
[
|
82 |
-
{
|
83 |
-
"instruction": "Arrange the given numbers in ascending order.",
|
84 |
-
"input": "2, 4, 0, 8, 3",
|
85 |
-
"output": "0, 2, 3, 4, 8"
|
86 |
-
},
|
87 |
-
...
|
88 |
-
]
|
89 |
-
```
|
90 |
-
|
91 |
-
2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
|
92 |
-
|
93 |
-
```bash
|
94 |
-
cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
|
95 |
-
```
|
96 |
-
|
97 |
-
3. Modify `scripts/prepare_mydata.py` to read the json data file.
|
98 |
-
4. Run the script to generate the preprocessed, tokenized train-val split:
|
99 |
-
|
100 |
-
```bash
|
101 |
-
python scripts/prepare_mydata.py --destination_path data/mydata/
|
102 |
-
```
|
103 |
-
|
104 |
-
5. Run `finetune/adapter_v2.py` by passing in the location of your data (and optionally other parameters):
|
105 |
-
|
106 |
-
```bash
|
107 |
-
python finetune/adapter_v2.py --data_dir data/mydata/ --out_dir out/myexperiment
|
108 |
-
```
|
109 |
-
|
110 |
-
|
111 |
-
## Troubleshooting
|
112 |
-
|
113 |
-
If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
|
114 |
-
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/finetune_full.md
DELETED
@@ -1,106 +0,0 @@
|
|
1 |
-
# Full Finetuning
|
2 |
-
|
3 |
-
Full finetuning updates all layers in the pretrained LLaMA model. This *regular* finetuning procedure is typically considered as the baseline for parameter-efficient alternatives such as Low-Rank Adaptation (LoRA) or LLaMA-Adapter.
|
4 |
-
|
5 |
-
The current [finetune/full.py](../finetune/full.py) we provide uses 4 A100 GPUs with a fully-sharded data parallel strategy to finetune Lit-LLaMA 7B on [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset. The A100 GPUs have 40 GB each, but it may require less memory to finetune this model.
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
## Preparation
|
10 |
-
|
11 |
-
The steps here only need to be done once:
|
12 |
-
|
13 |
-
1. Follow the instructions in the [README](README.md) to install the dependencies.
|
14 |
-
|
15 |
-
2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
|
16 |
-
|
17 |
-
4. Download the data and generate the Alpaca instruction tuning dataset:
|
18 |
-
|
19 |
-
```bash
|
20 |
-
python scripts/prepare_alpaca.py
|
21 |
-
```
|
22 |
-
|
23 |
-
or [prepare your own dataset](#tune-on-your-own-dataset).
|
24 |
-
|
25 |
-
See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
|
26 |
-
|
27 |
-
## Running the finetuning
|
28 |
-
|
29 |
-
```bash
|
30 |
-
python finetune/full.py
|
31 |
-
```
|
32 |
-
|
33 |
-
|
34 |
-
You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available or increase the `batch_size`.
|
35 |
-
Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
|
36 |
-
|
37 |
-
For example, the following settings will let you finetune the model in 32 hours using a fully-sharded data parallel strategy:
|
38 |
-
```python
|
39 |
-
devices = 4
|
40 |
-
batch_size = 128 // devices
|
41 |
-
micro_batch_size = 4
|
42 |
-
```
|
43 |
-
|
44 |
-
This script will save checkpoints periodically to the folder `out/`.
|
45 |
-
|
46 |
-
> **Note**
|
47 |
-
> All scripts support argument [customization](customize_paths.md)
|
48 |
-
|
49 |
-
## Test the model
|
50 |
-
|
51 |
-
You can test the finetuned model with your own instructions by running:
|
52 |
-
|
53 |
-
```bash
|
54 |
-
python generate/full.py \
|
55 |
-
--prompt "Recommend a movie to watch on the weekend." \
|
56 |
-
--quantize llm.int8
|
57 |
-
```
|
58 |
-
Output:
|
59 |
-
```
|
60 |
-
A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
|
61 |
-
```
|
62 |
-
If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.
|
63 |
-
|
64 |
-
## Tune on your dataset
|
65 |
-
|
66 |
-
With only a few modifications, you can prepare and train on your own instruction dataset.
|
67 |
-
|
68 |
-
1. Create a json file in which each row holds one instruction-response pair.
|
69 |
-
A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
|
70 |
-
the empty string if the instruction doesn't require a context. Below is an example json file:
|
71 |
-
|
72 |
-
```
|
73 |
-
[
|
74 |
-
{
|
75 |
-
"instruction": "Arrange the given numbers in ascending order.",
|
76 |
-
"input": "2, 4, 0, 8, 3",
|
77 |
-
"output": "0, 2, 3, 4, 8"
|
78 |
-
},
|
79 |
-
...
|
80 |
-
]
|
81 |
-
```
|
82 |
-
|
83 |
-
2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
|
84 |
-
|
85 |
-
```bash
|
86 |
-
cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
|
87 |
-
```
|
88 |
-
|
89 |
-
3. Modify `scripts/prepare_mydata.py` to read the json data file.
|
90 |
-
4. Run the script to generate the preprocessed, tokenized train-val split:
|
91 |
-
|
92 |
-
```bash
|
93 |
-
python scripts/prepare_mydata.py --destination_path data/mydata/
|
94 |
-
```
|
95 |
-
|
96 |
-
5. Run `finetune/full.py` by passing in the location of your data (and optionally other parameters):
|
97 |
-
|
98 |
-
```bash
|
99 |
-
python finetune/full.py --data_dir data/mydata/ --out_dir out/myexperiment
|
100 |
-
```
|
101 |
-
|
102 |
-
|
103 |
-
## Troubleshooting
|
104 |
-
|
105 |
-
If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
|
106 |
-
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/finetune_lora.md
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
# Finetuning with LoRA
|
2 |
-
|
3 |
-
[Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model.
|
4 |
-
We demonstrate this method by instruction-finetuning LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU**.
|
5 |
-
|
6 |
-
## Preparation
|
7 |
-
|
8 |
-
The steps here only need to be done once:
|
9 |
-
|
10 |
-
1. Follow the instructions in the [README](../README.md) to install the dependencies.
|
11 |
-
2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
|
12 |
-
3. Download the data and generate the instruction tuning dataset:
|
13 |
-
|
14 |
-
```bash
|
15 |
-
python scripts/prepare_alpaca.py
|
16 |
-
```
|
17 |
-
|
18 |
-
See also: [Finetuning on an unstructured dataset](unstructured_dataset.md)
|
19 |
-
|
20 |
-
## Running the finetuning
|
21 |
-
|
22 |
-
```bash
|
23 |
-
python finetune/lora.py
|
24 |
-
```
|
25 |
-
|
26 |
-
The finetuning requires at least one GPU with ~24 GB memory (RTX 3090).
|
27 |
-
|
28 |
-
This script will save checkpoints periodically to the folder `out/`.
|
29 |
-
|
30 |
-
> **Note**
|
31 |
-
> All scripts support argument [customization](customize_paths.md)
|
32 |
-
|
33 |
-
|
34 |
-
## Test the model
|
35 |
-
|
36 |
-
You can test the finetuned model with your own instructions by running:
|
37 |
-
|
38 |
-
```bash
|
39 |
-
python generate/lora.py --prompt "Recommend a movie to watch on the weekend."
|
40 |
-
```
|
41 |
-
Output:
|
42 |
-
```
|
43 |
-
I would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of...
|
44 |
-
```
|
45 |
-
|
46 |
-
If your GPU supports `bfloat16`, you can additionally pass `--dtype bfloat16` to bring the memory consumption down to ~14 GB.
|
47 |
-
|
48 |
-
## Tune on your dataset
|
49 |
-
|
50 |
-
With only a few modifications, you can prepare and train on your own instruction dataset.
|
51 |
-
|
52 |
-
1. Create a json file in which each row holds one instruction-response pair.
|
53 |
-
A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
|
54 |
-
the empty string if the instruction doesn't require a context. Below is an example json file:
|
55 |
-
|
56 |
-
```
|
57 |
-
[
|
58 |
-
{
|
59 |
-
"instruction": "Arrange the given numbers in ascending order.",
|
60 |
-
"input": "2, 4, 0, 8, 3",
|
61 |
-
"output": "0, 2, 3, 4, 8"
|
62 |
-
},
|
63 |
-
...
|
64 |
-
]
|
65 |
-
```
|
66 |
-
|
67 |
-
2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
|
68 |
-
|
69 |
-
```bash
|
70 |
-
cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
|
71 |
-
```
|
72 |
-
|
73 |
-
3. Modify `scripts/prepare_mydata.py` to read the json data file.
|
74 |
-
4. Run the script to generate the preprocessed, tokenized train-val split:
|
75 |
-
|
76 |
-
```bash
|
77 |
-
python scripts/prepare_mydata.py --destination_path data/mydata/
|
78 |
-
```
|
79 |
-
|
80 |
-
5. Run `finetune/lora.py` by passing in the location of your data (and optionally other parameters):
|
81 |
-
|
82 |
-
```bash
|
83 |
-
python finetune/lora.py --data_dir data/mydata/ --out_dir out/myexperiment
|
84 |
-
```
|
85 |
-
|
86 |
-
|
87 |
-
## Troubleshooting
|
88 |
-
|
89 |
-
If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
|
90 |
-
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/inference.md
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
# Inference
|
2 |
-
|
3 |
-
We demonstrate how to run inference (next token prediction) with the LLaMA base model in the [`generate.py`](generate.py) script:
|
4 |
-
|
5 |
-
```bash
|
6 |
-
python generate.py --prompt "Hello, my name is"
|
7 |
-
```
|
8 |
-
Output:
|
9 |
-
```
|
10 |
-
Hello my name is TJ. I have a passion for the outdoors, love hiking and exploring. I also enjoy traveling and learning new things. I especially enjoy long walks, good conversation and a friendly smile.
|
11 |
-
```
|
12 |
-
|
13 |
-
The script assumes you have downloaded and converted the weights and saved them in the `./checkpoints` folder as described [here](download_weights.md).
|
14 |
-
|
15 |
-
> **Note**
|
16 |
-
> All scripts support argument [customization](customize_paths.md)
|
17 |
-
|
18 |
-
With the default settings, this will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
|
19 |
-
|
20 |
-
## Run Lit-LLaMA on consumer devices
|
21 |
-
|
22 |
-
On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
|
23 |
-
For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):
|
24 |
-
|
25 |
-
```bash
|
26 |
-
python generate.py --quantize llm.int8 --prompt "Hello, my name is"
|
27 |
-
```
|
28 |
-
This will consume about ~10 GB of GPU memory or ~8 GB if also using `bfloat16`.
|
29 |
-
See `python generate.py --help` for more options.
|
30 |
-
|
31 |
-
You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:
|
32 |
-
|
33 |
-
```bash
|
34 |
-
python quantize/gptq.py --output_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth --dtype bfloat16 --quantize gptq.int4
|
35 |
-
```
|
36 |
-
|
37 |
-
GPTQ-style int4 quantization brings GPU usage down to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to also use `--dtype bfloat16` even with the quantization enabled.
|
38 |
-
|
39 |
-
With the generated quantized checkpoint generation quantization then works as usual with `--quantize gptq.int4` and the newly generated checkpoint file:
|
40 |
-
|
41 |
-
```bash
|
42 |
-
python generate.py --quantize gptq.int4 --checkpoint_path checkpoints/lit-llama/7B/llama-gptq.4bit.pth
|
43 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/tpus.md
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
# TPU support
|
2 |
-
|
3 |
-
Lit-LLaMA used `lightning.Fabric` under the hood, which itself supports TPUs (via [PyTorch XLA](https://github.com/pytorch/xla)).
|
4 |
-
|
5 |
-
The following commands will allow you to set up a `Google Cloud` instance with a [TPU v4](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) VM:
|
6 |
-
|
7 |
-
```shell
|
8 |
-
gcloud compute tpus tpu-vm create lit-llama --version=tpu-vm-v4-pt-2.0 --accelerator-type=v4-8 --zone=us-central2-b
|
9 |
-
gcloud compute tpus tpu-vm ssh lit-llama --zone=us-central2-b
|
10 |
-
```
|
11 |
-
|
12 |
-
Now that you are in the machine, let's clone the repository and install the dependencies
|
13 |
-
|
14 |
-
```shell
|
15 |
-
git clone https://github.com/Lightning-AI/lit-llama
|
16 |
-
cd lit-llama
|
17 |
-
pip install -r requirements.txt
|
18 |
-
```
|
19 |
-
|
20 |
-
By default, computations will run using the new (and experimental) PjRT runtime. Still, it's recommended that you set the following environment variables
|
21 |
-
|
22 |
-
```shell
|
23 |
-
export PJRT_DEVICE=TPU
|
24 |
-
export ALLOW_MULTIPLE_LIBTPU_LOAD=1
|
25 |
-
```
|
26 |
-
|
27 |
-
> **Note**
|
28 |
-
> You can find an extensive guide on how to get set-up and all the available options [here](https://cloud.google.com/tpu/docs/v4-users-guide).
|
29 |
-
|
30 |
-
Since you created a new machine, you'll probably need to download the weights. You could scp them into the machine with `gcloud compute tpus tpu-vm scp` or you can follow the steps described in our [downloading guide](download_weights.md).
|
31 |
-
|
32 |
-
## Inference
|
33 |
-
|
34 |
-
Generation works out-of-the-box with TPUs:
|
35 |
-
|
36 |
-
```shell
|
37 |
-
python3 generate.py --prompt "Hello, my name is" --num_samples 3
|
38 |
-
```
|
39 |
-
|
40 |
-
This command will take take ~20s for the first generation time as XLA needs to compile the graph.
|
41 |
-
You'll notice that afterwards, generation times drop to ~5s.
|
42 |
-
|
43 |
-
## Finetuning
|
44 |
-
|
45 |
-
Coming soon.
|
46 |
-
|
47 |
-
> **Warning**
|
48 |
-
> When you are done, remember to delete your instance
|
49 |
-
> ```shell
|
50 |
-
> gcloud compute tpus tpu-vm delete lit-llama --zone=us-central2-b
|
51 |
-
> ```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/train_redpajama.md
DELETED
@@ -1,133 +0,0 @@
|
|
1 |
-
# Pre-train LLaMA on RedPajama
|
2 |
-
|
3 |
-
This howto will walk you through setting up the RedPajama dataset and launching the pre-training script.
|
4 |
-
|
5 |
-
## What's RedPajama
|
6 |
-
|
7 |
-
[RedPajama](https://github.com/togethercomputer/RedPajama-Data) is an open-source reproduction of the original LLaMA training dataset.
|
8 |
-
|
9 |
-
It contains a total of 1.2 trillion tokens, divided into
|
10 |
-
|
11 |
-
```text
|
12 |
-
Commoncrawl 878B
|
13 |
-
C4 175B
|
14 |
-
GitHub 59B
|
15 |
-
Books 26B
|
16 |
-
ArXiv 28B
|
17 |
-
Wikipedia 24B
|
18 |
-
StackExchange 20B
|
19 |
-
```
|
20 |
-
|
21 |
-
The [RedPajama repo](https://github.com/togethercomputer/RedPajama-Data) contains the source code for collecting and preparing
|
22 |
-
the dataset, and it is Apache 2.0 licensed.
|
23 |
-
|
24 |
-
The data itself is licensed according to the original licenses with which its invidivdual parts were released.
|
25 |
-
The GitHub datasets are limited to MIT, BSD, or Apache 2.0 repositories.
|
26 |
-
|
27 |
-
Along with the full [RedPajama-1T dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T),
|
28 |
-
the [RedPajama-1T-Sample](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample) 1B sample dataset
|
29 |
-
is also available for development.
|
30 |
-
|
31 |
-
You can download the data using git lfs:
|
32 |
-
|
33 |
-
```bash
|
34 |
-
# Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
|
35 |
-
git clone https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T data/RedPajama-Data-1T
|
36 |
-
```
|
37 |
-
|
38 |
-
```bash
|
39 |
-
# Make sure you have git-lfs installed (https://git-lfs.com): git lfs install
|
40 |
-
git clone https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample data/RedPajama-Data-1T-Sample
|
41 |
-
```
|
42 |
-
|
43 |
-
## Prepare RedPajama for training
|
44 |
-
|
45 |
-
The dataset consists of 2084 `jsonl` files (the sample dataset contains 11). In order to start pre-training lit-llama
|
46 |
-
on it, you need to read, tokenize, and write the data in binary chunks. This will leverage the `PackedDataset`
|
47 |
-
streaming dataset that comes with lit-llama.
|
48 |
-
|
49 |
-
Do to so, run
|
50 |
-
|
51 |
-
```bash
|
52 |
-
python scripts/prepare_redpajama.py --source_path data/RedPajama-Data-1T --tokenizer_path checkpoints/lit-llama/tokenizer.model --destination_path data/lit-redpajama
|
53 |
-
```
|
54 |
-
|
55 |
-
or
|
56 |
-
|
57 |
-
```bash
|
58 |
-
python scripts/prepare_redpajama.py --source_path data/RedPajama-Data-1T-Sample --tokenizer_path checkpoints/lit-llama/tokenizer.model --destination_path data/lit-redpajama-sample --sample True
|
59 |
-
```
|
60 |
-
|
61 |
-
for the sample dataset.
|
62 |
-
|
63 |
-
In the above we are assuming that you will be using the same tokenizer as used in LLaMA, but any trained [SentencePiece](https://github.com/google/sentencepiece) tokenizer with a 32000 vocabulary size will do here.
|
64 |
-
|
65 |
-
The script will take a while to run, so time for :tea:
|
66 |
-
|
67 |
-
## Pre-training
|
68 |
-
|
69 |
-
Running the pre-training script requires at least 4 GPUs with 40GB+ each (A100).
|
70 |
-
|
71 |
-
```bash
|
72 |
-
python pretrain/redpajama.py --devices 4 --train_data_dir data/lit-redpajama
|
73 |
-
```
|
74 |
-
|
75 |
-
For running on the sample dataset:
|
76 |
-
|
77 |
-
```bash
|
78 |
-
python pretrain/redpajama.py --devices 4 --train_data_dir data/lit-redpajama-sample
|
79 |
-
```
|
80 |
-
|
81 |
-
The script will save checkpoints periodically to the folder `out/`.
|
82 |
-
|
83 |
-
The `train_redpajama.py` script will pre-train the LLaMA 7B model with FSDP in
|
84 |
-
`bfloat16` precision and gradient accumulation.
|
85 |
-
|
86 |
-
You can easily change the size of the model by passing a different string to
|
87 |
-
|
88 |
-
```python
|
89 |
-
config = LLaMAConfig.from_name("7B")
|
90 |
-
```
|
91 |
-
|
92 |
-
in the `main` function.
|
93 |
-
|
94 |
-
Keep in mind that the original LLaMA training for the 7B model required 83k A100 80GB
|
95 |
-
hours, so you'll need access to a cluster.
|
96 |
-
|
97 |
-
Once you're in a cluster, you can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)
|
98 |
-
to launch the script across machines:
|
99 |
-
|
100 |
-
- [SLURM cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html)
|
101 |
-
- [Barebones cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/barebones.html)
|
102 |
-
- [MPI](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)
|
103 |
-
|
104 |
-
The script contains several configurations and hyperparameters you can tweak:
|
105 |
-
|
106 |
-
```python
|
107 |
-
out_dir = "out/training"
|
108 |
-
save_interval = 1000
|
109 |
-
eval_interval = 1000
|
110 |
-
eval_iters = 100
|
111 |
-
log_interval = 1
|
112 |
-
|
113 |
-
# Hyperparameters
|
114 |
-
learning_rate = 6e-4
|
115 |
-
batch_size = 125
|
116 |
-
micro_batch_size = 5
|
117 |
-
max_iters = 600000 # num_epochs * (epoch_size // micro_batch_size) // devices
|
118 |
-
weight_decay = 1e-1
|
119 |
-
beta1 = 0.9
|
120 |
-
beta2 = 0.95
|
121 |
-
grad_clip = 1.0
|
122 |
-
decay_lr = True
|
123 |
-
warmup_iters = 2000
|
124 |
-
lr_decay_iters = max_iters
|
125 |
-
min_lr = 6e-5
|
126 |
-
```
|
127 |
-
|
128 |
-
In particular, `micro_batch_size` should be adjusted so the process will use the available
|
129 |
-
GPU memory.
|
130 |
-
|
131 |
-
Last, logging is kept minimal in the script. In order to use a particular logger
|
132 |
-
please refer to <https://lightning.ai/docs/fabric/stable/api/loggers.html> or
|
133 |
-
call a logging client library like `wandb` directly.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/howto/unstructured_dataset.md
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
# Finetuning on an unstructured dataset
|
2 |
-
|
3 |
-
While most scripts were made to finetune on instruction datasets, it is possible to finetune on any dataset. This is useful for experimentation while not being as expensive as training a full model.
|
4 |
-
|
5 |
-
This guide is only to prepare the finetuning, as either LoRA or Adapter-v1 methods support this dataset type!
|
6 |
-
|
7 |
-
## Preparation
|
8 |
-
|
9 |
-
1. Gather your text into an input file named `input.txt`
|
10 |
-
2. Divide the data into training and validation sets using the following script:
|
11 |
-
|
12 |
-
```bash
|
13 |
-
python scripts/prepare_any_text.py
|
14 |
-
```
|
15 |
-
|
16 |
-
3. Modify relevant scripts for your finetuning method under `finetune/` and `evaluate/`, setting the `instruction_tuning` variable to `False`
|
17 |
-
|
18 |
-
And then you're set! Proceed to run the [LoRA guide](./finetune_lora.md) or [Adapter v1 guide](./finetune_adapter.md).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/lit_llama/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
|
2 |
-
from lit_llama.tokenizer import Tokenizer
|
|
|
|
|
|
lit-llama/lit_llama/adapter.py
DELETED
@@ -1,313 +0,0 @@
|
|
1 |
-
"""Implementation of the paper:
|
2 |
-
|
3 |
-
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
|
4 |
-
https://arxiv.org/abs/2303.16199
|
5 |
-
|
6 |
-
| Prefix cross-attention
|
7 |
-
|
|
8 |
-
┌─────────────────┐ | ┌──────────────────┐
|
9 |
-
┆ x ┆ | ┆ prefix ┆
|
10 |
-
└─────────────────┘ | └──────────────────┘
|
11 |
-
| | |
|
12 |
-
▼ | ▼
|
13 |
-
┌──────────────────┐ | ┌─────────────────────┐
|
14 |
-
┆ self-attention ┆ --------------------------------------------------------------┐ ┆ linear projection ┆
|
15 |
-
└──────────────────┘ | ┆ └─────────────────────┘
|
16 |
-
| | ┆ | \
|
17 |
-
▼ | ▼ ▼ ▼
|
18 |
-
╭───╮ ┌────────────────┐ ╭───╮ ┌──────────────────────────┐ | ┌─────────┐ ┌──────────────┐ ┌────────────────┐
|
19 |
-
┆ + ┆ ◀── ┆ gating factor ┆-┆ x ┆-┆ prefix cross-attention ┆ | ┆ query ┆ ┆ prefix key ┆ ┆ prefix value ┆
|
20 |
-
╰───╯ └────────────────┘ ╰───╯ └──────────────────────────┘ | └─────────┘ └──────────────┘ └────────────────┘
|
21 |
-
| | \ | /
|
22 |
-
▼ | ▼ ▼ ▼
|
23 |
-
| ┌────────────────────────────────┐
|
24 |
-
| ┆ scaled dot-product attention ┆
|
25 |
-
| └────────────────────────────────┘
|
26 |
-
|
27 |
-
|
28 |
-
In order to inject learnable information from the prefix to pretrained weights we need to sum outputs from
|
29 |
-
self-attention and prefix cross-attention (times gating factor). For prefix cross-attention we need `query` (from
|
30 |
-
self-attention as a result of linear projection), `prefix key` and `prefix value` (from cross-attention as a result of
|
31 |
-
linear projection).
|
32 |
-
The output of prefix cross-attention is multiplied by gating factor, which is a learnable parameter that is needed to
|
33 |
-
avoid potential disruption of pretrained weights caused by incorporating randomly initialized tensors. This factor is
|
34 |
-
initialized with zeros to avoid noise from the adaption prompts at the early training stage.
|
35 |
-
More about it: https://lightning.ai/pages/community/article/understanding-llama-adapters/
|
36 |
-
|
37 |
-
Notes about implementation: as per paper adapter's prefix is concatenated with the input, while here outputs of
|
38 |
-
self-attention and prefix cross-attention are summed. Both variants are mathematically equivalent:
|
39 |
-
https://github.com/ZrrSkywalker/LLaMA-Adapter/issues/47
|
40 |
-
"""
|
41 |
-
# mypy: ignore-errors
|
42 |
-
from dataclasses import dataclass
|
43 |
-
from typing import Optional, Tuple, List, Union
|
44 |
-
|
45 |
-
import torch
|
46 |
-
import torch.nn as nn
|
47 |
-
from torch.nn import functional as F
|
48 |
-
|
49 |
-
import lit_llama.model as llama
|
50 |
-
from lit_llama.model import build_rope_cache, apply_rope, RMSNorm, MLP, KVCache, RoPECache
|
51 |
-
|
52 |
-
|
53 |
-
@dataclass
|
54 |
-
class LLaMAConfig(llama.LLaMAConfig):
|
55 |
-
adapter_prompt_length: int = 10
|
56 |
-
adapter_start_layer: int = 2
|
57 |
-
|
58 |
-
|
59 |
-
class CausalSelfAttention(nn.Module):
|
60 |
-
"""A modification of `lit_llama.model.CausalSelfAttention` that adds the attention
|
61 |
-
over the adaption prompt."""
|
62 |
-
|
63 |
-
def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
|
64 |
-
super().__init__()
|
65 |
-
assert config.n_embd % config.n_head == 0
|
66 |
-
|
67 |
-
# key, query, value projections for all heads, but in a batch
|
68 |
-
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
|
69 |
-
# output projection
|
70 |
-
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
71 |
-
|
72 |
-
if block_idx >= config.adapter_start_layer:
|
73 |
-
# adapter embedding layer
|
74 |
-
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
|
75 |
-
# a learnable gating factor (to avoid potential disruption of pretrained weights) initialized with zeros (to
|
76 |
-
# avoid noise from adaption prompts at the early training stage)
|
77 |
-
self.gating_factor = torch.nn.Parameter(torch.zeros(1, config.n_head, 1, 1))
|
78 |
-
|
79 |
-
self.n_head = config.n_head
|
80 |
-
self.n_embd = config.n_embd
|
81 |
-
self.block_size = config.block_size
|
82 |
-
self.block_idx = block_idx
|
83 |
-
self.adapter_prompt_length = config.adapter_prompt_length
|
84 |
-
self.adapter_start_layer = config.adapter_start_layer
|
85 |
-
|
86 |
-
def forward(
|
87 |
-
self,
|
88 |
-
x: torch.Tensor,
|
89 |
-
rope: RoPECache,
|
90 |
-
mask: torch.Tensor,
|
91 |
-
max_seq_length: int,
|
92 |
-
input_pos: Optional[torch.Tensor] = None,
|
93 |
-
kv_cache: Optional[KVCache] = None,
|
94 |
-
adapter_kv_cache: Optional[KVCache] = None,
|
95 |
-
) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
|
96 |
-
# notation:
|
97 |
-
# - B | batch
|
98 |
-
# - T | time-step (sequence length)
|
99 |
-
# - C | embeddings size (n_embd) = head size * num heads
|
100 |
-
# - hs | head size
|
101 |
-
# - nh | number of heads
|
102 |
-
|
103 |
-
B, T, C = x.size()
|
104 |
-
|
105 |
-
# instead of calculating `query`, `key` and `value` by separately multiplying input `x` with corresponding
|
106 |
-
# weight matrices do it (for all heads) in a single multiplication with a matrix of 3x size (concatenated
|
107 |
-
# weights for q, k, v) and then split the result along `embedding size` dimension
|
108 |
-
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C)
|
109 |
-
|
110 |
-
# in order to move head_size (hs) dimension right after batch (B) dimension, we need to first split
|
111 |
-
# embedding size (C) dimension into num_heads (nh) and head_size (hs)
|
112 |
-
head_size = C // self.n_head
|
113 |
-
k = k.view(B, T, self.n_head, head_size)
|
114 |
-
q = q.view(B, T, self.n_head, head_size)
|
115 |
-
v = v.view(B, T, self.n_head, head_size)
|
116 |
-
|
117 |
-
# "Unlike standard positional embeddings rotary embeddings must be applied at every layer"
|
118 |
-
q = apply_rope(q, rope) # (B, T, nh, hs)
|
119 |
-
k = apply_rope(k, rope) # (B, T, nh, hs)
|
120 |
-
|
121 |
-
# now `key`, 'query` and `value` tensors are correctly represented: for each element in a batch (B)
|
122 |
-
# there is a number of heads (nh) and for each head there is a sequence of elements (T), each of them is
|
123 |
-
# represented by a vector of size `hs`
|
124 |
-
k = k.transpose(1, 2) # (B, nh, T, hs)
|
125 |
-
q = q.transpose(1, 2) # (B, nh, T, hs)
|
126 |
-
v = v.transpose(1, 2) # (B, nh, T, hs)
|
127 |
-
|
128 |
-
if kv_cache is not None:
|
129 |
-
cache_k, cache_v = kv_cache # 2 * (B, nh, max_seq_length, hs)
|
130 |
-
# check if reached token limit
|
131 |
-
if input_pos[-1] >= max_seq_length:
|
132 |
-
# if we reached token limit and thus there is no space to put newly calculated `key` and `value`
|
133 |
-
# right next to cached ones, we need to rotate cache tensor along `max_seq_length` dimension by one
|
134 |
-
# element to the left: this will free up space for new `key` and `value`
|
135 |
-
input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
|
136 |
-
# shift 1 position to the left
|
137 |
-
cache_k = torch.roll(cache_k, -1, dims=2)
|
138 |
-
cache_v = torch.roll(cache_v, -1, dims=2)
|
139 |
-
k = cache_k.index_copy(2, input_pos, k) # (B, nh, max_seq_length, hs)
|
140 |
-
v = cache_v.index_copy(2, input_pos, v) # (B, nh, max_seq_length, hs)
|
141 |
-
kv_cache = k, v
|
142 |
-
|
143 |
-
# efficient attention using Flash Attention CUDA kernels
|
144 |
-
# ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
|
145 |
-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) # (B, nh, T, hs)
|
146 |
-
|
147 |
-
# "Adapters are applied to the topmost layers to better tune the language
|
148 |
-
# representations with higher-level semantics".
|
149 |
-
if self.block_idx >= self.adapter_start_layer:
|
150 |
-
if adapter_kv_cache is not None:
|
151 |
-
ak, av = adapter_kv_cache # 2 * (B, nh, aT, hs)
|
152 |
-
else:
|
153 |
-
prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
|
154 |
-
aT = prefix.size(1)
|
155 |
-
_, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C)
|
156 |
-
ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
|
157 |
-
av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
|
158 |
-
adapter_kv_cache = (ak, av)
|
159 |
-
|
160 |
-
# Apply cross-attention with `query`, `adapter_key`, `adapter_value` and sum the output with the output
|
161 |
-
# obtained from self-attention step. This is mathematically equivalent to concatenation of prefix and input as per paper.
|
162 |
-
amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device) # (T, aT)
|
163 |
-
# ↓ (B, nh, T, hs) @ (B, nh, aT, hs).mT --> (B, nh, T, aT) @ (B, nh, aT, hs) --> (B, nh, T, hs)
|
164 |
-
ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False) # (B, nh, T, hs)
|
165 |
-
y = y + self.gating_factor * ay
|
166 |
-
|
167 |
-
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
168 |
-
|
169 |
-
# output projection
|
170 |
-
y = self.c_proj(y) # (B, T, C)
|
171 |
-
|
172 |
-
return y, kv_cache, adapter_kv_cache
|
173 |
-
|
174 |
-
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
175 |
-
"""For backward compatibility with old checkpoints that have a single gating value for all heads."""
|
176 |
-
name = prefix + "gating_factor"
|
177 |
-
if name in state_dict:
|
178 |
-
tensor = state_dict[name]
|
179 |
-
# in case we are loading with `utils.lazy_load()`
|
180 |
-
tensor = tensor._load_tensor() if hasattr(tensor, "_load_tensor") else tensor
|
181 |
-
|
182 |
-
if len(tensor.shape) < 4:
|
183 |
-
# For old checkpoints with unified gating value
|
184 |
-
state_dict[name] = tensor.reshape(1, 1, 1, 1).repeat(1, self.n_head, 1, 1)
|
185 |
-
else:
|
186 |
-
state_dict[name] = tensor
|
187 |
-
|
188 |
-
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
189 |
-
|
190 |
-
|
191 |
-
class Block(nn.Module):
|
192 |
-
"""The implementation is identical to `lit_llama.model.Block` with the exception that
|
193 |
-
we replace the attention layer where adaption is implemented."""
|
194 |
-
|
195 |
-
def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
|
196 |
-
super().__init__()
|
197 |
-
self.rms_1 = RMSNorm(config.n_embd)
|
198 |
-
self.attn = CausalSelfAttention(config, block_idx)
|
199 |
-
self.rms_2 = RMSNorm(config.n_embd)
|
200 |
-
self.mlp = MLP(config)
|
201 |
-
|
202 |
-
def forward(
|
203 |
-
self,
|
204 |
-
x: torch.Tensor,
|
205 |
-
rope: RoPECache,
|
206 |
-
mask: torch.Tensor,
|
207 |
-
max_seq_length: int,
|
208 |
-
input_pos: Optional[torch.Tensor] = None,
|
209 |
-
kv_cache: Optional[KVCache] = None,
|
210 |
-
adapter_kv_cache: Optional[KVCache] = None,
|
211 |
-
) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
|
212 |
-
h, new_kv_cache, new_adapter_kv_cache = self.attn(
|
213 |
-
self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache, adapter_kv_cache
|
214 |
-
)
|
215 |
-
x = x + h
|
216 |
-
x = x + self.mlp(self.rms_2(x))
|
217 |
-
return x, new_kv_cache, new_adapter_kv_cache
|
218 |
-
|
219 |
-
|
220 |
-
class LLaMA(llama.LLaMA):
|
221 |
-
"""The implementation is identical to `lit_llama.model.LLaMA` with the exception that
|
222 |
-
the `Block` saves the layer index and passes it down to the attention layer."""
|
223 |
-
|
224 |
-
def __init__(self, config: LLaMAConfig) -> None:
|
225 |
-
nn.Module.__init__(self)
|
226 |
-
assert config.vocab_size is not None
|
227 |
-
assert config.block_size is not None
|
228 |
-
self.config = config
|
229 |
-
|
230 |
-
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
231 |
-
self.transformer = nn.ModuleDict(
|
232 |
-
dict(
|
233 |
-
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
234 |
-
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
|
235 |
-
ln_f=RMSNorm(config.n_embd),
|
236 |
-
)
|
237 |
-
)
|
238 |
-
|
239 |
-
self.rope_cache: Optional[RoPECache] = None
|
240 |
-
self.mask_cache: Optional[torch.Tensor] = None
|
241 |
-
self.kv_caches: List[KVCache] = []
|
242 |
-
self.adapter_kv_caches: List[KVCache] = []
|
243 |
-
|
244 |
-
@classmethod
|
245 |
-
def from_name(cls, name: str):
|
246 |
-
return cls(LLaMAConfig.from_name(name))
|
247 |
-
|
248 |
-
def reset_cache(self) -> None:
|
249 |
-
super().reset_cache()
|
250 |
-
self.adapter_kv_caches.clear()
|
251 |
-
|
252 |
-
def forward(
|
253 |
-
self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
|
254 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
|
255 |
-
B, T = idx.size()
|
256 |
-
|
257 |
-
block_size = self.config.block_size
|
258 |
-
if max_seq_length is None:
|
259 |
-
max_seq_length = block_size
|
260 |
-
assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
|
261 |
-
assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
|
262 |
-
assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"
|
263 |
-
|
264 |
-
if self.rope_cache is None:
|
265 |
-
self.rope_cache = self.build_rope_cache(idx) # (block_size, head_size / 2, 2)
|
266 |
-
if self.mask_cache is None:
|
267 |
-
self.mask_cache = self.build_mask_cache(idx) # (1, 1, block_size, block_size)
|
268 |
-
|
269 |
-
if input_pos is not None:
|
270 |
-
rope = self.rope_cache.index_select(0, input_pos)
|
271 |
-
mask = self.mask_cache.index_select(2, input_pos)
|
272 |
-
mask = mask[:, :, :, :max_seq_length]
|
273 |
-
else:
|
274 |
-
rope = self.rope_cache[:T]
|
275 |
-
mask = self.mask_cache[:, :, :T, :T]
|
276 |
-
|
277 |
-
# forward the model itself
|
278 |
-
x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
|
279 |
-
|
280 |
-
if input_pos is None: # proxy for use_cache=False
|
281 |
-
for block in self.transformer.h:
|
282 |
-
x, *_ = block(x, rope, mask, max_seq_length)
|
283 |
-
else:
|
284 |
-
if not self.kv_caches:
|
285 |
-
head_size = self.config.n_embd // self.config.n_head
|
286 |
-
cache_shape = (B, self.config.n_head, max_seq_length, head_size)
|
287 |
-
self.kv_caches = [
|
288 |
-
(torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype))
|
289 |
-
for _ in range(self.config.n_layer)
|
290 |
-
]
|
291 |
-
if not self.adapter_kv_caches:
|
292 |
-
self.adapter_kv_caches = [None for _ in range(self.config.n_layer)]
|
293 |
-
for i, block in enumerate(self.transformer.h):
|
294 |
-
x, self.kv_caches[i], self.adapter_kv_caches[i] = block(
|
295 |
-
x, rope, mask, max_seq_length, input_pos, self.kv_caches[i], self.adapter_kv_caches[i]
|
296 |
-
)
|
297 |
-
|
298 |
-
x = self.transformer.ln_f(x) # (B, T, n_embd)
|
299 |
-
|
300 |
-
logits = self.lm_head(x) # (B, T, vocab_size)
|
301 |
-
|
302 |
-
return logits
|
303 |
-
|
304 |
-
|
305 |
-
def mark_only_adapter_as_trainable(model: LLaMA) -> None:
|
306 |
-
"""Sets `requires_grad=False` for all non-adapter weights."""
|
307 |
-
for name, param in model.named_parameters():
|
308 |
-
param.requires_grad = "adapter_wte" in name or "gating_factor" in name
|
309 |
-
|
310 |
-
|
311 |
-
def adapter_state_from_state_dict(state_dict: dict) -> dict:
|
312 |
-
"""Returns the model state dict with only the adapter weights for saving."""
|
313 |
-
return {name: param for name, param in state_dict.items() if "adapter_wte" in name or "gating_factor" in name}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/lit_llama/adapter_v2.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import Tensor
|
3 |
-
import torch.nn as nn
|
4 |
-
from torch.nn import functional as F
|
5 |
-
|
6 |
-
from lit_llama.adapter import LLaMA
|
7 |
-
|
8 |
-
|
9 |
-
def get_adapter_substrings():
|
10 |
-
substrings = ["adapter_wte", "gating_factor"] # regular adapter v1 parameters
|
11 |
-
substrings.extend(["adapter_scale", "adapter_bias"]) # adapter v2: new bias and scale used in Linear
|
12 |
-
substrings.extend(["rms_1", "rms_2", "ln_f"]) # adapter v2: RMSNorm parameters are now trainable
|
13 |
-
return substrings
|
14 |
-
|
15 |
-
|
16 |
-
def mark_only_adapter_v2_as_trainable(model: LLaMA) -> None:
|
17 |
-
"""Sets `requires_grad=False` for all non-adapter weights."""
|
18 |
-
for name, param in model.named_parameters():
|
19 |
-
param.requires_grad = any(s in name for s in get_adapter_substrings())
|
20 |
-
|
21 |
-
|
22 |
-
def adapter_v2_state_from_state_dict(state_dict: dict) -> dict:
|
23 |
-
"""Returns the model state dict with only the adapter weights for saving."""
|
24 |
-
return {name: param for name, param in state_dict.items()
|
25 |
-
if any(s in name for s in get_adapter_substrings())}
|
26 |
-
|
27 |
-
|
28 |
-
def adapter_v2_new_forward(self, input: Tensor) -> Tensor:
|
29 |
-
return self.adapter_scale * (
|
30 |
-
F.linear(input, self.weight, self.bias) + self.adapter_bias
|
31 |
-
)
|
32 |
-
|
33 |
-
|
34 |
-
def adapter_v2_linear_with_bias_and_scale(layer):
|
35 |
-
layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True)
|
36 |
-
layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True)
|
37 |
-
bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__)
|
38 |
-
setattr(layer, 'forward', bound_method)
|
39 |
-
return layer
|
40 |
-
|
41 |
-
|
42 |
-
def add_adapter_v2_parameters_to_linear_layers(model):
|
43 |
-
for module in model.modules():
|
44 |
-
if isinstance(module, nn.Linear):
|
45 |
-
adapter_v2_linear_with_bias_and_scale(module)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/lit_llama/lora.py
DELETED
@@ -1,476 +0,0 @@
|
|
1 |
-
# Derived from https://github.com/microsoft/LoRA
|
2 |
-
# ------------------------------------------------------------------------------------------
|
3 |
-
# Copyright (c) Microsoft Corporation. All rights reserved.
|
4 |
-
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
5 |
-
# ------------------------------------------------------------------------------------------
|
6 |
-
|
7 |
-
r"""
|
8 |
-
Low Ranking Adaptation for LLMs scheme.
|
9 |
-
|
10 |
-
┌───────────────────┐
|
11 |
-
┆ h ┆
|
12 |
-
└───────────────────┘
|
13 |
-
▲
|
14 |
-
|
|
15 |
-
+
|
16 |
-
/ \
|
17 |
-
┌─────────────────┐ ╭───────────────╮ Matrix initialization:
|
18 |
-
┆ ┆ \ B / B = 0
|
19 |
-
┆ pretrained ┆ \ r*d / A = N(0, sigma^2)
|
20 |
-
┆ weights ┆ ╰─────────╯
|
21 |
-
┆ ┆ | r | r - rank
|
22 |
-
┆ W e R^(d*d) ┆ | ◀─────▶ |
|
23 |
-
┆ ┆ ╭─────────╮
|
24 |
-
└─────────────────┘ / A \
|
25 |
-
▲ / d*r \
|
26 |
-
\ ╰───────────────╯
|
27 |
-
\ ▲
|
28 |
-
\ /
|
29 |
-
\ /
|
30 |
-
┌───────────────────┐
|
31 |
-
┆ x ┆
|
32 |
-
└───────────────────┘
|
33 |
-
|
34 |
-
With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
|
35 |
-
we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
|
36 |
-
for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
|
37 |
-
course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
|
38 |
-
pretrained weights and thus fine-tune the model.
|
39 |
-
|
40 |
-
The goal of this approach is to move weight updates into a separate matrix which is decomposed with
|
41 |
-
two matrices of a lower rank.
|
42 |
-
"""
|
43 |
-
|
44 |
-
import torch
|
45 |
-
import torch.nn as nn
|
46 |
-
import torch.nn.functional as F
|
47 |
-
|
48 |
-
import math
|
49 |
-
from typing import Dict, List
|
50 |
-
|
51 |
-
import lit_llama.model as llama
|
52 |
-
|
53 |
-
from contextlib import contextmanager
|
54 |
-
from dataclasses import dataclass
|
55 |
-
|
56 |
-
|
57 |
-
class LoRALayer():
|
58 |
-
def __init__(
|
59 |
-
self,
|
60 |
-
r: int,
|
61 |
-
lora_alpha: int,
|
62 |
-
lora_dropout: float,
|
63 |
-
merge_weights: bool,
|
64 |
-
):
|
65 |
-
"""Store LoRA specific attributes in a class.
|
66 |
-
|
67 |
-
Args:
|
68 |
-
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
|
69 |
-
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
|
70 |
-
lora_alpha: alpha is needed for scaling updates as alpha/r
|
71 |
-
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
|
72 |
-
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
|
73 |
-
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
|
74 |
-
merge_weights: whether we want to merge pretrained weights and LoRA weight updates. This is useful if one wants to use
|
75 |
-
fine-tuned model as a standalone one (without storing LoRA weights separately) plus it helps to reduce
|
76 |
-
overhead during inference.
|
77 |
-
"""
|
78 |
-
self.r = r
|
79 |
-
self.lora_alpha = lora_alpha
|
80 |
-
# Optional dropout
|
81 |
-
if lora_dropout > 0.:
|
82 |
-
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
83 |
-
else:
|
84 |
-
self.lora_dropout = lambda x: x
|
85 |
-
# Mark the weight as unmerged
|
86 |
-
self.merged = False
|
87 |
-
self.merge_weights = merge_weights
|
88 |
-
|
89 |
-
|
90 |
-
class MergedLinear(nn.Linear, LoRALayer):
|
91 |
-
# LoRA implemented in a dense layer
|
92 |
-
def __init__(
|
93 |
-
self,
|
94 |
-
# ↓ this part is for pretrained weights
|
95 |
-
in_features: int,
|
96 |
-
out_features: int,
|
97 |
-
# ↓ the remaining part is for LoRA
|
98 |
-
r: int = 0,
|
99 |
-
lora_alpha: int = 1,
|
100 |
-
lora_dropout: float = 0.,
|
101 |
-
enable_lora: List[bool] = [False],
|
102 |
-
fan_in_fan_out: bool = False,
|
103 |
-
merge_weights: bool = True,
|
104 |
-
**kwargs
|
105 |
-
):
|
106 |
-
"""LoRA wrapper around linear class that is used for calculation of q, k and v matrices.
|
107 |
-
|
108 |
-
This class has three weight matrices:
|
109 |
-
1. Pretrained weights are stored as `self.weight` (because of the nn.Linear inheritance)
|
110 |
-
2. LoRA A matrix as `self.lora_A`
|
111 |
-
3. LoRA B matrix as `self.lora_B`
|
112 |
-
Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
|
113 |
-
|
114 |
-
Args:
|
115 |
-
in_features: number of input features of the pretrained weights
|
116 |
-
out_features: number of output features of the pretrained weights
|
117 |
-
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
|
118 |
-
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
|
119 |
-
lora_alpha: alpha is needed for scaling updates as alpha/r
|
120 |
-
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
|
121 |
-
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
|
122 |
-
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
|
123 |
-
enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
|
124 |
-
don't want to apply LoRA for all three (query, key and value) we can set it as False. For example if we want
|
125 |
-
to apply LoRA only to `query` and `value` but keep `key` without weight updates we should pass `[True,
|
126 |
-
False, True]`
|
127 |
-
fan_in_fan_out: set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses
|
128 |
-
`Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`
|
129 |
-
https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py#LL53C9-L53C112
|
130 |
-
merge_weights: whether we want to merge pretrained weights and LoRA weight updates. This is useful if one wants to use
|
131 |
-
fine-tuned model as a standalone one (without storing LoRA weight separately) plus it helps to reduce
|
132 |
-
overhead during inference.
|
133 |
-
"""
|
134 |
-
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
135 |
-
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
136 |
-
merge_weights=merge_weights)
|
137 |
-
assert out_features % len(enable_lora) == 0, \
|
138 |
-
'The length of enable_lora must divide out_features'
|
139 |
-
self.enable_lora = enable_lora
|
140 |
-
self.fan_in_fan_out = fan_in_fan_out
|
141 |
-
|
142 |
-
# Actual trainable parameters
|
143 |
-
# To better understand initialization let's imagine that we have such parameters:
|
144 |
-
# ⚬ in_features: 128 (embeddings_size)
|
145 |
-
# ⚬ out_features: 384 (3 * embedding_size)
|
146 |
-
# ⚬ r: 2
|
147 |
-
# ⚬ enable_lora: [True, False, True]
|
148 |
-
if r > 0 and any(enable_lora):
|
149 |
-
self.lora_A = nn.Parameter(
|
150 |
-
self.weight.new_zeros((r * sum(enable_lora), in_features))) # (4, 128)
|
151 |
-
self.lora_B = nn.Parameter(
|
152 |
-
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) # (256, 2)
|
153 |
-
) # weights for Conv1D with groups=sum(enable_lora)
|
154 |
-
# Notes about shapes above
|
155 |
-
# - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
|
156 |
-
# 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
|
157 |
-
# F.linear function weights are automatically transposed. In addition conv1d requires channels to
|
158 |
-
# be before seq length
|
159 |
-
# - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
|
160 |
-
# 128*2; 2 tells to have two channels per group for group convolution
|
161 |
-
|
162 |
-
# Scaling:
|
163 |
-
# This balances the pretrained model`s knowledge and the new task-specific adaptation
|
164 |
-
# https://lightning.ai/pages/community/tutorial/lora-llm/
|
165 |
-
# So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
|
166 |
-
# alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
|
167 |
-
# tune these values to your needs. This value can be even slightly greater than 1.0!
|
168 |
-
# https://github.com/cloneofsimo/lora
|
169 |
-
self.scaling = self.lora_alpha / self.r
|
170 |
-
|
171 |
-
# Freezing the pre-trained weight matrix
|
172 |
-
self.weight.requires_grad = False # (384, 128)
|
173 |
-
|
174 |
-
# Compute the indices
|
175 |
-
# Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
|
176 |
-
# but not keys, then the weights update should be:
|
177 |
-
#
|
178 |
-
# [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
|
179 |
-
# [....................................],
|
180 |
-
# [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
|
181 |
-
# ↑ ↑ ↑
|
182 |
-
# ________________________________________
|
183 |
-
# | query | key | value |
|
184 |
-
# ----------------------------------------
|
185 |
-
self.lora_ind = self.weight.new_zeros(
|
186 |
-
(out_features, ), dtype=torch.bool
|
187 |
-
).view(len(enable_lora), -1) # (3, 128)
|
188 |
-
self.lora_ind[enable_lora, :] = True # (3, 128)
|
189 |
-
self.lora_ind = self.lora_ind.view(-1) # (384,)
|
190 |
-
self.reset_parameters()
|
191 |
-
if fan_in_fan_out:
|
192 |
-
self.weight.data = self.weight.data.T
|
193 |
-
|
194 |
-
def reset_parameters(self):
|
195 |
-
"""Reset all the weights, even including pretrained ones."""
|
196 |
-
nn.Linear.reset_parameters(self)
|
197 |
-
if hasattr(self, 'lora_A'):
|
198 |
-
# initialize A the same way as the default for nn.Linear and B to zero
|
199 |
-
# Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
|
200 |
-
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
201 |
-
nn.init.zeros_(self.lora_B)
|
202 |
-
|
203 |
-
def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
|
204 |
-
"""Properly pad weight updates with zeros.
|
205 |
-
|
206 |
-
If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
|
207 |
-
then the weights update should be:
|
208 |
-
|
209 |
-
[[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
|
210 |
-
[....................................],
|
211 |
-
[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
|
212 |
-
↑ ↑ ↑
|
213 |
-
________________________________________
|
214 |
-
| query | key | value |
|
215 |
-
----------------------------------------
|
216 |
-
|
217 |
-
Args:
|
218 |
-
x: tensor with weights update that will be padded with zeros if necessary
|
219 |
-
|
220 |
-
Returns:
|
221 |
-
A tensor with weight updates and zeros for deselected q, k or v
|
222 |
-
"""
|
223 |
-
# Let's image that:
|
224 |
-
# ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
|
225 |
-
# ⚬ embeddings_size: 128
|
226 |
-
# ⚬ self.out_features: 384 (3 * embeddings_size)
|
227 |
-
# ⚬ enable_lora: [True, False, True]
|
228 |
-
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
|
229 |
-
# embeddings_size is 384 (self.out_features), so that means that we need to pad from 256 to 384 with zeros, but
|
230 |
-
# only for key updates (this is where self.lora_ind comes in handy)
|
231 |
-
# Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
|
232 |
-
# for example when we want to merge/unmerge LoRA weights and pretrained weights
|
233 |
-
x = x.transpose(0, 1)
|
234 |
-
result = x.new_zeros((*x.shape[:-1], self.out_features)) # (64, 64, 384)
|
235 |
-
result = result.view(-1, self.out_features) # (4096, 384)
|
236 |
-
result[:, self.lora_ind] = x.reshape(
|
237 |
-
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
|
238 |
-
) # (4096, 256)
|
239 |
-
return result.view((*x.shape[:-1], self.out_features)).transpose(0, 1) # (64, 64, 384)
|
240 |
-
|
241 |
-
def train(self, mode: bool = True):
|
242 |
-
"""Set the module into train or eval mode if `mode` is True of False respectively.
|
243 |
-
|
244 |
-
For train mode (train(True)) if weights are merged we need to subtract weights updates (LoRA_A @ LoRA_B) from
|
245 |
-
pretrained weights so we can continue training LoRA's matrices A and B and keep pretrained weights frozen.
|
246 |
-
|
247 |
-
For eval mode (train(False)) if weights are not merged we need to add weight updates to pretrained weights in
|
248 |
-
order to reduce computational overhead during inference.
|
249 |
-
|
250 |
-
Args:
|
251 |
-
mode: if True the module will be set into train mode (affects Dropout and BatchNorm), if False - eval mode.
|
252 |
-
|
253 |
-
"""
|
254 |
-
def T(w):
|
255 |
-
return w.T if self.fan_in_fan_out else w
|
256 |
-
# despite being called from nn.Linear this method will put all layers into train mode, including nn.Dropout
|
257 |
-
# of course except parameters (such as self.lora_A, self.lora_B)
|
258 |
-
nn.Linear.train(self, mode)
|
259 |
-
|
260 |
-
# if train(True) -> unmerge unless we already have them unmerged
|
261 |
-
# if train(False) -> merge unless we already have them merged
|
262 |
-
should = self.merged if mode else not self.merged
|
263 |
-
|
264 |
-
# Let's assume that:
|
265 |
-
# ⚬ self.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
|
266 |
-
# ⚬ self.lora_A.data: (4, 128)
|
267 |
-
# ⚬ self.lora_B.data: (256, 2)
|
268 |
-
if self.merge_weights and should:
|
269 |
-
if self.r > 0 and any(self.enable_lora):
|
270 |
-
delta_w = F.conv1d(
|
271 |
-
self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
|
272 |
-
self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
|
273 |
-
groups=sum(self.enable_lora)
|
274 |
-
).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
|
275 |
-
# -1: W = W - delta_W (unmerge), +1: W = W + delta_W (merge)
|
276 |
-
sign = -1 if mode else 1
|
277 |
-
self.weight.data += sign * self.zero_pad(T(delta_w * self.scaling)) # (256, 128) after zero_pad (384, 128)
|
278 |
-
self.merged = not mode
|
279 |
-
|
280 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
281 |
-
"""Do the forward pass.
|
282 |
-
|
283 |
-
If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
|
284 |
-
If not, then multiply pretrained weights with input, apply LoRA on input and do summation.
|
285 |
-
|
286 |
-
Args:
|
287 |
-
x: input tensor of shape (batch_size, context_length, embedding_size)
|
288 |
-
|
289 |
-
Returns:
|
290 |
-
Output tensor of shape (batch_size, context_length, 3 * embedding_size)
|
291 |
-
"""
|
292 |
-
def T(w):
|
293 |
-
return w.T if self.fan_in_fan_out else w
|
294 |
-
|
295 |
-
# Let's assume that:
|
296 |
-
# ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
|
297 |
-
# ⚬ self.weight: (384, 128) or (3 * embedding_size, embedding_size)
|
298 |
-
# ⚬ self.lora_A.data: (4, 128)
|
299 |
-
# ⚬ self.lora_B.data: (256, 2)
|
300 |
-
|
301 |
-
# the logic here is that the weights are merged only during inference
|
302 |
-
# so if they are merged we don't need to do anything with LoRA's A and B matrices
|
303 |
-
# but if the weights are not merged that means that the forward method is called during
|
304 |
-
# training and we need to forward pass input through pretrained weights, LoRA A and B matrices
|
305 |
-
# and do the summation (as per scheme at the top of the file)
|
306 |
-
if self.merged:
|
307 |
-
return F.linear(x, T(self.weight), bias=self.bias)
|
308 |
-
else:
|
309 |
-
# `F.linear` automatically transposes the second argument (T(self.weight) in our case)
|
310 |
-
result = F.linear(x, T(self.weight), bias=self.bias) # (64, 64, 128) @ (384, 128) -> (64, 64, 384)
|
311 |
-
if self.r > 0:
|
312 |
-
after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
|
313 |
-
# For F.conv1d:
|
314 |
-
# ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
|
315 |
-
# ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
|
316 |
-
# ⚬ groups: split input into groups, in_channels should be divisible by the number of groups. Default: 1
|
317 |
-
# presumably iW - sequence width/length, kW - kernel width
|
318 |
-
after_B = F.conv1d(
|
319 |
-
after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
|
320 |
-
self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
|
321 |
-
groups=sum(self.enable_lora)
|
322 |
-
).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
|
323 |
-
result += self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
|
324 |
-
return result
|
325 |
-
|
326 |
-
|
327 |
-
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
|
328 |
-
"""Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.
|
329 |
-
|
330 |
-
Args:
|
331 |
-
model: model with LoRA layers
|
332 |
-
bias:
|
333 |
-
``"none"``: all bias weights will be frozen,
|
334 |
-
``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
|
335 |
-
``"all"``: all bias weights will be unfrozen.
|
336 |
-
|
337 |
-
Raises:
|
338 |
-
NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
|
339 |
-
"""
|
340 |
-
# freeze all layers except LoRA's
|
341 |
-
for n, p in model.named_parameters():
|
342 |
-
if 'lora_' not in n:
|
343 |
-
p.requires_grad = False
|
344 |
-
|
345 |
-
# depending on the `bias` value unfreeze bias weights
|
346 |
-
if bias == 'none':
|
347 |
-
return
|
348 |
-
elif bias == 'all':
|
349 |
-
for n, p in model.named_parameters():
|
350 |
-
if 'bias' in n:
|
351 |
-
p.requires_grad = True
|
352 |
-
elif bias == 'lora_only':
|
353 |
-
for m in model.modules():
|
354 |
-
if isinstance(m, LoRALayer) and \
|
355 |
-
hasattr(m, 'bias') and \
|
356 |
-
m.bias is not None:
|
357 |
-
m.bias.requires_grad = True
|
358 |
-
else:
|
359 |
-
raise NotImplementedError
|
360 |
-
|
361 |
-
|
362 |
-
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
|
363 |
-
"""Return state_dict with weights of LoRA's A and B matrices and with biases depending on the `bias` value.
|
364 |
-
|
365 |
-
Args:
|
366 |
-
model: model with LoRA layers
|
367 |
-
bias:
|
368 |
-
``"none"``: state dict will not store bias weights,
|
369 |
-
``"lora_only"``: state dict will store bias weights only from LoRA layers,
|
370 |
-
``"all"``: state dict will store all bias weights.
|
371 |
-
|
372 |
-
Returns:
|
373 |
-
Weights and biases of LoRA layers
|
374 |
-
|
375 |
-
Raises:
|
376 |
-
NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
|
377 |
-
"""
|
378 |
-
my_state_dict = model.state_dict()
|
379 |
-
if bias == 'none':
|
380 |
-
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
|
381 |
-
elif bias == 'all':
|
382 |
-
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
|
383 |
-
elif bias == 'lora_only':
|
384 |
-
to_return = {}
|
385 |
-
for k in my_state_dict:
|
386 |
-
if 'lora_' in k:
|
387 |
-
to_return[k] = my_state_dict[k]
|
388 |
-
bias_name = k.split('lora_')[0]+'bias'
|
389 |
-
if bias_name in my_state_dict:
|
390 |
-
to_return[bias_name] = my_state_dict[bias_name]
|
391 |
-
return to_return
|
392 |
-
else:
|
393 |
-
raise NotImplementedError
|
394 |
-
|
395 |
-
|
396 |
-
@dataclass
|
397 |
-
class LoRAConfig:
|
398 |
-
r: float = 0.0
|
399 |
-
alpha: float = 1.0
|
400 |
-
dropout: float = 0.0
|
401 |
-
|
402 |
-
|
403 |
-
class CausalSelfAttention(llama.CausalSelfAttention):
|
404 |
-
lora_config = None
|
405 |
-
|
406 |
-
def __init__(self, config: llama.LLaMAConfig) -> None:
|
407 |
-
"""Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for
|
408 |
-
parameter-efficient fine-tuning.
|
409 |
-
|
410 |
-
*Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for
|
411 |
-
query, key and value for each head) we can do this in a single pass with a single weight matrix.
|
412 |
-
|
413 |
-
Args:
|
414 |
-
config:
|
415 |
-
``"block_size"``: size of the context of the model,
|
416 |
-
``"vocab_size"``: number of unique tokens,
|
417 |
-
``"padded_vocab_size"``: padded size of the vocabulary to the nearest multiple of 64 (leads to a greater performance),
|
418 |
-
``"n_layer"``: number of transformer blocks (self-attention + MLP),
|
419 |
-
``"n_head"``: number of heads in multi-head attention mechanism,
|
420 |
-
``"n_embd"``: size of the embedding: vector representation of each token.
|
421 |
-
"""
|
422 |
-
# Skip the parent class __init__ altogether and replace it to avoid
|
423 |
-
# useless allocations
|
424 |
-
nn.Module.__init__(self)
|
425 |
-
assert config.n_embd % config.n_head == 0
|
426 |
-
|
427 |
-
# key, query, value projections for all heads, but in a batch
|
428 |
-
self.c_attn = MergedLinear(
|
429 |
-
in_features=config.n_embd,
|
430 |
-
out_features=3 * config.n_embd,
|
431 |
-
r=self.lora_config.r,
|
432 |
-
lora_alpha=self.lora_config.alpha,
|
433 |
-
lora_dropout=self.lora_config.dropout,
|
434 |
-
enable_lora=[True, False, True],
|
435 |
-
fan_in_fan_out = False,
|
436 |
-
merge_weights=True,
|
437 |
-
bias=False)
|
438 |
-
# output projection
|
439 |
-
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
440 |
-
# regularization
|
441 |
-
self.n_head = config.n_head
|
442 |
-
self.n_embd = config.n_embd
|
443 |
-
self.block_size = config.block_size
|
444 |
-
self.rope_cache = None
|
445 |
-
|
446 |
-
|
447 |
-
@contextmanager
|
448 |
-
def lora(r, alpha, dropout, enabled: bool = True):
|
449 |
-
"""Apply context manager under which you can instantiate the model with LoRA.
|
450 |
-
|
451 |
-
In a nutshell the code inside this function forces to use LoRA variant of causal self-attention
|
452 |
-
instead of the original one (without LoRA).
|
453 |
-
|
454 |
-
Args:
|
455 |
-
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
|
456 |
-
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
|
457 |
-
alpha: alpha is needed for scaling updates as alpha/r
|
458 |
-
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
|
459 |
-
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
|
460 |
-
dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
|
461 |
-
enabled: enables/disables LoRA
|
462 |
-
"""
|
463 |
-
if not enabled:
|
464 |
-
yield
|
465 |
-
return
|
466 |
-
|
467 |
-
CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout)
|
468 |
-
# when entering context manager replace link to causal self-attention class from original
|
469 |
-
# to a variant with LoRA
|
470 |
-
causal_self_attention = llama.CausalSelfAttention
|
471 |
-
llama.CausalSelfAttention = CausalSelfAttention
|
472 |
-
yield
|
473 |
-
# when exiting context manager - restore link to original causal self-attention class
|
474 |
-
llama.CausalSelfAttention = causal_self_attention
|
475 |
-
|
476 |
-
CausalSelfAttention.lora_config = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/lit_llama/model.py
DELETED
@@ -1,321 +0,0 @@
|
|
1 |
-
"""Full definition of a LLaMA Language Model, all of it in this single file.
|
2 |
-
|
3 |
-
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
|
4 |
-
"""
|
5 |
-
# mypy: ignore-errors
|
6 |
-
import math
|
7 |
-
from dataclasses import dataclass
|
8 |
-
from typing import List, Optional, Tuple, Union
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
from torch.nn import functional as F
|
13 |
-
from typing_extensions import Self
|
14 |
-
|
15 |
-
from lit_llama.utils import find_multiple
|
16 |
-
|
17 |
-
|
18 |
-
MaskCache = torch.Tensor
|
19 |
-
RoPECache = torch.Tensor
|
20 |
-
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
21 |
-
|
22 |
-
|
23 |
-
@dataclass
|
24 |
-
class LLaMAConfig:
|
25 |
-
block_size: int = 2048
|
26 |
-
vocab_size: int = 32000
|
27 |
-
padded_vocab_size: Optional[int] = None
|
28 |
-
n_layer: int = 32
|
29 |
-
n_head: int = 32
|
30 |
-
n_embd: int = 4096
|
31 |
-
|
32 |
-
def __post_init__(self):
|
33 |
-
if self.padded_vocab_size is None:
|
34 |
-
self.padded_vocab_size = find_multiple(self.vocab_size, 64)
|
35 |
-
|
36 |
-
@classmethod
|
37 |
-
def from_name(cls, name: str) -> Self:
|
38 |
-
return cls(**llama_configs[name])
|
39 |
-
|
40 |
-
|
41 |
-
llama_configs = {
|
42 |
-
"7B": dict(n_layer=32, n_head=32, n_embd=4096),
|
43 |
-
"13B": dict(n_layer=40, n_head=40, n_embd=5120),
|
44 |
-
"30B": dict(n_layer=60, n_head=52, n_embd=6656),
|
45 |
-
"65B": dict(n_layer=80, n_head=64, n_embd=8192),
|
46 |
-
}
|
47 |
-
|
48 |
-
|
49 |
-
class LLaMA(nn.Module):
|
50 |
-
def __init__(self, config: LLaMAConfig) -> None:
|
51 |
-
super().__init__()
|
52 |
-
assert config.padded_vocab_size is not None
|
53 |
-
self.config = config
|
54 |
-
|
55 |
-
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
|
56 |
-
self.transformer = nn.ModuleDict(
|
57 |
-
dict(
|
58 |
-
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
59 |
-
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
|
60 |
-
ln_f=RMSNorm(config.n_embd),
|
61 |
-
)
|
62 |
-
)
|
63 |
-
|
64 |
-
self.rope_cache: Optional[RoPECache] = None
|
65 |
-
self.mask_cache: Optional[MaskCache] = None
|
66 |
-
self.kv_caches: List[KVCache] = []
|
67 |
-
|
68 |
-
def _init_weights(self, module: nn.Module) -> None:
|
69 |
-
if isinstance(module, nn.Linear):
|
70 |
-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
|
71 |
-
elif isinstance(module, nn.Embedding):
|
72 |
-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
|
73 |
-
|
74 |
-
def forward(
|
75 |
-
self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
|
76 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
|
77 |
-
B, T = idx.size()
|
78 |
-
|
79 |
-
block_size = self.config.block_size
|
80 |
-
if max_seq_length is None:
|
81 |
-
max_seq_length = block_size
|
82 |
-
assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
|
83 |
-
assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
|
84 |
-
assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"
|
85 |
-
|
86 |
-
if self.rope_cache is None:
|
87 |
-
self.rope_cache = self.build_rope_cache(idx)
|
88 |
-
if self.mask_cache is None:
|
89 |
-
self.mask_cache = self.build_mask_cache(idx)
|
90 |
-
|
91 |
-
if input_pos is not None:
|
92 |
-
rope = self.rope_cache.index_select(0, input_pos)
|
93 |
-
mask = self.mask_cache.index_select(2, input_pos)
|
94 |
-
mask = mask[:, :, :, :max_seq_length]
|
95 |
-
else:
|
96 |
-
rope = self.rope_cache[:T]
|
97 |
-
mask = self.mask_cache[:, :, :T, :T]
|
98 |
-
|
99 |
-
# forward the model itself
|
100 |
-
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
101 |
-
|
102 |
-
if input_pos is None: # proxy for use_cache=False
|
103 |
-
for block in self.transformer.h:
|
104 |
-
x, _ = block(x, rope, mask, max_seq_length)
|
105 |
-
else:
|
106 |
-
if not self.kv_caches:
|
107 |
-
head_size = self.config.n_embd // self.config.n_head
|
108 |
-
cache_shape = (B, self.config.n_head, max_seq_length, head_size)
|
109 |
-
self.kv_caches = [
|
110 |
-
(torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype))
|
111 |
-
for _ in range(self.config.n_layer)
|
112 |
-
]
|
113 |
-
for i, block in enumerate(self.transformer.h):
|
114 |
-
x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
|
115 |
-
|
116 |
-
x = self.transformer.ln_f(x)
|
117 |
-
|
118 |
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
119 |
-
|
120 |
-
return logits
|
121 |
-
|
122 |
-
@classmethod
|
123 |
-
def from_name(cls, name: str) -> Self:
|
124 |
-
return cls(LLaMAConfig.from_name(name))
|
125 |
-
|
126 |
-
def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
|
127 |
-
return build_rope_cache(
|
128 |
-
seq_len=self.config.block_size,
|
129 |
-
n_elem=self.config.n_embd // self.config.n_head,
|
130 |
-
dtype=idx.dtype,
|
131 |
-
device=idx.device,
|
132 |
-
)
|
133 |
-
|
134 |
-
def build_mask_cache(self, idx: torch.Tensor) -> MaskCache:
|
135 |
-
ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
|
136 |
-
return torch.tril(ones).unsqueeze(0).unsqueeze(0)
|
137 |
-
|
138 |
-
def reset_cache(self) -> None:
|
139 |
-
self.kv_caches.clear()
|
140 |
-
if self.mask_cache.device.type == "xla":
|
141 |
-
# https://github.com/Lightning-AI/lit-parrot/pull/83#issuecomment-1558150179
|
142 |
-
self.rope_cache = None
|
143 |
-
self.mask_cache = None
|
144 |
-
|
145 |
-
|
146 |
-
class Block(nn.Module):
|
147 |
-
def __init__(self, config: LLaMAConfig) -> None:
|
148 |
-
super().__init__()
|
149 |
-
self.rms_1 = RMSNorm(config.n_embd)
|
150 |
-
self.attn = CausalSelfAttention(config)
|
151 |
-
self.rms_2 = RMSNorm(config.n_embd)
|
152 |
-
self.mlp = MLP(config)
|
153 |
-
|
154 |
-
def forward(
|
155 |
-
self,
|
156 |
-
x: torch.Tensor,
|
157 |
-
rope: RoPECache,
|
158 |
-
mask: MaskCache,
|
159 |
-
max_seq_length: int,
|
160 |
-
input_pos: Optional[torch.Tensor] = None,
|
161 |
-
kv_cache: Optional[KVCache] = None,
|
162 |
-
) -> Tuple[torch.Tensor, Optional[KVCache]]:
|
163 |
-
h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
|
164 |
-
x = x + h
|
165 |
-
x = x + self.mlp(self.rms_2(x))
|
166 |
-
return x, new_kv_cache
|
167 |
-
|
168 |
-
|
169 |
-
class CausalSelfAttention(nn.Module):
|
170 |
-
def __init__(self, config: LLaMAConfig) -> None:
|
171 |
-
super().__init__()
|
172 |
-
assert config.n_embd % config.n_head == 0
|
173 |
-
|
174 |
-
# key, query, value projections for all heads, but in a batch
|
175 |
-
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
|
176 |
-
# output projection
|
177 |
-
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
178 |
-
|
179 |
-
self.n_head = config.n_head
|
180 |
-
self.n_embd = config.n_embd
|
181 |
-
self.block_size = config.block_size
|
182 |
-
|
183 |
-
def forward(
|
184 |
-
self,
|
185 |
-
x: torch.Tensor,
|
186 |
-
rope: RoPECache,
|
187 |
-
mask: MaskCache,
|
188 |
-
max_seq_length: int,
|
189 |
-
input_pos: Optional[torch.Tensor] = None,
|
190 |
-
kv_cache: Optional[KVCache] = None,
|
191 |
-
) -> Tuple[torch.Tensor, Optional[KVCache]]:
|
192 |
-
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
193 |
-
|
194 |
-
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
195 |
-
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
196 |
-
|
197 |
-
head_size = C // self.n_head
|
198 |
-
k = k.view(B, T, self.n_head, head_size)
|
199 |
-
q = q.view(B, T, self.n_head, head_size)
|
200 |
-
v = v.view(B, T, self.n_head, head_size)
|
201 |
-
|
202 |
-
q = apply_rope(q, rope)
|
203 |
-
k = apply_rope(k, rope)
|
204 |
-
|
205 |
-
k = k.transpose(1, 2) # (B, nh, T, hs)
|
206 |
-
q = q.transpose(1, 2) # (B, nh, T, hs)
|
207 |
-
v = v.transpose(1, 2) # (B, nh, T, hs)
|
208 |
-
|
209 |
-
if kv_cache is not None:
|
210 |
-
cache_k, cache_v = kv_cache
|
211 |
-
# check if reached token limit
|
212 |
-
if input_pos[-1] >= max_seq_length:
|
213 |
-
input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
|
214 |
-
# shift 1 position to the left
|
215 |
-
cache_k = torch.roll(cache_k, -1, dims=2)
|
216 |
-
cache_v = torch.roll(cache_v, -1, dims=2)
|
217 |
-
k = cache_k.index_copy(2, input_pos, k)
|
218 |
-
v = cache_v.index_copy(2, input_pos, v)
|
219 |
-
kv_cache = k, v
|
220 |
-
|
221 |
-
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
222 |
-
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
223 |
-
# att = att.masked_fill(mask[:,:,:T,:T] == 0, float('-inf'))
|
224 |
-
# att = F.softmax(att, dim=-1)
|
225 |
-
# y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
226 |
-
|
227 |
-
# efficient attention using Flash Attention CUDA kernels
|
228 |
-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
229 |
-
|
230 |
-
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
231 |
-
|
232 |
-
# output projection
|
233 |
-
y = self.c_proj(y)
|
234 |
-
|
235 |
-
return y, kv_cache
|
236 |
-
|
237 |
-
|
238 |
-
class MLP(nn.Module):
|
239 |
-
def __init__(self, config: LLaMAConfig) -> None:
|
240 |
-
super().__init__()
|
241 |
-
hidden_dim = 4 * config.n_embd
|
242 |
-
n_hidden = int(2 * hidden_dim / 3)
|
243 |
-
n_hidden = find_multiple(n_hidden, 256)
|
244 |
-
|
245 |
-
self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
|
246 |
-
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
|
247 |
-
self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
|
248 |
-
|
249 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
250 |
-
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
251 |
-
x = self.c_proj(x)
|
252 |
-
return x
|
253 |
-
|
254 |
-
|
255 |
-
class RMSNorm(nn.Module):
|
256 |
-
"""Root Mean Square Layer Normalization.
|
257 |
-
|
258 |
-
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
|
259 |
-
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
|
260 |
-
"""
|
261 |
-
|
262 |
-
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
|
263 |
-
super().__init__()
|
264 |
-
self.scale = nn.Parameter(torch.ones(size))
|
265 |
-
self.eps = eps
|
266 |
-
self.dim = dim
|
267 |
-
|
268 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
269 |
-
# NOTE: the original RMSNorm paper implementation is not equivalent
|
270 |
-
# norm_x = x.norm(2, dim=self.dim, keepdim=True)
|
271 |
-
# rms_x = norm_x * d_x ** (-1. / 2)
|
272 |
-
# x_normed = x / (rms_x + self.eps)
|
273 |
-
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
|
274 |
-
x_normed = x * torch.rsqrt(norm_x + self.eps)
|
275 |
-
return self.scale * x_normed
|
276 |
-
|
277 |
-
|
278 |
-
def build_rope_cache(
|
279 |
-
seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
280 |
-
) -> RoPECache:
|
281 |
-
"""Enhanced Transformer with Rotary Position Embedding.
|
282 |
-
|
283 |
-
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
284 |
-
transformers/rope/__init__.py. MIT License:
|
285 |
-
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
286 |
-
"""
|
287 |
-
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
288 |
-
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
|
289 |
-
|
290 |
-
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
291 |
-
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
|
292 |
-
|
293 |
-
# Calculate the product of position index and $\theta_i$
|
294 |
-
idx_theta = torch.outer(seq_idx, theta).float()
|
295 |
-
|
296 |
-
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
297 |
-
|
298 |
-
# this is to mimic the behaviour of complex32, else we will get different results
|
299 |
-
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
300 |
-
cache = cache.half()
|
301 |
-
return cache
|
302 |
-
|
303 |
-
|
304 |
-
def apply_rope(x: torch.Tensor, rope_cache: RoPECache) -> torch.Tensor:
|
305 |
-
# truncate to support variable sizes
|
306 |
-
T = x.size(1)
|
307 |
-
rope_cache = rope_cache[:T]
|
308 |
-
|
309 |
-
# cast because the reference does
|
310 |
-
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
311 |
-
rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
312 |
-
x_out2 = torch.stack(
|
313 |
-
[
|
314 |
-
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
315 |
-
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
316 |
-
],
|
317 |
-
-1,
|
318 |
-
)
|
319 |
-
|
320 |
-
x_out2 = x_out2.flatten(3)
|
321 |
-
return x_out2.type_as(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/lit_llama/packed_dataset.py
DELETED
@@ -1,260 +0,0 @@
|
|
1 |
-
# Very loosely inspired by indexed_dataset in Fairseq, Megatron
|
2 |
-
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
|
3 |
-
|
4 |
-
|
5 |
-
import os
|
6 |
-
import struct
|
7 |
-
import random
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
from torch.utils.data import IterableDataset, get_worker_info
|
12 |
-
|
13 |
-
|
14 |
-
dtypes = {
|
15 |
-
1: np.uint8,
|
16 |
-
2: np.int8,
|
17 |
-
3: np.int16,
|
18 |
-
4: np.int32,
|
19 |
-
5: np.int64,
|
20 |
-
6: np.float32,
|
21 |
-
7: np.float64,
|
22 |
-
8: np.uint16,
|
23 |
-
}
|
24 |
-
|
25 |
-
|
26 |
-
def code(dtype):
|
27 |
-
for k in dtypes.keys():
|
28 |
-
if dtypes[k] == dtype:
|
29 |
-
return k
|
30 |
-
raise ValueError(dtype)
|
31 |
-
|
32 |
-
|
33 |
-
HDR_MAGIC = b"LITPKDS"
|
34 |
-
HDR_SIZE = 24 # bytes
|
35 |
-
|
36 |
-
|
37 |
-
class PackedDataset(IterableDataset):
|
38 |
-
def __init__(self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0):
|
39 |
-
self._filenames = filenames
|
40 |
-
self._n_chunks = n_chunks
|
41 |
-
self._block_size = block_size
|
42 |
-
self._seed = seed
|
43 |
-
self._shuffle = shuffle
|
44 |
-
self._wrap = wrap
|
45 |
-
self._num_processes = num_processes
|
46 |
-
self._process_rank = process_rank
|
47 |
-
|
48 |
-
def __iter__(self):
|
49 |
-
worker_info = get_worker_info()
|
50 |
-
num_workers = worker_info.num_workers if worker_info is not None else 1
|
51 |
-
worker_id = worker_info.id if worker_info is not None else 0
|
52 |
-
num_shards = num_workers * self._num_processes
|
53 |
-
shard_id = self._process_rank * num_workers + worker_id
|
54 |
-
|
55 |
-
max_num_files = len(self._filenames) // num_shards * num_shards
|
56 |
-
filenames = self._filenames[shard_id : max_num_files : num_shards]
|
57 |
-
|
58 |
-
return PackedDatasetIterator(
|
59 |
-
filenames=filenames,
|
60 |
-
n_chunks=self._n_chunks,
|
61 |
-
block_size=self._block_size,
|
62 |
-
seed=self._seed,
|
63 |
-
shuffle=self._shuffle,
|
64 |
-
wrap=self._wrap,
|
65 |
-
)
|
66 |
-
|
67 |
-
|
68 |
-
class PackedDatasetBuilder(object):
|
69 |
-
def __init__(
|
70 |
-
self,
|
71 |
-
outdir,
|
72 |
-
prefix,
|
73 |
-
chunk_size,
|
74 |
-
sep_token,
|
75 |
-
dtype="auto",
|
76 |
-
vocab_size=None,
|
77 |
-
):
|
78 |
-
if dtype == "auto":
|
79 |
-
if vocab_size is None:
|
80 |
-
raise ValueError("vocab_size cannot be None when dtype='auto'")
|
81 |
-
if vocab_size is not None and vocab_size < 65500:
|
82 |
-
self._dtype = np.uint16
|
83 |
-
else:
|
84 |
-
self._dtype = np.int32
|
85 |
-
else:
|
86 |
-
self._dtype = dtype
|
87 |
-
self._counter = 0
|
88 |
-
self._chunk_size = chunk_size
|
89 |
-
self._outdir = outdir
|
90 |
-
self._prefix = prefix
|
91 |
-
self._sep_token = sep_token
|
92 |
-
self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
|
93 |
-
self._arr.fill(self._sep_token)
|
94 |
-
self._idx = 0
|
95 |
-
self._version = 1
|
96 |
-
self._filenames = []
|
97 |
-
|
98 |
-
def _write_chunk(self):
|
99 |
-
filename = f"{self._prefix}_{self._counter:010d}.bin"
|
100 |
-
filename = os.path.join(self._outdir, filename)
|
101 |
-
|
102 |
-
with open(filename, "wb") as f:
|
103 |
-
f.write(HDR_MAGIC)
|
104 |
-
f.write(struct.pack("<Q", self._version))
|
105 |
-
f.write(struct.pack("<B", code(self._dtype)))
|
106 |
-
f.write(struct.pack("<Q", self._chunk_size))
|
107 |
-
f.write(self._arr.tobytes(order="C"))
|
108 |
-
|
109 |
-
self._filenames.append(filename)
|
110 |
-
self._counter += 1
|
111 |
-
self._arr.fill(self._sep_token)
|
112 |
-
self._idx = 0
|
113 |
-
|
114 |
-
@property
|
115 |
-
def dtype(self):
|
116 |
-
return self._dtype
|
117 |
-
|
118 |
-
@property
|
119 |
-
def filenames(self):
|
120 |
-
return self._filenames.copy()
|
121 |
-
|
122 |
-
def add_array(self, arr):
|
123 |
-
while self._idx + arr.shape[0] > self._chunk_size:
|
124 |
-
part_len = self._chunk_size - self._idx
|
125 |
-
self._arr[self._idx : self._idx + part_len] = arr[:part_len]
|
126 |
-
self._write_chunk()
|
127 |
-
arr = arr[part_len:]
|
128 |
-
|
129 |
-
arr_len = arr.shape[0]
|
130 |
-
self._arr[self._idx : self._idx + arr_len] = arr
|
131 |
-
self._idx += arr_len
|
132 |
-
|
133 |
-
def write_reminder(self):
|
134 |
-
self._write_chunk()
|
135 |
-
|
136 |
-
|
137 |
-
class PackedDatasetIterator:
|
138 |
-
def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
|
139 |
-
self._seed = seed
|
140 |
-
self._shuffle = shuffle
|
141 |
-
self._rng = np.random.default_rng(seed) if shuffle else None
|
142 |
-
self._block_idxs = None
|
143 |
-
|
144 |
-
self._wrap = wrap
|
145 |
-
|
146 |
-
# TODO: instead of filenames, we could have a single text stream
|
147 |
-
# (or text file) with the sequence of all files to be
|
148 |
-
# fetched/loaded.
|
149 |
-
self._filenames = filenames
|
150 |
-
self._file_idx = 0
|
151 |
-
|
152 |
-
self._n_chunks = n_chunks
|
153 |
-
|
154 |
-
self._dtype = None
|
155 |
-
self._block_size = block_size
|
156 |
-
self._n_blocks = None
|
157 |
-
|
158 |
-
self._mmaps = []
|
159 |
-
self._buffers = []
|
160 |
-
|
161 |
-
self._block_idxs = []
|
162 |
-
self._curr_idx = 0
|
163 |
-
|
164 |
-
self._load_n_chunks()
|
165 |
-
|
166 |
-
def _read_header(self, path):
|
167 |
-
with open(path, "rb") as f:
|
168 |
-
magic = f.read(len(HDR_MAGIC))
|
169 |
-
assert magic == HDR_MAGIC, "File doesn't match expected format."
|
170 |
-
version = struct.unpack("<Q", f.read(8))
|
171 |
-
assert (1,) == version
|
172 |
-
(dtype_code,) = struct.unpack("<B", f.read(1))
|
173 |
-
dtype = dtypes[dtype_code]
|
174 |
-
(chunk_size,) = struct.unpack("<Q", f.read(8))
|
175 |
-
return dtype, chunk_size
|
176 |
-
|
177 |
-
def _close_mmaps(self):
|
178 |
-
for mmap in self._mmaps:
|
179 |
-
mmap._mmap.close()
|
180 |
-
|
181 |
-
def _load_n_chunks(self):
|
182 |
-
self._close_mmaps()
|
183 |
-
self._mmaps = []
|
184 |
-
self._buffers = []
|
185 |
-
|
186 |
-
if self._n_chunks > len(self._filenames[self._file_idx:]):
|
187 |
-
if not self._wrap:
|
188 |
-
raise StopIteration
|
189 |
-
else:
|
190 |
-
self._file_idx = 0
|
191 |
-
|
192 |
-
for i in range(self._n_chunks):
|
193 |
-
filename = self._filenames[self._file_idx + i]
|
194 |
-
if self._dtype is None:
|
195 |
-
self._dtype, self._chunk_size = self._read_header(
|
196 |
-
filename
|
197 |
-
)
|
198 |
-
self._n_blocks = self._chunk_size // self._block_size
|
199 |
-
# TODO: check header matches with previous files
|
200 |
-
mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
|
201 |
-
self._mmaps.append(mmap)
|
202 |
-
self._buffers.append(memoryview(mmap))
|
203 |
-
|
204 |
-
self._file_idx += self._n_chunks
|
205 |
-
n_all_blocks = self._n_chunks * self._n_blocks
|
206 |
-
|
207 |
-
self._block_idxs = (
|
208 |
-
self._rng.permutation(n_all_blocks)
|
209 |
-
if self._shuffle
|
210 |
-
else range(n_all_blocks)
|
211 |
-
)
|
212 |
-
|
213 |
-
self._curr_idx = 0
|
214 |
-
|
215 |
-
def __del__(self):
|
216 |
-
self._close_mmaps()
|
217 |
-
del self._mmaps
|
218 |
-
del self._buffers
|
219 |
-
|
220 |
-
def __iter__(self):
|
221 |
-
return self
|
222 |
-
|
223 |
-
def __next__(self):
|
224 |
-
if self._curr_idx >= len(self._block_idxs):
|
225 |
-
self._load_n_chunks()
|
226 |
-
# TODO: trigger fetching next next n_chunks if remote
|
227 |
-
block_idx = self._block_idxs[self._curr_idx]
|
228 |
-
chunk_id = block_idx // self._n_blocks
|
229 |
-
buffer = self._buffers[chunk_id]
|
230 |
-
elem_id = (block_idx % self._n_blocks) * self._block_size
|
231 |
-
offset = np.dtype(self._dtype).itemsize * elem_id
|
232 |
-
arr = np.frombuffer(
|
233 |
-
buffer, dtype=self._dtype, count=self._block_size, offset=offset
|
234 |
-
)
|
235 |
-
self._curr_idx += 1
|
236 |
-
return torch.from_numpy(arr.astype(np.int64))
|
237 |
-
|
238 |
-
|
239 |
-
class CombinedDataset(IterableDataset):
|
240 |
-
def __init__(self, datasets, seed, weights=None):
|
241 |
-
self._seed = seed
|
242 |
-
self._datasets = datasets
|
243 |
-
self._weights = weights
|
244 |
-
n_datasets = len(datasets)
|
245 |
-
if weights is None:
|
246 |
-
self._weights = [1 / n_datasets] * n_datasets
|
247 |
-
|
248 |
-
def __iter__(self):
|
249 |
-
return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
|
250 |
-
|
251 |
-
|
252 |
-
class CombinedDatasetIterator:
|
253 |
-
def __init__(self, datasets, seed, weights):
|
254 |
-
self._datasets = [iter(el) for el in datasets]
|
255 |
-
self._weights = weights
|
256 |
-
self._rng = random.Random(seed)
|
257 |
-
|
258 |
-
def __next__(self):
|
259 |
-
dataset, = self._rng.choices(self._datasets, weights=self._weights, k=1)
|
260 |
-
return next(dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/lit_llama/quantization.py
DELETED
@@ -1,614 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from contextlib import contextmanager
|
3 |
-
import warnings
|
4 |
-
import math
|
5 |
-
|
6 |
-
import torch
|
7 |
-
|
8 |
-
# configuration for bitsandbytes before import
|
9 |
-
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
10 |
-
warnings.filterwarnings(
|
11 |
-
"ignore",
|
12 |
-
message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization",
|
13 |
-
)
|
14 |
-
warnings.filterwarnings(
|
15 |
-
"ignore",
|
16 |
-
message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
|
17 |
-
)
|
18 |
-
warnings.filterwarnings(
|
19 |
-
"ignore",
|
20 |
-
message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.",
|
21 |
-
)
|
22 |
-
|
23 |
-
try:
|
24 |
-
import bitsandbytes as bnb # noqa: E402
|
25 |
-
except:
|
26 |
-
bnb = None
|
27 |
-
|
28 |
-
try:
|
29 |
-
import triton # noqa: E402
|
30 |
-
import triton.language as tl # noqa: E402
|
31 |
-
except:
|
32 |
-
triton = None
|
33 |
-
|
34 |
-
if bnb is not None:
|
35 |
-
|
36 |
-
class Linear8bitLt(bnb.nn.Linear8bitLt):
|
37 |
-
"""Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and
|
38 |
-
re-quantizaton when loading the state dict.
|
39 |
-
|
40 |
-
|
41 |
-
This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly.
|
42 |
-
"""
|
43 |
-
|
44 |
-
def __init__(self, *args, **kwargs):
|
45 |
-
super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0)
|
46 |
-
# We quantize the initial weight here so we don't end up filling the device
|
47 |
-
# memory with float32 weights which could lead to OOM.
|
48 |
-
self._quantize_weight(self.weight.data)
|
49 |
-
|
50 |
-
def _load_from_state_dict(self, local_state_dict, *args, **kwargs):
|
51 |
-
# There is only one key that ends with `*.weight`, the other one is the bias
|
52 |
-
weight_key = next(
|
53 |
-
(name for name in local_state_dict.keys() if name.endswith("weight")),
|
54 |
-
None,
|
55 |
-
)
|
56 |
-
if weight_key is None:
|
57 |
-
return
|
58 |
-
|
59 |
-
# Load the weight from the state dict and re-quantize it
|
60 |
-
weight = local_state_dict.pop(weight_key)
|
61 |
-
self._quantize_weight(weight)
|
62 |
-
|
63 |
-
# If there is a bias, let nn.Module load it
|
64 |
-
if local_state_dict:
|
65 |
-
super()._load_from_state_dict(local_state_dict, *args, **kwargs)
|
66 |
-
|
67 |
-
def _quantize_weight(self, weight: torch.Tensor) -> None:
|
68 |
-
# This code is taken and adapted from `bnb.nn.Int8Params.cuda()`
|
69 |
-
B = weight.contiguous().half().cuda()
|
70 |
-
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
71 |
-
del CBt
|
72 |
-
del SCBt
|
73 |
-
self.weight.data = CB
|
74 |
-
setattr(self.weight, "CB", CB)
|
75 |
-
setattr(self.weight, "SCB", SCB)
|
76 |
-
|
77 |
-
|
78 |
-
if triton is not None:
|
79 |
-
# This is adapted from the OpenAI Triton matmul example.
|
80 |
-
@triton.autotune(
|
81 |
-
configs=[
|
82 |
-
triton.Config(
|
83 |
-
{
|
84 |
-
"BLOCK_SIZE_M": 128,
|
85 |
-
"BLOCK_SIZE_N": 256,
|
86 |
-
"BLOCK_SIZE_K": 32,
|
87 |
-
"GROUP_SIZE_M": 8,
|
88 |
-
},
|
89 |
-
num_stages=3,
|
90 |
-
num_warps=8,
|
91 |
-
),
|
92 |
-
triton.Config(
|
93 |
-
{
|
94 |
-
"BLOCK_SIZE_M": 256,
|
95 |
-
"BLOCK_SIZE_N": 128,
|
96 |
-
"BLOCK_SIZE_K": 32,
|
97 |
-
"GROUP_SIZE_M": 8,
|
98 |
-
},
|
99 |
-
num_stages=3,
|
100 |
-
num_warps=8,
|
101 |
-
),
|
102 |
-
triton.Config(
|
103 |
-
{
|
104 |
-
"BLOCK_SIZE_M": 256,
|
105 |
-
"BLOCK_SIZE_N": 64,
|
106 |
-
"BLOCK_SIZE_K": 32,
|
107 |
-
"GROUP_SIZE_M": 8,
|
108 |
-
},
|
109 |
-
num_stages=4,
|
110 |
-
num_warps=4,
|
111 |
-
),
|
112 |
-
triton.Config(
|
113 |
-
{
|
114 |
-
"BLOCK_SIZE_M": 64,
|
115 |
-
"BLOCK_SIZE_N": 256,
|
116 |
-
"BLOCK_SIZE_K": 32,
|
117 |
-
"GROUP_SIZE_M": 8,
|
118 |
-
},
|
119 |
-
num_stages=4,
|
120 |
-
num_warps=4,
|
121 |
-
),
|
122 |
-
triton.Config(
|
123 |
-
{
|
124 |
-
"BLOCK_SIZE_M": 128,
|
125 |
-
"BLOCK_SIZE_N": 128,
|
126 |
-
"BLOCK_SIZE_K": 32,
|
127 |
-
"GROUP_SIZE_M": 8,
|
128 |
-
},
|
129 |
-
num_stages=4,
|
130 |
-
num_warps=4,
|
131 |
-
),
|
132 |
-
triton.Config(
|
133 |
-
{
|
134 |
-
"BLOCK_SIZE_M": 128,
|
135 |
-
"BLOCK_SIZE_N": 64,
|
136 |
-
"BLOCK_SIZE_K": 32,
|
137 |
-
"GROUP_SIZE_M": 8,
|
138 |
-
},
|
139 |
-
num_stages=4,
|
140 |
-
num_warps=4,
|
141 |
-
),
|
142 |
-
triton.Config(
|
143 |
-
{
|
144 |
-
"BLOCK_SIZE_M": 64,
|
145 |
-
"BLOCK_SIZE_N": 128,
|
146 |
-
"BLOCK_SIZE_K": 32,
|
147 |
-
"GROUP_SIZE_M": 8,
|
148 |
-
},
|
149 |
-
num_stages=4,
|
150 |
-
num_warps=4,
|
151 |
-
),
|
152 |
-
triton.Config(
|
153 |
-
{
|
154 |
-
"BLOCK_SIZE_M": 128,
|
155 |
-
"BLOCK_SIZE_N": 32,
|
156 |
-
"BLOCK_SIZE_K": 32,
|
157 |
-
"GROUP_SIZE_M": 8,
|
158 |
-
},
|
159 |
-
num_stages=4,
|
160 |
-
num_warps=4,
|
161 |
-
),
|
162 |
-
triton.Config(
|
163 |
-
{
|
164 |
-
"BLOCK_SIZE_M": 64,
|
165 |
-
"BLOCK_SIZE_N": 32,
|
166 |
-
"BLOCK_SIZE_K": 32,
|
167 |
-
"GROUP_SIZE_M": 8,
|
168 |
-
},
|
169 |
-
num_stages=5,
|
170 |
-
num_warps=2,
|
171 |
-
),
|
172 |
-
triton.Config(
|
173 |
-
{
|
174 |
-
"BLOCK_SIZE_M": 32,
|
175 |
-
"BLOCK_SIZE_N": 64,
|
176 |
-
"BLOCK_SIZE_K": 32,
|
177 |
-
"GROUP_SIZE_M": 8,
|
178 |
-
},
|
179 |
-
num_stages=5,
|
180 |
-
num_warps=2,
|
181 |
-
),
|
182 |
-
],
|
183 |
-
key=["M", "N", "K"],
|
184 |
-
)
|
185 |
-
@triton.jit
|
186 |
-
def linear_kernel_4bit_weight(
|
187 |
-
# Pointers to matrices
|
188 |
-
a_ptr,
|
189 |
-
b_ptr,
|
190 |
-
c_ptr,
|
191 |
-
bscales_ptr,
|
192 |
-
bzeros_ptr,
|
193 |
-
# bdequant,
|
194 |
-
# Matrix dimensions
|
195 |
-
M,
|
196 |
-
N,
|
197 |
-
K,
|
198 |
-
# The stride variables represent how much to increase the ptr by when moving by 1
|
199 |
-
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
200 |
-
# by to get the element one row down (A has M rows)
|
201 |
-
stride_am,
|
202 |
-
stride_ak,
|
203 |
-
stride_bk,
|
204 |
-
stride_bn,
|
205 |
-
stride_cm,
|
206 |
-
stride_cn,
|
207 |
-
# Meta-parameters
|
208 |
-
BLOCK_SIZE_M: tl.constexpr,
|
209 |
-
BLOCK_SIZE_N: tl.constexpr,
|
210 |
-
BLOCK_SIZE_K: tl.constexpr,
|
211 |
-
GROUP_SIZE_M: tl.constexpr,
|
212 |
-
):
|
213 |
-
"""Kernel for computing the matmul C = A x B.T.
|
214 |
-
A has shape (M, K), B has shape (N, K) and C has shape (M, N)
|
215 |
-
"""
|
216 |
-
# -----------------------------------------------------------
|
217 |
-
# Map program ids `pid` to the block of C it should compute.
|
218 |
-
# This is done in a grouped ordering to promote L2 data reuse
|
219 |
-
# See above `L2 Cache Optimizations` section for details
|
220 |
-
pid = tl.program_id(axis=0)
|
221 |
-
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
222 |
-
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
223 |
-
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
224 |
-
group_id = pid // num_pid_in_group
|
225 |
-
first_pid_m = group_id * GROUP_SIZE_M
|
226 |
-
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
227 |
-
pid_m = first_pid_m + (pid % group_size_m)
|
228 |
-
pid_n = (pid % num_pid_in_group) // group_size_m
|
229 |
-
|
230 |
-
# ----------------------------------------------------------
|
231 |
-
# Create pointers for the first blocks of A and B.
|
232 |
-
# We will advance this pointer as we move in the K direction
|
233 |
-
# and accumulate
|
234 |
-
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
235 |
-
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
|
236 |
-
# see above `Pointer Arithmetics` section for details
|
237 |
-
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
238 |
-
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
239 |
-
a_mask = offs_am[:, None] < M
|
240 |
-
b_mask = offs_bn[None, :] < N
|
241 |
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
242 |
-
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
243 |
-
b_ptrs = b_ptr + (
|
244 |
-
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
|
245 |
-
)
|
246 |
-
|
247 |
-
bscales_ptrs = bscales_ptr + offs_bn[None, :]
|
248 |
-
bzeros_ptrs = bzeros_ptr + offs_bn[None, :]
|
249 |
-
|
250 |
-
scale = tl.load(bscales_ptrs)
|
251 |
-
zero = tl.load(bzeros_ptrs)
|
252 |
-
# -----------------------------------------------------------
|
253 |
-
# Iterate to compute a block of the C matrix
|
254 |
-
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
255 |
-
# of fp32 values for higher accuracy.
|
256 |
-
# `accumulator` will be converted back to fp16 after the loop
|
257 |
-
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
258 |
-
for k in range(0, K, BLOCK_SIZE_K):
|
259 |
-
# wasteful as it is to load everything twice, my attempts at avoiding it lead to slower code
|
260 |
-
b12 = tl.load(b_ptrs, mask=b_mask)
|
261 |
-
# Note that for simplicity, we don't apply a mask in K here.
|
262 |
-
a = tl.load(a_ptrs, mask=a_mask).to(tl.float32)
|
263 |
-
b = (
|
264 |
-
((b12.to(tl.uint8) >> ((offs_k[:, None] % 2) * 4)) & 0xF).to(tl.float32)
|
265 |
-
- zero
|
266 |
-
) * scale
|
267 |
-
accumulator += tl.dot(a, b)
|
268 |
-
|
269 |
-
# Advance the ptrs to the next K block
|
270 |
-
a_ptrs += BLOCK_SIZE_K * stride_ak
|
271 |
-
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
272 |
-
c = accumulator
|
273 |
-
|
274 |
-
# -----------------------------------------------------------
|
275 |
-
# Write back the block of the output matrix C
|
276 |
-
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
277 |
-
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
278 |
-
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
279 |
-
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
280 |
-
tl.store(c_ptrs, c, mask=c_mask)
|
281 |
-
|
282 |
-
def qlinear_4bit_weight(inp, weight, scales, zeros):
|
283 |
-
weight = weight.t().contiguous()
|
284 |
-
c_shape = inp.shape[:-1] + weight.shape[-1:]
|
285 |
-
inp = inp.reshape(-1, inp.shape[-1]).contiguous()
|
286 |
-
# we pad the input to amortize triton compilation cost better
|
287 |
-
PAD_TO = 256
|
288 |
-
if inp.shape[0] % PAD_TO != 0:
|
289 |
-
c_crop = inp.shape[0]
|
290 |
-
new_inp_shape0 = inp.shape[0] + PAD_TO - inp.shape[0] % PAD_TO
|
291 |
-
inp2 = inp.new_empty((new_inp_shape0, inp.shape[1]))
|
292 |
-
inp2[: inp.shape[0]] = inp
|
293 |
-
inp2[inp.shape[0] :].zero_()
|
294 |
-
inp = inp2
|
295 |
-
else:
|
296 |
-
c_crop = None
|
297 |
-
|
298 |
-
assert inp.shape[1] == weight.shape[0] * 2, "incompatible dimensions"
|
299 |
-
|
300 |
-
assert scales.shape == (weight.shape[1], 1)
|
301 |
-
assert zeros.shape == (weight.shape[1], 1)
|
302 |
-
scales = scales.contiguous()
|
303 |
-
zeros = zeros.contiguous()
|
304 |
-
K, N = weight.shape
|
305 |
-
M, K = inp.shape
|
306 |
-
assert (
|
307 |
-
K % 32 == 0
|
308 |
-
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
|
309 |
-
# allocates output
|
310 |
-
c = torch.empty((M, N), device=inp.device, dtype=inp.dtype)
|
311 |
-
# 1D launch kernel where each block gets its own program.
|
312 |
-
grid = lambda META: (
|
313 |
-
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
314 |
-
)
|
315 |
-
linear_kernel_4bit_weight[grid](
|
316 |
-
inp,
|
317 |
-
weight,
|
318 |
-
c,
|
319 |
-
scales,
|
320 |
-
zeros,
|
321 |
-
M,
|
322 |
-
N,
|
323 |
-
K,
|
324 |
-
inp.stride(0),
|
325 |
-
inp.stride(1),
|
326 |
-
weight.stride(0),
|
327 |
-
weight.stride(1),
|
328 |
-
c.stride(0),
|
329 |
-
c.stride(1),
|
330 |
-
)
|
331 |
-
return c[:c_crop].reshape(c_shape)
|
332 |
-
|
333 |
-
else:
|
334 |
-
qlinear_4bit_weight = None
|
335 |
-
|
336 |
-
|
337 |
-
# for correctness but with terrible perf
|
338 |
-
class ColBlockQuantizedLinear(torch.nn.Module):
|
339 |
-
def __init__(self, in_features, out_features, bias: bool, *, bits, tile_cols):
|
340 |
-
super().__init__()
|
341 |
-
self.in_features = in_features
|
342 |
-
self.out_features = out_features
|
343 |
-
self.tile_cols = tile_cols if tile_cols != -1 else self.in_features
|
344 |
-
self.bits = bits
|
345 |
-
self.entries_per_byte = 8 // bits
|
346 |
-
assert self.entries_per_byte > 0 and self.entries_per_byte * self.bits == 8
|
347 |
-
assert in_features % self.entries_per_byte == 0
|
348 |
-
self.register_buffer(
|
349 |
-
"quant_weight",
|
350 |
-
torch.empty(
|
351 |
-
(self.out_features, self.in_features // self.entries_per_byte),
|
352 |
-
dtype=torch.uint8,
|
353 |
-
)
|
354 |
-
.t()
|
355 |
-
.contiguous()
|
356 |
-
.t(),
|
357 |
-
)
|
358 |
-
self.register_buffer(
|
359 |
-
"scales",
|
360 |
-
torch.empty(
|
361 |
-
(
|
362 |
-
self.out_features,
|
363 |
-
(self.in_features + self.tile_cols - 1) // self.tile_cols,
|
364 |
-
)
|
365 |
-
),
|
366 |
-
)
|
367 |
-
self.register_buffer("zeros", torch.empty_like(self.scales))
|
368 |
-
assert isinstance(bias, bool)
|
369 |
-
if bias:
|
370 |
-
self.register_buffer("bias", torch.empty((self.out_features,)))
|
371 |
-
else:
|
372 |
-
self.register_buffer("bias", None)
|
373 |
-
|
374 |
-
def pack_weight(self, weight):
|
375 |
-
weight = weight.to(device=self.quant_weight.device, copy=True)
|
376 |
-
for j in range(self.scales.size(1)):
|
377 |
-
weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] /= self.scales[
|
378 |
-
:, j : j + 1
|
379 |
-
]
|
380 |
-
weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] += self.zeros[
|
381 |
-
:, j : j + 1
|
382 |
-
]
|
383 |
-
weight = weight.clamp_(min=0, max=2**self.bits - 1).to(dtype=torch.uint8)
|
384 |
-
self.quant_weight.zero_()
|
385 |
-
for nr in range(self.entries_per_byte):
|
386 |
-
self.quant_weight += weight[:, nr :: self.entries_per_byte] << (
|
387 |
-
nr * self.bits
|
388 |
-
)
|
389 |
-
|
390 |
-
def get_weight(self, dtype=torch.float):
|
391 |
-
weight = torch.empty(
|
392 |
-
(self.out_features, self.in_features),
|
393 |
-
device=self.quant_weight.device,
|
394 |
-
dtype=dtype,
|
395 |
-
)
|
396 |
-
mask = (1 << self.bits) - 1
|
397 |
-
for nr in range(self.entries_per_byte):
|
398 |
-
weight[:, nr :: self.entries_per_byte] = (
|
399 |
-
(self.quant_weight >> (nr * self.bits)) & mask
|
400 |
-
).float()
|
401 |
-
self.quant_weight.to(dtype)
|
402 |
-
for j in range(self.scales.size(1)):
|
403 |
-
weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] -= self.zeros[
|
404 |
-
:, j : j + 1
|
405 |
-
]
|
406 |
-
weight[:, j * self.tile_cols : (j + 1) * self.tile_cols] *= self.scales[
|
407 |
-
:, j : j + 1
|
408 |
-
]
|
409 |
-
return weight
|
410 |
-
|
411 |
-
def forward(self, inp):
|
412 |
-
if (
|
413 |
-
triton is not None
|
414 |
-
and self.bits == 4
|
415 |
-
and self.quant_weight.device.type == "cuda"
|
416 |
-
and self.zeros.shape[1] == 1
|
417 |
-
and self.quant_weight.shape[1] % 32 == 0
|
418 |
-
):
|
419 |
-
return qlinear_4bit_weight(inp, self.quant_weight, self.scales, self.zeros)
|
420 |
-
weight = self.get_weight(dtype=inp.dtype)
|
421 |
-
return torch.nn.functional.linear(inp, weight, self.bias)
|
422 |
-
|
423 |
-
|
424 |
-
class GPTQQuantizer:
|
425 |
-
# The algorithm and code has been taken from https://github.com/IST-DASLab/gptq/
|
426 |
-
# E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
|
427 |
-
# portions copyright by the authors licensed under the Apache License 2.0
|
428 |
-
# All errors are our own.
|
429 |
-
|
430 |
-
def __init__(
|
431 |
-
self,
|
432 |
-
linear_module,
|
433 |
-
*,
|
434 |
-
bits,
|
435 |
-
perchannel=True,
|
436 |
-
sym=False,
|
437 |
-
blocksize=128,
|
438 |
-
percdamp=0.01,
|
439 |
-
groupsize=-1,
|
440 |
-
actorder=False
|
441 |
-
):
|
442 |
-
assert isinstance(linear_module, torch.nn.Linear)
|
443 |
-
|
444 |
-
self.linear_module = linear_module
|
445 |
-
self.dev = self.linear_module.weight.device
|
446 |
-
self.rows = linear_module.weight.shape[0]
|
447 |
-
self.columns = linear_module.weight.shape[1]
|
448 |
-
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
|
449 |
-
self.nsamples = 0
|
450 |
-
self.bits = bits
|
451 |
-
self.maxq = 2**bits - 1
|
452 |
-
self.perchannel = perchannel
|
453 |
-
self.sym = sym
|
454 |
-
self.blocksize = blocksize
|
455 |
-
self.percdamp = percdamp
|
456 |
-
self.groupsize = groupsize
|
457 |
-
self.actorder = actorder
|
458 |
-
self.tile_cols = self.columns if groupsize == -1 else groupsize
|
459 |
-
self.scales = torch.zeros(
|
460 |
-
(self.rows, (self.columns + self.tile_cols - 1) // self.tile_cols),
|
461 |
-
dtype=self.linear_module.weight.dtype,
|
462 |
-
device=self.dev,
|
463 |
-
)
|
464 |
-
self.zeros = torch.zeros_like(self.scales)
|
465 |
-
assert not (
|
466 |
-
self.actorder and self.groupsize != -1
|
467 |
-
), "The permutation trick does not work for grouped quantization"
|
468 |
-
|
469 |
-
@staticmethod
|
470 |
-
def quantize_weight(x, scale, zero, maxq):
|
471 |
-
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
472 |
-
x_rec = scale * (q - zero)
|
473 |
-
return x_rec
|
474 |
-
|
475 |
-
def find_params_weight(self, x):
|
476 |
-
dev = x.device
|
477 |
-
|
478 |
-
shape = x.shape
|
479 |
-
if self.perchannel:
|
480 |
-
x = x.flatten(1)
|
481 |
-
else:
|
482 |
-
x = x.flatten().unsqueeze(0)
|
483 |
-
|
484 |
-
tmp = torch.zeros(x.shape[0], device=dev)
|
485 |
-
xmin = torch.minimum(x.min(1)[0], tmp)
|
486 |
-
xmax = torch.maximum(x.max(1)[0], tmp)
|
487 |
-
|
488 |
-
if self.sym:
|
489 |
-
xmax = torch.maximum(torch.abs(xmin), xmax)
|
490 |
-
tmp = xmin < 0
|
491 |
-
if torch.any(tmp):
|
492 |
-
xmin[tmp] = -xmax[tmp]
|
493 |
-
tmp = (xmin == 0) & (xmax == 0)
|
494 |
-
xmin[tmp] = -1
|
495 |
-
xmax[tmp] = +1
|
496 |
-
|
497 |
-
scale = (xmax - xmin) / self.maxq
|
498 |
-
if self.sym:
|
499 |
-
zero = torch.full_like(scale, (self.maxq + 1) / 2)
|
500 |
-
else:
|
501 |
-
zero = torch.round(-xmin / scale)
|
502 |
-
|
503 |
-
if not self.perchannel:
|
504 |
-
tmp = shape[0]
|
505 |
-
scale = scale.repeat(tmp)
|
506 |
-
zero = zero.repeat(tmp)
|
507 |
-
|
508 |
-
shape = [-1] + [1] * (len(shape) - 1)
|
509 |
-
scale = scale.reshape(shape)
|
510 |
-
zero = zero.reshape(shape)
|
511 |
-
return scale, zero
|
512 |
-
|
513 |
-
def collect_input_stats(self, _1, inp, _2):
|
514 |
-
inp = inp[0].detach()
|
515 |
-
self.last_inp = inp
|
516 |
-
if len(inp.shape) == 2:
|
517 |
-
inp = inp.unsqueeze(0)
|
518 |
-
tmp = inp.shape[0]
|
519 |
-
if len(inp.shape) == 3:
|
520 |
-
inp = inp.reshape((-1, inp.shape[-1]))
|
521 |
-
inp = inp.t()
|
522 |
-
self.H *= self.nsamples / (self.nsamples + tmp)
|
523 |
-
self.nsamples += tmp
|
524 |
-
# inp = inp.float()
|
525 |
-
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
526 |
-
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
|
527 |
-
self.H += inp.matmul(inp.t())
|
528 |
-
|
529 |
-
def quantize(self):
|
530 |
-
W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True)
|
531 |
-
|
532 |
-
scale, zero = self.find_params_weight(W)
|
533 |
-
self.scales[:] = scale
|
534 |
-
self.zeros[:] = zero
|
535 |
-
|
536 |
-
H = self.H
|
537 |
-
del self.H
|
538 |
-
dead = torch.diag(H) == 0
|
539 |
-
H[dead, dead] = 1
|
540 |
-
W[:, dead] = 0
|
541 |
-
if self.actorder:
|
542 |
-
perm = torch.argsort(torch.diag(H), descending=True)
|
543 |
-
W = W[:, perm]
|
544 |
-
H = H[perm][:, perm]
|
545 |
-
|
546 |
-
Losses = torch.zeros_like(W)
|
547 |
-
Q = torch.zeros_like(W)
|
548 |
-
|
549 |
-
damp = self.percdamp * torch.mean(torch.diag(H))
|
550 |
-
diag = torch.arange(self.columns, device=self.dev)
|
551 |
-
H[diag, diag] += damp
|
552 |
-
H = torch.linalg.cholesky(H)
|
553 |
-
H = torch.cholesky_inverse(H)
|
554 |
-
H = torch.linalg.cholesky(H, upper=True)
|
555 |
-
Hinv = H
|
556 |
-
|
557 |
-
for i1 in range(0, self.columns, self.blocksize):
|
558 |
-
i2 = min(i1 + self.blocksize, self.columns)
|
559 |
-
count = i2 - i1
|
560 |
-
|
561 |
-
W1 = W[:, i1:i2].clone()
|
562 |
-
Q1 = torch.zeros_like(W1)
|
563 |
-
Err1 = torch.zeros_like(W1)
|
564 |
-
Losses1 = torch.zeros_like(W1)
|
565 |
-
Hinv1 = Hinv[i1:i2, i1:i2]
|
566 |
-
|
567 |
-
for i in range(count):
|
568 |
-
w = W1[:, i]
|
569 |
-
d = Hinv1[i, i]
|
570 |
-
|
571 |
-
if self.groupsize != -1:
|
572 |
-
if (i1 + i) % self.groupsize == 0:
|
573 |
-
scale, zero = self.find_params_weight(
|
574 |
-
W[:, (i1 + i) : (i1 + i + self.groupsize)]
|
575 |
-
)
|
576 |
-
self.scales[:, (i1 + i) // self.groupsize] = scale
|
577 |
-
self.zeros[:, (i1 + i) // self.groupsize] = zero
|
578 |
-
|
579 |
-
q = self.quantize_weight(w.unsqueeze(1), scale, zero, self.maxq)
|
580 |
-
q = q.squeeze(1)
|
581 |
-
assert q.dim() == 1
|
582 |
-
Q1[:, i] = q
|
583 |
-
Losses1[:, i] = (w - q) ** 2 / d**2
|
584 |
-
|
585 |
-
err1 = (w - q) / d
|
586 |
-
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
587 |
-
Err1[:, i] = err1
|
588 |
-
|
589 |
-
Q[:, i1:i2] = Q1
|
590 |
-
Losses[:, i1:i2] = Losses1 / 2
|
591 |
-
|
592 |
-
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
593 |
-
|
594 |
-
if self.actorder:
|
595 |
-
invperm = torch.argsort(perm)
|
596 |
-
Q = Q[:, invperm]
|
597 |
-
|
598 |
-
weight = Q.reshape(self.linear_module.weight.shape).to(
|
599 |
-
self.linear_module.weight.data.dtype
|
600 |
-
)
|
601 |
-
error = torch.sum(Losses).item()
|
602 |
-
|
603 |
-
q_module = ColBlockQuantizedLinear(
|
604 |
-
self.linear_module.in_features,
|
605 |
-
self.linear_module.out_features,
|
606 |
-
self.linear_module.bias is not None,
|
607 |
-
bits=self.bits,
|
608 |
-
tile_cols=self.groupsize,
|
609 |
-
).to(self.dev)
|
610 |
-
q_module.scales = self.scales
|
611 |
-
q_module.zeros = self.zeros
|
612 |
-
q_module.pack_weight(weight)
|
613 |
-
q_module.bias = self.linear_module.bias
|
614 |
-
return q_module, error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/lit_llama/tokenizer.py
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Optional
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
|
7 |
-
|
8 |
-
|
9 |
-
class Tokenizer:
|
10 |
-
"""Tokenizer for LLaMA."""
|
11 |
-
|
12 |
-
def __init__(self, model_path: Path) -> None:
|
13 |
-
self.processor = SentencePieceProcessor(model_file=str(model_path))
|
14 |
-
self.bos_id = self.processor.bos_id()
|
15 |
-
self.eos_id = self.processor.eos_id()
|
16 |
-
self.pad_id = self.processor.pad_id()
|
17 |
-
|
18 |
-
@property
|
19 |
-
def vocab_size(self) -> int:
|
20 |
-
return self.processor.vocab_size()
|
21 |
-
|
22 |
-
def encode(
|
23 |
-
self,
|
24 |
-
string: str,
|
25 |
-
bos: bool = True,
|
26 |
-
eos: bool = False,
|
27 |
-
max_length: int = -1,
|
28 |
-
pad: bool = False,
|
29 |
-
device: Optional[torch.device] = None
|
30 |
-
) -> torch.Tensor:
|
31 |
-
tokens = self.processor.encode(string)
|
32 |
-
if bos:
|
33 |
-
tokens = [self.bos_id] + tokens
|
34 |
-
if eos:
|
35 |
-
tokens = tokens + [self.eos_id]
|
36 |
-
if max_length > 0:
|
37 |
-
tokens = tokens[:max_length]
|
38 |
-
if pad and len(tokens) < max_length:
|
39 |
-
tokens += [self.pad_id] * (max_length - len(tokens))
|
40 |
-
|
41 |
-
return torch.tensor(tokens, dtype=torch.int, device=device)
|
42 |
-
|
43 |
-
def decode(self, tokens: torch.Tensor) -> str:
|
44 |
-
return self.processor.decode(tokens.tolist())
|
45 |
-
|
46 |
-
@staticmethod
|
47 |
-
def train(input: str, destination: str, vocab_size=32000) -> None:
|
48 |
-
model_prefix = os.path.join(destination, "tokenizer")
|
49 |
-
SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/lit_llama/utils.py
DELETED
@@ -1,496 +0,0 @@
|
|
1 |
-
"""Utility functions for training and inference."""
|
2 |
-
|
3 |
-
import functools
|
4 |
-
import pickle
|
5 |
-
import warnings
|
6 |
-
from io import BytesIO
|
7 |
-
from pathlib import Path
|
8 |
-
from contextlib import contextmanager
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torch.utils._device
|
12 |
-
from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy
|
13 |
-
from torch.distributed.fsdp import FullStateDictConfig
|
14 |
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
15 |
-
from torch.distributed.fsdp import StateDictType
|
16 |
-
from torch.serialization import normalize_storage_type
|
17 |
-
|
18 |
-
llama_model_sizes = {
|
19 |
-
4096: "7B", # 7B n_embd=4096
|
20 |
-
5120: "13B", # 13B n_embd=5120
|
21 |
-
6656: "30B", # 30B n_embd=6656
|
22 |
-
8192: "65B", # 65B n_embd=8192
|
23 |
-
}
|
24 |
-
|
25 |
-
|
26 |
-
def llama_model_lookup(checkpoint: dict) -> str:
|
27 |
-
"""Returns the LLaMA model name from the checkpoint.
|
28 |
-
|
29 |
-
Checks the width of the lm_head.weight matrix, as these uniquely identify the model.
|
30 |
-
"""
|
31 |
-
embedding_size = checkpoint['transformer.wte.weight'].shape[1]
|
32 |
-
return llama_model_sizes[embedding_size]
|
33 |
-
|
34 |
-
|
35 |
-
def find_multiple(n: int, k: int) -> int:
|
36 |
-
if n % k == 0:
|
37 |
-
return n
|
38 |
-
return n + k - (n % k)
|
39 |
-
|
40 |
-
|
41 |
-
def save_model_checkpoint(fabric, model, file_path):
|
42 |
-
"""Handles boilerplate logic for retrieving and saving the state_dict.
|
43 |
-
|
44 |
-
This will be upstreamed to Fabric soon.
|
45 |
-
"""
|
46 |
-
file_path = Path(file_path)
|
47 |
-
|
48 |
-
if isinstance(fabric.strategy, DeepSpeedStrategy):
|
49 |
-
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
50 |
-
|
51 |
-
fabric.save(file_path, {"model": model})
|
52 |
-
fabric.barrier()
|
53 |
-
if fabric.global_rank == 0:
|
54 |
-
# Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
|
55 |
-
convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth"))
|
56 |
-
return
|
57 |
-
|
58 |
-
if isinstance(fabric.strategy, FSDPStrategy):
|
59 |
-
save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True)
|
60 |
-
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
|
61 |
-
state_dict = model._forward_module.state_dict()
|
62 |
-
else:
|
63 |
-
state_dict = model.state_dict()
|
64 |
-
|
65 |
-
if fabric.global_rank == 0:
|
66 |
-
torch.save(state_dict, file_path)
|
67 |
-
fabric.barrier()
|
68 |
-
|
69 |
-
|
70 |
-
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
|
71 |
-
def __init__(self, device=None, dtype=None, quantization_mode=None):
|
72 |
-
"""
|
73 |
-
Create tensors with given device and dtype and don't run initialization
|
74 |
-
(but instead use "empty tensors", i.e. uninitialized memory).
|
75 |
-
|
76 |
-
device: `torch.device` to work with
|
77 |
-
dtype: `torch.dtype` to work with
|
78 |
-
quantization_mode: optional string, quantization mode to work with, default `None`.
|
79 |
-
Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU)
|
80 |
-
`gptq.int4`, `gptq.int8`: GPTQ pre-quantized models
|
81 |
-
|
82 |
-
Example::
|
83 |
-
with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
|
84 |
-
model = LLaMA.from_name('7B')
|
85 |
-
model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))"""
|
86 |
-
|
87 |
-
self.quantization_mode = quantization_mode
|
88 |
-
self.quantized_linear_cls = None
|
89 |
-
if self.quantization_mode == 'llm.int8':
|
90 |
-
if device.type != "cuda":
|
91 |
-
raise ValueError("Quantization is only supported on the GPU.")
|
92 |
-
from .quantization import Linear8bitLt
|
93 |
-
self.quantized_linear_cls = Linear8bitLt
|
94 |
-
elif self.quantization_mode == 'gptq.int4':
|
95 |
-
from .quantization import ColBlockQuantizedLinear
|
96 |
-
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
|
97 |
-
elif self.quantization_mode == 'gptq.int8':
|
98 |
-
from .quantization import ColBlockQuantizedLinear
|
99 |
-
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
|
100 |
-
elif self.quantization_mode is not None:
|
101 |
-
raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")
|
102 |
-
self.device = device
|
103 |
-
self.dtype = dtype
|
104 |
-
|
105 |
-
def __enter__(self):
|
106 |
-
if self.quantized_linear_cls != None:
|
107 |
-
self.torch_linear_cls = torch.nn.Linear
|
108 |
-
torch.nn.Linear = self.quantized_linear_cls
|
109 |
-
return super().__enter__()
|
110 |
-
|
111 |
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
112 |
-
if self.quantized_linear_cls != None:
|
113 |
-
torch.nn.Linear = self.torch_linear_cls
|
114 |
-
return super().__exit__(exc_type, exc_val, exc_tb)
|
115 |
-
|
116 |
-
def __torch_function__(self, func, types, args=(), kwargs=None):
|
117 |
-
kwargs = kwargs or {}
|
118 |
-
if getattr(func, "__module__", None) == "torch.nn.init":
|
119 |
-
if "tensor" in kwargs:
|
120 |
-
return kwargs["tensor"]
|
121 |
-
else:
|
122 |
-
return args[0]
|
123 |
-
if (
|
124 |
-
self.device is not None
|
125 |
-
and func in torch.utils._device._device_constructors()
|
126 |
-
and kwargs.get("device") is None
|
127 |
-
):
|
128 |
-
kwargs["device"] = self.device
|
129 |
-
if (
|
130 |
-
self.dtype is not None
|
131 |
-
and func in torch.utils._device._device_constructors()
|
132 |
-
and kwargs.get("dtype") is None
|
133 |
-
):
|
134 |
-
kwargs["dtype"] = self.dtype
|
135 |
-
return func(*args, **kwargs)
|
136 |
-
|
137 |
-
|
138 |
-
@contextmanager
|
139 |
-
def quantization(mode: str = None):
|
140 |
-
quantized_linear_cls = None
|
141 |
-
if mode == 'llm.int8':
|
142 |
-
from .quantization import Linear8bitLt
|
143 |
-
quantized_linear_cls = Linear8bitLt
|
144 |
-
elif mode == 'gptq.int4':
|
145 |
-
from .quantization import ColBlockQuantizedLinear
|
146 |
-
quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
|
147 |
-
elif mode == 'gptq.int8':
|
148 |
-
from .quantization import ColBlockQuantizedLinear
|
149 |
-
quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
|
150 |
-
elif mode is not None:
|
151 |
-
raise ValueError(f"Unknown quantization mode: {mode}")
|
152 |
-
|
153 |
-
enabled = mode is not None
|
154 |
-
torch_linear_cls = torch.nn.Linear
|
155 |
-
if enabled:
|
156 |
-
torch.nn.Linear = quantized_linear_cls
|
157 |
-
yield
|
158 |
-
if enabled:
|
159 |
-
torch.nn.Linear = torch_linear_cls
|
160 |
-
|
161 |
-
|
162 |
-
# this is taken from torchhacks https://github.com/lernapparat/torchhacks
|
163 |
-
|
164 |
-
|
165 |
-
class NotYetLoadedTensor:
|
166 |
-
def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
|
167 |
-
self.metatensor = metatensor
|
168 |
-
self.archiveinfo = archiveinfo
|
169 |
-
self.storageinfo = storageinfo
|
170 |
-
self.rebuild_args = rebuild_args
|
171 |
-
|
172 |
-
@classmethod
|
173 |
-
def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None):
|
174 |
-
ret = func(*args)
|
175 |
-
if isinstance(ret, NotYetLoadedTensor):
|
176 |
-
old_lt = ret._load_tensor
|
177 |
-
|
178 |
-
def _load_tensor():
|
179 |
-
t = old_lt()
|
180 |
-
return torch._tensor._rebuild_from_type_v2(
|
181 |
-
lambda: t, new_type, (), state
|
182 |
-
)
|
183 |
-
|
184 |
-
ret._load_tensor = _load_tensor
|
185 |
-
return ret
|
186 |
-
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)
|
187 |
-
|
188 |
-
@classmethod
|
189 |
-
def rebuild_parameter(
|
190 |
-
cls, data, requires_grad, backward_hooks, *, archiveinfo=None
|
191 |
-
):
|
192 |
-
if isinstance(data, NotYetLoadedTensor):
|
193 |
-
old_lt = data._load_tensor
|
194 |
-
|
195 |
-
def _load_tensor():
|
196 |
-
t = old_lt()
|
197 |
-
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)
|
198 |
-
|
199 |
-
data._load_tensor = _load_tensor
|
200 |
-
return data
|
201 |
-
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)
|
202 |
-
|
203 |
-
@classmethod
|
204 |
-
def rebuild_tensor_v2(
|
205 |
-
cls,
|
206 |
-
storage,
|
207 |
-
storage_offset,
|
208 |
-
size,
|
209 |
-
stride,
|
210 |
-
requires_grad,
|
211 |
-
backward_hooks,
|
212 |
-
metadata=None,
|
213 |
-
*,
|
214 |
-
archiveinfo=None,
|
215 |
-
):
|
216 |
-
rebuild_args = (
|
217 |
-
storage_offset,
|
218 |
-
size,
|
219 |
-
stride,
|
220 |
-
requires_grad,
|
221 |
-
backward_hooks,
|
222 |
-
metadata,
|
223 |
-
)
|
224 |
-
metatensor = torch._utils._rebuild_tensor_v2(
|
225 |
-
storage,
|
226 |
-
storage_offset,
|
227 |
-
size,
|
228 |
-
stride,
|
229 |
-
requires_grad,
|
230 |
-
backward_hooks,
|
231 |
-
metadata,
|
232 |
-
)
|
233 |
-
storageinfo = storage.archiveinfo
|
234 |
-
return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)
|
235 |
-
|
236 |
-
def _load_tensor(self):
|
237 |
-
name, storage_cls, fn, device, size = self.storageinfo
|
238 |
-
dtype = self.metatensor.dtype
|
239 |
-
|
240 |
-
uts = (
|
241 |
-
self.archiveinfo.zipfile_context.zf.get_storage_from_record(
|
242 |
-
f"data/{fn}",
|
243 |
-
size * torch._utils._element_size(dtype),
|
244 |
-
torch.UntypedStorage,
|
245 |
-
)
|
246 |
-
._typed_storage()
|
247 |
-
._untyped_storage
|
248 |
-
)
|
249 |
-
with warnings.catch_warnings():
|
250 |
-
warnings.simplefilter("ignore")
|
251 |
-
storage = torch.storage.TypedStorage(
|
252 |
-
wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True
|
253 |
-
)
|
254 |
-
tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
|
255 |
-
return tensor
|
256 |
-
|
257 |
-
@classmethod
|
258 |
-
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
259 |
-
if kwargs is None:
|
260 |
-
kwargs = {}
|
261 |
-
loaded_args = [
|
262 |
-
(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args
|
263 |
-
]
|
264 |
-
res = func(*loaded_args, **kwargs)
|
265 |
-
# gc.collect would be costly here, maybe do it optionally
|
266 |
-
return res
|
267 |
-
|
268 |
-
def __getattr__(self, name):
|
269 |
-
# properties
|
270 |
-
## TODO: device, is_...??
|
271 |
-
## TODO: mH, mT, H, T, data, imag, real
|
272 |
-
## name ???
|
273 |
-
if name in {
|
274 |
-
"dtype",
|
275 |
-
"grad",
|
276 |
-
"grad_fn",
|
277 |
-
"layout",
|
278 |
-
"names",
|
279 |
-
"ndim",
|
280 |
-
"output_nr",
|
281 |
-
"requires_grad",
|
282 |
-
"retains_grad",
|
283 |
-
"shape",
|
284 |
-
"volatile",
|
285 |
-
}:
|
286 |
-
return getattr(self.metatensor, name)
|
287 |
-
if name in {"size"}:
|
288 |
-
return getattr(self.metatensor, name)
|
289 |
-
# materializing with contiguous is needed for quantization
|
290 |
-
if name in {"contiguous"}:
|
291 |
-
return getattr(self._load_tensor(), name)
|
292 |
-
|
293 |
-
raise AttributeError(f"{type(self)} does not have {name}")
|
294 |
-
|
295 |
-
def __repr__(self):
|
296 |
-
return f"NotYetLoadedTensor({repr(self.metatensor)})"
|
297 |
-
|
298 |
-
|
299 |
-
class LazyLoadingUnpickler(pickle.Unpickler):
|
300 |
-
def __init__(self, file, zipfile_context):
|
301 |
-
super().__init__(file)
|
302 |
-
self.zipfile_context = zipfile_context
|
303 |
-
|
304 |
-
def find_class(self, module, name):
|
305 |
-
res = super().find_class(module, name)
|
306 |
-
if module == "torch._utils" and name == "_rebuild_tensor_v2":
|
307 |
-
return functools.partial(
|
308 |
-
NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self
|
309 |
-
)
|
310 |
-
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
311 |
-
return functools.partial(
|
312 |
-
NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self
|
313 |
-
)
|
314 |
-
elif module == "torch._utils" and name == "_rebuild_parameter":
|
315 |
-
return functools.partial(
|
316 |
-
NotYetLoadedTensor.rebuild_parameter, archiveinfo=self
|
317 |
-
)
|
318 |
-
return res
|
319 |
-
|
320 |
-
def persistent_load(self, pid):
|
321 |
-
name, cls, fn, device, size = pid
|
322 |
-
with warnings.catch_warnings():
|
323 |
-
warnings.simplefilter("ignore")
|
324 |
-
s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
|
325 |
-
s.archiveinfo = pid
|
326 |
-
return s
|
327 |
-
|
328 |
-
|
329 |
-
class lazy_load:
|
330 |
-
def __init__(self, fn):
|
331 |
-
self.zf = torch._C.PyTorchFileReader(str(fn))
|
332 |
-
with BytesIO(self.zf.get_record("data.pkl")) as pkl:
|
333 |
-
mup = LazyLoadingUnpickler(pkl, self)
|
334 |
-
self.sd = mup.load()
|
335 |
-
|
336 |
-
def __enter__(self):
|
337 |
-
return self.sd
|
338 |
-
|
339 |
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
340 |
-
del self.zf # I don't think there is a way to force closing...
|
341 |
-
self.zf = None
|
342 |
-
|
343 |
-
|
344 |
-
class SavingProxyForStorage:
|
345 |
-
def __init__(self, obj, saver, protocol_version=5):
|
346 |
-
self.protocol_version = protocol_version
|
347 |
-
self.saver = saver
|
348 |
-
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
|
349 |
-
raise TypeError(f"expected storage, not {type(obj)}")
|
350 |
-
|
351 |
-
# this logic is taken from PyTorch 2.0+ torch/serialization.py
|
352 |
-
if isinstance(obj, torch.storage.TypedStorage):
|
353 |
-
# PT upstream wants to deprecate this eventually...
|
354 |
-
storage = obj._untyped_storage
|
355 |
-
storage_type_str = obj._pickle_storage_type()
|
356 |
-
storage_type = getattr(torch, storage_type_str)
|
357 |
-
storage_numel = obj._size()
|
358 |
-
else:
|
359 |
-
storage = obj
|
360 |
-
storage_type = normalize_storage_type(type(obj))
|
361 |
-
storage_numel = storage.nbytes()
|
362 |
-
|
363 |
-
storage_key = saver._write_storage_and_return_key(storage)
|
364 |
-
location = torch.serialization.location_tag(storage)
|
365 |
-
|
366 |
-
self.storage_info = (
|
367 |
-
"storage",
|
368 |
-
storage_type,
|
369 |
-
storage_key,
|
370 |
-
location,
|
371 |
-
storage_numel,
|
372 |
-
)
|
373 |
-
|
374 |
-
def __reduce_ex__(self, protocol_version):
|
375 |
-
assert False, "this should be handled with out of band"
|
376 |
-
|
377 |
-
|
378 |
-
class SavingProxyForTensor:
|
379 |
-
def __init__(self, tensor, saver, protocol_version=5):
|
380 |
-
self.protocol_version = protocol_version
|
381 |
-
self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(
|
382 |
-
protocol_version
|
383 |
-
)
|
384 |
-
assert isinstance(
|
385 |
-
storage, torch.storage.TypedStorage
|
386 |
-
), "Please check for updates"
|
387 |
-
storage_proxy = SavingProxyForStorage(
|
388 |
-
storage, saver, protocol_version=protocol_version
|
389 |
-
)
|
390 |
-
self.reduce_args = (storage_proxy, *other_reduce_args)
|
391 |
-
|
392 |
-
def __reduce_ex__(self, protocol_version):
|
393 |
-
if protocol_version != self.protocol_version:
|
394 |
-
raise RuntimeError(
|
395 |
-
f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
|
396 |
-
)
|
397 |
-
return self.reduce_ret_fn, self.reduce_args
|
398 |
-
|
399 |
-
|
400 |
-
class IncrementalPyTorchPickler(pickle.Pickler):
|
401 |
-
def __init__(self, saver, *args, **kwargs):
|
402 |
-
super().__init__(*args, **kwargs)
|
403 |
-
self.storage_dtypes = {}
|
404 |
-
self.saver = saver
|
405 |
-
self.id_map = {}
|
406 |
-
|
407 |
-
# this logic is taken from PyTorch 2.0+ torch/serialization.py
|
408 |
-
def persistent_id(self, obj):
|
409 |
-
# FIXME: the docs say that persistent_id should only return a string
|
410 |
-
# but torch store returns tuples. This works only in the binary protocol
|
411 |
-
# see
|
412 |
-
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
|
413 |
-
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
|
414 |
-
if isinstance(obj, SavingProxyForStorage):
|
415 |
-
return obj.storage_info
|
416 |
-
|
417 |
-
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
|
418 |
-
if isinstance(obj, torch.storage.TypedStorage):
|
419 |
-
# TODO: Once we decide to break serialization FC, this case
|
420 |
-
# can be deleted
|
421 |
-
storage = obj._untyped_storage
|
422 |
-
storage_dtype = obj.dtype
|
423 |
-
storage_type_str = obj._pickle_storage_type()
|
424 |
-
storage_type = getattr(torch, storage_type_str)
|
425 |
-
storage_numel = obj._size()
|
426 |
-
|
427 |
-
else:
|
428 |
-
storage = obj
|
429 |
-
storage_dtype = torch.uint8
|
430 |
-
storage_type = normalize_storage_type(type(obj))
|
431 |
-
storage_numel = storage.nbytes()
|
432 |
-
|
433 |
-
# If storage is allocated, ensure that any other saved storages
|
434 |
-
# pointing to the same data all have the same dtype. If storage is
|
435 |
-
# not allocated, don't perform this check
|
436 |
-
if storage.data_ptr() != 0:
|
437 |
-
if storage.data_ptr() in self.storage_dtypes:
|
438 |
-
if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
|
439 |
-
raise RuntimeError(
|
440 |
-
"Cannot save multiple tensors or storages that "
|
441 |
-
"view the same data as different types"
|
442 |
-
)
|
443 |
-
else:
|
444 |
-
self.storage_dtypes[storage.data_ptr()] = storage_dtype
|
445 |
-
|
446 |
-
storage_key = self.id_map.get(storage._cdata)
|
447 |
-
if storage_key is None:
|
448 |
-
storage_key = self.saver._write_storage_and_return_key(storage)
|
449 |
-
self.id_map[storage._cdata] = storage_key
|
450 |
-
location = torch.serialization.location_tag(storage)
|
451 |
-
|
452 |
-
return ("storage", storage_type, storage_key, location, storage_numel)
|
453 |
-
|
454 |
-
return None
|
455 |
-
|
456 |
-
|
457 |
-
class incremental_save:
|
458 |
-
def __init__(self, name):
|
459 |
-
self.name = name
|
460 |
-
self.zipfile = torch._C.PyTorchFileWriter(str(name))
|
461 |
-
self.has_saved = False
|
462 |
-
self.next_key = 0
|
463 |
-
|
464 |
-
def __enter__(self):
|
465 |
-
return self
|
466 |
-
|
467 |
-
def store_early(self, tensor):
|
468 |
-
if isinstance(tensor, torch.Tensor):
|
469 |
-
return SavingProxyForTensor(tensor, self)
|
470 |
-
raise TypeError(f"can only store tensors early, not {type(tensor)}")
|
471 |
-
|
472 |
-
def save(self, obj):
|
473 |
-
if self.has_saved:
|
474 |
-
raise RuntimeError("have already saved")
|
475 |
-
# Write the pickle data for `obj`
|
476 |
-
data_buf = BytesIO()
|
477 |
-
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
|
478 |
-
pickler.dump(obj)
|
479 |
-
data_value = data_buf.getvalue()
|
480 |
-
self.zipfile.write_record("data.pkl", data_value, len(data_value))
|
481 |
-
self.has_saved = True
|
482 |
-
|
483 |
-
def _write_storage_and_return_key(self, storage):
|
484 |
-
if self.has_saved:
|
485 |
-
raise RuntimeError("have already saved")
|
486 |
-
key = self.next_key
|
487 |
-
self.next_key += 1
|
488 |
-
name = f"data/{key}"
|
489 |
-
if storage.device.type != "cpu":
|
490 |
-
storage = storage.cpu()
|
491 |
-
num_bytes = storage.nbytes()
|
492 |
-
self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
|
493 |
-
return key
|
494 |
-
|
495 |
-
def __exit__(self, type, value, traceback):
|
496 |
-
self.zipfile.write_end_of_file()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/pretrain/redpajama.py
DELETED
@@ -1,321 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import math
|
4 |
-
import glob
|
5 |
-
import time
|
6 |
-
from functools import partial
|
7 |
-
from pathlib import Path
|
8 |
-
from typing import Tuple, Optional
|
9 |
-
|
10 |
-
import lightning as L
|
11 |
-
from lightning.fabric.strategies import FSDPStrategy
|
12 |
-
|
13 |
-
import torch
|
14 |
-
from torch.utils.data import DataLoader
|
15 |
-
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
16 |
-
|
17 |
-
import numpy as np
|
18 |
-
|
19 |
-
# support running without installing as a package
|
20 |
-
wd = Path(__file__).parent.parent.resolve()
|
21 |
-
sys.path.append(str(wd))
|
22 |
-
|
23 |
-
from lit_llama.model import Block, LLaMA, LLaMAConfig
|
24 |
-
from lit_llama.packed_dataset import PackedDataset, CombinedDataset
|
25 |
-
from lit_llama.utils import save_model_checkpoint
|
26 |
-
|
27 |
-
|
28 |
-
out_dir = "out/training"
|
29 |
-
save_interval = 1000
|
30 |
-
eval_interval = 1000
|
31 |
-
eval_iters = 100
|
32 |
-
log_interval = 1
|
33 |
-
|
34 |
-
# compile = False
|
35 |
-
|
36 |
-
# Hyperparameters
|
37 |
-
learning_rate = 6e-4
|
38 |
-
batch_size = 125
|
39 |
-
micro_batch_size = 5
|
40 |
-
max_iters = 600000 # num_epochs * (epoch_size // micro_batch_size) // devices
|
41 |
-
weight_decay = 1e-1
|
42 |
-
beta1 = 0.9
|
43 |
-
beta2 = 0.95
|
44 |
-
grad_clip = 1.0
|
45 |
-
decay_lr = True
|
46 |
-
warmup_iters = 2000
|
47 |
-
lr_decay_iters = max_iters
|
48 |
-
min_lr = 6e-5
|
49 |
-
|
50 |
-
|
51 |
-
# Data proportions from https://arxiv.org/pdf/2302.13971.pdf Table 1
|
52 |
-
data_config = [
|
53 |
-
("arxiv", 2.5),
|
54 |
-
("book", 4.5),
|
55 |
-
("c4", 15.0),
|
56 |
-
("cc", 67.0),
|
57 |
-
("github", 4.5),
|
58 |
-
("stackexchange", 2.0),
|
59 |
-
("wikipedia", 4.5),
|
60 |
-
]
|
61 |
-
|
62 |
-
|
63 |
-
def main(
|
64 |
-
devices: int = 4,
|
65 |
-
train_data_dir: Path = "data/lit-redpajama",
|
66 |
-
val_data_dir: Optional[Path] = None,
|
67 |
-
) -> None:
|
68 |
-
auto_wrap_policy = partial(
|
69 |
-
transformer_auto_wrap_policy, transformer_layer_cls={Block}
|
70 |
-
)
|
71 |
-
strategy = FSDPStrategy(
|
72 |
-
auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True
|
73 |
-
)
|
74 |
-
|
75 |
-
fabric = L.Fabric(
|
76 |
-
accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy
|
77 |
-
)
|
78 |
-
fabric.launch()
|
79 |
-
fabric.seed_everything(1337)
|
80 |
-
|
81 |
-
if fabric.global_rank == 0:
|
82 |
-
os.makedirs(out_dir, exist_ok=True)
|
83 |
-
|
84 |
-
config = LLaMAConfig.from_name("7B")
|
85 |
-
|
86 |
-
train_dataloader, val_dataloader = create_dataloaders(
|
87 |
-
batch_size=micro_batch_size,
|
88 |
-
block_size=config.block_size,
|
89 |
-
fabric=fabric,
|
90 |
-
train_data_dir=train_data_dir,
|
91 |
-
val_data_dir=val_data_dir,
|
92 |
-
seed=1338,
|
93 |
-
)
|
94 |
-
if val_dataloader is None:
|
95 |
-
train_dataloader = fabric.setup_dataloaders(train_dataloader)
|
96 |
-
else:
|
97 |
-
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
|
98 |
-
|
99 |
-
with fabric.device:
|
100 |
-
torch.set_default_dtype(torch.bfloat16)
|
101 |
-
model = LLaMA(config)
|
102 |
-
model.apply(model._init_weights)
|
103 |
-
torch.set_default_dtype(torch.float32)
|
104 |
-
|
105 |
-
# if compile:
|
106 |
-
# model = torch.compile(model)
|
107 |
-
|
108 |
-
optimizer = torch.optim.AdamW(
|
109 |
-
model.parameters(),
|
110 |
-
lr=learning_rate,
|
111 |
-
weight_decay=weight_decay,
|
112 |
-
betas=(beta1, beta2),
|
113 |
-
foreach=False,
|
114 |
-
)
|
115 |
-
|
116 |
-
model, optimizer = fabric.setup(model, optimizer)
|
117 |
-
|
118 |
-
process_batch_size = batch_size // devices
|
119 |
-
gradient_accumulation_iters = process_batch_size // micro_batch_size
|
120 |
-
|
121 |
-
train(fabric, model, optimizer, train_dataloader, val_dataloader, gradient_accumulation_iters, devices)
|
122 |
-
|
123 |
-
|
124 |
-
def train(
|
125 |
-
fabric: L.Fabric,
|
126 |
-
model: torch.nn.Module,
|
127 |
-
optimizer: torch.optim.Optimizer,
|
128 |
-
train_dataloader: DataLoader,
|
129 |
-
val_dataloader: Optional[DataLoader],
|
130 |
-
grad_accum_steps: int,
|
131 |
-
devices: int,
|
132 |
-
) -> None:
|
133 |
-
"""The training loop.
|
134 |
-
|
135 |
-
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
|
136 |
-
"""
|
137 |
-
|
138 |
-
step_count = 0
|
139 |
-
|
140 |
-
step_time = 0.0
|
141 |
-
tokens = 0
|
142 |
-
tokens_sec = 0.0
|
143 |
-
prev_t1 = time.time()
|
144 |
-
|
145 |
-
for iter_num, train_data in enumerate(train_dataloader):
|
146 |
-
t0 = time.time()
|
147 |
-
|
148 |
-
# determine and set the learning rate for this iteration
|
149 |
-
lr = get_lr(iter_num) if decay_lr else learning_rate
|
150 |
-
for param_group in optimizer.param_groups:
|
151 |
-
param_group["lr"] = lr
|
152 |
-
|
153 |
-
|
154 |
-
input_ids = train_data[:, 0 : model.config.block_size].contiguous()
|
155 |
-
targets = train_data[:, 1 : model.config.block_size + 1].contiguous()
|
156 |
-
|
157 |
-
is_accumulating = (iter_num + 1) % grad_accum_steps != 0
|
158 |
-
|
159 |
-
with fabric.no_backward_sync(model, enabled=is_accumulating):
|
160 |
-
logits = model(input_ids)
|
161 |
-
loss = torch.nn.functional.cross_entropy(
|
162 |
-
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
|
163 |
-
)
|
164 |
-
fabric.backward(loss / grad_accum_steps)
|
165 |
-
|
166 |
-
t1 = time.time()
|
167 |
-
|
168 |
-
if not is_accumulating:
|
169 |
-
fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
|
170 |
-
|
171 |
-
optimizer.step()
|
172 |
-
optimizer.zero_grad()
|
173 |
-
step_count += 1
|
174 |
-
|
175 |
-
t1 = time.time()
|
176 |
-
|
177 |
-
if val_dataloader is not None and step_count % eval_interval == 0:
|
178 |
-
val_loss = validate(fabric, model, val_dataloader)
|
179 |
-
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
|
180 |
-
fabric.barrier()
|
181 |
-
fabric.log_dict(
|
182 |
-
{"iter": iter_num, "val_loss": val_loss, "step": step_count, "lr": lr}
|
183 |
-
)
|
184 |
-
|
185 |
-
if step_count % save_interval == 0:
|
186 |
-
fabric.print(f"Saving checkpoint to {out_dir}")
|
187 |
-
save_model_checkpoint(
|
188 |
-
fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth")
|
189 |
-
)
|
190 |
-
|
191 |
-
dt = t1 - t0
|
192 |
-
|
193 |
-
tokens += micro_batch_size * model.config.block_size
|
194 |
-
step_time += t1 - prev_t1
|
195 |
-
prev_t1 = t1
|
196 |
-
|
197 |
-
if iter_num % log_interval == 0:
|
198 |
-
tokens_sec_str = f"{tokens / step_time:.0f}" if not is_accumulating else "-"
|
199 |
-
|
200 |
-
fabric.log_dict(
|
201 |
-
{"iter": iter_num, "train_loss": loss, "step": step_count, "lr": lr}
|
202 |
-
)
|
203 |
-
fabric.print(
|
204 |
-
f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms, speed: {tokens_sec_str} toks/s/device"
|
205 |
-
)
|
206 |
-
|
207 |
-
if not is_accumulating:
|
208 |
-
tokens = 0
|
209 |
-
step_time = 0.0
|
210 |
-
|
211 |
-
if iter_num > max_iters:
|
212 |
-
break
|
213 |
-
|
214 |
-
|
215 |
-
@torch.no_grad()
|
216 |
-
def validate(
|
217 |
-
fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader
|
218 |
-
) -> torch.Tensor:
|
219 |
-
fabric.print("Validating ...")
|
220 |
-
model.eval()
|
221 |
-
losses = torch.zeros(eval_iters)
|
222 |
-
for k, val_data in enumerate(val_dataloader):
|
223 |
-
input_ids = val_data[:, 0 : model.config.block_size].contiguous()
|
224 |
-
targets = val_data[:, 1 : model.config.block_size + 1].contiguous()
|
225 |
-
logits = model(input_ids)
|
226 |
-
loss = torch.nn.functional.cross_entropy(
|
227 |
-
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
|
228 |
-
)
|
229 |
-
losses[k] = loss.item()
|
230 |
-
out = losses.mean()
|
231 |
-
model.train()
|
232 |
-
return out
|
233 |
-
|
234 |
-
|
235 |
-
def create_dataloader(
|
236 |
-
batch_size: int,
|
237 |
-
block_size: int,
|
238 |
-
data_dir: str,
|
239 |
-
fabric,
|
240 |
-
shuffle: bool = True,
|
241 |
-
seed: int = 12345,
|
242 |
-
) -> DataLoader:
|
243 |
-
datasets = []
|
244 |
-
for prefix, _ in data_config:
|
245 |
-
filenames = glob.glob(os.path.join(data_dir, prefix + "*"))
|
246 |
-
dataset = PackedDataset(
|
247 |
-
filenames, n_chunks=4, block_size=block_size, shuffle=shuffle, seed=seed,
|
248 |
-
num_processes=fabric.world_size, process_rank=fabric.global_rank,
|
249 |
-
)
|
250 |
-
datasets.append(dataset)
|
251 |
-
|
252 |
-
if not datasets:
|
253 |
-
raise RuntimeError(
|
254 |
-
f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset."
|
255 |
-
)
|
256 |
-
|
257 |
-
weights = [weight for _, weight in data_config]
|
258 |
-
sum_weights = sum(weights)
|
259 |
-
weights = [el / sum_weights for el in weights]
|
260 |
-
|
261 |
-
combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)
|
262 |
-
|
263 |
-
return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
|
264 |
-
|
265 |
-
|
266 |
-
def create_dataloaders(
|
267 |
-
batch_size: int,
|
268 |
-
block_size: int,
|
269 |
-
fabric,
|
270 |
-
train_data_dir: str = "data/lit-redpajama",
|
271 |
-
val_data_dir: Optional[str] = None,
|
272 |
-
seed: int = 12345,
|
273 |
-
) -> Tuple[DataLoader, DataLoader]:
|
274 |
-
# Increase by one because we need the next word as well
|
275 |
-
effective_block_size = block_size + 1
|
276 |
-
train_dataloader = create_dataloader(
|
277 |
-
batch_size=batch_size,
|
278 |
-
block_size=effective_block_size,
|
279 |
-
fabric=fabric,
|
280 |
-
data_dir=train_data_dir,
|
281 |
-
shuffle=True,
|
282 |
-
seed=seed,
|
283 |
-
)
|
284 |
-
val_dataloader = (
|
285 |
-
create_dataloader(
|
286 |
-
batch_size=batch_size,
|
287 |
-
block_size=effective_block_size,
|
288 |
-
fabric=fabric,
|
289 |
-
data_dir=val_data_dir,
|
290 |
-
shuffle=False,
|
291 |
-
seed=seed,
|
292 |
-
)
|
293 |
-
if val_data_dir
|
294 |
-
else None
|
295 |
-
)
|
296 |
-
return train_dataloader, val_dataloader
|
297 |
-
|
298 |
-
|
299 |
-
# learning rate decay scheduler (cosine with warmup)
|
300 |
-
def get_lr(it):
|
301 |
-
# 1) linear warmup for warmup_iters steps
|
302 |
-
if it < warmup_iters:
|
303 |
-
return learning_rate * it / warmup_iters
|
304 |
-
# 2) if it > lr_decay_iters, return min learning rate
|
305 |
-
if it > lr_decay_iters:
|
306 |
-
return min_lr
|
307 |
-
# 3) in between, use cosine decay down to min learning rate
|
308 |
-
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
|
309 |
-
assert 0 <= decay_ratio <= 1
|
310 |
-
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
|
311 |
-
return min_lr + coeff * (learning_rate - min_lr)
|
312 |
-
|
313 |
-
|
314 |
-
if __name__ == "__main__":
|
315 |
-
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
|
316 |
-
# torch.backends.cuda.enable_flash_sdp(False)
|
317 |
-
torch.set_float32_matmul_precision("high")
|
318 |
-
|
319 |
-
from jsonargparse.cli import CLI
|
320 |
-
|
321 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/pretrain/shakespeare.py
DELETED
@@ -1,166 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This script is a placeholder for training LLaMA from scratch.
|
3 |
-
Currently, it just trains on the Shakespeare dataset.
|
4 |
-
"""
|
5 |
-
from pathlib import Path
|
6 |
-
import sys
|
7 |
-
import os
|
8 |
-
import time
|
9 |
-
from functools import partial
|
10 |
-
from typing import Tuple
|
11 |
-
|
12 |
-
import lightning as L
|
13 |
-
from lightning.fabric.strategies import FSDPStrategy
|
14 |
-
|
15 |
-
import torch
|
16 |
-
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
17 |
-
|
18 |
-
import numpy as np
|
19 |
-
|
20 |
-
# support running without installing as a package
|
21 |
-
wd = Path(__file__).parent.parent.resolve()
|
22 |
-
sys.path.append(str(wd))
|
23 |
-
|
24 |
-
from lit_llama.model import Block, LLaMA, LLaMAConfig
|
25 |
-
from lit_llama.utils import save_model_checkpoint
|
26 |
-
|
27 |
-
|
28 |
-
out_dir = "out/training"
|
29 |
-
eval_interval = 2000
|
30 |
-
eval_iters = 200
|
31 |
-
log_interval = 1
|
32 |
-
# compilation fails as it does not support torch.complex64 for RoPE
|
33 |
-
# compile = False
|
34 |
-
|
35 |
-
# Hyperparameters
|
36 |
-
learning_rate = 6e-4
|
37 |
-
batch_size = 2
|
38 |
-
max_iters = 600000
|
39 |
-
weight_decay = 1e-1
|
40 |
-
beta1 = 0.9
|
41 |
-
beta2 = 0.95
|
42 |
-
grad_clip = 1.0
|
43 |
-
|
44 |
-
# For shakespeare, choose smaller block size than vanilla LLaMA
|
45 |
-
block_size = 1024
|
46 |
-
|
47 |
-
|
48 |
-
def main() -> None:
|
49 |
-
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
|
50 |
-
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)
|
51 |
-
|
52 |
-
fabric = L.Fabric(accelerator="cuda", devices=4, precision="bf16-mixed", strategy=strategy)
|
53 |
-
fabric.launch()
|
54 |
-
fabric.seed_everything(1337 + fabric.global_rank)
|
55 |
-
|
56 |
-
if fabric.global_rank == 0:
|
57 |
-
os.makedirs(out_dir, exist_ok=True)
|
58 |
-
|
59 |
-
train_data, val_data = load_datasets()
|
60 |
-
|
61 |
-
config = LLaMAConfig.from_name("7B")
|
62 |
-
config.block_size = block_size
|
63 |
-
config.vocab_size = 100 # from prepare_shakespeare.py
|
64 |
-
|
65 |
-
with fabric.device:
|
66 |
-
model = LLaMA(config)
|
67 |
-
|
68 |
-
# if compile:
|
69 |
-
# model = torch.compile(model)
|
70 |
-
|
71 |
-
model = fabric.setup_module(model)
|
72 |
-
|
73 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False)
|
74 |
-
optimizer = fabric.setup_optimizers(optimizer)
|
75 |
-
|
76 |
-
train(fabric, model, optimizer, train_data, val_data)
|
77 |
-
|
78 |
-
|
79 |
-
def train(
|
80 |
-
fabric: L.Fabric,
|
81 |
-
model: torch.nn.Module,
|
82 |
-
optimizer: torch.optim.Optimizer,
|
83 |
-
train_data: np.ndarray,
|
84 |
-
val_data: np.ndarray,
|
85 |
-
) -> None:
|
86 |
-
"""The training loop.
|
87 |
-
|
88 |
-
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
|
89 |
-
"""
|
90 |
-
|
91 |
-
iter_num = 0
|
92 |
-
|
93 |
-
while True:
|
94 |
-
# TODO: add learning rate scheduling
|
95 |
-
|
96 |
-
# evaluate the loss on train/val sets and write checkpoints
|
97 |
-
if iter_num > 0 and iter_num % eval_interval == 0:
|
98 |
-
val_loss = validate(fabric, model, val_data)
|
99 |
-
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
|
100 |
-
fabric.print(f"Saving checkpoint to {out_dir}")
|
101 |
-
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))
|
102 |
-
|
103 |
-
t0 = time.time()
|
104 |
-
|
105 |
-
input_ids, targets = get_batch(
|
106 |
-
fabric,
|
107 |
-
train_data,
|
108 |
-
block_size=model.config.block_size, # type: ignore[union-attr,arg-type]
|
109 |
-
)
|
110 |
-
logits = model(input_ids)
|
111 |
-
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
112 |
-
|
113 |
-
fabric.backward(loss)
|
114 |
-
|
115 |
-
# TODO: Gradient clipping
|
116 |
-
# if grad_clip != 0.0:
|
117 |
-
# fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
|
118 |
-
|
119 |
-
optimizer.step()
|
120 |
-
optimizer.zero_grad()
|
121 |
-
|
122 |
-
dt = time.time() - t0
|
123 |
-
if iter_num % log_interval == 0:
|
124 |
-
fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
|
125 |
-
iter_num += 1
|
126 |
-
|
127 |
-
if iter_num > max_iters:
|
128 |
-
break
|
129 |
-
|
130 |
-
|
131 |
-
@torch.no_grad()
|
132 |
-
def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
|
133 |
-
fabric.print("Validating ...")
|
134 |
-
model.eval()
|
135 |
-
losses = torch.zeros(eval_iters)
|
136 |
-
for k in range(eval_iters):
|
137 |
-
input_ids, targets = get_batch(
|
138 |
-
fabric,
|
139 |
-
val_data,
|
140 |
-
block_size=model.config.block_size, # type: ignore[union-attr,arg-type]
|
141 |
-
)
|
142 |
-
logits = model(input_ids)
|
143 |
-
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
144 |
-
losses[k] = loss.item()
|
145 |
-
out = losses.mean()
|
146 |
-
model.train()
|
147 |
-
return out
|
148 |
-
|
149 |
-
|
150 |
-
def get_batch(fabric: L.Fabric, data: np.ndarray, block_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
151 |
-
ix = torch.randint(len(data) - block_size, (batch_size,))
|
152 |
-
x = torch.stack([torch.from_numpy((data[i : i + block_size]).astype(np.int64)) for i in ix])
|
153 |
-
y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + block_size]).astype(np.int64)) for i in ix])
|
154 |
-
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
|
155 |
-
return x, y
|
156 |
-
|
157 |
-
|
158 |
-
def load_datasets(data_dir: str = "data/shakespeare") -> Tuple[np.ndarray, np.ndarray]:
|
159 |
-
train_data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r")
|
160 |
-
val_data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r")
|
161 |
-
return train_data, val_data
|
162 |
-
|
163 |
-
|
164 |
-
if __name__ == "__main__":
|
165 |
-
torch.set_float32_matmul_precision("high")
|
166 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/quantize/gptq.py
DELETED
@@ -1,238 +0,0 @@
|
|
1 |
-
# This adapts GPTQ's quantization process: https://github.com/IST-DASLab/gptq/
|
2 |
-
# E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
|
3 |
-
# portions copyright by the authors licensed under the Apache License 2.0
|
4 |
-
import gc
|
5 |
-
import sys
|
6 |
-
import time
|
7 |
-
from pathlib import Path
|
8 |
-
from typing import Optional
|
9 |
-
|
10 |
-
import torch
|
11 |
-
from datasets import load_dataset
|
12 |
-
|
13 |
-
# support running without installing as a package
|
14 |
-
wd = Path(__file__).parent.parent.resolve()
|
15 |
-
sys.path.append(str(wd))
|
16 |
-
|
17 |
-
from lit_llama import LLaMA, Tokenizer
|
18 |
-
from lit_llama.quantization import GPTQQuantizer
|
19 |
-
from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup
|
20 |
-
|
21 |
-
|
22 |
-
def get_sample_data():
|
23 |
-
traindata = load_dataset(
|
24 |
-
"allenai/c4",
|
25 |
-
"allenai--c4",
|
26 |
-
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
27 |
-
split="train",
|
28 |
-
)
|
29 |
-
# heuristic for the data size?
|
30 |
-
txt = "\n".join(
|
31 |
-
traindata[i]["text"] for i in torch.randperm(len(traindata))[:1000].tolist()
|
32 |
-
)
|
33 |
-
return txt
|
34 |
-
|
35 |
-
|
36 |
-
@torch.no_grad()
|
37 |
-
def llama_blockwise_quantization(
|
38 |
-
model, sample_inputs, working_device, *, bits=4, groupsize=-1
|
39 |
-
):
|
40 |
-
"""
|
41 |
-
This is the classic post-training quantization of all linear layers.
|
42 |
-
We quantize in order, i.e. when observing the inputs, we use the outputs of the previously quantized layers rather
|
43 |
-
than doing them all at once.
|
44 |
-
"""
|
45 |
-
print(model)
|
46 |
-
print(model.config)
|
47 |
-
|
48 |
-
print("Getting inputs for first block")
|
49 |
-
model.transformer.wte.to(working_device)
|
50 |
-
sample_inputs = sample_inputs.to(working_device)
|
51 |
-
inps = model.transformer.wte(sample_inputs)
|
52 |
-
model.transformer.wte.to("cpu")
|
53 |
-
torch.cuda.empty_cache()
|
54 |
-
|
55 |
-
rope_cache = model.build_rope_cache(sample_inputs)
|
56 |
-
mask_cache = model.build_mask_cache(sample_inputs)
|
57 |
-
|
58 |
-
print("Starting to quantize blocks")
|
59 |
-
outs = torch.zeros_like(inps)
|
60 |
-
|
61 |
-
# better than relying on enumeration? originally the code bundled
|
62 |
-
# the two mlp fc layers
|
63 |
-
# we could automate this with a lot of hooks and another iteration
|
64 |
-
submodules_to_process = [
|
65 |
-
"attn.c_attn",
|
66 |
-
"attn.c_proj",
|
67 |
-
"mlp.c_fc1",
|
68 |
-
"mlp.c_fc2",
|
69 |
-
"mlp.c_proj",
|
70 |
-
]
|
71 |
-
|
72 |
-
for i, block in enumerate(model.transformer.h):
|
73 |
-
block.to(working_device)
|
74 |
-
|
75 |
-
for name in submodules_to_process:
|
76 |
-
print(i, name, end=" ")
|
77 |
-
t0 = time.perf_counter()
|
78 |
-
print("collecting stats", end=" ")
|
79 |
-
sys.stdout.flush()
|
80 |
-
module = block.get_submodule(name)
|
81 |
-
|
82 |
-
gptq = GPTQQuantizer(
|
83 |
-
module,
|
84 |
-
bits=bits,
|
85 |
-
groupsize=groupsize,
|
86 |
-
actorder=(groupsize == -1),
|
87 |
-
)
|
88 |
-
handle = module.register_forward_hook(gptq.collect_input_stats)
|
89 |
-
for j in range(inps.size(0)):
|
90 |
-
outs[j : j + 1], _ = block(
|
91 |
-
inps[j : j + 1],
|
92 |
-
rope=rope_cache,
|
93 |
-
mask=mask_cache,
|
94 |
-
max_seq_length=model.config.block_size
|
95 |
-
)
|
96 |
-
|
97 |
-
handle.remove()
|
98 |
-
|
99 |
-
print("quantizing", end=" ")
|
100 |
-
sys.stdout.flush()
|
101 |
-
q_module, error = gptq.quantize()
|
102 |
-
|
103 |
-
# replace the linear module with the quantized module
|
104 |
-
pname, dname = name.rsplit(".", 1)
|
105 |
-
setattr(block.get_submodule(pname), dname, q_module)
|
106 |
-
|
107 |
-
# cleanup in an attempt to not run out of memory
|
108 |
-
del gptq
|
109 |
-
gc.collect()
|
110 |
-
torch.cuda.empty_cache()
|
111 |
-
t1 = time.perf_counter()
|
112 |
-
print(f"time {int(t1 - t0 + 0.5)}s quantization error {error:.1f}")
|
113 |
-
|
114 |
-
for j in range(inps.size(0)):
|
115 |
-
outs[j : j + 1], _ = block(
|
116 |
-
inps[j : j + 1],
|
117 |
-
rope=rope_cache,
|
118 |
-
mask=mask_cache,
|
119 |
-
max_seq_length=model.config.block_size
|
120 |
-
)
|
121 |
-
|
122 |
-
block.cpu()
|
123 |
-
gc.collect()
|
124 |
-
torch.cuda.empty_cache()
|
125 |
-
|
126 |
-
# the outputs are the next block's inputs and we'll reuse the old inputs
|
127 |
-
inps, outs = outs, inps
|
128 |
-
|
129 |
-
model.transformer.ln_f.to(working_device)
|
130 |
-
for j in range(inps.size(0)):
|
131 |
-
outs[j : j + 1] = model.transformer.ln_f(inps[j : j + 1])
|
132 |
-
model.transformer.ln_f.to("cpu")
|
133 |
-
inps, outs = outs, inps
|
134 |
-
|
135 |
-
model.lm_head.to(working_device)
|
136 |
-
gptq = GPTQQuantizer(
|
137 |
-
model.lm_head,
|
138 |
-
bits=bits,
|
139 |
-
groupsize=groupsize,
|
140 |
-
actorder=(groupsize == -1),
|
141 |
-
)
|
142 |
-
handle = model.lm_head.register_forward_hook(gptq.collect_input_stats)
|
143 |
-
for j in range(inps.size(0)):
|
144 |
-
model.lm_head(inps[j : j + 1])
|
145 |
-
handle.remove()
|
146 |
-
q_module, error = gptq.quantize()
|
147 |
-
model.lm_head = q_module
|
148 |
-
model.lm_head.to("cpu")
|
149 |
-
|
150 |
-
|
151 |
-
def main(
|
152 |
-
*,
|
153 |
-
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
|
154 |
-
output_path: Optional[Path] = None,
|
155 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
156 |
-
n_samples: int = 128,
|
157 |
-
dtype: str = "float32",
|
158 |
-
quantize: Optional[str] = None,
|
159 |
-
) -> None:
|
160 |
-
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
|
161 |
-
|
162 |
-
Args:
|
163 |
-
checkpoint_path: The checkpoint path to load.
|
164 |
-
output_path: Path to write the quantized model's state dict to.
|
165 |
-
tokenizer_path: The tokenizer path to load.
|
166 |
-
n_samples: Number of example inputs to use for statistics (default: 128)
|
167 |
-
dtype: The dtype to use to load the model.
|
168 |
-
quantize: Mode to quantize the model to:
|
169 |
-
``"gptq.int4"``: GPTQ 4-bit mode.
|
170 |
-
Note that ``"llm.int8"```does not need a quantization step.
|
171 |
-
"""
|
172 |
-
assert checkpoint_path.is_file()
|
173 |
-
assert tokenizer_path.is_file()
|
174 |
-
if output_path is None:
|
175 |
-
output_path = checkpoint_path.parent / "llama-gptq.4bit.pth"
|
176 |
-
assert output_path.parent.is_dir() and (not output_path.exists() or output_path.is_file())
|
177 |
-
|
178 |
-
device = "cuda"
|
179 |
-
|
180 |
-
dt = getattr(torch, dtype, None)
|
181 |
-
if not isinstance(dt, torch.dtype):
|
182 |
-
raise ValueError(f"{dtype} is not a valid dtype.")
|
183 |
-
dtype = dt
|
184 |
-
|
185 |
-
if quantize == "gptq.int4":
|
186 |
-
bits = 4
|
187 |
-
elif quantize == "gptq.int8":
|
188 |
-
bits = 8
|
189 |
-
else:
|
190 |
-
raise RuntimeError(f"unknown/unsupported quantization mode {quantize}")
|
191 |
-
|
192 |
-
# we avoid loading the entire model on the GPU and do this block by block
|
193 |
-
with EmptyInitOnDevice(
|
194 |
-
device="cpu",
|
195 |
-
dtype=dtype,
|
196 |
-
):
|
197 |
-
print("Loading model ...", file=sys.stderr)
|
198 |
-
t0 = time.time()
|
199 |
-
checkpoint = torch.load(checkpoint_path)
|
200 |
-
name = llama_model_lookup(checkpoint)
|
201 |
-
model = LLaMA.from_name(name)
|
202 |
-
model.load_state_dict(checkpoint)
|
203 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
204 |
-
|
205 |
-
model.eval()
|
206 |
-
|
207 |
-
tokenizer = Tokenizer(tokenizer_path)
|
208 |
-
|
209 |
-
test_string = get_sample_data()
|
210 |
-
encoded_text = tokenizer.encode(
|
211 |
-
test_string,
|
212 |
-
bos=True,
|
213 |
-
eos=False,
|
214 |
-
)
|
215 |
-
block_size = 2048 # this is for compat with gptq, and indeed we get much worse beyond this (https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L30)
|
216 |
-
encoded_text = encoded_text[: n_samples * block_size].reshape(n_samples, block_size)
|
217 |
-
|
218 |
-
t0 = time.perf_counter()
|
219 |
-
llama_blockwise_quantization(model, encoded_text, device, bits=bits)
|
220 |
-
t = time.perf_counter() - t0
|
221 |
-
|
222 |
-
print(
|
223 |
-
f"\n\nTime for quantization: {t:.02f} sec total",
|
224 |
-
file=sys.stderr,
|
225 |
-
)
|
226 |
-
print(
|
227 |
-
f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
|
228 |
-
file=sys.stderr,
|
229 |
-
)
|
230 |
-
|
231 |
-
torch.save(model.state_dict(), output_path)
|
232 |
-
|
233 |
-
|
234 |
-
if __name__ == "__main__":
|
235 |
-
from jsonargparse import CLI
|
236 |
-
|
237 |
-
torch.set_float32_matmul_precision("high")
|
238 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/requirements.txt
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
torch>=2.0.0
|
2 |
-
lightning @ git+https://github.com/Lightning-AI/lightning@master
|
3 |
-
sentencepiece
|
4 |
-
tqdm # convert_checkpoint.py
|
5 |
-
numpy # train.py dataset memmap
|
6 |
-
jsonargparse[signatures] # generate.py, convert_checkpoint.py CLI
|
7 |
-
bitsandbytes # quantization.py
|
8 |
-
datasets # evaluate.py
|
9 |
-
zstandard # prepare_redpajama.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/scripts/convert_checkpoint.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
import gc
|
2 |
-
import shutil
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Dict
|
5 |
-
|
6 |
-
import torch
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
"""
|
10 |
-
Sample usage:
|
11 |
-
|
12 |
-
```bash
|
13 |
-
python -m scripts.convert_checkpoint -h
|
14 |
-
|
15 |
-
python -m scripts.convert_checkpoint converted
|
16 |
-
```
|
17 |
-
"""
|
18 |
-
|
19 |
-
|
20 |
-
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
|
21 |
-
converted = {}
|
22 |
-
converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"].to(dtype)
|
23 |
-
converted["lm_head.weight"] = state_dict["output.weight"].to(dtype)
|
24 |
-
converted["transformer.ln_f.scale"] = state_dict["norm.weight"].to(dtype)
|
25 |
-
|
26 |
-
for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
|
27 |
-
# attention
|
28 |
-
# the wq, wk, wv from the FB model are stacked in our model as c_attn
|
29 |
-
converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat(
|
30 |
-
(
|
31 |
-
state_dict[f"layers.{layer_idx}.attention.wq.weight"].to(dtype),
|
32 |
-
state_dict[f"layers.{layer_idx}.attention.wk.weight"].to(dtype),
|
33 |
-
state_dict[f"layers.{layer_idx}.attention.wv.weight"].to(dtype),
|
34 |
-
)
|
35 |
-
)
|
36 |
-
converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[
|
37 |
-
f"layers.{layer_idx}.attention.wo.weight"
|
38 |
-
].to(dtype)
|
39 |
-
# mlp
|
40 |
-
converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[
|
41 |
-
f"layers.{layer_idx}.feed_forward.w1.weight"
|
42 |
-
].to(dtype)
|
43 |
-
converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[
|
44 |
-
f"layers.{layer_idx}.feed_forward.w2.weight"
|
45 |
-
].to(dtype)
|
46 |
-
converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[
|
47 |
-
f"layers.{layer_idx}.feed_forward.w3.weight"
|
48 |
-
].to(dtype)
|
49 |
-
# rms norm
|
50 |
-
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"].to(dtype)
|
51 |
-
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"].to(dtype)
|
52 |
-
return converted
|
53 |
-
|
54 |
-
|
55 |
-
shard_dims = {
|
56 |
-
"lm_head.weight": 0,
|
57 |
-
"wte.weight": 1,
|
58 |
-
"attn.c_attn.weight": 0,
|
59 |
-
"attn.c_proj.weight": 1,
|
60 |
-
"mlp.c_fc1.weight": 0,
|
61 |
-
"mlp.c_fc2.weight": 0,
|
62 |
-
"mlp.c_proj.weight": 1
|
63 |
-
}
|
64 |
-
|
65 |
-
|
66 |
-
def meta_weights_for_nano_model(
|
67 |
-
*,
|
68 |
-
output_dir: Path = Path("checkpoints/lit-llama"),
|
69 |
-
checkpoint_dir: Path = Path("checkpoints/llama/"),
|
70 |
-
model_size: str = "7B",
|
71 |
-
dtype: str = "float32",
|
72 |
-
) -> None:
|
73 |
-
output_dir = output_dir / model_size
|
74 |
-
checkpoint_dir = checkpoint_dir / model_size
|
75 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
76 |
-
|
77 |
-
# the tokenizer is the same for all model sizes, so we store it in the parent dir
|
78 |
-
shutil.copy(checkpoint_dir.parent / "tokenizer.model", output_dir.parent)
|
79 |
-
|
80 |
-
dt = getattr(torch, dtype, None)
|
81 |
-
if not isinstance(dt, torch.dtype):
|
82 |
-
raise ValueError(f"{dtype} is not a valid dtype.")
|
83 |
-
dtype = dt
|
84 |
-
|
85 |
-
checkpoint_files = sorted(checkpoint_dir.glob("*.pth"))
|
86 |
-
checkpoint_files.sort()
|
87 |
-
n_checkpoints = len(checkpoint_files)
|
88 |
-
|
89 |
-
if n_checkpoints == 0:
|
90 |
-
raise RuntimeError(f"No checkpoints were found at checkpoint_dir {checkpoint_dir}. `consolidated.0*.pth` files expected at that location.")
|
91 |
-
|
92 |
-
# for the bigger models, there are multiple model-parallel checkpoints
|
93 |
-
# and we combine them into one single file
|
94 |
-
combined = None
|
95 |
-
for file in tqdm(checkpoint_files, total=n_checkpoints):
|
96 |
-
checkpoint = torch.load(file, map_location="cpu")
|
97 |
-
converted = convert_state_dict(checkpoint, dtype=dtype)
|
98 |
-
if combined is None:
|
99 |
-
combined = converted
|
100 |
-
continue
|
101 |
-
for name, param in converted.items():
|
102 |
-
dim = None
|
103 |
-
for k, d in shard_dims.items():
|
104 |
-
if k in name:
|
105 |
-
dim = d
|
106 |
-
break
|
107 |
-
if dim is None:
|
108 |
-
# Extra check: assert that tensors are the same if not sharded
|
109 |
-
# assert torch.allclose(combined[name], param)
|
110 |
-
continue
|
111 |
-
combined[name] = torch.cat((combined[name], param), dim=dim)
|
112 |
-
|
113 |
-
del checkpoint
|
114 |
-
del converted
|
115 |
-
gc.collect()
|
116 |
-
|
117 |
-
for name, param in combined.items():
|
118 |
-
if "c_attn" not in name:
|
119 |
-
continue
|
120 |
-
|
121 |
-
# Turn [Q1, K1, V1, Q2, K2, V2, ...] into [Q1, Q2, ..., K1, K2, .., V1, V2, ...]
|
122 |
-
|
123 |
-
src_chunk_len = param.shape[0] // n_checkpoints
|
124 |
-
mat_len = src_chunk_len // 3
|
125 |
-
dst_chunk_len = mat_len * n_checkpoints
|
126 |
-
attn = torch.clone(param)
|
127 |
-
for i in range(n_checkpoints):
|
128 |
-
for j in range(3):
|
129 |
-
param[j * dst_chunk_len + i * mat_len: j * dst_chunk_len + (i+1) * mat_len] = \
|
130 |
-
attn[i * src_chunk_len + j * mat_len: i * src_chunk_len + (j+1) * mat_len]
|
131 |
-
|
132 |
-
del attn
|
133 |
-
gc.collect()
|
134 |
-
|
135 |
-
torch.save(combined, output_dir / "lit-llama.pth")
|
136 |
-
|
137 |
-
|
138 |
-
if __name__ == "__main__":
|
139 |
-
from jsonargparse import CLI
|
140 |
-
|
141 |
-
CLI(meta_weights_for_nano_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/scripts/convert_hf_checkpoint.py
DELETED
@@ -1,167 +0,0 @@
|
|
1 |
-
import collections
|
2 |
-
import contextlib
|
3 |
-
import gc
|
4 |
-
import json
|
5 |
-
import shutil
|
6 |
-
import sys
|
7 |
-
from pathlib import Path
|
8 |
-
|
9 |
-
import torch
|
10 |
-
|
11 |
-
# support running without installing as a package
|
12 |
-
wd = Path(__file__).parent.parent.resolve()
|
13 |
-
sys.path.append(str(wd))
|
14 |
-
|
15 |
-
from lit_llama.model import LLaMA, LLaMAConfig
|
16 |
-
from lit_llama.utils import EmptyInitOnDevice, lazy_load, incremental_save
|
17 |
-
|
18 |
-
|
19 |
-
@torch.no_grad()
|
20 |
-
def convert_hf_checkpoint(
|
21 |
-
*,
|
22 |
-
output_dir: Path = Path("checkpoints/lit-llama/7B"),
|
23 |
-
checkpoint_dir: Path = Path("checkpoints/hf-llama/7B"),
|
24 |
-
model_size: str = "7B",
|
25 |
-
dtype: str = "float32",
|
26 |
-
verify: bool = False,
|
27 |
-
) -> None:
|
28 |
-
"""
|
29 |
-
Perform the reverse operation of: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
|
30 |
-
"""
|
31 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
32 |
-
|
33 |
-
# the tokenizer is the same for all model sizes, so we store it in the parent dir
|
34 |
-
shutil.copy(checkpoint_dir / "tokenizer.model", output_dir.parent)
|
35 |
-
|
36 |
-
dt = getattr(torch, dtype, None)
|
37 |
-
if not isinstance(dt, torch.dtype):
|
38 |
-
raise ValueError(f"{dtype} is not a valid dtype.")
|
39 |
-
dtype = dt
|
40 |
-
|
41 |
-
print("Initializing lit-llama")
|
42 |
-
config = LLaMAConfig.from_name(model_size)
|
43 |
-
|
44 |
-
with EmptyInitOnDevice(device="meta", dtype=dtype):
|
45 |
-
model = LLaMA(config)
|
46 |
-
|
47 |
-
qkv_size = model.transformer.h[0].attn.c_attn.weight.shape[0] // 3
|
48 |
-
|
49 |
-
# initialize a new empty state dict to hold our new weights
|
50 |
-
sd_meta = model.state_dict()
|
51 |
-
sd = {}
|
52 |
-
|
53 |
-
# Load the json file containing weight mapping
|
54 |
-
pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json"
|
55 |
-
with open(pytorch_bin_map_json_path) as json_map:
|
56 |
-
bin_index = json.load(json_map)
|
57 |
-
bin_files = set(checkpoint_dir / bin for bin in bin_index["weight_map"].values())
|
58 |
-
if not bin_files:
|
59 |
-
raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files")
|
60 |
-
|
61 |
-
def permute(w):
|
62 |
-
dim = config.n_embd
|
63 |
-
w = w._load_tensor().to(dtype)
|
64 |
-
return (
|
65 |
-
w.view(config.n_head, 2, dim // config.n_head // 2, dim)
|
66 |
-
.transpose(1, 2)
|
67 |
-
.reshape(dim, dim)
|
68 |
-
)
|
69 |
-
|
70 |
-
weight_map = {
|
71 |
-
"self_attn.o_proj.weight": "attn.c_proj.weight",
|
72 |
-
"self_attn.q_proj.weight": "attn.c_attn.weight",
|
73 |
-
"self_attn.k_proj.weight": "attn.c_attn.weight",
|
74 |
-
"self_attn.v_proj.weight": "attn.c_attn.weight",
|
75 |
-
"mlp.gate_proj.weight": "mlp.c_fc1.weight",
|
76 |
-
"mlp.up_proj.weight": "mlp.c_fc2.weight",
|
77 |
-
"mlp.down_proj.weight": "mlp.c_proj.weight",
|
78 |
-
"input_layernorm.weight": "rms_1.scale",
|
79 |
-
"post_attention_layernorm.weight": "rms_2.scale",
|
80 |
-
"model.embed_tokens.weight": "transformer.wte.weight",
|
81 |
-
"model.norm.weight": "transformer.ln_f.scale",
|
82 |
-
"lm_head.weight": "lm_head.weight",
|
83 |
-
}
|
84 |
-
|
85 |
-
print(f"Saving to disk at {output_dir}")
|
86 |
-
unprocessed_weights = collections.defaultdict(dict)
|
87 |
-
|
88 |
-
with incremental_save(output_dir / "lit-llama.pth") as saver:
|
89 |
-
# for checkpoints that split the QKV across several files, we need to keep all the bin files
|
90 |
-
# open, so we use `ExitStack` to close them all together at the end
|
91 |
-
with contextlib.ExitStack() as stack:
|
92 |
-
for bin_file in bin_files:
|
93 |
-
print("Processing", bin_file)
|
94 |
-
hf_weights = stack.enter_context(lazy_load(bin_file))
|
95 |
-
for name, param in hf_weights.items():
|
96 |
-
skip = False
|
97 |
-
if "rotary_emb.inv_freq" in name:
|
98 |
-
continue
|
99 |
-
if "model.layers" in name:
|
100 |
-
block_id = int(name.split(".")[2])
|
101 |
-
from_name = ".".join(name.split(".")[3:])
|
102 |
-
to_name = weight_map[from_name]
|
103 |
-
sd_key = f"transformer.h.{block_id}.{to_name}"
|
104 |
-
|
105 |
-
if "q_proj" in name:
|
106 |
-
unprocessed_weights[sd_key]["q_proj"] = param
|
107 |
-
skip = True
|
108 |
-
elif "k_proj" in name:
|
109 |
-
unprocessed_weights[sd_key]["k_proj"] = param
|
110 |
-
skip = True
|
111 |
-
elif "v_proj" in name:
|
112 |
-
unprocessed_weights[sd_key]["v_proj"] = param
|
113 |
-
skip = True
|
114 |
-
if skip and len(unprocessed_weights[sd_key]) == 3:
|
115 |
-
w = torch.empty(
|
116 |
-
sd_meta[sd_key].shape, dtype=sd_meta[sd_key].dtype
|
117 |
-
)
|
118 |
-
w[:qkv_size] = permute(unprocessed_weights[sd_key]["q_proj"])
|
119 |
-
w[qkv_size:-qkv_size] = permute(
|
120 |
-
unprocessed_weights[sd_key]["k_proj"]
|
121 |
-
)
|
122 |
-
w[-qkv_size:] = (
|
123 |
-
unprocessed_weights[sd_key]["v_proj"]
|
124 |
-
._load_tensor()
|
125 |
-
.to(dtype)
|
126 |
-
)
|
127 |
-
sd[sd_key] = w
|
128 |
-
del unprocessed_weights[sd_key]
|
129 |
-
skip = False
|
130 |
-
else:
|
131 |
-
sd[sd_key] = param._load_tensor().to(dtype)
|
132 |
-
else:
|
133 |
-
sd_key = weight_map[name]
|
134 |
-
sd[sd_key] = param._load_tensor().to(dtype)
|
135 |
-
if not skip:
|
136 |
-
sd[sd_key] = saver.store_early(sd[sd_key])
|
137 |
-
gc.collect()
|
138 |
-
saver.save(sd)
|
139 |
-
|
140 |
-
assert len(unprocessed_weights) == 0, f"unexpected partial weights {list(unprocessed_weights)}"
|
141 |
-
if verify:
|
142 |
-
try:
|
143 |
-
from transformers import LlamaForCausalLM
|
144 |
-
except ImportError:
|
145 |
-
raise ImportError("verify=True requires transformers to be installed, please `pip install transformers`")
|
146 |
-
print("Verifying...")
|
147 |
-
|
148 |
-
token_sample = torch.randint(0, config.vocab_size, size=(1, config.block_size), dtype=torch.int64)
|
149 |
-
out = model(token_sample)
|
150 |
-
del model
|
151 |
-
gc.collect()
|
152 |
-
|
153 |
-
print("Loading original model for comparison")
|
154 |
-
model_hf = LlamaForCausalLM.from_pretrained(checkpoint_dir)
|
155 |
-
out_hf = model_hf(token_sample)["logits"]
|
156 |
-
|
157 |
-
print("Comparing outputs")
|
158 |
-
assert out.device.type == out_hf.device.type
|
159 |
-
assert out.dtype == out_hf.dtype
|
160 |
-
assert torch.testing.assert_close(out, out_hf)
|
161 |
-
|
162 |
-
|
163 |
-
if __name__ == "__main__":
|
164 |
-
from jsonargparse import CLI
|
165 |
-
|
166 |
-
CLI(convert_hf_checkpoint)
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/scripts/convert_lora_weights.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import time
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Optional
|
5 |
-
|
6 |
-
import lightning as L
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
|
10 |
-
# support running without installing as a package
|
11 |
-
wd = Path(__file__).parent.parent.resolve()
|
12 |
-
sys.path.append(str(wd))
|
13 |
-
|
14 |
-
from lit_llama import LLaMA
|
15 |
-
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
|
16 |
-
from lit_llama.lora import lora
|
17 |
-
|
18 |
-
def del_lora_state_dict(model: nn.Module):
|
19 |
-
base_model_dict = model.state_dict()
|
20 |
-
key_to_delete = [k for k in base_model_dict if "lora_" in k]
|
21 |
-
for del_key in key_to_delete:
|
22 |
-
del base_model_dict[del_key]
|
23 |
-
return base_model_dict
|
24 |
-
|
25 |
-
|
26 |
-
def lora_model_lookup(checkpoint: dict) -> int:
|
27 |
-
"""Returns the LoRA rank from the adapter checkpoint.
|
28 |
-
|
29 |
-
"""
|
30 |
-
return checkpoint["transformer.h.0.attn.c_attn.lora_B"].shape[1]
|
31 |
-
|
32 |
-
|
33 |
-
def main(
|
34 |
-
accelerator: str = "auto",
|
35 |
-
lora_path: Optional[Path] = None,
|
36 |
-
checkpoint_path: Optional[Path] = None,
|
37 |
-
dtype: str = "bfloat16",
|
38 |
-
) -> None:
|
39 |
-
"""Merges lora weights to base model.
|
40 |
-
|
41 |
-
Args:
|
42 |
-
accelerator: The hardware to run on. Possible choices are:
|
43 |
-
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
|
44 |
-
lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
|
45 |
-
`finetune_lora.py`.
|
46 |
-
checkpoint_path: The checkpoint path to load.
|
47 |
-
dtype: `torch.dtype` to work with
|
48 |
-
"""
|
49 |
-
if not lora_path:
|
50 |
-
lora_path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth")
|
51 |
-
if not checkpoint_path:
|
52 |
-
checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
|
53 |
-
|
54 |
-
assert lora_path.is_file()
|
55 |
-
assert checkpoint_path.is_file()
|
56 |
-
|
57 |
-
fabric = L.Fabric(accelerator=accelerator, devices=1)
|
58 |
-
|
59 |
-
dt = getattr(torch, dtype, None)
|
60 |
-
if not isinstance(dt, torch.dtype):
|
61 |
-
raise ValueError(f"{dtype} is not a valid dtype.")
|
62 |
-
dtype = dt
|
63 |
-
|
64 |
-
print("Loading model ...", file=sys.stderr)
|
65 |
-
t0 = time.time()
|
66 |
-
|
67 |
-
with (lazy_load(checkpoint_path) as pretrained_checkpoint,
|
68 |
-
lazy_load(lora_path) as lora_checkpoint):
|
69 |
-
name = llama_model_lookup(pretrained_checkpoint)
|
70 |
-
rank = lora_model_lookup(lora_checkpoint)
|
71 |
-
|
72 |
-
with EmptyInitOnDevice(
|
73 |
-
device=fabric.device, dtype=dtype
|
74 |
-
), lora(r=rank, alpha=16, dropout=0.05, enabled=True):
|
75 |
-
model = LLaMA.from_name(name)
|
76 |
-
|
77 |
-
# 1. Load the pretrained weights
|
78 |
-
model.load_state_dict(pretrained_checkpoint, strict=False)
|
79 |
-
# 2. Load the fine-tuned lora weights
|
80 |
-
model.load_state_dict(lora_checkpoint, strict=False)
|
81 |
-
|
82 |
-
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
83 |
-
|
84 |
-
model.eval()
|
85 |
-
base_model_dict = del_lora_state_dict(model)
|
86 |
-
save_path = lora_path.with_stem(f"{lora_path.stem}-lora-merged-weights")
|
87 |
-
print("Saving LoRA to base model weights ...")
|
88 |
-
torch.save(base_model_dict, save_path)
|
89 |
-
print(f"Model saved at {save_path}")
|
90 |
-
|
91 |
-
|
92 |
-
if __name__ == "__main__":
|
93 |
-
from jsonargparse import CLI
|
94 |
-
|
95 |
-
CLI(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/scripts/download.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from typing import Optional
|
3 |
-
from urllib.request import urlretrieve
|
4 |
-
|
5 |
-
files = {
|
6 |
-
"original_model.py": "https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/7dd20f51c2a1ff2886387f0e25c1750a485a08e1/llama_model.py",
|
7 |
-
"original_adapter.py": "https://gist.githubusercontent.com/awaelchli/546f33fcdb84cc9f1b661ca1ca18418d/raw/e81d8f35fb1fec53af1099349b0c455fc8c9fb01/original_adapter.py",
|
8 |
-
}
|
9 |
-
|
10 |
-
|
11 |
-
def download_original(wd: str) -> None:
|
12 |
-
for file, url in files.items():
|
13 |
-
filepath = os.path.join(wd, file)
|
14 |
-
if not os.path.isfile(filepath):
|
15 |
-
print(f"Downloading original implementation to {filepath!r}")
|
16 |
-
urlretrieve(url=url, filename=file)
|
17 |
-
print("Done")
|
18 |
-
else:
|
19 |
-
print("Original implementation found. Skipping download.")
|
20 |
-
|
21 |
-
|
22 |
-
def download_from_hub(repo_id: Optional[str] = None, local_dir: str = "checkpoints/hf-llama/7B") -> None:
|
23 |
-
if repo_id is None:
|
24 |
-
raise ValueError("Please pass `--repo_id=...`. You can try googling 'huggingface hub llama' for options.")
|
25 |
-
|
26 |
-
from huggingface_hub import snapshot_download
|
27 |
-
|
28 |
-
snapshot_download(repo_id, local_dir=local_dir, local_dir_use_symlinks=False)
|
29 |
-
|
30 |
-
|
31 |
-
if __name__ == "__main__":
|
32 |
-
from jsonargparse import CLI
|
33 |
-
|
34 |
-
CLI(download_from_hub)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/scripts/prepare_alpaca.py
DELETED
@@ -1,131 +0,0 @@
|
|
1 |
-
"""Implementation derived from https://github.com/tloen/alpaca-lora"""
|
2 |
-
import sys
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
# support running without installing as a package
|
6 |
-
wd = Path(__file__).parent.parent.resolve()
|
7 |
-
sys.path.append(str(wd))
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import requests
|
11 |
-
import json
|
12 |
-
from torch.utils.data import random_split
|
13 |
-
from lit_llama.tokenizer import Tokenizer
|
14 |
-
from tqdm import tqdm
|
15 |
-
|
16 |
-
|
17 |
-
DATA_FILE = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json"
|
18 |
-
DATA_FILE_NAME = "alpaca_data_cleaned_archive.json"
|
19 |
-
IGNORE_INDEX = -1
|
20 |
-
|
21 |
-
|
22 |
-
def prepare(
|
23 |
-
destination_path: Path = Path("data/alpaca"),
|
24 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
25 |
-
test_split_size: int = 2000,
|
26 |
-
max_seq_length: int = 256,
|
27 |
-
seed: int = 42,
|
28 |
-
mask_inputs: bool = False, # as in alpaca-lora
|
29 |
-
data_file_name: str = DATA_FILE_NAME
|
30 |
-
) -> None:
|
31 |
-
"""Prepare the Alpaca dataset for instruction tuning.
|
32 |
-
|
33 |
-
The output is a training and validation dataset saved as `train.pt` and `val.pt`,
|
34 |
-
which stores the preprocessed and tokenized prompts and labels.
|
35 |
-
"""
|
36 |
-
|
37 |
-
destination_path.mkdir(parents=True, exist_ok=True)
|
38 |
-
file_path = destination_path / data_file_name
|
39 |
-
download(file_path)
|
40 |
-
|
41 |
-
# TODO: If we don't have the Meta weights, where do we get the tokenizer from?
|
42 |
-
tokenizer = Tokenizer(tokenizer_path)
|
43 |
-
|
44 |
-
with open(file_path, "r") as file:
|
45 |
-
data = json.load(file)
|
46 |
-
|
47 |
-
# Partition the dataset into train and test
|
48 |
-
train_split_size = len(data) - test_split_size
|
49 |
-
train_set, test_set = random_split(
|
50 |
-
data,
|
51 |
-
lengths=(train_split_size, test_split_size),
|
52 |
-
generator=torch.Generator().manual_seed(seed),
|
53 |
-
)
|
54 |
-
train_set, test_set = list(train_set), list(test_set)
|
55 |
-
|
56 |
-
print(f"train has {len(train_set):,} samples")
|
57 |
-
print(f"val has {len(test_set):,} samples")
|
58 |
-
|
59 |
-
print("Processing train split ...")
|
60 |
-
train_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(train_set)]
|
61 |
-
torch.save(train_set, file_path.parent / "train.pt")
|
62 |
-
|
63 |
-
print("Processing test split ...")
|
64 |
-
test_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(test_set)]
|
65 |
-
torch.save(test_set, file_path.parent / "test.pt")
|
66 |
-
|
67 |
-
|
68 |
-
def download(file_path: Path):
|
69 |
-
"""Downloads the raw json data file and saves it in the given destination."""
|
70 |
-
if file_path.exists():
|
71 |
-
return
|
72 |
-
with open(file_path, "w") as f:
|
73 |
-
f.write(requests.get(DATA_FILE).text)
|
74 |
-
|
75 |
-
|
76 |
-
def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True):
|
77 |
-
"""Processes a single sample.
|
78 |
-
|
79 |
-
Each sample in the dataset consists of:
|
80 |
-
- instruction: A string describing the task
|
81 |
-
- input: A string holding a special input value for the instruction.
|
82 |
-
This only applies to some samples, and in others this is empty.
|
83 |
-
- output: The response string
|
84 |
-
|
85 |
-
This function processes this data to produce a prompt text and a label for
|
86 |
-
supervised training. The input text is formed as a single message including all
|
87 |
-
the instruction, the input (optional) and the response.
|
88 |
-
The label/target is the same message but can optionally have the instruction + input text
|
89 |
-
masked out (mask_inputs=True).
|
90 |
-
|
91 |
-
Finally, both the prompt and the label get tokenized. If desired, all tokens
|
92 |
-
in the label that correspond to the original input prompt get masked out (default).
|
93 |
-
"""
|
94 |
-
full_prompt = generate_prompt(example)
|
95 |
-
full_prompt_and_response = full_prompt + example["output"]
|
96 |
-
encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False)
|
97 |
-
encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length)
|
98 |
-
|
99 |
-
# The labels are the full prompt with response, but with the prompt masked out
|
100 |
-
labels = encoded_full_prompt_and_response.clone()
|
101 |
-
if mask_inputs:
|
102 |
-
labels[:len(encoded_full_prompt)] = IGNORE_INDEX
|
103 |
-
|
104 |
-
return {**example, "input_ids": encoded_full_prompt_and_response, "input_ids_no_response": encoded_full_prompt, "labels": labels}
|
105 |
-
|
106 |
-
|
107 |
-
def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
|
108 |
-
return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)
|
109 |
-
|
110 |
-
|
111 |
-
def generate_prompt(example):
|
112 |
-
"""Generates a standardized message to prompt the model with an instruction, optional input and a
|
113 |
-
'response' field."""
|
114 |
-
|
115 |
-
if example["input"]:
|
116 |
-
return (
|
117 |
-
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
118 |
-
"Write a response that appropriately completes the request.\n\n"
|
119 |
-
f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
|
120 |
-
)
|
121 |
-
return (
|
122 |
-
"Below is an instruction that describes a task. "
|
123 |
-
"Write a response that appropriately completes the request.\n\n"
|
124 |
-
f"### Instruction:\n{example['instruction']}\n\n### Response:"
|
125 |
-
)
|
126 |
-
|
127 |
-
|
128 |
-
if __name__ == "__main__":
|
129 |
-
from jsonargparse import CLI
|
130 |
-
|
131 |
-
CLI(prepare)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/scripts/prepare_any_text.py
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
"""Implementation derived from https://github.com/tloen/alpaca-lora"""
|
2 |
-
import sys
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
# support running without installing as a package
|
6 |
-
wd = Path(__file__).parent.parent.resolve()
|
7 |
-
sys.path.append(str(wd))
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import requests
|
11 |
-
import json
|
12 |
-
from torch.utils.data import random_split
|
13 |
-
from lit_llama.tokenizer import Tokenizer
|
14 |
-
from tqdm import tqdm
|
15 |
-
|
16 |
-
|
17 |
-
IGNORE_INDEX = -1
|
18 |
-
|
19 |
-
DATA_FILE_NAME = "input.txt"
|
20 |
-
|
21 |
-
|
22 |
-
def prepare(
|
23 |
-
destination_path: Path = Path("data/any"),
|
24 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
25 |
-
test_split_ratio: float = 0.9, # default 90% train, 10% validation
|
26 |
-
max_seq_length: int = 256,
|
27 |
-
seed: int = 42,
|
28 |
-
data_file_name: str = DATA_FILE_NAME,
|
29 |
-
) -> None:
|
30 |
-
"""Prepare any dataset for finetuning (akin to Shakespheare full tuning).
|
31 |
-
|
32 |
-
The output is a training and validation dataset saved as `train.pt` and `val.pt`,
|
33 |
-
which stores the preprocessed and tokenized prompts and labels.
|
34 |
-
"""
|
35 |
-
|
36 |
-
destination_path.mkdir(parents=True, exist_ok=True)
|
37 |
-
file_path = destination_path / data_file_name
|
38 |
-
if not file_path.exists():
|
39 |
-
raise AssertionError(f"{data_file_name} is provided by the user")
|
40 |
-
|
41 |
-
# TODO: If we don't have the Meta weights, where do we get the tokenizer from?
|
42 |
-
tokenizer = Tokenizer(tokenizer_path)
|
43 |
-
|
44 |
-
data = []
|
45 |
-
|
46 |
-
with open(file_path, "r") as input_file:
|
47 |
-
for line in input_file.readlines():
|
48 |
-
data.append(line)
|
49 |
-
|
50 |
-
# Partition the dataset into train and test
|
51 |
-
train_split_size = int(len(data) * test_split_ratio)
|
52 |
-
test_split_size = len(data) - train_split_size
|
53 |
-
train_set, test_set = random_split(
|
54 |
-
data,
|
55 |
-
lengths=(train_split_size, test_split_size),
|
56 |
-
generator=torch.Generator().manual_seed(seed),
|
57 |
-
)
|
58 |
-
train_set, test_set = list(train_set), list(test_set)
|
59 |
-
|
60 |
-
print(f"train has {len(train_set):,} samples")
|
61 |
-
print(f"val has {len(test_set):,} samples")
|
62 |
-
|
63 |
-
print("Processing train split ...")
|
64 |
-
train_set = [
|
65 |
-
prepare_line(line, tokenizer, max_seq_length) for line in tqdm(train_set)
|
66 |
-
]
|
67 |
-
torch.save(train_set, file_path.parent / "train.pt")
|
68 |
-
|
69 |
-
print("Processing test split ...")
|
70 |
-
test_set = [
|
71 |
-
prepare_line(line, tokenizer, max_seq_length) for line in tqdm(test_set)
|
72 |
-
]
|
73 |
-
torch.save(test_set, file_path.parent / "test.pt")
|
74 |
-
|
75 |
-
|
76 |
-
def prepare_line(line: str, tokenizer: Tokenizer, max_length: int):
|
77 |
-
"""Processes a single sample.
|
78 |
-
|
79 |
-
This function processes the line to produce the tokenized version of it.
|
80 |
-
"""
|
81 |
-
encoded_full_prompt = tokenize(tokenizer, line, max_length=max_length, eos=False)
|
82 |
-
return {
|
83 |
-
"input_ids": encoded_full_prompt,
|
84 |
-
"labels": encoded_full_prompt,
|
85 |
-
}
|
86 |
-
|
87 |
-
|
88 |
-
def tokenize(
|
89 |
-
tokenizer: Tokenizer, string: str, max_length: int, eos=True
|
90 |
-
) -> torch.Tensor:
|
91 |
-
return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)
|
92 |
-
|
93 |
-
|
94 |
-
if __name__ == "__main__":
|
95 |
-
from jsonargparse import CLI
|
96 |
-
|
97 |
-
CLI(prepare)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lit-llama/scripts/prepare_dolly.py
DELETED
@@ -1,133 +0,0 @@
|
|
1 |
-
"""Implementation derived from https://github.com/tloen/alpaca-lora"""
|
2 |
-
import sys
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
# support running without installing as a package
|
6 |
-
wd = Path(__file__).parent.parent.resolve()
|
7 |
-
sys.path.append(str(wd))
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import requests
|
11 |
-
import json
|
12 |
-
from torch.utils.data import random_split
|
13 |
-
from lit_llama.tokenizer import Tokenizer
|
14 |
-
from tqdm import tqdm
|
15 |
-
|
16 |
-
|
17 |
-
DATA_FILE = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
|
18 |
-
DATA_FILE_NAME = "dolly_data_cleaned.json"
|
19 |
-
IGNORE_INDEX = -1
|
20 |
-
|
21 |
-
|
22 |
-
def prepare(
|
23 |
-
destination_path: Path = Path("data/dolly"),
|
24 |
-
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
|
25 |
-
test_split_size: int = 2000,
|
26 |
-
max_seq_length: int = 1024,
|
27 |
-
seed: int = 42,
|
28 |
-
mask_inputs: bool = False, # as in alpaca-lora
|
29 |
-
) -> None:
|
30 |
-
"""Prepare the Dolly dataset for instruction tuning.
|
31 |
-
|
32 |
-
The output is a training and validation dataset saved as `train.pt` and `val.pt`,
|
33 |
-
which stores the preprocessed and tokenized prompts and labels.
|
34 |
-
"""
|
35 |
-
|
36 |
-
destination_path.mkdir(parents=True, exist_ok=True)
|
37 |
-
file_path = destination_path / DATA_FILE_NAME
|
38 |
-
download(file_path)
|
39 |
-
|
40 |
-
# TODO: If we don't have the Meta weights, where do we get the tokenizer from?
|
41 |
-
tokenizer = Tokenizer(tokenizer_path)
|
42 |
-
|
43 |
-
with open(file_path, "r") as file:
|
44 |
-
data = file.readlines()
|
45 |
-
data = [json.loads(line) for line in data]
|
46 |
-
for item in data:
|
47 |
-
item["input"] = item.pop("context")
|
48 |
-
item["output"] = item.pop("response")
|
49 |
-
|
50 |
-
# Partition the dataset into train and test
|
51 |
-
train_split_size = len(data) - test_split_size
|
52 |
-
train_set, test_set = random_split(
|
53 |
-
data,
|
54 |
-
lengths=(train_split_size, test_split_size),
|
55 |
-
generator=torch.Generator().manual_seed(seed),
|
56 |
-
)
|
57 |
-
train_set, test_set = list(train_set), list(test_set)
|
58 |
-
|
59 |
-
print(f"train has {len(train_set):,} samples")
|
60 |
-
print(f"val has {len(test_set):,} samples")
|
61 |
-
|
62 |
-
print("Processing train split ...")
|
63 |
-
train_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(train_set)]
|
64 |
-
torch.save(train_set, file_path.parent / "train.pt")
|
65 |
-
|
66 |
-
print("Processing test split ...")
|
67 |
-
test_set = [prepare_sample(sample, tokenizer, max_seq_length, mask_inputs) for sample in tqdm(test_set)]
|
68 |
-
torch.save(test_set, file_path.parent / "test.pt")
|
69 |
-
|
70 |
-
|
71 |
-
def download(file_path: Path):
|
72 |
-
"""Downloads the raw json data file and saves it in the given destination."""
|
73 |
-
if file_path.exists():
|
74 |
-
return
|
75 |
-
with open(file_path, "w") as f:
|
76 |
-
f.write(requests.get(DATA_FILE).text)
|
77 |
-
|
78 |
-
|
79 |
-
def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True):
|
80 |
-
"""Processes a single sample.
|
81 |
-
|
82 |
-
Each sample in the dataset consists of:
|
83 |
-
- instruction: A string describing the task
|
84 |
-
- input: A string holding a special input value for the instruction.
|
85 |
-
This only applies to some samples, and in others this is empty.
|
86 |
-
- output: The response string
|
87 |
-
|
88 |
-
This function processes this data to produce a prompt text and a label for
|
89 |
-
supervised training. The prompt text is formed as a single message including both
|
90 |
-
the instruction and the input. The label/target is the same message but with the
|
91 |
-
response attached.
|
92 |
-
|
93 |
-
Finally, both the prompt and the label get tokenized. If desired, all tokens
|
94 |
-
in the label that correspond to the original input prompt get masked out (default).
|
95 |
-
"""
|
96 |
-
full_prompt = generate_prompt(example)
|
97 |
-
full_prompt_and_response = full_prompt + example["output"]
|
98 |
-
encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False)
|
99 |
-
encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length)
|
100 |
-
|
101 |
-
# The labels are the full prompt with response, but with the prompt masked out
|
102 |
-
labels = encoded_full_prompt_and_response.clone()
|
103 |
-
if mask_inputs:
|
104 |
-
labels[:len(encoded_full_prompt)] = IGNORE_INDEX
|
105 |
-
|
106 |
-
return {**example, "input_ids": encoded_full_prompt_and_response, "input_ids_no_response": encoded_full_prompt, "labels": labels}
|
107 |
-
|
108 |
-
|
109 |
-
def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
|
110 |
-
return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)
|
111 |
-
|
112 |
-
|
113 |
-
def generate_prompt(example):
|
114 |
-
"""Generates a standardized message to prompt the model with an instruction, optional input and a
|
115 |
-
'response' field."""
|
116 |
-
|
117 |
-
if example["input"]:
|
118 |
-
return (
|
119 |
-
f"Below is an instruction that describes a task, paired with an input that provides further context. "
|
120 |
-
"Write a response that appropriately completes the request.\n\n"
|
121 |
-
f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
|
122 |
-
)
|
123 |
-
return (
|
124 |
-
f"Below is an instruction that describes a task. "
|
125 |
-
"Write a response that appropriately completes the request.\n\n"
|
126 |
-
f"### Instruction:\n{example['instruction']}\n\n### Response:"
|
127 |
-
)
|
128 |
-
|
129 |
-
|
130 |
-
if __name__ == "__main__":
|
131 |
-
from jsonargparse import CLI
|
132 |
-
|
133 |
-
CLI(prepare)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|