Updated default config.
Browse files- configuration_grok.py +3 -3
- modeling_grok.py +17 -2
configuration_grok.py
CHANGED
@@ -96,8 +96,8 @@ class GrokConfig(PretrainedConfig):
|
|
96 |
num_hidden_layers=64,
|
97 |
num_attention_heads=48,
|
98 |
num_key_value_heads=8,
|
99 |
-
hidden_act="
|
100 |
-
max_position_embeddings=
|
101 |
initializer_range=0.02,
|
102 |
rms_norm_eps=1e-5,
|
103 |
use_cache=True,
|
@@ -105,7 +105,7 @@ class GrokConfig(PretrainedConfig):
|
|
105 |
bos_token_id=1,
|
106 |
eos_token_id=2,
|
107 |
tie_word_embeddings=True,
|
108 |
-
rope_theta=
|
109 |
attention_dropout=0.0,
|
110 |
num_experts_per_tok=2,
|
111 |
num_local_experts=8,
|
|
|
96 |
num_hidden_layers=64,
|
97 |
num_attention_heads=48,
|
98 |
num_key_value_heads=8,
|
99 |
+
hidden_act="gelu_new",
|
100 |
+
max_position_embeddings=8192,
|
101 |
initializer_range=0.02,
|
102 |
rms_norm_eps=1e-5,
|
103 |
use_cache=True,
|
|
|
105 |
bos_token_id=1,
|
106 |
eos_token_id=2,
|
107 |
tie_word_embeddings=True,
|
108 |
+
rope_theta=1e4,
|
109 |
attention_dropout=0.0,
|
110 |
num_experts_per_tok=2,
|
111 |
num_local_experts=8,
|
modeling_grok.py
CHANGED
@@ -338,8 +338,11 @@ class GrokDecoderLayer(nn.Module):
|
|
338 |
self.top_k = config.num_experts_per_tok
|
339 |
|
340 |
self.multi_head_attention = GrokAttention(config, layer_idx)
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
343 |
|
344 |
self.rms_norm = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
345 |
self.rms_norm_1 = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -397,6 +400,18 @@ class GrokDecoderLayer(nn.Module):
|
|
397 |
residual = hidden_states
|
398 |
hidden_states = self.rms_norm_2(hidden_states)
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
401 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
402 |
# router_logits: (batch * sequence_length, n_experts)
|
|
|
338 |
self.top_k = config.num_experts_per_tok
|
339 |
|
340 |
self.multi_head_attention = GrokAttention(config, layer_idx)
|
341 |
+
if self.num_experts > 1:
|
342 |
+
self.router = nn.Linear(self.hidden_size, self.num_experts, dtype=torch.float32, bias=False)
|
343 |
+
self.moe = nn.ModuleList([GrokBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
344 |
+
else:
|
345 |
+
self.mlp = GrokBlockSparseTop2MLP(config)
|
346 |
|
347 |
self.rms_norm = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
348 |
self.rms_norm_1 = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
400 |
residual = hidden_states
|
401 |
hidden_states = self.rms_norm_2(hidden_states)
|
402 |
|
403 |
+
if self.num_experts <= 1:
|
404 |
+
hidden_states = self.mlp(hidden_states)
|
405 |
+
hidden_states = residual + self.rms_norm_3(hidden_states)
|
406 |
+
outputs = (hidden_states,)
|
407 |
+
if output_attentions:
|
408 |
+
outputs += (self_attn_weights,)
|
409 |
+
if use_cache:
|
410 |
+
outputs += (present_key_value,)
|
411 |
+
if output_router_logits:
|
412 |
+
outputs += (None,)
|
413 |
+
return outputs
|
414 |
+
|
415 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
416 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
417 |
# router_logits: (batch * sequence_length, n_experts)
|