Spaces:
Sleeping
Sleeping
FantasticGNU
commited on
Commit
•
4b9278e
1
Parent(s):
c2ca0ca
Update model/openllama.py
Browse files- model/openllama.py +7 -5
model/openllama.py
CHANGED
@@ -165,6 +165,8 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
165 |
max_tgt_len = args['max_tgt_len']
|
166 |
stage = args['stage']
|
167 |
|
|
|
|
|
168 |
print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
|
169 |
|
170 |
self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
|
@@ -173,12 +175,12 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
173 |
|
174 |
self.iter = 0
|
175 |
|
176 |
-
self.image_decoder = LinearLayer(1280, 1024, 4)
|
177 |
|
178 |
-
self.prompt_learner = PromptLearner(1, 4096)
|
179 |
|
180 |
-
self.loss_focal = FocalLoss()
|
181 |
-
self.loss_dice = BinaryDiceLoss()
|
182 |
|
183 |
|
184 |
# free vision encoder
|
@@ -213,7 +215,7 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
213 |
)
|
214 |
|
215 |
self.max_tgt_len = max_tgt_len
|
216 |
-
|
217 |
|
218 |
|
219 |
def rot90_img(self,x,k):
|
|
|
165 |
max_tgt_len = args['max_tgt_len']
|
166 |
stage = args['stage']
|
167 |
|
168 |
+
self.device = torch.cuda.current_device()
|
169 |
+
|
170 |
print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
|
171 |
|
172 |
self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
|
|
|
175 |
|
176 |
self.iter = 0
|
177 |
|
178 |
+
self.image_decoder = LinearLayer(1280, 1024, 4).to(self.device)
|
179 |
|
180 |
+
self.prompt_learner = PromptLearner(1, 4096).to(self.device)
|
181 |
|
182 |
+
self.loss_focal = FocalLoss().to(self.device)
|
183 |
+
self.loss_dice = BinaryDiceLoss().to(self.device)
|
184 |
|
185 |
|
186 |
# free vision encoder
|
|
|
215 |
)
|
216 |
|
217 |
self.max_tgt_len = max_tgt_len
|
218 |
+
|
219 |
|
220 |
|
221 |
def rot90_img(self,x,k):
|