Spaces:
Runtime error
Runtime error
import gradio | |
import subprocess | |
from PIL import Image | |
import torch, torch.backends.cudnn, torch.backends.cuda | |
from min_dalle import MinDalle | |
from emoji import demojize | |
import string | |
def filename_from_text(text: str) -> str: | |
text = demojize(text, delimiters=['', '']) | |
text = text.lower().encode('ascii', errors='ignore').decode() | |
allowed_chars = string.ascii_lowercase + ' ' | |
text = ''.join(i for i in text.lower() if i in allowed_chars) | |
text = text[:64] | |
text = '-'.join(text.strip().split()) | |
if len(text) == 0: text = 'blank' | |
return text | |
def log_gpu_memory(): | |
print(subprocess.check_output('nvidia-smi').decode('utf-8')) | |
log_gpu_memory() | |
model = MinDalle( | |
is_mega=True, | |
is_reusable=True, | |
device='cuda', | |
dtype=torch.float32 | |
) | |
log_gpu_memory() | |
def run_model( | |
text: str, | |
grid_size: int, | |
is_seamless: bool, | |
save_as_png: bool, | |
temperature: float, | |
supercondition: str, | |
top_k: str | |
) -> str: | |
torch.set_grad_enabled(False) | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.deterministic = False | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True | |
print('text:', text) | |
print('grid_size:', grid_size) | |
print('is_seamless:', is_seamless) | |
print('temperature:', temperature) | |
print('supercondition:', supercondition) | |
print('top_k:', top_k) | |
try: | |
temperature = float(temperature) | |
assert(temperature > 1e-6) | |
except: | |
raise Exception('Temperature must be a positive nonzero number') | |
try: | |
grid_size = int(grid_size) | |
assert(grid_size <= 5) | |
assert(grid_size >= 1) | |
except: | |
raise Exception('Grid size must be between 1 and 5') | |
try: | |
top_k = int(top_k) | |
assert(top_k <= 16384) | |
assert(top_k >= 1) | |
except: | |
raise Exception('Top k must be between 1 and 16384') | |
with torch.no_grad(): | |
image = model.generate_image( | |
text = text, | |
seed = -1, | |
grid_size = grid_size, | |
is_seamless = bool(is_seamless), | |
temperature = temperature, | |
supercondition_factor = float(supercondition), | |
top_k = top_k, | |
is_verbose = True | |
) | |
log_gpu_memory() | |
ext = 'png' if bool(save_as_png) else 'jpg' | |
filename = filename_from_text(text) | |
image_path = '{}.{}'.format(filename, ext) | |
image.save(image_path) | |
return image_path | |
demo = gradio.Blocks(analytics_enabled=True) | |
with demo: | |
with gradio.Row(): | |
with gradio.Column(): | |
input_text = gradio.Textbox( | |
label='Input Text', | |
value='Rusty Iron Man suit found abandoned in the woods being reclaimed by nature', | |
lines=3 | |
) | |
run_button = gradio.Button(value='Generate Image').style(full_width=True) | |
output_image = gradio.Image( | |
value='examples/rusty-iron-man.jpg', | |
label='Output Image', | |
type='file', | |
interactive=False | |
) | |
with gradio.Column(): | |
gradio.Markdown('## Settings') | |
with gradio.Row(): | |
grid_size = gradio.Slider( | |
label='Grid Size', | |
value=3, | |
minimum=1, | |
maximum=5, | |
step=1 | |
) | |
save_as_png = gradio.Checkbox( | |
label='Output PNG', | |
value=False | |
) | |
is_seamless = gradio.Checkbox( | |
label='Seamless', | |
value=False | |
) | |
gradio.Markdown('#### Advanced') | |
with gradio.Row(): | |
temperature = gradio.Number( | |
label='Temperature', | |
value=1 | |
) | |
top_k = gradio.Dropdown( | |
label='Top-k', | |
choices=[str(2 ** i) for i in range(15)], | |
value='128' | |
) | |
supercondition = gradio.Dropdown( | |
label='Super Condition', | |
choices=[str(2 ** i) for i in range(2, 7)], | |
value='16' | |
) | |
gradio.Markdown( | |
""" | |
#### | |
- **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image. | |
- **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds. | |
- **Seamless**: Tile images in image token space instead of pixel space. | |
- **Temperature**: High temperature increases the probability of sampling low scoring image tokens. | |
- **Top-k**: Each image token is sampled from the top-k scoring tokens. | |
- **Super Condition**: Higher values can result in better agreement with the text. | |
""" | |
) | |
gradio.Examples( | |
examples=[ | |
['Rusty Iron Man suit found abandoned in the woods being reclaimed by nature', 3, 'examples/rusty-iron-man.jpg'], | |
['Moai statue giving a TED Talk', 5, 'examples/moai-statue.jpg'], | |
['Court sketch of Godzilla on trial', 5, 'examples/godzilla-trial.jpg'], | |
['lofi nuclear war to relax and study to', 5, 'examples/lofi-nuclear-war.jpg'], | |
['Karl Marx slimed at Kids Choice Awards', 4, 'examples/marx-slimed.jpg'], | |
['Scientists trying to rhyme orange with banana', 4, 'examples/scientists-rhyme.jpg'], | |
['Jesus turning water into wine on Americas Got Talent', 5, 'examples/jesus-talent.jpg'], | |
['Elmo in a street riot throwing a Molotov cocktail, hyperrealistic', 5, 'examples/elmo-riot.jpg'], | |
['Trail cam footage of gollum eating watermelon', 4, 'examples/gollum.jpg'], | |
['Funeral at Whole Foods', 4, 'examples/funeral-whole-foods.jpg'], | |
['Singularity, hyperrealism', 5, 'examples/singularity.jpg'], | |
['Astronaut riding a horse hyperrealistic', 5, 'examples/astronaut-horse.jpg'], | |
['An astronaut walking on Mars next to a Starship rocket, realistic', 5, 'examples/astronaut-mars.jpg'], | |
['Nuclear explosion broccoli', 4, 'examples/nuclear-broccoli.jpg'], | |
['Dali painting of WALL·E', 5, 'examples/dali-walle.jpg'], | |
['Cleopatra checking her iPhone', 4, 'examples/cleopatra-iphone.jpg'], | |
], | |
inputs=[ | |
input_text, | |
grid_size, | |
output_image | |
], | |
examples_per_page=20 | |
) | |
run_button.click( | |
fn=run_model, | |
inputs=[ | |
input_text, | |
grid_size, | |
is_seamless, | |
save_as_png, | |
temperature, | |
supercondition, | |
top_k | |
], | |
outputs=[ | |
output_image | |
] | |
) | |
gradio.Markdown( | |
""" | |
### **[❤️ Sponsor](https://github.com/sponsors/kuprel)** | |
""" | |
) | |
demo.launch() |