FoleyCrafter / app.py
fantaxy's picture
Update app.py
4ba14c9 verified
import torch
import torchvision
import os
import os.path as osp
import spaces
import random
from argparse import ArgumentParser
from datetime import datetime
import gradio as gr
from foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy
from foleycrafter.pipelines.auffusion_pipeline import denormalize_spectrogram
from foleycrafter.pipelines.auffusion_pipeline import Generator
from foleycrafter.models.time_detector.model import VideoOnsetNet
from foleycrafter.models.specvqgan.onset_baseline.utils import torch_utils
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from huggingface_hub import snapshot_download
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
import soundfile as sf
from moviepy.editor import AudioFileClip, VideoFileClip
os.environ['GRADIO_TEMP_DIR'] = './tmp'
sample_idx = 0
scheduler_dict = {
"DDIM": DDIMScheduler,
"Euler": EulerDiscreteScheduler,
"PNDM": PNDMScheduler,
}
css = """
footer {
visibility: hidden;
}
"""
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="example/config/base.yaml")
parser.add_argument("--server-name", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", type=bool, default=True)
parser.add_argument("--save-path", default="samples")
args = parser.parse_args()
N_PROMPT = (
""
)
class FoleyController:
def __init__(self):
# config dirs
self.basedir = os.getcwd()
self.model_dir = os.path.join(self.basedir, "models")
self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
self.savedir_sample = os.path.join(self.savedir, "sample")
os.makedirs(self.savedir, exist_ok=True)
self.pipeline = None
self.loaded = False
self.load_model()
def load_model(self):
gr.Info("Start Load Models...")
print("Start Load Models...")
# download ckpt
pretrained_model_name_or_path = 'auffusion/auffusion-full-no-adapter'
if not os.path.isdir(pretrained_model_name_or_path):
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, local_dir='models/auffusion')
fc_ckpt = 'ymzhang319/FoleyCrafter'
if not os.path.isdir(fc_ckpt):
fc_ckpt = snapshot_download(fc_ckpt, local_dir='models/')
# set model config
temporal_ckpt_path = osp.join(self.model_dir, 'temporal_adapter.ckpt')
# load vocoder
vocoder_config_path= "./models/auffusion"
self.vocoder = Generator.from_pretrained(
vocoder_config_path,
subfolder="vocoder")
# load time detector
time_detector_ckpt = osp.join(osp.join(self.model_dir, 'timestamp_detector.pth.tar'))
time_detector = VideoOnsetNet(False)
self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True)
self.pipeline = build_foleycrafter()
ckpt = torch.load(temporal_ckpt_path)
# load temporal adapter
if 'state_dict' in ckpt.keys():
ckpt = ckpt['state_dict']
load_gligen_ckpt = {}
for key, value in ckpt.items():
if key.startswith('module.'):
load_gligen_ckpt[key[len('module.'):]] = value
else:
load_gligen_ckpt[key] = value
m, u = self.pipeline.controlnet.load_state_dict(load_gligen_ckpt, strict=False)
print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
self.image_processor = CLIPImageProcessor()
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder='models/image_encoder')
self.pipeline.load_ip_adapter(fc_ckpt, subfolder='semantic', weight_name='semantic_adapter.bin', image_encoder_folder=None)
gr.Info("Load Finish!")
print("Load Finish!")
self.loaded = True
return "Load"
@spaces.GPU
def foley(
self,
input_video,
prompt_textbox,
negative_prompt_textbox,
ip_adapter_scale,
temporal_scale,
sampler_dropdown,
sample_step_slider,
cfg_scale_slider,
seed_textbox,
):
device = 'cuda'
# move to gpu
self.time_detector = controller.time_detector.to(device)
self.pipeline = controller.pipeline.to(device)
self.vocoder = controller.vocoder.to(device)
self.image_encoder = controller.image_encoder.to(device)
vision_transform_list = [
torchvision.transforms.Resize((128, 128)),
torchvision.transforms.CenterCrop((112, 112)),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
video_transform = torchvision.transforms.Compose(vision_transform_list)
# if not self.loaded:
# raise gr.Error("Error with loading model")
generator = torch.Generator()
if seed_textbox != "":
torch.manual_seed(int(seed_textbox))
generator.manual_seed(int(seed_textbox))
max_frame_nums = 150
frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums)
if duration >= 10:
duration = 10
time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2).to(device)
time_frames = video_transform(time_frames)
time_frames = {'frames': time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}
preds = self.time_detector(time_frames)
preds = torch.sigmoid(preds)
# duration
time_condition = [-1 if preds[0][int(i / (1024 / 10 * duration) * max_frame_nums)] < 0.5 else 1 for i in range(int(1024 / 10 * duration))]
time_condition = time_condition + [-1] * (1024 - len(time_condition))
# w -> b c h w
time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1)
# Note that clip need fewer frames
frames = frames[::10]
images = self.image_processor(images=frames, return_tensors="pt").to(device)
image_embeddings = self.image_encoder(**images).image_embeds
image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)
neg_image_embeddings = torch.zeros_like(image_embeddings)
image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1)
self.pipeline.set_ip_adapter_scale(ip_adapter_scale)
sample = self.pipeline(
prompt=prompt_textbox,
negative_prompt=negative_prompt_textbox,
ip_adapter_image_embeds=image_embeddings,
image=time_condition,
controlnet_conditioning_scale=float(temporal_scale),
num_inference_steps=sample_step_slider,
height=256,
width=1024,
output_type="pt",
generator=generator,
)
name = 'output'
audio_img = sample.images[0]
audio = denormalize_spectrogram(audio_img)
audio = self.vocoder.inference(audio, lengths=160000)[0]
audio_save_path = osp.join(self.savedir_sample, 'audio')
os.makedirs(audio_save_path, exist_ok=True)
audio = audio[:int(duration * 16000)]
save_path = osp.join(audio_save_path, f'{name}.wav')
sf.write(save_path, audio, 16000)
audio = AudioFileClip(osp.join(audio_save_path, f'{name}.wav'))
video = VideoFileClip(input_video)
audio = audio.subclip(0, duration)
video.audio = audio
video = video.subclip(0, duration)
video.write_videofile(osp.join(self.savedir_sample, f'{name}.mp4'))
save_sample_path = os.path.join(self.savedir_sample, f"{name}.mp4")
return save_sample_path
controller = FoleyController()
device = "cuda" if torch.cuda.is_available() else "cpu"
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
with gr.Row():
with gr.Column(variant="panel"):
with gr.Row(equal_height=False):
with gr.Column():
with gr.Row():
init_img = gr.Video(label="Input Video")
with gr.Row():
prompt_textbox = gr.Textbox(value='', label="Prompt", lines=1)
with gr.Row():
negative_prompt_textbox = gr.Textbox(value=N_PROMPT, label="Negative prompt", lines=1)
with gr.Row():
ip_adapter_scale = gr.Slider(label="Visual Content Scale", value=1.0, minimum=0, maximum=1)
temporal_scale = gr.Slider(label="Temporal Align Scale", value=0.2, minimum=0., maximum=1.0)
with gr.Accordion("Sampling Settings", open=False):
with gr.Row():
sampler_dropdown = gr.Dropdown(
label="Sampling method",
choices=list(scheduler_dict.keys()),
value=list(scheduler_dict.keys())[0],
)
sample_step_slider = gr.Slider(
label="Sampling steps", value=25, minimum=10, maximum=100, step=1
)
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
with gr.Row():
seed_textbox = gr.Textbox(label="Seed", value=42)
seed_button = gr.Button(value="\U0001f3b2", elem_classes="toolbutton")
seed_button.click(fn=lambda x: random.randint(1, 1e8), outputs=[seed_textbox], queue=False)
generate_button = gr.Button(value="Generate", variant="primary")
with gr.Column():
result_video = gr.Video(label="Generated Audio", interactive=False)
with gr.Row():
pass
generate_button.click(
fn=controller.foley,
inputs=[
init_img,
prompt_textbox,
negative_prompt_textbox,
ip_adapter_scale,
temporal_scale,
sampler_dropdown,
sample_step_slider,
cfg_scale_slider,
seed_textbox,
],
outputs=[result_video],
)
gr.Examples(
examples=[
['examples/input/case1.mp4', '', '', 1.0, 0.2, 'DDIM', 25, 7.5, 33817921],
['examples/input/case3.mp4', '', '', 1.0, 0.2,'DDIM', 25, 7.5, 94667578],
['examples/input/case5.mp4', '', '', 0.75, 0.2,'DDIM', 25, 7.5, 92890876],
['examples/input/case6.mp4', '', '', 1.0, 0.2, 'DDIM', 25, 7.5, 77015909],
],
inputs=[init_img, prompt_textbox, negative_prompt_textbox, ip_adapter_scale, temporal_scale, sampler_dropdown, sample_step_slider, cfg_scale_slider, seed_textbox],
cache_examples=True,
outputs=[result_video],
fn=controller.foley,
)
demo.queue(10)
demo.launch(server_name=args.server_name, server_port=args.port, share=args.share, allowed_paths=["./foleycrafter.png"])