pseudotensor commited on
Commit
4694579
1 Parent(s): cba2f63

Update modelling_RW.py

Browse files

ensure use 40b file as reference.

Files changed (1) hide show
  1. modelling_RW.py +4 -4
modelling_RW.py CHANGED
@@ -52,10 +52,11 @@ class RotaryEmbedding(torch.nn.Module):
52
 
53
  def __init__(
54
  self,
55
- head_dim: int,
56
  base=10000,
57
- use_cache=False,
58
  ):
 
 
59
  super().__init__()
60
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
61
  self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -64,7 +65,6 @@ class RotaryEmbedding(torch.nn.Module):
64
  self.batch_size_cached = None
65
  self.cos_cached: torch.Tensor | None = None
66
  self.sin_cached: torch.Tensor | None = None
67
- self.use_cache = use_cache
68
 
69
  def cos_sin(
70
  self,
@@ -184,7 +184,7 @@ class Attention(nn.Module):
184
  f" {self.num_heads})."
185
  )
186
 
187
- self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
188
 
189
  # Layer-wise attention scaling
190
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
 
52
 
53
  def __init__(
54
  self,
55
+ config,
56
  base=10000,
 
57
  ):
58
+ head_dim = config.head_dim
59
+ self.use_cache = config.use_cache
60
  super().__init__()
61
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
62
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
65
  self.batch_size_cached = None
66
  self.cos_cached: torch.Tensor | None = None
67
  self.sin_cached: torch.Tensor | None = None
 
68
 
69
  def cos_sin(
70
  self,
 
184
  f" {self.num_heads})."
185
  )
186
 
187
+ self.maybe_rotary = RotaryEmbedding(config) if config.rotary else lambda q, k: (q, k)
188
 
189
  # Layer-wise attention scaling
190
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)