wuzhiying commited on
Commit
183f998
1 Parent(s): 2ce8919

NormHead-forward nn.Parameter

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +1 -1
modeling_baichuan.py CHANGED
@@ -504,7 +504,7 @@ class NormHead(nn.Module):
504
  norm_weight = nn.functional.normalize(self.weight)
505
  elif self.first_flag:
506
  self.first_flag = False
507
- self.weight = nn.Parameter(nn.functional.normalize(self.weight))
508
  norm_weight = self.weight
509
  else:
510
  norm_weight = self.weight
 
504
  norm_weight = nn.functional.normalize(self.weight)
505
  elif self.first_flag:
506
  self.first_flag = False
507
+ self.weight.data = nn.functional.normalize(self.weight)
508
  norm_weight = self.weight
509
  else:
510
  norm_weight = self.weight