add support for flash attn 2
Browse files- modeling_qwen.py +40 -31
modeling_qwen.py
CHANGED
@@ -36,10 +36,6 @@ SUPPORT_CUDA = torch.cuda.is_available()
|
|
36 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
37 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
38 |
|
39 |
-
apply_rotary_emb_func = None
|
40 |
-
rms_norm = None
|
41 |
-
flash_attn_unpadded_func = None
|
42 |
-
|
43 |
from .configuration_qwen import QWenConfig
|
44 |
from .qwen_generation_utils import (
|
45 |
HistoryType,
|
@@ -57,6 +53,45 @@ _CONFIG_FOR_DOC = "QWenConfig"
|
|
57 |
|
58 |
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
class FlashSelfAttention(torch.nn.Module):
|
61 |
def __init__(
|
62 |
self,
|
@@ -794,33 +829,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
794 |
logger.warn("Flash attention will be disabled because it does NOT support fp32.")
|
795 |
|
796 |
if config.use_flash_attn:
|
797 |
-
|
798 |
-
try:
|
799 |
-
from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
|
800 |
-
apply_rotary_emb_func = __apply_rotary_emb_func
|
801 |
-
except ImportError:
|
802 |
-
logger.warn(
|
803 |
-
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
|
804 |
-
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
|
805 |
-
)
|
806 |
-
|
807 |
-
try:
|
808 |
-
from flash_attn.ops.rms_norm import rms_norm as __rms_norm
|
809 |
-
rms_norm = __rms_norm
|
810 |
-
except ImportError:
|
811 |
-
logger.warn(
|
812 |
-
"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
|
813 |
-
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
|
814 |
-
)
|
815 |
-
|
816 |
-
try:
|
817 |
-
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
|
818 |
-
flash_attn_unpadded_func = __flash_attn_unpadded_func
|
819 |
-
except ImportError:
|
820 |
-
logger.warn(
|
821 |
-
"Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
|
822 |
-
"https://github.com/Dao-AILab/flash-attention"
|
823 |
-
)
|
824 |
|
825 |
self.transformer = QWenModel(config)
|
826 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
36 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
37 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
38 |
|
|
|
|
|
|
|
|
|
39 |
from .configuration_qwen import QWenConfig
|
40 |
from .qwen_generation_utils import (
|
41 |
HistoryType,
|
|
|
53 |
|
54 |
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
55 |
|
56 |
+
apply_rotary_emb_func = None
|
57 |
+
rms_norm = None
|
58 |
+
flash_attn_unpadded_func = None
|
59 |
+
|
60 |
+
|
61 |
+
def _import_flash_attn():
|
62 |
+
global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
|
63 |
+
try:
|
64 |
+
from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
|
65 |
+
apply_rotary_emb_func = __apply_rotary_emb_func
|
66 |
+
except ImportError:
|
67 |
+
logger.warn(
|
68 |
+
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
|
69 |
+
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
|
70 |
+
)
|
71 |
+
|
72 |
+
try:
|
73 |
+
from flash_attn.ops.rms_norm import rms_norm as __rms_norm
|
74 |
+
rms_norm = __rms_norm
|
75 |
+
except ImportError:
|
76 |
+
logger.warn(
|
77 |
+
"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
|
78 |
+
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
|
79 |
+
)
|
80 |
+
|
81 |
+
try:
|
82 |
+
import flash_attn
|
83 |
+
if int(flash_attn.__version__.split(".")[0]) >= 2:
|
84 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
|
85 |
+
else:
|
86 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
|
87 |
+
flash_attn_unpadded_func = __flash_attn_unpadded_func
|
88 |
+
except ImportError:
|
89 |
+
logger.warn(
|
90 |
+
"Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
|
91 |
+
"https://github.com/Dao-AILab/flash-attention"
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
class FlashSelfAttention(torch.nn.Module):
|
96 |
def __init__(
|
97 |
self,
|
|
|
829 |
logger.warn("Flash attention will be disabled because it does NOT support fp32.")
|
830 |
|
831 |
if config.use_flash_attn:
|
832 |
+
_import_flash_attn()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
|
834 |
self.transformer = QWenModel(config)
|
835 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|