helloollel commited on
Commit
3044532
1 Parent(s): ff50e49

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +238 -3
README.md CHANGED
@@ -1,3 +1,238 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # vicuna-7b
3
+
4
+ This README provides a step-by-step guide to set up and run the FastChat application with the required dependencies and model.
5
+
6
+ ## Prerequisites
7
+
8
+ Before you proceed, ensure that you have `git` installed on your system.
9
+
10
+ ## Installation
11
+
12
+ Follow the steps below to install the required packages and set up the environment.
13
+
14
+ 1. Upgrade `pip`:
15
+
16
+ ```bash
17
+ python3 -m pip install --upgrade pip
18
+ ```
19
+
20
+ 2. Install `accelerate`:
21
+
22
+ ```bash
23
+ python3 -m pip install accelerate
24
+ ```
25
+
26
+ 3. Clone the `bitsandbytes` repository and install it:
27
+
28
+ ```bash
29
+ git clone https://github.com/TimDettmers/bitsandbytes.git
30
+ cd bitsandbytes
31
+ CUDA_VERSION=118 make cuda11x
32
+ python3 -m pip install .
33
+ cd ..
34
+ ```
35
+
36
+ 4. Clone the `FastChat` repository and install it:
37
+
38
+ ```bash
39
+ git clone https://github.com/lm-sys/FastChat.git
40
+ cd FastChat
41
+ python3 -m pip install -e .
42
+ cd ..
43
+ ```
44
+
45
+ 5. Install `git-lfs`:
46
+
47
+ ```bash
48
+ curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
49
+ sudo apt-get install git-lfs
50
+ git lfs install
51
+ ```
52
+
53
+ 6. Clone the `vicuna-7b` model:
54
+
55
+ ```bash
56
+ git clone https://huggingface.co/helloollel/vicuna-7b
57
+ ```
58
+
59
+ ## Running FastChat
60
+
61
+ After completing the installation, you can run FastChat with the following command:
62
+
63
+ ```bash
64
+ python3 -m fastchat.serve.cli --model-name ./vicuna-7b
65
+ ```
66
+
67
+ This will start the FastChat server using the `vicuna-7b` model.
68
+
69
+ ## Running in Notebook
70
+
71
+ ```python
72
+ import argparse
73
+ import time
74
+
75
+ import torch
76
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
77
+
78
+ from fastchat.conversation import conv_templates, SeparatorStyle
79
+ from fastchat.serve.monkey_patch_non_inplace import replace_llama_attn_with_non_inplace_operations
80
+
81
+
82
+ def load_model(model_name, device, num_gpus, load_8bit=False):
83
+ if device == "cpu":
84
+ kwargs = {}
85
+ elif device == "cuda":
86
+ kwargs = {"torch_dtype": torch.float16}
87
+ if load_8bit:
88
+ if num_gpus != "auto" and int(num_gpus) != 1:
89
+ print("8-bit weights are not supported on multiple GPUs. Revert to use one GPU.")
90
+ kwargs.update({"load_in_8bit": True, "device_map": "auto"})
91
+ else:
92
+ if num_gpus == "auto":
93
+ kwargs["device_map"] = "auto"
94
+ else:
95
+ num_gpus = int(num_gpus)
96
+ if num_gpus != 1:
97
+ kwargs.update({
98
+ "device_map": "auto",
99
+ "max_memory": {i: "13GiB" for i in range(num_gpus)},
100
+ })
101
+ elif device == "mps":
102
+ # Avoid bugs in mps backend by not using in-place operations.
103
+ kwargs = {"torch_dtype": torch.float16}
104
+ replace_llama_attn_with_non_inplace_operations()
105
+ else:
106
+ raise ValueError(f"Invalid device: {device}")
107
+
108
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
109
+ model = AutoModelForCausalLM.from_pretrained(model_name,
110
+ low_cpu_mem_usage=True, **kwargs)
111
+
112
+ # calling model.cuda() mess up weights if loading 8-bit weights
113
+ if device == "cuda" and num_gpus == 1 and not load_8bit:
114
+ model.to("cuda")
115
+ elif device == "mps":
116
+ model.to("mps")
117
+
118
+ return model, tokenizer
119
+
120
+
121
+ @torch.inference_mode()
122
+ def generate_stream(tokenizer, model, params, device,
123
+ context_len=2048, stream_interval=2):
124
+ """Adapted from fastchat/serve/model_worker.py::generate_stream"""
125
+
126
+ prompt = params["prompt"]
127
+ l_prompt = len(prompt)
128
+ temperature = float(params.get("temperature", 1.0))
129
+ max_new_tokens = int(params.get("max_new_tokens", 256))
130
+ stop_str = params.get("stop", None)
131
+
132
+ input_ids = tokenizer(prompt).input_ids
133
+ output_ids = list(input_ids)
134
+
135
+ max_src_len = context_len - max_new_tokens - 8
136
+ input_ids = input_ids[-max_src_len:]
137
+
138
+ for i in range(max_new_tokens):
139
+ if i == 0:
140
+ out = model(
141
+ torch.as_tensor([input_ids], device=device), use_cache=True)
142
+ logits = out.logits
143
+ past_key_values = out.past_key_values
144
+ else:
145
+ attention_mask = torch.ones(
146
+ 1, past_key_values[0][0].shape[-2] + 1, device=device)
147
+ out = model(input_ids=torch.as_tensor([[token]], device=device),
148
+ use_cache=True,
149
+ attention_mask=attention_mask,
150
+ past_key_values=past_key_values)
151
+ logits = out.logits
152
+ past_key_values = out.past_key_values
153
+
154
+ last_token_logits = logits[0][-1]
155
+
156
+ if device == "mps":
157
+ # Switch to CPU by avoiding some bugs in mps backend.
158
+ last_token_logits = last_token_logits.float().to("cpu")
159
+
160
+ if temperature < 1e-4:
161
+ token = int(torch.argmax(last_token_logits))
162
+ else:
163
+ probs = torch.softmax(last_token_logits / temperature, dim=-1)
164
+ token = int(torch.multinomial(probs, num_samples=1))
165
+
166
+ output_ids.append(token)
167
+
168
+ if token == tokenizer.eos_token_id:
169
+ stopped = True
170
+ else:
171
+ stopped = False
172
+
173
+ if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
174
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
175
+ pos = output.rfind(stop_str, l_prompt)
176
+ if pos != -1:
177
+ output = output[:pos]
178
+ stopped = True
179
+ yield output
180
+
181
+ if stopped:
182
+ break
183
+
184
+ del past_key_values
185
+
186
+ args = dict(
187
+ model_name='./vicuna-7b',
188
+ device='cuda',
189
+ num_gpus='1',
190
+ load_8bit=True,
191
+ conv_template='v1',
192
+ temperature=0.7,
193
+ max_new_tokens=512,
194
+ debug=False
195
+ )
196
+
197
+ args = argparse.Namespace(**args)
198
+
199
+ model_name = args.model_name
200
+
201
+ # Model
202
+ model, tokenizer = load_model(args.model_name, args.device,
203
+ args.num_gpus, args.load_8bit)
204
+
205
+ # Chat
206
+ conv = conv_templates[args.conv_template].copy()
207
+
208
+ def chat(inp):
209
+ conv.append_message(conv.roles[0], inp)
210
+ conv.append_message(conv.roles[1], None)
211
+ prompt = conv.get_prompt()
212
+
213
+ params = {
214
+ "model": model_name,
215
+ "prompt": prompt,
216
+ "temperature": args.temperature,
217
+ "max_new_tokens": args.max_new_tokens,
218
+ "stop": conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2,
219
+ }
220
+
221
+ print(f"{conv.roles[1]}: ", end="", flush=True)
222
+ pre = 0
223
+ for outputs in generate_stream(tokenizer, model, params, args.device):
224
+ outputs = outputs[len(prompt) + 1:].strip()
225
+ outputs = outputs.split(" ")
226
+ now = len(outputs)
227
+ if now - 1 > pre:
228
+ print(" ".join(outputs[pre:now-1]), end=" ", flush=True)
229
+ pre = now - 1
230
+ print(" ".join(outputs[pre:]), flush=True)
231
+
232
+ conv.messages[-1][-1] = " ".join(outputs)
233
+ ```
234
+
235
+ ```python
236
+ chat("what's the meaning of life?")
237
+ ```
238
+