teticio commited on
Commit
1ef9d1c
1 Parent(s): 021deca

handle steps correctly

Browse files
audiodiffusion/__init__.py CHANGED
@@ -177,11 +177,8 @@ class AudioDiffusionPipeline(DiffusionPipeline):
177
  (float, List[np.ndarray]): sample rate and raw audios
178
  """
179
 
180
- if steps is None:
181
- steps = self.scheduler.num_train_timesteps
182
- # Unfortunately, the schedule is set up in the constructor
183
- scheduler = self.scheduler.__class__(num_train_timesteps=steps)
184
- scheduler.set_timesteps(steps)
185
  mask = None
186
  images = noise = torch.randn(
187
  (batch_size, self.unet.in_channels, self.unet.sample_size,
@@ -204,7 +201,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
204
  input_images = 0.18215 * input_images
205
 
206
  if start_step > 0:
207
- images[0, 0] = scheduler.add_noise(
208
  torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
209
  noise, torch.tensor(steps - start_step))
210
 
@@ -213,18 +210,18 @@ class AudioDiffusionPipeline(DiffusionPipeline):
213
  mel.x_res)
214
  mask_start = int(mask_start_secs * pixels_per_second)
215
  mask_end = int(mask_end_secs * pixels_per_second)
216
- mask = scheduler.add_noise(
217
  torch.tensor(input_images[:, np.newaxis, :]), noise,
218
- torch.tensor(scheduler.timesteps[start_step:]))
219
 
220
  images = images.to(self.device)
221
  for step, t in enumerate(
222
- self.progress_bar(scheduler.timesteps[start_step:])):
223
  model_output = self.unet(images, t)['sample']
224
- images = scheduler.step(model_output,
225
- t,
226
- images,
227
- generator=generator)['prev_sample']
228
 
229
  if mask is not None:
230
  if mask_start > 0:
 
177
  (float, List[np.ndarray]): sample rate and raw audios
178
  """
179
 
180
+ if steps is not None:
181
+ self.scheduler.set_timesteps(steps)
 
 
 
182
  mask = None
183
  images = noise = torch.randn(
184
  (batch_size, self.unet.in_channels, self.unet.sample_size,
 
201
  input_images = 0.18215 * input_images
202
 
203
  if start_step > 0:
204
+ images[0, 0] = self.scheduler.add_noise(
205
  torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
206
  noise, torch.tensor(steps - start_step))
207
 
 
210
  mel.x_res)
211
  mask_start = int(mask_start_secs * pixels_per_second)
212
  mask_end = int(mask_end_secs * pixels_per_second)
213
+ mask = self.scheduler.add_noise(
214
  torch.tensor(input_images[:, np.newaxis, :]), noise,
215
+ torch.tensor(self.scheduler.timesteps[start_step:]))
216
 
217
  images = images.to(self.device)
218
  for step, t in enumerate(
219
+ self.progress_bar(self.scheduler.timesteps[start_step:])):
220
  model_output = self.unet(images, t)['sample']
221
+ images = self.scheduler.step(model_output,
222
+ t,
223
+ images,
224
+ generator=generator)['prev_sample']
225
 
226
  if mask is not None:
227
  if mask_start > 0:
scripts/train_unconditional.py CHANGED
@@ -274,7 +274,6 @@ def main(args):
274
  mel=mel,
275
  generator=generator,
276
  batch_size=args.eval_batch_size,
277
- steps=args.num_train_steps,
278
  )
279
 
280
  # denormalize the images and save to tensorboard
 
274
  mel=mel,
275
  generator=generator,
276
  batch_size=args.eval_batch_size,
 
277
  )
278
 
279
  # denormalize the images and save to tensorboard