Update modeling_mistral.py
Browse files- modeling_mistral.py +171 -2
modeling_mistral.py
CHANGED
@@ -166,7 +166,6 @@ class MistralRMSNorm(nn.Module):
|
|
166 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
167 |
return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
|
168 |
|
169 |
-
|
170 |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
|
171 |
class MistralRotaryEmbedding(nn.Module):
|
172 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
@@ -187,7 +186,7 @@ class MistralRotaryEmbedding(nn.Module):
|
|
187 |
self.max_seq_len_cached = seq_len
|
188 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
189 |
|
190 |
-
freqs = torch.
|
191 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
192 |
emb = torch.cat((freqs, freqs), dim=-1)
|
193 |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
@@ -203,6 +202,176 @@ class MistralRotaryEmbedding(nn.Module):
|
|
203 |
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
204 |
)
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
208 |
def rotate_half(x):
|
|
|
166 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
167 |
return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
|
168 |
|
|
|
169 |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
|
170 |
class MistralRotaryEmbedding(nn.Module):
|
171 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
|
186 |
self.max_seq_len_cached = seq_len
|
187 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
188 |
|
189 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
190 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
191 |
emb = torch.cat((freqs, freqs), dim=-1)
|
192 |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
|
202 |
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
203 |
)
|
204 |
|
205 |
+
class MistralLinearScalingRotaryEmbedding(MistralRotaryEmbedding):
|
206 |
+
"""MistralRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
207 |
+
|
208 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
209 |
+
self.scaling_factor = scaling_factor
|
210 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
211 |
+
|
212 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
213 |
+
self.max_seq_len_cached = seq_len
|
214 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
215 |
+
t = t / self.scaling_factor
|
216 |
+
|
217 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
218 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
219 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
220 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
221 |
+
self.register_buffer("sin_cached", emb.cos().to(dtype), persistent=False)
|
222 |
+
|
223 |
+
|
224 |
+
class MistralDynamicNTKScalingRotaryEmbedding(MistralRotaryEmbedding):
|
225 |
+
"""MistralRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
226 |
+
|
227 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
228 |
+
self.scaling_factor = scaling_factor
|
229 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
230 |
+
|
231 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
232 |
+
self.max_seq_len_cached = seq_len
|
233 |
+
|
234 |
+
if seq_len > self.max_position_embeddings:
|
235 |
+
base = self.base * (
|
236 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
237 |
+
) ** (self.dim / (self.dim - 2))
|
238 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
239 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
240 |
+
|
241 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
242 |
+
|
243 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
244 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
245 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
246 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
247 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
248 |
+
|
249 |
+
|
250 |
+
class MistralYaRNScaledRotaryEmbedding(torch.nn.Module):
|
251 |
+
"""MistralRotaryEmbedding extended with YaRN. See: https://arxiv.org/abs/2309.00071"""
|
252 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, original_max_position_embeddings=2048,
|
253 |
+
extrapolation_factor=1, attn_factor=1, beta_fast=128, beta_slow=2, finetuned=False, device=None):
|
254 |
+
super().__init__()
|
255 |
+
|
256 |
+
self.dim = dim
|
257 |
+
self.max_position_embeddings = max_position_embeddings
|
258 |
+
self.base = base
|
259 |
+
self.scale = scale
|
260 |
+
self.original_max_position_embeddings = original_max_position_embeddings
|
261 |
+
self.extrapolation_factor = extrapolation_factor
|
262 |
+
self.attn_factor = attn_factor
|
263 |
+
self.beta_fast = beta_fast
|
264 |
+
self.beta_slow = beta_slow
|
265 |
+
|
266 |
+
self.yarn(device)
|
267 |
+
|
268 |
+
# Build here to make `torch.jit.trace` work.
|
269 |
+
self.max_seq_len_cached = max_position_embeddings
|
270 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
271 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
272 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
273 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
274 |
+
dtype = torch.get_default_dtype()
|
275 |
+
|
276 |
+
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
|
277 |
+
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
|
278 |
+
|
279 |
+
def forward(self, x, seq_len=None):
|
280 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
281 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
282 |
+
if seq_len > self.max_seq_len_cached:
|
283 |
+
self.max_seq_len_cached = seq_len
|
284 |
+
|
285 |
+
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
286 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
287 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
288 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
289 |
+
|
290 |
+
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(x.dtype), persistent=False)
|
291 |
+
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(x.dtype), persistent=False)
|
292 |
+
return (
|
293 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
294 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
295 |
+
)
|
296 |
+
|
297 |
+
def yarn(self, device):
|
298 |
+
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
299 |
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
300 |
+
inv_freq_interpolation = 1.0 / (self.scale * pos_freqs)
|
301 |
+
|
302 |
+
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
|
303 |
+
inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
304 |
+
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
305 |
+
|
306 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
307 |
+
self.mscale = float(_yarn_get_mscale(self.scale) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
308 |
+
|
309 |
+
|
310 |
+
class MistralDynamicYaRNScaledRotaryEmbedding(torch.nn.Module):
|
311 |
+
"""MistralRotaryEmbedding extended with Dynamic YaRN. See: https://arxiv.org/abs/2309.00071"""
|
312 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, original_max_position_embeddings=2048,
|
313 |
+
extrapolation_factor=1, attn_factor=1, beta_fast=128, beta_slow=2, finetuned=False, device=None):
|
314 |
+
super().__init__()
|
315 |
+
|
316 |
+
self.dim = dim
|
317 |
+
self.max_position_embeddings = max_position_embeddings
|
318 |
+
self.base = base
|
319 |
+
self.original_max_position_embeddings = original_max_position_embeddings
|
320 |
+
self.extrapolation_factor = extrapolation_factor
|
321 |
+
self.attn_factor = attn_factor
|
322 |
+
self.beta_fast = beta_fast
|
323 |
+
self.beta_slow = beta_slow
|
324 |
+
|
325 |
+
if finetuned:
|
326 |
+
self.yarn(self.max_position_embeddings / self.original_max_position_embeddings, device)
|
327 |
+
else:
|
328 |
+
inv_freq = 1.0 / \
|
329 |
+
(base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
330 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
331 |
+
self.mscale = 1
|
332 |
+
|
333 |
+
# Build here to make `torch.jit.trace` work.
|
334 |
+
self.max_seq_len_cached = max_position_embeddings
|
335 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
336 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
337 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
338 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
339 |
+
dtype = torch.get_default_dtype()
|
340 |
+
|
341 |
+
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
|
342 |
+
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
|
343 |
+
|
344 |
+
def forward(self, x, seq_len=None):
|
345 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
346 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
347 |
+
if seq_len > self.max_seq_len_cached:
|
348 |
+
self.max_seq_len_cached = seq_len
|
349 |
+
|
350 |
+
self.yarn(seq_len / self.max_position_embeddings, x.device)
|
351 |
+
|
352 |
+
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
353 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
354 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
355 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
356 |
+
|
357 |
+
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(x.dtype), persistent=False)
|
358 |
+
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(x.dtype), persistent=False)
|
359 |
+
return (
|
360 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
361 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
362 |
+
)
|
363 |
+
|
364 |
+
def yarn(self, scale, device):
|
365 |
+
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
366 |
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
367 |
+
inv_freq_interpolation = 1.0 / (scale * pos_freqs)
|
368 |
+
|
369 |
+
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
|
370 |
+
inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
371 |
+
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
372 |
+
|
373 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
374 |
+
self.mscale = float(_yarn_get_mscale(scale) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
375 |
|
376 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
377 |
def rotate_half(x):
|