File size: 4,884 Bytes
7f49ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7c9ff6
7f49ac7
 
 
 
 
c7c9ff6
7f49ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import utils as vutils

import os
import random
import argparse
from tqdm import tqdm

from models import Generator


def load_params(model, new_param):
    for p, new_p in zip(model.parameters(), new_param):
        p.data.copy_(new_p)

def resize(img):
    return F.interpolate(img, size=256)

def batch_generate(zs, netG, batch=8):
    g_images = []
    with torch.no_grad():
        for i in range(len(zs)//batch):
            g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
        if len(zs)%batch>0:
            g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
    return torch.cat(g_images)

def batch_save(images, folder_name):
    if not os.path.exists(folder_name):
        os.mkdir(folder_name)
    for i, image in enumerate(images):
        vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='generate images'
    )
    parser.add_argument('--ckpt', type=str, default="pre_trained_checkpoint_4ch/all_50000.pth")
    parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
    parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
    parser.add_argument('--start_iter', type=int, default=6)
    parser.add_argument('--end_iter', type=int, default=10)

    parser.add_argument('--dist', type=str, default='test_out')
    parser.add_argument('--size', type=int, default=256)
    parser.add_argument('--batch', default=1, type=int, help='batch size')
    parser.add_argument('--n_sample', type=int, default=1000)
    parser.add_argument('--big', action='store_true')
    parser.add_argument('--im_size', type=int, default=256)
    parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
    parser.set_defaults(big=False)
    args = parser.parse_args()

    noise_dim = 256
    device = torch.device('cuda:%d'%(args.cuda))
    
    net_ig = Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
    net_ig.to(device)

    #for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
    ckpt = args.ckpt #f"{args.artifacts}/models/{epoch}.pth"
    #checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
    checkpoint = torch.load(ckpt)
    # Remove prefix `module`.
    checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
    net_ig.load_state_dict(checkpoint['g'])
    #load_params(net_ig, checkpoint['g_ema'])

    #net_ig.eval()
    print("load checkpoint success")

    net_ig.to(device)

    del checkpoint

    #dist = 'eval_%d'%(epoch)
    #dist = os.path.join(args.dist, 'img')
    dist = args.dist
    os.makedirs(dist, exist_ok=True)

    with torch.no_grad():
        for i in tqdm(range(args.n_sample//args.batch)):
            noise = torch.randn(args.batch, noise_dim).to(device)
            g_imgs = net_ig(noise)[0]
            g_imgs = F.interpolate(g_imgs, 512)
            
            
            for j, g_img in enumerate( g_imgs ):
                #print("img sahpe=", g_img.shape)
                g_mask = g_img.add(1).mul(0.5)[-1, :, :].expand(3, -1, -1)
                g_img = g_img.add(1).mul(0.5)[0:3, :, :]

                # Clean generated data using clamping
                g_mask = torch.clamp(g_mask, min=0, max=1)
                g_img = torch.clamp(g_img, min=0, max=1)
                #print(g_mask.type())
                g_mask = (g_mask > 0.5) * 1.0
                #print(g_mask.type())

                #print("gmask_min:", g_mask.min())
                #print("gmask_max:", g_mask.max())
                #exit()
                
                #print("img sahpe=", g_img.shape)

                if args.save_option == "image_and_mask":
                    vutils.save_image(g_img, 
                        os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
                    vutils.save_image(g_mask, 
                        os.path.join(dist, '%d_mask.png'%(i*args.batch+j))) #, normalize=True, range=(0,1))

                elif args.save_option == "image_only":
                    vutils.save_image(g_img, 
                        os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
                    
                elif args.save_option == "mask_only":
                    vutils.save_image(g_mask, 
                        os.path.join(dist, '%d_mask.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
                else:
                    print("wrong choise to save option.")