Spaces:
Running
on
T4
Running
on
T4
FantasticGNU
commited on
Commit
•
c2ca0ca
1
Parent(s):
22c2b9f
Update model/openllama.py
Browse files- model/openllama.py +4 -4
model/openllama.py
CHANGED
@@ -41,7 +41,7 @@ for obj in objs:
|
|
41 |
for s in prompted_state:
|
42 |
for template in prompt_templates:
|
43 |
prompted_sentence.append(template.format(s))
|
44 |
-
prompted_sentence = data.load_and_transform_text(prompted_sentence, torch.
|
45 |
prompt_sentence_obj.append(prompted_sentence)
|
46 |
prompt_sentences[obj] = prompt_sentence_obj
|
47 |
|
@@ -199,11 +199,11 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
199 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
200 |
)
|
201 |
|
202 |
-
self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
|
203 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
204 |
self.llama_model.print_trainable_parameters()
|
205 |
|
206 |
-
self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False)
|
207 |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
208 |
self.llama_tokenizer.padding_side = "right"
|
209 |
print ('Language decoder initialized.')
|
@@ -213,7 +213,7 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
213 |
)
|
214 |
|
215 |
self.max_tgt_len = max_tgt_len
|
216 |
-
self.device = torch.
|
217 |
|
218 |
|
219 |
def rot90_img(self,x,k):
|
|
|
41 |
for s in prompted_state:
|
42 |
for template in prompt_templates:
|
43 |
prompted_sentence.append(template.format(s))
|
44 |
+
prompted_sentence = data.load_and_transform_text(prompted_sentence, torch.cuda.current_device())#torch.cuda.current_device())
|
45 |
prompt_sentence_obj.append(prompted_sentence)
|
46 |
prompt_sentences[obj] = prompt_sentence_obj
|
47 |
|
|
|
199 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
200 |
)
|
201 |
|
202 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload", offload_state_dict = True)
|
203 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
204 |
self.llama_model.print_trainable_parameters()
|
205 |
|
206 |
+
self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload", offload_state_dict = True)
|
207 |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
208 |
self.llama_tokenizer.padding_side = "right"
|
209 |
print ('Language decoder initialized.')
|
|
|
213 |
)
|
214 |
|
215 |
self.max_tgt_len = max_tgt_len
|
216 |
+
self.device = torch.cuda.current_device()
|
217 |
|
218 |
|
219 |
def rot90_img(self,x,k):
|