jupyterjazz commited on
Commit
77a17f7
1 Parent(s): 8b64fa8

feat: default dim

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (1) hide show
  1. modeling_xlm_roberta.py +1 -1
modeling_xlm_roberta.py CHANGED
@@ -91,7 +91,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
91
  rotary_kwargs = {}
92
  if config.position_embedding_type == "rotary":
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
- config, "rotary_emb_dim", config.hidden_size / 12
95
  )
96
  rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
 
91
  rotary_kwargs = {}
92
  if config.position_embedding_type == "rotary":
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
+ config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
95
  )
96
  rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(