import os import datetime import random import argparse from copy import deepcopy from tqdm import tqdm os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" import numpy import torch from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerFast from datasets import load_dataset import torch.multiprocessing as mp # import data_utils from lib import utils parser = argparse.ArgumentParser() parser.add_argument('--seed', default=0, type=int) parser.add_argument('--batch_size', default=2, type=int) parser.add_argument('--devset_size', default=256, type=int) parser.add_argument('--ctx_size', default=4096, type=int) parser.add_argument('--base_model', default='meta-llama/Llama-2-70b-hf', type=str) parser.add_argument('--save_path', default='hessians/llama2_70b', type=str) parser.add_argument('--scratch_path', default=None, type=str) parser.add_argument('--chunk_size', default=256, type=int) parser.add_argument('--async_copy_speed', default=-1, type=int) parser.add_argument('--act_save_rate', default=4, type=int) parser.add_argument('--save_activations', action='store_true') parser.add_argument('--sample_proc', default=4, type=int) def move_fn(in_q, async_copy_speed): # async copy to avoid slow disk while True: item = in_q.get() if item is None: return src, tgt = item if async_copy_speed > 0: os.system(f'rsync --bwlimit={async_copy_speed} {src} {tgt}') else: os.system(f'rsync {src} {tgt}') os.system(f'rm {src}') print(f'moved {src} to {tgt}') def forward_layer(layer, position_ids, attention_mask, bs, device, in_q, out_q): torch.set_grad_enabled(False) layer = layer.to(device) position_ids = position_ids.to(device) attention_mask = attention_mask.to(device) done_qkv = utils.register_H_hook(layer.self_attn.q_proj, device) done_o = utils.register_H_hook(layer.self_attn.o_proj, device) done_up = utils.register_H_hook(layer.mlp.up_proj, device) done_down = utils.register_H_hook(layer.mlp.down_proj, device) while True: dev_emb = in_q.get() if dev_emb is None: layer = layer.cpu() position_ids = position_ids.cpu() attention_mask = attention_mask.cpu() out_q.put({'qkv': done_qkv(), 'o': done_o(), 'up': done_up(), 'down': done_down()}) return assert len(dev_emb) % bs == 0 for i in range(len(dev_emb) // bs): dev_emb[i * bs:(i + 1) * bs] = layer(dev_emb[i * bs:(i + 1) * bs].to(device), position_ids=position_ids, attention_mask=attention_mask, use_cache=False, output_attentions=False)[0].cpu() def accumulate(in_q, move_q, ngpus, args, transformer_layer_index): Hs = {} mus = {} cts = {} for i in range(ngpus): out = in_q.get() if i == 0: for key in out: Hs[key] = torch.zeros(out[key][0].shape, dtype=out[key][0].dtype) mus[key] = torch.zeros(out[key][1].shape, dtype=out[key][1].dtype) cts[key] = 0 for key in out: Hs[key].add_(out[key][0]) mus[key].add_(out[key][1]) cts[key] += out[key][2] keys = list(Hs.keys()) for key in Hs: mus[key].div_(cts[key]) Hs[key].div_(cts[key]) Hs[key].addmm_(-mus[key].unsqueeze(-1), mus[key].unsqueeze(0)) save_path = f"{args.scratch_path}/{transformer_layer_index}_{key}.pt" if args.scratch_path is not None else f"{args.save_path}/{transformer_layer_index}_{key}.pt" torch.save( { 'flatH': utils.sym_to_flat(Hs[key].to(torch.float32)), 'mu': mus[key].to(torch.float32), 'n': Hs[key].shape[0], 'ct': cts[key] }, save_path) if args.scratch_path is not None: move_q.put((f"{args.scratch_path}/{transformer_layer_index}_{key}.pt", f"{args.save_path}/{transformer_layer_index}_{key}.pt")) del Hs, mus, cts, out def main(args): print("loading model...") model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype="auto", low_cpu_mem_usage=True) print("loaded model!") tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True) tokenizer.pad_token = tokenizer.eos_token if os.path.isfile(f"{args.save_path}/dev_activations.pt"): print("loading cached dataset...") loaded_dev_activations = torch.load(f"{args.save_path}/dev_activations.pt") after_layer = loaded_dev_activations['after_layer'] dev_emb = loaded_dev_activations['dev_emb'] print(f"loaded cached dataset from {loaded_dev_activations['timestamp']}") else: print("loading dataset...") dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", split="train") devset = utils.sample_devset(dataset, tokenizer, args.devset_size, args.ctx_size, nproc=args.sample_proc) dev_emb = model.model.embed_tokens(devset) after_layer = -1 print("loaded dataset!") print(f"dev_emb dtype: {dev_emb.dtype}") dev_emb.share_memory_() position_ids = torch.arange(args.ctx_size, dtype=torch.int64)[None, :] + \ torch.zeros(args.batch_size, args.ctx_size, dtype=torch.int64) if hasattr(model.config, 'sliding_window'): # mistral models attention_mask = model.model._prepare_decoder_attention_mask( torch.ones(args.batch_size, args.ctx_size, dtype=torch.bool), (args.batch_size, args.ctx_size), dev_emb[0:args.batch_size, :, :], 0, sliding_window=model.config.sliding_window) else: attention_mask = model.model._prepare_decoder_attention_mask( torch.ones(args.batch_size, args.ctx_size, dtype=torch.bool), (args.batch_size, args.ctx_size), dev_emb[0:args.batch_size, :, :], 0) if args.scratch_path is not None: move_q = mp.Queue() move_p = mp.Process(target=move_fn, args=(move_q, args.async_copy_speed)) move_p.start() else: move_q = None for transformer_layer_index in range(len(model.model.layers)): if (transformer_layer_index <= after_layer): print( f"skipping layer {transformer_layer_index} because it is before cached activations at layer {after_layer}" ) continue transformer_layer = model.model.layers[transformer_layer_index] # check that there are four layers, as expected assert (len([m for m in transformer_layer.modules() if isinstance(m, torch.nn.Linear)]) == 7) chunk_size = min(args.chunk_size, len(dev_emb)) ngpus = min(torch.cuda.device_count(), len(dev_emb) // chunk_size) manager = mp.get_context('spawn').Manager() in_q = manager.Queue() out_q = manager.Queue() accumulate_proc = mp.Process(target=accumulate, args=(out_q, move_q, ngpus, args, transformer_layer_index)) accumulate_proc.start() forward_procs = [] for i in range(ngpus): p = mp.Process(target=forward_layer, args=(transformer_layer, position_ids, attention_mask, args.batch_size, i, in_q, out_q)) p.start() forward_procs.append(p) assert len(dev_emb) % args.batch_size == 0 and chunk_size % args.batch_size == 0 i = 0 while i < len(dev_emb): next = min(i + chunk_size, len(dev_emb)) in_q.put(dev_emb[i:next]) i = next for i in range(ngpus): in_q.put(None) for p in forward_procs: p.join() accumulate_proc.join() transformer_layer.cpu() model.model.layers[transformer_layer_index] = None utils.clean() if args.save_activations and ( transformer_layer_index % args.act_save_rate == 0 or \ transformer_layer_index == len(model.model.layers) - 1): if args.scratch_path is not None: if os.path.exists(f'{args.scratch_path}/dev_activations.pt'): print('not saving layer since disk is too slow') else: torch.save( { 'dev_emb': dev_emb, 'after_layer': transformer_layer_index, 'timestamp': str(datetime.datetime.now()) }, f'{args.scratch_path}/dev_activations.pt') move_q.put((f'{args.scratch_path}/dev_activations.pt', f'{args.save_path}/dev_activations.pt')) else: torch.save( { 'dev_emb': dev_emb, 'after_layer': transformer_layer_index, 'timestamp': str(datetime.datetime.now()) }, f'{args.save_path}/dev_activations.pt') print(f"done processing layer {transformer_layer_index}") if args.scratch_path is not None: move_q.put(None) move_p.join() if __name__ == "__main__": mp.set_start_method('spawn') torch.set_grad_enabled(False) args = parser.parse_args() torch.manual_seed(args.seed) random.seed(args.seed) numpy.random.seed(args.seed) os.makedirs(args.save_path, exist_ok=True) main(args)