from __future__ import annotations import os import io import re import time import uuid import torch import cohere import secrets import requests import fasttext import replicate import numpy as np import gradio as gr from PIL import Image from groq import Groq from TTS.api import TTS from elevenlabs import save from gradio.themes.base import Base from elevenlabs.client import ElevenLabs from huggingface_hub import hf_hub_download from gradio.themes.utils import colors, fonts, sizes from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from prompt_examples import TEXT_CHAT_EXAMPLES, IMG_GEN_PROMPT_EXAMPLES, AUDIO_EXAMPLES, TEXT_CHAT_EXAMPLES_LABELS, IMG_GEN_PROMPT_EXAMPLES_LABELS, AUDIO_EXAMPLES_LABELS from preambles import CHAT_PREAMBLE, AUDIO_RESPONSE_PREAMBLE, IMG_DESCRIPTION_PREAMBLE from constants import LID_LANGUAGES, NEETS_AI_LANGID_MAP, AYA_MODEL_NAME, BATCH_SIZE, USE_ELVENLABS, USE_REPLICATE HF_API_TOKEN = os.getenv("HF_API_KEY") ELEVEN_LABS_KEY = os.getenv("ELEVEN_LABS_KEY") NEETS_AI_API_KEY = os.getenv("NEETS_AI_API_KEY") GROQ_API_KEY = os.getenv("GROQ_API_KEY") IMG_COHERE_API_KEY = os.getenv("IMG_COHERE_API_KEY") AUDIO_COHERE_API_KEY = os.getenv("AUDIO_COHERE_API_KEY") CHAT_COHERE_API_KEY = os.getenv("CHAT_COHERE_API_KEY") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Initialize cohere clients img_prompt_client = cohere.Client( api_key=IMG_COHERE_API_KEY, client_name="c4ai-aya-expanse-img" ) chat_client = cohere.Client( api_key=CHAT_COHERE_API_KEY, client_name="c4ai-aya-expanse-chat" ) audio_response_client = cohere.Client( api_key=AUDIO_COHERE_API_KEY, client_name="c4ai-aya-expanse-audio" ) # Initialize the Groq client groq_client = Groq(api_key=GROQ_API_KEY) # Initialize the ElevenLabs client eleven_labs_client = ElevenLabs( api_key=ELEVEN_LABS_KEY, ) # Language identification lid_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin") LID_model = fasttext.load_model(lid_model_path) def predict_language(text): text = re.sub("\n", " ", text) label, logit = LID_model.predict(text) label = label[0][len("__label__") :] print("predicted language:", label) return label # Image Generation util functions def get_hf_inference_api_response(payload, model_id): headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} MODEL_API_URL = f"https://api-inference.huggingface.co/models/{model_id}" response = requests.post(MODEL_API_URL, headers=headers, json=payload) return response.content def replicate_api_inference(input_prompt): input_params={ "prompt": input_prompt, "go_fast": True, "megapixels": "1", "num_outputs": 1, "aspect_ratio": "1:1", "output_format": "jpg", "output_quality": 80, "num_inference_steps": 4 } image = replicate.run("black-forest-labs/flux-schnell",input=input_params) image = Image.open(image[0]) return image def generate_image(input_prompt, model_id="black-forest-labs/FLUX.1-schnell"): if input_prompt!="": if input_prompt=='Image generation blocked for prompts that include humans, kids, or children.': return None else: if USE_REPLICATE: print("using replicate for image generation") image = replicate_api_inference(input_prompt) else: try: print("using HF inference API for image generation") image_bytes = get_hf_inference_api_response({ "inputs": input_prompt}, model_id) image = np.array(Image.open(io.BytesIO(image_bytes))) except Exception as e: print("HF API error:", e) # generate image with help replicate in case of error image = replicate_api_inference(input_prompt) return image else: return None def generate_img_prompt(input_prompt): # clean prompt before doing language detection cleaned_prompt = clean_text(input_prompt, remove_bullets=True, remove_newline=True) text_lang_code = predict_language(cleaned_prompt) language = LID_LANGUAGES[text_lang_code] gr.Info("Generating Image", duration=2) if language!="english": text = f""" Translate the given input prompt to English. Input Prompt: {input_prompt} Once translated, use the English version of the prompt to create a detailed image description suitable for a text-to-image model. Ensure the description is concise, limited to 2-3 lines, and integrates key elements from the translated prompt. Add the prompt English translation to the image description, and respond with that. """ else: text = f"""Generate a detailed image description which can be used to generate an image using a text-to-image model based on the given input prompt: Input Prompt: {input_prompt} Do not use more than 3-4 lines for the description. """ response = img_prompt_client.chat(message=text, preamble=IMG_DESCRIPTION_PREAMBLE, model=AYA_MODEL_NAME) output = response.text return output # Chat with Aya util functions def trigger_example(example): chat, updated_history = generate_aya_chat_response(example) return chat, updated_history def generate_aya_chat_response(user_message, cid, token, history=None): if not token: raise gr.Error("Error loading.") if history is None: history = [] if cid == "" or None: cid = str(uuid.uuid4()) print(f"cid: {cid} prompt:{user_message}") history.append(user_message) stream = chat_client.chat_stream(message=user_message, preamble=CHAT_PREAMBLE, conversation_id=cid, model=AYA_MODEL_NAME, connectors=[], temperature=0.3) output = "" for idx, response in enumerate(stream): if response.event_type == "text-generation": output += response.text if idx == 0: history.append(" " + output) else: history[-1] = output chat = [ (history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2) ] yield chat, history, cid return chat, history, cid def clear_chat(): return [], [], str(uuid.uuid4()) # Audio Pipeline util functions def transcribe_and_stream(inputs, show_info="no", model_name="openai/whisper-large-v3-turbo", language="english"): if inputs is not None and inputs!="": if show_info=="show_info": gr.Info("Processing Audio", duration=1) if model_name != "groq_whisper": print("DEVICE:", DEVICE) pipe = pipeline( task="automatic-speech-recognition", model=model_name, chunk_length_s=30, DEVICE=DEVICE) text = pipe(inputs, batch_size=BATCH_SIZE, return_timestamps=True)["text"] else: text = groq_whisper_tts(inputs) # stream text output for i in range(len(text)): time.sleep(0.01) yield text[: i + 10] else: return "" def aya_speech_text_response(text): if text is not None and text!="": stream = audio_response_client.chat_stream(message=text,preamble=AUDIO_RESPONSE_PREAMBLE, model=AYA_MODEL_NAME) output = "" for event in stream: if event: if event.event_type == "text-generation": output+=event.text cleaned_output = clean_text(output) yield cleaned_output else: return "" def clean_text(text, remove_bullets=False, remove_newline=False): # Remove bold formatting cleaned_text = re.sub(r"\*\*", "", text) if remove_bullets: cleaned_text = re.sub(r"^- ", "", cleaned_text, flags=re.MULTILINE) if remove_newline: cleaned_text = re.sub(r"\n", " ", cleaned_text) return cleaned_text def convert_text_to_speech(text, language="english"): # do language detection to determine voice of speech response if text is not None and text!="": # clean text before doing language detection cleaned_text = clean_text(text, remove_bullets=True, remove_newline=True) text_lang_code = predict_language(cleaned_text) language = LID_LANGUAGES[text_lang_code] if not USE_ELVENLABS: if language!= "japanese": audio_path = neetsai_tts(text, language) else: print("DEVICE:", DEVICE) # if language is japanese then use XTTS for TTS since neets_ai doesn't support japanese voice tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(DEVICE) speaker_wav="samples/ja-sample.wav" lang_code="ja" audio_path = "./output.wav" tts.tts_to_file(text=text, speaker_wav=speaker_wav, language=lang_code, file_path=audio_path) else: # use elevenlabs for TTS audio_path = elevenlabs_generate_audio(text) return audio_path else: return None def elevenlabs_generate_audio(text): audio = eleven_labs_client.generate( text=text, voice="River", model="eleven_turbo_v2_5", #"eleven_multilingual_v2" ) # save audio audio_path = "./audio.mp3" save(audio, audio_path) return audio_path def neetsai_tts(input_text, language): lang_id = NEETS_AI_LANGID_MAP[language] neets_vits_voice_id = f"vits-{lang_id}" response = requests.request( method="POST", url="https://api.neets.ai/v1/tts", headers={ "Content-Type": "application/json", "X-API-Key": NEETS_AI_API_KEY }, json={ "text": input_text, "voice_id": neets_vits_voice_id, "params": { "model": "vits" } } ) # save audio file audio_path = "neets_demo.mp3" with open(audio_path, "wb") as f: f.write(response.content) return audio_path def groq_whisper_tts(filename): with open(filename, "rb") as file: transcriptions = groq_client.audio.transcriptions.create( file=(filename, file.read()), model="whisper-large-v3-turbo", response_format="json", temperature=0.0 ) print("transcribed text:", transcriptions.text) print("********************************") return transcriptions.text # setup gradio app theme theme = gr.themes.Base( primary_hue=gr.themes.colors.teal, secondary_hue=gr.themes.colors.blue, neutral_hue=gr.themes.colors.gray, text_size=gr.themes.sizes.text_lg, ).set( # Primary Button Color button_primary_background_fill="#114A56", button_primary_background_fill_hover="#114A56", # Block Labels block_title_text_weight="600", block_label_text_weight="600", block_label_text_size="*text_md", ) demo = gr.Blocks(theme=theme, analytics_enabled=False) with demo: with gr.Row(variant="panel"): with gr.Column(scale=1): gr.Image("aya-expanse.png", elem_id="logo-img", show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False) with gr.Column(scale=30): gr.Markdown("""C4AI Aya Expanse is a state-of-art model with highly advanced capabilities to connect the world across languages.
You can use this space to chat, speak and visualize with Aya Expanse in 23 languages. **Developed by**: [Cohere for AI](https://cohere.com/research) and [Cohere](https://cohere.com/) """ ) # Text Chat with gr.TabItem("Chat with Aya") as chat_with_aya: cid = gr.State("") token = gr.State(value=None) with gr.Column(): with gr.Row(): chatbot = gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, height=300) with gr.Row(): user_message = gr.Textbox(lines=1, placeholder="Ask anything in our 23 languages ...", label="Input", show_label=False) with gr.Row(): submit_button = gr.Button("Submit",variant="primary") clear_button = gr.Button("Clear") history = gr.State([]) user_message.submit(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32) submit_button.click(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32) clear_button.click(fn=clear_chat, inputs=None, outputs=[chatbot, history, cid], concurrency_limit=32) user_message.submit(lambda x: gr.update(value=""), None, [user_message], queue=False) submit_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False) clear_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False) with gr.Row(): gr.Examples( examples=TEXT_CHAT_EXAMPLES, inputs=user_message, cache_examples=False, fn=trigger_example, outputs=[chatbot], examples_per_page=25, label="Load example prompt for:", example_labels=TEXT_CHAT_EXAMPLES_LABELS, ) # Audio Pipeline with gr.TabItem("Speak with Aya") as speak_with_aya: with gr.Row(): with gr.Column(): e2e_audio_file = gr.Audio(sources="microphone", type="filepath", min_length=None) clear_button_microphone = gr.ClearButton() gr.Examples( examples=AUDIO_EXAMPLES, inputs=e2e_audio_file, cache_examples=False, examples_per_page=25, label="Load example audio for:", example_labels=AUDIO_EXAMPLES_LABELS, ) with gr.Column(): e2e_audio_file_trans = gr.Textbox(lines=3,label="Your Input", autoscroll=False, show_copy_button=True, interactive=False) e2e_audio_file_aya_response = gr.Textbox(lines=3,label="Aya's Response", show_copy_button=True, container=True, interactive=False) e2e_aya_audio_response = gr.Audio(type="filepath", label="Aya's Audio Response") show_info = gr.Textbox(value="show_info", visible=False) stt_model = gr.Textbox(value="groq_whisper", visible=False) with gr.Accordion("See Details", open=False): gr.Markdown("To enable voice interaction with Aya Expanse, this space uses [Whisper large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) and [Groq](https://groq.com/) for STT and [neets.ai](http://neets.ai/) for TTS.") # Image Generation with gr.TabItem("Visualize with Aya") as visualize_with_aya: with gr.Row(): with gr.Column(): input_img_prompt = gr.Textbox(placeholder="Ask anything in our 23 languages ...", label="Describe an image", lines=3) # generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False) submit_button_img = gr.Button(value="Submit", variant="primary") clear_button_img = gr.ClearButton() with gr.Column(): generated_img = gr.Image(label="Generated Image", interactive=False) with gr.Row(): gr.Examples( examples=IMG_GEN_PROMPT_EXAMPLES, inputs=input_img_prompt, cache_examples=False, examples_per_page=25, label="Load example prompt for:", example_labels=IMG_GEN_PROMPT_EXAMPLES_LABELS ) generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False) # increase spacing between examples and Accordion components with gr.Row(): pass with gr.Row(): pass with gr.Row(): pass with gr.Row(): with gr.Accordion("See Details", open=False): gr.Markdown("This space uses Aya Expanse for translating multilingual prompts and generating detailed image descriptions and [Flux Schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) for Image Generation.") # Image Generation clear_button_img.click(lambda: None, None, input_img_prompt) clear_button_img.click(lambda: None, None, generated_img_desc) clear_button_img.click(lambda: None, None, generated_img) submit_button_img.click( generate_img_prompt, inputs=[input_img_prompt], outputs=[generated_img_desc], ) generated_img_desc.change( generate_image, #run_flux, inputs=[generated_img_desc], outputs=[generated_img], show_progress="hidden", ) # Audio Pipeline clear_button_microphone.click(lambda: None, None, e2e_audio_file) clear_button_microphone.click(lambda: None, None, e2e_audio_file_trans) clear_button_microphone.click(lambda: None, None, e2e_aya_audio_response) e2e_audio_file.change( transcribe_and_stream, inputs=[e2e_audio_file, show_info, stt_model], outputs=[e2e_audio_file_trans], show_progress="hidden", ).then( aya_speech_text_response, inputs=[e2e_audio_file_trans], outputs=[e2e_audio_file_aya_response], show_progress="minimal", ).then( convert_text_to_speech, inputs=[e2e_audio_file_aya_response], outputs=[e2e_aya_audio_response], show_progress="minimal", ) demo.load(lambda: secrets.token_hex(16), None, token) demo.queue(api_open=False, max_size=40).launch(show_api=False, allowed_paths=['/home/user/app'])