pseudotensor commited on
Commit
236c227
1 Parent(s): a62a230

Update modelling_RW.py

Browse files
Files changed (1) hide show
  1. modelling_RW.py +3 -1
modelling_RW.py CHANGED
@@ -54,6 +54,7 @@ class RotaryEmbedding(torch.nn.Module):
54
  self,
55
  head_dim: int,
56
  base=10000,
 
57
  ):
58
  super().__init__()
59
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
@@ -63,6 +64,7 @@ class RotaryEmbedding(torch.nn.Module):
63
  self.batch_size_cached = None
64
  self.cos_cached: torch.Tensor | None = None
65
  self.sin_cached: torch.Tensor | None = None
 
66
 
67
  def cos_sin(
68
  self,
@@ -70,7 +72,7 @@ class RotaryEmbedding(torch.nn.Module):
70
  device="cuda",
71
  dtype=torch.bfloat16,
72
  ) -> torch.Tensor:
73
- if seq_len != self.seq_len_cached:
74
  self.seq_len_cached = seq_len
75
  t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
76
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
 
54
  self,
55
  head_dim: int,
56
  base=10000,
57
+ use_cache=False,
58
  ):
59
  super().__init__()
60
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
 
64
  self.batch_size_cached = None
65
  self.cos_cached: torch.Tensor | None = None
66
  self.sin_cached: torch.Tensor | None = None
67
+ self.use_cache = use_cache
68
 
69
  def cos_sin(
70
  self,
 
72
  device="cuda",
73
  dtype=torch.bfloat16,
74
  ) -> torch.Tensor:
75
+ if seq_len != self.seq_len_cached or not self.use_cache:
76
  self.seq_len_cached = seq_len
77
  t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
78
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)