Add print statements
Browse files- modeling_cogvlm.py +7 -2
modeling_cogvlm.py
CHANGED
@@ -117,7 +117,8 @@ def attention_fn(
|
|
117 |
attention_mask: "torch.tensor(B, H, L, HD)",
|
118 |
*,
|
119 |
scaling_attention_score: bool = True,
|
120 |
-
attention_dropout: nn.Module = None
|
|
|
121 |
):
|
122 |
attention_mask_bool = (attention_mask == 0)
|
123 |
is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
|
@@ -126,6 +127,10 @@ def attention_fn(
|
|
126 |
warnings.warn("It's recommended to use torch2.0 or higher.")
|
127 |
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
|
128 |
dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
|
|
|
|
|
|
|
|
|
129 |
return torch.nn.functional.scaled_dot_product_attention(
|
130 |
query_layer, key_layer, value_layer,
|
131 |
attn_mask=None,
|
@@ -302,7 +307,7 @@ class VisionExpertAttention(nn.Module):
|
|
302 |
|
303 |
context_layer = attention_fn(
|
304 |
query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
|
305 |
-
scaling_attention_score=True, attention_dropout=None)
|
306 |
|
307 |
if print_values:
|
308 |
print("Shape of context_layer:", context_layer.shape)
|
|
|
117 |
attention_mask: "torch.tensor(B, H, L, HD)",
|
118 |
*,
|
119 |
scaling_attention_score: bool = True,
|
120 |
+
attention_dropout: nn.Module = None,
|
121 |
+
print_values: bool = False,
|
122 |
):
|
123 |
attention_mask_bool = (attention_mask == 0)
|
124 |
is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
|
|
|
127 |
warnings.warn("It's recommended to use torch2.0 or higher.")
|
128 |
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
|
129 |
dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
|
130 |
+
|
131 |
+
if print_values:
|
132 |
+
print("Is_causal:", not is_full)
|
133 |
+
|
134 |
return torch.nn.functional.scaled_dot_product_attention(
|
135 |
query_layer, key_layer, value_layer,
|
136 |
attn_mask=None,
|
|
|
307 |
|
308 |
context_layer = attention_fn(
|
309 |
query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
|
310 |
+
scaling_attention_score=True, attention_dropout=None, print_values=print_values)
|
311 |
|
312 |
if print_values:
|
313 |
print("Shape of context_layer:", context_layer.shape)
|