import torch import numpy as np import gradio as gr from PIL import Image import matplotlib from omegaconf import OmegaConf from einops import repeat import librosa from ldm.models.diffusion.ddim import DDIMSampler from vocoder.bigvgan.models import VocoderBigVGAN from ldm.util import instantiate_from_config from ldm.data.extract_mel_spectrogram import TRANSFORMS_16000 SAMPLE_RATE = 16000 cmap_transform = matplotlib.cm.viridis torch.set_grad_enabled(False) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") def initialize_model(config, ckpt): config = OmegaConf.load(config) model = instantiate_from_config(config.model) model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False) model = model.to(device) print(model.device,device,model.cond_stage_model.device) sampler = DDIMSampler(model) return sampler def make_batch_sd( mel, mask, device, num_samples=1): mel = torch.from_numpy(mel)[None,None,...].to(dtype=torch.float32) mask = torch.from_numpy(mask)[None,None,...].to(dtype=torch.float32) masked_mel = (1 - mask) * mel mel = mel * 2 - 1 mask = mask * 2 - 1 masked_mel = masked_mel * 2 -1 batch = { "mel": repeat(mel.to(device=device), "1 ... -> n ...", n=num_samples), "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), "masked_mel": repeat(masked_mel.to(device=device), "1 ... -> n ...", n=num_samples), } return batch def gen_mel(input_audio): sr,ori_wav = input_audio print(sr,ori_wav.shape,ori_wav) ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0 # order='C'是以C语言格式存储,不用管 if len(ori_wav.shape)==2:# stereo ori_wav = librosa.to_mono(ori_wav.T)# gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len) print(sr,ori_wav.shape,ori_wav) ori_wav = librosa.resample(ori_wav,orig_sr = sr,target_sr = SAMPLE_RATE) mel_len,hop_size = 848,256 input_len = mel_len * hop_size if len(ori_wav) < input_len: input_wav = np.pad(ori_wav,(0,mel_len*hop_size),constant_values=0) else: input_wav = ori_wav[:input_len] mel = TRANSFORMS_16000(input_wav) return mel def show_mel_fn(input_audio): crop_len = 500 # the full mel cannot be showed due to gradio's Image bug when using tool='sketch' crop_mel = gen_mel(input_audio)[:,:crop_len] color_mel = cmap_transform(crop_mel) return Image.fromarray((color_mel*255).astype(np.uint8)) def inpaint(sampler, batch, seed, ddim_steps, num_samples=1, W=512, H=512): model = sampler.model prng = np.random.RandomState(seed) start_code = prng.randn(num_samples, model.first_stage_model.embed_dim, H // 8, W // 8) start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32) c = model.get_first_stage_encoding(model.encode_first_stage(batch["masked_mel"])) cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:]) c = torch.cat((c, cc), dim=1) # (b,c+1,h,w) 1 is mask shape = (c.shape[1]-1,)+c.shape[2:] samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False) x_samples_ddim = model.decode_first_stage(samples_ddim) mask = batch["mask"]# [-1,1] mel = torch.clamp((batch["mel"]+1.0)/2.0,min=0.0, max=1.0) mask = torch.clamp((batch["mask"]+1.0)/2.0,min=0.0, max=1.0) predicted_mel = torch.clamp((x_samples_ddim+1.0)/2.0,min=0.0, max=1.0) inpainted = (1-mask)*mel+mask*predicted_mel inpainted = inpainted.cpu().numpy().squeeze() inapint_wav = vocoder.vocode(inpainted) return inpainted,inapint_wav def predict(input_audio,mel_and_mask,ddim_steps,seed): show_mel = np.array(mel_and_mask['image'].convert("L"))/255 # 由于展示的mel只展示了一部分,所以需要重新从音频生成mel mask = np.array(mel_and_mask["mask"].convert("L"))/255 mel_bins,mel_len = 80,848 input_mel = gen_mel(input_audio)[:,:mel_len]# 由于展示的mel只展示了一部分,所以需要重新从音频生成mel mask = np.pad(mask,((0,0),(0,mel_len-mask.shape[1])),mode='constant',constant_values=0)# 将mask填充到原来的mel的大小 print(mask.shape,input_mel.shape) with torch.no_grad(): batch = make_batch_sd(input_mel,mask,device,num_samples=1) inpainted,gen_wav = inpaint( sampler=sampler, batch=batch, seed=seed, ddim_steps=ddim_steps, num_samples=1, H=mel_bins, W=mel_len ) inpainted = inpainted[:,:show_mel.shape[1]] color_mel = cmap_transform(inpainted) input_len = int(input_audio[1].shape[0] * SAMPLE_RATE / input_audio[0]) gen_wav = (gen_wav * 32768).astype(np.int16)[:input_len] return Image.fromarray((color_mel*255).astype(np.uint8)),(SAMPLE_RATE,gen_wav) sampler = initialize_model('./configs/inpaint/txt2audio_args.yaml', './useful_ckpts/inpaint7_epoch00047.ckpt') vocoder = VocoderBigVGAN('./vocoder/logs/bigv16k53w',device=device) block = gr.Blocks().queue() with block: with gr.Row(): gr.Markdown("## Make-An-Audio Inpainting") with gr.Row(): with gr.Column(): input_audio = gr.inputs.Audio() show_button = gr.Button("Show Mel") run_button = gr.Button("Predict Masked Place") with gr.Accordion("Advanced options", open=False): ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=150, value=100, step=1) seed = gr.Slider( label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True, ) with gr.Column(): show_inpainted = gr.Image(type="pil").style(width=848,height=80) outaudio = gr.Audio() show_mel = gr.Image(type="pil",tool='sketch')#.style(width=848,height=80) # 加上这个没办法展示完全图片 show_button.click(fn=show_mel_fn, inputs=[input_audio], outputs=show_mel) run_button.click(fn=predict, inputs=[input_audio,show_mel,ddim_steps,seed], outputs=[show_inpainted,outaudio]) block.launch()