Initial GPTQ model commit
Browse files- train_vicuna.py +205 -0
train_vicuna.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastchat.train.llama_flash_attn_monkey_patch import (
|
2 |
+
replace_llama_attn_with_flash_attn,
|
3 |
+
)
|
4 |
+
|
5 |
+
replace_llama_attn_with_flash_attn()
|
6 |
+
|
7 |
+
import json
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from accelerate import Accelerator
|
10 |
+
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, AdamW
|
11 |
+
import torch
|
12 |
+
from torch.nn.utils.rnn import pad_sequence
|
13 |
+
from tqdm import tqdm
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
IGNORE_TOKEN_ID = -100
|
18 |
+
|
19 |
+
|
20 |
+
class MixData(Dataset):
|
21 |
+
def __init__(self, dataset, ratio, tokenizer):
|
22 |
+
super(Dataset, self).__init__()
|
23 |
+
self.dataset = dataset
|
24 |
+
self.data_size = [len(c) for c in self.dataset]
|
25 |
+
ratio = [r if isinstance(r, int) else s for r, s in zip(ratio, self.data_size)]
|
26 |
+
self.ratio = ratio
|
27 |
+
self.tokenizer = tokenizer
|
28 |
+
self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio]
|
29 |
+
print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)])
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def rounder(number):
|
33 |
+
rand = np.random.rand()
|
34 |
+
if rand < number - int(number):
|
35 |
+
return int(number) + 1
|
36 |
+
else:
|
37 |
+
return int(number)
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def choice_index(number, sample_size):
|
41 |
+
for i in range(len(sample_size)):
|
42 |
+
if number < sum(sample_size[:i + 1]):
|
43 |
+
return i, number - sum(sample_size[:i])
|
44 |
+
|
45 |
+
def __getitem__(self, index):
|
46 |
+
corpus_id, index = self.choice_index(index, self.sample_size)
|
47 |
+
rand = np.random.rand()
|
48 |
+
index = self.rounder((index + rand) / self.sample_size[corpus_id] * self.data_size[corpus_id])
|
49 |
+
index = min(index, len(self.dataset[corpus_id]) - 1)
|
50 |
+
return self.dataset[corpus_id][index]
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return sum(self.sample_size)
|
54 |
+
|
55 |
+
def set_ratio(self, ratio):
|
56 |
+
self.ratio = ratio
|
57 |
+
self.data_size = [len(c) for c in self.dataset]
|
58 |
+
self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio]
|
59 |
+
print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)])
|
60 |
+
|
61 |
+
def collate_fn(self, data):
|
62 |
+
input_ids, labels = zip(*data)
|
63 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
64 |
+
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
|
65 |
+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
|
66 |
+
features = {
|
67 |
+
'input_ids': input_ids.long(),
|
68 |
+
'labels': labels.long(),
|
69 |
+
'attention_mask': attention_mask.long(),
|
70 |
+
}
|
71 |
+
return features
|
72 |
+
|
73 |
+
|
74 |
+
def last_index(lst, value):
|
75 |
+
return next((len(lst) - i - 1 for i, x in enumerate(lst[::-1]) if x != value), -1)
|
76 |
+
|
77 |
+
|
78 |
+
def safe_ids(ids, max_value, pad_id):
|
79 |
+
return [i if i < max_value else pad_id for i in ids]
|
80 |
+
|
81 |
+
|
82 |
+
dummy_message = [{"role": "user", "content": "Who are you?"},
|
83 |
+
{"role": "assistant", "content": "I am vicuna, a language model trained by researchers from open-source community."},
|
84 |
+
{"role": "user", "content": "What can you do?"},
|
85 |
+
{"role": "assistant", "content": "I can chat with you."}]
|
86 |
+
|
87 |
+
|
88 |
+
def tokenize(messages, tokenizer):
|
89 |
+
roles = {"user": "USER", "assistant": "ASSISTANT"}
|
90 |
+
input_ids = []
|
91 |
+
labels = []
|
92 |
+
system = "A chat between a curious user and an artificial intelligence assistant. " \
|
93 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
94 |
+
system_ids = tokenizer.encode(system, add_special_tokens=False)
|
95 |
+
input_ids += system_ids
|
96 |
+
labels += [IGNORE_TOKEN_ID] * len(system_ids)
|
97 |
+
for i, turn in enumerate(messages):
|
98 |
+
role = roles.get(turn['role'], 'USER')
|
99 |
+
content = turn['content']
|
100 |
+
content = content.strip()
|
101 |
+
if role == 'ASSISTANT':
|
102 |
+
content += '</s>'
|
103 |
+
role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
|
104 |
+
content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
|
105 |
+
max_length=tokenizer.model_max_length)
|
106 |
+
input_ids += role_ids + content_ids
|
107 |
+
if role == 'ASSISTANT':
|
108 |
+
labels += [IGNORE_TOKEN_ID] * len(role_ids) + content_ids
|
109 |
+
else:
|
110 |
+
labels += [IGNORE_TOKEN_ID] * (len(role_ids) + len(content_ids))
|
111 |
+
|
112 |
+
if tokenizer.add_bos_token:
|
113 |
+
input_ids = [tokenizer.bos_token_id] + input_ids
|
114 |
+
labels = [IGNORE_TOKEN_ID] + labels
|
115 |
+
|
116 |
+
input_ids = input_ids[:tokenizer.model_max_length]
|
117 |
+
labels = labels[:tokenizer.model_max_length]
|
118 |
+
|
119 |
+
trunc_id = last_index(labels, IGNORE_TOKEN_ID) + 1
|
120 |
+
input_ids = input_ids[:trunc_id]
|
121 |
+
labels = labels[:trunc_id]
|
122 |
+
if len(labels) == 0:
|
123 |
+
return tokenize(dummy_message, tokenizer)
|
124 |
+
input_ids = safe_ids(input_ids, tokenizer.vocab_size, tokenizer.pad_token_id)
|
125 |
+
labels = safe_ids(labels, tokenizer.vocab_size, IGNORE_TOKEN_ID)
|
126 |
+
return input_ids, labels
|
127 |
+
|
128 |
+
|
129 |
+
class VicunaData(Dataset):
|
130 |
+
def __init__(self, data, tokenizer):
|
131 |
+
self.data = data
|
132 |
+
self.tokenizer = tokenizer
|
133 |
+
|
134 |
+
def __len__(self):
|
135 |
+
return len(self.data)
|
136 |
+
|
137 |
+
def __getitem__(self, item):
|
138 |
+
item = self.data[item]
|
139 |
+
input_ids, labels = tokenize(item, self.tokenizer)
|
140 |
+
return torch.tensor(input_ids), torch.tensor(labels)
|
141 |
+
|
142 |
+
def collate_fn(self, data):
|
143 |
+
input_ids, labels = zip(*data)
|
144 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
145 |
+
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_TOKEN_ID)
|
146 |
+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
|
147 |
+
features = {
|
148 |
+
'input_ids': input_ids.long(),
|
149 |
+
'labels': labels.long(),
|
150 |
+
'attention_mask': attention_mask.long(),
|
151 |
+
}
|
152 |
+
return features
|
153 |
+
|
154 |
+
|
155 |
+
def main():
|
156 |
+
accelerator = Accelerator(gradient_accumulation_steps=4)
|
157 |
+
batch_size = 4
|
158 |
+
|
159 |
+
save_path = 'out/baichuan-vicuna-7b'
|
160 |
+
model_name = 'fireballoon/baichuan-llama-7b'
|
161 |
+
|
162 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="right", model_max_length=4096)
|
163 |
+
tokenizer.pad_token = tokenizer.unk_token
|
164 |
+
|
165 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
166 |
+
model.config.use_cache = False
|
167 |
+
model.gradient_checkpointing_enable()
|
168 |
+
|
169 |
+
dataset = VicunaData(
|
170 |
+
json.load(open('data/new/share_gpt-90k.json')) +
|
171 |
+
json.load(open('data/new/cot-75k.json')) +
|
172 |
+
json.load(open('data/new/leet-9k.json')), tokenizer)
|
173 |
+
|
174 |
+
print(len(dataset))
|
175 |
+
|
176 |
+
data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn,
|
177 |
+
batch_size=batch_size, num_workers=0, shuffle=True)
|
178 |
+
|
179 |
+
optimizer = AdamW(model.parameters(), 2e-5)
|
180 |
+
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
|
181 |
+
|
182 |
+
for epoch in range(10):
|
183 |
+
accelerator.print(f'Training {save_path} {epoch}')
|
184 |
+
accelerator.wait_for_everyone()
|
185 |
+
model.train()
|
186 |
+
tk0 = tqdm(data_loader, total=len(data_loader))
|
187 |
+
loss_report = []
|
188 |
+
for batch in tk0:
|
189 |
+
with accelerator.accumulate(model):
|
190 |
+
out = model(**batch)
|
191 |
+
loss = out.loss
|
192 |
+
|
193 |
+
accelerator.backward(loss)
|
194 |
+
accelerator.clip_grad_norm_(model.parameters(), 1.)
|
195 |
+
optimizer.step()
|
196 |
+
optimizer.zero_grad()
|
197 |
+
|
198 |
+
loss_report.append(accelerator.gather(loss).mean().item())
|
199 |
+
tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))
|
200 |
+
accelerator.wait_for_everyone()
|
201 |
+
model.save_checkpoint(f'{save_path}/{epoch}')
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == '__main__':
|
205 |
+
main()
|