pseudotensor commited on
Commit
11599d0
1 Parent(s): 636ee83

Update modelling_RW.py

Browse files

Better version, so can set config.use_cache = False in top level during model load, and gets to bottom level, for https://github.com/h2oai/h2ogpt/pull/297

Files changed (1) hide show
  1. modelling_RW.py +48 -48
modelling_RW.py CHANGED
@@ -52,10 +52,11 @@ class RotaryEmbedding(torch.nn.Module):
52
 
53
  def __init__(
54
  self,
55
- head_dim: int,
56
  base=10000,
57
- use_cache=False,
58
  ):
 
 
59
  super().__init__()
60
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
61
  self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -64,7 +65,6 @@ class RotaryEmbedding(torch.nn.Module):
64
  self.batch_size_cached = None
65
  self.cos_cached: torch.Tensor | None = None
66
  self.sin_cached: torch.Tensor | None = None
67
- self.use_cache = use_cache
68
 
69
  def cos_sin(
70
  self,
@@ -107,7 +107,10 @@ class RotaryEmbedding(torch.nn.Module):
107
  def forward(self, q, k):
108
  batch, seq_len, head_dim = q.shape
109
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
110
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
 
 
111
 
112
 
113
  def _make_causal_mask(
@@ -184,7 +187,7 @@ class Attention(nn.Module):
184
  f" {self.num_heads})."
185
  )
186
 
187
- self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
188
 
189
  # Layer-wise attention scaling
190
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
@@ -192,44 +195,34 @@ class Attention(nn.Module):
192
 
193
  self.query_key_value = Linear(
194
  self.hidden_size,
195
- (config.n_head_kv * 2 + config.n_head) * self.head_dim,
196
  bias=config.bias,
197
  )
 
198
  self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
199
  self.attention_dropout = nn.Dropout(config.attention_dropout)
200
- self.num_kv = config.n_head_kv
201
 
202
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
203
  """
204
- Split the last dimension into (num_heads, head_dim), results share same memory
205
  storage as `fused_qkv`
206
 
207
  Args:
208
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
209
 
210
  Returns:
211
- query: [batch_size, seq_length, num_heads, head_dim]
212
- key: [batch_size, seq_length, num_heads, head_dim]
213
  value: [batch_size, seq_length, num_heads, head_dim]
214
  """
215
- batch, seq_len, _ = fused_qkv.shape
216
- qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv + 2, 64)
217
- q = qkv[:, :, :, :-2]
218
- k = qkv[:, :, :, [-2]]
219
- v = qkv[:, :, :, [-1]]
220
- k = torch.broadcast_to(k, q.shape)
221
- v = torch.broadcast_to(v, q.shape)
222
-
223
- q, k, v = [
224
- rearrange(
225
- x,
226
- "batch seq_len group num_heads head_dim ->\
227
- batch seq_len (group num_heads) head_dim",
228
- head_dim=self.head_dim,
229
- )
230
- for x in [q, k, v]
231
- ]
232
- return q, k, v
233
 
234
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
235
  """
@@ -275,11 +268,11 @@ class Attention(nn.Module):
275
 
276
  query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
277
  key_layer = key_layer.transpose(1, 2).reshape(
278
- batch_size * self.num_heads,
279
  q_length,
280
  self.head_dim,
281
  )
282
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
283
 
284
  query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
285
 
@@ -300,12 +293,15 @@ class Attention(nn.Module):
300
 
301
  if alibi is None:
302
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
303
- key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
304
- value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
305
 
306
- attn_output = F.scaled_dot_product_attention(
307
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
308
- )
 
 
 
309
 
310
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
311
  x = x.permute(0, 2, 1, 3)
@@ -330,8 +326,7 @@ class Attention(nn.Module):
330
  attention_scores = attention_scores.to(torch.float32)
331
  # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
332
  attention_probs = F.softmax(
333
- (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor
334
- + attention_mask_float,
335
  dim=-1,
336
  dtype=hidden_states.dtype,
337
  )
@@ -380,12 +375,14 @@ class DecoderLayer(nn.Module):
380
  super().__init__()
381
  hidden_size = config.hidden_size
382
 
383
- self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
384
- self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
385
-
386
  self.num_heads = config.n_head
387
  self.self_attention = Attention(config)
388
 
 
 
 
 
389
  self.mlp = MLP(config)
390
 
391
  self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
@@ -404,14 +401,12 @@ class DecoderLayer(nn.Module):
404
  output_attentions: bool = False,
405
  ):
406
 
407
- ln_attn = self.ln_attn(hidden_states)
408
- ln_mlp = self.ln_mlp(hidden_states)
409
-
410
  residual = hidden_states
411
 
412
  # Self attention.
413
  attn_outputs = self.self_attention(
414
- ln_attn,
415
  layer_past=layer_past,
416
  attention_mask=attention_mask,
417
  alibi=alibi,
@@ -422,14 +417,19 @@ class DecoderLayer(nn.Module):
422
 
423
  attention_output = attn_outputs[0]
424
 
 
 
 
 
425
  outputs = attn_outputs[1:]
426
 
427
  # MLP.
428
- mlp_output = self.mlp(ln_mlp)
429
 
430
- output = dropout_add(
431
- mlp_output + attention_output, residual, self.config.hidden_dropout, training=self.training
432
- )
 
433
 
434
  if use_cache:
435
  outputs = (output,) + outputs
 
52
 
53
  def __init__(
54
  self,
55
+ config,
56
  base=10000,
 
57
  ):
58
+ head_dim = config.head_dim
59
+ self.use_cache = config.use_cache
60
  super().__init__()
61
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
62
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
65
  self.batch_size_cached = None
66
  self.cos_cached: torch.Tensor | None = None
67
  self.sin_cached: torch.Tensor | None = None
 
68
 
69
  def cos_sin(
70
  self,
 
107
  def forward(self, q, k):
108
  batch, seq_len, head_dim = q.shape
109
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
110
+ try:
111
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
112
+ except Exception as e:
113
+ raise
114
 
115
 
116
  def _make_causal_mask(
 
187
  f" {self.num_heads})."
188
  )
189
 
190
+ self.maybe_rotary = RotaryEmbedding(config) if config.rotary else lambda q, k: (q, k)
191
 
192
  # Layer-wise attention scaling
193
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
 
195
 
196
  self.query_key_value = Linear(
197
  self.hidden_size,
198
+ 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
199
  bias=config.bias,
200
  )
201
+ self.multi_query = config.multi_query
202
  self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
203
  self.attention_dropout = nn.Dropout(config.attention_dropout)
204
+ self.num_kv = config.n_head if not self.multi_query else 1
205
 
206
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
207
  """
208
+ Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
209
  storage as `fused_qkv`
210
 
211
  Args:
212
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
213
 
214
  Returns:
215
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
 
216
  value: [batch_size, seq_length, num_heads, head_dim]
217
  """
218
+ if not self.multi_query:
219
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
220
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
221
+ return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
222
+ else:
223
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
224
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
225
+ return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
 
 
 
 
 
 
 
 
 
 
226
 
227
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
228
  """
 
268
 
269
  query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
270
  key_layer = key_layer.transpose(1, 2).reshape(
271
+ batch_size * self.num_kv,
272
  q_length,
273
  self.head_dim,
274
  )
275
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
276
 
277
  query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
278
 
 
293
 
294
  if alibi is None:
295
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
296
+ key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
297
+ value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
298
 
299
+ try:
300
+ attn_output = F.scaled_dot_product_attention(
301
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
302
+ )
303
+ except Exception as e:
304
+ raise
305
 
306
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
307
  x = x.permute(0, 2, 1, 3)
 
326
  attention_scores = attention_scores.to(torch.float32)
327
  # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
328
  attention_probs = F.softmax(
329
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
 
330
  dim=-1,
331
  dtype=hidden_states.dtype,
332
  )
 
375
  super().__init__()
376
  hidden_size = config.hidden_size
377
 
378
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
 
 
379
  self.num_heads = config.n_head
380
  self.self_attention = Attention(config)
381
 
382
+ if not config.parallel_attn:
383
+ # unused if parallel attn
384
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
385
+
386
  self.mlp = MLP(config)
387
 
388
  self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
 
401
  output_attentions: bool = False,
402
  ):
403
 
404
+ layernorm_output = self.input_layernorm(hidden_states)
 
 
405
  residual = hidden_states
406
 
407
  # Self attention.
408
  attn_outputs = self.self_attention(
409
+ layernorm_output,
410
  layer_past=layer_past,
411
  attention_mask=attention_mask,
412
  alibi=alibi,
 
417
 
418
  attention_output = attn_outputs[0]
419
 
420
+ if not self.config.parallel_attn:
421
+ residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
422
+ layernorm_output = self.post_attention_layernorm(residual)
423
+
424
  outputs = attn_outputs[1:]
425
 
426
  # MLP.
427
+ mlp_output = self.mlp(layernorm_output)
428
 
429
+ if self.config.parallel_attn:
430
+ mlp_output += attention_output
431
+
432
+ output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
433
 
434
  if use_cache:
435
  outputs = (output,) + outputs