from __future__ import annotations import os import io import re import time import uuid import torch import cohere import secrets import requests import fasttext import replicate import numpy as np import gradio as gr from PIL import Image from groq import Groq from TTS.api import TTS from elevenlabs import save from gradio.themes.base import Base from elevenlabs.client import ElevenLabs from huggingface_hub import hf_hub_download from gradio.themes.utils import colors, fonts, sizes from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from prompt_examples import TEXT_CHAT_EXAMPLES, IMG_GEN_PROMPT_EXAMPLES, AUDIO_EXAMPLES, TEXT_CHAT_EXAMPLES_LABELS, IMG_GEN_PROMPT_EXAMPLES_LABELS, AUDIO_EXAMPLES_LABELS from preambles import CHAT_PREAMBLE, AUDIO_RESPONSE_PREAMBLE, IMG_DESCRIPTION_PREAMBLE from constants import LID_LANGUAGES, NEETS_AI_LANGID_MAP, AYA_MODEL_NAME, BATCH_SIZE, USE_ELVENLABS, USE_REPLICATE HF_API_TOKEN = os.getenv("HF_API_KEY") ELEVEN_LABS_KEY = os.getenv("ELEVEN_LABS_KEY") NEETS_AI_API_KEY = os.getenv("NEETS_AI_API_KEY") GROQ_API_KEY = os.getenv("GROQ_API_KEY") IMG_COHERE_API_KEY = os.getenv("IMG_COHERE_API_KEY") AUDIO_COHERE_API_KEY = os.getenv("AUDIO_COHERE_API_KEY") CHAT_COHERE_API_KEY = os.getenv("CHAT_COHERE_API_KEY") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Initialize cohere clients img_prompt_client = cohere.Client( api_key=IMG_COHERE_API_KEY, client_name="c4ai-aya-expanse-img" ) chat_client = cohere.Client( api_key=CHAT_COHERE_API_KEY, client_name="c4ai-aya-expanse-chat" ) audio_response_client = cohere.Client( api_key=AUDIO_COHERE_API_KEY, client_name="c4ai-aya-expanse-audio" ) # Initialize the Groq client groq_client = Groq(api_key=GROQ_API_KEY) # Initialize the ElevenLabs client eleven_labs_client = ElevenLabs( api_key=ELEVEN_LABS_KEY, ) # Language identification lid_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin") LID_model = fasttext.load_model(lid_model_path) def predict_language(text): text = re.sub("\n", " ", text) label, logit = LID_model.predict(text) label = label[0][len("__label__") :] print("predicted language:", label) return label # Image Generation util functions def get_hf_inference_api_response(payload, model_id): headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} MODEL_API_URL = f"https://api-inference.huggingface.co/models/{model_id}" response = requests.post(MODEL_API_URL, headers=headers, json=payload) return response.content def replicate_api_inference(input_prompt): input_params={ "prompt": input_prompt, "go_fast": True, "megapixels": "1", "num_outputs": 1, "aspect_ratio": "1:1", "output_format": "jpg", "output_quality": 80, "enable_safety_checker": True, "safety_tolerance": 1, "num_inference_steps": 4 } image = replicate.run("black-forest-labs/flux-schnell",input=input_params) image = Image.open(image[0]) return image def generate_image(input_prompt, model_id="black-forest-labs/FLUX.1-schnell"): if input_prompt: if USE_REPLICATE: print("using replicate for image generation") image = replicate_api_inference(input_prompt) else: try: print("using HF inference API for image generation") image_bytes = get_hf_inference_api_response({ "inputs": input_prompt}, model_id) image = np.array(Image.open(io.BytesIO(image_bytes))) except Exception as e: print("HF API error:", e) # generate image with help replicate in case of error image = replicate_api_inference(input_prompt) return image else: return None def generate_img_prompt(input_prompt): if input_prompt: # clean prompt before doing language detection cleaned_prompt = clean_text(input_prompt, remove_bullets=True, remove_newline=True) text_lang_code = predict_language(cleaned_prompt) gr.Info("Generating Image", duration=2) if text_lang_code!="eng_Latn": text = f""" Translate the given input prompt to English. Input Prompt: {input_prompt} Then based on the English translation of the prompt, generate a detailed image description which can be used to generate an image using a text-to-image model. Do not use more than 3-4 lines for the image description. Respond with only the image description. """ else: text = f"""Generate a detailed image description which can be used to generate an image using a text-to-image model based on the given input prompt: Input Prompt: {input_prompt} Do not use more than 3-4 lines for the description. """ response = img_prompt_client.chat(message=text, preamble=IMG_DESCRIPTION_PREAMBLE, model=AYA_MODEL_NAME) output = response.text return output else: return None # Chat with Aya util functions def trigger_example(example): chat, updated_history = generate_AJ_chat_response(example) return chat, updated_history def generate_aya_chat_response(user_message, cid, token, history=None): if not token: print("no token") #raise gr.Error("Error loading.") if history is None: history = [] if cid == "" or None: cid = str(uuid.uuid4()) print(f"cid: {cid} prompt:{user_message}") history.append(user_message) stream = chat_client.chat_stream(message=user_message, preamble=CHAT_PREAMBLE, conversation_id=cid, model=AYA_MODEL_NAME, connectors=[], temperature=0.3) output = "" for idx, response in enumerate(stream): if response.event_type == "text-generation": output += response.text if idx == 0: history.append(" " + output) else: history[-1] = output chat = [ (history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2) ] yield chat, history, cid return chat, history, cid def clear_chat(): return [], [], str(uuid.uuid4()) # Audio Pipeline util functions def transcribe_and_stream(inputs, model_name="groq_whisper", show_info="show_info", language="english"): if inputs: if show_info=="show_info": gr.Info("Processing Audio", duration=1) if model_name != "groq_whisper": print("DEVICE:", DEVICE) pipe = pipeline( task="automatic-speech-recognition", model=model_name, chunk_length_s=30, DEVICE=DEVICE) text = pipe(inputs, batch_size=BATCH_SIZE, return_timestamps=True)["text"] else: text = groq_whisper_tts(inputs) # stream text output for i in range(len(text)): time.sleep(0.01) yield text[: i + 10] else: return "" def aya_speech_text_response(text): if text: stream = audio_response_client.chat_stream(message=text,preamble=AUDIO_RESPONSE_PREAMBLE, model=AYA_MODEL_NAME) output = "" for event in stream: if event: if event.event_type == "text-generation": output+=event.text cleaned_output = clean_text(output) yield cleaned_output else: return "" def clean_text(text, remove_bullets=False, remove_newline=False): # Remove bold formatting cleaned_text = re.sub(r"\*\*", "", text) if remove_bullets: cleaned_text = re.sub(r"^- ", "", cleaned_text, flags=re.MULTILINE) if remove_newline: cleaned_text = re.sub(r"\n", " ", cleaned_text) return cleaned_text def convert_text_to_speech(text, language="english"): # do language detection to determine voice of speech response if text: # clean text before doing language detection cleaned_text = clean_text(text, remove_bullets=True, remove_newline=True) text_lang_code = predict_language(cleaned_text) if not USE_ELVENLABS: if text_lang_code!= "jpn_Jpan": audio_path = neetsai_tts(text, text_lang_code) else: print("DEVICE:", DEVICE) # if language is japanese then use XTTS for TTS since neets_ai doesn't support japanese voice tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(DEVICE) speaker_wav="samples/ja-sample.wav" lang_code="ja" audio_path = "./output.wav" tts.tts_to_file(text=text, speaker_wav=speaker_wav, language=lang_code, file_path=audio_path) else: # use elevenlabs for TTS audio_path = elevenlabs_generate_audio(text) return audio_path else: return None def elevenlabs_generate_audio(text): audio = eleven_labs_client.generate( text=text, voice="River", model="eleven_turbo_v2_5", #"eleven_multilingual_v2" ) # save audio audio_path = "./audio.mp3" save(audio, audio_path) return audio_path def neetsai_tts(input_text, text_lang_code): if text_lang_code in LID_LANGUAGES.keys(): language = LID_LANGUAGES[text_lang_code] else: # use english voice as default for languages outside 23 languages of AJ language = "english" neets_lang_id = NEETS_AI_LANGID_MAP[language] neets_vits_voice_id = f"vits-{neets_lang_id}" response = requests.request( method="POST", url="https://api.neets.ai/v1/tts", headers={ "Content-Type": "application/json", "X-API-Key": NEETS_AI_API_KEY }, json={ "text": input_text, "voice_id": neets_vits_voice_id, "params": { "model": "vits" } } ) # save audio file audio_path = "neets_demo.mp3" with open(audio_path, "wb") as f: f.write(response.content) return audio_path def groq_whisper_tts(filename): with open(filename, "rb") as file: transcriptions = groq_client.audio.transcriptions.create( file=(filename, file.read()), model="whisper-large-v3-turbo", response_format="json", temperature=0.0 ) print("transcribed text:", transcriptions.text) print("********************************") return transcriptions.text # setup gradio app theme theme = gr.themes.Base( primary_hue=gr.themes.colors.teal, secondary_hue=gr.themes.colors.blue, neutral_hue=gr.themes.colors.gray, text_size=gr.themes.sizes.text_lg, ).set( # Primary Button Color button_primary_background_fill="#114A56", button_primary_background_fill_hover="#114A56", # Block Labels block_title_text_weight="600", block_label_text_weight="600", block_label_text_size="*text_md", ) demo = gr.Blocks(theme=theme, analytics_enabled=False) with demo: with gr.Row(variant="panel"): with gr.Column(scale=1): gr.Image("AyaExpanse.png", elem_id="logo-img", show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False) with gr.Column(scale=30): gr.Markdown("""C4AI Aya Expanse is a state-of-art model with highly advanced capabilities to connect the world across languages.
You can use this space to chat, speak and visualize with Aya Expanse in 23 languages.
**Model**: [aya-expanse-32B](https://huggingface.co/CohereForAI/aya-expanse-32b)
**Developed by**: [Cohere for AI](https://cohere.com/research) and [Cohere](https://cohere.com/)
**License**: [CC-BY-NC](https://cohere.com/c4ai-cc-by-nc-license), requires also adhering to [C4AI's Acceptable Use Policy](https://docs.cohere.com/docs/c4ai-acceptable-use-policy) """ ) with gr.TabItem("Chat with Aya") as chat_with_AJ: cid = gr.State("") token = gr.State(value=None) with gr.Column(): with gr.Row(): chatbot = gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, height=300) with gr.Row(): user_message = gr.Textbox(lines=1, placeholder="Ask anything in our 23 languages ...", label="Input", show_label=False) with gr.Row(): submit_button = gr.Button("Submit",variant="primary") clear_button = gr.Button("Clear") history = gr.State([]) user_message.submit(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32) submit_button.click(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32) clear_button.click(fn=clear_chat, inputs=None, outputs=[chatbot, history, cid], concurrency_limit=32) user_message.submit(lambda x: gr.update(value=""), None, [user_message], queue=False) submit_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False) clear_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False) with gr.Row(): gr.Examples( examples=TEXT_CHAT_EXAMPLES, inputs=user_message, cache_examples=False, fn=trigger_example, outputs=[chatbot], examples_per_page=25, label="Load example prompt for:", example_labels=TEXT_CHAT_EXAMPLES_LABELS, ) # End to End Testing Pipeline for speak with AJ with gr.TabItem("Speak with Aya") as speak_with_aya: with gr.Row(): with gr.Column(): e2e_audio_file = gr.Audio(sources="microphone", type="filepath", min_length=None) e2_audio_submit_button = gr.Button(value="Get Aya's Response", variant="primary") clear_button_microphone = gr.ClearButton() gr.Examples( examples=AUDIO_EXAMPLES, inputs=e2e_audio_file, cache_examples=False, examples_per_page=25, label="Load example audio for:", example_labels=AUDIO_EXAMPLES_LABELS, ) with gr.Column(): e2e_audio_file_trans = gr.Textbox(lines=3,label="Your Input", autoscroll=False, show_copy_button=True, interactive=False) e2e_audio_file_aya_response = gr.Textbox(lines=3,label="Aya's Response", show_copy_button=True, container=True, interactive=False) e2e_aya_audio_response = gr.Audio(type="filepath", label="Aya's Audio Response") # show_info = gr.Textbox(value="show_info", visible=False) # stt_model = gr.Textbox(value="groq_whisper", visible=False) with gr.Accordion("See Details", open=False): gr.Markdown("To enable voice interaction with Aya Expanse, this space uses [Whisper large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) and [Groq](https://groq.com/) for STT and [neets.ai](http://neets.ai/) for TTS.") # Generate Images with gr.TabItem("Visualize with AJ") as visualize_with_aya: with gr.Row(): with gr.Column(): input_img_prompt = gr.Textbox(placeholder="Ask anything in our 23 languages ...", label="Describe an image", lines=3) # generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False) submit_button_img = gr.Button(value="Submit", variant="primary") clear_button_img = gr.ClearButton() with gr.Column(): generated_img = gr.Image(label="Generated Image", interactive=False) with gr.Row(): gr.Examples( examples=IMG_GEN_PROMPT_EXAMPLES, inputs=input_img_prompt, cache_examples=False, examples_per_page=25, label="Load example prompt for:", example_labels=IMG_GEN_PROMPT_EXAMPLES_LABELS ) generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False) # increase spacing between examples and Accordion components with gr.Row(): pass with gr.Row(): pass with gr.Row(): pass with gr.Row(): with gr.Accordion("See Details", open=False): gr.Markdown("This space uses AJ.Chat for translating multilingual prompts and generating detailed image descriptions and [Flux Schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) for Image Generation.") # Image Generation clear_button_img.click(lambda: None, None, input_img_prompt) clear_button_img.click(lambda: None, None, generated_img_desc) clear_button_img.click(lambda: None, None, generated_img) submit_button_img.click( generate_img_prompt, inputs=[input_img_prompt], outputs=[generated_img_desc], ) generated_img_desc.change( generate_image, #run_flux, inputs=[generated_img_desc], outputs=[generated_img], show_progress="full", ) # Audio Pipeline clear_button_microphone.click(lambda: None, None, e2e_audio_file) clear_button_microphone.click(lambda: None, None, e2e_aya_audio_response) clear_button_microphone.click(lambda: None, None, e2e_audio_file_aya_response) clear_button_microphone.click(lambda: None, None, e2e_audio_file_trans) #e2e_audio_file.change( e2_audio_submit_button.click( transcribe_and_stream, inputs=[e2e_audio_file], outputs=[e2e_audio_file_trans], show_progress="full", ).then( aya_speech_text_response, inputs=[e2e_audio_file_trans], outputs=[e2e_audio_file_aya_response], show_progress="full", ).then( convert_text_to_speech, inputs=[e2e_audio_file_aya_response], outputs=[e2e_aya_audio_response], show_progress="full", ) demo.load(lambda: secrets.token_hex(16), None, token) demo.queue(api_open=False, max_size=20, default_concurrency_limit=4).launch(show_api=False, allowed_paths=['/home/user/app'])