x54-729 commited on
Commit
9c9e595
1 Parent(s): 5966e12

fix import error

Browse files
Files changed (2) hide show
  1. config.json +3 -2
  2. modeling_internlm2.py +18 -5
config.json CHANGED
@@ -21,10 +21,11 @@
21
  "num_key_value_heads": 8,
22
  "pad_token_id": 2,
23
  "rms_norm_eps": 1e-05,
24
- "rotary": {
25
- "base": 1000000,
26
  "type": "dynamic"
27
  },
 
28
  "tie_word_embeddings": false,
29
  "torch_dtype": "bfloat16",
30
  "transformers_version": "4.33.0",
 
21
  "num_key_value_heads": 8,
22
  "pad_token_id": 2,
23
  "rms_norm_eps": 1e-05,
24
+ "rope_scaling": {
25
+ "factor": 1.0,
26
  "type": "dynamic"
27
  },
28
+ "rope_theta": 1000000,
29
  "tie_word_embeddings": false,
30
  "torch_dtype": "bfloat16",
31
  "transformers_version": "4.33.0",
modeling_internlm2.py CHANGED
@@ -51,6 +51,19 @@ logger = logging.get_logger(__name__)
51
 
52
  _CONFIG_FOR_DOC = "InternLM2Config"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
55
  def _get_unpad_data(attention_mask):
56
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -502,13 +515,11 @@ class InternLM2FlashAttention2(InternLM2Attention):
502
  softmax_scale (`float`, *optional*):
503
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
504
  """
505
- from flash_attn import flash_attn_func, flash_attn_varlen_func
506
- from flash_attn.bert_padding import pad_input
507
  # Contains at least one padding token in the sequence
508
  causal = self.is_causal and query_length != 1
509
  if attention_mask is not None:
510
  batch_size = query_states.shape[0]
511
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
512
  query_states, key_states, value_states, attention_mask, query_length
513
  )
514
 
@@ -536,8 +547,7 @@ class InternLM2FlashAttention2(InternLM2Attention):
536
 
537
  return attn_output
538
 
539
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
540
- from flash_attn.bert_padding import index_first_axis, unpad_input
541
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
542
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
543
 
@@ -842,6 +852,9 @@ class InternLM2Model(InternLM2PreTrainedModel):
842
 
843
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
844
 
 
 
 
845
  # retrieve input_ids and inputs_embeds
846
  if input_ids is not None and inputs_embeds is not None:
847
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
51
 
52
  _CONFIG_FOR_DOC = "InternLM2Config"
53
 
54
+ flash_attn_func, flash_attn_varlen_func = None, None
55
+ pad_input, index_first_axis, unpad_input = None, None, None
56
+ def _import_flash_attn():
57
+ global flash_attn_func, flash_attn_varlen_func
58
+ global pad_input, index_first_axis, unpad_input
59
+ try:
60
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
61
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
62
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
63
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
64
+ except ImportError:
65
+ raise ImportError("flash_attn is not installed.")
66
+
67
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
  def _get_unpad_data(attention_mask):
69
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
515
  softmax_scale (`float`, *optional*):
516
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
517
  """
 
 
518
  # Contains at least one padding token in the sequence
519
  causal = self.is_causal and query_length != 1
520
  if attention_mask is not None:
521
  batch_size = query_states.shape[0]
522
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
523
  query_states, key_states, value_states, attention_mask, query_length
524
  )
525
 
 
547
 
548
  return attn_output
549
 
550
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
 
551
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
552
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
553
 
 
852
 
853
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
854
 
855
+ if self.config.attn_implementation == "flash_attention_2":
856
+ _import_flash_attn()
857
+
858
  # retrieve input_ids and inputs_embeds
859
  if input_ids is not None and inputs_embeds is not None:
860
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")