Spaces:
Running
Running
import re | |
import gradio as gr | |
import numpy as np | |
import os | |
import io | |
import wave | |
import threading | |
import subprocess | |
import sys | |
import time | |
from huggingface_hub import snapshot_download | |
from tools.fish_e2e import FishE2EAgent, FishE2EEventType | |
from tools.schema import ServeMessage, ServeTextPart, ServeVQPart | |
# Download Weights | |
os.makedirs("checkpoints", exist_ok=True) | |
snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4") | |
snapshot_download(repo_id="fishaudio/fish-agent-v0.1-3b", local_dir="./checkpoints/fish-agent-v0.1-3b") | |
SYSTEM_PROMPT = 'You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user\'s speech, then answer it in the following format: "Question: [USER_SPEECH]\n\nResponse: [YOUR_RESPONSE]\n"。You are required to use the following voice in this conversation.' | |
class ChatState: | |
def __init__(self): | |
self.conversation = [] | |
self.added_systext = False | |
self.added_sysaudio = False | |
def get_history(self): | |
results = [] | |
for msg in self.conversation: | |
results.append({"role": msg.role, "content": self.repr_message(msg)}) | |
# Process assistant messages to extract questions and update user messages | |
for i, msg in enumerate(results): | |
if msg["role"] == "assistant": | |
match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"]) | |
if match and i > 0 and results[i - 1]["role"] == "user": | |
# Update previous user message with extracted question | |
results[i - 1]["content"] += "\n" + match.group(1) | |
# Remove the Question/Answer format from assistant message | |
msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1] | |
return results | |
def repr_message(self, msg: ServeMessage): | |
response = "" | |
for part in msg.parts: | |
if isinstance(part, ServeTextPart): | |
response += part.text | |
elif isinstance(part, ServeVQPart): | |
response += f"<audio {len(part.codes[0]) / 21:.2f}s>" | |
return response | |
def clear_fn(): | |
return [], ChatState(), None, None, None | |
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): | |
buffer = io.BytesIO() | |
with wave.open(buffer, "wb") as wav_file: | |
wav_file.setnchannels(channels) | |
wav_file.setsampwidth(bit_depth // 8) | |
wav_file.setframerate(sample_rate) | |
wav_header_bytes = buffer.getvalue() | |
buffer.close() | |
return wav_header_bytes | |
async def process_audio_input( | |
sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str | |
): | |
if audio_input is None and not text_input: | |
raise gr.Error("No input provided") | |
agent = FishE2EAgent() # Create new agent instance for each request | |
# Convert audio input to numpy array | |
if isinstance(audio_input, tuple): | |
sr, audio_data = audio_input | |
elif text_input: | |
sr = 44100 | |
audio_data = None | |
else: | |
raise gr.Error("Invalid audio format") | |
if isinstance(sys_audio_input, tuple): | |
sr, sys_audio_data = sys_audio_input | |
else: | |
sr = 44100 | |
sys_audio_data = None | |
def append_to_chat_ctx( | |
part: ServeTextPart | ServeVQPart, role: str = "assistant" | |
) -> None: | |
if not state.conversation or state.conversation[-1].role != role: | |
state.conversation.append(ServeMessage(role=role, parts=[part])) | |
else: | |
state.conversation[-1].parts.append(part) | |
if state.added_systext is False and sys_text_input: | |
state.added_systext = True | |
append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system") | |
if text_input: | |
append_to_chat_ctx(ServeTextPart(text=text_input), role="user") | |
audio_data = None | |
result_audio = b"" | |
async for event in agent.stream( | |
sys_audio_data, | |
audio_data, | |
sr, | |
1, | |
chat_ctx={ | |
"messages": state.conversation, | |
"added_sysaudio": state.added_sysaudio, | |
}, | |
): | |
if event.type == FishE2EEventType.USER_CODES: | |
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user") | |
elif event.type == FishE2EEventType.SPEECH_SEGMENT: | |
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes)) | |
yield state.get_history(), wav_chunk_header() + event.frame.data, None, None | |
elif event.type == FishE2EEventType.TEXT_SEGMENT: | |
append_to_chat_ctx(ServeTextPart(text=event.text)) | |
yield state.get_history(), None, None, None | |
yield state.get_history(), None, None, None | |
async def process_text_input( | |
sys_audio_input, sys_text_input, state: ChatState, text_input: str | |
): | |
async for event in process_audio_input( | |
sys_audio_input, sys_text_input, None, state, text_input | |
): | |
yield event | |
def create_demo(): | |
with gr.Blocks() as demo: | |
state = gr.State(ChatState()) | |
with gr.Row(): | |
# Left column (70%) for chatbot and notes | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False, | |
height=600, | |
type="messages", | |
) | |
notes = gr.Markdown( | |
""" | |
# Fish Agent | |
1. This demo is the Fish Audio self-developed end-to-end language model Fish Agent 3B version. | |
2. You can find the code and weights in our official repository, but all related content is released under the CC BY-NC-SA 4.0 license. | |
3. The demo is an early beta version, and inference speed is yet to be optimized. | |
# Features | |
1. This model automatically integrates ASR and TTS components, requiring no external models, making it truly end-to-end rather than a three-stage process (ASR+LLM+TTS). | |
2. The model can use reference audio to control speaking voice. | |
3. It can generate audio with strong emotions and prosody. | |
""" | |
) | |
# Right column (30%) for controls | |
with gr.Column(scale=3): | |
sys_audio_input = gr.Audio( | |
sources=["upload"], | |
type="numpy", | |
label="Give a timbre for your assistant", | |
) | |
sys_text_input = gr.Textbox( | |
label="What is your assistant's role?", | |
value=SYSTEM_PROMPT, | |
type="text", | |
) | |
audio_input = gr.Audio( | |
sources=["microphone"], type="numpy", label="Speak your message" | |
) | |
text_input = gr.Textbox(label="Or type your message", type="text",value="Can you give a brief introduction of yourself?") | |
output_audio = gr.Audio( | |
label="Assistant's Voice", | |
streaming=True, | |
autoplay=True, | |
interactive=False, | |
) | |
send_button = gr.Button("Send", variant="primary") | |
clear_button = gr.Button("Clear") | |
# Event handlers | |
audio_input.stop_recording( | |
process_audio_input, | |
inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input], | |
outputs=[chatbot, output_audio, audio_input, text_input], | |
show_progress=True, | |
) | |
send_button.click( | |
process_text_input, | |
inputs=[sys_audio_input, sys_text_input, state, text_input], | |
outputs=[chatbot, output_audio, audio_input, text_input], | |
show_progress=True, | |
) | |
text_input.submit( | |
process_text_input, | |
inputs=[sys_audio_input, sys_text_input, state, text_input], | |
outputs=[chatbot, output_audio, audio_input, text_input], | |
show_progress=True, | |
) | |
clear_button.click( | |
clear_fn, | |
inputs=[], | |
outputs=[chatbot, state, audio_input, output_audio, text_input], | |
) | |
return demo | |
def run_api(): | |
subprocess.run([sys.executable, "-m", "tools.api"]) | |
if __name__ == "__main__": | |
# 创建并启动 API 线程 | |
api_thread = threading.Thread(target=run_api, daemon=True) | |
api_thread.start() | |
# 给 API 一些时间启动 | |
time.sleep(90) | |
# 创建并启动 Gradio demo | |
demo = create_demo() | |
demo.launch(share=True) | |