Anupam251272 commited on
Commit
3cc2780
1 Parent(s): 4cf5f2a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +505 -0
app.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ import io
4
+ import re
5
+ import time
6
+ import uuid
7
+ import torch
8
+ import cohere
9
+ import secrets
10
+ import requests
11
+ import fasttext
12
+ import replicate
13
+ import numpy as np
14
+ import gradio as gr
15
+ from PIL import Image
16
+ from groq import Groq
17
+ from TTS.api import TTS
18
+ from elevenlabs import save
19
+ from gradio.themes.base import Base
20
+ from elevenlabs.client import ElevenLabs
21
+ from huggingface_hub import hf_hub_download
22
+ from gradio.themes.utils import colors, fonts, sizes
23
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
24
+ from prompt_examples import TEXT_CHAT_EXAMPLES, IMG_GEN_PROMPT_EXAMPLES, AUDIO_EXAMPLES, TEXT_CHAT_EXAMPLES_LABELS, IMG_GEN_PROMPT_EXAMPLES_LABELS, AUDIO_EXAMPLES_LABELS
25
+ from preambles import CHAT_PREAMBLE, AUDIO_RESPONSE_PREAMBLE, IMG_DESCRIPTION_PREAMBLE
26
+ from constants import LID_LANGUAGES, NEETS_AI_LANGID_MAP, AYA_MODEL_NAME, BATCH_SIZE, USE_ELVENLABS, USE_REPLICATE
27
+
28
+
29
+ HF_API_TOKEN = os.getenv("HF_API_KEY")
30
+ ELEVEN_LABS_KEY = os.getenv("ELEVEN_LABS_KEY")
31
+ NEETS_AI_API_KEY = os.getenv("NEETS_AI_API_KEY")
32
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
33
+ IMG_COHERE_API_KEY = os.getenv("IMG_COHERE_API_KEY")
34
+ AUDIO_COHERE_API_KEY = os.getenv("AUDIO_COHERE_API_KEY")
35
+ CHAT_COHERE_API_KEY = os.getenv("CHAT_COHERE_API_KEY")
36
+
37
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ # Initialize cohere clients
40
+ img_prompt_client = cohere.Client(
41
+ api_key=IMG_COHERE_API_KEY,
42
+ client_name="c4ai-aya-expanse-img"
43
+ )
44
+ chat_client = cohere.Client(
45
+ api_key=CHAT_COHERE_API_KEY,
46
+ client_name="c4ai-aya-expanse-chat"
47
+ )
48
+ audio_response_client = cohere.Client(
49
+ api_key=AUDIO_COHERE_API_KEY,
50
+ client_name="c4ai-aya-expanse-audio"
51
+ )
52
+
53
+ # Initialize the Groq client
54
+ groq_client = Groq(api_key=GROQ_API_KEY)
55
+
56
+ # Initialize the ElevenLabs client
57
+ eleven_labs_client = ElevenLabs(
58
+ api_key=ELEVEN_LABS_KEY,
59
+ )
60
+
61
+ # Language identification
62
+ lid_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
63
+ LID_model = fasttext.load_model(lid_model_path)
64
+
65
+ def predict_language(text):
66
+ text = re.sub("\n", " ", text)
67
+ label, logit = LID_model.predict(text)
68
+ label = label[0][len("__label__") :]
69
+ print("predicted language:", label)
70
+ return label
71
+
72
+ # Image Generation util functions
73
+ def get_hf_inference_api_response(payload, model_id):
74
+ headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
75
+ MODEL_API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
76
+ response = requests.post(MODEL_API_URL, headers=headers, json=payload)
77
+ return response.content
78
+
79
+ def replicate_api_inference(input_prompt):
80
+ input_params={
81
+ "prompt": input_prompt,
82
+ "go_fast": True,
83
+ "megapixels": "1",
84
+ "num_outputs": 1,
85
+ "aspect_ratio": "1:1",
86
+ "output_format": "jpg",
87
+ "output_quality": 80,
88
+ "enable_safety_checker": True,
89
+ "safety_tolerance": 1,
90
+ "num_inference_steps": 4
91
+ }
92
+ image = replicate.run("black-forest-labs/flux-schnell",input=input_params)
93
+ image = Image.open(image[0])
94
+ return image
95
+
96
+ def generate_image(input_prompt, model_id="black-forest-labs/FLUX.1-schnell"):
97
+ if input_prompt:
98
+ if USE_REPLICATE:
99
+ print("using replicate for image generation")
100
+ image = replicate_api_inference(input_prompt)
101
+ else:
102
+ try:
103
+ print("using HF inference API for image generation")
104
+ image_bytes = get_hf_inference_api_response({ "inputs": input_prompt}, model_id)
105
+ image = np.array(Image.open(io.BytesIO(image_bytes)))
106
+ except Exception as e:
107
+ print("HF API error:", e)
108
+ # generate image with help replicate in case of error
109
+ image = replicate_api_inference(input_prompt)
110
+ return image
111
+ else:
112
+ return None
113
+
114
+ def generate_img_prompt(input_prompt):
115
+ if input_prompt:
116
+ # clean prompt before doing language detection
117
+ cleaned_prompt = clean_text(input_prompt, remove_bullets=True, remove_newline=True)
118
+ text_lang_code = predict_language(cleaned_prompt)
119
+
120
+ gr.Info("Generating Image", duration=2)
121
+
122
+ if text_lang_code!="eng_Latn":
123
+ text = f"""
124
+ Translate the given input prompt to English.
125
+ Input Prompt: {input_prompt}
126
+ Then based on the English translation of the prompt, generate a detailed image description which can be used to generate an image using a text-to-image model.
127
+ Do not use more than 3-4 lines for the image description. Respond with only the image description.
128
+ """
129
+ else:
130
+ text = f"""Generate a detailed image description which can be used to generate an image using a text-to-image model based on the given input prompt:
131
+ Input Prompt: {input_prompt}
132
+ Do not use more than 3-4 lines for the description.
133
+ """
134
+
135
+ response = img_prompt_client.chat(message=text, preamble=IMG_DESCRIPTION_PREAMBLE, model=AYA_MODEL_NAME)
136
+ output = response.text
137
+
138
+ return output
139
+ else:
140
+ return None
141
+
142
+
143
+ # Chat with Aya util functions
144
+
145
+ def trigger_example(example):
146
+ chat, updated_history = generate_AJ_chat_response(example)
147
+ return chat, updated_history
148
+
149
+ def generate_aya_chat_response(user_message, cid, token, history=None):
150
+ if not token:
151
+ print("no token")
152
+ #raise gr.Error("Error loading.")
153
+
154
+ if history is None:
155
+ history = []
156
+ if cid == "" or None:
157
+ cid = str(uuid.uuid4())
158
+
159
+ print(f"cid: {cid} prompt:{user_message}")
160
+
161
+ history.append(user_message)
162
+
163
+ stream = chat_client.chat_stream(message=user_message, preamble=CHAT_PREAMBLE, conversation_id=cid, model=AYA_MODEL_NAME, connectors=[], temperature=0.3)
164
+ output = ""
165
+
166
+ for idx, response in enumerate(stream):
167
+ if response.event_type == "text-generation":
168
+ output += response.text
169
+ if idx == 0:
170
+ history.append(" " + output)
171
+ else:
172
+ history[-1] = output
173
+ chat = [
174
+ (history[i].strip(), history[i + 1].strip())
175
+ for i in range(0, len(history) - 1, 2)
176
+ ]
177
+ yield chat, history, cid
178
+
179
+ return chat, history, cid
180
+
181
+
182
+ def clear_chat():
183
+ return [], [], str(uuid.uuid4())
184
+
185
+ # Audio Pipeline util functions
186
+
187
+ def transcribe_and_stream(inputs, model_name="groq_whisper", show_info="show_info", language="english"):
188
+ if inputs:
189
+ if show_info=="show_info":
190
+ gr.Info("Processing Audio", duration=1)
191
+ if model_name != "groq_whisper":
192
+ print("DEVICE:", DEVICE)
193
+ pipe = pipeline(
194
+ task="automatic-speech-recognition",
195
+ model=model_name,
196
+ chunk_length_s=30,
197
+ DEVICE=DEVICE)
198
+ text = pipe(inputs, batch_size=BATCH_SIZE, return_timestamps=True)["text"]
199
+ else:
200
+ text = groq_whisper_tts(inputs)
201
+
202
+ # stream text output
203
+ for i in range(len(text)):
204
+ time.sleep(0.01)
205
+ yield text[: i + 10]
206
+ else:
207
+ return ""
208
+
209
+
210
+ def aya_speech_text_response(text):
211
+ if text:
212
+ stream = audio_response_client.chat_stream(message=text,preamble=AUDIO_RESPONSE_PREAMBLE, model=AYA_MODEL_NAME)
213
+ output = ""
214
+
215
+ for event in stream:
216
+ if event:
217
+ if event.event_type == "text-generation":
218
+ output+=event.text
219
+ cleaned_output = clean_text(output)
220
+ yield cleaned_output
221
+ else:
222
+ return ""
223
+
224
+ def clean_text(text, remove_bullets=False, remove_newline=False):
225
+ # Remove bold formatting
226
+ cleaned_text = re.sub(r"\*\*", "", text)
227
+
228
+ if remove_bullets:
229
+ cleaned_text = re.sub(r"^- ", "", cleaned_text, flags=re.MULTILINE)
230
+
231
+ if remove_newline:
232
+ cleaned_text = re.sub(r"\n", " ", cleaned_text)
233
+
234
+ return cleaned_text
235
+
236
+ def convert_text_to_speech(text, language="english"):
237
+
238
+ # do language detection to determine voice of speech response
239
+ if text:
240
+ # clean text before doing language detection
241
+ cleaned_text = clean_text(text, remove_bullets=True, remove_newline=True)
242
+ text_lang_code = predict_language(cleaned_text)
243
+
244
+ if not USE_ELVENLABS:
245
+ if text_lang_code!= "jpn_Jpan":
246
+ audio_path = neetsai_tts(text, text_lang_code)
247
+ else:
248
+ print("DEVICE:", DEVICE)
249
+ # if language is japanese then use XTTS for TTS since neets_ai doesn't support japanese voice
250
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(DEVICE)
251
+ speaker_wav="samples/ja-sample.wav"
252
+ lang_code="ja"
253
+ audio_path = "./output.wav"
254
+ tts.tts_to_file(text=text, speaker_wav=speaker_wav, language=lang_code, file_path=audio_path)
255
+ else:
256
+ # use elevenlabs for TTS
257
+ audio_path = elevenlabs_generate_audio(text)
258
+
259
+ return audio_path
260
+ else:
261
+ return None
262
+
263
+ def elevenlabs_generate_audio(text):
264
+ audio = eleven_labs_client.generate(
265
+ text=text,
266
+ voice="River",
267
+ model="eleven_turbo_v2_5", #"eleven_multilingual_v2"
268
+ )
269
+ # save audio
270
+ audio_path = "./audio.mp3"
271
+ save(audio, audio_path)
272
+ return audio_path
273
+
274
+ def neetsai_tts(input_text, text_lang_code):
275
+
276
+ if text_lang_code in LID_LANGUAGES.keys():
277
+ language = LID_LANGUAGES[text_lang_code]
278
+ else:
279
+ # use english voice as default for languages outside 23 languages of AJ
280
+ language = "english"
281
+
282
+ neets_lang_id = NEETS_AI_LANGID_MAP[language]
283
+ neets_vits_voice_id = f"vits-{neets_lang_id}"
284
+
285
+ response = requests.request(
286
+ method="POST",
287
+ url="https://api.neets.ai/v1/tts",
288
+ headers={
289
+ "Content-Type": "application/json",
290
+ "X-API-Key": NEETS_AI_API_KEY
291
+ },
292
+ json={
293
+ "text": input_text,
294
+ "voice_id": neets_vits_voice_id,
295
+ "params": {
296
+ "model": "vits"
297
+ }
298
+ }
299
+ )
300
+ # save audio file
301
+ audio_path = "neets_demo.mp3"
302
+ with open(audio_path, "wb") as f:
303
+ f.write(response.content)
304
+ return audio_path
305
+
306
+ def groq_whisper_tts(filename):
307
+ with open(filename, "rb") as file:
308
+ transcriptions = groq_client.audio.transcriptions.create(
309
+ file=(filename, file.read()),
310
+ model="whisper-large-v3-turbo",
311
+ response_format="json",
312
+ temperature=0.0
313
+ )
314
+ print("transcribed text:", transcriptions.text)
315
+ print("********************************")
316
+ return transcriptions.text
317
+
318
+
319
+ # setup gradio app theme
320
+ theme = gr.themes.Base(
321
+ primary_hue=gr.themes.colors.teal,
322
+ secondary_hue=gr.themes.colors.blue,
323
+ neutral_hue=gr.themes.colors.gray,
324
+ text_size=gr.themes.sizes.text_lg,
325
+ ).set(
326
+ # Primary Button Color
327
+ button_primary_background_fill="#114A56",
328
+ button_primary_background_fill_hover="#114A56",
329
+ # Block Labels
330
+ block_title_text_weight="600",
331
+ block_label_text_weight="600",
332
+ block_label_text_size="*text_md",
333
+ )
334
+
335
+
336
+ demo = gr.Blocks(theme=theme, analytics_enabled=False)
337
+
338
+ with demo:
339
+ with gr.Row(variant="panel"):
340
+ with gr.Column(scale=1):
341
+ gr.Image("AyaExpanse.png", elem_id="logo-img", show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False)
342
+ with gr.Column(scale=30):
343
+ gr.Markdown("""C4AI Aya Expanse is a state-of-art model with highly advanced capabilities to connect the world across languages.
344
+ <br/>
345
+ You can use this space to chat, speak and visualize with Aya Expanse in 23 languages.
346
+ <br/>
347
+ **Model**: [aya-expanse-32B](https://huggingface.co/CohereForAI/aya-expanse-32b)
348
+ <br/>
349
+ **Developed by**: [Cohere for AI](https://cohere.com/research) and [Cohere](https://cohere.com/)
350
+ <br/>
351
+ **License**: [CC-BY-NC](https://cohere.com/c4ai-cc-by-nc-license), requires also adhering to [C4AI's Acceptable Use Policy](https://docs.cohere.com/docs/c4ai-acceptable-use-policy)
352
+ """
353
+ )
354
+
355
+ with gr.TabItem("Chat with Aya") as chat_with_AJ:
356
+ cid = gr.State("")
357
+ token = gr.State(value=None)
358
+
359
+ with gr.Column():
360
+ with gr.Row():
361
+ chatbot = gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, height=300)
362
+
363
+ with gr.Row():
364
+ user_message = gr.Textbox(lines=1, placeholder="Ask anything in our 23 languages ...", label="Input", show_label=False)
365
+
366
+
367
+ with gr.Row():
368
+ submit_button = gr.Button("Submit",variant="primary")
369
+ clear_button = gr.Button("Clear")
370
+
371
+
372
+ history = gr.State([])
373
+
374
+ user_message.submit(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32)
375
+ submit_button.click(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32)
376
+
377
+ clear_button.click(fn=clear_chat, inputs=None, outputs=[chatbot, history, cid], concurrency_limit=32)
378
+
379
+ user_message.submit(lambda x: gr.update(value=""), None, [user_message], queue=False)
380
+ submit_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False)
381
+ clear_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False)
382
+
383
+ with gr.Row():
384
+ gr.Examples(
385
+ examples=TEXT_CHAT_EXAMPLES,
386
+ inputs=user_message,
387
+ cache_examples=False,
388
+ fn=trigger_example,
389
+ outputs=[chatbot],
390
+ examples_per_page=25,
391
+ label="Load example prompt for:",
392
+ example_labels=TEXT_CHAT_EXAMPLES_LABELS,
393
+ )
394
+
395
+ # End to End Testing Pipeline for speak with AJ
396
+ with gr.TabItem("Speak with Aya") as speak_with_aya:
397
+
398
+ with gr.Row():
399
+ with gr.Column():
400
+ e2e_audio_file = gr.Audio(sources="microphone", type="filepath", min_length=None)
401
+ e2_audio_submit_button = gr.Button(value="Get Aya's Response", variant="primary")
402
+
403
+ clear_button_microphone = gr.ClearButton()
404
+ gr.Examples(
405
+ examples=AUDIO_EXAMPLES,
406
+ inputs=e2e_audio_file,
407
+ cache_examples=False,
408
+ examples_per_page=25,
409
+ label="Load example audio for:",
410
+ example_labels=AUDIO_EXAMPLES_LABELS,
411
+ )
412
+
413
+ with gr.Column():
414
+ e2e_audio_file_trans = gr.Textbox(lines=3,label="Your Input", autoscroll=False, show_copy_button=True, interactive=False)
415
+ e2e_audio_file_aya_response = gr.Textbox(lines=3,label="Aya's Response", show_copy_button=True, container=True, interactive=False)
416
+ e2e_aya_audio_response = gr.Audio(type="filepath", label="Aya's Audio Response")
417
+
418
+ # show_info = gr.Textbox(value="show_info", visible=False)
419
+ # stt_model = gr.Textbox(value="groq_whisper", visible=False)
420
+
421
+ with gr.Accordion("See Details", open=False):
422
+ gr.Markdown("To enable voice interaction with Aya Expanse, this space uses [Whisper large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) and [Groq](https://groq.com/) for STT and [neets.ai](http://neets.ai/) for TTS.")
423
+
424
+
425
+ # Generate Images
426
+ with gr.TabItem("Visualize with AJ") as visualize_with_aya:
427
+ with gr.Row():
428
+ with gr.Column():
429
+ input_img_prompt = gr.Textbox(placeholder="Ask anything in our 23 languages ...", label="Describe an image", lines=3)
430
+ # generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False)
431
+ submit_button_img = gr.Button(value="Submit", variant="primary")
432
+ clear_button_img = gr.ClearButton()
433
+
434
+
435
+ with gr.Column():
436
+ generated_img = gr.Image(label="Generated Image", interactive=False)
437
+
438
+ with gr.Row():
439
+ gr.Examples(
440
+ examples=IMG_GEN_PROMPT_EXAMPLES,
441
+ inputs=input_img_prompt,
442
+ cache_examples=False,
443
+ examples_per_page=25,
444
+ label="Load example prompt for:",
445
+ example_labels=IMG_GEN_PROMPT_EXAMPLES_LABELS
446
+ )
447
+ generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False)
448
+
449
+ # increase spacing between examples and Accordion components
450
+ with gr.Row():
451
+ pass
452
+ with gr.Row():
453
+ pass
454
+ with gr.Row():
455
+ pass
456
+
457
+ with gr.Row():
458
+ with gr.Accordion("See Details", open=False):
459
+ gr.Markdown("This space uses AJ.Chat for translating multilingual prompts and generating detailed image descriptions and [Flux Schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) for Image Generation.")
460
+
461
+ # Image Generation
462
+ clear_button_img.click(lambda: None, None, input_img_prompt)
463
+ clear_button_img.click(lambda: None, None, generated_img_desc)
464
+ clear_button_img.click(lambda: None, None, generated_img)
465
+
466
+ submit_button_img.click(
467
+ generate_img_prompt,
468
+ inputs=[input_img_prompt],
469
+ outputs=[generated_img_desc],
470
+ )
471
+
472
+ generated_img_desc.change(
473
+ generate_image, #run_flux,
474
+ inputs=[generated_img_desc],
475
+ outputs=[generated_img],
476
+ show_progress="full",
477
+ )
478
+
479
+ # Audio Pipeline
480
+ clear_button_microphone.click(lambda: None, None, e2e_audio_file)
481
+ clear_button_microphone.click(lambda: None, None, e2e_aya_audio_response)
482
+ clear_button_microphone.click(lambda: None, None, e2e_audio_file_aya_response)
483
+ clear_button_microphone.click(lambda: None, None, e2e_audio_file_trans)
484
+
485
+ #e2e_audio_file.change(
486
+ e2_audio_submit_button.click(
487
+ transcribe_and_stream,
488
+ inputs=[e2e_audio_file],
489
+ outputs=[e2e_audio_file_trans],
490
+ show_progress="full",
491
+ ).then(
492
+ aya_speech_text_response,
493
+ inputs=[e2e_audio_file_trans],
494
+ outputs=[e2e_audio_file_aya_response],
495
+ show_progress="full",
496
+ ).then(
497
+ convert_text_to_speech,
498
+ inputs=[e2e_audio_file_aya_response],
499
+ outputs=[e2e_aya_audio_response],
500
+ show_progress="full",
501
+ )
502
+
503
+ demo.load(lambda: secrets.token_hex(16), None, token)
504
+
505
+ demo.queue(api_open=False, max_size=20, default_concurrency_limit=4).launch(show_api=False, allowed_paths=['/home/user/app'])