Spaces:
Build error
Build error
resolve conflict
Browse files
app.py
CHANGED
@@ -28,7 +28,7 @@ def handler(signum, frame):
|
|
28 |
if res == 'y':
|
29 |
gr.close_all()
|
30 |
exit(1)
|
31 |
-
|
32 |
signal.signal(signal.SIGINT, handler)
|
33 |
|
34 |
|
@@ -56,7 +56,7 @@ def check_name(model_name='FFHQ512'):
|
|
56 |
"""Gets model by name."""
|
57 |
if model_name == 'FFHQ512':
|
58 |
network_pkl = hf_hub_download(repo_id='thomagram/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl')
|
59 |
-
|
60 |
# TODO: checkpoint to be updated!
|
61 |
# elif model_name == 'FFHQ512v2':
|
62 |
# network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
|
@@ -109,10 +109,10 @@ def proc_seed(history, seed):
|
|
109 |
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
|
110 |
history = history or {}
|
111 |
seeds = []
|
112 |
-
|
113 |
if model_find != "":
|
114 |
model_name = model_find
|
115 |
-
|
116 |
model_name = check_name(model_name)
|
117 |
if model_name != history.get("model_name", None):
|
118 |
model, res, imgs = get_model(model_name, render_option)
|
@@ -139,7 +139,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
139 |
ws = ws.detach().cpu().numpy()
|
140 |
img = img[0].permute(1,2,0).detach().cpu().numpy()
|
141 |
|
142 |
-
|
143 |
imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
|
144 |
np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
|
145 |
(res//2, res//2), cv2.INTER_AREA)
|
@@ -151,7 +151,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
151 |
history[f'seed{idx}'] = seed
|
152 |
history['trunc'] = trunc
|
153 |
history['model_name'] = model_name
|
154 |
-
|
155 |
set_random_seed(sum(seeds))
|
156 |
|
157 |
# style mixing (?)
|
@@ -159,18 +159,18 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
159 |
ws = ws1.clone()
|
160 |
ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
|
161 |
ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)
|
162 |
-
|
163 |
# set visualization for other types of inputs.
|
164 |
if early == 'Normal Map':
|
165 |
render_option += ',normal,early'
|
166 |
elif early == 'Gradient Map':
|
167 |
render_option += ',gradient,early'
|
168 |
-
|
169 |
start_t = time.time()
|
170 |
with torch.no_grad():
|
171 |
cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
|
172 |
image = model.get_final_output(
|
173 |
-
styles=ws, camera_matrices=cam,
|
174 |
theta=roll * np.pi,
|
175 |
render_option=render_option)
|
176 |
end_t = time.time()
|
@@ -184,7 +184,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
184 |
b = int(imgs.shape[1] / imgs.shape[0] * a)
|
185 |
print(f'resize {a} {b} {image.shape} {imgs.shape}')
|
186 |
image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)
|
187 |
-
|
188 |
print(f'rendering time = {end_t-start_t:.4f}s')
|
189 |
image = (image * 255).astype('uint8')
|
190 |
return image, history
|
@@ -210,4 +210,4 @@ gr.Interface(fn=f_synthesis,
|
|
210 |
outputs=["image", "state"],
|
211 |
layout='unaligned',
|
212 |
css=css, theme='dark-huggingface',
|
213 |
-
live=True).launch(
|
|
|
28 |
if res == 'y':
|
29 |
gr.close_all()
|
30 |
exit(1)
|
31 |
+
|
32 |
signal.signal(signal.SIGINT, handler)
|
33 |
|
34 |
|
|
|
56 |
"""Gets model by name."""
|
57 |
if model_name == 'FFHQ512':
|
58 |
network_pkl = hf_hub_download(repo_id='thomagram/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl')
|
59 |
+
|
60 |
# TODO: checkpoint to be updated!
|
61 |
# elif model_name == 'FFHQ512v2':
|
62 |
# network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
|
|
|
109 |
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
|
110 |
history = history or {}
|
111 |
seeds = []
|
112 |
+
|
113 |
if model_find != "":
|
114 |
model_name = model_find
|
115 |
+
|
116 |
model_name = check_name(model_name)
|
117 |
if model_name != history.get("model_name", None):
|
118 |
model, res, imgs = get_model(model_name, render_option)
|
|
|
139 |
ws = ws.detach().cpu().numpy()
|
140 |
img = img[0].permute(1,2,0).detach().cpu().numpy()
|
141 |
|
142 |
+
|
143 |
imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
|
144 |
np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
|
145 |
(res//2, res//2), cv2.INTER_AREA)
|
|
|
151 |
history[f'seed{idx}'] = seed
|
152 |
history['trunc'] = trunc
|
153 |
history['model_name'] = model_name
|
154 |
+
|
155 |
set_random_seed(sum(seeds))
|
156 |
|
157 |
# style mixing (?)
|
|
|
159 |
ws = ws1.clone()
|
160 |
ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
|
161 |
ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)
|
162 |
+
|
163 |
# set visualization for other types of inputs.
|
164 |
if early == 'Normal Map':
|
165 |
render_option += ',normal,early'
|
166 |
elif early == 'Gradient Map':
|
167 |
render_option += ',gradient,early'
|
168 |
+
|
169 |
start_t = time.time()
|
170 |
with torch.no_grad():
|
171 |
cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
|
172 |
image = model.get_final_output(
|
173 |
+
styles=ws, camera_matrices=cam,
|
174 |
theta=roll * np.pi,
|
175 |
render_option=render_option)
|
176 |
end_t = time.time()
|
|
|
184 |
b = int(imgs.shape[1] / imgs.shape[0] * a)
|
185 |
print(f'resize {a} {b} {image.shape} {imgs.shape}')
|
186 |
image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)
|
187 |
+
|
188 |
print(f'rendering time = {end_t-start_t:.4f}s')
|
189 |
image = (image * 255).astype('uint8')
|
190 |
return image, history
|
|
|
210 |
outputs=["image", "state"],
|
211 |
layout='unaligned',
|
212 |
css=css, theme='dark-huggingface',
|
213 |
+
live=True).launch(enable_queue=True)
|