LokasNori commited on
Commit
898c6fa
1 Parent(s): 8e65456

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.distributed as dist
4
+ import torch.multiprocessing as mp
5
+ import time
6
+ import json
7
+ import os
8
+
9
+ from src.video_crafter import VideoCrafterPipeline
10
+ from src.tools import DistController
11
+ from src.video_infinity.wrapper import DistWrapper
12
+
13
+ def init_pipeline(config):
14
+ pipe = VideoCrafterPipeline.from_pretrained(
15
+ 'adamdad/videocrafterv2_diffusers',
16
+ torch_dtype=torch.float16
17
+ )
18
+ pipe.enable_model_cpu_offload(
19
+ gpu_id=config["devices"][dist.get_rank() % len(config["devices"])],
20
+ )
21
+ pipe.enable_vae_slicing()
22
+ return pipe
23
+
24
+ def run_inference(prompt, config):
25
+ dist_controller = DistController(0, 1, config)
26
+ pipe = init_pipeline(config)
27
+ dist_pipe = DistWrapper(pipe, dist_controller, config)
28
+ pipe_configs = config['pipe_configs']
29
+ plugin_configs = config['plugin_configs']
30
+
31
+ start = time.time()
32
+ video_path = dist_pipe.inference(
33
+ prompt,
34
+ config,
35
+ pipe_configs,
36
+ plugin_configs,
37
+ additional_info={
38
+ "full_config": config,
39
+ }
40
+ )
41
+ print(f"Inference finished. Time: {time.time() - start}")
42
+ return video_path
43
+
44
+ def demo(input_text):
45
+ base_path = "./results"
46
+
47
+ if not os.path.exists(base_path):
48
+ os.makedirs(base_path)
49
+
50
+ config = {
51
+ "devices": [0], # Укажите индексы ваших GPU, например [0] для одной GPU или [0, 1] для двух
52
+ "base_path": base_path, # Указываем путь, где будут сохраняться видео
53
+ "pipe_configs": {
54
+ "prompts": [input_text]
55
+ },
56
+ "plugin_configs": {}
57
+ }
58
+ video_path = run_inference(input_text, config)
59
+ return video_path
60
+
61
+ iface = gr.Interface(fn=demo, inputs="text", outputs="video")
62
+ iface.launch()