FantasticGNU commited on
Commit
ab9dae6
1 Parent(s): 4b9278e

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +5 -4
model/openllama.py CHANGED
@@ -172,15 +172,16 @@ class OpenLLAMAPEFTModel(nn.Module):
172
  self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
173
  imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=torch.device('cpu'))
174
  self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
 
175
 
176
  self.iter = 0
177
 
178
- self.image_decoder = LinearLayer(1280, 1024, 4).to(self.device)
179
 
180
- self.prompt_learner = PromptLearner(1, 4096).to(self.device)
181
 
182
- self.loss_focal = FocalLoss().to(self.device)
183
- self.loss_dice = BinaryDiceLoss().to(self.device)
184
 
185
 
186
  # free vision encoder
 
172
  self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
173
  imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=torch.device('cpu'))
174
  self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
175
+ self.visual_encoder.to(self.device)
176
 
177
  self.iter = 0
178
 
179
+ self.image_decoder = LinearLayer(1280, 1024, 4)
180
 
181
+ self.prompt_learner = PromptLearner(1, 4096)
182
 
183
+ self.loss_focal = FocalLoss()
184
+ self.loss_dice = BinaryDiceLoss()
185
 
186
 
187
  # free vision encoder