FantasticGNU commited on
Commit
96a4737
1 Parent(s): f2de29b

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +9 -9
model/openllama.py CHANGED
@@ -172,16 +172,16 @@ class OpenLLAMAPEFTModel(nn.Module):
172
  print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
173
 
174
  self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
175
- self.visual_encoder.to(torch.bfloat16).to(self.device)
176
  imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=torch.device('cpu'))
177
  self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
178
 
179
 
180
  self.iter = 0
181
 
182
- self.image_decoder = LinearLayer(1280, 1024, 4).to(torch.bfloat16).to(self.device)
183
 
184
- self.prompt_learner = PromptLearner(1, 4096).to(torch.bfloat16).to(self.device)
185
 
186
  self.loss_focal = FocalLoss()
187
  self.loss_dice = BinaryDiceLoss()
@@ -215,7 +215,7 @@ class OpenLLAMAPEFTModel(nn.Module):
215
  # # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
216
  # # self.llama_model.to(torch.float16)
217
  # # try:
218
- self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.bfloat16, device_map='auto', load_in_8bit=True, offload_folder="offload1")
219
  # # except:
220
  # pass
221
  # finally:
@@ -225,7 +225,7 @@ class OpenLLAMAPEFTModel(nn.Module):
225
  self.llama_model.load_state_dict(delta_ckpt, strict=False)
226
  self.llama_model.print_trainable_parameters()
227
 
228
- self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload2")
229
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
230
  self.llama_tokenizer.padding_side = "right"
231
  print ('Language decoder initialized.')
@@ -634,10 +634,10 @@ class OpenLLAMAPEFTModel(nn.Module):
634
  anomaly_map = (100.0 * patch_tokens[layer] @ feats_text_tensor.transpose(-2,-1))
635
  B, L, C = anomaly_map.shape
636
  H = int(np.sqrt(L))
637
- anomaly_map = anomaly_map.to(torch.float16)
638
  anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
639
  size=224, mode='bilinear', align_corners=True)
640
- anomaly_map = anomaly_map.to(torch.bfloat16)
641
  anomaly_map = torch.softmax(anomaly_map, dim=1)
642
  anomaly_maps.append(anomaly_map[:,1,:,:])
643
 
@@ -661,9 +661,9 @@ class OpenLLAMAPEFTModel(nn.Module):
661
  sims.append(sim_max)
662
 
663
  sim = torch.mean(torch.stack(sims,dim=0), dim=0).reshape(1,1,16,16)
664
- anomaly_map = anomaly_map.to(torch.float16)
665
  sim = F.interpolate(sim,size=224, mode='bilinear', align_corners=True)
666
- anomaly_map = anomaly_map.to(torch.bfloat16)
667
  anomaly_map_ret = 1 - sim # (anomaly_map_ret + 1 - sim) / 2
668
 
669
 
 
172
  print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
173
 
174
  self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
175
+ self.visual_encoder.to(torch.float16).to(self.device)
176
  imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=torch.device('cpu'))
177
  self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
178
 
179
 
180
  self.iter = 0
181
 
182
+ self.image_decoder = LinearLayer(1280, 1024, 4).to(torch.float16).to(self.device)
183
 
184
+ self.prompt_learner = PromptLearner(1, 4096).to(torch.float16).to(self.device)
185
 
186
  self.loss_focal = FocalLoss()
187
  self.loss_dice = BinaryDiceLoss()
 
215
  # # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
216
  # # self.llama_model.to(torch.float16)
217
  # # try:
218
+ self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', load_in_8bit=True, offload_folder="offload1")
219
  # # except:
220
  # pass
221
  # finally:
 
225
  self.llama_model.load_state_dict(delta_ckpt, strict=False)
226
  self.llama_model.print_trainable_parameters()
227
 
228
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16, device_map='auto', offload_folder="offload2")
229
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
230
  self.llama_tokenizer.padding_side = "right"
231
  print ('Language decoder initialized.')
 
634
  anomaly_map = (100.0 * patch_tokens[layer] @ feats_text_tensor.transpose(-2,-1))
635
  B, L, C = anomaly_map.shape
636
  H = int(np.sqrt(L))
637
+ # anomaly_map = anomaly_map.to(torch.float16)
638
  anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
639
  size=224, mode='bilinear', align_corners=True)
640
+ # anomaly_map = anomaly_map.to(torch.bfloat16)
641
  anomaly_map = torch.softmax(anomaly_map, dim=1)
642
  anomaly_maps.append(anomaly_map[:,1,:,:])
643
 
 
661
  sims.append(sim_max)
662
 
663
  sim = torch.mean(torch.stack(sims,dim=0), dim=0).reshape(1,1,16,16)
664
+ # anomaly_map = anomaly_map.to(torch.float16)
665
  sim = F.interpolate(sim,size=224, mode='bilinear', align_corners=True)
666
+ # anomaly_map = anomaly_map.to(torch.bfloat16)
667
  anomaly_map_ret = 1 - sim # (anomaly_map_ret + 1 - sim) / 2
668
 
669