fish-agent / app.py
PoTaTo721's picture
Upload Fish-Agent Demo
4f6613a
raw
history blame
8.38 kB
import re
import gradio as gr
import numpy as np
import os
import threading
import subprocess
import sys
import time
from huggingface_hub import snapshot_download
from tools.fish_e2e import FishE2EAgent, FishE2EEventType
from tools.schema import ServeMessage, ServeTextPart, ServeVQPart
# Download Weights
os.makedirs("checkpoints", exist_ok=True)
snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4")
snapshot_download(repo_id="fishaudio/fish-agent-v0.1-3b", local_dir="./checkpoints/fish-agent-v0.1-3b")
class ChatState:
def __init__(self):
self.conversation = []
self.added_systext = False
self.added_sysaudio = False
def get_history(self):
results = []
for msg in self.conversation:
results.append({"role": msg.role, "content": self.repr_message(msg)})
# Process assistant messages to extract questions and update user messages
for i, msg in enumerate(results):
if msg["role"] == "assistant":
match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
if match and i > 0 and results[i - 1]["role"] == "user":
# Update previous user message with extracted question
results[i - 1]["content"] += "\n" + match.group(1)
# Remove the Question/Answer format from assistant message
msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
return results
def repr_message(self, msg: ServeMessage):
response = ""
for part in msg.parts:
if isinstance(part, ServeTextPart):
response += part.text
elif isinstance(part, ServeVQPart):
response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
return response
def clear_fn():
return [], ChatState(), None, None, None
async def process_audio_input(
sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
):
if audio_input is None and not text_input:
raise gr.Error("No input provided")
agent = FishE2EAgent() # Create new agent instance for each request
# Convert audio input to numpy array
if isinstance(audio_input, tuple):
sr, audio_data = audio_input
elif text_input:
sr = 44100
audio_data = None
else:
raise gr.Error("Invalid audio format")
if isinstance(sys_audio_input, tuple):
sr, sys_audio_data = sys_audio_input
elif text_input:
sr = 44100
sys_audio_data = None
else:
raise gr.Error("Invalid audio format")
def append_to_chat_ctx(
part: ServeTextPart | ServeVQPart, role: str = "assistant"
) -> None:
if not state.conversation or state.conversation[-1].role != role:
state.conversation.append(ServeMessage(role=role, parts=[part]))
else:
state.conversation[-1].parts.append(part)
if state.added_systext is False and sys_text_input:
state.added_systext = True
append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
if text_input:
append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
audio_data = None
result_audio = b""
async for event in agent.stream(
sys_audio_data,
audio_data,
sr,
1,
chat_ctx={
"messages": state.conversation,
"added_sysaudio": state.added_sysaudio,
},
):
if event.type == FishE2EEventType.USER_CODES:
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
elif event.type == FishE2EEventType.SPEECH_SEGMENT:
result_audio += event.frame.data
np_audio = np.frombuffer(result_audio, dtype=np.int16)
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
yield state.get_history(), (44100, np_audio), None, None
elif event.type == FishE2EEventType.TEXT_SEGMENT:
append_to_chat_ctx(ServeTextPart(text=event.text))
if result_audio:
np_audio = np.frombuffer(result_audio, dtype=np.int16)
yield state.get_history(), (44100, np_audio), None, None
else:
yield state.get_history(), None, None, None
np_audio = np.frombuffer(result_audio, dtype=np.int16)
yield state.get_history(), (44100, np_audio), None, None
async def process_text_input(
sys_audio_input, sys_text_input, state: ChatState, text_input: str
):
async for event in process_audio_input(
sys_audio_input, sys_text_input, None, state, text_input
):
yield event
def create_demo():
with gr.Blocks() as demo:
state = gr.State(ChatState())
with gr.Row():
# Left column (70%) for chatbot and notes
with gr.Column(scale=7):
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
height=600,
type="messages",
)
notes = gr.Markdown(
"""
# Fish Agent
1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
2. 你可以在我们的官方仓库找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
3. Demo为早期灰度测试版本,推理速度尚待优化.
# 特色
1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
2. 模型可以使用reference audio控制说话音色.
3. 可以生成具有较强情感与韵律的音频.
"""
)
# Right column (30%) for controls
with gr.Column(scale=3):
sys_audio_input = gr.Audio(
sources=["upload"],
type="numpy",
label="Give a timbre for your assistant",
)
sys_text_input = gr.Textbox(
label="What is your assistant's role?",
value='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nResponse: [你的回答]\n"。',
type="text",
)
audio_input = gr.Audio(
sources=["microphone"], type="numpy", label="Speak your message"
)
text_input = gr.Textbox(label="Or type your message", type="text")
output_audio = gr.Audio(label="Assistant's Voice", type="numpy")
send_button = gr.Button("Send", variant="primary")
clear_button = gr.Button("Clear")
# Event handlers
audio_input.stop_recording(
process_audio_input,
inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
outputs=[chatbot, output_audio, audio_input, text_input],
show_progress=True,
)
send_button.click(
process_text_input,
inputs=[sys_audio_input, sys_text_input, state, text_input],
outputs=[chatbot, output_audio, audio_input, text_input],
show_progress=True,
)
text_input.submit(
process_text_input,
inputs=[sys_audio_input, sys_text_input, state, text_input],
outputs=[chatbot, output_audio, audio_input, text_input],
show_progress=True,
)
clear_button.click(
clear_fn,
inputs=[],
outputs=[chatbot, state, audio_input, output_audio, text_input],
)
return demo
def run_api():
subprocess.run([sys.executable, "-m", "tools.api"])
if __name__ == "__main__":
# 创建并启动 API 线程
api_thread = threading.Thread(target=run_api, daemon=True)
api_thread.start()
# 给 API 一些时间启动
time.sleep(60)
# 创建并启动 Gradio demo
demo = create_demo()
demo.launch(server_name="127.0.0.1", server_port=7860, share=True)