audio-diffusion / src /audio_to_images.py
teticio's picture
add push_to_hub
e97d748
raw
history blame
2.47 kB
import os
import re
import io
import argparse
import pandas as pd
from tqdm.auto import tqdm
from datasets import Dataset, DatasetDict, Features, Image, Value
from mel import Mel
def main(args):
mel = Mel(x_res=args.resolution, y_res=args.resolution, hop_length=args.hop_length)
os.makedirs(args.output_dir, exist_ok=True)
audio_files = [
os.path.join(root, file)
for root, _, files in os.walk(args.input_dir)
for file in files
if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
]
examples = []
try:
for audio_file in tqdm(audio_files):
try:
mel.load_audio(audio_file)
except KeyboardInterrupt:
raise
except:
continue
for slice in range(mel.get_number_of_slices()):
image = mel.audio_slice_to_image(slice)
assert (
image.width == args.resolution and image.height == args.resolution
)
with io.BytesIO() as output:
image.save(output, format="PNG")
bytes = output.getvalue()
examples.extend(
[
{
"image": {"bytes": bytes},
"audio_file": audio_file,
"slice": slice,
}
]
)
finally:
ds = Dataset.from_pandas(
pd.DataFrame(examples),
features=Features(
{
"image": Image(),
"audio_file": Value(dtype="string"),
"slice": Value(dtype="int16"),
}
),
)
dsd = DatasetDict({"train": ds})
dsd.save_to_disk(os.path.join(args.output_dir))
if args.push_to_hub:
dsd.push_to_hub(args.push_to_hub)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Create dataset of Mel spectrograms from directory of audio files."
)
parser.add_argument("--input_dir", type=str)
parser.add_argument("--output_dir", type=str, default="data")
parser.add_argument("--resolution", type=int, default=256)
parser.add_argument("--hop_length", type=int, default=512)
parser.add_argument("--push_to_hub", type=str, default=None)
args = parser.parse_args()
main(args)