Spaces:
Runtime error
Runtime error
import argparse | |
import gradio as gr | |
from benchmark.app_image import ImageSwap | |
from benchmark.app_video import VideoSwap | |
from configs.train_config import TrainConfig | |
from models.model import HifiFace | |
class ConfigPath: | |
face_detector_weights = "/checkpoints/face_detector/face_detector_scrfd_10g_bnkps.onnx" | |
model_path = "" | |
model_idx = 80000 | |
ffmpeg_device = "cuda" | |
device = "cuda" | |
def main(): | |
cfg = ConfigPath() | |
parser = argparse.ArgumentParser( | |
prog="benchmark", description="What the program does", epilog="Text at the bottom of help" | |
) | |
parser.add_argument("-m", "--model_path", default="/checkpoints/hififace_pretrained/standard_model") | |
parser.add_argument("-i", "--model_idx", default="320000") | |
parser.add_argument("-f", "--ffmpeg_device", default="cpu") | |
parser.add_argument("-d", "--device", default="cpu") | |
args = parser.parse_args() | |
cfg.model_path = args.model_path | |
cfg.model_idx = int(args.model_idx) | |
cfg.ffmpeg_device = args.ffmpeg_device | |
cfg.device = args.device | |
opt = TrainConfig() | |
checkpoint = (cfg.model_path, cfg.model_idx) | |
model_path_1 = "/checkpoints/hififace_pretrained/with_gaze_and_mouth" | |
checkpoint1 = ("/checkpoints/hififace_pretrained/with_gaze_and_mouth", "190000") | |
model = HifiFace(opt.identity_extractor_config, is_training=False, device=cfg.device, load_checkpoint=checkpoint) | |
model1 = HifiFace(opt.identity_extractor_config, is_training=False, device=cfg.device, load_checkpoint=checkpoint1) | |
image_infer = ImageSwap(cfg, model) | |
image_infer1 = ImageSwap(cfg, model1) | |
def inference_image(source_face, target_face, shape_rate, id_rate, iterations): | |
return image_infer.inference(source_face, target_face, shape_rate, id_rate, int(iterations)) | |
def inference_image1(source_face, target_face, shape_rate, id_rate, iterations): | |
return image_infer1.inference(source_face, target_face, shape_rate, id_rate, int(iterations)) | |
model_name = cfg.model_path.split("/")[-1] + ":" + f"{cfg.model_idx}" | |
model_name1 = model_path_1.split("/")[-1] + ":" + "190000" | |
with gr.Blocks(title="FaceSwap") as demo: | |
gr.Markdown( | |
f""" | |
### standard model: {model_name} \n | |
### model with eye and mouth hm loss: {model_name1} | |
""" | |
) | |
with gr.Tab("Image swap with standard model"): | |
with gr.Row(): | |
source_image = gr.Image(shape=None, label="source image") | |
target_image = gr.Image(shape=None, label="target image") | |
with gr.Row(): | |
with gr.Column(): | |
structure_sim = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
label="3d similarity", | |
) | |
id_sim = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
label="id similarity", | |
) | |
iters = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=1, | |
step=1, | |
label="iters", | |
) | |
image_btn = gr.Button("image swap") | |
output_image = gr.Image(shape=None, label="Result") | |
image_btn.click( | |
fn=inference_image, | |
inputs=[source_image, target_image, structure_sim, id_sim, iters], | |
outputs=output_image, | |
) | |
with gr.Tab("Image swap with eye&mouth hm loss model"): | |
with gr.Row(): | |
source_image = gr.Image(shape=None, label="source image") | |
target_image = gr.Image(shape=None, label="target image") | |
with gr.Row(): | |
with gr.Column(): | |
structure_sim = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
label="3d similarity", | |
) | |
id_sim = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
label="id similarity", | |
) | |
iters = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=1, | |
step=1, | |
label="iters", | |
) | |
image_btn = gr.Button("image swap") | |
output_image = gr.Image(shape=None, label="Result") | |
image_btn.click( | |
fn=inference_image1, | |
inputs=[source_image, target_image, structure_sim, id_sim, iters], | |
outputs=output_image, | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
if __name__ == "__main__": | |
main() | |