chrisc36 commited on
Commit
f708cc2
1 Parent(s): 2ce8d82

Upload modeling_molmo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_molmo.py +9 -2
modeling_molmo.py CHANGED
@@ -762,7 +762,6 @@ class ViTMLP(nn.Module):
762
  return x
763
 
764
 
765
-
766
  class ResidualAttentionBlock(nn.Module):
767
 
768
  def __init__(self, config: FullMolmoConfig):
@@ -819,6 +818,14 @@ class BlockCollection(nn.Module):
819
  return hidden_states
820
 
821
 
 
 
 
 
 
 
 
 
822
  class VisionTransformer(nn.Module):
823
 
824
  def __init__(self, config: FullMolmoConfig):
@@ -844,7 +851,7 @@ class VisionTransformer(nn.Module):
844
  device=config.init_device,
845
  )
846
 
847
- self.pre_ln = nn.LayerNorm(
848
  v_cfg.image_emb_dim,
849
  eps=v_cfg.image_norm_eps,
850
  )
 
762
  return x
763
 
764
 
 
765
  class ResidualAttentionBlock(nn.Module):
766
 
767
  def __init__(self, config: FullMolmoConfig):
 
818
  return hidden_states
819
 
820
 
821
+ class LayerNormFp32(nn.LayerNorm):
822
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
823
+ orig_type = x.dtype
824
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight.to(torch.float32),
825
+ self.bias.to(torch.float32), self.eps)
826
+ return x.to(orig_type)
827
+
828
+
829
  class VisionTransformer(nn.Module):
830
 
831
  def __init__(self, config: FullMolmoConfig):
 
851
  device=config.init_device,
852
  )
853
 
854
+ self.pre_ln = LayerNormFp32(
855
  v_cfg.image_emb_dim,
856
  eps=v_cfg.image_norm_eps,
857
  )