Add flash attention
#21
by
yuanzhoulvpi
- opened
- modeling_baichuan.py +28 -12
modeling_baichuan.py
CHANGED
@@ -138,20 +138,36 @@ class BaichuanAttention(torch.nn.Module):
|
|
138 |
|
139 |
past_key_value = (key_states, value_states) if use_cache else None
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
attn_weights = attn_weights + attention_mask
|
150 |
-
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
151 |
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
attn_output = attn_output.transpose(1, 2)
|
157 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
138 |
|
139 |
past_key_value = (key_states, value_states) if use_cache else None
|
140 |
|
141 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
142 |
+
if pytorch_major_version >= 2:
|
143 |
+
if attention_mask is not None:
|
144 |
+
if q_len == 1: # inference with cache
|
145 |
+
if len(attention_mask.size()) == 4:
|
146 |
+
attention_mask = attention_mask[:, :, -1:, :]
|
147 |
+
else:
|
148 |
+
attention_mask = attention_mask[:, -1:, :]
|
|
|
|
|
149 |
|
150 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states,
|
151 |
+
key_states,
|
152 |
+
value_states,
|
153 |
+
dropout_p=0.0,
|
154 |
+
attn_mask=attention_mask)
|
155 |
+
else:
|
156 |
|
157 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
158 |
+
|
159 |
+
if attention_mask is not None:
|
160 |
+
if q_len == 1: # inference with cache
|
161 |
+
if len(attention_mask.size()) == 4:
|
162 |
+
attention_mask = attention_mask[:, :, -1:, :]
|
163 |
+
else:
|
164 |
+
attention_mask = attention_mask[:, -1:, :]
|
165 |
+
attn_weights = attn_weights + attention_mask
|
166 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
167 |
+
|
168 |
+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
169 |
+
|
170 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
171 |
|
172 |
attn_output = attn_output.transpose(1, 2)
|
173 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|