Update modeling_mistral.py
Browse files- modeling_mistral.py +32 -4
modeling_mistral.py
CHANGED
@@ -475,14 +475,42 @@ class MistralAttention(nn.Module):
|
|
475 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
476 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
477 |
|
478 |
-
self.
|
479 |
-
self.head_dim,
|
480 |
-
max_position_embeddings=self.max_position_embeddings,
|
481 |
-
base=self.rope_theta,
|
482 |
)
|
483 |
|
484 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
485 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
|
487 |
def forward(
|
488 |
self,
|
|
|
475 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
476 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
477 |
|
478 |
+
self._init_rope()
|
|
|
|
|
|
|
479 |
)
|
480 |
|
481 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
482 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
483 |
+
def _init_rope(self):
|
484 |
+
if self.config.rope_scaling is None:
|
485 |
+
self.rotary_emb = MistralRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta)
|
486 |
+
else:
|
487 |
+
scaling_type = self.config.rope_scaling["type"]
|
488 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
489 |
+
if scaling_type == "linear":
|
490 |
+
self.rotary_emb = MistralLinearScalingRotaryEmbedding(
|
491 |
+
self.head_dim, max_position_embeddings=self.max_position_embeddings,
|
492 |
+
scaling_factor=scaling_factor, base=self.rope_theta,
|
493 |
+
)
|
494 |
+
elif scaling_type == "dynamic":
|
495 |
+
self.rotary_emb = MistralDynamicNTKScalingRotaryEmbedding(
|
496 |
+
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor,
|
497 |
+
base=self.rope_theta,
|
498 |
+
)
|
499 |
+
elif scaling_type == "yarn":
|
500 |
+
original_max_position_embeddings = self.config.rope_scaling["original_max_position_embeddings"]
|
501 |
+
self.rotary_emb = MistralYaRNScaledRotaryEmbedding(
|
502 |
+
self.head_dim, max_position_embeddings=self.max_position_embeddings, scale=scaling_factor,
|
503 |
+
original_max_position_embeddings=original_max_position_embeddings, base=self.rope_theta,
|
504 |
+
)
|
505 |
+
elif scaling_type == "dynamic-yarn":
|
506 |
+
original_max_position_embeddings = self.config.rope_scaling["original_max_position_embeddings"]
|
507 |
+
self.rotary_emb = MistralDynamicYaRNScaledRotaryEmbedding(
|
508 |
+
self.head_dim, max_position_embeddings=self.max_position_embeddings,
|
509 |
+
original_max_position_embeddings=original_max_position_embeddings, base=self.rope_theta,
|
510 |
+
)
|
511 |
+
else:
|
512 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
513 |
+
|
514 |
|
515 |
def forward(
|
516 |
self,
|