v2ray commited on
Commit
7159da5
1 Parent(s): cf0f89d

Updated default config.

Browse files
Files changed (2) hide show
  1. configuration_grok.py +3 -3
  2. modeling_grok.py +17 -2
configuration_grok.py CHANGED
@@ -96,8 +96,8 @@ class GrokConfig(PretrainedConfig):
96
  num_hidden_layers=64,
97
  num_attention_heads=48,
98
  num_key_value_heads=8,
99
- hidden_act="silu",
100
- max_position_embeddings=4096,
101
  initializer_range=0.02,
102
  rms_norm_eps=1e-5,
103
  use_cache=True,
@@ -105,7 +105,7 @@ class GrokConfig(PretrainedConfig):
105
  bos_token_id=1,
106
  eos_token_id=2,
107
  tie_word_embeddings=True,
108
- rope_theta=1e5,
109
  attention_dropout=0.0,
110
  num_experts_per_tok=2,
111
  num_local_experts=8,
 
96
  num_hidden_layers=64,
97
  num_attention_heads=48,
98
  num_key_value_heads=8,
99
+ hidden_act="gelu_new",
100
+ max_position_embeddings=8192,
101
  initializer_range=0.02,
102
  rms_norm_eps=1e-5,
103
  use_cache=True,
 
105
  bos_token_id=1,
106
  eos_token_id=2,
107
  tie_word_embeddings=True,
108
+ rope_theta=1e4,
109
  attention_dropout=0.0,
110
  num_experts_per_tok=2,
111
  num_local_experts=8,
modeling_grok.py CHANGED
@@ -338,8 +338,11 @@ class GrokDecoderLayer(nn.Module):
338
  self.top_k = config.num_experts_per_tok
339
 
340
  self.multi_head_attention = GrokAttention(config, layer_idx)
341
- self.router = nn.Linear(self.hidden_size, self.num_experts, dtype=torch.float32, bias=False)
342
- self.moe = nn.ModuleList([GrokBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
 
 
 
343
 
344
  self.rms_norm = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
345
  self.rms_norm_1 = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -397,6 +400,18 @@ class GrokDecoderLayer(nn.Module):
397
  residual = hidden_states
398
  hidden_states = self.rms_norm_2(hidden_states)
399
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  batch_size, sequence_length, hidden_dim = hidden_states.shape
401
  hidden_states = hidden_states.view(-1, hidden_dim)
402
  # router_logits: (batch * sequence_length, n_experts)
 
338
  self.top_k = config.num_experts_per_tok
339
 
340
  self.multi_head_attention = GrokAttention(config, layer_idx)
341
+ if self.num_experts > 1:
342
+ self.router = nn.Linear(self.hidden_size, self.num_experts, dtype=torch.float32, bias=False)
343
+ self.moe = nn.ModuleList([GrokBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
344
+ else:
345
+ self.mlp = GrokBlockSparseTop2MLP(config)
346
 
347
  self.rms_norm = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
348
  self.rms_norm_1 = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
400
  residual = hidden_states
401
  hidden_states = self.rms_norm_2(hidden_states)
402
 
403
+ if self.num_experts <= 1:
404
+ hidden_states = self.mlp(hidden_states)
405
+ hidden_states = residual + self.rms_norm_3(hidden_states)
406
+ outputs = (hidden_states,)
407
+ if output_attentions:
408
+ outputs += (self_attn_weights,)
409
+ if use_cache:
410
+ outputs += (present_key_value,)
411
+ if output_router_logits:
412
+ outputs += (None,)
413
+ return outputs
414
+
415
  batch_size, sequence_length, hidden_dim = hidden_states.shape
416
  hidden_states = hidden_states.view(-1, hidden_dim)
417
  # router_logits: (batch * sequence_length, n_experts)