analogy / app.py
ritwikbiswas's picture
try diffusion fix
42cd0c7
raw
history blame
2.28 kB
import gradio as gr
import os
import openai
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import torch
model_id = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
# pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# pipe = pipe.to("cuda")
openai.api_key = os.getenv("OPENAI_API_KEY")
def generate_prompt(radio,word1,word2):
#prompt = 'Create an analogy for this phrase:\n\n{word1}'
# 50/50 in that/because
# pluralize singluar words
if radio == "normal":
prompt_in = f'Create an analogy for this phrase:\n\n{word1} is like {word2} in that:'
else:
prompt_in = f'Create a {radio} analogy for this phrase:\n\n{word1} is like {word2} in that:'
response = openai.Completion.create(
model="text-davinci-003",
prompt=prompt_in,
temperature=0.5,
max_tokens=60,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0
)['choices'][0]['text']
response_txt = response.replace('\n','')
diffusion_in = f'a dramatic painting of: {response_txt.split(".")[0]}'
image = pipe(diffusion_in).images[0]
return response_txt, image
demo = gr.Interface(
generate_prompt,
[
gr.Radio(["normal", "very insulting"],value='normal',label="Flavor"),
gr.Textbox(label="Thing 1"),
gr.Textbox(label="Thing 2")
# gr.Dropdown(team_list, value=[team_list[random.randint(1,30)]], multiselect=True),
# gr.Checkbox(label="Is it the morning?"),
],
["text","image"],
# "image",
allow_flagging="never",
title="GPT-3 Analogy Lab 🧪",
description="Enter two things you want to connect.",
css="footer {visibility: hidden}"
)
demo.launch()
#openai.api_key = os.getenv("OPENAI_API_KEY")
# openai.api_key = "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"
# response = openai.Completion.create(
# model="text-davinci-003",
# prompt="Create an analogy for this phrase:\n\nQuestions are arrows in that:",
# temperature=0.5,
# max_tokens=60,
# top_p=1.0,
# frequency_penalty=0.0,
# presence_penalty=0.0
# )