LeroyDyer commited on
Commit
a88942a
1 Parent(s): 1f1e352

Update modeling_mistral.py

Browse files
Files changed (1) hide show
  1. modeling_mistral.py +171 -2
modeling_mistral.py CHANGED
@@ -166,7 +166,6 @@ class MistralRMSNorm(nn.Module):
166
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
167
  return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
168
 
169
-
170
  # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
171
  class MistralRotaryEmbedding(nn.Module):
172
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
@@ -187,7 +186,7 @@ class MistralRotaryEmbedding(nn.Module):
187
  self.max_seq_len_cached = seq_len
188
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
189
 
190
- freqs = torch.outer(t, self.inv_freq)
191
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
192
  emb = torch.cat((freqs, freqs), dim=-1)
193
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
@@ -203,6 +202,176 @@ class MistralRotaryEmbedding(nn.Module):
203
  self.sin_cached[:seq_len].to(dtype=x.dtype),
204
  )
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  # Copied from transformers.models.llama.modeling_llama.rotate_half
208
  def rotate_half(x):
 
166
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
167
  return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
168
 
 
169
  # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
170
  class MistralRotaryEmbedding(nn.Module):
171
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 
186
  self.max_seq_len_cached = seq_len
187
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
188
 
189
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
190
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
191
  emb = torch.cat((freqs, freqs), dim=-1)
192
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
 
202
  self.sin_cached[:seq_len].to(dtype=x.dtype),
203
  )
204
 
205
+ class MistralLinearScalingRotaryEmbedding(MistralRotaryEmbedding):
206
+ """MistralRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
207
+
208
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
209
+ self.scaling_factor = scaling_factor
210
+ super().__init__(dim, max_position_embeddings, base, device)
211
+
212
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
213
+ self.max_seq_len_cached = seq_len
214
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
215
+ t = t / self.scaling_factor
216
+
217
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
218
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
219
+ emb = torch.cat((freqs, freqs), dim=-1)
220
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
221
+ self.register_buffer("sin_cached", emb.cos().to(dtype), persistent=False)
222
+
223
+
224
+ class MistralDynamicNTKScalingRotaryEmbedding(MistralRotaryEmbedding):
225
+ """MistralRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
226
+
227
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
228
+ self.scaling_factor = scaling_factor
229
+ super().__init__(dim, max_position_embeddings, base, device)
230
+
231
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
232
+ self.max_seq_len_cached = seq_len
233
+
234
+ if seq_len > self.max_position_embeddings:
235
+ base = self.base * (
236
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
237
+ ) ** (self.dim / (self.dim - 2))
238
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
239
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
240
+
241
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
242
+
243
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
244
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
245
+ emb = torch.cat((freqs, freqs), dim=-1)
246
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
247
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
248
+
249
+
250
+ class MistralYaRNScaledRotaryEmbedding(torch.nn.Module):
251
+ """MistralRotaryEmbedding extended with YaRN. See: https://arxiv.org/abs/2309.00071"""
252
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, original_max_position_embeddings=2048,
253
+ extrapolation_factor=1, attn_factor=1, beta_fast=128, beta_slow=2, finetuned=False, device=None):
254
+ super().__init__()
255
+
256
+ self.dim = dim
257
+ self.max_position_embeddings = max_position_embeddings
258
+ self.base = base
259
+ self.scale = scale
260
+ self.original_max_position_embeddings = original_max_position_embeddings
261
+ self.extrapolation_factor = extrapolation_factor
262
+ self.attn_factor = attn_factor
263
+ self.beta_fast = beta_fast
264
+ self.beta_slow = beta_slow
265
+
266
+ self.yarn(device)
267
+
268
+ # Build here to make `torch.jit.trace` work.
269
+ self.max_seq_len_cached = max_position_embeddings
270
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
271
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
272
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
273
+ emb = torch.cat((freqs, freqs), dim=-1)
274
+ dtype = torch.get_default_dtype()
275
+
276
+ self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
277
+ self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
278
+
279
+ def forward(self, x, seq_len=None):
280
+ # x: [bs, num_attention_heads, seq_len, head_size]
281
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
282
+ if seq_len > self.max_seq_len_cached:
283
+ self.max_seq_len_cached = seq_len
284
+
285
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
286
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
287
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
288
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
289
+
290
+ self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(x.dtype), persistent=False)
291
+ self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(x.dtype), persistent=False)
292
+ return (
293
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
294
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
295
+ )
296
+
297
+ def yarn(self, device):
298
+ pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
299
+ inv_freq_extrapolation = 1.0 / pos_freqs
300
+ inv_freq_interpolation = 1.0 / (self.scale * pos_freqs)
301
+
302
+ low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
303
+ inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
304
+ inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
305
+
306
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
307
+ self.mscale = float(_yarn_get_mscale(self.scale) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
308
+
309
+
310
+ class MistralDynamicYaRNScaledRotaryEmbedding(torch.nn.Module):
311
+ """MistralRotaryEmbedding extended with Dynamic YaRN. See: https://arxiv.org/abs/2309.00071"""
312
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, original_max_position_embeddings=2048,
313
+ extrapolation_factor=1, attn_factor=1, beta_fast=128, beta_slow=2, finetuned=False, device=None):
314
+ super().__init__()
315
+
316
+ self.dim = dim
317
+ self.max_position_embeddings = max_position_embeddings
318
+ self.base = base
319
+ self.original_max_position_embeddings = original_max_position_embeddings
320
+ self.extrapolation_factor = extrapolation_factor
321
+ self.attn_factor = attn_factor
322
+ self.beta_fast = beta_fast
323
+ self.beta_slow = beta_slow
324
+
325
+ if finetuned:
326
+ self.yarn(self.max_position_embeddings / self.original_max_position_embeddings, device)
327
+ else:
328
+ inv_freq = 1.0 / \
329
+ (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
330
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
331
+ self.mscale = 1
332
+
333
+ # Build here to make `torch.jit.trace` work.
334
+ self.max_seq_len_cached = max_position_embeddings
335
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
336
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
337
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
338
+ emb = torch.cat((freqs, freqs), dim=-1)
339
+ dtype = torch.get_default_dtype()
340
+
341
+ self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
342
+ self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
343
+
344
+ def forward(self, x, seq_len=None):
345
+ # x: [bs, num_attention_heads, seq_len, head_size]
346
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
347
+ if seq_len > self.max_seq_len_cached:
348
+ self.max_seq_len_cached = seq_len
349
+
350
+ self.yarn(seq_len / self.max_position_embeddings, x.device)
351
+
352
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
353
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
354
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
355
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
356
+
357
+ self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(x.dtype), persistent=False)
358
+ self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(x.dtype), persistent=False)
359
+ return (
360
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
361
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
362
+ )
363
+
364
+ def yarn(self, scale, device):
365
+ pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
366
+ inv_freq_extrapolation = 1.0 / pos_freqs
367
+ inv_freq_interpolation = 1.0 / (scale * pos_freqs)
368
+
369
+ low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
370
+ inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
371
+ inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
372
+
373
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
374
+ self.mscale = float(_yarn_get_mscale(scale) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
375
 
376
  # Copied from transformers.models.llama.modeling_llama.rotate_half
377
  def rotate_half(x):