wuzhiying2023
commited on
Commit
•
cb7fc74
1
Parent(s):
57c398d
fix NormHead eval bug
Browse files- modeling_baichuan.py +1 -0
modeling_baichuan.py
CHANGED
@@ -502,6 +502,7 @@ class NormHead(nn.Module):
|
|
502 |
def forward(self, hidden_states):
|
503 |
if self.training:
|
504 |
norm_weight = nn.functional.normalize(self.weight)
|
|
|
505 |
elif self.first_flag:
|
506 |
self.first_flag = False
|
507 |
self.weight = nn.Parameter(nn.functional.normalize(self.weight))
|
|
|
502 |
def forward(self, hidden_states):
|
503 |
if self.training:
|
504 |
norm_weight = nn.functional.normalize(self.weight)
|
505 |
+
self.first_flag = True
|
506 |
elif self.first_flag:
|
507 |
self.first_flag = False
|
508 |
self.weight = nn.Parameter(nn.functional.normalize(self.weight))
|