Spaces:
Runtime error
Runtime error
################################################################################ | |
# Copyright (C) 2023 Xingqian Xu - All Rights Reserved # | |
# # | |
# Please visit Prompt-Free-Diffusion's arXiv paper for more details, link at # | |
# arxiv.org/abs/2305.16223 # | |
# # | |
################################################################################ | |
import gradio as gr | |
import os.path as osp | |
from PIL import Image | |
import numpy as np | |
import time | |
import torch | |
import torchvision.transforms as tvtrans | |
from lib.cfg_helper import model_cfg_bank | |
from lib.model_zoo import get_model | |
from collections import OrderedDict | |
from lib.model_zoo.ddim import DDIMSampler | |
n_sample_image = 1 | |
controlnet_path = OrderedDict([ | |
['canny' , ('canny' , 'pretrained/controlnet/control_sd15_canny_slimmed.safetensors')], | |
['canny_v11p' , ('canny' , 'pretrained/controlnet/control_v11p_sd15_canny_slimmed.safetensors')], | |
['depth' , ('depth' , 'pretrained/controlnet/control_sd15_depth_slimmed.safetensors')], | |
['hed' , ('hed' , 'pretrained/controlnet/control_sd15_hed_slimmed.safetensors')], | |
['mlsd' , ('mlsd' , 'pretrained/controlnet/control_sd15_mlsd_slimmed.safetensors')], | |
['mlsd_v11p' , ('mlsd' , 'pretrained/controlnet/control_v11p_sd15_mlsd_slimmed.safetensors')], | |
['normal' , ('normal' , 'pretrained/controlnet/control_sd15_normal_slimmed.safetensors')], | |
['openpose' , ('openpose', 'pretrained/controlnet/control_sd15_openpose_slimmed.safetensors')], | |
['openpose_v11p' , ('openpose', 'pretrained/controlnet/control_v11p_sd15_openpose_slimmed.safetensors')], | |
['scribble' , ('scribble', 'pretrained/controlnet/control_sd15_scribble_slimmed.safetensors')], | |
['softedge_v11p' , ('scribble', 'pretrained/controlnet/control_v11p_sd15_softedge_slimmed.safetensors')], | |
['seg' , ('none' , 'pretrained/controlnet/control_sd15_seg_slimmed.safetensors')], | |
['lineart_v11p' , ('none' , 'pretrained/controlnet/control_v11p_sd15_lineart_slimmed.safetensors')], | |
['lineart_anime_v11p', ('none' , 'pretrained/controlnet/control_v11p_sd15s2_lineart_anime_slimmed.safetensors')], | |
]) | |
preprocess_method = [ | |
'canny' , | |
'depth' , | |
'hed' , | |
'mlsd' , | |
'normal' , | |
'openpose' , | |
'openpose_withface' , | |
'openpose_withfacehand', | |
'scribble' , | |
'none' , | |
] | |
diffuser_path = OrderedDict([ | |
['SD-v1.5' , 'pretrained/pfd/diffuser/SD-v1-5.safetensors'], | |
['OpenJouney-v4' , 'pretrained/pfd/diffuser/OpenJouney-v4.safetensors'], | |
['Deliberate-v2.0' , 'pretrained/pfd/diffuser/Deliberate-v2-0.safetensors'], | |
['RealisticVision-v2.0', 'pretrained/pfd/diffuser/RealisticVision-v2-0.safetensors'], | |
['Anything-v4' , 'pretrained/pfd/diffuser/Anything-v4.safetensors'], | |
['Oam-v3' , 'pretrained/pfd/diffuser/AbyssOrangeMix-v3.safetensors'], | |
['Oam-v2' , 'pretrained/pfd/diffuser/AbyssOrangeMix-v2.safetensors'], | |
]) | |
ctxencoder_path = OrderedDict([ | |
['SeeCoder' , 'pretrained/pfd/seecoder/seecoder-v1-0.safetensors'], | |
['SeeCoder-PA' , 'pretrained/pfd/seecoder/seecoder-pa-v1-0.safetensors'], | |
['SeeCoder-Anime', 'pretrained/pfd/seecoder/seecoder-anime-v1-0.safetensors'], | |
]) | |
########## | |
# helper # | |
########## | |
def highlight_print(info): | |
print('') | |
print(''.join(['#']*(len(info)+4))) | |
print('# '+info+' #') | |
print(''.join(['#']*(len(info)+4))) | |
print('') | |
def load_sd_from_file(target): | |
if osp.splitext(target)[-1] == '.ckpt': | |
sd = torch.load(target, map_location='cpu')['state_dict'] | |
elif osp.splitext(target)[-1] == '.pth': | |
sd = torch.load(target, map_location='cpu') | |
elif osp.splitext(target)[-1] == '.safetensors': | |
from safetensors.torch import load_file as stload | |
sd = OrderedDict(stload(target, device='cpu')) | |
else: | |
assert False, "File type must be .ckpt or .pth or .safetensors" | |
return sd | |
######## | |
# main # | |
######## | |
class prompt_free_diffusion(object): | |
def __init__(self, | |
fp16=False, | |
tag_ctx=None, | |
tag_diffuser=None, | |
tag_ctl=None,): | |
self.tag_ctx = tag_ctx | |
self.tag_diffuser = tag_diffuser | |
self.tag_ctl = tag_ctl | |
self.strict_sd = True | |
cfgm = model_cfg_bank()('pfd_seecoder_with_controlnet') | |
self.net = get_model()(cfgm) | |
self.action_load_ctx(tag_ctx) | |
self.action_load_diffuser(tag_diffuser) | |
self.action_load_ctl(tag_ctl) | |
if fp16: | |
highlight_print('Running in FP16') | |
self.net.ctx['image'].fp16 = True | |
self.net = self.net.half() | |
self.dtype = torch.float16 | |
else: | |
self.dtype = torch.float32 | |
self.use_cuda = torch.cuda.is_available() | |
if self.use_cuda: | |
self.net.to('cuda') | |
self.net.eval() | |
self.sampler = DDIMSampler(self.net) | |
self.n_sample_image = n_sample_image | |
self.ddim_steps = 50 | |
self.ddim_eta = 0.0 | |
self.image_latent_dim = 4 | |
def load_ctx(self, pretrained): | |
sd = load_sd_from_file(pretrained) | |
sd_extra = [(ki, vi) for ki, vi in self.net.state_dict().items() \ | |
if ki.find('ctx.')!=0] | |
sd.update(OrderedDict(sd_extra)) | |
self.net.load_state_dict(sd, strict=True) | |
print('Load context encoder from [{}] strict [{}].'.format(pretrained, True)) | |
def load_diffuser(self, pretrained): | |
sd = load_sd_from_file(pretrained) | |
if len([ki for ki in sd.keys() if ki.find('diffuser.image.context_blocks.')==0]) == 0: | |
sd = [( | |
ki.replace('diffuser.text.context_blocks.', 'diffuser.image.context_blocks.'), vi) | |
for ki, vi in sd.items()] | |
sd = OrderedDict(sd) | |
sd_extra = [(ki, vi) for ki, vi in self.net.state_dict().items() \ | |
if ki.find('diffuser.')!=0] | |
sd.update(OrderedDict(sd_extra)) | |
self.net.load_state_dict(sd, strict=True) | |
print('Load diffuser from [{}] strict [{}].'.format(pretrained, True)) | |
def load_ctl(self, pretrained): | |
sd = load_sd_from_file(pretrained) | |
self.net.ctl.load_state_dict(sd, strict=True) | |
print('Load controlnet from [{}] strict [{}].'.format(pretrained, True)) | |
def action_load_ctx(self, tag): | |
pretrained = ctxencoder_path[tag] | |
if tag == 'SeeCoder-PA': | |
from lib.model_zoo.seecoder import PPE_MLP | |
pe_layer = \ | |
PPE_MLP(freq_num=20, freq_max=None, out_channel=768, mlp_layer=3) | |
if self.dtype == torch.float16: | |
pe_layer = pe_layer.half() | |
if self.use_cuda: | |
pe_layer.to('cuda') | |
pe_layer.eval() | |
self.net.ctx['image'].qtransformer.pe_layer = pe_layer | |
else: | |
self.net.ctx['image'].qtransformer.pe_layer = None | |
if pretrained is not None: | |
self.load_ctx(pretrained) | |
self.tag_ctx = tag | |
return tag | |
def action_load_diffuser(self, tag): | |
pretrained = diffuser_path[tag] | |
if pretrained is not None: | |
self.load_diffuser(pretrained) | |
self.tag_diffuser = tag | |
return tag | |
def action_load_ctl(self, tag): | |
pretrained = controlnet_path[tag][1] | |
if pretrained is not None: | |
self.load_ctl(pretrained) | |
self.tag_ctl = tag | |
return tag | |
def action_autoset_hw(self, imctl): | |
if imctl is None: | |
return 512, 512 | |
w, h = imctl.size | |
w = w//64 * 64 | |
h = h//64 * 64 | |
w = w if w >=512 else 512 | |
w = w if w <=1536 else 1536 | |
h = h if h >=512 else 512 | |
h = h if h <=1536 else 1536 | |
return h, w | |
def action_autoset_method(self, tag): | |
return controlnet_path[tag][0] | |
def action_inference( | |
self, im, imctl, ctl_method, do_preprocess, | |
h, w, ugscale, seed, | |
tag_ctx, tag_diffuser, tag_ctl,): | |
if tag_ctx != self.tag_ctx: | |
self.action_load_ctx(tag_ctx) | |
if tag_diffuser != self.tag_diffuser: | |
self.action_load_diffuser(tag_diffuser) | |
if tag_ctl != self.tag_ctl: | |
self.action_load_ctl(tag_ctl) | |
n_samples = self.n_sample_image | |
sampler = self.sampler | |
device = self.net.device | |
w = w//64 * 64 | |
h = h//64 * 64 | |
if imctl is not None: | |
imctl = imctl.resize([w, h], Image.Resampling.BICUBIC) | |
craw = tvtrans.ToTensor()(im)[None].to(device).to(self.dtype) | |
c = self.net.ctx_encode(craw, which='image').repeat(n_samples, 1, 1) | |
u = torch.zeros_like(c) | |
if tag_ctx in ["SeeCoder-Anime"]: | |
u = torch.load('assets/anime_ug.pth')[None].to(device).to(self.dtype) | |
pad = c.size(1) - u.size(1) | |
u = torch.cat([u, torch.zeros_like(u[:, 0:1].repeat(1, pad, 1))], axis=1) | |
if tag_ctl != 'none': | |
ccraw = tvtrans.ToTensor()(imctl)[None].to(device).to(self.dtype) | |
if do_preprocess: | |
cc = self.net.ctl.preprocess(ccraw, type=ctl_method, size=[h, w]) | |
cc = cc.to(self.dtype) | |
else: | |
cc = ccraw | |
else: | |
cc = None | |
shape = [n_samples, self.image_latent_dim, h//8, w//8] | |
if seed < 0: | |
np.random.seed(int(time.time())) | |
torch.manual_seed(-seed + 100) | |
else: | |
np.random.seed(seed + 100) | |
torch.manual_seed(seed) | |
x, _ = sampler.sample( | |
steps=self.ddim_steps, | |
x_info={'type':'image',}, | |
c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u, | |
'unconditional_guidance_scale':ugscale, | |
'control':cc,}, | |
shape=shape, | |
verbose=False, | |
eta=self.ddim_eta) | |
ccout = [tvtrans.ToPILImage()(i) for i in cc] if cc is not None else [] | |
imout = self.net.vae_decode(x, which='image') | |
imout = [tvtrans.ToPILImage()(i) for i in imout] | |
return imout + ccout | |
pfd_inference = prompt_free_diffusion( | |
fp16=True, tag_ctx = 'SeeCoder', tag_diffuser = 'Deliberate-v2.0', tag_ctl = 'canny',) | |
################# | |
# sub interface # | |
################# | |
cache_examples = True | |
def get_example(): | |
case = [ | |
[ | |
'assets/examples/ghibli-input.jpg', | |
'assets/examples/ghibli-canny.png', | |
'canny', False, | |
768, 1024, 1.8, 23, | |
'SeeCoder', 'Deliberate-v2.0', 'canny', ], | |
[ | |
'assets/examples/astronautridinghouse-input.jpg', | |
'assets/examples/astronautridinghouse-canny.png', | |
'canny', False, | |
512, 768, 2.0, 21, | |
'SeeCoder', 'Deliberate-v2.0', 'canny', ], | |
[ | |
'assets/examples/grassland-input.jpg', | |
'assets/examples/grassland-scribble.png', | |
'scribble', False, | |
768, 512, 2.0, 41, | |
'SeeCoder', 'Deliberate-v2.0', 'scribble', ], | |
[ | |
'assets/examples/jeep-input.jpg', | |
'assets/examples/jeep-depth.png', | |
'depth', False, | |
512, 768, 2.0, 30, | |
'SeeCoder', 'Deliberate-v2.0', 'depth', ], | |
[ | |
'assets/examples/bedroom-input.jpg', | |
'assets/examples/bedroom-mlsd.png', | |
'mlsd', False, | |
512, 512, 2.0, 31, | |
'SeeCoder', 'Deliberate-v2.0', 'mlsd', ], | |
[ | |
'assets/examples/nightstreet-input.jpg', | |
'assets/examples/nightstreet-canny.png', | |
'canny', False, | |
768, 512, 2.3, 20, | |
'SeeCoder', 'Deliberate-v2.0', 'canny', ], | |
[ | |
'assets/examples/woodcar-input.jpg', | |
'assets/examples/woodcar-depth.png', | |
'depth', False, | |
768, 512, 2.0, 20, | |
'SeeCoder', 'Deliberate-v2.0', 'depth', ], | |
[ | |
'assets/examples-anime/miku.jpg', | |
'assets/examples-anime/miku-canny.png', | |
'canny', False, | |
768, 576, 1.5, 22, | |
'SeeCoder-Anime', 'Anything-v4', 'canny', ], | |
[ | |
'assets/examples-anime/random0.jpg', | |
'assets/examples-anime/pose.png', | |
'openpose', False, | |
768, 1536, 2.0, 41, | |
'SeeCoder-Anime', 'Oam-v2', 'openpose_v11p', ], | |
[ | |
'assets/examples-anime/random1.jpg', | |
'assets/examples-anime/pose.png', | |
'openpose', False, | |
768, 1536, 2.5, 28, | |
'SeeCoder-Anime', 'Oam-v2', 'openpose_v11p', ], | |
[ | |
'assets/examples-anime/camping.jpg', | |
'assets/examples-anime/pose.png', | |
'openpose', False, | |
768, 1536, 2.0, 35, | |
'SeeCoder-Anime', 'Anything-v4', 'openpose_v11p', ], | |
[ | |
'assets/examples-anime/hanfu_girl.jpg', | |
'assets/examples-anime/pose.png', | |
'openpose', False, | |
768, 1536, 2.0, 20, | |
'SeeCoder-Anime', 'Anything-v4', 'openpose_v11p', ], | |
] | |
return case | |
def interface(): | |
with gr.Row(): | |
with gr.Column(): | |
img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox') | |
with gr.Row(): | |
out_width = gr.Slider(label="Width" , minimum=512, maximum=1536, value=512, step=64, visible=True) | |
out_height = gr.Slider(label="Height", minimum=512, maximum=1536, value=512, step=64, visible=True) | |
with gr.Row(): | |
scl_lvl = gr.Slider(label="CFGScale", minimum=0, maximum=10, value=2, step=0.01, visible=True) | |
seed = gr.Number(20, label="Seed", precision=0) | |
with gr.Row(): | |
tag_ctx = gr.Dropdown(label='Context Encoder', choices=[pi for pi in ctxencoder_path.keys()], value='SeeCoder') | |
tag_diffuser = gr.Dropdown(label='Diffuser', choices=[pi for pi in diffuser_path.keys()], value='Deliberate-v2.0') | |
button = gr.Button("Run") | |
with gr.Column(): | |
ctl_input = gr.Image(label='Control Input', type='pil', elem_id='customized_imbox') | |
do_preprocess = gr.Checkbox(label='Preprocess', value=False) | |
with gr.Row(): | |
ctl_method = gr.Dropdown(label='Preprocess Type', choices=preprocess_method, value='canny') | |
tag_ctl = gr.Dropdown(label='ControlNet', choices=[pi for pi in controlnet_path.keys()], value='canny') | |
with gr.Column(): | |
img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image+1) | |
tag_ctl.change( | |
pfd_inference.action_autoset_method, | |
inputs = [tag_ctl], | |
outputs = [ctl_method],) | |
ctl_input.change( | |
pfd_inference.action_autoset_hw, | |
inputs = [ctl_input], | |
outputs = [out_height, out_width],) | |
# tag_ctx.change( | |
# pfd_inference.action_load_ctx, | |
# inputs = [tag_ctx], | |
# outputs = [tag_ctx],) | |
# tag_diffuser.change( | |
# pfd_inference.action_load_diffuser, | |
# inputs = [tag_diffuser], | |
# outputs = [tag_diffuser],) | |
# tag_ctl.change( | |
# pfd_inference.action_load_ctl, | |
# inputs = [tag_ctl], | |
# outputs = [tag_ctl],) | |
button.click( | |
pfd_inference.action_inference, | |
inputs=[img_input, ctl_input, ctl_method, do_preprocess, | |
out_height, out_width, scl_lvl, seed, | |
tag_ctx, tag_diffuser, tag_ctl, ], | |
outputs=[img_output]) | |
gr.Examples( | |
label='Examples', | |
examples=get_example(), | |
fn=pfd_inference.action_inference, | |
inputs=[img_input, ctl_input, ctl_method, do_preprocess, | |
out_height, out_width, scl_lvl, seed, | |
tag_ctx, tag_diffuser, tag_ctl, ], | |
outputs=[img_output], | |
cache_examples=cache_examples,) | |
############# | |
# Interface # | |
############# | |
css = """ | |
#customized_imbox { | |
min-height: 450px; | |
} | |
#customized_imbox>div[data-testid="image"] { | |
min-height: 450px; | |
} | |
#customized_imbox>div[data-testid="image"]>div { | |
min-height: 450px; | |
} | |
#customized_imbox>div[data-testid="image"]>iframe { | |
min-height: 450px; | |
} | |
#customized_imbox>div.unpadded_box { | |
min-height: 450px; | |
} | |
#myinst { | |
font-size: 0.8rem; | |
margin: 0rem; | |
color: #6B7280; | |
} | |
#maskinst { | |
text-align: justify; | |
min-width: 1200px; | |
} | |
#maskinst>img { | |
min-width:399px; | |
max-width:450px; | |
vertical-align: top; | |
display: inline-block; | |
} | |
#maskinst:after { | |
content: ""; | |
width: 100%; | |
display: inline-block; | |
} | |
""" | |
if True: | |
with gr.Blocks(css=css) as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 1200px; margin: 20px auto;"> | |
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem"> | |
Prompt-Free Diffusion | |
</h1> | |
</div> | |
""") | |
interface() | |
# gr.HTML( | |
# """ | |
# <div style="text-align: justify; max-width: 1200px; margin: 20px auto;"> | |
# <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem"> | |
# <b>Version</b>: {} | |
# </h3> | |
# </div> | |
# """.format(' '+str(pfd_inference.pretrained))) | |
# demo.launch(server_name="0.0.0.0", server_port=7992) | |
# demo.launch() | |
demo.launch(debug=True) | |