Fix precision error
Browse files- modeling_chatglm.py +9 -7
modeling_chatglm.py
CHANGED
@@ -3,9 +3,7 @@
|
|
3 |
import math
|
4 |
import copy
|
5 |
import warnings
|
6 |
-
import re
|
7 |
import sys
|
8 |
-
|
9 |
import torch
|
10 |
import torch.utils.checkpoint
|
11 |
import torch.nn.functional as F
|
@@ -183,9 +181,14 @@ class RMSNorm(torch.nn.Module):
|
|
183 |
self.eps = eps
|
184 |
|
185 |
def forward(self, hidden_states: torch.Tensor):
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
return (self.weight * hidden_states).to(input_dtype)
|
191 |
|
@@ -517,8 +520,7 @@ class GLMBlock(torch.nn.Module):
|
|
517 |
|
518 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
519 |
# Layernorm on the input data.
|
520 |
-
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
521 |
-
dtype=config.torch_dtype)
|
522 |
|
523 |
# Self attention.
|
524 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
|
3 |
import math
|
4 |
import copy
|
5 |
import warnings
|
|
|
6 |
import sys
|
|
|
7 |
import torch
|
8 |
import torch.utils.checkpoint
|
9 |
import torch.nn.functional as F
|
|
|
181 |
self.eps = eps
|
182 |
|
183 |
def forward(self, hidden_states: torch.Tensor):
|
184 |
+
if hidden_states == torch.bfloat16:
|
185 |
+
norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
|
186 |
+
x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
|
187 |
+
return self.weight * x_normed
|
188 |
+
else:
|
189 |
+
input_dtype = hidden_states.dtype
|
190 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
191 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
192 |
|
193 |
return (self.weight * hidden_states).to(input_dtype)
|
194 |
|
|
|
520 |
|
521 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
522 |
# Layernorm on the input data.
|
523 |
+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype)
|
|
|
524 |
|
525 |
# Self attention.
|
526 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|