teticio commited on
Commit
327bccf
1 Parent(s): e3b5a6d

move resolution specification to dataset generation

Browse files
README.md CHANGED
@@ -57,7 +57,7 @@ pip install .
57
 
58
  ```bash
59
  python scripts/audio_to_images.py \
60
- --resolution 64 \
61
  --hop_length 1024 \
62
  --input_dir path-to-audio-files \
63
  --output_dir path-to-output-data
@@ -78,7 +78,6 @@ python scripts/audio_to_images.py \
78
  accelerate launch --config_file config/accelerate_local.yaml \
79
  scripts/train_unconditional.py \
80
  --dataset_name data/audio-diffusion-64 \
81
- --resolution 64 \
82
  --hop_length 1024 \
83
  --output_dir models/ddpm-ema-audio-64 \
84
  --train_batch_size 16 \
@@ -94,7 +93,6 @@ accelerate launch --config_file config/accelerate_local.yaml \
94
  accelerate launch --config_file config/accelerate_local.yaml \
95
  scripts/train_unconditional.py \
96
  --dataset_name teticio/audio-diffusion-256 \
97
- --resolution 256 \
98
  --output_dir models/audio-diffusion-256 \
99
  --num_epochs 100 \
100
  --train_batch_size 2 \
@@ -113,7 +111,6 @@ accelerate launch --config_file config/accelerate_local.yaml \
113
  accelerate launch --config_file config/accelerate_sagemaker.yaml \
114
  scripts/train_unconditional.py \
115
  --dataset_name teticio/audio-diffusion-256 \
116
- --resolution 256 \
117
  --output_dir models/ddpm-ema-audio-256 \
118
  --train_batch_size 16 \
119
  --num_epochs 100 \
@@ -147,5 +144,4 @@ python scripts/train_vae.py \
147
  accelerate launch ...
148
  ...
149
  --vae models/autoencoder-kl
150
- --latent_resoultion 32
151
  ```
 
57
 
58
  ```bash
59
  python scripts/audio_to_images.py \
60
+ --resolution 64,64 \
61
  --hop_length 1024 \
62
  --input_dir path-to-audio-files \
63
  --output_dir path-to-output-data
 
78
  accelerate launch --config_file config/accelerate_local.yaml \
79
  scripts/train_unconditional.py \
80
  --dataset_name data/audio-diffusion-64 \
 
81
  --hop_length 1024 \
82
  --output_dir models/ddpm-ema-audio-64 \
83
  --train_batch_size 16 \
 
93
  accelerate launch --config_file config/accelerate_local.yaml \
94
  scripts/train_unconditional.py \
95
  --dataset_name teticio/audio-diffusion-256 \
 
96
  --output_dir models/audio-diffusion-256 \
97
  --num_epochs 100 \
98
  --train_batch_size 2 \
 
111
  accelerate launch --config_file config/accelerate_sagemaker.yaml \
112
  scripts/train_unconditional.py \
113
  --dataset_name teticio/audio-diffusion-256 \
 
114
  --output_dir models/ddpm-ema-audio-256 \
115
  --train_batch_size 16 \
116
  --num_epochs 100 \
 
144
  accelerate launch ...
145
  ...
146
  --vae models/autoencoder-kl
 
147
  ```
audiodiffusion/__init__.py CHANGED
@@ -180,10 +180,12 @@ class AudioDiffusionPipeline(DiffusionPipeline):
180
  if steps is not None:
181
  self.scheduler.set_timesteps(steps)
182
  mask = None
183
- images = noise = torch.randn(
184
- (batch_size, self.unet.in_channels, self.unet.sample_size[0],
185
- self.unet.sample_size[1]),
186
- generator=generator)
 
 
187
 
188
  if audio_file is not None or raw_audio is not None:
189
  mel.load_audio(audio_file, raw_audio)
@@ -205,9 +207,8 @@ class AudioDiffusionPipeline(DiffusionPipeline):
205
  torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
206
  noise, torch.tensor(steps - start_step))
207
 
208
- pixels_per_second = (mel.get_sample_rate() *
209
- mel.y_res / mel.hop_length /
210
- mel.x_res)
211
  mask_start = int(mask_start_secs * pixels_per_second)
212
  mask_end = int(mask_end_secs * pixels_per_second)
213
  mask = self.scheduler.add_noise(
 
180
  if steps is not None:
181
  self.scheduler.set_timesteps(steps)
182
  mask = None
183
+ # For backwards compatibility
184
+ sample_size = (self.unet.sample_size, self.unet.sample_size) if type(
185
+ self.unet.sample_size) == int else self.unet.sample_size
186
+ images = noise = torch.randn((batch_size, self.unet.in_channels) +
187
+ sample_size,
188
+ generator=generator)
189
 
190
  if audio_file is not None or raw_audio is not None:
191
  mel.load_audio(audio_file, raw_audio)
 
207
  torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
208
  noise, torch.tensor(steps - start_step))
209
 
210
+ pixels_per_second = (mel.get_sample_rate() * sample_size[1] /
211
+ mel.hop_length / mel.x_res)
 
212
  mask_start = int(mask_start_secs * pixels_per_second)
213
  mask_end = int(mask_end_secs * pixels_per_second)
214
  mask = self.scheduler.add_noise(
audiodiffusion/mel.py CHANGED
@@ -106,7 +106,7 @@ class Mel:
106
  log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
107
  bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) +
108
  0.5).astype(np.uint8)
109
- image = Image.frombytes("L", log_S.shape, bytedata.tobytes())
110
  return image
111
 
112
  def image_to_audio(self, image: Image.Image) -> np.ndarray:
 
106
  log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
107
  bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) +
108
  0.5).astype(np.uint8)
109
+ image = Image.fromarray(bytedata)
110
  return image
111
 
112
  def image_to_audio(self, image: Image.Image) -> np.ndarray:
scripts/audio_to_images.py CHANGED
@@ -16,9 +16,9 @@ logger = logging.getLogger('audio_to_images')
16
 
17
 
18
  def main(args):
19
- mel = Mel(x_res=args.resolution,
20
- y_res=args.resolution,
21
- hop_length=args.hop_length)
22
  os.makedirs(args.output_dir, exist_ok=True)
23
  audio_files = [
24
  os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
@@ -35,8 +35,8 @@ def main(args):
35
  continue
36
  for slice in range(mel.get_number_of_slices()):
37
  image = mel.audio_slice_to_image(slice)
38
- assert (image.width == args.resolution
39
- and image.height == args.resolution)
40
  # skip completely silent slices
41
  if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255):
42
  logger.warn('File %s slice %d is completely silent',
@@ -52,6 +52,8 @@ def main(args):
52
  "audio_file": audio_file,
53
  "slice": slice,
54
  }])
 
 
55
  finally:
56
  if len(examples) == 0:
57
  logger.warn('No valid audio files were found.')
@@ -76,12 +78,30 @@ if __name__ == "__main__":
76
  "Create dataset of Mel spectrograms from directory of audio files.")
77
  parser.add_argument("--input_dir", type=str)
78
  parser.add_argument("--output_dir", type=str, default="data")
79
- parser.add_argument("--resolution", type=int, default=256)
 
 
 
80
  parser.add_argument("--hop_length", type=int, default=512)
81
  parser.add_argument("--push_to_hub", type=str, default=None)
82
  args = parser.parse_args()
 
83
  if args.input_dir is None:
84
  raise ValueError(
85
- "You must specify an input directory for the audio files."
86
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  main(args)
 
16
 
17
 
18
  def main(args):
19
+ mel = Mel(x_res=args.resolution[0],
20
+ y_res=args.resolution[1],
21
+ hop_length=args.hop_length)
22
  os.makedirs(args.output_dir, exist_ok=True)
23
  audio_files = [
24
  os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
 
35
  continue
36
  for slice in range(mel.get_number_of_slices()):
37
  image = mel.audio_slice_to_image(slice)
38
+ assert (image.width == args.resolution[0] and image.height
39
+ == args.resolution[1]), "Wrong resolution"
40
  # skip completely silent slices
41
  if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255):
42
  logger.warn('File %s slice %d is completely silent',
 
52
  "audio_file": audio_file,
53
  "slice": slice,
54
  }])
55
+ except Exception as e:
56
+ print(e)
57
  finally:
58
  if len(examples) == 0:
59
  logger.warn('No valid audio files were found.')
 
78
  "Create dataset of Mel spectrograms from directory of audio files.")
79
  parser.add_argument("--input_dir", type=str)
80
  parser.add_argument("--output_dir", type=str, default="data")
81
+ parser.add_argument("--resolution",
82
+ type=str,
83
+ default="256",
84
+ help="Either square resolution or width,height.")
85
  parser.add_argument("--hop_length", type=int, default=512)
86
  parser.add_argument("--push_to_hub", type=str, default=None)
87
  args = parser.parse_args()
88
+
89
  if args.input_dir is None:
90
  raise ValueError(
91
+ "You must specify an input directory for the audio files.")
92
+
93
+ # Handle the resolutions.
94
+ try:
95
+ args.resolution = (int(args.resolution), int(args.resolution))
96
+ except ValueError:
97
+ try:
98
+ args.resolution = tuple(int(x) for x in args.resolution.split(","))
99
+ if len(args.resolution) != 2:
100
+ raise ValueError
101
+ except ValueError:
102
+ raise ValueError(
103
+ "Resolution must be a tuple of two integers or a single integer."
104
+ )
105
+ assert isinstance(args.resolution, tuple)
106
+
107
  main(args)
scripts/train_unconditional.py CHANGED
@@ -26,9 +26,6 @@ import numpy as np
26
  from tqdm.auto import tqdm
27
  from librosa.util import normalize
28
 
29
- import sys
30
- sys.path.append('.')
31
- sys.path.append('..')
32
  from audiodiffusion.mel import Mel
33
  from audiodiffusion import LatentAudioDiffusionPipeline, AudioDiffusionPipeline
34
 
@@ -45,31 +42,68 @@ def main(args):
45
  logging_dir=logging_dir,
46
  )
47
 
48
- # Handle the resolutions.
49
- try:
50
- args.resolution = (int(args.resolution), int(args.resolution))
51
- except:
52
- try :
53
- args.resolution = tuple(int(x) for x in args.resolution.split(","))
54
- if len(args.resolution) != 2:
55
- raise ValueError("Resolution must be a tuple of two integers or a single integer.")
56
- except:
57
- raise ValueError("Resolution must be a tuple of two integers or a single integer.")
58
- assert isinstance(args.resolution, tuple)
 
 
 
 
 
 
 
 
 
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if args.vae is not None:
61
  vqvae = AutoencoderKL.from_pretrained(args.vae)
 
 
 
 
 
62
 
63
  if args.from_pretrained is not None:
64
- model = DiffusionPipeline.from_pretrained(args.from_pretrained).unet
 
 
 
65
  else:
66
  model = UNet2DModel(
67
- sample_size=args.resolution
68
- if args.vae is None else args.latent_resolution,
69
  in_channels=1
70
- if args.vae is None else vqvae.config['latent_channels'],
71
  out_channels=1
72
- if args.vae is None else vqvae.config['latent_channels'],
73
  layers_per_block=2,
74
  block_out_channels=(128, 128, 256, 256, 512, 512),
75
  down_block_types=(
@@ -105,47 +139,6 @@ def main(args):
105
  eps=args.adam_epsilon,
106
  )
107
 
108
- augmentations = Compose([
109
- Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
110
- CenterCrop(args.resolution),
111
- ToTensor(),
112
- Normalize([0.5], [0.5]),
113
- ])
114
-
115
- if args.dataset_name is not None:
116
- if os.path.exists(args.dataset_name):
117
- dataset = load_from_disk(args.dataset_name,
118
- args.dataset_config_name)["train"]
119
- else:
120
- dataset = load_dataset(
121
- args.dataset_name,
122
- args.dataset_config_name,
123
- cache_dir=args.cache_dir,
124
- use_auth_token=True if args.use_auth_token else None,
125
- split="train",
126
- )
127
- else:
128
- dataset = load_dataset(
129
- "imagefolder",
130
- data_dir=args.train_data_dir,
131
- cache_dir=args.cache_dir,
132
- split="train",
133
- )
134
-
135
- def transforms(examples):
136
- if args.vae is not None and vqvae.config['in_channels'] == 3:
137
- images = [
138
- augmentations(image.convert('RGB'))
139
- for image in examples["image"]
140
- ]
141
- else:
142
- images = [augmentations(image) for image in examples["image"]]
143
- return {"input": images}
144
-
145
- dataset.set_transform(transforms)
146
- train_dataloader = torch.utils.data.DataLoader(
147
- dataset, batch_size=args.train_batch_size, shuffle=True)
148
-
149
  lr_scheduler = get_scheduler(
150
  args.lr_scheduler,
151
  optimizer=optimizer,
@@ -171,9 +164,9 @@ def main(args):
171
  run = os.path.split(__file__)[-1].split(".")[0]
172
  accelerator.init_trackers(run)
173
 
174
- mel = Mel(x_res=args.resolution[0],
175
- y_res=args.resolution[1],
176
- hop_length=args.hop_length)
177
 
178
  global_step = 0
179
  for epoch in range(args.num_epochs):
@@ -195,7 +188,7 @@ def main(args):
195
  for step, batch in enumerate(train_dataloader):
196
  clean_images = batch["input"]
197
 
198
- if args.vae is not None:
199
  vqvae.to(clean_images.device)
200
  with torch.no_grad():
201
  clean_images = vqvae.encode(
@@ -252,7 +245,7 @@ def main(args):
252
  # Generate sample images for visual inspection
253
  if accelerator.is_main_process:
254
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
255
- if args.vae is not None:
256
  pipeline = LatentAudioDiffusionPipeline(
257
  unet=accelerator.unwrap_model(
258
  ema_model.averaged_model if args.use_ema else model
@@ -326,7 +319,6 @@ if __name__ == "__main__":
326
  parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
327
  parser.add_argument("--overwrite_output_dir", type=bool, default=False)
328
  parser.add_argument("--cache_dir", type=str, default=None)
329
- parser.add_argument("--resolution", type=str, default="256")
330
  parser.add_argument("--train_batch_size", type=int, default=16)
331
  parser.add_argument("--eval_batch_size", type=int, default=16)
332
  parser.add_argument("--num_epochs", type=int, default=100)
@@ -364,7 +356,6 @@ if __name__ == "__main__":
364
  parser.add_argument("--from_pretrained", type=str, default=None)
365
  parser.add_argument("--start_epoch", type=int, default=0)
366
  parser.add_argument("--num_train_steps", type=int, default=1000)
367
- parser.add_argument("--latent_resolution", type=int, default=None)
368
  parser.add_argument("--scheduler",
369
  type=str,
370
  default="ddpm",
 
26
  from tqdm.auto import tqdm
27
  from librosa.util import normalize
28
 
 
 
 
29
  from audiodiffusion.mel import Mel
30
  from audiodiffusion import LatentAudioDiffusionPipeline, AudioDiffusionPipeline
31
 
 
42
  logging_dir=logging_dir,
43
  )
44
 
45
+ if args.dataset_name is not None:
46
+ if os.path.exists(args.dataset_name):
47
+ dataset = load_from_disk(args.dataset_name,
48
+ args.dataset_config_name)["train"]
49
+ else:
50
+ dataset = load_dataset(
51
+ args.dataset_name,
52
+ args.dataset_config_name,
53
+ cache_dir=args.cache_dir,
54
+ use_auth_token=True if args.use_auth_token else None,
55
+ split="train",
56
+ )
57
+ else:
58
+ dataset = load_dataset(
59
+ "imagefolder",
60
+ data_dir=args.train_data_dir,
61
+ cache_dir=args.cache_dir,
62
+ split="train",
63
+ )
64
+ # Determine image resolution
65
+ resolution = dataset[0]['image'].height, dataset[0]['image'].width
66
 
67
+ augmentations = Compose([
68
+ ToTensor(),
69
+ Normalize([0.5], [0.5]),
70
+ ])
71
+
72
+ def transforms(examples):
73
+ if args.vae is not None and vqvae.config['in_channels'] == 3:
74
+ images = [
75
+ augmentations(image.convert('RGB'))
76
+ for image in examples["image"]
77
+ ]
78
+ else:
79
+ images = [augmentations(image) for image in examples["image"]]
80
+ return {"input": images}
81
+
82
+ dataset.set_transform(transforms)
83
+ train_dataloader = torch.utils.data.DataLoader(
84
+ dataset, batch_size=args.train_batch_size, shuffle=True)
85
+
86
+ vqvae = None
87
  if args.vae is not None:
88
  vqvae = AutoencoderKL.from_pretrained(args.vae)
89
+ # Determine latent resolution
90
+ with torch.no_grad():
91
+ latent_resolution = vqvae.encode(
92
+ torch.zeros((1, 1) +
93
+ resolution)).latent_dist.sample().shape[2:]
94
 
95
  if args.from_pretrained is not None:
96
+ pipeline = DiffusionPipeline.from_pretrained(args.from_pretrained)
97
+ model = pipeline.unet
98
+ if hasattr(pipeline, 'vqvae'):
99
+ vqvae = AutoencoderKL.from_pretrained(args.vae)
100
  else:
101
  model = UNet2DModel(
102
+ sample_size=resolution if vqvae is None else latent_resolution,
 
103
  in_channels=1
104
+ if vqvae is None else vqvae.config['latent_channels'],
105
  out_channels=1
106
+ if vqvae is None else vqvae.config['latent_channels'],
107
  layers_per_block=2,
108
  block_out_channels=(128, 128, 256, 256, 512, 512),
109
  down_block_types=(
 
139
  eps=args.adam_epsilon,
140
  )
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  lr_scheduler = get_scheduler(
143
  args.lr_scheduler,
144
  optimizer=optimizer,
 
164
  run = os.path.split(__file__)[-1].split(".")[0]
165
  accelerator.init_trackers(run)
166
 
167
+ mel = Mel(x_res=resolution[1],
168
+ y_res=resolution[0],
169
+ hop_length=args.hop_length)
170
 
171
  global_step = 0
172
  for epoch in range(args.num_epochs):
 
188
  for step, batch in enumerate(train_dataloader):
189
  clean_images = batch["input"]
190
 
191
+ if vqvae is not None:
192
  vqvae.to(clean_images.device)
193
  with torch.no_grad():
194
  clean_images = vqvae.encode(
 
245
  # Generate sample images for visual inspection
246
  if accelerator.is_main_process:
247
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
248
+ if vqvae is not None:
249
  pipeline = LatentAudioDiffusionPipeline(
250
  unet=accelerator.unwrap_model(
251
  ema_model.averaged_model if args.use_ema else model
 
319
  parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
320
  parser.add_argument("--overwrite_output_dir", type=bool, default=False)
321
  parser.add_argument("--cache_dir", type=str, default=None)
 
322
  parser.add_argument("--train_batch_size", type=int, default=16)
323
  parser.add_argument("--eval_batch_size", type=int, default=16)
324
  parser.add_argument("--num_epochs", type=int, default=100)
 
356
  parser.add_argument("--from_pretrained", type=str, default=None)
357
  parser.add_argument("--start_epoch", type=int, default=0)
358
  parser.add_argument("--num_train_steps", type=int, default=1000)
 
359
  parser.add_argument("--scheduler",
360
  type=str,
361
  default="ddpm",
scripts/train_vae.py CHANGED
@@ -58,13 +58,10 @@ class AudioDiffusionDataModule(pl.LightningDataModule):
58
 
59
  class ImageLogger(Callback):
60
 
61
- def __init__(self, every=1000, channels=3, resolution=256, hop_length=512):
62
  super().__init__()
63
- self.mel = Mel(x_res=resolution,
64
- y_res=resolution,
65
- hop_length=hop_length)
66
  self.every = every
67
- self.channels = channels
68
 
69
  @rank_zero_only
70
  def log_images_and_audios(self, pl_module, batch):
@@ -73,6 +70,12 @@ class ImageLogger(Callback):
73
  images = pl_module.log_images(batch, split='train')
74
  pl_module.train()
75
 
 
 
 
 
 
 
76
  for k in images:
77
  images[k] = images[k].detach().cpu()
78
  images[k] = torch.clamp(images[k], -1., 1.)
@@ -86,14 +89,14 @@ class ImageLogger(Callback):
86
  images[k] = (images[k].numpy() *
87
  255).round().astype("uint8").transpose(0, 2, 3, 1)
88
  for _, image in enumerate(images[k]):
89
- audio = self.mel.image_to_audio(
90
- Image.fromarray(image, mode='RGB').convert('L') if self.
91
- channels == 3 else Image.fromarray(image[0]))
92
  pl_module.logger.experiment.add_audio(
93
  tag + f"/{_}",
94
  normalize(audio),
95
  global_step=pl_module.global_step,
96
- sample_rate=self.mel.get_sample_rate())
97
 
98
  def on_train_batch_end(self, trainer, pl_module, outputs, batch,
99
  batch_idx):
@@ -139,7 +142,6 @@ if __name__ == "__main__":
139
  "--gradient_accumulation_steps",
140
  type=int,
141
  default=1)
142
- parser.add_argument("--resolution", type=int, default=256)
143
  parser.add_argument("--hop_length", type=int, default=512)
144
  parser.add_argument("--save_images_batches", type=int, default=1000)
145
  args = parser.parse_args()
@@ -160,8 +162,6 @@ if __name__ == "__main__":
160
  resume_from_checkpoint=args.resume_from_checkpoint,
161
  callbacks=[
162
  ImageLogger(every=args.save_images_batches,
163
- channels=config.model.params.ddconfig.out_ch,
164
- resolution=args.resolution,
165
  hop_length=args.hop_length),
166
  HFModelCheckpoint(ldm_config=config,
167
  hf_checkpoint=args.hf_checkpoint_dir,
 
58
 
59
  class ImageLogger(Callback):
60
 
61
+ def __init__(self, every=1000, hop_length=512):
62
  super().__init__()
 
 
 
63
  self.every = every
64
+ self.hop_length = hop_length
65
 
66
  @rank_zero_only
67
  def log_images_and_audios(self, pl_module, batch):
 
70
  images = pl_module.log_images(batch, split='train')
71
  pl_module.train()
72
 
73
+ image_shape = next(iter(images.values())).shape
74
+ channels = image_shape[1]
75
+ mel = Mel(x_res=image_shape[2],
76
+ y_res=image_shape[3],
77
+ hop_length=self.hop_length)
78
+
79
  for k in images:
80
  images[k] = images[k].detach().cpu()
81
  images[k] = torch.clamp(images[k], -1., 1.)
 
89
  images[k] = (images[k].numpy() *
90
  255).round().astype("uint8").transpose(0, 2, 3, 1)
91
  for _, image in enumerate(images[k]):
92
+ audio = mel.image_to_audio(
93
+ Image.fromarray(image, mode='RGB').convert('L')
94
+ if channels == 3 else Image.fromarray(image[0]))
95
  pl_module.logger.experiment.add_audio(
96
  tag + f"/{_}",
97
  normalize(audio),
98
  global_step=pl_module.global_step,
99
+ sample_rate=mel.get_sample_rate())
100
 
101
  def on_train_batch_end(self, trainer, pl_module, outputs, batch,
102
  batch_idx):
 
142
  "--gradient_accumulation_steps",
143
  type=int,
144
  default=1)
 
145
  parser.add_argument("--hop_length", type=int, default=512)
146
  parser.add_argument("--save_images_batches", type=int, default=1000)
147
  args = parser.parse_args()
 
162
  resume_from_checkpoint=args.resume_from_checkpoint,
163
  callbacks=[
164
  ImageLogger(every=args.save_images_batches,
 
 
165
  hop_length=args.hop_length),
166
  HFModelCheckpoint(ldm_config=config,
167
  hf_checkpoint=args.hf_checkpoint_dir,