kuaizhirui
commited on
Commit
•
1682dd0
1
Parent(s):
f9d4d8d
fix NormHead eval
Browse filesI encountered this problem when using Baichuan2-7B-Base with deepspeed stage3 for sft. A similar situation also happened in the place such as https://github.com/baichuan-inc/Baichuan2/issues/39#issuecomment-1710146497
I found that Baichuan2-13B-Chat has solved this problem, so I synced the code here
- modeling_baichuan.py +2 -1
modeling_baichuan.py
CHANGED
@@ -502,9 +502,10 @@ class NormHead(nn.Module):
|
|
502 |
def forward(self, hidden_states):
|
503 |
if self.training:
|
504 |
norm_weight = nn.functional.normalize(self.weight)
|
|
|
505 |
elif self.first_flag:
|
506 |
self.first_flag = False
|
507 |
-
self.weight = nn.
|
508 |
norm_weight = self.weight
|
509 |
else:
|
510 |
norm_weight = self.weight
|
|
|
502 |
def forward(self, hidden_states):
|
503 |
if self.training:
|
504 |
norm_weight = nn.functional.normalize(self.weight)
|
505 |
+
self.first_flag = False
|
506 |
elif self.first_flag:
|
507 |
self.first_flag = False
|
508 |
+
self.weight.data = nn.functional.normalize(self.weight)
|
509 |
norm_weight = self.weight
|
510 |
else:
|
511 |
norm_weight = self.weight
|