pseudotensor commited on
Commit
cba2f63
1 Parent(s): 11599d0

Update modelling_RW.py

Browse files
Files changed (1) hide show
  1. modelling_RW.py +49 -49
modelling_RW.py CHANGED
@@ -52,11 +52,10 @@ class RotaryEmbedding(torch.nn.Module):
52
 
53
  def __init__(
54
  self,
55
- config,
56
  base=10000,
 
57
  ):
58
- head_dim = config.head_dim
59
- self.use_cache = config.use_cache
60
  super().__init__()
61
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
62
  self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -65,6 +64,7 @@ class RotaryEmbedding(torch.nn.Module):
65
  self.batch_size_cached = None
66
  self.cos_cached: torch.Tensor | None = None
67
  self.sin_cached: torch.Tensor | None = None
 
68
 
69
  def cos_sin(
70
  self,
@@ -107,10 +107,7 @@ class RotaryEmbedding(torch.nn.Module):
107
  def forward(self, q, k):
108
  batch, seq_len, head_dim = q.shape
109
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
110
- try:
111
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
112
- except Exception as e:
113
- raise
114
 
115
 
116
  def _make_causal_mask(
@@ -187,7 +184,7 @@ class Attention(nn.Module):
187
  f" {self.num_heads})."
188
  )
189
 
190
- self.maybe_rotary = RotaryEmbedding(config) if config.rotary else lambda q, k: (q, k)
191
 
192
  # Layer-wise attention scaling
193
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
@@ -195,34 +192,44 @@ class Attention(nn.Module):
195
 
196
  self.query_key_value = Linear(
197
  self.hidden_size,
198
- 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
199
  bias=config.bias,
200
  )
201
- self.multi_query = config.multi_query
202
  self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
203
  self.attention_dropout = nn.Dropout(config.attention_dropout)
204
- self.num_kv = config.n_head if not self.multi_query else 1
205
 
206
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
207
  """
208
- Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
209
  storage as `fused_qkv`
210
 
211
  Args:
212
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
213
 
214
  Returns:
215
- query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
 
216
  value: [batch_size, seq_length, num_heads, head_dim]
217
  """
218
- if not self.multi_query:
219
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
220
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
221
- return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
222
- else:
223
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
224
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
225
- return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
 
 
 
 
 
 
 
 
 
 
226
 
227
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
228
  """
@@ -268,11 +275,11 @@ class Attention(nn.Module):
268
 
269
  query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
270
  key_layer = key_layer.transpose(1, 2).reshape(
271
- batch_size * self.num_kv,
272
  q_length,
273
  self.head_dim,
274
  )
275
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
276
 
277
  query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
278
 
@@ -293,15 +300,12 @@ class Attention(nn.Module):
293
 
294
  if alibi is None:
295
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
296
- key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
297
- value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
298
 
299
- try:
300
- attn_output = F.scaled_dot_product_attention(
301
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
302
- )
303
- except Exception as e:
304
- raise
305
 
306
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
307
  x = x.permute(0, 2, 1, 3)
@@ -326,7 +330,8 @@ class Attention(nn.Module):
326
  attention_scores = attention_scores.to(torch.float32)
327
  # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
328
  attention_probs = F.softmax(
329
- (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
 
330
  dim=-1,
331
  dtype=hidden_states.dtype,
332
  )
@@ -375,14 +380,12 @@ class DecoderLayer(nn.Module):
375
  super().__init__()
376
  hidden_size = config.hidden_size
377
 
378
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
 
 
379
  self.num_heads = config.n_head
380
  self.self_attention = Attention(config)
381
 
382
- if not config.parallel_attn:
383
- # unused if parallel attn
384
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
385
-
386
  self.mlp = MLP(config)
387
 
388
  self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
@@ -401,12 +404,14 @@ class DecoderLayer(nn.Module):
401
  output_attentions: bool = False,
402
  ):
403
 
404
- layernorm_output = self.input_layernorm(hidden_states)
 
 
405
  residual = hidden_states
406
 
407
  # Self attention.
408
  attn_outputs = self.self_attention(
409
- layernorm_output,
410
  layer_past=layer_past,
411
  attention_mask=attention_mask,
412
  alibi=alibi,
@@ -417,19 +422,14 @@ class DecoderLayer(nn.Module):
417
 
418
  attention_output = attn_outputs[0]
419
 
420
- if not self.config.parallel_attn:
421
- residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
422
- layernorm_output = self.post_attention_layernorm(residual)
423
-
424
  outputs = attn_outputs[1:]
425
 
426
  # MLP.
427
- mlp_output = self.mlp(layernorm_output)
428
-
429
- if self.config.parallel_attn:
430
- mlp_output += attention_output
431
 
432
- output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
 
 
433
 
434
  if use_cache:
435
  outputs = (output,) + outputs
@@ -1120,4 +1120,4 @@ class RWForQuestionAnswering(RWPreTrainedModel):
1120
  end_logits=end_logits,
1121
  hidden_states=outputs.hidden_states,
1122
  attentions=outputs.attentions,
1123
- )
 
52
 
53
  def __init__(
54
  self,
55
+ head_dim: int,
56
  base=10000,
57
+ use_cache=False,
58
  ):
 
 
59
  super().__init__()
60
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
61
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
64
  self.batch_size_cached = None
65
  self.cos_cached: torch.Tensor | None = None
66
  self.sin_cached: torch.Tensor | None = None
67
+ self.use_cache = use_cache
68
 
69
  def cos_sin(
70
  self,
 
107
  def forward(self, q, k):
108
  batch, seq_len, head_dim = q.shape
109
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
110
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
 
 
111
 
112
 
113
  def _make_causal_mask(
 
184
  f" {self.num_heads})."
185
  )
186
 
187
+ self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
188
 
189
  # Layer-wise attention scaling
190
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
 
192
 
193
  self.query_key_value = Linear(
194
  self.hidden_size,
195
+ (config.n_head_kv * 2 + config.n_head) * self.head_dim,
196
  bias=config.bias,
197
  )
 
198
  self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
199
  self.attention_dropout = nn.Dropout(config.attention_dropout)
200
+ self.num_kv = config.n_head_kv
201
 
202
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
203
  """
204
+ Split the last dimension into (num_heads, head_dim), results share same memory
205
  storage as `fused_qkv`
206
 
207
  Args:
208
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
209
 
210
  Returns:
211
+ query: [batch_size, seq_length, num_heads, head_dim]
212
+ key: [batch_size, seq_length, num_heads, head_dim]
213
  value: [batch_size, seq_length, num_heads, head_dim]
214
  """
215
+ batch, seq_len, _ = fused_qkv.shape
216
+ qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv + 2, 64)
217
+ q = qkv[:, :, :, :-2]
218
+ k = qkv[:, :, :, [-2]]
219
+ v = qkv[:, :, :, [-1]]
220
+ k = torch.broadcast_to(k, q.shape)
221
+ v = torch.broadcast_to(v, q.shape)
222
+
223
+ q, k, v = [
224
+ rearrange(
225
+ x,
226
+ "batch seq_len group num_heads head_dim ->\
227
+ batch seq_len (group num_heads) head_dim",
228
+ head_dim=self.head_dim,
229
+ )
230
+ for x in [q, k, v]
231
+ ]
232
+ return q, k, v
233
 
234
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
235
  """
 
275
 
276
  query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
277
  key_layer = key_layer.transpose(1, 2).reshape(
278
+ batch_size * self.num_heads,
279
  q_length,
280
  self.head_dim,
281
  )
282
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
283
 
284
  query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
285
 
 
300
 
301
  if alibi is None:
302
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
303
+ key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
304
+ value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
305
 
306
+ attn_output = F.scaled_dot_product_attention(
307
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
308
+ )
 
 
 
309
 
310
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
311
  x = x.permute(0, 2, 1, 3)
 
330
  attention_scores = attention_scores.to(torch.float32)
331
  # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
332
  attention_probs = F.softmax(
333
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor
334
+ + attention_mask_float,
335
  dim=-1,
336
  dtype=hidden_states.dtype,
337
  )
 
380
  super().__init__()
381
  hidden_size = config.hidden_size
382
 
383
+ self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
384
+ self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
385
+
386
  self.num_heads = config.n_head
387
  self.self_attention = Attention(config)
388
 
 
 
 
 
389
  self.mlp = MLP(config)
390
 
391
  self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
 
404
  output_attentions: bool = False,
405
  ):
406
 
407
+ ln_attn = self.ln_attn(hidden_states)
408
+ ln_mlp = self.ln_mlp(hidden_states)
409
+
410
  residual = hidden_states
411
 
412
  # Self attention.
413
  attn_outputs = self.self_attention(
414
+ ln_attn,
415
  layer_past=layer_past,
416
  attention_mask=attention_mask,
417
  alibi=alibi,
 
422
 
423
  attention_output = attn_outputs[0]
424
 
 
 
 
 
425
  outputs = attn_outputs[1:]
426
 
427
  # MLP.
428
+ mlp_output = self.mlp(ln_mlp)
 
 
 
429
 
430
+ output = dropout_add(
431
+ mlp_output + attention_output, residual, self.config.hidden_dropout, training=self.training
432
+ )
433
 
434
  if use_cache:
435
  outputs = (output,) + outputs
 
1120
  end_logits=end_logits,
1121
  hidden_states=outputs.hidden_states,
1122
  attentions=outputs.attentions,
1123
+ )