fix flash-attention usage
Browse files- README.md +2 -2
- config.json +1 -1
- modeling_qwen.py +19 -12
README.md
CHANGED
@@ -16,7 +16,7 @@ inference: false
|
|
16 |
<br>
|
17 |
|
18 |
<p align="center">
|
19 |
-
Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖
|
20 |
</p>
|
21 |
<br>
|
22 |
|
@@ -350,4 +350,4 @@ Our code and checkpoints are open to research purpose, and they are allowed for
|
|
350 |
|
351 |
如果你想给我们的研发团队和产品团队留言,请通过邮件([email protected])联系我们。
|
352 |
|
353 |
-
If you are interested to leave a message to either our research team or product team, feel free to send an email to [email protected].
|
|
|
16 |
<br>
|
17 |
|
18 |
<p align="center">
|
19 |
+
Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 </a> | <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>  | Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 </a>| <a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>  |  <a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>  |  <a href="https://github.com/QwenLM/Qwen-7B/blob/main/tech_memo.md">Report</a>
|
20 |
</p>
|
21 |
<br>
|
22 |
|
|
|
350 |
|
351 |
如果你想给我们的研发团队和产品团队留言,请通过邮件([email protected])联系我们。
|
352 |
|
353 |
+
If you are interested to leave a message to either our research team or product team, feel free to send an email to [email protected].
|
config.json
CHANGED
@@ -38,7 +38,7 @@
|
|
38 |
"tokenizer_type": "QWenTokenizer",
|
39 |
"transformers_version": "4.31.0",
|
40 |
"use_cache": true,
|
41 |
-
"use_flash_attn":
|
42 |
"vocab_size": 151936,
|
43 |
"use_dynamic_ntk": false,
|
44 |
"use_logn_attn": false
|
|
|
38 |
"tokenizer_type": "QWenTokenizer",
|
39 |
"transformers_version": "4.31.0",
|
40 |
"use_cache": true,
|
41 |
+
"use_flash_attn": true,
|
42 |
"vocab_size": 151936,
|
43 |
"use_dynamic_ntk": false,
|
44 |
"use_logn_attn": false
|
modeling_qwen.py
CHANGED
@@ -36,18 +36,17 @@ try:
|
|
36 |
from einops import rearrange
|
37 |
|
38 |
use_flash_rotary = True
|
39 |
-
print("use flash_attn rotary")
|
40 |
except ImportError:
|
41 |
use_flash_rotary = False
|
42 |
-
print("import flash_attn rotary fail"
|
|
|
43 |
|
44 |
try:
|
45 |
from flash_attn.ops.rms_norm import rms_norm
|
46 |
-
|
47 |
-
print("use flash_attn rms_norm")
|
48 |
except ImportError:
|
49 |
rms_norm = None
|
50 |
-
print("import flash_attn rms_norm fail"
|
|
|
51 |
|
52 |
from .configuration_qwen import QWenConfig
|
53 |
from .qwen_generation_utils import (
|
@@ -70,6 +69,8 @@ try:
|
|
70 |
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
71 |
except ImportError:
|
72 |
flash_attn_unpadded_func = None
|
|
|
|
|
73 |
|
74 |
|
75 |
class FlashSelfAttention(torch.nn.Module):
|
@@ -176,7 +177,7 @@ class QWenAttention(nn.Module):
|
|
176 |
config.hidden_size, self.projection_size, bias=not config.no_bias
|
177 |
)
|
178 |
|
179 |
-
if self.use_flash_attn:
|
180 |
self.core_attention_flash = FlashSelfAttention(
|
181 |
causal=True, attention_dropout=config.attn_pdrop
|
182 |
)
|
@@ -333,7 +334,7 @@ class QWenAttention(nn.Module):
|
|
333 |
if layer_past:
|
334 |
# layer past[0] shape: bs * seq_len * head_num * dim
|
335 |
kv_seq_len += layer_past[0].shape[1]
|
336 |
-
if self.use_dynamic_ntk and kv_seq_len == hidden_states.size()[1]:
|
337 |
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
|
338 |
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
339 |
ntk_alpha = max(ntk_alpha, 1)
|
@@ -367,7 +368,7 @@ class QWenAttention(nn.Module):
|
|
367 |
else:
|
368 |
present = None
|
369 |
|
370 |
-
if self.use_logn_attn:
|
371 |
if self.logn_tensor.device != query.device:
|
372 |
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
|
373 |
seq_start = key.size(0) - query.size(0)
|
@@ -375,7 +376,7 @@ class QWenAttention(nn.Module):
|
|
375 |
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
376 |
query = query * logn_tensor.expand_as(query)
|
377 |
|
378 |
-
if self.use_flash_attn:
|
379 |
q, k, v = query, key, value
|
380 |
context_layer = self.core_attention_flash(q, k, v)
|
381 |
|
@@ -396,7 +397,7 @@ class QWenAttention(nn.Module):
|
|
396 |
attn_output = self.c_proj(context_layer)
|
397 |
outputs = (attn_output, present)
|
398 |
if output_attentions:
|
399 |
-
if self.use_flash_attn:
|
400 |
raise ValueError("Cannot output attentions while using flash-attn")
|
401 |
else:
|
402 |
outputs += (attn_weight,)
|
@@ -748,6 +749,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
748 |
super().__init__(config)
|
749 |
self.transformer = QWenModel(config)
|
750 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
751 |
self.post_init()
|
752 |
|
753 |
def get_output_embeddings(self):
|
@@ -957,8 +965,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
957 |
super().__init__()
|
958 |
self.dim = dim
|
959 |
self.base = base
|
960 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
961 |
-
self.register_buffer("inv_freq", inv_freq)
|
962 |
if importlib.util.find_spec("einops") is None:
|
963 |
raise RuntimeError("einops is required for Rotary Embedding")
|
964 |
|
|
|
36 |
from einops import rearrange
|
37 |
|
38 |
use_flash_rotary = True
|
|
|
39 |
except ImportError:
|
40 |
use_flash_rotary = False
|
41 |
+
print("Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance "
|
42 |
+
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary")
|
43 |
|
44 |
try:
|
45 |
from flash_attn.ops.rms_norm import rms_norm
|
|
|
|
|
46 |
except ImportError:
|
47 |
rms_norm = None
|
48 |
+
print("Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance "
|
49 |
+
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm")
|
50 |
|
51 |
from .configuration_qwen import QWenConfig
|
52 |
from .qwen_generation_utils import (
|
|
|
69 |
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
70 |
except ImportError:
|
71 |
flash_attn_unpadded_func = None
|
72 |
+
print("Warning: import flash_attn fail, please install FlashAttention "
|
73 |
+
"https://github.com/Dao-AILab/flash-attention")
|
74 |
|
75 |
|
76 |
class FlashSelfAttention(torch.nn.Module):
|
|
|
177 |
config.hidden_size, self.projection_size, bias=not config.no_bias
|
178 |
)
|
179 |
|
180 |
+
if self.use_flash_attn and flash_attn_unpadded_func is not None:
|
181 |
self.core_attention_flash = FlashSelfAttention(
|
182 |
causal=True, attention_dropout=config.attn_pdrop
|
183 |
)
|
|
|
334 |
if layer_past:
|
335 |
# layer past[0] shape: bs * seq_len * head_num * dim
|
336 |
kv_seq_len += layer_past[0].shape[1]
|
337 |
+
if self.use_dynamic_ntk and kv_seq_len == hidden_states.size()[1] and not self.training:
|
338 |
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
|
339 |
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
340 |
ntk_alpha = max(ntk_alpha, 1)
|
|
|
368 |
else:
|
369 |
present = None
|
370 |
|
371 |
+
if self.use_logn_attn and not self.training:
|
372 |
if self.logn_tensor.device != query.device:
|
373 |
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
|
374 |
seq_start = key.size(0) - query.size(0)
|
|
|
376 |
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
377 |
query = query * logn_tensor.expand_as(query)
|
378 |
|
379 |
+
if self.use_flash_attn and flash_attn_unpadded_func is not None:
|
380 |
q, k, v = query, key, value
|
381 |
context_layer = self.core_attention_flash(q, k, v)
|
382 |
|
|
|
397 |
attn_output = self.c_proj(context_layer)
|
398 |
outputs = (attn_output, present)
|
399 |
if output_attentions:
|
400 |
+
if self.use_flash_attn and flash_attn_unpadded_func is not None:
|
401 |
raise ValueError("Cannot output attentions while using flash-attn")
|
402 |
else:
|
403 |
outputs += (attn_weight,)
|
|
|
749 |
super().__init__(config)
|
750 |
self.transformer = QWenModel(config)
|
751 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
752 |
+
assert not(config.bf16 and config.fp16), ("In config, bf16 and fp16 cannot both be true")
|
753 |
+
if config.bf16:
|
754 |
+
self.transformer.bfloat16()
|
755 |
+
self.lm_head.bfloat16()
|
756 |
+
if config.fp16:
|
757 |
+
self.transformer.half()
|
758 |
+
self.lm_head.half()
|
759 |
self.post_init()
|
760 |
|
761 |
def get_output_embeddings(self):
|
|
|
965 |
super().__init__()
|
966 |
self.dim = dim
|
967 |
self.base = base
|
968 |
+
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
969 |
if importlib.util.find_spec("einops") is None:
|
970 |
raise RuntimeError("einops is required for Rotary Embedding")
|
971 |
|