Text-To-Image / app.py
LDY's picture
Update app.py
7ea818e
import gradio
import subprocess
from PIL import Image
import torch, torch.backends.cudnn, torch.backends.cuda
from min_dalle import MinDalle
from emoji import demojize
import string
def filename_from_text(text: str) -> str:
text = demojize(text, delimiters=['', ''])
text = text.lower().encode('ascii', errors='ignore').decode()
allowed_chars = string.ascii_lowercase + ' '
text = ''.join(i for i in text.lower() if i in allowed_chars)
text = text[:64]
text = '-'.join(text.strip().split())
if len(text) == 0: text = 'blank'
return text
def log_gpu_memory():
print(subprocess.check_output('nvidia-smi').decode('utf-8'))
# log_gpu_memory()
model = MinDalle(
is_mega=True,
is_reusable=True,
device='cpu',
# dtype=torch.float32
)
# log_gpu_memory()
def run_model(
text: str,
grid_size: int,
is_seamless: bool,
save_as_png: bool,
temperature: float,
supercondition: str,
top_k: str
) -> str:
torch.set_grad_enabled(False)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
print('text:', text)
print('grid_size:', grid_size)
print('is_seamless:', is_seamless)
print('temperature:', temperature)
print('supercondition:', supercondition)
print('top_k:', top_k)
try:
temperature = float(temperature)
assert(temperature > 1e-6)
except:
raise Exception('Temperature must be a positive nonzero number')
try:
grid_size = int(grid_size)
assert(grid_size <= 5)
assert(grid_size >= 1)
except:
raise Exception('Grid size must be between 1 and 5')
try:
top_k = int(top_k)
assert(top_k <= 16384)
assert(top_k >= 1)
except:
raise Exception('Top k must be between 1 and 16384')
with torch.no_grad():
image = model.generate_image(
text = text,
seed = -1,
grid_size = grid_size,
is_seamless = bool(is_seamless),
temperature = temperature,
supercondition_factor = float(supercondition),
top_k = top_k,
is_verbose = True
)
log_gpu_memory()
ext = 'png' if bool(save_as_png) else 'jpg'
filename = filename_from_text(text)
image_path = '{}.{}'.format(filename, ext)
image.save(image_path)
return image_path
demo = gradio.Blocks(analytics_enabled=True)
with demo:
with gradio.Row():
with gradio.Column():
input_text = gradio.Textbox(
label='Input Text',
value='Rusty Iron Man suit found abandoned in the woods being reclaimed by nature',
lines=3
)
run_button = gradio.Button(value='Generate Image').style(full_width=True)
output_image = gradio.Image(
value='example/8k dog.png',
label='Output Image',
type='file',
interactive=False
)
with gradio.Column():
gradio.Markdown('## Settings')
with gradio.Row():
grid_size = gradio.Slider(
label='Grid Size',
value=3,
minimum=1,
maximum=5,
step=1
)
save_as_png = gradio.Checkbox(
label='Output PNG',
value=False
)
is_seamless = gradio.Checkbox(
label='Seamless',
value=False
)
gradio.Markdown('#### Advanced')
with gradio.Row():
temperature = gradio.Number(
label='Temperature',
value=1
)
top_k = gradio.Dropdown(
label='Top-k',
choices=[str(2 ** i) for i in range(15)],
value='128'
)
supercondition = gradio.Dropdown(
label='Super Condition',
choices=[str(2 ** i) for i in range(2, 7)],
value='16'
)
gradio.Markdown(
"""
#### Parameter
- **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image.
- **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds.
- **Seamless**: Tile images in image token space instead of pixel space.
- **Temperature**: High temperature increases the probability of sampling low scoring image tokens.
- **Top-k**: Each image token is sampled from the top-k scoring tokens.
- **Super Condition**: Higher values can result in better agreement with the text.
####
"""
)
gradio.Examples(
examples=[
['Portrait of a basset hound, 8k, photograph', 3, 'example/8k dog.png'],
['A diorama of Puppy cloud ,8k, photograph', 3, 'example/puppy.png'],
['A dragon that looks like a cream', 3, 'example/cream.png'],
['A photo of a sleeping orange tabby cat', 3, 'example/tabby.png'],
['A diorama of a bunny family sitting around the table having dinner ,8k, photograph', 3, 'example/table.png'],
['A white cat with golden sunglasses on, pink background, studio lighting, 4k, award winning photography', 2, 'example/cat.png'],
['an astronaut dancing on the moon’s surface, close-up photo', 2, 'example/astronaut.png'],
['A photo of a Samoyed dog with its tongue out hugging a white Siamese cat', 5, 'example/dog.png'],
['Dragons of Earth, Wind, Fire, powering up a huge sphere of compressed energy, digital art', 2, 'example/dragon.png'],
['A snowboarder jumping in the air while coming down a ski mountain, concept art, artstation, unreal engine, 3d render, HD, Bokeh', 3, 'example/snow.png'],
['Antique photo of a dragon fire', 3, 'example/fire.png'],
['A space parrot flying through the cosmos, digital art', 3, 'example/parrot.png'],
],
inputs=[
input_text,
grid_size,
output_image
],
examples_per_page=20
)
run_button.click(
fn=run_model,
inputs=[
input_text,
grid_size,
is_seamless,
save_as_png,
temperature,
supercondition,
top_k
],
outputs=[
output_image
]
)
demo.launch()