Spaces:
Build error
Build error
File size: 7,815 Bytes
94ada0b df44b7d 94ada0b df44b7d 94ada0b 368dc9b 94ada0b 368dc9b 94ada0b 368dc9b 94ada0b 368dc9b 94ada0b 368dc9b 94ada0b 368dc9b 94ada0b 77c753d 94ada0b 368dc9b 94ada0b 77c753d 94ada0b 77c753d 94ada0b 77c753d 94ada0b 77c753d 94ada0b 77c753d 94ada0b 77c753d 94ada0b 368dc9b df44b7d 94ada0b 368dc9b 94ada0b 368dc9b 77c753d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os, sys
os.system('pip install -r requirements.txt')
import gradio as gr
import numpy as np
import dnnlib
import time
import legacy
import torch
import glob
import cv2
from torch_utils import misc
from renderer import Renderer
from training.networks import Generator
from huggingface_hub import hf_hub_download
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
model_lists = {
'ffhq-512x512-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl'),
'ffhq-256x256-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_256.pkl'),
'ffhq-1024x1024-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_1024.pkl'),
}
model_names = [name for name in model_lists]
def set_random_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name=None):
gen = model.synthesis
range_u, range_v = gen.C.range_u, gen.C.range_v
if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option?
yaw, pitch = 0.5 * yaw, 0.3 * pitch
pitch = pitch + np.pi/2
u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
else:
u = (yaw + 1) / 2
v = (pitch + 1) / 2
cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device, fov=fov)
return cam
def check_name(model_name):
"""Gets model by name."""
if model_name in model_lists:
network_pkl = hf_hub_download(**model_lists[model_name])
else:
if os.path.isdir(model_name):
network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
else:
network_pkl = model_name
return network_pkl
def get_model(network_pkl, render_option=None):
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as f:
network = legacy.load_network_pkl(f)
G = network['G_ema'].to(device) # type: ignore
with torch.no_grad():
G2 = Generator(*G.init_args, **G.init_kwargs).to(device)
misc.copy_params_and_buffers(G, G2, require_all=False)
print('compile and go through the initial image')
G2 = G2.eval()
init_z = torch.from_numpy(np.random.RandomState(0).rand(1, G2.z_dim)).to(device)
init_cam = get_camera_traj(G2, 0, 0, model_name=network_pkl)
dummy = G2(z=init_z, c=None, camera_matrices=init_cam, render_option=render_option, theta=0)
res = dummy['img'].shape[-1]
imgs = np.zeros((res, res//2, 3))
return G2, res, imgs
global_states = list(get_model(check_name(model_names[0])))
wss = [None, None]
def proc_seed(history, seed):
if isinstance(seed, str):
seed = 0
else:
seed = int(seed)
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
history = history or {}
seeds = []
trunc = trunc / 100
if model_find != "":
model_name = model_find
model_name = check_name(model_name)
if model_name != history.get("model_name", None):
model, res, imgs = get_model(model_name, render_option)
global_states[0] = model
global_states[1] = res
global_states[2] = imgs
model, res, imgs = global_states
for idx, seed in enumerate([seed1, seed2]):
if isinstance(seed, str):
seed = 0
else:
seed = int(seed)
if (seed != history.get(f'seed{idx}', -1)) or \
(model_name != history.get("model_name", None)) or \
(trunc != history.get("trunc", 0.7)) or \
(wss[idx] is None):
print(f'use seed {seed}')
set_random_seed(seed)
z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
ws = model.mapping(z=z, c=None, truncation_psi=trunc)
img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0, model_name=model_name), render_option=render_option)
ws = ws.detach().cpu().numpy()
img = img[0].permute(1,2,0).detach().cpu().numpy()
imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
(res//2, res//2), cv2.INTER_AREA)
wss[idx] = ws
else:
seed = history[f'seed{idx}']
seeds += [seed]
history[f'seed{idx}'] = seed
history['trunc'] = trunc
history['model_name'] = model_name
set_random_seed(sum(seeds))
# style mixing (?)
ws1, ws2 = [torch.from_numpy(ws).to(device) for ws in wss]
ws = ws1.clone()
ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)
# set visualization for other types of inputs.
if early == 'Normal Map':
render_option += ',normal,early'
elif early == 'Gradient Map':
render_option += ',gradient,early'
start_t = time.time()
with torch.no_grad():
cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
image = model.get_final_output(
styles=ws, camera_matrices=cam,
theta=roll * np.pi,
render_option=render_option)
end_t = time.time()
image = image[0].permute(1,2,0).detach().cpu().numpy().clip(-1, 1) * 0.5 + 0.5
if imgs.shape[0] == image.shape[0]:
image = np.concatenate([imgs, image], 1)
else:
a = image.shape[0]
b = int(imgs.shape[1] / imgs.shape[0] * a)
print(f'resize {a} {b} {image.shape} {imgs.shape}')
image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)
print(f'rendering time = {end_t-start_t:.4f}s')
image = (image * 255).astype('uint8')
return image, history
model_name = gr.inputs.Dropdown(model_names)
model_find = gr.inputs.Textbox(label="Checkpoint path (folder or .pkl file)", default="")
render_option = gr.inputs.Textbox(label="Additional rendering options", default='freeze_bg,steps:50')
trunc = gr.inputs.Slider(default=70, maximum=100, minimum=0, label='Truncation trick (%)')
seed1 = gr.inputs.Number(default=1, label="Random seed1")
seed2 = gr.inputs.Number(default=9, label="Random seed2")
mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (geometry)")
mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (apparence)")
early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='Intermedia output')
yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Yaw")
pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Pitch")
roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Roll (optional, not suggested for basic config)")
fov = gr.inputs.Slider(minimum=10, maximum=14, default=12, label="Fov")
css = ".output-image, .input-image, .image-preview {height: 600px !important} "
gr.Interface(fn=f_synthesis,
inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
title="Interactive Web Demo for StyleNeRF (ICLR 2022)",
description="StyleNeRF: A Style-based 3D-Aware Generator for High-resolution Image Synthesis. Currently the demo runs on CPU only.",
outputs=["image", "state"],
layout='unaligned',
css=css, theme='dark-seafoam',
live=True).launch(enable_queue=True)
|