FantasticGNU commited on
Commit
04a232e
1 Parent(s): c37f381

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +2 -0
model/openllama.py CHANGED
@@ -221,6 +221,8 @@ class OpenLLAMAPEFTModel(nn.Module):
221
  # finally:
222
  # print(self.llama_model.hf_device_map)
223
  self.llama_model = get_peft_model(self.llama_model, peft_config)
 
 
224
  self.llama_model.print_trainable_parameters()
225
 
226
  self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload2", offload_state_dict = True)
 
221
  # finally:
222
  # print(self.llama_model.hf_device_map)
223
  self.llama_model = get_peft_model(self.llama_model, peft_config)
224
+ delta_ckpt = torch.load(args['delta_ckpt_path'])
225
+ self.llama_model.load_state_dict(delta_ckpt, strict=False)
226
  self.llama_model.print_trainable_parameters()
227
 
228
  self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload2", offload_state_dict = True)