Spaces:
Running
Running
import torch | |
import argparse | |
import selfies as sf | |
from tqdm import tqdm | |
from transformers import set_seed | |
from src.scripts.mytokenizers import Tokenizer | |
from src.improved_diffusion import gaussian_diffusion as gd | |
from src.improved_diffusion import dist_util, logger | |
from src.improved_diffusion.respace import SpacedDiffusion | |
from src.improved_diffusion.transformer_model import TransformerNetModel | |
from src.improved_diffusion.script_util import ( | |
model_and_diffusion_defaults, | |
add_dict_to_argparser, | |
) | |
from src.scripts.mydatasets import Lang2molDataset_eval | |
def main(): | |
set_seed(42) | |
args = create_argparser().parse_args() | |
# dist_util.setup_dist() | |
logger.configure() | |
args.sigma_small = True | |
# args.diffusion_steps = 200 #500 # DEBUG | |
if args.experiment == "random1": | |
args.experiment = "random" | |
logger.log("creating model and diffusion...") | |
tokenizer = Tokenizer() | |
model = TransformerNetModel( | |
in_channels=args.model_in_channels, | |
model_channels=args.model_model_channels, | |
dropout=args.model_dropout, | |
vocab_size=len(tokenizer), | |
hidden_size=args.model_hidden_size, | |
num_attention_heads=args.model_num_attention_heads, | |
num_hidden_layers=args.model_num_hidden_layers, | |
) | |
model.eval() | |
diffusion = SpacedDiffusion( | |
use_timesteps=[i for i in range(0, args.diffusion_steps, 10)], | |
betas=gd.get_named_beta_schedule("sqrt", args.diffusion_steps), | |
model_mean_type=(gd.ModelMeanType.START_X), | |
model_var_type=((gd.ModelVarType.FIXED_LARGE)), | |
loss_type=gd.LossType.E2E_MSE, | |
rescale_timesteps=True, | |
model_arch="transformer", | |
training_mode="e2e", | |
) | |
model.load_state_dict( | |
dist_util.load_state_dict(args.model_path, map_location="cpu") | |
) | |
pytorch_total_params = sum(p.numel() for p in model.parameters()) | |
logger.log(f"the parameter count is {pytorch_total_params}") | |
model.to(dist_util.dev()) | |
model.eval() | |
logger.log("sampling...") | |
print("--" * 30) | |
print(f"Loading {args.split} set") | |
print("--" * 30) | |
validation_dataset = Lang2molDataset_eval( | |
dir=args.dataset_path, | |
tokenizer=tokenizer, | |
split=args.split, | |
corrupt_prob=0.0, | |
token_max_length=args.token_max_length, | |
dataset_name=args.dataset_name, | |
) | |
print("-------------------- DATASET INFO --------------------") | |
print(f"Size: {len(validation_dataset)} samples") | |
print(f'Sample shape: {validation_dataset[0]["caption_state"].shape}') | |
print(f"Use DDIM: {args.use_ddim}") | |
sample_fn = ( | |
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop | |
) | |
print(f"Batch size: {args.batch_size}") | |
next_batch_start = args.start | |
next_batch_end = next_batch_start + args.batch_size | |
all_outputs = [] | |
all_caption = [] | |
all_smiles = [] | |
pbar = tqdm( | |
total=len(validation_dataset) // args.batch_size + 1 | |
if len(validation_dataset) % args.batch_size != 0 | |
else len(validation_dataset) // args.batch_size | |
) | |
while True: | |
sample = [ | |
( | |
validation_dataset[i]["caption_state"], | |
validation_dataset[i]["caption_mask"], | |
validation_dataset[i]["caption"], | |
validation_dataset[i]["smiles"], | |
) | |
for i in range(next_batch_start, next_batch_end) | |
] | |
caption_state = torch.concat([i[0] for i in sample], dim=0) | |
caption_mask = torch.concat([i[1] for i in sample], dim=0) | |
caption = [i[2] for i in sample] | |
smiles = [i[3] for i in sample] | |
outputs = sample_fn( | |
model, | |
(args.batch_size, 256, model.in_channels), | |
clip_denoised=args.clip_denoised, | |
denoised_fn=None, | |
model_kwargs={}, | |
top_p=args.top_p, | |
progress=True, | |
caption=(caption_state, caption_mask), | |
) | |
logits = model.get_logits(torch.tensor(outputs).cuda()) | |
cands = torch.topk(logits, k=1, dim=-1) | |
outputs = cands.indices | |
outputs = outputs.squeeze(-1) | |
outputs = tokenizer.decode(outputs) | |
with open(args.outputdir, "a") as f: | |
for i, x in enumerate(outputs): | |
f.write( | |
caption[i] | |
+ "\t" | |
+ smiles[i] | |
+ "\t" | |
+ sf.decoder(x.replace("<pad>", "").replace("</s>", "")) | |
+ "\n" | |
) | |
all_outputs += outputs | |
all_caption += caption | |
all_smiles += smiles | |
next_batch_start = next_batch_end | |
next_batch_end = min(next_batch_end + args.batch_size, len(validation_dataset)) | |
pbar.update(1) | |
if next_batch_start == len(validation_dataset): | |
break | |
with open(args.outputdir.replace(".txt", "_final.txt"), "w") as f: | |
for i, x in enumerate(all_outputs): | |
f.write( | |
all_caption[i] | |
+ "\t" | |
+ all_smiles[i] | |
+ "\t" | |
+ sf.decoder(x.replace("<pad>", "").replace("</s>", "")) | |
+ "\n" | |
) | |
def create_argparser(): | |
defaults = dict( | |
clip_denoised=False, | |
mbr_sample=1, | |
model_path="", | |
model_arch="conv-unet", | |
verbose="yes", | |
) | |
text_defaults = dict( | |
modality="text", | |
dataset_name="", | |
dataset_config_name="wikitext-2-raw-v1", | |
dataset_path="dataset", | |
experiment="gpt2_pre_compress", | |
model_arch="trans-unet", | |
model_in_channels=32, | |
model_model_channels=128, | |
model_dropout=0.1, | |
model_hidden_size=1024, | |
model_num_attention_heads=16, | |
model_num_hidden_layers=12, | |
preprocessing_num_workers=1, | |
emb_scale_factor=1.0, | |
clamp="clamp", | |
split="validation", | |
model_path="", | |
use_ddim=False, | |
batch_size=16, | |
top_p=1.0, | |
outputdir="output.txt", | |
diffusion_steps=2000, | |
token_max_length=256, | |
start=0, | |
) | |
defaults.update(model_and_diffusion_defaults()) | |
defaults.update(text_defaults) | |
parser = argparse.ArgumentParser() | |
add_dict_to_argparser(parser, defaults) | |
return parser | |
if __name__ == "__main__": | |
main() | |