import contextlib import logging import math from typing import List, Optional import torch import transformers from torch import nn LOGGER = logging.getLogger(__name__) QUANT_LAYERS = [nn.Linear, nn.Conv2d, transformers.Conv1D] def is_transformer_conv1d(layer): return isinstance(layer, transformers.Conv1D) # These two functions only work on per-channel symmetric quantization for weight def get_weight_scale(weight, weight_bit_width): weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() return weight_scale def fake_quantize_weight(weight, weight_scale): weight_scale = weight_scale[:, None] fake_quantized_weight = torch.round(weight / weight_scale) * weight_scale return fake_quantized_weight class GPTQLayerWrapper: def __init__(self, layer_name, layer, weight_bit_width): super().__init__() self.layer_name = layer_name self.layer = layer self.device = layer.weight.device columns = layer.weight.shape[1] self.columns = columns self.H = torch.zeros((columns, columns), device=self.device) self.nsamples = 0 self.is_record = True self.weight_bit_width = weight_bit_width self.weight_scale = None def record_h(self, x): if self.is_record: x = x.detach().clone() if len(x.shape) == 2: x = x.unsqueeze(0) batch = x.shape[0] if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer): if len(x.shape) == 3: x = x.reshape((-1, x.shape[-1])) x = x.t() if isinstance(self.layer, nn.Conv2d): unfold = nn.Unfold( self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride ) x = unfold(x) x = x.permute([1, 0, 2]) x = x.flatten(1) self.H *= self.nsamples / (self.nsamples + batch) self.nsamples += batch x = math.sqrt(2 / self.nsamples) * x.float() self.H += x.matmul(x.t()) def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1): if groupsize != -1: raise RuntimeError("Group quantization of gptq quantizer is not supported for now") weight = self.layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): weight = weight.flatten(1) if is_transformer_conv1d(self.layer): weight = weight.t() weight = weight.float() weight_scale = get_weight_scale(weight, self.weight_bit_width) # todo: use buffer to store scale self.weight_scale = weight_scale H = self.H dead = torch.diag(H) == 0 H[dead, dead] = 1 weight[:, dead] = 0 losses = torch.zeros_like(weight) Q = torch.zeros_like(weight) damp = percdamp * torch.mean(torch.diag(H)) diag = torch.arange(self.columns, device=self.device) H[diag, diag] += damp try: H = torch.linalg.cholesky(H) H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True) except Exception: logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error") return if H.isnan().any(): logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error") return hinv = H for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) count = i2 - i1 w1 = weight[:, i1:i2].clone() q1 = torch.zeros_like(w1) total_err = torch.zeros_like(w1) losses1 = torch.zeros_like(w1) hinv1 = hinv[i1:i2, i1:i2] for i in range(count): w = w1[:, i] d = hinv1[i, i] q = fake_quantize_weight(w.unsqueeze(1), weight_scale).flatten() q1[:, i] = q losses1[:, i] = (w - q) ** 2 / d ** 2 err = (w - q) / d w1[:, i:] -= err.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0)) total_err[:, i] = err Q[:, i1:i2] = q1 losses[:, i1:i2] = losses1 / 2 weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:]) if torch.cuda.is_available(): torch.cuda.synchronize() if is_transformer_conv1d(self.layer): Q = Q.t() shape = self.layer.weight.shape dtype = self.layer.weight.data.dtype del self.layer.weight setattr(self.layer, "weight", nn.Parameter(Q.reshape(shape).to(dtype), requires_grad=False)) del self.H class GPTQBlockWrapper: def __init__(self, block_name: str, block: nn.Module, weight_bit_width=8): self.layer_wrappers = {} self.hook_handles = [] # block order in the whole network self.order = 0 self.block_name = block_name def get_hook(layer_name): def record_hook(_, x): self.layer_wrappers[layer_name].record_h(x[0]) return record_hook for layer_name, layer in block.named_modules(): if isinstance(layer, tuple(QUANT_LAYERS)): full_layer_name = f"{block_name}.{layer_name}" if layer_name else f"{block_name}" self.layer_wrappers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, weight_bit_width) handle = layer.register_forward_pre_hook(get_hook(full_layer_name)) self.hook_handles.append(handle) def quant_block(self): for _, wrapper in self.layer_wrappers.items(): wrapper.quant_weight() for h in self.hook_handles: h.remove() def set_order(self, idx): self.order = idx def get_order(self): return self.order def enable(self): for n, l in self.layer_wrappers.items(): l.is_record = True def disable(self): for n, l in self.layer_wrappers.items(): l.is_record = False class GPTQuantizer: def __init__(self, block_type: Optional[List[type]] = None): self.gptq_block_wrappers = {} self.block_type = block_type def wrap_model(self, model: nn.Module, weight_bit_width=8): def wrap_block(m, prefix=""): for name, child in m.named_children(): child_prefix = f"{prefix}.{name}" if prefix else name if isinstance(child, tuple(self.block_type)): self.gptq_block_wrappers[name] = GPTQBlockWrapper(child_prefix, child, weight_bit_width) LOGGER.debug(f"Calibrate block {child_prefix} as a whole block in GPTQ") else: wrap_block(child, child_prefix) wrap_block(model) return model @property def calibration_iters(self): return len(self.gptq_block_wrappers) @contextlib.contextmanager def record_order(self): counter = 0 record_handles = [] orders = {} try: def get_record_order_hook(block_name): def record_hook(*args, **kwargs): nonlocal counter if block_name not in orders: orders[block_name] = counter counter += 1 return record_hook for block_name, block_wrapper in self.gptq_block_wrappers.items(): # disable the record for _, layer_wrapper in block_wrapper.layer_wrappers.items(): layer_wrapper.is_record = False one_layer_wrapper_in_block = list(block_wrapper.layer_wrappers.values())[0] handles = one_layer_wrapper_in_block.layer.register_forward_pre_hook(get_record_order_hook(block_name)) record_handles.append(handles) yield except Exception as e: logging.warning(e) finally: for block_name, order in orders.items(): self.gptq_block_wrappers[block_name].set_order(order) for h in record_handles: h.remove() for _, block_wrapper in self.gptq_block_wrappers.items(): # disable the record for _, layer_wrapper in block_wrapper.layer_wrappers.items(): layer_wrapper.is_record = True @contextlib.contextmanager def start_calib_iter(self, i): assert i < len(self.gptq_block_wrappers) target_block_wrapper = None try: for _, block_wrapper in self.gptq_block_wrappers.items(): if block_wrapper.get_order() == i: block_wrapper.enable() target_block_wrapper = block_wrapper else: block_wrapper.disable() yield finally: target_block_wrapper.quant_block() def release_reference(self): # delete reference so that `torch.cuda.empty_cache()` can # release all the gpu memory cache used during calibration for _, block_wrapper in self.gptq_block_wrappers.items(): for _, layer_wrapper in block_wrapper.layer_wrappers.items(): del layer_wrapper.layer torch.cuda.empty_cache() def locate_parent(root: nn.Module, full_path: str): parent = root path = full_path.split('.') for p in path[:-1]: parent = getattr(parent, p) return parent, path[-1] @torch.no_grad() def gptq_quantize(model, tokenizer, weight_bit_width, calib_data): from .modeling_chatglm import GLMBlock from .quantization import QuantizedLinear quantizer = GPTQuantizer([GLMBlock]) calib_model = quantizer.wrap_model(model, weight_bit_width) with quantizer.record_order(): calib_model.chat(tokenizer, calib_data[0], history=[]) logging.info("Start doing calibration using GPTQ ") for i in range(quantizer.calibration_iters): logging.info(f"Process: {i + 1}/{quantizer.calibration_iters}") # todo: should add early return to speed up the calibration # todo: add cpu offload to reduce the gpu memory requirements. with quantizer.start_calib_iter(i): for prompt in calib_data: model.chat(tokenizer, prompt, history=[]) # replace the fp16 linear with quantized linear for _, block_wrapper in quantizer.gptq_block_wrappers.items(): for layer_name, layer_wrapper in block_wrapper.layer_wrappers.items(): layer = layer_wrapper.layer parent, name_in_parent = locate_parent(model, layer_name) quantized_layer = QuantizedLinear( weight_bit_width=weight_bit_width, weight_tensor=layer.weight, bias_tensor=layer.bias, weight_scale=layer_wrapper.weight_scale, in_features=layer.in_features, out_features=layer.out_features, bias=True, dtype=torch.half, device=layer_wrapper.device, empty_init=False ) parent.add_module(name_in_parent, quantized_layer) # release the memory caache during calibration quantizer.release_reference() return