chat_generator / app.py
bol20162021's picture
Update app.py
8013bb6 verified
import gradio as gr
import random
import openai
from openai import APIError, APIConnectionError, RateLimitError
import os
from PIL import Image # This is the corrected import
import io
import base64
import asyncio
from queue import Queue
from threading import Thread
import time
# Get the current script's directory
current_dir = os.path.dirname(os.path.abspath(__file__))
avatars_dir = os.path.join(current_dir, "avatars")
# Dictionary mapping characters to their avatar image filenames
character_avatars = {
"Harry Potter": "harry.png",
"Hermione Granger": "hermione.png",
"poor Ph.D. student": "phd.png",
"a super cute red panda": "red_panda.png"
}
BACKUP_API_KEY_0 = os.environ.get('BACKUP_API_KEY_0')
BACKUP_API_KEY_1 = os.environ.get('BACKUP_API_KEY_1')
BACKUP_API_KEYS = [BACKUP_API_KEY_0, BACKUP_API_KEY_1]
predefined_characters = ["Harry Potter", "Hermione Granger", "poor Ph.D. student", "a super cute red panda"]
def get_character(dropdown_value, custom_value):
return custom_value if dropdown_value == "Custom" else dropdown_value
def resize_image(image_path, size=(100, 100)):
if not os.path.exists(image_path):
return None
with Image.open(image_path) as img:
img.thumbnail(size)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
resized_avatars = {}
for character, filename in character_avatars.items():
full_path = os.path.join(avatars_dir, filename)
if os.path.exists(full_path):
resized_avatars[character] = resize_image(full_path)
else:
pass
async def generate_response_stream(messages, user_api_key):
# Combine the user's API key with your backup keys
api_keys = [user_api_key] + BACKUP_API_KEYS # backup_api_keys is a list of your internal keys
for idx, api_key in enumerate(api_keys):
client = openai.AsyncOpenAI(
api_key=api_key,
base_url="https://api.sambanova.ai/v1",
)
try:
response = await client.chat.completions.create(
model='Meta-Llama-3.1-405B-Instruct',
messages=messages,
temperature=0.7,
top_p=0.9,
stream=True
)
full_response = ""
async for chunk in response:
if chunk.choices[0].delta.content:
full_response += chunk.choices[0].delta.content
yield full_response
# If successful, exit the loop
return
except RateLimitError:
if idx == len(api_keys) - 1:
# No more API keys to try
raise Exception("Rate limit exceeded")
else:
# Try the next API key
continue
except Exception as e:
# For other exceptions, raise the error
raise e
async def simulate_conversation_stream(character1, character2, initial_message, num_turns, api_key):
messages_character_1 = [
{"role": "system", "content": f"Avoid overly verbose answer in your response. Act as {character1}."},
{"role": "assistant", "content": initial_message}
]
messages_character_2 = [
{"role": "system", "content": f"Avoid overly verbose answer in your response. Act as {character2}."},
{"role": "user", "content": initial_message}
]
conversation = [
{"character": character1, "content": initial_message},
# We will add new messages as we loop
]
yield format_conversation_as_html(conversation)
num_turns *= 2
for turn_num in range(num_turns - 1):
current_character = character2 if turn_num % 2 == 0 else character1
messages = messages_character_2 if turn_num % 2 == 0 else messages_character_1
# Add a new empty message for the current character
conversation.append({"character": current_character, "content": ""})
full_response = ""
try:
async for response in generate_response_stream(messages, api_key):
full_response = response
conversation[-1]["content"] = full_response
yield format_conversation_as_html(conversation)
# After a successful response, update the messages
if turn_num % 2 == 0:
messages_character_1.append({"role": "user", "content": full_response})
messages_character_2.append({"role": "assistant", "content": full_response})
else:
messages_character_2.append({"role": "user", "content": full_response})
messages_character_1.append({"role": "assistant", "content": full_response})
except Exception as e:
# Replace the current message with the error message
error_message = f"Error: {str(e)}"
conversation[-1]["character"] = "System"
conversation[-1]["content"] = error_message
yield format_conversation_as_html(conversation)
# Stop the conversation
break
def stream_conversation(character1, character2, initial_message, num_turns, api_key, queue):
async def run_simulation():
try:
async for html in simulate_conversation_stream(character1, character2, initial_message, num_turns, api_key):
queue.put(html)
queue.put(None) # Signal that the conversation is complete
except Exception as e:
# Handle exceptions and put the error message in the queue
error_message = f"Error: {str(e)}"
queue.put(error_message)
queue.put(None) # Signal that the conversation is complete
# Create a new event loop for the thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(run_simulation())
loop.close()
def validate_api_key(api_key):
if not api_key.strip():
return False, "API key is required. Please enter a valid API key."
return True, ""
def update_api_key_status(api_key):
is_valid, message = validate_api_key(api_key)
if not is_valid:
return f"<p style='color: red;'>{message}</p>"
return ""
def chat_interface(character1_dropdown, character1_custom, character2_dropdown, character2_custom,
initial_message, num_turns, api_key):
character1 = get_character(character1_dropdown, character1_custom)
character2 = get_character(character2_dropdown, character2_custom)
queue = Queue()
thread = Thread(target=stream_conversation, args=(character1, character2, initial_message, num_turns, api_key, queue))
thread.start()
while True:
result = queue.get()
if result is None:
break
yield result
thread.join()
def format_conversation_as_html(conversation):
html_output = """
<style>
.chat-container {
display: flex;
flex-direction: column;
gap: 10px;
font-family: Arial, sans-serif;
}
.message {
display: flex;
padding: 10px;
border-radius: 10px;
max-width: 80%;
align-items: flex-start;
}
.left {
align-self: flex-start;
background-color: #1565C0;
color: #FFFFFF;
}
.right {
align-self: flex-end;
background-color: #2E7D32;
color: #FFFFFF;
flex-direction: row-reverse;
}
.avatar-container {
flex-shrink: 0;
width: 40px;
height: 40px;
margin: 0 10px;
}
.avatar {
width: 100%;
height: 100%;
border-radius: 50%;
object-fit: cover;
}
.message-content {
display: flex;
flex-direction: column;
min-width: 150px;
flex-grow: 1;
}
.character-name {
font-weight: bold;
margin-bottom: 5px;
}
.message-text {
word-wrap: break-word;
overflow-wrap: break-word;
}
</style>
<div class="chat-container">
"""
for i, message in enumerate(conversation):
align = "left" if i % 2 == 0 else "right"
avatar_data = resized_avatars.get(message["character"])
html_output += f'<div class="message {align}">'
if avatar_data:
html_output += f'''
<div class="avatar-container">
<img src="data:image/png;base64,{avatar_data}" class="avatar" alt="{message["character"]} avatar">
</div>
'''
html_output += f'''
<div class="message-content">
<div class="character-name">{message["character"]}</div>
<div class="message-text">{message["content"]}</div>
</div>
</div>
'''
html_output += "</div>"
return html_output
def format_chat_for_download(html_chat):
# Extract text content from HTML
import re
chat_text = re.findall(r'<div class="character-name">(.*?)</div>.*?<div class="message-text">(.*?)</div>', html_chat, re.DOTALL)
return "\n".join([f"{speaker.strip()}: {message.strip()}" for speaker, message in chat_text])
def save_chat_to_file(chat_content):
# Create a downloads directory if it doesn't exist
downloads_dir = os.path.join(os.getcwd(), "downloads")
os.makedirs(downloads_dir, exist_ok=True)
# Generate a unique filename
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"chat_{timestamp}.txt"
file_path = os.path.join(downloads_dir, filename)
# Save the chat content to the file
with open(file_path, "w", encoding="utf-8") as f:
f.write(chat_content)
return file_path
with gr.Blocks() as app:
gr.Markdown("# Character Chat Generator")
gr.Markdown("Powerd by [LLama3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on [SambaNova Cloud](https://cloud.sambanova.ai/apis)")
api_key = gr.Textbox(label="Enter your Sambanova Cloud API Key\n(To get one, go to https://cloud.sambanova.ai/apis)", type="password")
api_key_status = gr.Markdown()
with gr.Column():
character1_dropdown = gr.Dropdown(choices=predefined_characters + ["Custom"], label="Select Character 1")
character1_custom = gr.Textbox(label="Custom Character 1 (if selected above)", visible=False)
with gr.Column():
character2_dropdown = gr.Dropdown(choices=predefined_characters + ["Custom"], label="Select Character 2")
character2_custom = gr.Textbox(label="Custom Character 2 (if selected above)", visible=False)
initial_message = gr.Textbox(label="Initial message (for Character 1)")
num_turns = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of conversation turns")
generate_btn = gr.Button("Generate Conversation")
output = gr.HTML(label="Generated Conversation")
def show_custom_input(choice):
return gr.update(visible=choice == "Custom")
character1_dropdown.change(show_custom_input, inputs=character1_dropdown, outputs=character1_custom)
character2_dropdown.change(show_custom_input, inputs=character2_dropdown, outputs=character2_custom)
api_key.change(update_api_key_status, inputs=[api_key], outputs=[api_key_status])
generate_btn.click(
chat_interface,
inputs=[character1_dropdown, character1_custom, character2_dropdown,
character2_custom, initial_message, num_turns, api_key],
outputs=output,
)
gr.Markdown("## Download Chat History")
download_btn = gr.Button("Download Conversation")
download_output = gr.File(label="Download")
def download_conversation(html_chat):
chat_content = format_chat_for_download(html_chat)
file_path = save_chat_to_file(chat_content)
return file_path
download_btn.click(
download_conversation,
inputs=output,
outputs=download_output
)
app.load(lambda: update_api_key_status(""), outputs=[api_key_status])
if __name__ == "__main__":
app.launch()