File size: 4,240 Bytes
c1a41d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3c0032
 
c1a41d7
 
 
 
 
 
 
 
b3c0032
 
 
 
c1a41d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3c0032
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import quiptools_cuda
from lib.utils import dtype_from_str, get_hadK
from lib import codebook
import time


class QuantizedLinear(nn.Module):

    def __init__(self,
                 in_features,
                 out_features,
                 codesz,
                 packsz,
                 pack_out,
                 idx_dtype,
                 codebook_version,
                 outlier_channel_split=False,
                 rank=-1,
                 rescale_WH=False,
                 bias=False):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.outlier_channel_split = outlier_channel_split
        self.rank = rank
        self.rescale_WH = rescale_WH

        self.has_bias = bias
        if self.has_bias:
            self.register_buffer('bias', torch.ones(out_features))
        
        if self.outlier_channel_split:
            self.register_buffer('ocs_dupe_inds', torch.arange(in_features))

        if self.rank > 0:
            self.register_buffer('A', torch.zeros(out_features, rank))
            self.register_buffer('B', torch.zeros(rank, in_features))
        else:
            self.A = None
            self.B = None

        if self.rescale_WH:
            self.register_buffer("scaleWH", torch.ones(in_features))
        else:
            self.scaleWH = None

        # direction we pack in, the code dimension is always in the in dimension
        if pack_out:
            self.register_buffer(
                "Qidxs",
                torch.zeros(out_features // packsz,
                            in_features // codesz,
                            dtype=dtype_from_str(idx_dtype)))
        else:
            self.register_buffer(
                "Qidxs",
                torch.zeros(out_features,
                            in_features // (codesz * packsz),
                            dtype=dtype_from_str(idx_dtype)))

        self.register_buffer("codebook_id", torch.tensor(0))
        self.register_buffer("SU", torch.ones(in_features))
        self.register_buffer("SV", torch.ones(out_features))
        self.register_buffer("Wscale", torch.ones(()))

        self.built_codebook_class = False
        self.built_graph = False
        self.codebook_version = codebook_version

        had_left, K_left = get_hadK(in_features)
        had_right, K_right = get_hadK(out_features)
        self.register_buffer('had_left', had_left, persistent=False)
        self.register_buffer('had_right', had_right, persistent=False)
        self.K_left = K_left
        self.K_right = K_right
        self.packed = (packsz != 1)

    def forward(self, input):
        if not self.built_codebook_class:
            self.codebook_class = codebook.get_quantized_class(self.codebook_id.item())(
                self.Qidxs.device)
            if self.codebook_class.codebook.version != self.codebook_version:
                raise Exception(
                    f"Saved weights version ({self.codebook_version}) does not match the "\
                    f"codebook version ({self.codebook_class.codebook.version}). "\
                    "Please download the latest weights from https://huggingface.co/relaxml")
            self.built_codebook_class = True

        if self.outlier_channel_split:
            input = input[..., self.ocs_dupe_inds]

        result = self.codebook_class(input,
                                     self.Qidxs,
                                     self.SU,
                                     self.SV,
                                     self.Wscale,
                                     self.had_left,
                                     self.had_right,
                                     self.K_left,
                                     self.K_right,
                                     rank=self.rank,
                                     A=self.A,
                                     B=self.B,
                                     rescale_WH=self.rescale_WH,
                                     scaleWH=self.scaleWH,
                                     packed=self.packed)
        if self.has_bias:
            return result + self.bias
        return result