pseudotensor
commited on
Commit
•
636ee83
1
Parent(s):
236c227
Update modelling_RW.py
Browse filesavoid cache to avoid https://github.com/h2oai/h2ogpt/pull/297
- modelling_RW.py +18 -3
modelling_RW.py
CHANGED
@@ -71,8 +71,23 @@ class RotaryEmbedding(torch.nn.Module):
|
|
71 |
seq_len: int,
|
72 |
device="cuda",
|
73 |
dtype=torch.bfloat16,
|
74 |
-
) -> torch.Tensor:
|
75 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
self.seq_len_cached = seq_len
|
77 |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
78 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
@@ -87,7 +102,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
87 |
self.cos_cached = self.cos_cached.type(dtype)
|
88 |
self.sin_cached = self.sin_cached.type(dtype)
|
89 |
|
90 |
-
|
91 |
|
92 |
def forward(self, q, k):
|
93 |
batch, seq_len, head_dim = q.shape
|
|
|
71 |
seq_len: int,
|
72 |
device="cuda",
|
73 |
dtype=torch.bfloat16,
|
74 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
75 |
+
if not self.use_cache:
|
76 |
+
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
77 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
78 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
79 |
+
|
80 |
+
if dtype in [torch.float16, torch.bfloat16]:
|
81 |
+
emb = emb.float()
|
82 |
+
|
83 |
+
cos_cached = emb.cos()[None, :, :]
|
84 |
+
sin_cached = emb.sin()[None, :, :]
|
85 |
+
|
86 |
+
cos_cached = cos_cached.type(dtype)
|
87 |
+
sin_cached = sin_cached.type(dtype)
|
88 |
+
|
89 |
+
return cos_cached, sin_cached
|
90 |
+
elif seq_len != self.seq_len_cached or not self.use_cache:
|
91 |
self.seq_len_cached = seq_len
|
92 |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
93 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
102 |
self.cos_cached = self.cos_cached.type(dtype)
|
103 |
self.sin_cached = self.sin_cached.type(dtype)
|
104 |
|
105 |
+
return self.cos_cached, self.sin_cached
|
106 |
|
107 |
def forward(self, q, k):
|
108 |
batch, seq_len, head_dim = q.shape
|