Text Generation
Transformers
Safetensors
openelm
custom_code
mahyar-najibi commited on
Commit
c853cdb
1 Parent(s): f93c421

Add the generate module.

Browse files
Files changed (1) hide show
  1. generate_openelm.py +235 -0
generate_openelm.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module to generate OpenELM output given a model and an input prompt."""
2
+ import os
3
+ import logging
4
+ import time
5
+ import argparse
6
+ from typing import Optional, Union
7
+ import torch
8
+
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+
11
+
12
+ def generate(
13
+ prompt: str,
14
+ model: Union[str, AutoModelForCausalLM],
15
+ hf_access_token: str = None,
16
+ tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf',
17
+ device: Optional[str] = None,
18
+ max_length: int = 1024,
19
+ assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None,
20
+ generate_kwargs: Optional[dict] = None,
21
+ ) -> str:
22
+ """ Generates output given a prompt.
23
+
24
+ Args:
25
+ prompt: The string prompt.
26
+ model: The LLM Model. If a string is passed, it should be the path to
27
+ the hf converted checkpoint.
28
+ hf_access_token: Hugging face access token.
29
+ tokenizer: Tokenizer instance. If model is set as a string path,
30
+ the tokenizer will be loaded from the checkpoint.
31
+ device: String representation of device to run the model on. If None
32
+ and cuda available it would be set to cuda:0 else cpu.
33
+ max_length: Maximum length of tokens, input prompt + generated tokens.
34
+ assistant_model: If set, this model will be used for
35
+ speculative generation. If a string is passed, it should be the
36
+ path to the hf converted checkpoint.
37
+ generate_kwargs: Extra kwargs passed to the hf generate function.
38
+
39
+ Returns:
40
+ output_text: output generated as a string.
41
+ generation_time: generation time in seconds.
42
+
43
+ Raises:
44
+ ValueError: If device is set to CUDA but no CUDA device is detected.
45
+ ValueError: If tokenizer is not set.
46
+ ValueError: If hf_access_token is not specified.
47
+ """
48
+ if not device:
49
+ if torch.cuda.is_available() and torch.cuda.device_count():
50
+ device = "cuda:0"
51
+ logging.warning(
52
+ 'inference device is not set, using cuda:0, %s',
53
+ torch.cuda.get_device_name(0)
54
+ )
55
+ else:
56
+ device = 'cpu'
57
+ logging.warning(
58
+ (
59
+ 'No CUDA device detected, using cpu, '
60
+ 'expect slower speeds.'
61
+ )
62
+ )
63
+
64
+ if 'cuda' in device and not torch.cuda.is_available():
65
+ raise ValueError('CUDA device requested but no CUDA device detected.')
66
+
67
+ if not tokenizer:
68
+ raise ValueError('Tokenizer is not set in the generate function.')
69
+
70
+ if not hf_access_token:
71
+ raise ValueError((
72
+ 'Hugging face access token needs to be specified. '
73
+ 'Please refer to https://huggingface.co/docs/hub/security-tokens'
74
+ ' to obtain one.'
75
+ )
76
+ )
77
+
78
+ if isinstance(model, str):
79
+ checkpoint_path = model
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ checkpoint_path,
82
+ trust_remote_code=True
83
+ )
84
+ model.to(device).eval()
85
+ if isinstance(tokenizer, str):
86
+ tokenizer = AutoTokenizer.from_pretrained(
87
+ tokenizer,
88
+ token=hf_access_token,
89
+ )
90
+
91
+ # Speculative mode
92
+ draft_model = None
93
+ if assistant_model:
94
+ draft_model = assistant_model
95
+ if isinstance(assistant_model, str):
96
+ draft_model = AutoModelForCausalLM.from_pretrained(
97
+ assistant_model,
98
+ trust_remote_code=True
99
+ )
100
+ draft_model.to(device).eval()
101
+
102
+ # Prepare the prompt
103
+ tokenized_prompt = tokenizer(prompt)
104
+ tokenized_prompt = torch.tensor(
105
+ tokenized_prompt['input_ids'],
106
+ device=device
107
+ )
108
+
109
+ tokenized_prompt = tokenized_prompt.unsqueeze(0)
110
+
111
+ # Generate
112
+ stime = time.time()
113
+ output_ids = model.generate(
114
+ tokenized_prompt,
115
+ max_length=max_length,
116
+ pad_token_id=0,
117
+ assistant_model=draft_model,
118
+ **(generate_kwargs if generate_kwargs else {}),
119
+ )
120
+ generation_time = time.time() - stime
121
+
122
+ output_text = tokenizer.decode(
123
+ output_ids[0].tolist(),
124
+ skip_special_tokens=True
125
+ )
126
+
127
+ return output_text, generation_time
128
+
129
+
130
+ def openelm_generate_parser():
131
+ """Argument Parser"""
132
+
133
+ class KwargsParser(argparse.Action):
134
+ """Parser action class to parse kwargs of form key=value"""
135
+ def __call__(self, parser, namespace, values, option_string=None):
136
+ setattr(namespace, self.dest, dict())
137
+ for val in values:
138
+ if '=' not in val:
139
+ raise ValueError(
140
+ (
141
+ 'Argument parsing error, kwargs are expected in'
142
+ ' the form of key=value.'
143
+ )
144
+ )
145
+ kwarg_k, kwarg_v = val.split('=')
146
+ try:
147
+ converted_v = int(kwarg_v)
148
+ except ValueError:
149
+ try:
150
+ converted_v = float(kwarg_v)
151
+ except ValueError:
152
+ converted_v = kwarg_v
153
+ getattr(namespace, self.dest)[kwarg_k] = converted_v
154
+
155
+ parser = argparse.ArgumentParser('OpenELM Generate Module')
156
+ parser.add_argument(
157
+ '--model',
158
+ dest='model',
159
+ help='Path to the hf converted model.',
160
+ required=True,
161
+ type=str,
162
+ )
163
+ parser.add_argument(
164
+ '--hf_access_token',
165
+ dest='hf_access_token',
166
+ help='Hugging face access token, starting with "hf_".',
167
+ type=str,
168
+ )
169
+ parser.add_argument(
170
+ '--prompt',
171
+ dest='prompt',
172
+ help='Prompt for LLM call.',
173
+ default='',
174
+ type=str,
175
+ )
176
+ parser.add_argument(
177
+ '--device',
178
+ dest='device',
179
+ help='Device used for inference.',
180
+ type=str,
181
+ )
182
+ parser.add_argument(
183
+ '--max_length',
184
+ dest='max_length',
185
+ help='Maximum length of tokens.',
186
+ default=256,
187
+ type=int,
188
+ )
189
+ parser.add_argument(
190
+ '--assistant_model',
191
+ dest='assistant_model',
192
+ help=(
193
+ (
194
+ 'If set, this is used as a draft model '
195
+ 'for assisted speculative generation.'
196
+ )
197
+ ),
198
+ type=str,
199
+ )
200
+ parser.add_argument(
201
+ '--generate_kwargs',
202
+ dest='generate_kwargs',
203
+ help='Additional kwargs passed to the HF generate function.',
204
+ type=str,
205
+ nargs='*',
206
+ action=KwargsParser,
207
+ )
208
+ return parser.parse_args()
209
+
210
+
211
+ if __name__ == '__main__':
212
+ args = openelm_generate_parser()
213
+ prompt = args.prompt
214
+
215
+ output_text, genertaion_time = generate(
216
+ prompt=prompt,
217
+ model=args.model,
218
+ device=args.device,
219
+ max_length=args.max_length,
220
+ assistant_model=args.assistant_model,
221
+ generate_kwargs=args.generate_kwargs,
222
+ hf_access_token=args.hf_access_token,
223
+ )
224
+
225
+ print_txt = (
226
+ f'\r\n{"=" * os.get_terminal_size().columns}\r\n'
227
+ '\033[1m Prompt + Generated Output\033[0m\r\n'
228
+ f'{"-" * os.get_terminal_size().columns}\r\n'
229
+ f'{output_text}\r\n'
230
+ f'{"-" * os.get_terminal_size().columns}\r\n'
231
+ '\r\nGeneration took'
232
+ f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m'
233
+ 'seconds.\r\n'
234
+ )
235
+ print(print_txt)