NormHead 中的分支判断
#3
by
JaheimLee
- opened
您好,请教NormHead的forward中为什么采用三个分支来生成norm_weight啊,直接norm_weight = nn.functional.normalize(self.weight)会有什么问题吗?另外,forward中存在nn.Parameter会使deepspeed报错,可以避免这个问题吗?感谢!
您好,请教NormHead的forward中为什么采用三个分支来生成norm_weight啊,直接norm_weight = nn.functional.normalize(self.weight)会有什么问题吗?另外,forward中存在nn.Parameter会使deepspeed报错,可以避免这个问题吗?感谢!
训练的时候直接norm_weight = nn.functional.normalize(self.weight)是可以的,这么做主要是为了减少计算,提高性能。如果是训练,你可以改成直接normalize的方式也行。
This comment has been hidden