import torch class QLinear(torch.nn.Module): def __init__(self, bits: int, weight: torch.Tensor, bias=None): super().__init__() self.quant_bits = bits if self.quant_bits != 8: raise ValueError( f'Only supprt int8 quant in current version' ) self.scale = weight.abs().max(dim=-1).values / ((2 ** (bits - 1)) - 1) self.weight = torch.round(weight / self.scale[:, None]).to(torch.int8) self.weight = self.weight.T self.bias = None def forward(self, input): if self.weight.device != input.device: self.weight = self.weight.to(input.device) self.scale = self.scale.to(input.device) output = torch.matmul(input, self.weight.to(input.dtype)) * self.scale.to(input.dtype)[None,None, :] if self.bias is not None: output = output + self.bias return output