Lang2mol-Diff / inference.py
ndhieunguyen's picture
Add application file
7dd9869
import torch
import argparse
import selfies as sf
from tqdm import tqdm
from transformers import set_seed
from src.scripts.mytokenizers import Tokenizer
from src.improved_diffusion import gaussian_diffusion as gd
from src.improved_diffusion import dist_util, logger
from src.improved_diffusion.respace import SpacedDiffusion
from src.improved_diffusion.transformer_model import TransformerNetModel
from src.improved_diffusion.script_util import (
model_and_diffusion_defaults,
add_dict_to_argparser,
)
from src.scripts.mydatasets import Lang2molDataset_eval
def main():
set_seed(42)
args = create_argparser().parse_args()
# dist_util.setup_dist()
logger.configure()
args.sigma_small = True
# args.diffusion_steps = 200 #500 # DEBUG
if args.experiment == "random1":
args.experiment = "random"
logger.log("creating model and diffusion...")
tokenizer = Tokenizer()
model = TransformerNetModel(
in_channels=args.model_in_channels,
model_channels=args.model_model_channels,
dropout=args.model_dropout,
vocab_size=len(tokenizer),
hidden_size=args.model_hidden_size,
num_attention_heads=args.model_num_attention_heads,
num_hidden_layers=args.model_num_hidden_layers,
)
model.eval()
diffusion = SpacedDiffusion(
use_timesteps=[i for i in range(0, args.diffusion_steps, 10)],
betas=gd.get_named_beta_schedule("sqrt", args.diffusion_steps),
model_mean_type=(gd.ModelMeanType.START_X),
model_var_type=((gd.ModelVarType.FIXED_LARGE)),
loss_type=gd.LossType.E2E_MSE,
rescale_timesteps=True,
model_arch="transformer",
training_mode="e2e",
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
pytorch_total_params = sum(p.numel() for p in model.parameters())
logger.log(f"the parameter count is {pytorch_total_params}")
model.to(dist_util.dev())
model.eval()
logger.log("sampling...")
print("--" * 30)
print(f"Loading {args.split} set")
print("--" * 30)
validation_dataset = Lang2molDataset_eval(
dir=args.dataset_path,
tokenizer=tokenizer,
split=args.split,
corrupt_prob=0.0,
token_max_length=args.token_max_length,
dataset_name=args.dataset_name,
)
print("-------------------- DATASET INFO --------------------")
print(f"Size: {len(validation_dataset)} samples")
print(f'Sample shape: {validation_dataset[0]["caption_state"].shape}')
print(f"Use DDIM: {args.use_ddim}")
sample_fn = (
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
)
print(f"Batch size: {args.batch_size}")
next_batch_start = args.start
next_batch_end = next_batch_start + args.batch_size
all_outputs = []
all_caption = []
all_smiles = []
pbar = tqdm(
total=len(validation_dataset) // args.batch_size + 1
if len(validation_dataset) % args.batch_size != 0
else len(validation_dataset) // args.batch_size
)
while True:
sample = [
(
validation_dataset[i]["caption_state"],
validation_dataset[i]["caption_mask"],
validation_dataset[i]["caption"],
validation_dataset[i]["smiles"],
)
for i in range(next_batch_start, next_batch_end)
]
caption_state = torch.concat([i[0] for i in sample], dim=0)
caption_mask = torch.concat([i[1] for i in sample], dim=0)
caption = [i[2] for i in sample]
smiles = [i[3] for i in sample]
outputs = sample_fn(
model,
(args.batch_size, 256, model.in_channels),
clip_denoised=args.clip_denoised,
denoised_fn=None,
model_kwargs={},
top_p=args.top_p,
progress=True,
caption=(caption_state, caption_mask),
)
logits = model.get_logits(torch.tensor(outputs).cuda())
cands = torch.topk(logits, k=1, dim=-1)
outputs = cands.indices
outputs = outputs.squeeze(-1)
outputs = tokenizer.decode(outputs)
with open(args.outputdir, "a") as f:
for i, x in enumerate(outputs):
f.write(
caption[i]
+ "\t"
+ smiles[i]
+ "\t"
+ sf.decoder(x.replace("<pad>", "").replace("</s>", ""))
+ "\n"
)
all_outputs += outputs
all_caption += caption
all_smiles += smiles
next_batch_start = next_batch_end
next_batch_end = min(next_batch_end + args.batch_size, len(validation_dataset))
pbar.update(1)
if next_batch_start == len(validation_dataset):
break
with open(args.outputdir.replace(".txt", "_final.txt"), "w") as f:
for i, x in enumerate(all_outputs):
f.write(
all_caption[i]
+ "\t"
+ all_smiles[i]
+ "\t"
+ sf.decoder(x.replace("<pad>", "").replace("</s>", ""))
+ "\n"
)
def create_argparser():
defaults = dict(
clip_denoised=False,
mbr_sample=1,
model_path="",
model_arch="conv-unet",
verbose="yes",
)
text_defaults = dict(
modality="text",
dataset_name="",
dataset_config_name="wikitext-2-raw-v1",
dataset_path="dataset",
experiment="gpt2_pre_compress",
model_arch="trans-unet",
model_in_channels=32,
model_model_channels=128,
model_dropout=0.1,
model_hidden_size=1024,
model_num_attention_heads=16,
model_num_hidden_layers=12,
preprocessing_num_workers=1,
emb_scale_factor=1.0,
clamp="clamp",
split="validation",
model_path="",
use_ddim=False,
batch_size=16,
top_p=1.0,
outputdir="output.txt",
diffusion_steps=2000,
token_max_length=256,
start=0,
)
defaults.update(model_and_diffusion_defaults())
defaults.update(text_defaults)
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()