Update visual.py
Browse files
visual.py
CHANGED
@@ -25,13 +25,11 @@ def sliding_window(matrix, window_size, stride):
|
|
25 |
window_cols = (width - window_size[1]) // stride + 1
|
26 |
images_448 = F.interpolate(matrix, size=window_size, mode='bicubic')
|
27 |
windows = []
|
28 |
-
# pdb.set_trace()
|
29 |
for i in range(window_rows):
|
30 |
windows_col = []
|
31 |
for j in range(window_cols):
|
32 |
window = matrix[:,:, i*stride:i*stride+window_size[0], j*stride:j*stride+window_size[1]]
|
33 |
windows.append(window)
|
34 |
-
# windows.append(windows_col)
|
35 |
windows.append(images_448)
|
36 |
images = torch.cat(windows,dim=1)
|
37 |
images = images.reshape(b*5,c,window_size[0], window_size[0])
|
@@ -145,12 +143,9 @@ class Resampler(nn.Module):
|
|
145 |
self.ln_kv = norm_layer(embed_dim)
|
146 |
|
147 |
self.apply(self._init_weights)
|
148 |
-
|
149 |
-
#self.load_state_dict(torch.load('/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth'))
|
150 |
|
151 |
def _init_weights(self, m):
|
152 |
-
# self.load_state_dict(torch.load('/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth'))
|
153 |
-
#pdb.set_trace()
|
154 |
if isinstance(m, nn.Linear):
|
155 |
trunc_normal_(m.weight, std=.02)
|
156 |
if isinstance(m, nn.Linear) and m.bias is not None:
|
@@ -160,7 +155,6 @@ class Resampler(nn.Module):
|
|
160 |
nn.init.constant_(m.weight, 1.0)
|
161 |
|
162 |
def forward(self, x, attn_mask=None):
|
163 |
-
#pdb.set_trace()
|
164 |
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
165 |
|
166 |
x = self.kv_proj(x)
|
@@ -401,7 +395,6 @@ class VisionTransformer(nn.Module):
|
|
401 |
act_layer=act_layer,
|
402 |
norm_layer=norm_layer,
|
403 |
)
|
404 |
-
# pdb.set_trace()
|
405 |
self.attn_pool = Resampler(
|
406 |
grid_size=int(math.sqrt(n_queries)),
|
407 |
embed_dim=output_dim,
|
@@ -418,14 +411,10 @@ class VisionTransformer(nn.Module):
|
|
418 |
)
|
419 |
self.ln_post = norm_layer(output_dim)
|
420 |
self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
|
421 |
-
# self.attn_pool2.load_state_dict(torch.load('/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth'))
|
422 |
|
423 |
-
# def initialize_vision_modules(self,lpath):
|
424 |
-
# self.attn_pool2[0].load_state_dict(torch.load(lpath))
|
425 |
|
426 |
def forward(self, x: torch.Tensor):
|
427 |
-
|
428 |
-
#torch.save(self.attn_pool.state_dict(), '/cfs/cfs-lugcocyb/mingdali/code/qWen-VL/vl-chat/attn_params.pth')
|
429 |
x = x.to(
|
430 |
dtype=self.transformer.get_cast_dtype(),
|
431 |
device=self.transformer.get_cast_device(),
|
@@ -442,7 +431,6 @@ class VisionTransformer(nn.Module):
|
|
442 |
x = x.permute(1, 0, 2) # NLD -> LND
|
443 |
x = self.transformer(x)
|
444 |
x = x.permute(1, 0, 2) # LND -> NLD
|
445 |
-
# pdb.set_trace()
|
446 |
src_size = int(math.sqrt(x.shape[1]))
|
447 |
x = x.reshape(x.shape[0]//5,5,-1, x.shape[-1])
|
448 |
x1 = x[:,4,:,:]
|
@@ -454,7 +442,6 @@ class VisionTransformer(nn.Module):
|
|
454 |
x1 = self.attn_pool(x1)
|
455 |
x = self.post_pro(x)
|
456 |
x1 = self.post_pro(x1)
|
457 |
-
# return x1
|
458 |
return torch.cat([x,x1],dim=1)
|
459 |
|
460 |
def post_pro(self, x):
|
@@ -465,7 +452,7 @@ class VisionTransformer(nn.Module):
|
|
465 |
|
466 |
def encode(self, image_paths: List[str]):
|
467 |
images = []
|
468 |
-
|
469 |
for image_path in image_paths:
|
470 |
try:
|
471 |
if image_path.startswith("http://") or image_path.startswith("https://"):
|
@@ -474,7 +461,6 @@ class VisionTransformer(nn.Module):
|
|
474 |
image = self.image_transform(Image.open(image_path).convert("RGB"))
|
475 |
except:
|
476 |
image = torch.zeros((3, 448*2, 448*2))
|
477 |
-
# pdb.set_trace()
|
478 |
images.append(image)
|
479 |
images = torch.stack(images, dim=0)
|
480 |
windows = sliding_window(images,window_size=(448,448),stride=448)
|
|
|
25 |
window_cols = (width - window_size[1]) // stride + 1
|
26 |
images_448 = F.interpolate(matrix, size=window_size, mode='bicubic')
|
27 |
windows = []
|
|
|
28 |
for i in range(window_rows):
|
29 |
windows_col = []
|
30 |
for j in range(window_cols):
|
31 |
window = matrix[:,:, i*stride:i*stride+window_size[0], j*stride:j*stride+window_size[1]]
|
32 |
windows.append(window)
|
|
|
33 |
windows.append(images_448)
|
34 |
images = torch.cat(windows,dim=1)
|
35 |
images = images.reshape(b*5,c,window_size[0], window_size[0])
|
|
|
143 |
self.ln_kv = norm_layer(embed_dim)
|
144 |
|
145 |
self.apply(self._init_weights)
|
146 |
+
|
|
|
147 |
|
148 |
def _init_weights(self, m):
|
|
|
|
|
149 |
if isinstance(m, nn.Linear):
|
150 |
trunc_normal_(m.weight, std=.02)
|
151 |
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
155 |
nn.init.constant_(m.weight, 1.0)
|
156 |
|
157 |
def forward(self, x, attn_mask=None):
|
|
|
158 |
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
159 |
|
160 |
x = self.kv_proj(x)
|
|
|
395 |
act_layer=act_layer,
|
396 |
norm_layer=norm_layer,
|
397 |
)
|
|
|
398 |
self.attn_pool = Resampler(
|
399 |
grid_size=int(math.sqrt(n_queries)),
|
400 |
embed_dim=output_dim,
|
|
|
411 |
)
|
412 |
self.ln_post = norm_layer(output_dim)
|
413 |
self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
|
|
|
414 |
|
|
|
|
|
415 |
|
416 |
def forward(self, x: torch.Tensor):
|
417 |
+
|
|
|
418 |
x = x.to(
|
419 |
dtype=self.transformer.get_cast_dtype(),
|
420 |
device=self.transformer.get_cast_device(),
|
|
|
431 |
x = x.permute(1, 0, 2) # NLD -> LND
|
432 |
x = self.transformer(x)
|
433 |
x = x.permute(1, 0, 2) # LND -> NLD
|
|
|
434 |
src_size = int(math.sqrt(x.shape[1]))
|
435 |
x = x.reshape(x.shape[0]//5,5,-1, x.shape[-1])
|
436 |
x1 = x[:,4,:,:]
|
|
|
442 |
x1 = self.attn_pool(x1)
|
443 |
x = self.post_pro(x)
|
444 |
x1 = self.post_pro(x1)
|
|
|
445 |
return torch.cat([x,x1],dim=1)
|
446 |
|
447 |
def post_pro(self, x):
|
|
|
452 |
|
453 |
def encode(self, image_paths: List[str]):
|
454 |
images = []
|
455 |
+
|
456 |
for image_path in image_paths:
|
457 |
try:
|
458 |
if image_path.startswith("http://") or image_path.startswith("https://"):
|
|
|
461 |
image = self.image_transform(Image.open(image_path).convert("RGB"))
|
462 |
except:
|
463 |
image = torch.zeros((3, 448*2, 448*2))
|
|
|
464 |
images.append(image)
|
465 |
images = torch.stack(images, dim=0)
|
466 |
windows = sliding_window(images,window_size=(448,448),stride=448)
|