import os import uuid from omegaconf import OmegaConf import spaces import random import imageio import torch import torchvision import gradio as gr import numpy as np from fastapi import FastAPI from fastapi.responses import FileResponse from gradio.components import Textbox, Video from huggingface_hub import hf_hub_download from utils.common_utils import load_model_checkpoint from utils.utils import instantiate_from_config from scheduler.t2v_turbo_scheduler import T2VTurboScheduler from pipeline.t2v_turbo_vc2_pipeline import T2VTurboVC2Pipeline # Keep all your original constants and DESCRIPTION MAX_SEED = np.iinfo(np.int32).max app = FastAPI() def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def save_video(video_array, video_save_path, fps: int = 16): video = video_array.detach().cpu() video = torch.clamp(video.float(), -1.0, 1.0) video = video.permute(1, 0, 2, 3) # t,c,h,w video = (video + 1.0) / 2.0 video = (video * 255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video( video_save_path, video, fps=fps, video_codec="h264", options={"crf": "10"} ) # Keep your original example_txt and examples @spaces.GPU(duration=120) @torch.inference_mode() def generate( prompt: str, guidance_scale: float = 7.5, percentage: float = 0.5, num_inference_steps: int = 4, num_frames: int = 16, seed: int = 0, randomize_seed: bool = False, param_dtype="bf16", motion_gs: float = 0.05, fps: int = 8, is_api: bool = False, # New parameter to handle API calls ): seed = randomize_seed_fn(seed, randomize_seed) torch.manual_seed(seed) if param_dtype == "bf16": dtype = torch.bfloat16 unet.dtype = torch.bfloat16 elif param_dtype == "fp16": dtype = torch.float16 unet.dtype = torch.float16 elif param_dtype == "fp32": dtype = torch.float32 unet.dtype = torch.float32 else: raise ValueError(f"Unknown dtype: {param_dtype}") pipeline.unet.to(device, dtype) pipeline.text_encoder.to(device, dtype) pipeline.vae.to(device, dtype) pipeline.to(device, dtype) result = pipeline( prompt=prompt, frames=num_frames, fps=fps, guidance_scale=guidance_scale, motion_gs=motion_gs, use_motion_cond=True, percentage=percentage, num_inference_steps=num_inference_steps, lcm_origin_steps=200, num_videos_per_prompt=1, ) torch.cuda.empty_cache() # Generate unique filename for API calls if is_api: video_filename = f"{uuid.uuid4()}.mp4" else: video_filename = "tmp.mp4" root_path = "./videos/" os.makedirs(root_path, exist_ok=True) video_save_path = os.path.join(root_path, video_filename) save_video(result[0], video_save_path, fps=fps) display_model_info = f"Video size: {num_frames}x320x512, Sampling Step: {num_inference_steps}, Guidance Scale: {guidance_scale}" if is_api: return { "video_path": video_save_path, "prompt": prompt, "model_info": display_model_info, "seed": seed } return video_save_path, prompt, display_model_info, seed # API endpoint @app.post("/generate") async def generate_api( prompt: str, guidance_scale: float = 7.5, percentage: float = 0.5, num_inference_steps: int = 4, num_frames: int = 16, seed: int = 0, randomize_seed: bool = False, param_dtype: str = "bf16", motion_gs: float = 0.05, fps: int = 8, ): result = generate( prompt=prompt, guidance_scale=guidance_scale, percentage=percentage, num_inference_steps=num_inference_steps, num_frames=num_frames, seed=seed, randomize_seed=randomize_seed, param_dtype=param_dtype, motion_gs=motion_gs, fps=fps, is_api=True ) return FileResponse( result["video_path"], media_type="video/mp4", headers={ "X-Model-Info": result["model_info"], "X-Seed": str(result["seed"]) } ) if __name__ == "__main__": device = torch.device("cuda:0") # Keep all your original model initialization code config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml") model_config = config.pop("model", OmegaConf.create()) pretrained_t2v = instantiate_from_config(model_config) pretrained_path = hf_hub_download("VideoCrafter/VideoCrafter2", filename="model.ckpt") pretrained_t2v = load_model_checkpoint(pretrained_t2v, pretrained_path) unet_config = model_config["params"]["unet_config"] unet_config["params"]["use_checkpoint"] = False unet_config["params"]["time_cond_proj_dim"] = 256 unet_config["params"]["motion_cond_proj_dim"] = 256 unet = instantiate_from_config(unet_config) unet_path = hf_hub_download(repo_id="jiachenli-ucsb/T2V-Turbo-v2", filename="unet_mg.pt") unet.load_state_dict(torch.load(unet_path, map_location=device)) unet.eval() pretrained_t2v.model.diffusion_model = unet scheduler = T2VTurboScheduler( linear_start=model_config["params"]["linear_start"], linear_end=model_config["params"]["linear_end"], ) pipeline = T2VTurboVC2Pipeline(pretrained_t2v, scheduler, model_config) pipeline.to(device) # Mount both Gradio and FastAPI demo = gr.Interface( fn=lambda *args: generate(*args, is_api=False), inputs=[ Textbox(label="", placeholder="Please enter your prompt. \n"), gr.Slider(label="Guidance scale", minimum=2, maximum=14, step=0.1, value=7.5), gr.Slider(label="Percentage of steps to apply motion guidance", minimum=0.0, maximum=0.5, step=0.05, value=0.5), gr.Slider(label="Number of inference steps", minimum=4, maximum=50, step=1, value=16), gr.Slider(label="Number of Video Frames", minimum=16, maximum=48, step=8, value=16), gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True), gr.Checkbox(label="Randomize seed", value=True), gr.Radio(["bf16", "fp16", "fp32"], label="torch.dtype", value="bf16", interactive=True), ], outputs=[ gr.Video(label="Generated Video", width=512, height=320, interactive=False, autoplay=True), Textbox(label="input prompt"), Textbox(label="model info"), gr.Slider(label="seed"), ], #description=DESCRIPTION, #theme=gr.themes.Default(), #css=block_css, #examples=examples, #cache_examples=False, concurrency_limit=10, ) app = gr.mount_gradio_app(app, demo, path="/") # Run both servers import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)