FantasticGNU commited on
Commit
4b9278e
1 Parent(s): c2ca0ca

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +7 -5
model/openllama.py CHANGED
@@ -165,6 +165,8 @@ class OpenLLAMAPEFTModel(nn.Module):
165
  max_tgt_len = args['max_tgt_len']
166
  stage = args['stage']
167
 
 
 
168
  print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
169
 
170
  self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
@@ -173,12 +175,12 @@ class OpenLLAMAPEFTModel(nn.Module):
173
 
174
  self.iter = 0
175
 
176
- self.image_decoder = LinearLayer(1280, 1024, 4)
177
 
178
- self.prompt_learner = PromptLearner(1, 4096)
179
 
180
- self.loss_focal = FocalLoss()
181
- self.loss_dice = BinaryDiceLoss()
182
 
183
 
184
  # free vision encoder
@@ -213,7 +215,7 @@ class OpenLLAMAPEFTModel(nn.Module):
213
  )
214
 
215
  self.max_tgt_len = max_tgt_len
216
- self.device = torch.cuda.current_device()
217
 
218
 
219
  def rot90_img(self,x,k):
 
165
  max_tgt_len = args['max_tgt_len']
166
  stage = args['stage']
167
 
168
+ self.device = torch.cuda.current_device()
169
+
170
  print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
171
 
172
  self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
 
175
 
176
  self.iter = 0
177
 
178
+ self.image_decoder = LinearLayer(1280, 1024, 4).to(self.device)
179
 
180
+ self.prompt_learner = PromptLearner(1, 4096).to(self.device)
181
 
182
+ self.loss_focal = FocalLoss().to(self.device)
183
+ self.loss_dice = BinaryDiceLoss().to(self.device)
184
 
185
 
186
  # free vision encoder
 
215
  )
216
 
217
  self.max_tgt_len = max_tgt_len
218
+
219
 
220
 
221
  def rot90_img(self,x,k):