Spaces:
Runtime error
Runtime error
first commit
Browse files- lib/{ddim.py β ddim_multi.py} +16 -13
lib/{ddim.py β ddim_multi.py}
RENAMED
@@ -14,8 +14,7 @@ import torch
|
|
14 |
import numpy as np
|
15 |
from tqdm import tqdm
|
16 |
|
17 |
-
from
|
18 |
-
extract_into_tensor
|
19 |
|
20 |
|
21 |
class DDIMSampler(object):
|
@@ -33,7 +32,7 @@ class DDIMSampler(object):
|
|
33 |
|
34 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
35 |
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
36 |
-
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
37 |
alphas_cumprod = self.model.alphas_cumprod
|
38 |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
39 |
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
@@ -52,14 +51,14 @@ class DDIMSampler(object):
|
|
52 |
# ddim sampling parameters
|
53 |
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
54 |
ddim_timesteps=self.ddim_timesteps,
|
55 |
-
eta=ddim_eta,
|
56 |
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
57 |
self.register_buffer('ddim_alphas', ddim_alphas)
|
58 |
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
59 |
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
60 |
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
61 |
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
62 |
-
|
63 |
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
64 |
|
65 |
@torch.no_grad()
|
@@ -83,8 +82,7 @@ class DDIMSampler(object):
|
|
83 |
x_T=None,
|
84 |
log_every_t=100,
|
85 |
unconditional_guidance_scale=1.,
|
86 |
-
unconditional_conditioning=None,
|
87 |
-
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
88 |
dynamic_threshold=None,
|
89 |
ucg_schedule=None,
|
90 |
**kwargs
|
@@ -153,7 +151,7 @@ class DDIMSampler(object):
|
|
153 |
timesteps = self.ddim_timesteps[:subset_end]
|
154 |
|
155 |
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
156 |
-
time_range = reversed(range(0,
|
157 |
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
158 |
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
159 |
|
@@ -196,6 +194,8 @@ class DDIMSampler(object):
|
|
196 |
dynamic_threshold=None):
|
197 |
b, *_, device = *x.shape, x.device
|
198 |
|
|
|
|
|
199 |
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
200 |
model_output = self.model.apply_model(x, t, c)
|
201 |
else:
|
@@ -205,14 +205,16 @@ class DDIMSampler(object):
|
|
205 |
assert isinstance(unconditional_conditioning, dict)
|
206 |
c_in = dict()
|
207 |
for k in c:
|
|
|
|
|
208 |
if isinstance(c[k], list):
|
209 |
c_in[k] = [torch.cat([
|
210 |
unconditional_conditioning[k][i],
|
211 |
c[k][i]]) for i in range(len(c[k]))]
|
212 |
else:
|
213 |
c_in[k] = torch.cat([
|
214 |
-
|
215 |
-
|
216 |
elif isinstance(c, list):
|
217 |
c_in = list()
|
218 |
assert isinstance(unconditional_conditioning, list)
|
@@ -220,6 +222,7 @@ class DDIMSampler(object):
|
|
220 |
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
221 |
else:
|
222 |
c_in = torch.cat([unconditional_conditioning, c])
|
|
|
223 |
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
224 |
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
225 |
|
@@ -240,7 +243,7 @@ class DDIMSampler(object):
|
|
240 |
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
241 |
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
242 |
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
243 |
-
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],
|
244 |
|
245 |
# current prediction for x_0
|
246 |
if self.model.parameterization != "v":
|
@@ -255,7 +258,7 @@ class DDIMSampler(object):
|
|
255 |
raise NotImplementedError()
|
256 |
|
257 |
# direction pointing to x_t
|
258 |
-
dir_xt = (1. - a_prev - sigma_t
|
259 |
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
260 |
if noise_dropout > 0.:
|
261 |
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
@@ -345,4 +348,4 @@ class DDIMSampler(object):
|
|
345 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
346 |
unconditional_conditioning=unconditional_conditioning)
|
347 |
if callback: callback(i)
|
348 |
-
return x_dec
|
|
|
14 |
import numpy as np
|
15 |
from tqdm import tqdm
|
16 |
|
17 |
+
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
|
|
18 |
|
19 |
|
20 |
class DDIMSampler(object):
|
|
|
32 |
|
33 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
34 |
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
35 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
36 |
alphas_cumprod = self.model.alphas_cumprod
|
37 |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
38 |
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
|
|
51 |
# ddim sampling parameters
|
52 |
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
53 |
ddim_timesteps=self.ddim_timesteps,
|
54 |
+
eta=ddim_eta,verbose=verbose)
|
55 |
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
56 |
self.register_buffer('ddim_alphas', ddim_alphas)
|
57 |
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
58 |
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
59 |
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
60 |
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
61 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
62 |
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
63 |
|
64 |
@torch.no_grad()
|
|
|
82 |
x_T=None,
|
83 |
log_every_t=100,
|
84 |
unconditional_guidance_scale=1.,
|
85 |
+
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
|
86 |
dynamic_threshold=None,
|
87 |
ucg_schedule=None,
|
88 |
**kwargs
|
|
|
151 |
timesteps = self.ddim_timesteps[:subset_end]
|
152 |
|
153 |
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
154 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
155 |
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
156 |
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
157 |
|
|
|
194 |
dynamic_threshold=None):
|
195 |
b, *_, device = *x.shape, x.device
|
196 |
|
197 |
+
task_name = c['task']
|
198 |
+
|
199 |
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
200 |
model_output = self.model.apply_model(x, t, c)
|
201 |
else:
|
|
|
205 |
assert isinstance(unconditional_conditioning, dict)
|
206 |
c_in = dict()
|
207 |
for k in c:
|
208 |
+
if k == 'task':
|
209 |
+
continue
|
210 |
if isinstance(c[k], list):
|
211 |
c_in[k] = [torch.cat([
|
212 |
unconditional_conditioning[k][i],
|
213 |
c[k][i]]) for i in range(len(c[k]))]
|
214 |
else:
|
215 |
c_in[k] = torch.cat([
|
216 |
+
unconditional_conditioning[k],
|
217 |
+
c[k]])
|
218 |
elif isinstance(c, list):
|
219 |
c_in = list()
|
220 |
assert isinstance(unconditional_conditioning, list)
|
|
|
222 |
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
223 |
else:
|
224 |
c_in = torch.cat([unconditional_conditioning, c])
|
225 |
+
c_in['task'] = task_name
|
226 |
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
227 |
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
228 |
|
|
|
243 |
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
244 |
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
245 |
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
246 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
247 |
|
248 |
# current prediction for x_0
|
249 |
if self.model.parameterization != "v":
|
|
|
258 |
raise NotImplementedError()
|
259 |
|
260 |
# direction pointing to x_t
|
261 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
262 |
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
263 |
if noise_dropout > 0.:
|
264 |
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
|
348 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
349 |
unconditional_conditioning=unconditional_conditioning)
|
350 |
if callback: callback(i)
|
351 |
+
return x_dec
|