|
|
|
|
|
|
|
import gc |
|
import sys |
|
import time |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import torch |
|
from datasets import load_dataset |
|
|
|
|
|
wd = Path(__file__).parent.parent.resolve() |
|
sys.path.append(str(wd)) |
|
|
|
from lit_llama import LLaMA, Tokenizer |
|
from lit_llama.quantization import GPTQQuantizer |
|
from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup |
|
|
|
|
|
def get_sample_data(): |
|
traindata = load_dataset( |
|
"allenai/c4", |
|
"allenai--c4", |
|
data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, |
|
split="train", |
|
) |
|
|
|
txt = "\n".join( |
|
traindata[i]["text"] for i in torch.randperm(len(traindata))[:1000].tolist() |
|
) |
|
return txt |
|
|
|
|
|
@torch.no_grad() |
|
def llama_blockwise_quantization( |
|
model, sample_inputs, working_device, *, bits=4, groupsize=-1 |
|
): |
|
""" |
|
This is the classic post-training quantization of all linear layers. |
|
We quantize in order, i.e. when observing the inputs, we use the outputs of the previously quantized layers rather |
|
than doing them all at once. |
|
""" |
|
print(model) |
|
print(model.config) |
|
|
|
print("Getting inputs for first block") |
|
model.transformer.wte.to(working_device) |
|
sample_inputs = sample_inputs.to(working_device) |
|
inps = model.transformer.wte(sample_inputs) |
|
model.transformer.wte.to("cpu") |
|
torch.cuda.empty_cache() |
|
|
|
rope_cache = model.build_rope_cache(sample_inputs) |
|
mask_cache = model.build_mask_cache(sample_inputs) |
|
|
|
print("Starting to quantize blocks") |
|
outs = torch.zeros_like(inps) |
|
|
|
|
|
|
|
|
|
submodules_to_process = [ |
|
"attn.c_attn", |
|
"attn.c_proj", |
|
"mlp.c_fc1", |
|
"mlp.c_fc2", |
|
"mlp.c_proj", |
|
] |
|
|
|
for i, block in enumerate(model.transformer.h): |
|
block.to(working_device) |
|
|
|
for name in submodules_to_process: |
|
print(i, name, end=" ") |
|
t0 = time.perf_counter() |
|
print("collecting stats", end=" ") |
|
sys.stdout.flush() |
|
module = block.get_submodule(name) |
|
|
|
gptq = GPTQQuantizer( |
|
module, |
|
bits=bits, |
|
groupsize=groupsize, |
|
actorder=(groupsize == -1), |
|
) |
|
handle = module.register_forward_hook(gptq.collect_input_stats) |
|
for j in range(inps.size(0)): |
|
outs[j : j + 1], _ = block( |
|
inps[j : j + 1], |
|
rope=rope_cache, |
|
mask=mask_cache, |
|
max_seq_length=model.config.block_size |
|
) |
|
|
|
handle.remove() |
|
|
|
print("quantizing", end=" ") |
|
sys.stdout.flush() |
|
q_module, error = gptq.quantize() |
|
|
|
|
|
pname, dname = name.rsplit(".", 1) |
|
setattr(block.get_submodule(pname), dname, q_module) |
|
|
|
|
|
del gptq |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
t1 = time.perf_counter() |
|
print(f"time {int(t1 - t0 + 0.5)}s quantization error {error:.1f}") |
|
|
|
for j in range(inps.size(0)): |
|
outs[j : j + 1], _ = block( |
|
inps[j : j + 1], |
|
rope=rope_cache, |
|
mask=mask_cache, |
|
max_seq_length=model.config.block_size |
|
) |
|
|
|
block.cpu() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
inps, outs = outs, inps |
|
|
|
model.transformer.ln_f.to(working_device) |
|
for j in range(inps.size(0)): |
|
outs[j : j + 1] = model.transformer.ln_f(inps[j : j + 1]) |
|
model.transformer.ln_f.to("cpu") |
|
inps, outs = outs, inps |
|
|
|
model.lm_head.to(working_device) |
|
gptq = GPTQQuantizer( |
|
model.lm_head, |
|
bits=bits, |
|
groupsize=groupsize, |
|
actorder=(groupsize == -1), |
|
) |
|
handle = model.lm_head.register_forward_hook(gptq.collect_input_stats) |
|
for j in range(inps.size(0)): |
|
model.lm_head(inps[j : j + 1]) |
|
handle.remove() |
|
q_module, error = gptq.quantize() |
|
model.lm_head = q_module |
|
model.lm_head.to("cpu") |
|
|
|
|
|
def main( |
|
*, |
|
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), |
|
output_path: Optional[Path] = None, |
|
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), |
|
n_samples: int = 128, |
|
dtype: str = "float32", |
|
quantize: Optional[str] = None, |
|
) -> None: |
|
"""Generates text samples based on a pre-trained LLaMA model and tokenizer. |
|
|
|
Args: |
|
checkpoint_path: The checkpoint path to load. |
|
output_path: Path to write the quantized model's state dict to. |
|
tokenizer_path: The tokenizer path to load. |
|
n_samples: Number of example inputs to use for statistics (default: 128) |
|
dtype: The dtype to use to load the model. |
|
quantize: Mode to quantize the model to: |
|
``"gptq.int4"``: GPTQ 4-bit mode. |
|
Note that ``"llm.int8"```does not need a quantization step. |
|
""" |
|
assert checkpoint_path.is_file() |
|
assert tokenizer_path.is_file() |
|
if output_path is None: |
|
output_path = checkpoint_path.parent / "llama-gptq.4bit.pth" |
|
assert output_path.parent.is_dir() and (not output_path.exists() or output_path.is_file()) |
|
|
|
device = "cuda" |
|
|
|
dt = getattr(torch, dtype, None) |
|
if not isinstance(dt, torch.dtype): |
|
raise ValueError(f"{dtype} is not a valid dtype.") |
|
dtype = dt |
|
|
|
if quantize == "gptq.int4": |
|
bits = 4 |
|
elif quantize == "gptq.int8": |
|
bits = 8 |
|
else: |
|
raise RuntimeError(f"unknown/unsupported quantization mode {quantize}") |
|
|
|
|
|
with EmptyInitOnDevice( |
|
device="cpu", |
|
dtype=dtype, |
|
): |
|
print("Loading model ...", file=sys.stderr) |
|
t0 = time.time() |
|
checkpoint = torch.load(checkpoint_path) |
|
name = llama_model_lookup(checkpoint) |
|
model = LLaMA.from_name(name) |
|
model.load_state_dict(checkpoint) |
|
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) |
|
|
|
model.eval() |
|
|
|
tokenizer = Tokenizer(tokenizer_path) |
|
|
|
test_string = get_sample_data() |
|
encoded_text = tokenizer.encode( |
|
test_string, |
|
bos=True, |
|
eos=False, |
|
) |
|
block_size = 2048 |
|
encoded_text = encoded_text[: n_samples * block_size].reshape(n_samples, block_size) |
|
|
|
t0 = time.perf_counter() |
|
llama_blockwise_quantization(model, encoded_text, device, bits=bits) |
|
t = time.perf_counter() - t0 |
|
|
|
print( |
|
f"\n\nTime for quantization: {t:.02f} sec total", |
|
file=sys.stderr, |
|
) |
|
print( |
|
f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", |
|
file=sys.stderr, |
|
) |
|
|
|
torch.save(model.state_dict(), output_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
from jsonargparse import CLI |
|
|
|
torch.set_float32_matmul_precision("high") |
|
CLI(main) |
|
|