FantasticGNU commited on
Commit
c2ca0ca
1 Parent(s): 22c2b9f

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +4 -4
model/openllama.py CHANGED
@@ -41,7 +41,7 @@ for obj in objs:
41
  for s in prompted_state:
42
  for template in prompt_templates:
43
  prompted_sentence.append(template.format(s))
44
- prompted_sentence = data.load_and_transform_text(prompted_sentence, torch.device('cpu'))#torch.cuda.current_device())
45
  prompt_sentence_obj.append(prompted_sentence)
46
  prompt_sentences[obj] = prompt_sentence_obj
47
 
@@ -199,11 +199,11 @@ class OpenLLAMAPEFTModel(nn.Module):
199
  target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
200
  )
201
 
202
- self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
203
  self.llama_model = get_peft_model(self.llama_model, peft_config)
204
  self.llama_model.print_trainable_parameters()
205
 
206
- self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False)
207
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
208
  self.llama_tokenizer.padding_side = "right"
209
  print ('Language decoder initialized.')
@@ -213,7 +213,7 @@ class OpenLLAMAPEFTModel(nn.Module):
213
  )
214
 
215
  self.max_tgt_len = max_tgt_len
216
- self.device = torch.device('cpu')#torch.cuda.current_device()
217
 
218
 
219
  def rot90_img(self,x,k):
 
41
  for s in prompted_state:
42
  for template in prompt_templates:
43
  prompted_sentence.append(template.format(s))
44
+ prompted_sentence = data.load_and_transform_text(prompted_sentence, torch.cuda.current_device())#torch.cuda.current_device())
45
  prompt_sentence_obj.append(prompted_sentence)
46
  prompt_sentences[obj] = prompt_sentence_obj
47
 
 
199
  target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
200
  )
201
 
202
+ self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload", offload_state_dict = True)
203
  self.llama_model = get_peft_model(self.llama_model, peft_config)
204
  self.llama_model.print_trainable_parameters()
205
 
206
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload", offload_state_dict = True)
207
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
208
  self.llama_tokenizer.padding_side = "right"
209
  print ('Language decoder initialized.')
 
213
  )
214
 
215
  self.max_tgt_len = max_tgt_len
216
+ self.device = torch.cuda.current_device()
217
 
218
 
219
  def rot90_img(self,x,k):