lmzjms's picture
Upload 5 files
5f898a2
raw
history blame contribute delete
No virus
6.57 kB
import torch
import numpy as np
import gradio as gr
from PIL import Image
import matplotlib
from omegaconf import OmegaConf
from einops import repeat
import librosa
from ldm.models.diffusion.ddim import DDIMSampler
from vocoder.bigvgan.models import VocoderBigVGAN
from ldm.util import instantiate_from_config
from ldm.data.extract_mel_spectrogram import TRANSFORMS_16000
SAMPLE_RATE = 16000
cmap_transform = matplotlib.cm.viridis
torch.set_grad_enabled(False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def initialize_model(config, ckpt):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
model = model.to(device)
print(model.device,device,model.cond_stage_model.device)
sampler = DDIMSampler(model)
return sampler
def make_batch_sd(
mel,
mask,
device,
num_samples=1):
mel = torch.from_numpy(mel)[None,None,...].to(dtype=torch.float32)
mask = torch.from_numpy(mask)[None,None,...].to(dtype=torch.float32)
masked_mel = (1 - mask) * mel
mel = mel * 2 - 1
mask = mask * 2 - 1
masked_mel = masked_mel * 2 -1
batch = {
"mel": repeat(mel.to(device=device), "1 ... -> n ...", n=num_samples),
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_mel": repeat(masked_mel.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def gen_mel(input_audio):
sr,ori_wav = input_audio
print(sr,ori_wav.shape,ori_wav)
ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0 # order='C'是以C语言格式存储,不用管
if len(ori_wav.shape)==2:# stereo
ori_wav = librosa.to_mono(ori_wav.T)# gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len)
print(sr,ori_wav.shape,ori_wav)
ori_wav = librosa.resample(ori_wav,orig_sr = sr,target_sr = SAMPLE_RATE)
mel_len,hop_size = 848,256
input_len = mel_len * hop_size
if len(ori_wav) < input_len:
input_wav = np.pad(ori_wav,(0,mel_len*hop_size),constant_values=0)
else:
input_wav = ori_wav[:input_len]
mel = TRANSFORMS_16000(input_wav)
return mel
def show_mel_fn(input_audio):
crop_len = 500 # the full mel cannot be showed due to gradio's Image bug when using tool='sketch'
crop_mel = gen_mel(input_audio)[:,:crop_len]
color_mel = cmap_transform(crop_mel)
return Image.fromarray((color_mel*255).astype(np.uint8))
def inpaint(sampler, batch, seed, ddim_steps, num_samples=1, W=512, H=512):
model = sampler.model
prng = np.random.RandomState(seed)
start_code = prng.randn(num_samples, model.first_stage_model.embed_dim, H // 8, W // 8)
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
c = model.get_first_stage_encoding(model.encode_first_stage(batch["masked_mel"]))
cc = torch.nn.functional.interpolate(batch["mask"],
size=c.shape[-2:])
c = torch.cat((c, cc), dim=1) # (b,c+1,h,w) 1 is mask
shape = (c.shape[1]-1,)+c.shape[2:]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False)
x_samples_ddim = model.decode_first_stage(samples_ddim)
mask = batch["mask"]# [-1,1]
mel = torch.clamp((batch["mel"]+1.0)/2.0,min=0.0, max=1.0)
mask = torch.clamp((batch["mask"]+1.0)/2.0,min=0.0, max=1.0)
predicted_mel = torch.clamp((x_samples_ddim+1.0)/2.0,min=0.0, max=1.0)
inpainted = (1-mask)*mel+mask*predicted_mel
inpainted = inpainted.cpu().numpy().squeeze()
inapint_wav = vocoder.vocode(inpainted)
return inpainted,inapint_wav
def predict(input_audio,mel_and_mask,ddim_steps,seed):
show_mel = np.array(mel_and_mask['image'].convert("L"))/255 # 由于展示的mel只展示了一部分,所以需要重新从音频生成mel
mask = np.array(mel_and_mask["mask"].convert("L"))/255
mel_bins,mel_len = 80,848
input_mel = gen_mel(input_audio)[:,:mel_len]# 由于展示的mel只展示了一部分,所以需要重新从音频生成mel
mask = np.pad(mask,((0,0),(0,mel_len-mask.shape[1])),mode='constant',constant_values=0)# 将mask填充到原来的mel的大小
print(mask.shape,input_mel.shape)
with torch.no_grad():
batch = make_batch_sd(input_mel,mask,device,num_samples=1)
inpainted,gen_wav = inpaint(
sampler=sampler,
batch=batch,
seed=seed,
ddim_steps=ddim_steps,
num_samples=1,
H=mel_bins, W=mel_len
)
inpainted = inpainted[:,:show_mel.shape[1]]
color_mel = cmap_transform(inpainted)
input_len = int(input_audio[1].shape[0] * SAMPLE_RATE / input_audio[0])
gen_wav = (gen_wav * 32768).astype(np.int16)[:input_len]
return Image.fromarray((color_mel*255).astype(np.uint8)),(SAMPLE_RATE,gen_wav)
sampler = initialize_model('./configs/inpaint/txt2audio_args.yaml', './useful_ckpts/inpaint7_epoch00047.ckpt')
vocoder = VocoderBigVGAN('./vocoder/logs/bigv16k53w',device=device)
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Make-An-Audio Inpainting")
with gr.Row():
with gr.Column():
input_audio = gr.inputs.Audio()
show_button = gr.Button("Show Mel")
run_button = gr.Button("Predict Masked Place")
with gr.Accordion("Advanced options", open=False):
ddim_steps = gr.Slider(label="Steps", minimum=1,
maximum=150, value=100, step=1)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
step=1,
randomize=True,
)
with gr.Column():
show_inpainted = gr.Image(type="pil").style(width=848,height=80)
outaudio = gr.Audio()
show_mel = gr.Image(type="pil",tool='sketch')#.style(width=848,height=80) # 加上这个没办法展示完全图片
show_button.click(fn=show_mel_fn, inputs=[input_audio], outputs=show_mel)
run_button.click(fn=predict, inputs=[input_audio,show_mel,ddim_steps,seed], outputs=[show_inpainted,outaudio])
block.launch()