Robert001 commited on
Commit
6d16592
β€’
1 Parent(s): eb1085f

first commit

Browse files
Files changed (1) hide show
  1. lib/{ddim.py β†’ ddim_multi.py} +16 -13
lib/{ddim.py β†’ ddim_multi.py} RENAMED
@@ -14,8 +14,7 @@ import torch
14
  import numpy as np
15
  from tqdm import tqdm
16
 
17
- from lib.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
18
- extract_into_tensor
19
 
20
 
21
  class DDIMSampler(object):
@@ -33,7 +32,7 @@ class DDIMSampler(object):
33
 
34
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
35
  self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
36
- num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
37
  alphas_cumprod = self.model.alphas_cumprod
38
  assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
39
  to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
@@ -52,14 +51,14 @@ class DDIMSampler(object):
52
  # ddim sampling parameters
53
  ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
54
  ddim_timesteps=self.ddim_timesteps,
55
- eta=ddim_eta, verbose=verbose)
56
  self.register_buffer('ddim_sigmas', ddim_sigmas)
57
  self.register_buffer('ddim_alphas', ddim_alphas)
58
  self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
59
  self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
60
  sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
61
  (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
62
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
63
  self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
64
 
65
  @torch.no_grad()
@@ -83,8 +82,7 @@ class DDIMSampler(object):
83
  x_T=None,
84
  log_every_t=100,
85
  unconditional_guidance_scale=1.,
86
- unconditional_conditioning=None,
87
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
88
  dynamic_threshold=None,
89
  ucg_schedule=None,
90
  **kwargs
@@ -153,7 +151,7 @@ class DDIMSampler(object):
153
  timesteps = self.ddim_timesteps[:subset_end]
154
 
155
  intermediates = {'x_inter': [img], 'pred_x0': [img]}
156
- time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
157
  total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
158
  print(f"Running DDIM Sampling with {total_steps} timesteps")
159
 
@@ -196,6 +194,8 @@ class DDIMSampler(object):
196
  dynamic_threshold=None):
197
  b, *_, device = *x.shape, x.device
198
 
 
 
199
  if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
200
  model_output = self.model.apply_model(x, t, c)
201
  else:
@@ -205,14 +205,16 @@ class DDIMSampler(object):
205
  assert isinstance(unconditional_conditioning, dict)
206
  c_in = dict()
207
  for k in c:
 
 
208
  if isinstance(c[k], list):
209
  c_in[k] = [torch.cat([
210
  unconditional_conditioning[k][i],
211
  c[k][i]]) for i in range(len(c[k]))]
212
  else:
213
  c_in[k] = torch.cat([
214
- unconditional_conditioning[k],
215
- c[k]])
216
  elif isinstance(c, list):
217
  c_in = list()
218
  assert isinstance(unconditional_conditioning, list)
@@ -220,6 +222,7 @@ class DDIMSampler(object):
220
  c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
221
  else:
222
  c_in = torch.cat([unconditional_conditioning, c])
 
223
  model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
224
  model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
225
 
@@ -240,7 +243,7 @@ class DDIMSampler(object):
240
  a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
241
  a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
242
  sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
243
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
244
 
245
  # current prediction for x_0
246
  if self.model.parameterization != "v":
@@ -255,7 +258,7 @@ class DDIMSampler(object):
255
  raise NotImplementedError()
256
 
257
  # direction pointing to x_t
258
- dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
259
  noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
260
  if noise_dropout > 0.:
261
  noise = torch.nn.functional.dropout(noise, p=noise_dropout)
@@ -345,4 +348,4 @@ class DDIMSampler(object):
345
  unconditional_guidance_scale=unconditional_guidance_scale,
346
  unconditional_conditioning=unconditional_conditioning)
347
  if callback: callback(i)
348
- return x_dec
 
14
  import numpy as np
15
  from tqdm import tqdm
16
 
17
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
 
18
 
19
 
20
  class DDIMSampler(object):
 
32
 
33
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
34
  self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
35
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
36
  alphas_cumprod = self.model.alphas_cumprod
37
  assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
38
  to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
 
51
  # ddim sampling parameters
52
  ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
53
  ddim_timesteps=self.ddim_timesteps,
54
+ eta=ddim_eta,verbose=verbose)
55
  self.register_buffer('ddim_sigmas', ddim_sigmas)
56
  self.register_buffer('ddim_alphas', ddim_alphas)
57
  self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
58
  self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
59
  sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
60
  (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
61
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
62
  self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
63
 
64
  @torch.no_grad()
 
82
  x_T=None,
83
  log_every_t=100,
84
  unconditional_guidance_scale=1.,
85
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
 
86
  dynamic_threshold=None,
87
  ucg_schedule=None,
88
  **kwargs
 
151
  timesteps = self.ddim_timesteps[:subset_end]
152
 
153
  intermediates = {'x_inter': [img], 'pred_x0': [img]}
154
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
155
  total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
156
  print(f"Running DDIM Sampling with {total_steps} timesteps")
157
 
 
194
  dynamic_threshold=None):
195
  b, *_, device = *x.shape, x.device
196
 
197
+ task_name = c['task']
198
+
199
  if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
200
  model_output = self.model.apply_model(x, t, c)
201
  else:
 
205
  assert isinstance(unconditional_conditioning, dict)
206
  c_in = dict()
207
  for k in c:
208
+ if k == 'task':
209
+ continue
210
  if isinstance(c[k], list):
211
  c_in[k] = [torch.cat([
212
  unconditional_conditioning[k][i],
213
  c[k][i]]) for i in range(len(c[k]))]
214
  else:
215
  c_in[k] = torch.cat([
216
+ unconditional_conditioning[k],
217
+ c[k]])
218
  elif isinstance(c, list):
219
  c_in = list()
220
  assert isinstance(unconditional_conditioning, list)
 
222
  c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
223
  else:
224
  c_in = torch.cat([unconditional_conditioning, c])
225
+ c_in['task'] = task_name
226
  model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
227
  model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
228
 
 
243
  a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
244
  a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
245
  sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
246
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
247
 
248
  # current prediction for x_0
249
  if self.model.parameterization != "v":
 
258
  raise NotImplementedError()
259
 
260
  # direction pointing to x_t
261
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
262
  noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
263
  if noise_dropout > 0.:
264
  noise = torch.nn.functional.dropout(noise, p=noise_dropout)
 
348
  unconditional_guidance_scale=unconditional_guidance_scale,
349
  unconditional_conditioning=unconditional_conditioning)
350
  if callback: callback(i)
351
+ return x_dec