multimodalart HF staff commited on
Commit
e2cb170
1 Parent(s): 4ebfd89

CogVideoXDPMScheduler and 24 steps

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -6,7 +6,7 @@ import time
6
  import gradio as gr
7
  import numpy as np
8
  import torch
9
- from diffusers import CogVideoXPipeline
10
  from datetime import datetime, timedelta
11
  from openai import OpenAI
12
  import spaces
@@ -18,6 +18,7 @@ import PIL
18
  dtype = torch.float16
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device)
 
21
 
22
  sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
23
 
@@ -88,7 +89,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
88
  return prompt
89
 
90
 
91
- @spaces.GPU()
92
  def infer(
93
  prompt: str,
94
  num_inference_steps: int,
@@ -174,7 +175,7 @@ with gr.Blocks() as demo:
174
  "Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
175
  "50 steps are recommended for most cases. will cause 120 seconds for inference.<br>")
176
  with gr.Row():
177
- num_inference_steps = gr.Number(label="Inference Steps", value=50)
178
  guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
179
  generate_button = gr.Button("🎬 Generate Video")
180
 
 
6
  import gradio as gr
7
  import numpy as np
8
  import torch
9
+ from diffusers import CogVideoXPipeline, CogVideoXDPMScheduler
10
  from datetime import datetime, timedelta
11
  from openai import OpenAI
12
  import spaces
 
18
  dtype = torch.float16
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device)
21
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
22
 
23
  sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
24
 
 
89
  return prompt
90
 
91
 
92
+ @spaces.GPU(duration=120)
93
  def infer(
94
  prompt: str,
95
  num_inference_steps: int,
 
175
  "Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
176
  "50 steps are recommended for most cases. will cause 120 seconds for inference.<br>")
177
  with gr.Row():
178
+ num_inference_steps = gr.Slider(label="Inference Steps", value=24, minimum=1, maximum=24)
179
  guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
180
  generate_button = gr.Button("🎬 Generate Video")
181