import torch | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input', '-I', type=str, help='Input file to prune', required = True) | |
args = parser.parse_args() | |
file = args.input | |
checkpoint = torch.load(file) | |
new_sd = dict() | |
for k in checkpoint.keys(): | |
if k != 'optimizer_states': | |
new_sd[k] = checkpoint[k] | |
torch.save(new_sd, f'pruned-{file}') |