KnutJaegersberg's picture
Upload 132 files
c1a41d7
raw
history blame
7.12 kB
"""
builds a deep-hole-centered D4 codebook
this is a codebook consisting of points on the lattice in R4
where each component is a half-integer
and the components sum to an even number
from this lattice, we select the points that have a norm-squared of at most 9
this results in a codebook of 256 points distributed as follows
8 with sorted abs of [1/2, 1/2, 1/2, 1/2]
8 [3/2, 3/2, 3/2, 3/2]
4c2 * 8 = 48 [1/2, 1/2. 3/2, 3/2]
4 * 8 = 32 [1/2, 1/2, 1/2, 3/2]
4 * 8 = 32 [1/2, 3/2, 3/2, 3/2]
4 * 8 = 32 [1/2, 1/2, 1/2, 5/2]
4 * 3 * 8 = 96 [1/2, 1/2, 3/2, 5/2]
"""
import torch
from torch import nn
import quiptools_cuda
from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda
_D4_CODESZ = 4
def code3_signs(i3, x):
if (i3 & (1 << 5)):
x[2] *= -1
if (i3 & (1 << 6)):
x[1] *= -1
if (sum(x) % 2 != 0):
x[3] *= -1
if (i3 & (1 << 7)):
for j in range(_D4_CODESZ):
x[j] *= -1
assert (sum(x) % 2 == 0)
return x
def code8_to_d4(i8):
assert ((i8 >= 0) and (i8 < 256))
i3 = i8 & (7 << 5)
i8 = i8 & 31
if i8 < 16:
if i8 < 8:
if i8 < 2:
if i8 < 1:
return code3_signs(i3, [0.5] * _D4_CODESZ)
else:
return code3_signs(i3, [1.5] * _D4_CODESZ)
else:
ibx = i8 >> 1
if i8 & 1:
x = [0.5] * _D4_CODESZ
x[0] = 1.5
x[ibx] = 1.5
else:
x = [1.5] * _D4_CODESZ
x[0] = 0.5
x[ibx] = 0.5
return code3_signs(i3, x)
else:
ibx = (i8 & 3)
if i8 < 8 + 4:
x = [0.5] * _D4_CODESZ
x[ibx] = 1.5
else:
x = [1.5] * _D4_CODESZ
x[ibx] = 0.5
return code3_signs(i3, x)
else:
if i8 < 16 + 4:
ibx = (i8 & 3)
x = [0.5] * _D4_CODESZ
x[ibx] = 2.5
return code3_signs(i3, x)
else:
ibx = i8 - 20
ib4 = ibx & 3
ib3 = ibx >> 2
x = [0.5] * _D4_CODESZ
x[ib4] = 1.5
if (ib3 >= ib4):
ib3 += 1
x[ib3] = 2.5
return code3_signs(i3, x)
def build_D4_CB():
CB = torch.zeros(256, _D4_CODESZ)
for i in range(256):
x = code8_to_d4(i)
for j in range(_D4_CODESZ):
CB[i, j] = x[j]
return CB
'''
def quantize(X, CB):
scale = X.square().mean().sqrt() / 1.21
X = X / scale
Xqidx = (2 * X @ CB.t() - (CB @ CB.t()).diag()).argmax(1)
return (CB[Xqidx, :] * scale, scale, Xqidx.to(torch.uint8))
def quantize_noscale_a(X, CB, A):
Xqidx = (2 * X @ A @ CB.t() - (CB @ A @ CB.t()).diag()).argmax(1)
return (CB[Xqidx, :], Xqidx.to(torch.uint8))
def quantize_full_lattice(X):
Xround = (X + 0.5).round() - 0.5
adjustParity = Xround.sum(1) % 2
furthestEntry = (X - Xround).abs().argmax(1)
furthestEntrySign = (X - Xround)[torch.arange(n), furthestEntry].sign()
Xround[torch.arange(n), furthestEntry] += furthestEntrySign * adjustParity
return Xround
'''
class D4_codebook(nn.Module):
def __init__(self, inference=False):
super(D4_codebook, self).__init__()
self.register_buffer("grid", build_D4_CB())
if not inference:
self.register_buffer('grid_norm', (self.grid @ self.grid.T).diag())
self.codesz = _D4_CODESZ
self.opt_scale = 1.21
self.idx_dtype = torch.uint8
self.packsz = 1
self.pack_out = False
self.version = 0
def _quantize_noscale(self, X, return_idx=True):
Xqidx = (2 * X @ self.grid.T - self.grid_norm).argmax(1)
if return_idx:
return self.grid[Xqidx, :], Xqidx.to(self.idx_dtype)
return self.grid[Xqidx, :]
def quantize(self, X, return_idx=True):
assert X.shape[-1] == self.codesz
return self._quantize_noscale(X, return_idx=return_idx)
def maybe_pack_idxs(self, idxs):
return idxs
def by_idxs(self, idxs, **kwargs):
return self.grid[idxs.int()]
class QuantizedD4Linear(nn.Module):
def __init__(self, device):
super().__init__()
self.codebook = D4_codebook(inference=True).to(torch.float16).to(device)
def forward(self,
input,
Qidxs,
SU,
SV,
Wscale,
had_left,
had_right,
K_left,
K_right,
rank=-1,
A=None,
B=None,
rescale_WH=False,
scaleWH=None,
**kwargs):
(m, n) = Qidxs.shape
x = input.view(-1, _D4_CODESZ * n).to(torch.float32)
if rescale_WH:
x /= scaleWH
x = matmul_hadUt_cuda(x * SU, had_left, K_left)
if rank > 0:
Bx = x @ B.t().to(torch.float32)
ABx = Bx @ A.t().to(torch.float32)
x = (x / 1024).to(torch.float16)
if (x.shape[0] <= 8):
if (x.shape[0] == 8):
x_padded = x.contiguous()
else:
x_padded = torch.zeros(8, n * _D4_CODESZ, dtype=torch.float16, device=x.device)
x_padded[0:(x.shape[0]), :] = x
z = torch.zeros(8, m, dtype=x.dtype, device=x.device)
quiptools_cuda.lookupmatmul_d4_k8(x_padded, Qidxs, self.codebook.grid, z)
z = z[0:(x.shape[0]), :]
elif (x.shape[0] <= 16):
if (x.shape[0] == 16):
x_padded = x.contiguous()
else:
x_padded = torch.zeros(16, n * _D4_CODESZ, dtype=torch.float16, device=x.device)
x_padded[0:(x.shape[0]), :] = x
z = torch.zeros(16, m, dtype=x.dtype, device=x.device)
quiptools_cuda.lookupmatmul_d4_k16(x_padded, Qidxs, self.codebook.grid, z)
z = z[0:(x.shape[0]), :]
elif (x.shape[0] <= 32):
if (x.shape[0] == 32):
x_padded = x.contiguous()
else:
x_padded = torch.zeros(32, n * _D4_CODESZ, dtype=torch.float16, device=x.device)
x_padded[0:(x.shape[0]), :] = x
z = torch.zeros(32, m, dtype=x.dtype, device=x.device)
quiptools_cuda.lookupmatmul_d4_k32(x_padded, Qidxs, self.codebook.grid, z)
z = z[0:(x.shape[0]), :]
else:
# manifest the matrix
W_decompressed = torch.zeros(m, n * _D4_CODESZ, dtype=torch.float16, device=x.device)
quiptools_cuda.decompress_d4(Qidxs, self.codebook.grid, W_decompressed)
z = x @ W_decompressed.t()
x = z.to(torch.float32) * (Wscale * 1024)
if rank > 0:
x = x + ABx.to(torch.float32)
return (matmul_hadU_cuda(x, had_right, K_right) * SV).view(*input.shape[:-1], m)