feat: choose flash attention heuristically if not set explicitly
Browse files- modeling_bert.py +2 -2
modeling_bert.py
CHANGED
@@ -66,7 +66,7 @@ logger = logging.getLogger(__name__)
|
|
66 |
|
67 |
|
68 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
69 |
-
use_flash_attn = config.use_flash_attn
|
70 |
use_qk_norm = config.use_qk_norm
|
71 |
fused_bias_fc = config.fused_bias_fc
|
72 |
window_size = config.window_size
|
@@ -161,7 +161,7 @@ def _init_weights(module, initializer_range=0.02):
|
|
161 |
class BertEncoder(nn.Module):
|
162 |
def __init__(self, config: JinaBertConfig):
|
163 |
super().__init__()
|
164 |
-
self.use_flash_attn =
|
165 |
self.layers = nn.ModuleList(
|
166 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
167 |
)
|
|
|
66 |
|
67 |
|
68 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
69 |
+
use_flash_attn = config.use_flash_attn if config.use_flash_attn is not None else torch.cuda.is_available()
|
70 |
use_qk_norm = config.use_qk_norm
|
71 |
fused_bias_fc = config.fused_bias_fc
|
72 |
window_size = config.window_size
|
|
|
161 |
class BertEncoder(nn.Module):
|
162 |
def __init__(self, config: JinaBertConfig):
|
163 |
super().__init__()
|
164 |
+
self.use_flash_attn = config.use_flash_attn if config.use_flash_attn is not None else torch.cuda.is_available()
|
165 |
self.layers = nn.ModuleList(
|
166 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
167 |
)
|