|
""" |
|
builds a deep-hole-centered D4 codebook |
|
this is a codebook consisting of points on the lattice in R4 |
|
where each component is a half-integer |
|
and the components sum to an even number |
|
from this lattice, we select the points that have a norm-squared of at most 9 |
|
this results in a codebook of 256 points distributed as follows |
|
8 with sorted abs of [1/2, 1/2, 1/2, 1/2] |
|
8 [3/2, 3/2, 3/2, 3/2] |
|
4c2 * 8 = 48 [1/2, 1/2. 3/2, 3/2] |
|
4 * 8 = 32 [1/2, 1/2, 1/2, 3/2] |
|
4 * 8 = 32 [1/2, 3/2, 3/2, 3/2] |
|
4 * 8 = 32 [1/2, 1/2, 1/2, 5/2] |
|
4 * 3 * 8 = 96 [1/2, 1/2, 3/2, 5/2] |
|
""" |
|
|
|
import torch |
|
from torch import nn |
|
import quiptools_cuda |
|
|
|
from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda |
|
|
|
_D4_CODESZ = 4 |
|
|
|
|
|
def code3_signs(i3, x): |
|
if (i3 & (1 << 5)): |
|
x[2] *= -1 |
|
if (i3 & (1 << 6)): |
|
x[1] *= -1 |
|
if (sum(x) % 2 != 0): |
|
x[3] *= -1 |
|
if (i3 & (1 << 7)): |
|
for j in range(_D4_CODESZ): |
|
x[j] *= -1 |
|
assert (sum(x) % 2 == 0) |
|
return x |
|
|
|
|
|
def code8_to_d4(i8): |
|
assert ((i8 >= 0) and (i8 < 256)) |
|
i3 = i8 & (7 << 5) |
|
i8 = i8 & 31 |
|
if i8 < 16: |
|
if i8 < 8: |
|
if i8 < 2: |
|
if i8 < 1: |
|
return code3_signs(i3, [0.5] * _D4_CODESZ) |
|
else: |
|
return code3_signs(i3, [1.5] * _D4_CODESZ) |
|
else: |
|
ibx = i8 >> 1 |
|
if i8 & 1: |
|
x = [0.5] * _D4_CODESZ |
|
x[0] = 1.5 |
|
x[ibx] = 1.5 |
|
else: |
|
x = [1.5] * _D4_CODESZ |
|
x[0] = 0.5 |
|
x[ibx] = 0.5 |
|
return code3_signs(i3, x) |
|
else: |
|
ibx = (i8 & 3) |
|
if i8 < 8 + 4: |
|
x = [0.5] * _D4_CODESZ |
|
x[ibx] = 1.5 |
|
else: |
|
x = [1.5] * _D4_CODESZ |
|
x[ibx] = 0.5 |
|
return code3_signs(i3, x) |
|
else: |
|
if i8 < 16 + 4: |
|
ibx = (i8 & 3) |
|
x = [0.5] * _D4_CODESZ |
|
x[ibx] = 2.5 |
|
return code3_signs(i3, x) |
|
else: |
|
ibx = i8 - 20 |
|
ib4 = ibx & 3 |
|
ib3 = ibx >> 2 |
|
x = [0.5] * _D4_CODESZ |
|
x[ib4] = 1.5 |
|
if (ib3 >= ib4): |
|
ib3 += 1 |
|
x[ib3] = 2.5 |
|
return code3_signs(i3, x) |
|
|
|
|
|
def build_D4_CB(): |
|
CB = torch.zeros(256, _D4_CODESZ) |
|
for i in range(256): |
|
x = code8_to_d4(i) |
|
for j in range(_D4_CODESZ): |
|
CB[i, j] = x[j] |
|
return CB |
|
|
|
|
|
''' |
|
def quantize(X, CB): |
|
scale = X.square().mean().sqrt() / 1.21 |
|
X = X / scale |
|
Xqidx = (2 * X @ CB.t() - (CB @ CB.t()).diag()).argmax(1) |
|
return (CB[Xqidx, :] * scale, scale, Xqidx.to(torch.uint8)) |
|
def quantize_noscale_a(X, CB, A): |
|
Xqidx = (2 * X @ A @ CB.t() - (CB @ A @ CB.t()).diag()).argmax(1) |
|
return (CB[Xqidx, :], Xqidx.to(torch.uint8)) |
|
def quantize_full_lattice(X): |
|
Xround = (X + 0.5).round() - 0.5 |
|
adjustParity = Xround.sum(1) % 2 |
|
furthestEntry = (X - Xround).abs().argmax(1) |
|
furthestEntrySign = (X - Xround)[torch.arange(n), furthestEntry].sign() |
|
Xround[torch.arange(n), furthestEntry] += furthestEntrySign * adjustParity |
|
return Xround |
|
''' |
|
|
|
|
|
class D4_codebook(nn.Module): |
|
|
|
def __init__(self, inference=False): |
|
super(D4_codebook, self).__init__() |
|
self.register_buffer("grid", build_D4_CB()) |
|
if not inference: |
|
self.register_buffer('grid_norm', (self.grid @ self.grid.T).diag()) |
|
self.codesz = _D4_CODESZ |
|
self.opt_scale = 1.21 |
|
self.idx_dtype = torch.uint8 |
|
self.packsz = 1 |
|
self.pack_out = False |
|
self.version = 0 |
|
|
|
def _quantize_noscale(self, X, return_idx=True): |
|
Xqidx = (2 * X @ self.grid.T - self.grid_norm).argmax(1) |
|
if return_idx: |
|
return self.grid[Xqidx, :], Xqidx.to(self.idx_dtype) |
|
return self.grid[Xqidx, :] |
|
|
|
def quantize(self, X, return_idx=True): |
|
assert X.shape[-1] == self.codesz |
|
return self._quantize_noscale(X, return_idx=return_idx) |
|
|
|
def maybe_pack_idxs(self, idxs): |
|
return idxs |
|
|
|
def by_idxs(self, idxs, **kwargs): |
|
return self.grid[idxs.int()] |
|
|
|
|
|
class QuantizedD4Linear(nn.Module): |
|
|
|
def __init__(self, device): |
|
super().__init__() |
|
self.codebook = D4_codebook(inference=True).to(torch.float16).to(device) |
|
|
|
def forward(self, |
|
input, |
|
Qidxs, |
|
SU, |
|
SV, |
|
Wscale, |
|
had_left, |
|
had_right, |
|
K_left, |
|
K_right, |
|
rank=-1, |
|
A=None, |
|
B=None, |
|
rescale_WH=False, |
|
scaleWH=None, |
|
**kwargs): |
|
(m, n) = Qidxs.shape |
|
|
|
x = input.view(-1, _D4_CODESZ * n).to(torch.float32) |
|
if rescale_WH: |
|
x /= scaleWH |
|
x = matmul_hadUt_cuda(x * SU, had_left, K_left) |
|
|
|
if rank > 0: |
|
Bx = x @ B.t().to(torch.float32) |
|
ABx = Bx @ A.t().to(torch.float32) |
|
|
|
x = (x / 1024).to(torch.float16) |
|
|
|
if (x.shape[0] <= 8): |
|
if (x.shape[0] == 8): |
|
x_padded = x.contiguous() |
|
else: |
|
x_padded = torch.zeros(8, n * _D4_CODESZ, dtype=torch.float16, device=x.device) |
|
x_padded[0:(x.shape[0]), :] = x |
|
z = torch.zeros(8, m, dtype=x.dtype, device=x.device) |
|
quiptools_cuda.lookupmatmul_d4_k8(x_padded, Qidxs, self.codebook.grid, z) |
|
z = z[0:(x.shape[0]), :] |
|
elif (x.shape[0] <= 16): |
|
if (x.shape[0] == 16): |
|
x_padded = x.contiguous() |
|
else: |
|
x_padded = torch.zeros(16, n * _D4_CODESZ, dtype=torch.float16, device=x.device) |
|
x_padded[0:(x.shape[0]), :] = x |
|
z = torch.zeros(16, m, dtype=x.dtype, device=x.device) |
|
quiptools_cuda.lookupmatmul_d4_k16(x_padded, Qidxs, self.codebook.grid, z) |
|
z = z[0:(x.shape[0]), :] |
|
elif (x.shape[0] <= 32): |
|
if (x.shape[0] == 32): |
|
x_padded = x.contiguous() |
|
else: |
|
x_padded = torch.zeros(32, n * _D4_CODESZ, dtype=torch.float16, device=x.device) |
|
x_padded[0:(x.shape[0]), :] = x |
|
z = torch.zeros(32, m, dtype=x.dtype, device=x.device) |
|
quiptools_cuda.lookupmatmul_d4_k32(x_padded, Qidxs, self.codebook.grid, z) |
|
z = z[0:(x.shape[0]), :] |
|
else: |
|
|
|
W_decompressed = torch.zeros(m, n * _D4_CODESZ, dtype=torch.float16, device=x.device) |
|
quiptools_cuda.decompress_d4(Qidxs, self.codebook.grid, W_decompressed) |
|
z = x @ W_decompressed.t() |
|
|
|
x = z.to(torch.float32) * (Wscale * 1024) |
|
if rank > 0: |
|
x = x + ABx.to(torch.float32) |
|
|
|
return (matmul_hadU_cuda(x, had_right, K_right) * SV).view(*input.shape[:-1], m) |
|
|