Spaces:
Runtime error
Runtime error
import gradio as gr | |
import random | |
import openai | |
from openai import APIError, APIConnectionError, RateLimitError | |
import os | |
from PIL import Image # This is the corrected import | |
import io | |
import base64 | |
import asyncio | |
from queue import Queue | |
from threading import Thread | |
import time | |
# Get the current script's directory | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
avatars_dir = os.path.join(current_dir, "avatars") | |
# Dictionary mapping characters to their avatar image filenames | |
character_avatars = { | |
"Harry Potter": "harry.png", | |
"Hermione Granger": "hermione.png", | |
"poor Ph.D. student": "phd.png", | |
"a super cute red panda": "red_panda.png" | |
} | |
BACKUP_API_KEY_0 = os.environ.get('BACKUP_API_KEY_0') | |
BACKUP_API_KEY_1 = os.environ.get('BACKUP_API_KEY_1') | |
BACKUP_API_KEYS = [BACKUP_API_KEY_0, BACKUP_API_KEY_1] | |
predefined_characters = ["Harry Potter", "Hermione Granger", "poor Ph.D. student", "a super cute red panda"] | |
def get_character(dropdown_value, custom_value): | |
return custom_value if dropdown_value == "Custom" else dropdown_value | |
def resize_image(image_path, size=(100, 100)): | |
if not os.path.exists(image_path): | |
return None | |
with Image.open(image_path) as img: | |
img.thumbnail(size) | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode() | |
resized_avatars = {} | |
for character, filename in character_avatars.items(): | |
full_path = os.path.join(avatars_dir, filename) | |
if os.path.exists(full_path): | |
resized_avatars[character] = resize_image(full_path) | |
else: | |
pass | |
async def generate_response_stream(messages, user_api_key): | |
# Combine the user's API key with your backup keys | |
api_keys = [user_api_key] + BACKUP_API_KEYS # backup_api_keys is a list of your internal keys | |
for idx, api_key in enumerate(api_keys): | |
client = openai.AsyncOpenAI( | |
api_key=api_key, | |
base_url="https://api.sambanova.ai/v1", | |
) | |
try: | |
response = await client.chat.completions.create( | |
model='Meta-Llama-3.1-405B-Instruct', | |
messages=messages, | |
temperature=0.7, | |
top_p=0.9, | |
stream=True | |
) | |
full_response = "" | |
async for chunk in response: | |
if chunk.choices[0].delta.content: | |
full_response += chunk.choices[0].delta.content | |
yield full_response | |
# If successful, exit the loop | |
return | |
except RateLimitError: | |
if idx == len(api_keys) - 1: | |
# No more API keys to try | |
raise Exception("Rate limit exceeded") | |
else: | |
# Try the next API key | |
continue | |
except Exception as e: | |
# For other exceptions, raise the error | |
raise e | |
async def simulate_conversation_stream(character1, character2, initial_message, num_turns, api_key): | |
messages_character_1 = [ | |
{"role": "system", "content": f"Avoid overly verbose answer in your response. Act as {character1}."}, | |
{"role": "assistant", "content": initial_message} | |
] | |
messages_character_2 = [ | |
{"role": "system", "content": f"Avoid overly verbose answer in your response. Act as {character2}."}, | |
{"role": "user", "content": initial_message} | |
] | |
conversation = [ | |
{"character": character1, "content": initial_message}, | |
# We will add new messages as we loop | |
] | |
yield format_conversation_as_html(conversation) | |
num_turns *= 2 | |
for turn_num in range(num_turns - 1): | |
current_character = character2 if turn_num % 2 == 0 else character1 | |
messages = messages_character_2 if turn_num % 2 == 0 else messages_character_1 | |
# Add a new empty message for the current character | |
conversation.append({"character": current_character, "content": ""}) | |
full_response = "" | |
try: | |
async for response in generate_response_stream(messages, api_key): | |
full_response = response | |
conversation[-1]["content"] = full_response | |
yield format_conversation_as_html(conversation) | |
# After a successful response, update the messages | |
if turn_num % 2 == 0: | |
messages_character_1.append({"role": "user", "content": full_response}) | |
messages_character_2.append({"role": "assistant", "content": full_response}) | |
else: | |
messages_character_2.append({"role": "user", "content": full_response}) | |
messages_character_1.append({"role": "assistant", "content": full_response}) | |
except Exception as e: | |
# Replace the current message with the error message | |
error_message = f"Error: {str(e)}" | |
conversation[-1]["character"] = "System" | |
conversation[-1]["content"] = error_message | |
yield format_conversation_as_html(conversation) | |
# Stop the conversation | |
break | |
def stream_conversation(character1, character2, initial_message, num_turns, api_key, queue): | |
async def run_simulation(): | |
try: | |
async for html in simulate_conversation_stream(character1, character2, initial_message, num_turns, api_key): | |
queue.put(html) | |
queue.put(None) # Signal that the conversation is complete | |
except Exception as e: | |
# Handle exceptions and put the error message in the queue | |
error_message = f"Error: {str(e)}" | |
queue.put(error_message) | |
queue.put(None) # Signal that the conversation is complete | |
# Create a new event loop for the thread | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
loop.run_until_complete(run_simulation()) | |
loop.close() | |
def validate_api_key(api_key): | |
if not api_key.strip(): | |
return False, "API key is required. Please enter a valid API key." | |
return True, "" | |
def update_api_key_status(api_key): | |
is_valid, message = validate_api_key(api_key) | |
if not is_valid: | |
return f"<p style='color: red;'>{message}</p>" | |
return "" | |
def chat_interface(character1_dropdown, character1_custom, character2_dropdown, character2_custom, | |
initial_message, num_turns, api_key): | |
character1 = get_character(character1_dropdown, character1_custom) | |
character2 = get_character(character2_dropdown, character2_custom) | |
queue = Queue() | |
thread = Thread(target=stream_conversation, args=(character1, character2, initial_message, num_turns, api_key, queue)) | |
thread.start() | |
while True: | |
result = queue.get() | |
if result is None: | |
break | |
yield result | |
thread.join() | |
def format_conversation_as_html(conversation): | |
html_output = """ | |
<style> | |
.chat-container { | |
display: flex; | |
flex-direction: column; | |
gap: 10px; | |
font-family: Arial, sans-serif; | |
} | |
.message { | |
display: flex; | |
padding: 10px; | |
border-radius: 10px; | |
max-width: 80%; | |
align-items: flex-start; | |
} | |
.left { | |
align-self: flex-start; | |
background-color: #1565C0; | |
color: #FFFFFF; | |
} | |
.right { | |
align-self: flex-end; | |
background-color: #2E7D32; | |
color: #FFFFFF; | |
flex-direction: row-reverse; | |
} | |
.avatar-container { | |
flex-shrink: 0; | |
width: 40px; | |
height: 40px; | |
margin: 0 10px; | |
} | |
.avatar { | |
width: 100%; | |
height: 100%; | |
border-radius: 50%; | |
object-fit: cover; | |
} | |
.message-content { | |
display: flex; | |
flex-direction: column; | |
min-width: 150px; | |
flex-grow: 1; | |
} | |
.character-name { | |
font-weight: bold; | |
margin-bottom: 5px; | |
} | |
.message-text { | |
word-wrap: break-word; | |
overflow-wrap: break-word; | |
} | |
</style> | |
<div class="chat-container"> | |
""" | |
for i, message in enumerate(conversation): | |
align = "left" if i % 2 == 0 else "right" | |
avatar_data = resized_avatars.get(message["character"]) | |
html_output += f'<div class="message {align}">' | |
if avatar_data: | |
html_output += f''' | |
<div class="avatar-container"> | |
<img src="data:image/png;base64,{avatar_data}" class="avatar" alt="{message["character"]} avatar"> | |
</div> | |
''' | |
html_output += f''' | |
<div class="message-content"> | |
<div class="character-name">{message["character"]}</div> | |
<div class="message-text">{message["content"]}</div> | |
</div> | |
</div> | |
''' | |
html_output += "</div>" | |
return html_output | |
def format_chat_for_download(html_chat): | |
# Extract text content from HTML | |
import re | |
chat_text = re.findall(r'<div class="character-name">(.*?)</div>.*?<div class="message-text">(.*?)</div>', html_chat, re.DOTALL) | |
return "\n".join([f"{speaker.strip()}: {message.strip()}" for speaker, message in chat_text]) | |
def save_chat_to_file(chat_content): | |
# Create a downloads directory if it doesn't exist | |
downloads_dir = os.path.join(os.getcwd(), "downloads") | |
os.makedirs(downloads_dir, exist_ok=True) | |
# Generate a unique filename | |
import datetime | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"chat_{timestamp}.txt" | |
file_path = os.path.join(downloads_dir, filename) | |
# Save the chat content to the file | |
with open(file_path, "w", encoding="utf-8") as f: | |
f.write(chat_content) | |
return file_path | |
with gr.Blocks() as app: | |
gr.Markdown("# Character Chat Generator") | |
gr.Markdown("Powerd by [LLama3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on [SambaNova Cloud](https://cloud.sambanova.ai/apis)") | |
api_key = gr.Textbox(label="Enter your Sambanova Cloud API Key\n(To get one, go to https://cloud.sambanova.ai/apis)", type="password") | |
api_key_status = gr.Markdown() | |
with gr.Column(): | |
character1_dropdown = gr.Dropdown(choices=predefined_characters + ["Custom"], label="Select Character 1") | |
character1_custom = gr.Textbox(label="Custom Character 1 (if selected above)", visible=False) | |
with gr.Column(): | |
character2_dropdown = gr.Dropdown(choices=predefined_characters + ["Custom"], label="Select Character 2") | |
character2_custom = gr.Textbox(label="Custom Character 2 (if selected above)", visible=False) | |
initial_message = gr.Textbox(label="Initial message (for Character 1)") | |
num_turns = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of conversation turns") | |
generate_btn = gr.Button("Generate Conversation") | |
output = gr.HTML(label="Generated Conversation") | |
def show_custom_input(choice): | |
return gr.update(visible=choice == "Custom") | |
character1_dropdown.change(show_custom_input, inputs=character1_dropdown, outputs=character1_custom) | |
character2_dropdown.change(show_custom_input, inputs=character2_dropdown, outputs=character2_custom) | |
api_key.change(update_api_key_status, inputs=[api_key], outputs=[api_key_status]) | |
generate_btn.click( | |
chat_interface, | |
inputs=[character1_dropdown, character1_custom, character2_dropdown, | |
character2_custom, initial_message, num_turns, api_key], | |
outputs=output, | |
) | |
gr.Markdown("## Download Chat History") | |
download_btn = gr.Button("Download Conversation") | |
download_output = gr.File(label="Download") | |
def download_conversation(html_chat): | |
chat_content = format_chat_for_download(html_chat) | |
file_path = save_chat_to_file(chat_content) | |
return file_path | |
download_btn.click( | |
download_conversation, | |
inputs=output, | |
outputs=download_output | |
) | |
app.load(lambda: update_api_key_status(""), outputs=[api_key_status]) | |
if __name__ == "__main__": | |
app.launch() |