FantasticGNU commited on
Commit
0d05e34
1 Parent(s): 93457f4

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +7 -6
model/openllama.py CHANGED
@@ -170,15 +170,16 @@ class OpenLLAMAPEFTModel(nn.Module):
170
  print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
171
 
172
  self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
173
- imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=torch.device('cpu'))
174
- self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
175
  self.visual_encoder.to(self.device)
 
 
 
176
 
177
  self.iter = 0
178
 
179
- self.image_decoder = LinearLayer(1280, 1024, 4)
180
 
181
- self.prompt_learner = PromptLearner(1, 4096)
182
 
183
  self.loss_focal = FocalLoss()
184
  self.loss_dice = BinaryDiceLoss()
@@ -202,11 +203,11 @@ class OpenLLAMAPEFTModel(nn.Module):
202
  target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
203
  )
204
 
205
- self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload", offload_state_dict = True)
206
  self.llama_model = get_peft_model(self.llama_model, peft_config)
207
  self.llama_model.print_trainable_parameters()
208
 
209
- self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload", offload_state_dict = True)
210
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
211
  self.llama_tokenizer.padding_side = "right"
212
  print ('Language decoder initialized.')
 
170
  print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
171
 
172
  self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
 
 
173
  self.visual_encoder.to(self.device)
174
+ imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=self.device)
175
+ self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
176
+
177
 
178
  self.iter = 0
179
 
180
+ self.image_decoder = LinearLayer(1280, 1024, 4).to(self.device)
181
 
182
+ self.prompt_learner = PromptLearner(1, 4096).to(self.device)
183
 
184
  self.loss_focal = FocalLoss()
185
  self.loss_dice = BinaryDiceLoss()
 
203
  target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
204
  )
205
 
206
+ self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
207
  self.llama_model = get_peft_model(self.llama_model, peft_config)
208
  self.llama_model.print_trainable_parameters()
209
 
210
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
211
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
212
  self.llama_tokenizer.padding_side = "right"
213
  print ('Language decoder initialized.')