Add print statements
Browse files- modeling_cogvlm.py +30 -0
modeling_cogvlm.py
CHANGED
@@ -225,6 +225,7 @@ class VisionExpertAttention(nn.Module):
|
|
225 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
226 |
output_attentions: bool = False,
|
227 |
use_cache: bool = False,
|
|
|
228 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
229 |
bsz, q_len, _ = hidden_states.size()
|
230 |
vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
@@ -240,6 +241,34 @@ class VisionExpertAttention(nn.Module):
|
|
240 |
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
241 |
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
kv_seq_len = key_states.shape[-2]
|
244 |
if past_key_value is not None:
|
245 |
kv_seq_len += past_key_value[0].shape[-2]
|
@@ -308,6 +337,7 @@ class CogVLMDecoderLayer(nn.Module):
|
|
308 |
past_key_value=past_key_value,
|
309 |
output_attentions=output_attentions,
|
310 |
use_cache=use_cache,
|
|
|
311 |
)
|
312 |
|
313 |
if print_values:
|
|
|
225 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
226 |
output_attentions: bool = False,
|
227 |
use_cache: bool = False,
|
228 |
+
print_values: bool = False,
|
229 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
230 |
bsz, q_len, _ = hidden_states.size()
|
231 |
vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
|
|
241 |
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
|
242 |
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
|
243 |
|
244 |
+
torch.save(query_states, "query_states.pt")
|
245 |
+
torch.save(key_states, "key_states.pt")
|
246 |
+
torch.save(value_states, "value_states.pt")
|
247 |
+
|
248 |
+
from huggingface_hub import HfApi
|
249 |
+
|
250 |
+
api = HfApi()
|
251 |
+
api.upload_file(
|
252 |
+
path_or_fileobj="query_states.pt",
|
253 |
+
path_in_repo="query_states.pt",
|
254 |
+
repo_id="nielsr/test-cogvlm",
|
255 |
+
repo_type="dataset",
|
256 |
+
)
|
257 |
+
api = HfApi()
|
258 |
+
api.upload_file(
|
259 |
+
path_or_fileobj="key_states.pt",
|
260 |
+
path_in_repo="key_states.pt",
|
261 |
+
repo_id="nielsr/test-cogvlm",
|
262 |
+
repo_type="dataset",
|
263 |
+
)
|
264 |
+
api = HfApi()
|
265 |
+
api.upload_file(
|
266 |
+
path_or_fileobj="value_states.pt",
|
267 |
+
path_in_repo="value_states.pt",
|
268 |
+
repo_id="nielsr/test-cogvlm",
|
269 |
+
repo_type="dataset",
|
270 |
+
)
|
271 |
+
|
272 |
kv_seq_len = key_states.shape[-2]
|
273 |
if past_key_value is not None:
|
274 |
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
337 |
past_key_value=past_key_value,
|
338 |
output_attentions=output_attentions,
|
339 |
use_cache=use_cache,
|
340 |
+
print_values=print_values,
|
341 |
)
|
342 |
|
343 |
if print_values:
|