pseudotensor
commited on
Commit
•
236c227
1
Parent(s):
a62a230
Update modelling_RW.py
Browse files- modelling_RW.py +3 -1
modelling_RW.py
CHANGED
@@ -54,6 +54,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
54 |
self,
|
55 |
head_dim: int,
|
56 |
base=10000,
|
|
|
57 |
):
|
58 |
super().__init__()
|
59 |
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
@@ -63,6 +64,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
63 |
self.batch_size_cached = None
|
64 |
self.cos_cached: torch.Tensor | None = None
|
65 |
self.sin_cached: torch.Tensor | None = None
|
|
|
66 |
|
67 |
def cos_sin(
|
68 |
self,
|
@@ -70,7 +72,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
70 |
device="cuda",
|
71 |
dtype=torch.bfloat16,
|
72 |
) -> torch.Tensor:
|
73 |
-
if seq_len != self.seq_len_cached:
|
74 |
self.seq_len_cached = seq_len
|
75 |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
76 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
54 |
self,
|
55 |
head_dim: int,
|
56 |
base=10000,
|
57 |
+
use_cache=False,
|
58 |
):
|
59 |
super().__init__()
|
60 |
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
|
|
64 |
self.batch_size_cached = None
|
65 |
self.cos_cached: torch.Tensor | None = None
|
66 |
self.sin_cached: torch.Tensor | None = None
|
67 |
+
self.use_cache = use_cache
|
68 |
|
69 |
def cos_sin(
|
70 |
self,
|
|
|
72 |
device="cuda",
|
73 |
dtype=torch.bfloat16,
|
74 |
) -> torch.Tensor:
|
75 |
+
if seq_len != self.seq_len_cached or not self.use_cache:
|
76 |
self.seq_len_cached = seq_len
|
77 |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
78 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|