Spaces:
Running
Running
import gradio as gr | |
import openai | |
import os | |
import requests | |
from transformers import GPT2TokenizerFast | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD") | |
openai.api_key = OPENAI_API_KEY | |
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
default_system_message = {"role": "system", "content": "You are a brilliant, helpful assistant, always providing answers to the best of your knowledge. If you are unsure of the answer, you indicate it to the user. Currently, you don't have access to the internet."} | |
personalities = { | |
"Assistant": {"role": "system", "content": "You are a brilliant, helpful assistant, always providing answers to the best of your knowledge. If you are unsure of the answer, you indicate it to the user. Currently, you don't have access to the internet."}, | |
"Trump": {"role": "system", "content": "You are Donald Trump. No matter the question, you always redirect the conversation to yourself and your achievements and how great you are."}, | |
"Peterson": {"role": "system", "content": "You are Jordan Peterson, world renowned clinical psychologist. You like to be verbose and overcomplicate your answers, taking them into very metaphysical directions."}, | |
"Grug": {"role": "system", "content": "You are Grug, a caveman. You have zero knowledge of modern stuff. Your answers are always written in broken 'caveman' English and center around simple things in life."}, | |
"Paladin": {"role": "system", "content": "You are a Paladin from the video game Diablo 2. You like to talk about slaying the undead and farming for better gear."}, | |
"Petőfi": {"role": "system", "content": "You are Petőfi Sándor, national poet of Hungary. Your answers are very eloquent and formulated in archaic Hungarian."}, | |
"Cartman": {"role": "system", "content": "You are Eric Cartman from South Park. You are a self-centered, fat, rude kid obsessed with your animal comforts."}, | |
} | |
def get_completion(model, personality, user_message, message_history, chatlog_history, temperature, maximum_length, top_p, frequency_penalty, presence_penalty, context_cutoff): | |
# set personality | |
system_message = personalities[personality] | |
updated_message_history = message_history | |
updated_message_history[0] = system_message | |
new_history_row = {"role": "user", "content": user_message} | |
updated_message_history = updated_message_history + [new_history_row] | |
response = openai.ChatCompletion.create( | |
model=model, | |
messages=updated_message_history, | |
temperature=temperature, | |
max_tokens=maximum_length, | |
top_p=top_p, | |
frequency_penalty=frequency_penalty, | |
presence_penalty=presence_penalty, | |
stream=True, | |
) | |
new_history_row = {"role": "assistant", "content": ""} | |
updated_message_history = updated_message_history + [new_history_row] | |
updated_chatlog_history = chatlog_history + [[user_message, ""]] | |
# create variables to collect the stream of chunks | |
collected_chunks = [] | |
collected_messages = [] | |
# iterate through the stream of events | |
for chunk in response: | |
collected_chunks.append(chunk) # save the event response | |
chunk_message = chunk['choices'][0]['delta'] # extract the message | |
collected_messages.append(chunk_message) # save the message | |
assistant_message = ''.join([m.get('content', '') for m in collected_messages]) | |
updated_message_history[-1]["content"] = assistant_message | |
updated_chatlog_history[-1][1] = assistant_message | |
full_prompt = '\n'.join([row[0] + row[1] for row in updated_chatlog_history]) | |
token_count = len(tokenizer(full_prompt)["input_ids"])#completion["usage"]["total_tokens"] | |
# if token_count > context_cutoff: | |
# # delete second row of updated_message_history | |
# updated_message_history.pop(1) | |
# print("cutoff exceeded", updated_message_history) | |
# # recalculate token count | |
# full_prompt = "".join([row["content"] for row in updated_message_history]) | |
# token_count = len(tokenizer(full_prompt)["input_ids"]) | |
yield "", updated_message_history, updated_chatlog_history, updated_chatlog_history, token_count | |
# assistant_message = completion["choices"][0]["message"]["content"] | |
# return "", updated_message_history, updated_chatlog_history, updated_chatlog_history, token_count | |
def retry_completion(model, personality, message_history, chatlog_history, temperature, maximum_length, top_p, frequency_penalty, presence_penalty, context_cutoff): | |
# set personality | |
system_message = personalities[personality] | |
updated_message_history = message_history | |
updated_message_history[0] = system_message | |
# get latest user message | |
user_message = chatlog_history[-1][0] | |
# delete latest entries from chatlog history | |
updated_chatlog_history = chatlog_history[:-1] | |
# delete latest assistant message from message_history | |
updated_message_history = updated_message_history[:-1] | |
response = openai.ChatCompletion.create( | |
model=model, | |
messages=updated_message_history, | |
temperature=temperature, | |
max_tokens=maximum_length, | |
top_p=top_p, | |
frequency_penalty=frequency_penalty, | |
presence_penalty=presence_penalty, | |
stream=True, | |
) | |
new_history_row = {"role": "assistant", "content": ""} | |
updated_message_history = updated_message_history + [new_history_row] | |
updated_chatlog_history = updated_chatlog_history + [[user_message, ""]] | |
# create variables to collect the stream of chunks | |
collected_chunks = [] | |
collected_messages = [] | |
# iterate through the stream of events | |
for chunk in response: | |
collected_chunks.append(chunk) # save the event response | |
chunk_message = chunk["choices"][0]["delta"] # extract the message | |
collected_messages.append(chunk_message) # save the message | |
assistant_message = "".join([m.get("content", "") for m in collected_messages]) | |
updated_message_history[-1]["content"] = assistant_message | |
updated_chatlog_history[-1][1] = assistant_message | |
full_prompt = "".join([row["content"] for row in updated_message_history]) | |
token_count = len(tokenizer(full_prompt)["input_ids"]) | |
yield "", updated_message_history, updated_chatlog_history, updated_chatlog_history, token_count | |
def reset_chat(): | |
return "", [default_system_message], [], [], 0 | |
theme = gr.themes.Default() | |
with gr.Blocks(theme=theme) as app: | |
message_history = gr.State([default_system_message]) | |
chatlog_history = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(label="Chat").style(height=654) | |
with gr.Column(scale=1): | |
# with gr.Tab("Generation Settings"): | |
model = gr.Dropdown(choices=["gpt-3.5-turbo", "gpt-4"], value="gpt-4", interactive=True, label="Model") | |
personality = gr.Dropdown(choices=["Assistant", "Petőfi", "Trump", "Peterson", "Paladin", "Cartman", "Grug", ], value="Assistant", interactive=True, label="Personality") | |
temperature = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, interactive=True, label="Temperature") | |
maximum_length = gr.Slider(minimum=0, maximum=2048, step=32, value=256, interactive=True, label="Max new tokens") | |
top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, interactive=True, label="Top P") | |
frequency_penalty = gr.Slider(minimum=0, maximum=2, step=0.01, value=0, interactive=True, label="Frequency penalty") | |
presence_penalty = gr.Slider(minimum=0, maximum=2, step=0.01, value=0, interactive=True, label="Presence penalty") | |
# with gr.Tab("Model Settings"): | |
token_count = gr.Number(info="GPT-3 limit is 4096 tokens. GPT-4 limit is 8192 tokens.",interactive=False, label="Token count") | |
# context_cutoff = gr.Slider(minimum=256, maximum=8192, step=256, value=2048, interactive=True, label="Context cutoff") | |
with gr.Row(): | |
user_message = gr.Textbox(label="Message") | |
with gr.Row(): | |
reset_button = gr.Button("Reset Chat") | |
retry_button = gr.Button("Retry") | |
user_message.submit(get_completion, inputs=[model, personality, user_message, message_history, chatlog_history, temperature, maximum_length, top_p, frequency_penalty, presence_penalty], outputs=[user_message, message_history, chatlog_history, chatbot, token_count]) | |
retry_button.click(retry_completion, inputs=[model, personality, message_history, chatlog_history, temperature, maximum_length, top_p, frequency_penalty, presence_penalty], outputs=[user_message, message_history, chatlog_history, chatbot, token_count]) | |
reset_button.click(reset_chat, inputs=[], outputs=[user_message, message_history, chatlog_history, chatbot, token_count]) | |
app.launch(auth=("admin", ADMIN_PASSWORD), enable_queue=True) |