Spaces:
Paused
Paused
import torch | |
import torchvision | |
import os | |
import os.path as osp | |
import spaces | |
import random | |
from argparse import ArgumentParser | |
from datetime import datetime | |
import gradio as gr | |
from foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy | |
from foleycrafter.pipelines.auffusion_pipeline import denormalize_spectrogram | |
from foleycrafter.pipelines.auffusion_pipeline import Generator | |
from foleycrafter.models.time_detector.model import VideoOnsetNet | |
from foleycrafter.models.specvqgan.onset_baseline.utils import torch_utils | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
from huggingface_hub import snapshot_download | |
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler | |
import soundfile as sf | |
from moviepy.editor import AudioFileClip, VideoFileClip | |
os.environ['GRADIO_TEMP_DIR'] = './tmp' | |
sample_idx = 0 | |
scheduler_dict = { | |
"DDIM": DDIMScheduler, | |
"Euler": EulerDiscreteScheduler, | |
"PNDM": PNDMScheduler, | |
} | |
css = """ | |
.toolbutton { | |
margin-buttom: 0em 0em 0em 0em; | |
max-width: 2.5em; | |
min-width: 2.5em !important; | |
height: 2.5em; | |
} | |
""" | |
parser = ArgumentParser() | |
parser.add_argument("--config", type=str, default="example/config/base.yaml") | |
parser.add_argument("--server-name", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--share", action="store_true") | |
parser.add_argument("--save-path", default="samples") | |
args = parser.parse_args() | |
N_PROMPT = ( | |
"" | |
) | |
class FoleyController: | |
def __init__(self): | |
# config dirs | |
self.basedir = os.getcwd() | |
self.model_dir = os.path.join(self.basedir, "models") | |
self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) | |
self.savedir_sample = os.path.join(self.savedir, "sample") | |
os.makedirs(self.savedir, exist_ok=True) | |
self.pipeline = None | |
self.loaded = False | |
self.load_model() | |
def load_model(self): | |
gr.Info("Start Load Models...") | |
print("Start Load Models...") | |
# download ckpt | |
pretrained_model_name_or_path = 'auffusion/auffusion-full-no-adapter' | |
if not os.path.isdir(pretrained_model_name_or_path): | |
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, local_dir='models/auffusion') | |
fc_ckpt = 'ymzhang319/FoleyCrafter' | |
if not os.path.isdir(fc_ckpt): | |
fc_ckpt = snapshot_download(fc_ckpt, local_dir='models/') | |
# set model config | |
temporal_ckpt_path = osp.join(self.model_dir, 'temporal_adapter.ckpt') | |
# load vocoder | |
vocoder_config_path= "./models/auffusion" | |
self.vocoder = Generator.from_pretrained( | |
vocoder_config_path, | |
subfolder="vocoder") | |
# load time detector | |
time_detector_ckpt = osp.join(osp.join(self.model_dir, 'timestamp_detector.pth.tar')) | |
time_detector = VideoOnsetNet(False) | |
self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True) | |
self.pipeline = build_foleycrafter() | |
ckpt = torch.load(temporal_ckpt_path) | |
# load temporal adapter | |
if 'state_dict' in ckpt.keys(): | |
ckpt = ckpt['state_dict'] | |
load_gligen_ckpt = {} | |
for key, value in ckpt.items(): | |
if key.startswith('module.'): | |
load_gligen_ckpt[key[len('module.'):]] = value | |
else: | |
load_gligen_ckpt[key] = value | |
m, u = self.pipeline.controlnet.load_state_dict(load_gligen_ckpt, strict=False) | |
print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};") | |
self.image_processor = CLIPImageProcessor() | |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder='models/image_encoder') | |
self.pipeline.load_ip_adapter(fc_ckpt, subfolder='semantic', weight_name='semantic_adapter.bin', image_encoder_folder=None) | |
gr.Info("Load Finish!") | |
print("Load Finish!") | |
self.loaded = True | |
return "Load" | |
def foley( | |
self, | |
input_video, | |
prompt_textbox, | |
negative_prompt_textbox, | |
ip_adapter_scale, | |
temporal_scale, | |
sampler_dropdown, | |
sample_step_slider, | |
cfg_scale_slider, | |
seed_textbox, | |
): | |
device = 'cuda' | |
# move to gpu | |
self.time_detector = controller.time_detector.to(device) | |
self.pipeline = controller.pipeline.to(device) | |
self.vocoder = controller.vocoder.to(device) | |
self.image_encoder = controller.image_encoder.to(device) | |
vision_transform_list = [ | |
torchvision.transforms.Resize((128, 128)), | |
torchvision.transforms.CenterCrop((112, 112)), | |
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
] | |
video_transform = torchvision.transforms.Compose(vision_transform_list) | |
# if not self.loaded: | |
# raise gr.Error("Error with loading model") | |
generator = torch.Generator() | |
if seed_textbox != "": | |
torch.manual_seed(int(seed_textbox)) | |
generator.manual_seed(int(seed_textbox)) | |
max_frame_nums = 30 | |
frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums) | |
if duration >= 10: | |
duration = 10 | |
time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2).to(device) | |
time_frames = video_transform(time_frames) | |
time_frames = {'frames': time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)} | |
preds = self.time_detector(time_frames) | |
preds = torch.sigmoid(preds) | |
# duration | |
time_condition = [-1 if preds[0][int(i / (1024 / 10 * duration) * max_frame_nums)] < 0.5 else 1 for i in range(int(1024 / 10 * duration))] | |
time_condition = time_condition + [-1] * (1024 - len(time_condition)) | |
# w -> b c h w | |
time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1) | |
images = self.image_processor(images=frames, return_tensors="pt").to(device) | |
image_embeddings = self.image_encoder(**images).image_embeds | |
image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0) | |
neg_image_embeddings = torch.zeros_like(image_embeddings) | |
image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1) | |
self.pipeline.set_ip_adapter_scale(ip_adapter_scale) | |
sample = self.pipeline( | |
prompt=prompt_textbox, | |
negative_prompt=negative_prompt_textbox, | |
ip_adapter_image_embeds=image_embeddings, | |
image=time_condition, | |
controlnet_conditioning_scale=float(temporal_scale), | |
num_inference_steps=sample_step_slider, | |
height=256, | |
width=1024, | |
output_type="pt", | |
generator=generator, | |
) | |
name = 'output' | |
audio_img = sample.images[0] | |
audio = denormalize_spectrogram(audio_img) | |
audio = self.vocoder.inference(audio, lengths=160000)[0] | |
audio_save_path = osp.join(self.savedir_sample, 'audio') | |
os.makedirs(audio_save_path, exist_ok=True) | |
audio = audio[:int(duration * 16000)] | |
save_path = osp.join(audio_save_path, f'{name}.wav') | |
sf.write(save_path, audio, 16000) | |
audio = AudioFileClip(osp.join(audio_save_path, f'{name}.wav')) | |
video = VideoFileClip(input_video) | |
audio = audio.subclip(0, duration) | |
video.audio = audio | |
video = video.subclip(0, duration) | |
video.write_videofile(osp.join(self.savedir_sample, f'{name}.mp4')) | |
save_sample_path = os.path.join(self.savedir_sample, f"{name}.mp4") | |
return save_sample_path | |
controller = FoleyController() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML( | |
'<h1 style="height: 136px; display: flex; align-items: center; justify-content: space-around;"><span style="height: 100%; width:136px;"><img src="file/foleycrafter.png" alt="logo" style="height: 100%; width:auto; object-fit: contain; margin: 0px 0px; padding: 0px 0px;"></span><strong style="font-size: 40px;">FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds</strong></h1>' | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"<div align='center'><font size='5'><a href='https://foleycrafter.github.io/'>Project Page</a>  " # noqa | |
"<a href='https://arxiv.org/abs/xxxx.xxxxx/'>Paper</a>  " | |
"<a href='https://github.com/open-mmlab/foleycrafter'>Code</a>  " | |
"<a href='https://huggingface.co/spaces/ymzhang319/FoleyCrafter'>Demo</a> </font></div>" | |
) | |
with gr.Column(variant="panel"): | |
with gr.Row(equal_height=False): | |
with gr.Column(): | |
with gr.Row(): | |
init_img = gr.Video(label="Input Video") | |
with gr.Row(): | |
prompt_textbox = gr.Textbox(value='', label="Prompt", lines=1) | |
with gr.Row(): | |
negative_prompt_textbox = gr.Textbox(value=N_PROMPT, label="Negative prompt", lines=1) | |
with gr.Row(): | |
sampler_dropdown = gr.Dropdown( | |
label="Sampling method", | |
choices=list(scheduler_dict.keys()), | |
value=list(scheduler_dict.keys())[0], | |
) | |
sample_step_slider = gr.Slider( | |
label="Sampling steps", value=25, minimum=10, maximum=100, step=1 | |
) | |
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20) | |
ip_adapter_scale = gr.Slider(label="Visual Content Scale", value=1.0, minimum=0, maximum=1) | |
temporal_scale = gr.Slider(label="Temporal Align Scale", value=0., minimum=0., maximum=1.0) | |
with gr.Row(): | |
seed_textbox = gr.Textbox(label="Seed", value=42) | |
seed_button = gr.Button(value="\U0001f3b2", elem_classes="toolbutton") | |
seed_button.click(fn=lambda x: random.randint(1, 1e8), outputs=[seed_textbox], queue=False) | |
generate_button = gr.Button(value="Generate", variant="primary") | |
result_video = gr.Video(label="Generated Audio", interactive=False) | |
generate_button.click( | |
fn=controller.foley, | |
inputs=[ | |
init_img, | |
prompt_textbox, | |
negative_prompt_textbox, | |
ip_adapter_scale, | |
temporal_scale, | |
sampler_dropdown, | |
sample_step_slider, | |
cfg_scale_slider, | |
seed_textbox, | |
], | |
outputs=[result_video], | |
) | |
gr.Examples( | |
examples= [ | |
['examples/videos/1.mp4', '', '', 1.0, 0.0, 'DDIM', 25, 7.5, 93493458], | |
['examples/videos/2.mp4', '', '', 1.0, 0.0, 'DDIM', 25, 7.5, 51972214], | |
['examples/videos/3.mp4', '', '', 1.0, 0.0, 'DDIM', 25, 7.5, 92530687], | |
], | |
inputs=[init_img,prompt_textbox,negative_prompt_textbox,ip_adapter_scale,temporal_scale,sampler_dropdown,sample_step_slider,cfg_scale_slider,seed_textbox], | |
) | |
demo.queue(10) | |
demo.launch(server_name=args.server_name, server_port=args.port, share=args.share, allowed_paths=["./foleycrafter.png"]) |