degm-stts2 / app.py
mrfakename's picture
Update app.py
6c019c7
raw
history blame
3.07 kB
# StyleTTS 2 HTTP Streaming API by @fakerybakery - Copyright (c) 2023 mrfakename. All rights reserved.
# Docs: API_DOCS.md
# To-Do:
# * Support voice cloning
# * Implement authentication, user "credits" system w/ SQLite3
import io
import markdown
from tortoise.utils.text import split_and_recombine_text
from flask import Flask, Response, request, jsonify
import numpy as np
import ljinference
import torch
import hashlib
from scipy.io.wavfile import read, write
from flask_cors import CORS
import os
import torchaudio
def genHeader(sampleRate, bitsPerSample, channels):
datasize = 2000 * 10**6
o = bytes("RIFF", "ascii")
o += (datasize + 36).to_bytes(4, "little")
o += bytes("WAVE", "ascii")
o += bytes("fmt ", "ascii")
o += (16).to_bytes(4, "little")
o += (1).to_bytes(2, "little")
o += (channels).to_bytes(2, "little")
o += (sampleRate).to_bytes(4, "little")
o += (sampleRate * channels * bitsPerSample // 8).to_bytes(4, "little")
o += (channels * bitsPerSample // 8).to_bytes(2, "little")
o += (bitsPerSample).to_bytes(2, "little")
o += bytes("data", "ascii")
o += (datasize).to_bytes(4, "little")
return o
import phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
print("Starting Flask app")
app = Flask(__name__)
cors = CORS(app)
@app.route("/")
def index():
with open('API_DOCS.md', 'r') as f:
return markdown.markdown(f.read())
cache_dir = 'cache'
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
@app.route("/api", methods=['GET', 'POST'])
def serve_wav():
if request.method == 'GET':
request.form = request.args
if 'text' not in request.form:
if 'text' not in request.json:
error_response = {'error': 'Missing required fields. Please include "text" in your request.'}
return jsonify(error_response), 400
else:
text = request.json['text']
else:
text = request.form['text'].strip()
if not text.strip():
error_response = {'error': 'Empty text. Please ensure "text" in not empty.'}
return jsonify(error_response), 400
texts = split_and_recombine_text(text)
audios = []
noise = torch.randn(1,1,256).to('cuda' if torch.cuda.is_available() else 'cpu')
for t in texts:
# check for cache
hash = hashlib.sha256(t.lower().encode()).hexdigest()
if os.path.exists(os.path.join(cache_dir, hash + '.wav')):
audios.append(read(os.path.join(cache_dir, hash + '.wav'))[1])
else:
aud = ljinference.inference(t, noise, diffusion_steps=7, embedding_scale=1)
write(os.path.join(cache_dir, hash + '.wav'), 24000, aud)
audios.append(aud)
output_buffer = io.BytesIO()
write(output_buffer, 24000, np.concatenate(audios))
response = Response(output_buffer.getvalue())
response.headers["Content-Type"] = "audio/wav"
return response
if __name__ == "__main__":
app.run("0.0.0.0", port=7860)