pseudotensor commited on
Commit
636ee83
1 Parent(s): 236c227

Update modelling_RW.py

Browse files

avoid cache to avoid https://github.com/h2oai/h2ogpt/pull/297

Files changed (1) hide show
  1. modelling_RW.py +18 -3
modelling_RW.py CHANGED
@@ -71,8 +71,23 @@ class RotaryEmbedding(torch.nn.Module):
71
  seq_len: int,
72
  device="cuda",
73
  dtype=torch.bfloat16,
74
- ) -> torch.Tensor:
75
- if seq_len != self.seq_len_cached or not self.use_cache:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  self.seq_len_cached = seq_len
77
  t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
78
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
@@ -87,7 +102,7 @@ class RotaryEmbedding(torch.nn.Module):
87
  self.cos_cached = self.cos_cached.type(dtype)
88
  self.sin_cached = self.sin_cached.type(dtype)
89
 
90
- return self.cos_cached, self.sin_cached
91
 
92
  def forward(self, q, k):
93
  batch, seq_len, head_dim = q.shape
 
71
  seq_len: int,
72
  device="cuda",
73
  dtype=torch.bfloat16,
74
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
75
+ if not self.use_cache:
76
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
77
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
78
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
79
+
80
+ if dtype in [torch.float16, torch.bfloat16]:
81
+ emb = emb.float()
82
+
83
+ cos_cached = emb.cos()[None, :, :]
84
+ sin_cached = emb.sin()[None, :, :]
85
+
86
+ cos_cached = cos_cached.type(dtype)
87
+ sin_cached = sin_cached.type(dtype)
88
+
89
+ return cos_cached, sin_cached
90
+ elif seq_len != self.seq_len_cached or not self.use_cache:
91
  self.seq_len_cached = seq_len
92
  t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
93
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
 
102
  self.cos_cached = self.cos_cached.type(dtype)
103
  self.sin_cached = self.sin_cached.type(dtype)
104
 
105
+ return self.cos_cached, self.sin_cached
106
 
107
  def forward(self, q, k):
108
  batch, seq_len, head_dim = q.shape