jupyterjazz
commited on
Commit
•
071760a
1
Parent(s):
90873c4
Update rotary.py
Browse files
rotary.py
CHANGED
@@ -534,7 +534,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
534 |
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
535 |
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
536 |
if rotary_base_changed:
|
537 |
-
self.inv_freq = self._compute_inv_freq(device=
|
538 |
if self.pos_idx_in_fp32:
|
539 |
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
540 |
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
|
|
534 |
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
535 |
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
536 |
if rotary_base_changed:
|
537 |
+
self.inv_freq = self._compute_inv_freq(device=device)
|
538 |
if self.pos_idx_in_fp32:
|
539 |
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
540 |
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|