bol20162021 commited on
Commit
077e61f
1 Parent(s): ec7e5c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +290 -0
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import openai
4
+ from openai import APIError, APIConnectionError, RateLimitError
5
+ import os
6
+ from PIL import Image # This is the corrected import
7
+ import io
8
+ import base64
9
+ import asyncio
10
+ from queue import Queue
11
+ from threading import Thread
12
+ import time
13
+
14
+ # Get the current script's directory
15
+ current_dir = os.path.dirname(os.path.abspath(__file__))
16
+ avatars_dir = os.path.join(current_dir, "avatars")
17
+
18
+ # Dictionary mapping characters to their avatar image filenames
19
+ character_avatars = {
20
+ "Harry Potter": "harry.png",
21
+ "Hermione Granger": "hermione.png",
22
+ "poor Ph.D. student": "phd.png",
23
+ "Donald Trump": "trump.png",
24
+ "a super cute red panda": "red_panda.png"
25
+ }
26
+
27
+
28
+ predefined_characters = ["Harry Potter", "Hermione Granger", "poor Ph.D. student", "Donald Trump", "a super cute red panda"]
29
+
30
+ def get_character(dropdown_value, custom_value):
31
+ return custom_value if dropdown_value == "Custom" else dropdown_value
32
+
33
+ def resize_image(image_path, size=(100, 100)):
34
+ if not os.path.exists(image_path):
35
+ return None
36
+ with Image.open(image_path) as img:
37
+ img.thumbnail(size)
38
+ buffered = io.BytesIO()
39
+ img.save(buffered, format="PNG")
40
+ return base64.b64encode(buffered.getvalue()).decode()
41
+
42
+ resized_avatars = {}
43
+ for character, filename in character_avatars.items():
44
+ full_path = os.path.join(avatars_dir, filename)
45
+ if os.path.exists(full_path):
46
+ resized_avatars[character] = resize_image(full_path)
47
+ else:
48
+ pass
49
+
50
+ async def generate_response_stream(messages, api_key):
51
+ client = openai.AsyncOpenAI(
52
+ api_key=api_key,
53
+ base_url="https://api.sambanova.ai/v1",
54
+ )
55
+ try:
56
+ if len(messages) >= 10:
57
+ # avoid hitting rate limit
58
+ time.sleep(0.5)
59
+ response = await client.chat.completions.create(
60
+ model='Meta-Llama-3.1-405B-Instruct',
61
+ messages=messages,
62
+ temperature=0.7,
63
+ top_p=0.9,
64
+ stream=True
65
+ )
66
+ full_response = ""
67
+ async for chunk in response:
68
+ if chunk.choices[0].delta.content is not None:
69
+ full_response += chunk.choices[0].delta.content
70
+ yield full_response
71
+ except Exception as e:
72
+ yield f"Error: {str(e)}"
73
+
74
+ async def simulate_conversation_stream(character1, character2, initial_message, num_turns, api_key):
75
+ messages_character_1 = [{"role": "system", "content": f"Avoid overly verbose answer in your response. Act as {character1}."},
76
+ {"role": "assistant", "content": initial_message}]
77
+ messages_character_2 = [{"role": "system", "content": f"Avoid overly verbose answer in your response. Act as {character2}."},
78
+ {"role": "user", "content": initial_message}]
79
+
80
+ conversation = [
81
+ {"character": character1, "content": initial_message},
82
+ {"character": character2, "content": ""} # Initialize with an empty response for character2
83
+ ]
84
+ yield format_conversation_as_html(conversation)
85
+ num_turns *= 2
86
+ for turn_num in range(num_turns - 1):
87
+ current_character = character2 if turn_num % 2 == 0 else character1
88
+ messages = messages_character_2 if turn_num % 2 == 0 else messages_character_1
89
+
90
+ full_response = ""
91
+ async for response in generate_response_stream(messages, api_key):
92
+ full_response = response
93
+ conversation[-1]["content"] = full_response
94
+ yield format_conversation_as_html(conversation)
95
+
96
+ if turn_num % 2 == 0:
97
+ messages_character_1.append({"role": "user", "content": full_response})
98
+ messages_character_2.append({"role": "assistant", "content": full_response})
99
+ else:
100
+ messages_character_2.append({"role": "user", "content": full_response})
101
+ messages_character_1.append({"role": "assistant", "content": full_response})
102
+
103
+ # Add a new empty message for the next turn, if it's not the last turn
104
+ if turn_num < num_turns - 2:
105
+ next_character = character1 if turn_num % 2 == 0 else character2
106
+ conversation.append({"character": next_character, "content": ""})
107
+
108
+ def stream_conversation(character1, character2, initial_message, num_turns, api_key, queue):
109
+ async def run_simulation():
110
+ async for html in simulate_conversation_stream(character1, character2, initial_message, num_turns, api_key):
111
+ queue.put(html)
112
+ queue.put(None) # Signal that the conversation is complete
113
+
114
+ asyncio.run(run_simulation())
115
+
116
+ def chat_interface(character1_dropdown, character1_custom, character2_dropdown, character2_custom,
117
+ initial_message, num_turns, api_key):
118
+
119
+ character1 = get_character(character1_dropdown, character1_custom)
120
+ character2 = get_character(character2_dropdown, character2_custom)
121
+
122
+ queue = Queue()
123
+ thread = Thread(target=stream_conversation, args=(character1, character2, initial_message, num_turns, api_key, queue))
124
+ thread.start()
125
+
126
+ while True:
127
+ result = queue.get()
128
+ if result is None:
129
+ break
130
+ yield result
131
+
132
+ thread.join()
133
+
134
+ def format_conversation_as_html(conversation):
135
+ html_output = """
136
+ <style>
137
+ .chat-container {
138
+ display: flex;
139
+ flex-direction: column;
140
+ gap: 10px;
141
+ font-family: Arial, sans-serif;
142
+ }
143
+ .message {
144
+ display: flex;
145
+ padding: 10px;
146
+ border-radius: 10px;
147
+ max-width: 80%;
148
+ align-items: flex-start;
149
+ }
150
+ .left {
151
+ align-self: flex-start;
152
+ background-color: #1565C0;
153
+ color: #FFFFFF;
154
+ }
155
+ .right {
156
+ align-self: flex-end;
157
+ background-color: #2E7D32;
158
+ color: #FFFFFF;
159
+ flex-direction: row-reverse;
160
+ }
161
+ .avatar-container {
162
+ flex-shrink: 0;
163
+ width: 40px;
164
+ height: 40px;
165
+ margin: 0 10px;
166
+ }
167
+ .avatar {
168
+ width: 100%;
169
+ height: 100%;
170
+ border-radius: 50%;
171
+ object-fit: cover;
172
+ }
173
+ .message-content {
174
+ display: flex;
175
+ flex-direction: column;
176
+ min-width: 150px;
177
+ flex-grow: 1;
178
+ }
179
+ .character-name {
180
+ font-weight: bold;
181
+ margin-bottom: 5px;
182
+ }
183
+ .message-text {
184
+ word-wrap: break-word;
185
+ overflow-wrap: break-word;
186
+ }
187
+ </style>
188
+ <div class="chat-container">
189
+ """
190
+
191
+ for i, message in enumerate(conversation):
192
+ align = "left" if i % 2 == 0 else "right"
193
+ avatar_data = resized_avatars.get(message["character"])
194
+
195
+ html_output += f'<div class="message {align}">'
196
+
197
+ if avatar_data:
198
+ html_output += f'''
199
+ <div class="avatar-container">
200
+ <img src="data:image/png;base64,{avatar_data}" class="avatar" alt="{message["character"]} avatar">
201
+ </div>
202
+ '''
203
+
204
+ html_output += f'''
205
+ <div class="message-content">
206
+ <div class="character-name">{message["character"]}</div>
207
+ <div class="message-text">{message["content"]}</div>
208
+ </div>
209
+ </div>
210
+ '''
211
+
212
+ html_output += "</div>"
213
+ return html_output
214
+
215
+
216
+ def format_chat_for_download(html_chat):
217
+ # Extract text content from HTML
218
+ import re
219
+ chat_text = re.findall(r'<div class="character-name">(.*?)</div>.*?<div class="message-text">(.*?)</div>', html_chat, re.DOTALL)
220
+ return "\n".join([f"{speaker.strip()}: {message.strip()}" for speaker, message in chat_text])
221
+
222
+ def save_chat_to_file(chat_content):
223
+ # Create a downloads directory if it doesn't exist
224
+ downloads_dir = os.path.join(os.getcwd(), "downloads")
225
+ os.makedirs(downloads_dir, exist_ok=True)
226
+
227
+ # Generate a unique filename
228
+ import datetime
229
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
230
+ filename = f"chat_{timestamp}.txt"
231
+ file_path = os.path.join(downloads_dir, filename)
232
+
233
+ # Save the chat content to the file
234
+ with open(file_path, "w", encoding="utf-8") as f:
235
+ f.write(chat_content)
236
+
237
+ return file_path
238
+
239
+
240
+ with gr.Blocks() as app:
241
+ gr.Markdown("# Character Chat Generator")
242
+
243
+ gr.Markdown("Powerd by [LLama3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on [SambaNova Cloud](https://cloud.sambanova.ai/apis)")
244
+ api_key = gr.Textbox(label="Enter your Sambanova Cloud API Key\n(To get one, go to https://cloud.sambanova.ai/apis)", type="password")
245
+
246
+ with gr.Column():
247
+ character1_dropdown = gr.Dropdown(choices=predefined_characters + ["Custom"], label="Select Character 1")
248
+ character1_custom = gr.Textbox(label="Custom Character 1 (if selected above)", visible=False)
249
+ with gr.Column():
250
+ character2_dropdown = gr.Dropdown(choices=predefined_characters + ["Custom"], label="Select Character 2")
251
+ character2_custom = gr.Textbox(label="Custom Character 2 (if selected above)", visible=False)
252
+
253
+ initial_message = gr.Textbox(label="Initial message (for Character 1)")
254
+ num_turns = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of conversation turns")
255
+
256
+ generate_btn = gr.Button("Generate Conversation")
257
+ output = gr.HTML(label="Generated Conversation")
258
+
259
+
260
+ def show_custom_input(choice):
261
+ return gr.update(visible=choice == "Custom")
262
+
263
+ character1_dropdown.change(show_custom_input, inputs=character1_dropdown, outputs=character1_custom)
264
+ character2_dropdown.change(show_custom_input, inputs=character2_dropdown, outputs=character2_custom)
265
+
266
+ generate_btn.click(
267
+ chat_interface,
268
+ inputs=[character1_dropdown, character1_custom, character2_dropdown,
269
+ character2_custom, initial_message, num_turns, api_key],
270
+ outputs=output,
271
+ )
272
+
273
+ gr.Markdown("## Download Chat History")
274
+
275
+ download_btn = gr.Button("Download Conversation")
276
+ download_output = gr.File(label="Download")
277
+
278
+ def download_conversation(html_chat):
279
+ chat_content = format_chat_for_download(html_chat)
280
+ file_path = save_chat_to_file(chat_content)
281
+ return file_path
282
+
283
+ download_btn.click(
284
+ download_conversation,
285
+ inputs=output,
286
+ outputs=download_output
287
+ )
288
+
289
+ if __name__ == "__main__":
290
+ app.launch()