rvc-tts-yutou / app.py
Hev832's picture
Update app.py
7986b8e verified
raw
history blame
15.7 kB
import asyncio
import datetime
import logging
import os
import time
import traceback
import shutil
import urllib.request
import zipfile
import gdown
from argparse import ArgumentParser
os.system("aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt -d . -o hubert_base.pt")
os.system("aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/rmvpe.pt -d . -o rmvpe.pt")
os.system("aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/sail-rvc/yoimiya-jp/resolve/main/model.pth -d ./weights/yoimiya -o yoimiya.pth")
os.system("aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/sail-rvc/yoimiya-jp/resolve/main/model.index -d ./weights/yoimiya -o yoimiya.index")
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
rvc_models_dir = os.path.join(BASE_DIR, 'weights')
import edge_tts
import gradio as gr
import librosa
import torch
from fairseq import checkpoint_utils
from config import Config
from lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid,
SynthesizerTrnMs256NSFsid_nono,
SynthesizerTrnMs768NSFsid,
SynthesizerTrnMs768NSFsid_nono,
)
from rmvpe import RMVPE
from vc_infer_pipeline import VC
logging.getLogger("fairseq").setLevel(logging.WARNING)
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
limitation = os.getenv("SYSTEM") == "spaces"
config = Config()
edge_output_filename = "edge_output.mp3"
tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
tts_voices = [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
model_root = "weights"
models = [
d for d in os.listdir(model_root) if os.path.isdir(os.path.join(model_root, d))
]
if len(models) == 0:
raise ValueError("No model found in `weights` folder")
models.sort()
def model_data(model_name):
# global n_spk, tgt_sr, net_g, vc, cpt, version, index_file
pth_files = [
os.path.join(model_root, model_name, f)
for f in os.listdir(os.path.join(model_root, model_name))
if f.endswith(".pth")
]
if len(pth_files) == 0:
raise ValueError(f"No pth file found in {model_root}/{model_name}")
pth_path = pth_files[0]
print(f"Loading {pth_path}")
cpt = torch.load(pth_path, map_location="cpu")
tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
if_f0 = cpt.get("f0", 1)
version = cpt.get("version", "v1")
if version == "v1":
if if_f0 == 1:
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
else:
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
elif version == "v2":
if if_f0 == 1:
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
else:
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
else:
raise ValueError("Unknown version")
del net_g.enc_q
net_g.load_state_dict(cpt["weight"], strict=False)
print("Model loaded")
net_g.eval().to(config.device)
if config.is_half:
net_g = net_g.half()
else:
net_g = net_g.float()
vc = VC(tgt_sr, config)
# n_spk = cpt["config"][-3]
index_files = [
os.path.join(model_root, model_name, f)
for f in os.listdir(os.path.join(model_root, model_name))
if f.endswith(".index")
]
if len(index_files) == 0:
print("No index file found")
index_file = ""
else:
index_file = index_files[0]
print(f"Index file found: {index_file}")
return tgt_sr, net_g, vc, version, index_file, if_f0
def load_hubert():
global hubert_model
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
["hubert_base.pt"],
suffix="",
)
hubert_model = models[0]
hubert_model = hubert_model.to(config.device)
if config.is_half:
hubert_model = hubert_model.half()
else:
hubert_model = hubert_model.float()
return hubert_model.eval()
print("Loading hubert model...")
hubert_model = load_hubert()
print("Hubert model loaded.")
print("Loading rmvpe model...")
rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
print("rmvpe model loaded.")
def tts(
model_name,
speed,
tts_text,
tts_voice,
f0_up_key,
f0_method,
index_rate,
protect,
filter_radius=3,
resample_sr=0,
rms_mix_rate=0.25,
):
print("------------------")
print(datetime.datetime.now())
print("tts_text:")
print(tts_text)
print(f"tts_voice: {tts_voice}")
print(f"Model name: {model_name}")
print(f"F0: {f0_method}, Key: {f0_up_key}, Index: {index_rate}, Protect: {protect}")
try:
if limitation and len(tts_text) > 280:
print("Error: Text too long")
return (
f"Text characters should be at most 280 in this huggingface space, but got {len(tts_text)} characters.",
None,
None,
)
tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
t0 = time.time()
if speed >= 0:
speed_str = f"+{speed}%"
else:
speed_str = f"{speed}%"
asyncio.run(
edge_tts.Communicate(
tts_text, "-".join(tts_voice.split("-")[:-1]), rate=speed_str
).save(edge_output_filename)
)
t1 = time.time()
edge_time = t1 - t0
audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
duration = len(audio) / sr
print(f"Audio duration: {duration}s")
if limitation and duration >= 20:
print("Error: Audio too long")
return (
f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
edge_output_filename,
None,
)
f0_up_key = int(f0_up_key)
if not hubert_model:
load_hubert()
if f0_method == "rmvpe":
vc.model_rmvpe = rmvpe_model
times = [0, 0, 0]
audio_opt = vc.pipeline(
hubert_model,
net_g,
0,
audio,
edge_output_filename,
times,
f0_up_key,
f0_method,
index_file,
# file_big_npy,
index_rate,
if_f0,
filter_radius,
tgt_sr,
resample_sr,
rms_mix_rate,
version,
protect,
None,
)
if tgt_sr != resample_sr >= 16000:
tgt_sr = resample_sr
info = f"Success. Time: edge-tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
print(info)
return (
info,
edge_output_filename,
(tgt_sr, audio_opt),
)
except EOFError:
info = (
"It seems that the edge-tts output is not valid. "
"This may occur when the input text and the speaker do not match. "
"For example, maybe you entered Japanese (without alphabets) text but chose non-Japanese speaker?"
)
print(info)
return info, None, None
except:
info = traceback.format_exc()
print(info)
return info, None, None
def extract_zip(extraction_folder, zip_name):
os.makedirs(extraction_folder)
with zipfile.ZipFile(zip_name, 'r') as zip_ref:
zip_ref.extractall(extraction_folder)
os.remove(zip_name)
index_filepath, model_filepath = None, None
for root, dirs, files in os.walk(extraction_folder):
for name in files:
if name.endswith('.index') and os.stat(os.path.join(root, name)).st_size > 1024 * 100:
index_filepath = os.path.join(root, name)
if name.endswith('.pth') and os.stat(os.path.join(root, name)).st_size > 1024 * 1024 * 40:
model_filepath = os.path.join(root, name)
if not model_filepath:
raise gr.Error(f'No .pth model file was found in the extracted zip. Please check {extraction_folder}.')
# move model and index file to extraction folder
os.rename(model_filepath, os.path.join(extraction_folder, os.path.basename(model_filepath)))
if index_filepath:
os.rename(index_filepath, os.path.join(extraction_folder, os.path.basename(index_filepath)))
# remove any unnecessary nested folders
for filepath in os.listdir(extraction_folder):
if os.path.isdir(os.path.join(extraction_folder, filepath)):
shutil.rmtree(os.path.join(extraction_folder, filepath))
def download_online_model(url, dir_name, progress=gr.Progress()):
try:
progress(0, desc=f'[~] Downloading voice model with name {dir_name}...')
zip_name = url.split('/')[-1]
extraction_folder = os.path.join(rvc_models_dir, dir_name)
if os.path.exists(extraction_folder):
raise gr.Error(f'Voice model directory {dir_name} already exists! Choose a different name for your voice model.')
if 'huggingface.co' in url:
urllib.request.urlretrieve(url, zip_name)
if 'pixeldrain.com' in url:
zip_name = dir_name + '.zip'
url = f'https://pixeldrain.com/api/file/{zip_name}'
urllib.request.urlretrieve(url, zip_name)
elif 'drive.google.com' in url:
# Extract the Google Drive file ID
zip_name = dir_name + '.zip'
file_id = url.split('/')[-2]
output = os.path.join('.', f'{dir_name}.zip') # Adjust the output path if needed
gdown.download(id=file_id, output=output, quiet=False)
progress(0.5, desc='[~] Extracting zip...')
extract_zip(extraction_folder, zip_name)
return f'[+] {dir_name} Model successfully downloaded!'
except Exception as e:
raise gr.Error(str(e))
initial_md = """
# RVC text-to-speech webui
[![open in clab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Blane187/rvc-tts/blob/main/rvc_tts.ipynb)
This is a text-to-speech webui of RVC models.
Input text ➡[(edge-tts)](https://github.com/rany2/edge-tts)➡ Speech mp3 file ➡[(RVC)](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI)➡ Final output
"""
app = gr.Blocks(theme="Hev832/emerald", title="RVC-TTS")
with app:
gr.Markdown(initial_md)
with gr.Tab('inference tts'):
with gr.Row():
with gr.Column():
model_name = gr.Dropdown(label="Model", choices=models, value=models[0])
f0_key_up = gr.Number(
label="Transpose (the best value depends on the models and speakers)",
value=0,
)
with gr.Column():
f0_method = gr.Radio(
label="Pitch extraction method (Rmvpe is default)",
choices=["rmvpe", "crepe"], # harvest is too slow
value="rmvpe",
interactive=True,
)
index_rate = gr.Slider(
minimum=0,
maximum=1,
label="Index rate",
value=1,
interactive=True,
)
protect0 = gr.Slider(
minimum=0,
maximum=0.5,
label="Protect",
value=0.33,
step=0.01,
interactive=True,
)
with gr.Row():
with gr.Column():
tts_voice = gr.Dropdown(
label="Edge-tts speaker (format: language-Country-Name-Gender)",
choices=tts_voices,
allow_custom_value=False,
value="ja-JP-NanamiNeural-Female",
)
speed = gr.Slider(
minimum=-100,
maximum=100,
label="Speech speed (%)",
value=0,
step=10,
interactive=True,
)
tts_text = gr.Textbox(label="Input Text", value="これは日本語テキストから音声への変換デモです。")
with gr.Column():
but0 = gr.Button("Convert", variant="primary")
info_text = gr.Textbox(label="Output info")
with gr.Column():
edge_tts_output = gr.Audio(label="Edge Voice", type="filepath")
tts_output = gr.Audio(label="Result")
but0.click(
tts,
[
model_name,
speed,
tts_text,
tts_voice,
f0_key_up,
f0_method,
index_rate,
protect0,
],
[info_text, edge_tts_output, tts_output],
)
with gr.Row():
examples = gr.Examples(
examples_per_page=100,
examples=[
["これは日本語テキストから音声への変換デモです。", "ja-JP-NanamiNeural-Female"],
[
"This is an English text to speech conversation demo.",
"en-US-AriaNeural-Female",
],
],
inputs=[tts_text, tts_voice],
)
with gr.Tab('Download model'):
with gr.Accordion('From HuggingFace/Pixeldrain URL', open=True):
with gr.Row():
model_zip_link = gr.Text(label='Download link to model', info='Should be a zip file containing a .pth model file and an optional .index file.')
model_name = gr.Text(label='Name your model', info='Give your new model a unique name from your other voice models.')
with gr.Row():
download_btn = gr.Button('Download', variant='primary', scale=19)
dl_output_message = gr.Text(label='Output Message', interactive=False, scale=20)
download_btn.click(download_online_model, inputs=[model_zip_link, model_name], outputs=dl_output_message)
gr.Markdown('## Input Examples',)
gr.Examples(
[
['https://huggingface.co/phant0m4r/LiSA/resolve/main/LiSA.zip', 'Lisa'],
['https://huggingface.co/Hev832/rvc/resolve/main/Sonic.zip?download=true', 'Sonic'],
['https://huggingface.co/jkhgf/SLWooly/resolve/main/Jax.zip', 'Jax']
],
[model_zip_link, model_name],
[],
download_online_model,
)
with gr.Accordion('From Public Index', open=False):
gr.Markdown('## How to use')
gr.Markdown('- Click Initialize public models table')
gr.Markdown('- Filter models using tags or search bar')
gr.Markdown('- Select a row to autofill the download link and model name')
gr.Markdown('- Click Download')
with gr.Row():
pub_zip_link = gr.Text(label='Download link to model')
pub_model_name = gr.Text(label='Model name')
with gr.Row():
download_pub_btn = gr.Button('Download', variant='primary', scale=19)
pub_dl_output_message = gr.Text(label='Output Message', interactive=False, scale=20)
app.launch()