Baichuan-13B-Base / quantizer.py
GuoPD's picture
add: add remote code
db8a935
raw
history blame
929 Bytes
import torch
class QLinear(torch.nn.Module):
def __init__(self, bits: int, weight: torch.Tensor, bias=None):
super().__init__()
self.quant_bits = bits
if self.quant_bits != 8:
raise ValueError(
f'Only supprt int8 quant in current version'
)
self.scale = weight.abs().max(dim=-1).values / ((2 ** (bits - 1)) - 1)
self.weight = torch.round(weight / self.scale[:, None]).to(torch.int8)
self.weight = self.weight.T
self.bias = None
def forward(self, input):
if self.weight.device != input.device:
self.weight = self.weight.to(input.device)
self.scale = self.scale.to(input.device)
output = torch.matmul(input, self.weight.to(input.dtype)) * self.scale.to(input.dtype)[None,None, :]
if self.bias is not None:
output = output + self.bias
return output