Spaces:
Paused
Paused
import torch | |
import torchvision.transforms as transforms | |
from torch.utils.data.dataset import Dataset | |
import torch.distributed as dist | |
import torchaudio | |
import torchvision | |
import torchvision.io | |
import os, io, csv, math, random | |
import os.path as osp | |
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
from einops import rearrange | |
import glob | |
from decord import VideoReader, AudioReader | |
import decord | |
from copy import deepcopy | |
import pickle | |
from petrel_client.client import Client | |
import sys | |
sys.path.append('./') | |
from foleycrafter.data import video_transforms | |
from foleycrafter.utils.util import \ | |
random_audio_video_clip, get_full_indices, video_tensor_to_np, get_video_frames | |
from foleycrafter.utils.spec_to_mel import wav_tensor_to_fbank, read_wav_file_io, load_audio, normalize_wav, pad_wav | |
from foleycrafter.utils.converter import get_mel_spectrogram_from_audio, pad_spec, normalize, normalize_spectrogram | |
def zero_rank_print(s): | |
if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s, flush=True) | |
def get_mel(audio_data, audio_cfg): | |
# mel shape: (n_mels, T) | |
mel = torchaudio.transforms.MelSpectrogram( | |
sample_rate=audio_cfg["sample_rate"], | |
n_fft=audio_cfg["window_size"], | |
win_length=audio_cfg["window_size"], | |
hop_length=audio_cfg["hop_size"], | |
center=True, | |
pad_mode="reflect", | |
power=2.0, | |
norm=None, | |
onesided=True, | |
n_mels=64, | |
f_min=audio_cfg["fmin"], | |
f_max=audio_cfg["fmax"], | |
).to(audio_data.device) | |
mel = mel(audio_data) | |
# we use log mel spectrogram as input | |
mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) | |
return mel # (T, n_mels) | |
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): | |
""" | |
PARAMS | |
------ | |
C: compression factor | |
""" | |
return normalize_fun(torch.clamp(x, min=clip_val) * C) | |
class CPU_Unpickler(pickle.Unpickler): | |
def find_class(self, module, name): | |
if module == 'torch.storage' and name == '_load_from_bytes': | |
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') | |
else: | |
return super().find_class(module, name) | |
class AudioSetStrong(Dataset): | |
# read feature and audio | |
def __init__( | |
self, | |
): | |
super().__init__() | |
self.data_path = 'data/AudioSetStrong/train/feature' | |
self.data_list = list(self._client.list(self.data_path)) | |
self.length = len(self.data_list) | |
# get video feature | |
self.video_path = 'data/AudioSetStrong/train/video' | |
vision_transform_list = [ | |
transforms.Resize((128, 128)), | |
transforms.CenterCrop((112, 112)), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
] | |
self.video_transform = transforms.Compose(vision_transform_list) | |
def get_batch(self, idx): | |
embeds = self.data_list[idx] | |
mel = embeds['mel'] | |
save_bsz = mel.shape[0] | |
audio_info = embeds['audio_info'] | |
text_embeds = embeds['text_embeds'] | |
# audio_info['label_list'] = np.array(audio_info['label_list']) | |
audio_info_array = np.array(audio_info['label_list']) | |
prompts = [] | |
for i in range(save_bsz): | |
prompts.append(', '.join(audio_info_array[i, :audio_info['event_num'][i]].tolist())) | |
# import ipdb; ipdb.set_trace() | |
# read videos | |
videos = None | |
for video_name in audio_info['audio_name']: | |
video_bytes = self._client.Get(osp.join(self.video_path, video_name+'.mp4')) | |
video_bytes = io.BytesIO(video_bytes) | |
video_reader = VideoReader(video_bytes) | |
video = video_reader.get_batch(get_full_indices(video_reader)).asnumpy() | |
video = get_video_frames(video, 150) | |
video = torch.from_numpy(video).permute(0, 3, 1, 2).contiguous().float() | |
video = self.video_transform(video) | |
video = video.unsqueeze(0) | |
if videos is None: | |
videos = video | |
else: | |
videos = torch.cat([videos, video], dim=0) | |
# video = torch.from_numpy(video).permute(0, 3, 1, 2).contiguous() | |
assert videos is not None, 'no video read' | |
return mel, audio_info, text_embeds, prompts, videos | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
while True: | |
try: | |
mel, audio_info, text_embeds, prompts, videos = self.get_batch(idx) | |
break | |
except Exception as e: | |
zero_rank_print(' >>> load error <<<') | |
idx = random.randint(0, self.length-1) | |
sample = dict(mel=mel, audio_info=audio_info, text_embeds=text_embeds, prompts=prompts, videos=videos) | |
return sample | |
class VGGSound(Dataset): | |
# read feature and audio | |
def __init__( | |
self, | |
): | |
super().__init__() | |
self.data_path = 'data/VGGSound/train/video' | |
self.visual_data_path = 'data/VGGSound/train/feature' | |
self.embeds_list = glob.glob(f'{self.data_path}/*.pt') | |
self.visual_list = glob.glob(f'{self.visual_data_path}/*.pt') | |
self.length = len(self.embeds_list) | |
def get_batch(self, idx): | |
embeds = torch.load(self.embeds_list[idx], map_location='cpu') | |
visual_embeds = torch.load(self.visual_list[idx], map_location='cpu') | |
# audio_embeds = embeds['audio_embeds'] | |
visual_embeds = visual_embeds['visual_embeds'] | |
video_name = embeds['video_name'] | |
text = embeds['text'] | |
mel = embeds['mel'] | |
audio = mel | |
return visual_embeds, audio, text | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
while True: | |
try: | |
visual_embeds, audio, text = self.get_batch(idx) | |
break | |
except Exception as e: | |
zero_rank_print('load error') | |
idx = random.randint(0, self.length-1) | |
sample = dict(visual_embeds=visual_embeds, audio=audio, text=text) | |
return sample |