update modeling_qwen.py
Browse files- modeling_qwen.py +4 -3
modeling_qwen.py
CHANGED
@@ -175,6 +175,7 @@ class FlashSelfAttention(torch.nn.Module):
|
|
175 |
assert all((i.is_cuda for i in (q, k, v)))
|
176 |
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
177 |
seqlen_k = k.shape[1]
|
|
|
178 |
|
179 |
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
180 |
cu_seqlens_q = torch.arange(
|
@@ -187,11 +188,11 @@ class FlashSelfAttention(torch.nn.Module):
|
|
187 |
|
188 |
if attention_mask is not None:
|
189 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
190 |
-
|
191 |
-
if self.training or q.size(0) == k.size(0):
|
192 |
q = q[indices_k]
|
193 |
cu_seqlens_q = cu_seqlens_k
|
194 |
seqlen_q = seqlen_k
|
|
|
195 |
else:
|
196 |
cu_seqlens_k = torch.arange(
|
197 |
0,
|
@@ -222,7 +223,7 @@ class FlashSelfAttention(torch.nn.Module):
|
|
222 |
causal=is_causal,
|
223 |
)
|
224 |
if attention_mask is not None and seqlen_q == seqlen_k:
|
225 |
-
output = self.pad_input(output, indices_k, batch_size,
|
226 |
else:
|
227 |
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
|
228 |
output = output.view(new_shape)
|
|
|
175 |
assert all((i.is_cuda for i in (q, k, v)))
|
176 |
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
177 |
seqlen_k = k.shape[1]
|
178 |
+
seqlen_out = seqlen_q
|
179 |
|
180 |
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
181 |
cu_seqlens_q = torch.arange(
|
|
|
188 |
|
189 |
if attention_mask is not None:
|
190 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
191 |
+
if q.size(0) == v.size(0):
|
|
|
192 |
q = q[indices_k]
|
193 |
cu_seqlens_q = cu_seqlens_k
|
194 |
seqlen_q = seqlen_k
|
195 |
+
v = v[indices_k]
|
196 |
else:
|
197 |
cu_seqlens_k = torch.arange(
|
198 |
0,
|
|
|
223 |
causal=is_causal,
|
224 |
)
|
225 |
if attention_mask is not None and seqlen_q == seqlen_k:
|
226 |
+
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
|
227 |
else:
|
228 |
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
|
229 |
output = output.view(new_shape)
|