Ji4chenLi
commited on
Commit
•
e13087b
1
Parent(s):
a6dd76e
update app.py to link space and models
Browse files- app.py +7 -2
- checkpoints/VideoCrafter2_model.ckpt +0 -3
- checkpoints/unet_mg.pt +0 -3
app.py
CHANGED
@@ -10,7 +10,9 @@ import torch
|
|
10 |
import torchvision
|
11 |
import gradio as gr
|
12 |
import numpy as np
|
|
|
13 |
from gradio.components import Textbox, Video
|
|
|
14 |
|
15 |
from utils.common_utils import load_model_checkpoint
|
16 |
from utils.utils import instantiate_from_config
|
@@ -144,7 +146,9 @@ if __name__ == "__main__":
|
|
144 |
config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml")
|
145 |
model_config = config.pop("model", OmegaConf.create())
|
146 |
pretrained_t2v = instantiate_from_config(model_config)
|
147 |
-
|
|
|
|
|
148 |
|
149 |
unet_config = model_config["params"]["unet_config"]
|
150 |
unet_config["params"]["use_checkpoint"] = False
|
@@ -153,7 +157,8 @@ if __name__ == "__main__":
|
|
153 |
|
154 |
unet = instantiate_from_config(unet_config)
|
155 |
|
156 |
-
|
|
|
157 |
unet.eval()
|
158 |
|
159 |
pretrained_t2v.model.diffusion_model = unet
|
|
|
10 |
import torchvision
|
11 |
import gradio as gr
|
12 |
import numpy as np
|
13 |
+
|
14 |
from gradio.components import Textbox, Video
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
|
17 |
from utils.common_utils import load_model_checkpoint
|
18 |
from utils.utils import instantiate_from_config
|
|
|
146 |
config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml")
|
147 |
model_config = config.pop("model", OmegaConf.create())
|
148 |
pretrained_t2v = instantiate_from_config(model_config)
|
149 |
+
|
150 |
+
pretrained_path = hf_hub_download("VideoCrafter/VideoCrafter2", filename="model.ckpt")
|
151 |
+
pretrained_t2v = load_model_checkpoint(pretrained_t2v, pretrained_path)
|
152 |
|
153 |
unet_config = model_config["params"]["unet_config"]
|
154 |
unet_config["params"]["use_checkpoint"] = False
|
|
|
157 |
|
158 |
unet = instantiate_from_config(unet_config)
|
159 |
|
160 |
+
unet_path = hf_hub_download(repo_id="jiachenli-ucsb/T2V-Turbo-v2", filename="unet_mg.pt")
|
161 |
+
unet.load_state_dict(torch.load(unet_path, map_location=device))
|
162 |
unet.eval()
|
163 |
|
164 |
pretrained_t2v.model.diffusion_model = unet
|
checkpoints/VideoCrafter2_model.ckpt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:1edf769ece3308e977228943eeeed39286806aba9da17350449a3fbf4324ccfb
|
3 |
-
size 7404653244
|
|
|
|
|
|
|
|
checkpoints/unet_mg.pt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:92c8767b40a5b2737dd3c69f5f13dae222ead5bd4befbbf894ca870231db13bc
|
3 |
-
size 5655143958
|
|
|
|
|
|
|
|