mrbeliever's picture
Update app.py
dcd919f verified
raw
history blame
6.82 kB
import os
import time
import uuid
from typing import List, Tuple, Optional, Dict, Union
import google.generativeai as genai
import gradio as gr
from PIL import Image
print("google-generativeai:", genai.__version__)
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
TITLE = """ """
SUBTITLE = """ """
DUPLICATE = """"""
AVATAR_IMAGES = (
None,
"https://media.roboflow.com/spaces/gemini-icon.png"
)
IMAGE_CACHE_DIRECTORY = "/tmp"
IMAGE_WIDTH = 512
CHAT_HISTORY = List[Tuple[Optional[Union[Tuple[str], str]], Optional[str]]]
def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
if not stop_sequences:
return None
return [sequence.strip() for sequence in stop_sequences.split(",")]
def preprocess_image(image: Image.Image) -> Optional[Image.Image]:
image_height = int(image.height * IMAGE_WIDTH / image.width)
return image.resize((IMAGE_WIDTH, image_height))
def cache_pil_image(image: Image.Image) -> str:
image_filename = f"{uuid.uuid4()}.jpeg"
os.makedirs(IMAGE_CACHE_DIRECTORY, exist_ok=True)
image_path = os.path.join(IMAGE_CACHE_DIRECTORY, image_filename)
image.save(image_path, "JPEG")
return image_path
def preprocess_chat_history(
history: CHAT_HISTORY
) -> List[Dict[str, Union[str, List[str]]]]:
messages = []
for user_message, model_message in history:
if isinstance(user_message, tuple):
pass
elif user_message is not None:
messages.append({'role': 'user', 'parts': [user_message]})
if model_message is not None:
messages.append({'role': 'model', 'parts': [model_message]})
return messages
def upload(files: Optional[List[str]], chatbot: CHAT_HISTORY) -> CHAT_HISTORY:
for file in files:
image = Image.open(file).convert('RGB')
image = preprocess_image(image)
image_path = cache_pil_image(image)
chatbot.append(((image_path,), None))
return chatbot
def user(text_prompt: str, chatbot: CHAT_HISTORY):
if text_prompt:
# Pre-filled text to go with user input
prefilled_text = "You are a specialized Prompt Generator focused on improving the original text while maintaining its essence. Keep the prompt length under 50 words never exceed this limit"
full_prompt = f"{prefilled_text} {text_prompt}"
chatbot.append((full_prompt, None))
return "", chatbot
def bot(
google_key: str,
files: Optional[List[str]],
temperature: float,
max_output_tokens: int,
stop_sequences: str,
top_k: int,
top_p: float,
chatbot: CHAT_HISTORY
):
if len(chatbot) == 0:
return ''
google_key = google_key if google_key else GOOGLE_API_KEY
if not google_key:
raise ValueError(
"GOOGLE_API_KEY is not set. "
"Please follow the instructions in the README to set it up.")
genai.configure(api_key=google_key)
generation_config = genai.types.GenerationConfig(
temperature=temperature,
max_output_tokens=max_output_tokens,
stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences),
top_k=top_k,
top_p=top_p)
if files:
text_prompt = [chatbot[-1][0]] \
if chatbot[-1][0] and isinstance(chatbot[-1][0], str) \
else []
image_prompt = [Image.open(file).convert('RGB') for file in files]
model = genai.GenerativeModel('gemini-pro-vision')
response = model.generate_content(
text_prompt + image_prompt,
stream=True,
generation_config=generation_config)
else:
messages = preprocess_chat_history(chatbot)
model = genai.GenerativeModel('gemini-pro')
response = model.generate_content(
messages,
stream=True,
generation_config=generation_config)
generated_text = ''
for chunk in response:
generated_text += chunk.text
return generated_text
output_text_component = gr.Textbox(
label="Generated Text",
value="",
placeholder="Generated text will appear here",
scale=8,
multiline=True,
)
def copy_text():
output_text_component.copy()
output_text_component_copy = gr.HTML(
"<svg xmlns='http://www.w3.org/2000/svg' width='24' height='24' viewBox='0 0 24 24' "
"fill='none' stroke='currentColor' stroke-width='2' stroke-linecap='round' "
"stroke-linejoin='round' class='feather feather-copy' onclick='copyText()'>"
"<rect x='9' y='9' width='13' height='13' rx='2' ry='2'></rect>"
"<path d='M9 15h4'></path><path d='M15 9v6'></path></svg>"
"<script>"
"function copyText() {"
"var copyText = document.getElementById('output-text');"
"copyText.select();"
"document.execCommand('copy');"
"alert('Copied to clipboard!');"
"}"
"</script>"
)
text_prompt_component = gr.Textbox(
placeholder="Hi there! [press Enter]",
show_label=False,
autofocus=True,
scale=8,
)
chatbot_component = gr.Chatbot(
label='Gemini',
bubble_full_width=False,
avatar_images=AVATAR_IMAGES,
scale=2,
height=400
)
user_inputs = [
text_prompt_component,
chatbot_component
]
bot_inputs = [
google_key_component,
upload_button_component,
temperature_component,
max_output_tokens_component,
stop_sequences_component,
top_k_component,
top_p_component,
chatbot_component
]
with gr.Blocks() as demo:
gr.HTML(TITLE)
gr.HTML(SUBTITLE)
gr.HTML(DUPLICATE)
with gr.Column():
chatbot_component.render()
with gr.Row():
text_prompt_component.render()
upload_button_component.render()
run_button_component.render()
with gr.Accordion("Parameters", open=False):
temperature_component.render()
max_output_tokens_component.render()
stop_sequences_component.render()
with gr.Accordion("Advanced", open=False):
top_k_component.render()
top_p_component.render()
run_button_component.click(
fn=user,
inputs=user_inputs,
outputs=[output_text_component, chatbot_component],
queue=False
).then(
fn=bot, inputs=bot_inputs, outputs=[output_text_component_copy],
)
text_prompt_component.submit(
fn=user,
inputs=user_inputs,
outputs=[output_text_component, chatbot_component],
queue=False
).then(
fn=bot, inputs=bot_inputs, outputs=[output_text_component_copy],
)
upload_button_component.upload(
fn=upload,
inputs=[upload_button_component, chatbot_component],
outputs=[output_text_component, chatbot_component],
queue=False
)
demo.queue(max_size=99).launch(debug=False, show_error=True)