michael-guenther
commited on
Commit
•
71ef017
1
Parent(s):
807ba34
set use_flash_attn if not available
Browse files- mha.py +0 -2
- modeling_xlm_roberta.py +16 -3
mha.py
CHANGED
@@ -10,8 +10,6 @@ import torch
|
|
10 |
import torch.nn as nn
|
11 |
from einops import rearrange, repeat
|
12 |
|
13 |
-
from flash_attn.utils.distributed import get_dim_for_local_rank
|
14 |
-
|
15 |
try:
|
16 |
from flash_attn import (
|
17 |
flash_attn_kvpacked_func,
|
|
|
10 |
import torch.nn as nn
|
11 |
from einops import rearrange, repeat
|
12 |
|
|
|
|
|
13 |
try:
|
14 |
from flash_attn import (
|
15 |
flash_attn_kvpacked_func,
|
modeling_xlm_roberta.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
# This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
|
2 |
# Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
|
3 |
-
|
4 |
# Copyright (c) 2022, Tri Dao.
|
5 |
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
6 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
@@ -8,6 +7,7 @@
|
|
8 |
|
9 |
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
10 |
|
|
|
11 |
import logging
|
12 |
import re
|
13 |
from collections import OrderedDict
|
@@ -65,8 +65,21 @@ except ImportError:
|
|
65 |
logger = logging.getLogger(__name__)
|
66 |
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
69 |
-
use_flash_attn =
|
70 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
71 |
rotary_kwargs = {}
|
72 |
if config.position_embedding_type == "rotary":
|
@@ -169,7 +182,7 @@ def _init_weights(module, initializer_range=0.02):
|
|
169 |
class XLMRobertaEncoder(nn.Module):
|
170 |
def __init__(self, config: XLMRobertaFlashConfig):
|
171 |
super().__init__()
|
172 |
-
self.use_flash_attn =
|
173 |
self.layers = nn.ModuleList(
|
174 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
175 |
)
|
|
|
1 |
# This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
|
2 |
# Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
|
|
|
3 |
# Copyright (c) 2022, Tri Dao.
|
4 |
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
5 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
|
|
7 |
|
8 |
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
9 |
|
10 |
+
import importlib.util
|
11 |
import logging
|
12 |
import re
|
13 |
from collections import OrderedDict
|
|
|
65 |
logger = logging.getLogger(__name__)
|
66 |
|
67 |
|
68 |
+
def get_use_flash_attn(config: XLMRobertaFlashConfig):
|
69 |
+
if not config.use_flash_attn:
|
70 |
+
return False
|
71 |
+
if not torch.cuda.is_available():
|
72 |
+
return False
|
73 |
+
if importlib.util.find_spec("flash_attn") is None:
|
74 |
+
logger.warning(
|
75 |
+
'flash_attn is not installed. Using PyTorch native attention implementation.'
|
76 |
+
)
|
77 |
+
return False
|
78 |
+
return True
|
79 |
+
|
80 |
+
|
81 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
82 |
+
use_flash_attn = get_use_flash_attn(config)
|
83 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
84 |
rotary_kwargs = {}
|
85 |
if config.position_embedding_type == "rotary":
|
|
|
182 |
class XLMRobertaEncoder(nn.Module):
|
183 |
def __init__(self, config: XLMRobertaFlashConfig):
|
184 |
super().__init__()
|
185 |
+
self.use_flash_attn = get_use_flash_attn(config)
|
186 |
self.layers = nn.ModuleList(
|
187 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
188 |
)
|