|
import torch |
|
|
|
class QLinear(torch.nn.Module): |
|
def __init__(self, bits: int, weight: torch.Tensor, bias=None): |
|
super().__init__() |
|
self.quant_bits = bits |
|
if self.quant_bits != 8: |
|
raise ValueError( |
|
f'Only supprt int8 quant in current version' |
|
) |
|
self.scale = weight.abs().max(dim=-1).values / ((2 ** (bits - 1)) - 1) |
|
self.weight = torch.round(weight / self.scale[:, None]).to(torch.int8) |
|
self.weight = self.weight.T |
|
self.bias = None |
|
|
|
def forward(self, input): |
|
if self.weight.device != input.device: |
|
self.weight = self.weight.to(input.device) |
|
self.scale = self.scale.to(input.device) |
|
|
|
output = torch.matmul(input, self.weight.to(input.dtype)) * self.scale.to(input.dtype)[None,None, :] |
|
if self.bias is not None: |
|
output = output + self.bias |
|
return output |
|
|