wuzhiying2023 commited on
Commit
f7c3fb3
1 Parent(s): f9d4cc5

fix NormHead eval

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +1 -0
modeling_baichuan.py CHANGED
@@ -511,6 +511,7 @@ class NormHead(nn.Module):
511
  def forward(self, hidden_states):
512
  if self.training:
513
  norm_weight = nn.functional.normalize(self.weight)
 
514
  elif self.first_flag:
515
  self.first_flag = False
516
  self.weight = nn.Parameter(nn.functional.normalize(self.weight))
 
511
  def forward(self, hidden_states):
512
  if self.training:
513
  norm_weight = nn.functional.normalize(self.weight)
514
+ self.first_flag = True
515
  elif self.first_flag:
516
  self.first_flag = False
517
  self.weight = nn.Parameter(nn.functional.normalize(self.weight))