tree3po commited on
Commit
64e9db8
1 Parent(s): e1993f3

Update open_oasis_master/generate.py

Browse files
Files changed (1) hide show
  1. open_oasis_master/generate.py +108 -107
open_oasis_master/generate.py CHANGED
@@ -13,110 +13,111 @@ from torch import autocast
13
  import os
14
  #assert torch.cuda.is_available()
15
  #device = "cuda:0"
16
- device = "cpu"
17
-
18
- # load DiT checkpoint
19
- ckpt = torch.load("oasis500m.pt",map_location=torch.device('cpu'))
20
- model = DiT_models["DiT-S/2"]()
21
- model.load_state_dict(ckpt, strict=False)
22
- model = model.to(device).eval()
23
-
24
- # load VAE checkpoint
25
- vae_ckpt = torch.load("vit-l-20.pt",map_location=torch.device('cpu'))
26
- vae = VAE_models["vit-l-20-shallow-encoder"]()
27
- vae.load_state_dict(vae_ckpt)
28
- vae = vae.to(device).eval()
29
-
30
- # sampling params
31
- B = 1
32
- total_frames = 32
33
- max_noise_level = 1000
34
- ddim_noise_steps = 100
35
- noise_range = torch.linspace(-1, max_noise_level - 1, ddim_noise_steps + 1)
36
- noise_abs_max = 20
37
- ctx_max_noise_idx = ddim_noise_steps // 10 * 3
38
-
39
- # get input video
40
- print(os.getcwd())
41
- video_id = "snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001"
42
- mp4_path = f"{os.getcwd()}/open_oasis_master/sample_data/{video_id}.mp4"
43
- actions_path = f"{os.getcwd()}/open_oasis_master/sample_data/{video_id}.actions.pt"
44
- video = read_video(mp4_path, pts_unit="sec")[0].float() / 255
45
- actions = one_hot_actions(torch.load(actions_path,map_location=torch.device('cpu')))
46
- offset = 100
47
- video = video[offset:offset+total_frames].unsqueeze(0)
48
- actions = actions[offset:offset+total_frames].unsqueeze(0)
49
-
50
- # sampling inputs
51
- n_prompt_frames = 1
52
- x = video[:, :n_prompt_frames]
53
- x = x.to(device)
54
- actions = actions.to(device)
55
-
56
- # vae encoding
57
- scaling_factor = 0.07843137255
58
- x = rearrange(x, "b t h w c -> (b t) c h w")
59
- H, W = x.shape[-2:]
60
- with torch.no_grad():
61
- x = vae.encode(x * 2 - 1).mean * scaling_factor
62
- x = rearrange(x, "(b t) (h w) c -> b t c h w", t=n_prompt_frames, h=H//vae.patch_size, w=W//vae.patch_size)
63
-
64
- # get alphas
65
- betas = sigmoid_beta_schedule(max_noise_level).to(device)
66
- alphas = 1.0 - betas
67
- alphas_cumprod = torch.cumprod(alphas, dim=0)
68
- alphas_cumprod = rearrange(alphas_cumprod, "T -> T 1 1 1")
69
-
70
- # sampling loop
71
- for i in tqdm(range(n_prompt_frames, total_frames)):
72
- chunk = torch.randn((B, 1, *x.shape[-3:]), device=device)
73
- chunk = torch.clamp(chunk, -noise_abs_max, +noise_abs_max)
74
- x = torch.cat([x, chunk], dim=1)
75
- start_frame = max(0, i + 1 - model.max_frames)
76
-
77
- for noise_idx in reversed(range(1, ddim_noise_steps + 1)):
78
- # set up noise values
79
- ctx_noise_idx = min(noise_idx, ctx_max_noise_idx)
80
- t_ctx = torch.full((B, i), noise_range[ctx_noise_idx], dtype=torch.long, device=device)
81
- t = torch.full((B, 1), noise_range[noise_idx], dtype=torch.long, device=device)
82
- t_next = torch.full((B, 1), noise_range[noise_idx - 1], dtype=torch.long, device=device)
83
- t_next = torch.where(t_next < 0, t, t_next)
84
- t = torch.cat([t_ctx, t], dim=1)
85
- t_next = torch.cat([t_ctx, t_next], dim=1)
86
-
87
- # sliding window
88
- x_curr = x.clone()
89
- x_curr = x_curr[:, start_frame:]
90
- t = t[:, start_frame:]
91
- t_next = t_next[:, start_frame:]
92
-
93
- # add some noise to the context
94
- ctx_noise = torch.randn_like(x_curr[:, :-1])
95
- ctx_noise = torch.clamp(ctx_noise, -noise_abs_max, +noise_abs_max)
96
- x_curr[:, :-1] = alphas_cumprod[t[:, :-1]].sqrt() * x_curr[:, :-1] + (1 - alphas_cumprod[t[:, :-1]]).sqrt() * ctx_noise
97
-
98
- # get model predictions
99
- with torch.no_grad():
100
- with autocast("cpu", dtype=torch.half):
101
- v = model(x_curr, t, actions[:, start_frame : i + 1])
102
-
103
- x_start = alphas_cumprod[t].sqrt() * x_curr - (1 - alphas_cumprod[t]).sqrt() * v
104
- x_noise = ((1 / alphas_cumprod[t]).sqrt() * x_curr - x_start) \
105
- / (1 / alphas_cumprod[t] - 1).sqrt()
106
-
107
- # get frame prediction
108
- x_pred = alphas_cumprod[t_next].sqrt() * x_start + x_noise * (1 - alphas_cumprod[t_next]).sqrt()
109
- x[:, -1:] = x_pred[:, -1:]
110
-
111
- # vae decoding
112
- x = rearrange(x, "b t c h w -> (b t) (h w) c")
113
- with torch.no_grad():
114
- x = (vae.decode(x / scaling_factor) + 1) / 2
115
- x = rearrange(x, "(b t) c h w -> b t h w c", t=total_frames)
116
-
117
- # save video
118
- x = torch.clamp(x, 0, 1)
119
- x = (x * 255).byte()
120
- write_video("video.mp4", x[0], fps=20)
121
- print("generation saved to video.mp4.")
122
-
 
 
13
  import os
14
  #assert torch.cuda.is_available()
15
  #device = "cuda:0"
16
+ def run_mod():
17
+ device = "cpu"
18
+
19
+ # load DiT checkpoint
20
+ ckpt = torch.load("oasis500m.pt",map_location=torch.device('cpu'))
21
+ model = DiT_models["DiT-S/2"]()
22
+ model.load_state_dict(ckpt, strict=False)
23
+ model = model.to(device).eval()
24
+
25
+ # load VAE checkpoint
26
+ vae_ckpt = torch.load("vit-l-20.pt",map_location=torch.device('cpu'))
27
+ vae = VAE_models["vit-l-20-shallow-encoder"]()
28
+ vae.load_state_dict(vae_ckpt)
29
+ vae = vae.to(device).eval()
30
+
31
+ # sampling params
32
+ B = 1
33
+ total_frames = 32
34
+ max_noise_level = 1000
35
+ ddim_noise_steps = 100
36
+ noise_range = torch.linspace(-1, max_noise_level - 1, ddim_noise_steps + 1)
37
+ noise_abs_max = 20
38
+ ctx_max_noise_idx = ddim_noise_steps // 10 * 3
39
+
40
+ # get input video
41
+ print(os.getcwd())
42
+ video_id = "snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001"
43
+ mp4_path = f"{os.getcwd()}/open_oasis_master/sample_data/{video_id}.mp4"
44
+ actions_path = f"{os.getcwd()}/open_oasis_master/sample_data/{video_id}.actions.pt"
45
+ video = read_video(mp4_path, pts_unit="sec")[0].float() / 255
46
+ actions = one_hot_actions(torch.load(actions_path,map_location=torch.device('cpu')))
47
+ offset = 100
48
+ video = video[offset:offset+total_frames].unsqueeze(0)
49
+ actions = actions[offset:offset+total_frames].unsqueeze(0)
50
+
51
+ # sampling inputs
52
+ n_prompt_frames = 1
53
+ x = video[:, :n_prompt_frames]
54
+ x = x.to(device)
55
+ actions = actions.to(device)
56
+
57
+ # vae encoding
58
+ scaling_factor = 0.07843137255
59
+ x = rearrange(x, "b t h w c -> (b t) c h w")
60
+ H, W = x.shape[-2:]
61
+ with torch.no_grad():
62
+ x = vae.encode(x * 2 - 1).mean * scaling_factor
63
+ x = rearrange(x, "(b t) (h w) c -> b t c h w", t=n_prompt_frames, h=H//vae.patch_size, w=W//vae.patch_size)
64
+
65
+ # get alphas
66
+ betas = sigmoid_beta_schedule(max_noise_level).to(device)
67
+ alphas = 1.0 - betas
68
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
69
+ alphas_cumprod = rearrange(alphas_cumprod, "T -> T 1 1 1")
70
+
71
+ # sampling loop
72
+ for i in tqdm(range(n_prompt_frames, total_frames)):
73
+ chunk = torch.randn((B, 1, *x.shape[-3:]), device=device)
74
+ chunk = torch.clamp(chunk, -noise_abs_max, +noise_abs_max)
75
+ x = torch.cat([x, chunk], dim=1)
76
+ start_frame = max(0, i + 1 - model.max_frames)
77
+
78
+ for noise_idx in reversed(range(1, ddim_noise_steps + 1)):
79
+ # set up noise values
80
+ ctx_noise_idx = min(noise_idx, ctx_max_noise_idx)
81
+ t_ctx = torch.full((B, i), noise_range[ctx_noise_idx], dtype=torch.long, device=device)
82
+ t = torch.full((B, 1), noise_range[noise_idx], dtype=torch.long, device=device)
83
+ t_next = torch.full((B, 1), noise_range[noise_idx - 1], dtype=torch.long, device=device)
84
+ t_next = torch.where(t_next < 0, t, t_next)
85
+ t = torch.cat([t_ctx, t], dim=1)
86
+ t_next = torch.cat([t_ctx, t_next], dim=1)
87
+
88
+ # sliding window
89
+ x_curr = x.clone()
90
+ x_curr = x_curr[:, start_frame:]
91
+ t = t[:, start_frame:]
92
+ t_next = t_next[:, start_frame:]
93
+
94
+ # add some noise to the context
95
+ ctx_noise = torch.randn_like(x_curr[:, :-1])
96
+ ctx_noise = torch.clamp(ctx_noise, -noise_abs_max, +noise_abs_max)
97
+ x_curr[:, :-1] = alphas_cumprod[t[:, :-1]].sqrt() * x_curr[:, :-1] + (1 - alphas_cumprod[t[:, :-1]]).sqrt() * ctx_noise
98
+
99
+ # get model predictions
100
+ with torch.no_grad():
101
+ with autocast("cpu", dtype=torch.half):
102
+ v = model(x_curr, t, actions[:, start_frame : i + 1])
103
+
104
+ x_start = alphas_cumprod[t].sqrt() * x_curr - (1 - alphas_cumprod[t]).sqrt() * v
105
+ x_noise = ((1 / alphas_cumprod[t]).sqrt() * x_curr - x_start) \
106
+ / (1 / alphas_cumprod[t] - 1).sqrt()
107
+
108
+ # get frame prediction
109
+ x_pred = alphas_cumprod[t_next].sqrt() * x_start + x_noise * (1 - alphas_cumprod[t_next]).sqrt()
110
+ x[:, -1:] = x_pred[:, -1:]
111
+
112
+ # vae decoding
113
+ x = rearrange(x, "b t c h w -> (b t) (h w) c")
114
+ with torch.no_grad():
115
+ x = (vae.decode(x / scaling_factor) + 1) / 2
116
+ x = rearrange(x, "(b t) c h w -> b t h w c", t=total_frames)
117
+
118
+ # save video
119
+ x = torch.clamp(x, 0, 1)
120
+ x = (x * 255).byte()
121
+ write_video("video.mp4", x[0], fps=20)
122
+ print("generation saved to video.mp4.")
123
+ return "video.mp4"