deepfake_gi_fastGAN / scripts /find_nearest_neighbor.py
vlbthambawita's picture
First
7f49ac7
raw
history blame
No virus
2.48 kB
from eval import load_params
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import utils as vutils
from torchvision import transforms
import os
import random
import argparse
from tqdm import tqdm
from models import Generator
from operation import load_params, InfiniteSamplerWrapper
noise_dim = 256
device = torch.device('cuda:%d'%(0))
im_size = 512
net_ig = Generator( ngf=64, nz=noise_dim, nc=3, im_size=im_size)#, big=args.big )
net_ig.to(device)
epoch = 50000
ckpt = './models/all_%d.pth'%(epoch)
checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
net_ig.load_state_dict(checkpoint['g'])
load_params(net_ig, checkpoint['g_ema'])
batch = 8
noise = torch.randn(batch, noise_dim).to(device)
g_imgs = net_ig(noise)[0]
vutils.save_image(g_imgs.add(1).mul(0.5),
os.path.join('./', '%d.png'%(2)))
transform_list = [
transforms.Resize((int(256),int(256))),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]
trans = transforms.Compose(transform_list)
data_root = '/media/database/images/first_1k'
dataset = ImageFolder(root=data_root, transform=trans)
import lpips
percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)
the_image = g_imgs[0].unsqueeze(0)
def find_closest(the_image):
the_image = F.interpolate(the_image, size=256)
small = 100
close_image = None
for i in tqdm(range(len(dataset))):
real_iamge = dataset[i][0].unsqueeze(0).to(device)
dis = percept(the_image, real_iamge).sum()
if dis < small:
small = dis
close_image = real_iamge
return close_image, small
all_dist = []
batch = 8
result_path = 'nn_track'
import os
os.makedirs(result_path, exist_ok=True)
for j in range(8):
with torch.no_grad():
noise = torch.randn(batch, noise_dim).to(device)
g_imgs = net_ig(noise)[0]
for n in range(batch):
the_image = g_imgs[n].unsqueeze(0)
close_0, dis = find_closest(the_image)
vutils.save_image(torch.cat([F.interpolate(the_image,256), close_0]).add(1).mul(0.5), \
result_path+'/nn_%d.jpg'%(j*batch+n))
all_dist.append(dis.view(1))
new_all_dist = []
for v in all_dist:
new_all_dist.append(v.view(1))
print(torch.cat(new_all_dist).mean())