Spaces:
Build error
Build error
import asyncio | |
import os | |
import uuid | |
from fastapi import APIRouter, Depends, HTTPException, Path, WebSocket, WebSocketDisconnect, Query | |
from firebase_admin import auth | |
from firebase_admin.exceptions import FirebaseError | |
from requests import Session | |
from realtime_ai_character.audio.speech_to_text import (SpeechToText, | |
get_speech_to_text) | |
from realtime_ai_character.audio.text_to_speech import (TextToSpeech, | |
get_text_to_speech) | |
from realtime_ai_character.character_catalog.catalog_manager import ( | |
CatalogManager, get_catalog_manager) | |
from realtime_ai_character.database.connection import get_db | |
from realtime_ai_character.llm import (AsyncCallbackAudioHandler, | |
AsyncCallbackTextHandler, get_llm, LLM) | |
from realtime_ai_character.logger import get_logger | |
from realtime_ai_character.models.interaction import Interaction | |
from realtime_ai_character.utils import (ConversationHistory, build_history, | |
get_connection_manager) | |
logger = get_logger(__name__) | |
router = APIRouter() | |
manager = get_connection_manager() | |
GREETING_TXT = 'Hi, my friend, what brings you here today?' | |
async def get_current_user(token: str): | |
"""Heler function for auth with Firebase.""" | |
if not token: | |
return "" | |
try: | |
decoded_token = auth.verify_id_token(token) | |
except FirebaseError as e: | |
logger.info(f'Receveid invalid token: {token} with error {e}') | |
raise HTTPException(status_code=401, | |
detail="Invalid authentication credentials") | |
return decoded_token['uid'] | |
async def websocket_endpoint(websocket: WebSocket, | |
client_id: int = Path(...), | |
api_key: str = Query(None), | |
llm_model: str = Query(default=os.getenv( | |
'LLM_MODEL_USE', 'gpt-3.5-turbo-16k')), | |
token: str = Query(None), | |
db: Session = Depends(get_db), | |
catalog_manager=Depends(get_catalog_manager), | |
speech_to_text=Depends(get_speech_to_text), | |
text_to_speech=Depends(get_text_to_speech)): | |
# Default user_id to client_id. If auth is enabled and token is provided, use | |
# the user_id from the token. | |
user_id = str(client_id) | |
if os.getenv('USE_AUTH', ''): | |
# Do not allow anonymous users to use non-GPT3.5 model. | |
if not token and llm_model != 'gpt-3.5-turbo-16k': | |
await websocket.close(code=1008, reason="Unauthorized") | |
return | |
try: | |
user_id = await get_current_user(token) | |
except HTTPException: | |
await websocket.close(code=1008, reason="Unauthorized") | |
return | |
llm = get_llm(model=llm_model) | |
await manager.connect(websocket) | |
try: | |
main_task = asyncio.create_task( | |
handle_receive(websocket, client_id, db, llm, catalog_manager, | |
speech_to_text, text_to_speech)) | |
await asyncio.gather(main_task) | |
except WebSocketDisconnect: | |
await manager.disconnect(websocket) | |
await manager.broadcast_message(f"User #{user_id} left the chat") | |
async def handle_receive(websocket: WebSocket, client_id: int, db: Session, | |
llm: LLM, catalog_manager: CatalogManager, | |
speech_to_text: SpeechToText, | |
text_to_speech: TextToSpeech): | |
try: | |
conversation_history = ConversationHistory() | |
# TODO: clean up client_id once migration is done. | |
user_id = str(client_id) | |
session_id = str(uuid.uuid4().hex) | |
# 0. Receive client platform info (web, mobile, terminal) | |
data = await websocket.receive() | |
if data['type'] != 'websocket.receive': | |
raise WebSocketDisconnect('disconnected') | |
platform = data['text'] | |
logger.info(f"User #{user_id}:{platform} connected to server with " | |
f"session_id {session_id}") | |
# 1. User selected a character | |
character = None | |
character_list = list(catalog_manager.characters.keys()) | |
user_input_template = 'Context:{context}\n User:{query}' | |
while not character: | |
character_message = "\n".join([ | |
f"{i+1} - {character}" | |
for i, character in enumerate(character_list) | |
]) | |
await manager.send_message( | |
message= | |
f"Select your character by entering the corresponding number:\n" | |
f"{character_message}\n", | |
websocket=websocket) | |
data = await websocket.receive() | |
if data['type'] != 'websocket.receive': | |
raise WebSocketDisconnect('disconnected') | |
if not character and 'text' in data: | |
selection = int(data['text']) | |
if selection > len(character_list) or selection < 1: | |
await manager.send_message( | |
message= | |
f"Invalid selection. Select your character [" | |
f"{', '.join(catalog_manager.characters.keys())}]\n", | |
websocket=websocket) | |
continue | |
character = catalog_manager.get_character( | |
character_list[selection - 1]) | |
conversation_history.system_prompt = character.llm_system_prompt | |
user_input_template = character.llm_user_prompt | |
logger.info( | |
f"User #{user_id} selected character: {character.name}") | |
tts_event = asyncio.Event() | |
tts_task = None | |
previous_transcript = None | |
token_buffer = [] | |
# Greet the user | |
await manager.send_message(message=GREETING_TXT, websocket=websocket) | |
tts_task = asyncio.create_task( | |
text_to_speech.stream( | |
text=GREETING_TXT, | |
websocket=websocket, | |
tts_event=tts_event, | |
characater_name=character.name, | |
first_sentence=True, | |
)) | |
# Send end of the greeting so the client knows when to start listening | |
await manager.send_message(message='[end]\n', websocket=websocket) | |
async def on_new_token(token): | |
return await manager.send_message(message=token, | |
websocket=websocket) | |
async def stop_audio(): | |
if tts_task and not tts_task.done(): | |
tts_event.set() | |
tts_task.cancel() | |
if previous_transcript: | |
conversation_history.user.append(previous_transcript) | |
conversation_history.ai.append(' '.join(token_buffer)) | |
token_buffer.clear() | |
try: | |
await tts_task | |
except asyncio.CancelledError: | |
pass | |
tts_event.clear() | |
while True: | |
data = await websocket.receive() | |
if data['type'] != 'websocket.receive': | |
raise WebSocketDisconnect('disconnected') | |
# handle text message | |
if 'text' in data: | |
msg_data = data['text'] | |
# 0. itermidiate transcript starts with [&] | |
if msg_data.startswith('[&]'): | |
logger.info(f'intermediate transcript: {msg_data}') | |
if not os.getenv('EXPERIMENT_CONVERSATION_UTTERANCE', ''): | |
continue | |
asyncio.create_task(stop_audio()) | |
asyncio.create_task( | |
llm.achat_utterances( | |
history=build_history(conversation_history), | |
user_input=msg_data, | |
callback=AsyncCallbackTextHandler( | |
on_new_token, []), | |
audioCallback=AsyncCallbackAudioHandler( | |
text_to_speech, websocket, tts_event, | |
character.name))) | |
continue | |
# 1. Send message to LLM | |
print('response = await llm.achat, user_input', msg_data) | |
response = await llm.achat( | |
history=build_history(conversation_history), | |
user_input=msg_data, | |
user_input_template=user_input_template, | |
callback=AsyncCallbackTextHandler(on_new_token, | |
token_buffer), | |
audioCallback=AsyncCallbackAudioHandler( | |
text_to_speech, websocket, tts_event, character.name), | |
character=character) | |
# 2. Send response to client | |
await manager.send_message(message='[end]\n', | |
websocket=websocket) | |
# 3. Update conversation history | |
conversation_history.user.append(msg_data) | |
conversation_history.ai.append(response) | |
token_buffer.clear() | |
# 4. Persist interaction in the database | |
Interaction(client_id=client_id, | |
user_id=user_id, | |
session_id=session_id, | |
client_message_unicode=msg_data, | |
server_message_unicode=response, | |
platform=platform, | |
action_type='text').save(db) | |
# handle binary message(audio) | |
elif 'bytes' in data: | |
binary_data = data['bytes'] | |
# 1. Transcribe audio | |
transcript: str = speech_to_text.transcribe( | |
binary_data, platform=platform, | |
prompt=character.name).strip() | |
# ignore audio that picks up background noise | |
if (not transcript or len(transcript) < 2): | |
continue | |
# 2. Send transcript to client | |
await manager.send_message( | |
message=f'[+]You said: {transcript}', websocket=websocket) | |
# 3. stop the previous audio stream, if new transcript is received | |
await stop_audio() | |
previous_transcript = transcript | |
async def tts_task_done_call_back(response): | |
# Send response to client, [=] indicates the response is done | |
await manager.send_message(message='[=]', | |
websocket=websocket) | |
# Update conversation history | |
conversation_history.user.append(transcript) | |
conversation_history.ai.append(response) | |
token_buffer.clear() | |
# Persist interaction in the database | |
Interaction(client_id=client_id, | |
user_id=user_id, | |
session_id=session_id, | |
client_message_unicode=transcript, | |
server_message_unicode=response, | |
platform=platform, | |
action_type='audio').save(db) | |
# 4. Send message to LLM | |
tts_task = asyncio.create_task( | |
llm.achat(history=build_history(conversation_history), | |
user_input=transcript, | |
user_input_template=user_input_template, | |
callback=AsyncCallbackTextHandler( | |
on_new_token, token_buffer, | |
tts_task_done_call_back), | |
audioCallback=AsyncCallbackAudioHandler( | |
text_to_speech, websocket, tts_event, | |
character.name), | |
character=character)) | |
except WebSocketDisconnect: | |
logger.info(f"User #{user_id} closed the connection") | |
await manager.disconnect(websocket) | |
return | |