OpenMusic / qa_mdt /tools.py
jadechoghari's picture
add model
9b9e0ee verified
raw
history blame
18.7 kB
# Author: Haohe Liu
# Email: [email protected]
# Date: 11 Feb 2023
import os
import json
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib
from scipy.io import wavfile
from matplotlib import pyplot as plt
matplotlib.use("Agg")
import hashlib
import os
import requests
from tqdm import tqdm
URL_MAP = {
"vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt",
"vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt",
"melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt",
}
CKPT_MAP = {
"vggishish_lpaps": "vggishish16.pt",
"vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt",
"melception": "melception-21-05-10T09-28-40.pt",
}
MD5_MAP = {
"vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd",
"vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625",
"melception": "a71a41041e945b457c7d3d814bbcf72d",
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def read_list(fname):
result = []
with open(fname, "r") as f:
for each in f.readlines():
each = each.strip("\n")
result.append(each)
return result
def build_dataset_json_from_list(list_path):
data = []
for each in read_list(list_path):
if "|" in each:
wav, caption = each.split("|")
else:
caption = each
wav = ""
data.append(
{
"wav": wav,
"caption": caption,
}
)
return {"data": data}
def load_json(fname):
with open(fname, "r") as f:
data = json.load(f)
return data
def read_json(dataset_json_file):
with open(dataset_json_file, "r") as fp:
data_json = json.load(fp)
return data_json["data"]
def copy_test_subset_data(metadata, testset_copy_target_path):
# metadata = read_json(testset_metadata)
os.makedirs(testset_copy_target_path, exist_ok=True)
if len(os.listdir(testset_copy_target_path)) == len(metadata):
return
else:
# delete files in folder testset_copy_target_path
for file in os.listdir(testset_copy_target_path):
try:
os.remove(os.path.join(testset_copy_target_path, file))
except Exception as e:
print(e)
print("Copying test subset data to {}".format(testset_copy_target_path))
for each in tqdm(metadata):
cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path))
os.system(cmd)
def listdir_nohidden(path):
for f in os.listdir(path):
if not f.startswith("."):
yield f
def get_restore_step(path):
checkpoints = os.listdir(path)
if os.path.exists(os.path.join(path, "final.ckpt")):
return "final.ckpt", 0
elif not os.path.exists(os.path.join(path, "last.ckpt")):
steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints]
return checkpoints[np.argmax(steps)], np.max(steps)
else:
steps = []
for x in checkpoints:
if "last" in x:
if "-v" not in x:
fname = "last.ckpt"
else:
this_version = int(x.split(".ckpt")[0].split("-v")[1])
steps.append(this_version)
if len(steps) == 0 or this_version > np.max(steps):
fname = "last-v%s.ckpt" % this_version
return fname, 0
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class KeyNotFoundError(Exception):
def __init__(self, cause, keys=None, visited=None):
self.cause = cause
self.keys = keys
self.visited = visited
messages = list()
if keys is not None:
messages.append("Key not found: {}".format(keys))
if visited is not None:
messages.append("Visited: {}".format(visited))
messages.append("Cause:\n{}".format(cause))
message = "\n".join(messages)
super().__init__(message)
def retrieve(
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
"""Given a nested list or dict return the desired value at key expanding
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
is done in-place.
Parameters
----------
list_or_dict : list or dict
Possibly nested list or dictionary.
key : str
key/to/value, path like string describing all keys necessary to
consider to get to the desired value. List indices can also be
passed here.
splitval : str
String that defines the delimiter between keys of the
different depth levels in `key`.
default : obj
Value returned if :attr:`key` is not found.
expand : bool
Whether to expand callable nodes on the path or not.
Returns
-------
The desired value or if :attr:`default` is not ``None`` and the
:attr:`key` is not found returns ``default``.
Raises
------
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
``None``.
"""
keys = key.split(splitval)
success = True
try:
visited = []
parent = None
last_key = None
for key in keys:
if callable(list_or_dict):
if not expand:
raise KeyNotFoundError(
ValueError(
"Trying to get past callable node with expand=False."
),
keys=keys,
visited=visited,
)
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
last_key = key
parent = list_or_dict
try:
if isinstance(list_or_dict, dict):
list_or_dict = list_or_dict[key]
else:
list_or_dict = list_or_dict[int(key)]
except (KeyError, IndexError, ValueError) as e:
raise KeyNotFoundError(e, keys=keys, visited=visited)
visited += [key]
# final expansion of retrieved value
if expand and callable(list_or_dict):
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
except KeyNotFoundError as e:
if default is None:
raise e
else:
list_or_dict = default
success = False
if not pass_success:
return list_or_dict
else:
return list_or_dict, success
def to_device(data, device):
if len(data) == 12:
(
ids,
raw_texts,
speakers,
texts,
src_lens,
max_src_len,
mels,
mel_lens,
max_mel_len,
pitches,
energies,
durations,
) = data
speakers = torch.from_numpy(speakers).long().to(device)
texts = torch.from_numpy(texts).long().to(device)
src_lens = torch.from_numpy(src_lens).to(device)
mels = torch.from_numpy(mels).float().to(device)
mel_lens = torch.from_numpy(mel_lens).to(device)
pitches = torch.from_numpy(pitches).float().to(device)
energies = torch.from_numpy(energies).to(device)
durations = torch.from_numpy(durations).long().to(device)
return (
ids,
raw_texts,
speakers,
texts,
src_lens,
max_src_len,
mels,
mel_lens,
max_mel_len,
pitches,
energies,
durations,
)
if len(data) == 6:
(ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
speakers = torch.from_numpy(speakers).long().to(device)
texts = torch.from_numpy(texts).long().to(device)
src_lens = torch.from_numpy(src_lens).to(device)
return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""):
# if losses is not None:
# logger.add_scalar("Loss/total_loss", losses[0], step)
# logger.add_scalar("Loss/mel_loss", losses[1], step)
# logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
# logger.add_scalar("Loss/pitch_loss", losses[3], step)
# logger.add_scalar("Loss/energy_loss", losses[4], step)
# logger.add_scalar("Loss/duration_loss", losses[5], step)
# if(len(losses) > 6):
# logger.add_scalar("Loss/disc_loss", losses[6], step)
# logger.add_scalar("Loss/fmap_loss", losses[7], step)
# logger.add_scalar("Loss/r_loss", losses[8], step)
# logger.add_scalar("Loss/g_loss", losses[9], step)
# logger.add_scalar("Loss/gen_loss", losses[10], step)
# logger.add_scalar("Loss/diff_loss", losses[11], step)
if fig is not None:
logger.add_figure(tag, fig)
if audio is not None:
audio = audio / (max(abs(audio)) * 1.1)
logger.add_audio(
tag,
audio,
sample_rate=sampling_rate,
)
def get_mask_from_lengths(lengths, max_len=None):
batch_size = lengths.shape[0]
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
return mask
def expand(values, durations):
out = list()
for value, d in zip(values, durations):
out += [value] * max(0, int(d))
return np.array(out)
def synth_one_sample_val(
targets, predictions, vocoder, model_config, preprocess_config
):
index = np.random.choice(list(np.arange(targets[6].size(0))))
basename = targets[0][index]
src_len = predictions[8][index].item()
mel_len = predictions[9][index].item()
mel_target = targets[6][index, :mel_len].detach().transpose(0, 1)
mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1)
postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1)
duration = targets[11][index, :src_len].detach().cpu().numpy()
if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
pitch = predictions[2][index, :src_len].detach().cpu().numpy()
pitch = expand(pitch, duration)
else:
pitch = predictions[2][index, :mel_len].detach().cpu().numpy()
if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
energy = predictions[3][index, :src_len].detach().cpu().numpy()
energy = expand(energy, duration)
else:
energy = predictions[3][index, :mel_len].detach().cpu().numpy()
with open(
os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
) as f:
stats = json.load(f)
stats = stats["pitch"] + stats["energy"][:2]
# from datetime import datetime
# now = datetime.now()
# current_time = now.strftime("%D:%H:%M:%S")
# np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy())
# np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy())
# np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy())
fig = plot_mel(
[
(mel_prediction.cpu().numpy(), pitch, energy),
(postnet_mel_prediction.cpu().numpy(), pitch, energy),
(mel_target.cpu().numpy(), pitch, energy),
],
stats,
[
"Raw mel spectrogram prediction",
"Postnet mel prediction",
"Ground-Truth Spectrogram",
],
)
if vocoder is not None:
from .model_util import vocoder_infer
wav_reconstruction = vocoder_infer(
mel_target.unsqueeze(0),
vocoder,
model_config,
preprocess_config,
)[0]
wav_prediction = vocoder_infer(
postnet_mel_prediction.unsqueeze(0),
vocoder,
model_config,
preprocess_config,
)[0]
else:
wav_reconstruction = wav_prediction = None
return fig, wav_reconstruction, wav_prediction, basename
def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
if vocoder is not None:
from .model_util import vocoder_infer
wav_reconstruction = vocoder_infer(
mel_input.permute(0, 2, 1),
vocoder,
)
wav_prediction = vocoder_infer(
mel_prediction.permute(0, 2, 1),
vocoder,
)
else:
wav_reconstruction = wav_prediction = None
return wav_reconstruction, wav_prediction
def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
# (diff_output, diff_loss, latent_loss) = diffusion
basenames = targets[0]
for i in range(len(predictions[1])):
basename = basenames[i]
src_len = predictions[8][i].item()
mel_len = predictions[9][i].item()
mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
# diff_output = diff_output[i, :mel_len].detach().transpose(0, 1)
# duration = predictions[5][i, :src_len].detach().cpu().numpy()
if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
pitch = predictions[2][i, :src_len].detach().cpu().numpy()
# pitch = expand(pitch, duration)
else:
pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
energy = predictions[3][i, :src_len].detach().cpu().numpy()
# energy = expand(energy, duration)
else:
energy = predictions[3][i, :mel_len].detach().cpu().numpy()
# import ipdb; ipdb.set_trace()
with open(
os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
) as f:
stats = json.load(f)
stats = stats["pitch"] + stats["energy"][:2]
fig = plot_mel(
[
(mel_prediction.cpu().numpy(), pitch, energy),
],
stats,
["Synthetized Spectrogram by PostNet"],
)
# np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy())
plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename)))
plt.close()
from .model_util import vocoder_infer
mel_predictions = predictions[1].transpose(1, 2)
lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
wav_predictions = vocoder_infer(
mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
)
sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
for wav, basename in zip(wav_predictions, basenames):
wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
def plot_mel(data, titles=None):
fig, axes = plt.subplots(len(data), 1, squeeze=False)
if titles is None:
titles = [None for i in range(len(data))]
for i in range(len(data)):
mel = data[i]
axes[i][0].imshow(mel, origin="lower", aspect="auto")
axes[i][0].set_aspect(2.5, adjustable="box")
axes[i][0].set_ylim(0, mel.shape[0])
axes[i][0].set_title(titles[i], fontsize="medium")
axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
axes[i][0].set_anchor("W")
return fig
def pad_1D(inputs, PAD=0):
def pad_data(x, length, PAD):
x_padded = np.pad(
x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
)
return x_padded
max_len = max((len(x) for x in inputs))
padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
return padded
def pad_2D(inputs, maxlen=None):
def pad(x, max_len):
PAD = 0
if np.shape(x)[0] > max_len:
raise ValueError("not max_len")
s = np.shape(x)[1]
x_padded = np.pad(
x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
)
return x_padded[:, :s]
if maxlen:
output = np.stack([pad(x, maxlen) for x in inputs])
else:
max_len = max(np.shape(x)[0] for x in inputs)
output = np.stack([pad(x, max_len) for x in inputs])
return output
def pad(input_ele, mel_max_length=None):
if mel_max_length:
max_len = mel_max_length
else:
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
out_list = list()
for i, batch in enumerate(input_ele):
if len(batch.shape) == 1:
one_batch_padded = F.pad(
batch, (0, max_len - batch.size(0)), "constant", 0.0
)
elif len(batch.shape) == 2:
one_batch_padded = F.pad(
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
)
out_list.append(one_batch_padded)
out_padded = torch.stack(out_list)
return out_padded