TheBloke commited on
Commit
0209490
1 Parent(s): 785c802

Initial GPTQ model commit

Browse files
Files changed (1) hide show
  1. train_vicuna.py +205 -0
train_vicuna.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastchat.train.llama_flash_attn_monkey_patch import (
2
+ replace_llama_attn_with_flash_attn,
3
+ )
4
+
5
+ replace_llama_attn_with_flash_attn()
6
+
7
+ import json
8
+ from torch.utils.data import Dataset
9
+ from accelerate import Accelerator
10
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, AdamW
11
+ import torch
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from tqdm import tqdm
14
+ import numpy as np
15
+
16
+
17
+ IGNORE_TOKEN_ID = -100
18
+
19
+
20
+ class MixData(Dataset):
21
+ def __init__(self, dataset, ratio, tokenizer):
22
+ super(Dataset, self).__init__()
23
+ self.dataset = dataset
24
+ self.data_size = [len(c) for c in self.dataset]
25
+ ratio = [r if isinstance(r, int) else s for r, s in zip(ratio, self.data_size)]
26
+ self.ratio = ratio
27
+ self.tokenizer = tokenizer
28
+ self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio]
29
+ print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)])
30
+
31
+ @staticmethod
32
+ def rounder(number):
33
+ rand = np.random.rand()
34
+ if rand < number - int(number):
35
+ return int(number) + 1
36
+ else:
37
+ return int(number)
38
+
39
+ @staticmethod
40
+ def choice_index(number, sample_size):
41
+ for i in range(len(sample_size)):
42
+ if number < sum(sample_size[:i + 1]):
43
+ return i, number - sum(sample_size[:i])
44
+
45
+ def __getitem__(self, index):
46
+ corpus_id, index = self.choice_index(index, self.sample_size)
47
+ rand = np.random.rand()
48
+ index = self.rounder((index + rand) / self.sample_size[corpus_id] * self.data_size[corpus_id])
49
+ index = min(index, len(self.dataset[corpus_id]) - 1)
50
+ return self.dataset[corpus_id][index]
51
+
52
+ def __len__(self):
53
+ return sum(self.sample_size)
54
+
55
+ def set_ratio(self, ratio):
56
+ self.ratio = ratio
57
+ self.data_size = [len(c) for c in self.dataset]
58
+ self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio]
59
+ print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)])
60
+
61
+ def collate_fn(self, data):
62
+ input_ids, labels = zip(*data)
63
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
64
+ labels = pad_sequence(labels, batch_first=True, padding_value=-100)
65
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
66
+ features = {
67
+ 'input_ids': input_ids.long(),
68
+ 'labels': labels.long(),
69
+ 'attention_mask': attention_mask.long(),
70
+ }
71
+ return features
72
+
73
+
74
+ def last_index(lst, value):
75
+ return next((len(lst) - i - 1 for i, x in enumerate(lst[::-1]) if x != value), -1)
76
+
77
+
78
+ def safe_ids(ids, max_value, pad_id):
79
+ return [i if i < max_value else pad_id for i in ids]
80
+
81
+
82
+ dummy_message = [{"role": "user", "content": "Who are you?"},
83
+ {"role": "assistant", "content": "I am vicuna, a language model trained by researchers from open-source community."},
84
+ {"role": "user", "content": "What can you do?"},
85
+ {"role": "assistant", "content": "I can chat with you."}]
86
+
87
+
88
+ def tokenize(messages, tokenizer):
89
+ roles = {"user": "USER", "assistant": "ASSISTANT"}
90
+ input_ids = []
91
+ labels = []
92
+ system = "A chat between a curious user and an artificial intelligence assistant. " \
93
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
94
+ system_ids = tokenizer.encode(system, add_special_tokens=False)
95
+ input_ids += system_ids
96
+ labels += [IGNORE_TOKEN_ID] * len(system_ids)
97
+ for i, turn in enumerate(messages):
98
+ role = roles.get(turn['role'], 'USER')
99
+ content = turn['content']
100
+ content = content.strip()
101
+ if role == 'ASSISTANT':
102
+ content += '</s>'
103
+ role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
104
+ content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
105
+ max_length=tokenizer.model_max_length)
106
+ input_ids += role_ids + content_ids
107
+ if role == 'ASSISTANT':
108
+ labels += [IGNORE_TOKEN_ID] * len(role_ids) + content_ids
109
+ else:
110
+ labels += [IGNORE_TOKEN_ID] * (len(role_ids) + len(content_ids))
111
+
112
+ if tokenizer.add_bos_token:
113
+ input_ids = [tokenizer.bos_token_id] + input_ids
114
+ labels = [IGNORE_TOKEN_ID] + labels
115
+
116
+ input_ids = input_ids[:tokenizer.model_max_length]
117
+ labels = labels[:tokenizer.model_max_length]
118
+
119
+ trunc_id = last_index(labels, IGNORE_TOKEN_ID) + 1
120
+ input_ids = input_ids[:trunc_id]
121
+ labels = labels[:trunc_id]
122
+ if len(labels) == 0:
123
+ return tokenize(dummy_message, tokenizer)
124
+ input_ids = safe_ids(input_ids, tokenizer.vocab_size, tokenizer.pad_token_id)
125
+ labels = safe_ids(labels, tokenizer.vocab_size, IGNORE_TOKEN_ID)
126
+ return input_ids, labels
127
+
128
+
129
+ class VicunaData(Dataset):
130
+ def __init__(self, data, tokenizer):
131
+ self.data = data
132
+ self.tokenizer = tokenizer
133
+
134
+ def __len__(self):
135
+ return len(self.data)
136
+
137
+ def __getitem__(self, item):
138
+ item = self.data[item]
139
+ input_ids, labels = tokenize(item, self.tokenizer)
140
+ return torch.tensor(input_ids), torch.tensor(labels)
141
+
142
+ def collate_fn(self, data):
143
+ input_ids, labels = zip(*data)
144
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
145
+ labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_TOKEN_ID)
146
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
147
+ features = {
148
+ 'input_ids': input_ids.long(),
149
+ 'labels': labels.long(),
150
+ 'attention_mask': attention_mask.long(),
151
+ }
152
+ return features
153
+
154
+
155
+ def main():
156
+ accelerator = Accelerator(gradient_accumulation_steps=4)
157
+ batch_size = 4
158
+
159
+ save_path = 'out/baichuan-vicuna-7b'
160
+ model_name = 'fireballoon/baichuan-llama-7b'
161
+
162
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="right", model_max_length=4096)
163
+ tokenizer.pad_token = tokenizer.unk_token
164
+
165
+ model = AutoModelForCausalLM.from_pretrained(model_name)
166
+ model.config.use_cache = False
167
+ model.gradient_checkpointing_enable()
168
+
169
+ dataset = VicunaData(
170
+ json.load(open('data/new/share_gpt-90k.json')) +
171
+ json.load(open('data/new/cot-75k.json')) +
172
+ json.load(open('data/new/leet-9k.json')), tokenizer)
173
+
174
+ print(len(dataset))
175
+
176
+ data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn,
177
+ batch_size=batch_size, num_workers=0, shuffle=True)
178
+
179
+ optimizer = AdamW(model.parameters(), 2e-5)
180
+ model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
181
+
182
+ for epoch in range(10):
183
+ accelerator.print(f'Training {save_path} {epoch}')
184
+ accelerator.wait_for_everyone()
185
+ model.train()
186
+ tk0 = tqdm(data_loader, total=len(data_loader))
187
+ loss_report = []
188
+ for batch in tk0:
189
+ with accelerator.accumulate(model):
190
+ out = model(**batch)
191
+ loss = out.loss
192
+
193
+ accelerator.backward(loss)
194
+ accelerator.clip_grad_norm_(model.parameters(), 1.)
195
+ optimizer.step()
196
+ optimizer.zero_grad()
197
+
198
+ loss_report.append(accelerator.gather(loss).mean().item())
199
+ tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))
200
+ accelerator.wait_for_everyone()
201
+ model.save_checkpoint(f'{save_path}/{epoch}')
202
+
203
+
204
+ if __name__ == '__main__':
205
+ main()