Spaces:
Paused
Paused
import os | |
import uuid | |
from omegaconf import OmegaConf | |
import spaces | |
import random | |
import imageio | |
import torch | |
import torchvision | |
import gradio as gr | |
import numpy as np | |
from fastapi import FastAPI | |
from fastapi.responses import FileResponse | |
from gradio.components import Textbox, Video | |
from huggingface_hub import hf_hub_download | |
from utils.common_utils import load_model_checkpoint | |
from utils.utils import instantiate_from_config | |
from scheduler.t2v_turbo_scheduler import T2VTurboScheduler | |
from pipeline.t2v_turbo_vc2_pipeline import T2VTurboVC2Pipeline | |
# Keep all your original constants and DESCRIPTION | |
MAX_SEED = np.iinfo(np.int32).max | |
app = FastAPI() | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def save_video(video_array, video_save_path, fps: int = 16): | |
video = video_array.detach().cpu() | |
video = torch.clamp(video.float(), -1.0, 1.0) | |
video = video.permute(1, 0, 2, 3) # t,c,h,w | |
video = (video + 1.0) / 2.0 | |
video = (video * 255).to(torch.uint8).permute(0, 2, 3, 1) | |
torchvision.io.write_video( | |
video_save_path, video, fps=fps, video_codec="h264", options={"crf": "10"} | |
) | |
# Keep your original example_txt and examples | |
def generate( | |
prompt: str, | |
guidance_scale: float = 7.5, | |
percentage: float = 0.5, | |
num_inference_steps: int = 4, | |
num_frames: int = 16, | |
seed: int = 0, | |
randomize_seed: bool = False, | |
param_dtype="bf16", | |
motion_gs: float = 0.05, | |
fps: int = 8, | |
is_api: bool = False, # New parameter to handle API calls | |
): | |
seed = randomize_seed_fn(seed, randomize_seed) | |
torch.manual_seed(seed) | |
if param_dtype == "bf16": | |
dtype = torch.bfloat16 | |
unet.dtype = torch.bfloat16 | |
elif param_dtype == "fp16": | |
dtype = torch.float16 | |
unet.dtype = torch.float16 | |
elif param_dtype == "fp32": | |
dtype = torch.float32 | |
unet.dtype = torch.float32 | |
else: | |
raise ValueError(f"Unknown dtype: {param_dtype}") | |
pipeline.unet.to(device, dtype) | |
pipeline.text_encoder.to(device, dtype) | |
pipeline.vae.to(device, dtype) | |
pipeline.to(device, dtype) | |
result = pipeline( | |
prompt=prompt, | |
frames=num_frames, | |
fps=fps, | |
guidance_scale=guidance_scale, | |
motion_gs=motion_gs, | |
use_motion_cond=True, | |
percentage=percentage, | |
num_inference_steps=num_inference_steps, | |
lcm_origin_steps=200, | |
num_videos_per_prompt=1, | |
) | |
torch.cuda.empty_cache() | |
# Generate unique filename for API calls | |
if is_api: | |
video_filename = f"{uuid.uuid4()}.mp4" | |
else: | |
video_filename = "tmp.mp4" | |
root_path = "./videos/" | |
os.makedirs(root_path, exist_ok=True) | |
video_save_path = os.path.join(root_path, video_filename) | |
save_video(result[0], video_save_path, fps=fps) | |
display_model_info = f"Video size: {num_frames}x320x512, Sampling Step: {num_inference_steps}, Guidance Scale: {guidance_scale}" | |
if is_api: | |
return { | |
"video_path": video_save_path, | |
"prompt": prompt, | |
"model_info": display_model_info, | |
"seed": seed | |
} | |
return video_save_path, prompt, display_model_info, seed | |
# API endpoint | |
async def generate_api( | |
prompt: str, | |
guidance_scale: float = 7.5, | |
percentage: float = 0.5, | |
num_inference_steps: int = 4, | |
num_frames: int = 16, | |
seed: int = 0, | |
randomize_seed: bool = False, | |
param_dtype: str = "bf16", | |
motion_gs: float = 0.05, | |
fps: int = 8, | |
): | |
result = generate( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
percentage=percentage, | |
num_inference_steps=num_inference_steps, | |
num_frames=num_frames, | |
seed=seed, | |
randomize_seed=randomize_seed, | |
param_dtype=param_dtype, | |
motion_gs=motion_gs, | |
fps=fps, | |
is_api=True | |
) | |
return FileResponse( | |
result["video_path"], | |
media_type="video/mp4", | |
headers={ | |
"X-Model-Info": result["model_info"], | |
"X-Seed": str(result["seed"]) | |
} | |
) | |
if __name__ == "__main__": | |
device = torch.device("cuda:0") | |
# Keep all your original model initialization code | |
config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml") | |
model_config = config.pop("model", OmegaConf.create()) | |
pretrained_t2v = instantiate_from_config(model_config) | |
pretrained_path = hf_hub_download("VideoCrafter/VideoCrafter2", filename="model.ckpt") | |
pretrained_t2v = load_model_checkpoint(pretrained_t2v, pretrained_path) | |
unet_config = model_config["params"]["unet_config"] | |
unet_config["params"]["use_checkpoint"] = False | |
unet_config["params"]["time_cond_proj_dim"] = 256 | |
unet_config["params"]["motion_cond_proj_dim"] = 256 | |
unet = instantiate_from_config(unet_config) | |
unet_path = hf_hub_download(repo_id="jiachenli-ucsb/T2V-Turbo-v2", filename="unet_mg.pt") | |
unet.load_state_dict(torch.load(unet_path, map_location=device)) | |
unet.eval() | |
pretrained_t2v.model.diffusion_model = unet | |
scheduler = T2VTurboScheduler( | |
linear_start=model_config["params"]["linear_start"], | |
linear_end=model_config["params"]["linear_end"], | |
) | |
pipeline = T2VTurboVC2Pipeline(pretrained_t2v, scheduler, model_config) | |
pipeline.to(device) | |
# Mount both Gradio and FastAPI | |
demo = gr.Interface( | |
fn=lambda *args: generate(*args, is_api=False), | |
inputs=[ | |
Textbox(label="", placeholder="Please enter your prompt. \n"), | |
gr.Slider(label="Guidance scale", minimum=2, maximum=14, step=0.1, value=7.5), | |
gr.Slider(label="Percentage of steps to apply motion guidance", minimum=0.0, maximum=0.5, step=0.05, value=0.5), | |
gr.Slider(label="Number of inference steps", minimum=4, maximum=50, step=1, value=16), | |
gr.Slider(label="Number of Video Frames", minimum=16, maximum=48, step=8, value=16), | |
gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True), | |
gr.Checkbox(label="Randomize seed", value=True), | |
gr.Radio(["bf16", "fp16", "fp32"], label="torch.dtype", value="bf16", interactive=True), | |
], | |
outputs=[ | |
gr.Video(label="Generated Video", width=512, height=320, interactive=False, autoplay=True), | |
Textbox(label="input prompt"), | |
Textbox(label="model info"), | |
gr.Slider(label="seed"), | |
], | |
#description=DESCRIPTION, | |
#theme=gr.themes.Default(), | |
#css=block_css, | |
#examples=examples, | |
#cache_examples=False, | |
concurrency_limit=10, | |
) | |
app = gr.mount_gradio_app(app, demo, path="/") | |
# Run both servers | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |