FantasticGNU commited on
Commit
79ec1bc
1 Parent(s): 4223fbb

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +6 -1
model/openllama.py CHANGED
@@ -209,7 +209,12 @@ class OpenLLAMAPEFTModel(nn.Module):
209
  # with init_empty_weights():
210
  # self.llama_model = AutoModelForCausalLM.from_config(config)
211
  # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map="auto", no_split_module_classes=["OPTDecoderLayer"], offload_folder="offload", offload_state_dict = True)
212
- self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
 
 
 
 
 
213
  self.llama_model = get_peft_model(self.llama_model, peft_config)
214
  self.llama_model.print_trainable_parameters()
215
 
 
209
  # with init_empty_weights():
210
  # self.llama_model = AutoModelForCausalLM.from_config(config)
211
  # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map="auto", no_split_module_classes=["OPTDecoderLayer"], offload_folder="offload", offload_state_dict = True)
212
+ try:
213
+ self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
214
+ except:
215
+ pass
216
+ finally:
217
+ print(self.llama_model.hf_device_map)
218
  self.llama_model = get_peft_model(self.llama_model, peft_config)
219
  self.llama_model.print_trainable_parameters()
220