|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class T5LayerNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6): |
|
""" |
|
Construct a layernorm module in the T5 style. No bias and no subtraction of mean. |
|
""" |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
|
|
|
|
|
|
|
|
|
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
|
|
|
if self.weight.dtype in [torch.float16, torch.bfloat16]: |
|
hidden_states = hidden_states.to(self.weight.dtype) |
|
|
|
return self.weight * hidden_states |
|
|
|
@staticmethod |
|
def from_native_module(module, *args, **kwargs): |
|
assert module.__class__.__name__ == "FusedRMSNorm", ( |
|
"Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm." |
|
"Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48" |
|
) |
|
|
|
layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps) |
|
layer_norm.weight.data.copy_(module.weight.data) |
|
layer_norm = layer_norm.to(module.weight.device) |
|
return layer_norm |
|
|