FantasticGNU commited on
Commit
33e8867
1 Parent(s): 46cc5f0

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +4 -4
model/openllama.py CHANGED
@@ -215,17 +215,17 @@ class OpenLLAMAPEFTModel(nn.Module):
215
  # # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
216
  # # self.llama_model.to(torch.float16)
217
  # # try:
218
- self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', load_in_8bit=True, offload_folder="offload1")
219
  # # except:
220
  # pass
221
  # finally:
222
  # print(self.llama_model.hf_device_map)
223
  self.llama_model = get_peft_model(self.llama_model, peft_config)
224
- delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
225
- self.llama_model.load_state_dict(delta_ckpt, strict=False)
226
  self.llama_model.print_trainable_parameters()
227
 
228
- self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16, device_map='auto', offload_folder="offload2")
229
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
230
  self.llama_tokenizer.padding_side = "right"
231
  print ('Language decoder initialized.')
 
215
  # # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
216
  # # self.llama_model.to(torch.float16)
217
  # # try:
218
+ self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', load_in_8bit=True)
219
  # # except:
220
  # pass
221
  # finally:
222
  # print(self.llama_model.hf_device_map)
223
  self.llama_model = get_peft_model(self.llama_model, peft_config)
224
+ # delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
225
+ # self.llama_model.load_state_dict(delta_ckpt, strict=False)
226
  self.llama_model.print_trainable_parameters()
227
 
228
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16)
229
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
230
  self.llama_tokenizer.padding_side = "right"
231
  print ('Language decoder initialized.')