Spaces:
Runtime error
Runtime error
import argparse, os, sys, glob, math, time | |
import torch | |
import numpy as np | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from main import instantiate_from_config, DataModuleFromConfig | |
from torch.utils.data import DataLoader | |
from torch.utils.data.dataloader import default_collate | |
from tqdm import trange | |
def save_image(x, path): | |
c,h,w = x.shape | |
assert c==3 | |
x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8) | |
Image.fromarray(x).save(path) | |
def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1): | |
if len(dsets.datasets) > 1: | |
split = sorted(dsets.datasets.keys())[0] | |
dset = dsets.datasets[split] | |
else: | |
dset = next(iter(dsets.datasets.values())) | |
print("Dataset: ", dset.__class__.__name__) | |
for start_idx in trange(0,len(dset)-batch_size+1,batch_size): | |
indices = list(range(start_idx, start_idx+batch_size)) | |
example = default_collate([dset[i] for i in indices]) | |
x = model.get_input("image", example).to(model.device) | |
for i in range(x.shape[0]): | |
save_image(x[i], os.path.join(outdir, "originals", | |
"{:06}.png".format(indices[i]))) | |
cond_key = model.cond_stage_key | |
c = model.get_input(cond_key, example).to(model.device) | |
scale_factor = 1.0 | |
quant_z, z_indices = model.encode_to_z(x) | |
quant_c, c_indices = model.encode_to_c(c) | |
cshape = quant_z.shape | |
xrec = model.first_stage_model.decode(quant_z) | |
for i in range(xrec.shape[0]): | |
save_image(xrec[i], os.path.join(outdir, "reconstructions", | |
"{:06}.png".format(indices[i]))) | |
if cond_key == "segmentation": | |
# get image from segmentation mask | |
num_classes = c.shape[1] | |
c = torch.argmax(c, dim=1, keepdim=True) | |
c = torch.nn.functional.one_hot(c, num_classes=num_classes) | |
c = c.squeeze(1).permute(0, 3, 1, 2).float() | |
c = model.cond_stage_model.to_rgb(c) | |
idx = z_indices | |
half_sample = False | |
if half_sample: | |
start = idx.shape[1]//2 | |
else: | |
start = 0 | |
idx[:,start:] = 0 | |
idx = idx.reshape(cshape[0],cshape[2],cshape[3]) | |
start_i = start//cshape[3] | |
start_j = start %cshape[3] | |
cidx = c_indices | |
cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3]) | |
sample = True | |
for i in range(start_i,cshape[2]-0): | |
if i <= 8: | |
local_i = i | |
elif cshape[2]-i < 8: | |
local_i = 16-(cshape[2]-i) | |
else: | |
local_i = 8 | |
for j in range(start_j,cshape[3]-0): | |
if j <= 8: | |
local_j = j | |
elif cshape[3]-j < 8: | |
local_j = 16-(cshape[3]-j) | |
else: | |
local_j = 8 | |
i_start = i-local_i | |
i_end = i_start+16 | |
j_start = j-local_j | |
j_end = j_start+16 | |
patch = idx[:,i_start:i_end,j_start:j_end] | |
patch = patch.reshape(patch.shape[0],-1) | |
cpatch = cidx[:, i_start:i_end, j_start:j_end] | |
cpatch = cpatch.reshape(cpatch.shape[0], -1) | |
patch = torch.cat((cpatch, patch), dim=1) | |
logits,_ = model.transformer(patch[:,:-1]) | |
logits = logits[:, -256:, :] | |
logits = logits.reshape(cshape[0],16,16,-1) | |
logits = logits[:,local_i,local_j,:] | |
logits = logits/temperature | |
if top_k is not None: | |
logits = model.top_k_logits(logits, top_k) | |
# apply softmax to convert to probabilities | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
# sample from the distribution or take the most likely | |
if sample: | |
ix = torch.multinomial(probs, num_samples=1) | |
else: | |
_, ix = torch.topk(probs, k=1, dim=-1) | |
idx[:,i,j] = ix | |
xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape) | |
for i in range(xsample.shape[0]): | |
save_image(xsample[i], os.path.join(outdir, "samples", | |
"{:06}.png".format(indices[i]))) | |
def get_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-r", | |
"--resume", | |
type=str, | |
nargs="?", | |
help="load from logdir or checkpoint in logdir", | |
) | |
parser.add_argument( | |
"-b", | |
"--base", | |
nargs="*", | |
metavar="base_config.yaml", | |
help="paths to base configs. Loaded from left-to-right. " | |
"Parameters can be overwritten or added with command-line options of the form `--key value`.", | |
default=list(), | |
) | |
parser.add_argument( | |
"-c", | |
"--config", | |
nargs="?", | |
metavar="single_config.yaml", | |
help="path to single config. If specified, base configs will be ignored " | |
"(except for the last one if left unspecified).", | |
const=True, | |
default="", | |
) | |
parser.add_argument( | |
"--ignore_base_data", | |
action="store_true", | |
help="Ignore data specification from base configs. Useful if you want " | |
"to specify a custom datasets on the command line.", | |
) | |
parser.add_argument( | |
"--outdir", | |
required=True, | |
type=str, | |
help="Where to write outputs to.", | |
) | |
parser.add_argument( | |
"--top_k", | |
type=int, | |
default=100, | |
help="Sample from among top-k predictions.", | |
) | |
parser.add_argument( | |
"--temperature", | |
type=float, | |
default=1.0, | |
help="Sampling temperature.", | |
) | |
return parser | |
def load_model_from_config(config, sd, gpu=True, eval_mode=True): | |
if "ckpt_path" in config.params: | |
print("Deleting the restore-ckpt path from the config...") | |
config.params.ckpt_path = None | |
if "downsample_cond_size" in config.params: | |
print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...") | |
config.params.downsample_cond_size = -1 | |
config.params["downsample_cond_factor"] = 0.5 | |
try: | |
if "ckpt_path" in config.params.first_stage_config.params: | |
config.params.first_stage_config.params.ckpt_path = None | |
print("Deleting the first-stage restore-ckpt path from the config...") | |
if "ckpt_path" in config.params.cond_stage_config.params: | |
config.params.cond_stage_config.params.ckpt_path = None | |
print("Deleting the cond-stage restore-ckpt path from the config...") | |
except: | |
pass | |
model = instantiate_from_config(config) | |
if sd is not None: | |
missing, unexpected = model.load_state_dict(sd, strict=False) | |
print(f"Missing Keys in State Dict: {missing}") | |
print(f"Unexpected Keys in State Dict: {unexpected}") | |
if gpu: | |
model.cuda() | |
if eval_mode: | |
model.eval() | |
return {"model": model} | |
def get_data(config): | |
# get data | |
data = instantiate_from_config(config.data) | |
data.prepare_data() | |
data.setup() | |
return data | |
def load_model_and_dset(config, ckpt, gpu, eval_mode): | |
# get data | |
dsets = get_data(config) # calls data.config ... | |
# now load the specified checkpoint | |
if ckpt: | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
global_step = pl_sd["global_step"] | |
else: | |
pl_sd = {"state_dict": None} | |
global_step = None | |
model = load_model_from_config(config.model, | |
pl_sd["state_dict"], | |
gpu=gpu, | |
eval_mode=eval_mode)["model"] | |
return dsets, model, global_step | |
if __name__ == "__main__": | |
sys.path.append(os.getcwd()) | |
parser = get_parser() | |
opt, unknown = parser.parse_known_args() | |
ckpt = None | |
if opt.resume: | |
if not os.path.exists(opt.resume): | |
raise ValueError("Cannot find {}".format(opt.resume)) | |
if os.path.isfile(opt.resume): | |
paths = opt.resume.split("/") | |
try: | |
idx = len(paths)-paths[::-1].index("logs")+1 | |
except ValueError: | |
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt | |
logdir = "/".join(paths[:idx]) | |
ckpt = opt.resume | |
else: | |
assert os.path.isdir(opt.resume), opt.resume | |
logdir = opt.resume.rstrip("/") | |
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") | |
print(f"logdir:{logdir}") | |
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) | |
opt.base = base_configs+opt.base | |
if opt.config: | |
if type(opt.config) == str: | |
opt.base = [opt.config] | |
else: | |
opt.base = [opt.base[-1]] | |
configs = [OmegaConf.load(cfg) for cfg in opt.base] | |
cli = OmegaConf.from_dotlist(unknown) | |
if opt.ignore_base_data: | |
for config in configs: | |
if hasattr(config, "data"): del config["data"] | |
config = OmegaConf.merge(*configs, cli) | |
print(ckpt) | |
gpu = True | |
eval_mode = True | |
show_config = False | |
if show_config: | |
print(OmegaConf.to_container(config)) | |
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode) | |
print(f"Global step: {global_step}") | |
outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step, | |
opt.top_k, | |
opt.temperature)) | |
os.makedirs(outdir, exist_ok=True) | |
print("Writing samples to ", outdir) | |
for k in ["originals", "reconstructions", "samples"]: | |
os.makedirs(os.path.join(outdir, k), exist_ok=True) | |
run_conditional(model, dsets, outdir, opt.top_k, opt.temperature) | |