File size: 4,240 Bytes
c1a41d7 b3c0032 c1a41d7 b3c0032 c1a41d7 b3c0032 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import torch
import torch.nn as nn
import quiptools_cuda
from lib.utils import dtype_from_str, get_hadK
from lib import codebook
import time
class QuantizedLinear(nn.Module):
def __init__(self,
in_features,
out_features,
codesz,
packsz,
pack_out,
idx_dtype,
codebook_version,
outlier_channel_split=False,
rank=-1,
rescale_WH=False,
bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.outlier_channel_split = outlier_channel_split
self.rank = rank
self.rescale_WH = rescale_WH
self.has_bias = bias
if self.has_bias:
self.register_buffer('bias', torch.ones(out_features))
if self.outlier_channel_split:
self.register_buffer('ocs_dupe_inds', torch.arange(in_features))
if self.rank > 0:
self.register_buffer('A', torch.zeros(out_features, rank))
self.register_buffer('B', torch.zeros(rank, in_features))
else:
self.A = None
self.B = None
if self.rescale_WH:
self.register_buffer("scaleWH", torch.ones(in_features))
else:
self.scaleWH = None
# direction we pack in, the code dimension is always in the in dimension
if pack_out:
self.register_buffer(
"Qidxs",
torch.zeros(out_features // packsz,
in_features // codesz,
dtype=dtype_from_str(idx_dtype)))
else:
self.register_buffer(
"Qidxs",
torch.zeros(out_features,
in_features // (codesz * packsz),
dtype=dtype_from_str(idx_dtype)))
self.register_buffer("codebook_id", torch.tensor(0))
self.register_buffer("SU", torch.ones(in_features))
self.register_buffer("SV", torch.ones(out_features))
self.register_buffer("Wscale", torch.ones(()))
self.built_codebook_class = False
self.built_graph = False
self.codebook_version = codebook_version
had_left, K_left = get_hadK(in_features)
had_right, K_right = get_hadK(out_features)
self.register_buffer('had_left', had_left, persistent=False)
self.register_buffer('had_right', had_right, persistent=False)
self.K_left = K_left
self.K_right = K_right
self.packed = (packsz != 1)
def forward(self, input):
if not self.built_codebook_class:
self.codebook_class = codebook.get_quantized_class(self.codebook_id.item())(
self.Qidxs.device)
if self.codebook_class.codebook.version != self.codebook_version:
raise Exception(
f"Saved weights version ({self.codebook_version}) does not match the "\
f"codebook version ({self.codebook_class.codebook.version}). "\
"Please download the latest weights from https://huggingface.co/relaxml")
self.built_codebook_class = True
if self.outlier_channel_split:
input = input[..., self.ocs_dupe_inds]
result = self.codebook_class(input,
self.Qidxs,
self.SU,
self.SV,
self.Wscale,
self.had_left,
self.had_right,
self.K_left,
self.K_right,
rank=self.rank,
A=self.A,
B=self.B,
rescale_WH=self.rescale_WH,
scaleWH=self.scaleWH,
packed=self.packed)
if self.has_bias:
return result + self.bias
return result
|