################################################################################ # Copyright (C) 2023 Xingqian Xu - All Rights Reserved # # # # Please visit Prompt-Free-Diffusion's arXiv paper for more details, link at # # arxiv.org/abs/2305.16223 # # # ################################################################################ import gradio as gr import os.path as osp from PIL import Image import numpy as np import time import torch import torchvision.transforms as tvtrans from lib.cfg_helper import model_cfg_bank from lib.model_zoo import get_model from collections import OrderedDict from lib.model_zoo.ddim import DDIMSampler n_sample_image = 1 controlnet_path = OrderedDict([ ['canny' , ('canny' , 'pretrained/controlnet/control_sd15_canny_slimmed.safetensors')], ['canny_v11p' , ('canny' , 'pretrained/controlnet/control_v11p_sd15_canny_slimmed.safetensors')], ['depth' , ('depth' , 'pretrained/controlnet/control_sd15_depth_slimmed.safetensors')], ['hed' , ('hed' , 'pretrained/controlnet/control_sd15_hed_slimmed.safetensors')], ['mlsd' , ('mlsd' , 'pretrained/controlnet/control_sd15_mlsd_slimmed.safetensors')], ['mlsd_v11p' , ('mlsd' , 'pretrained/controlnet/control_v11p_sd15_mlsd_slimmed.safetensors')], ['normal' , ('normal' , 'pretrained/controlnet/control_sd15_normal_slimmed.safetensors')], ['openpose' , ('openpose', 'pretrained/controlnet/control_sd15_openpose_slimmed.safetensors')], ['openpose_v11p' , ('openpose', 'pretrained/controlnet/control_v11p_sd15_openpose_slimmed.safetensors')], ['scribble' , ('scribble', 'pretrained/controlnet/control_sd15_scribble_slimmed.safetensors')], ['softedge_v11p' , ('scribble', 'pretrained/controlnet/control_v11p_sd15_softedge_slimmed.safetensors')], ['seg' , ('none' , 'pretrained/controlnet/control_sd15_seg_slimmed.safetensors')], ['lineart_v11p' , ('none' , 'pretrained/controlnet/control_v11p_sd15_lineart_slimmed.safetensors')], ['lineart_anime_v11p', ('none' , 'pretrained/controlnet/control_v11p_sd15s2_lineart_anime_slimmed.safetensors')], ]) preprocess_method = [ 'canny' , 'depth' , 'hed' , 'mlsd' , 'normal' , 'openpose' , 'openpose_withface' , 'openpose_withfacehand', 'scribble' , 'none' , ] diffuser_path = OrderedDict([ ['SD-v1.5' , 'pretrained/pfd/diffuser/SD-v1-5.safetensors'], ['OpenJouney-v4' , 'pretrained/pfd/diffuser/OpenJouney-v4.safetensors'], ['Deliberate-v2.0' , 'pretrained/pfd/diffuser/Deliberate-v2-0.safetensors'], ['RealisticVision-v2.0', 'pretrained/pfd/diffuser/RealisticVision-v2-0.safetensors'], ['Anything-v4' , 'pretrained/pfd/diffuser/Anything-v4.safetensors'], ['Oam-v3' , 'pretrained/pfd/diffuser/AbyssOrangeMix-v3.safetensors'], ['Oam-v2' , 'pretrained/pfd/diffuser/AbyssOrangeMix-v2.safetensors'], ]) ctxencoder_path = OrderedDict([ ['SeeCoder' , 'pretrained/pfd/seecoder/seecoder-v1-0.safetensors'], ['SeeCoder-PA' , 'pretrained/pfd/seecoder/seecoder-pa-v1-0.safetensors'], ['SeeCoder-Anime', 'pretrained/pfd/seecoder/seecoder-anime-v1-0.safetensors'], ]) ########## # helper # ########## def highlight_print(info): print('') print(''.join(['#']*(len(info)+4))) print('# '+info+' #') print(''.join(['#']*(len(info)+4))) print('') def load_sd_from_file(target): if osp.splitext(target)[-1] == '.ckpt': sd = torch.load(target, map_location='cpu')['state_dict'] elif osp.splitext(target)[-1] == '.pth': sd = torch.load(target, map_location='cpu') elif osp.splitext(target)[-1] == '.safetensors': from safetensors.torch import load_file as stload sd = OrderedDict(stload(target, device='cpu')) else: assert False, "File type must be .ckpt or .pth or .safetensors" return sd ######## # main # ######## class prompt_free_diffusion(object): def __init__(self, fp16=False, tag_ctx=None, tag_diffuser=None, tag_ctl=None,): self.tag_ctx = tag_ctx self.tag_diffuser = tag_diffuser self.tag_ctl = tag_ctl self.strict_sd = True cfgm = model_cfg_bank()('pfd_seecoder_with_controlnet') self.net = get_model()(cfgm) self.action_load_ctx(tag_ctx) self.action_load_diffuser(tag_diffuser) self.action_load_ctl(tag_ctl) if fp16: highlight_print('Running in FP16') self.net.ctx['image'].fp16 = True self.net = self.net.half() self.dtype = torch.float16 else: self.dtype = torch.float32 self.use_cuda = torch.cuda.is_available() if self.use_cuda: self.net.to('cuda') self.net.eval() self.sampler = DDIMSampler(self.net) self.n_sample_image = n_sample_image self.ddim_steps = 50 self.ddim_eta = 0.0 self.image_latent_dim = 4 def load_ctx(self, pretrained): sd = load_sd_from_file(pretrained) sd_extra = [(ki, vi) for ki, vi in self.net.state_dict().items() \ if ki.find('ctx.')!=0] sd.update(OrderedDict(sd_extra)) self.net.load_state_dict(sd, strict=True) print('Load context encoder from [{}] strict [{}].'.format(pretrained, True)) def load_diffuser(self, pretrained): sd = load_sd_from_file(pretrained) if len([ki for ki in sd.keys() if ki.find('diffuser.image.context_blocks.')==0]) == 0: sd = [( ki.replace('diffuser.text.context_blocks.', 'diffuser.image.context_blocks.'), vi) for ki, vi in sd.items()] sd = OrderedDict(sd) sd_extra = [(ki, vi) for ki, vi in self.net.state_dict().items() \ if ki.find('diffuser.')!=0] sd.update(OrderedDict(sd_extra)) self.net.load_state_dict(sd, strict=True) print('Load diffuser from [{}] strict [{}].'.format(pretrained, True)) def load_ctl(self, pretrained): sd = load_sd_from_file(pretrained) self.net.ctl.load_state_dict(sd, strict=True) print('Load controlnet from [{}] strict [{}].'.format(pretrained, True)) def action_load_ctx(self, tag): pretrained = ctxencoder_path[tag] if tag == 'SeeCoder-PA': from lib.model_zoo.seecoder import PPE_MLP pe_layer = \ PPE_MLP(freq_num=20, freq_max=None, out_channel=768, mlp_layer=3) if self.dtype == torch.float16: pe_layer = pe_layer.half() if self.use_cuda: pe_layer.to('cuda') pe_layer.eval() self.net.ctx['image'].qtransformer.pe_layer = pe_layer else: self.net.ctx['image'].qtransformer.pe_layer = None if pretrained is not None: self.load_ctx(pretrained) self.tag_ctx = tag return tag def action_load_diffuser(self, tag): pretrained = diffuser_path[tag] if pretrained is not None: self.load_diffuser(pretrained) self.tag_diffuser = tag return tag def action_load_ctl(self, tag): pretrained = controlnet_path[tag][1] if pretrained is not None: self.load_ctl(pretrained) self.tag_ctl = tag return tag def action_autoset_hw(self, imctl): if imctl is None: return 512, 512 w, h = imctl.size w = w//64 * 64 h = h//64 * 64 w = w if w >=512 else 512 w = w if w <=1536 else 1536 h = h if h >=512 else 512 h = h if h <=1536 else 1536 return h, w def action_autoset_method(self, tag): return controlnet_path[tag][0] def action_inference( self, im, imctl, ctl_method, do_preprocess, h, w, ugscale, seed, tag_ctx, tag_diffuser, tag_ctl,): if tag_ctx != self.tag_ctx: self.action_load_ctx(tag_ctx) if tag_diffuser != self.tag_diffuser: self.action_load_diffuser(tag_diffuser) if tag_ctl != self.tag_ctl: self.action_load_ctl(tag_ctl) n_samples = self.n_sample_image sampler = self.sampler device = self.net.device w = w//64 * 64 h = h//64 * 64 if imctl is not None: imctl = imctl.resize([w, h], Image.Resampling.BICUBIC) craw = tvtrans.ToTensor()(im)[None].to(device).to(self.dtype) c = self.net.ctx_encode(craw, which='image').repeat(n_samples, 1, 1) u = torch.zeros_like(c) if tag_ctx in ["SeeCoder-Anime"]: u = torch.load('assets/anime_ug.pth')[None].to(device).to(self.dtype) pad = c.size(1) - u.size(1) u = torch.cat([u, torch.zeros_like(u[:, 0:1].repeat(1, pad, 1))], axis=1) if tag_ctl != 'none': ccraw = tvtrans.ToTensor()(imctl)[None].to(device).to(self.dtype) if do_preprocess: cc = self.net.ctl.preprocess(ccraw, type=ctl_method, size=[h, w]) cc = cc.to(self.dtype) else: cc = ccraw else: cc = None shape = [n_samples, self.image_latent_dim, h//8, w//8] if seed < 0: np.random.seed(int(time.time())) torch.manual_seed(-seed + 100) else: np.random.seed(seed + 100) torch.manual_seed(seed) x, _ = sampler.sample( steps=self.ddim_steps, x_info={'type':'image',}, c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u, 'unconditional_guidance_scale':ugscale, 'control':cc,}, shape=shape, verbose=False, eta=self.ddim_eta) ccout = [tvtrans.ToPILImage()(i) for i in cc] if cc is not None else [] imout = self.net.vae_decode(x, which='image') imout = [tvtrans.ToPILImage()(i) for i in imout] return imout + ccout pfd_inference = prompt_free_diffusion( fp16=True, tag_ctx = 'SeeCoder', tag_diffuser = 'Deliberate-v2.0', tag_ctl = 'canny',) ################# # sub interface # ################# cache_examples = True def get_example(): case = [ [ 'assets/examples/ghibli-input.jpg', 'assets/examples/ghibli-canny.png', 'canny', False, 768, 1024, 1.8, 23, 'SeeCoder', 'Deliberate-v2.0', 'canny', ], [ 'assets/examples/astronautridinghouse-input.jpg', 'assets/examples/astronautridinghouse-canny.png', 'canny', False, 512, 768, 2.0, 21, 'SeeCoder', 'Deliberate-v2.0', 'canny', ], [ 'assets/examples/grassland-input.jpg', 'assets/examples/grassland-scribble.png', 'scribble', False, 768, 512, 2.0, 41, 'SeeCoder', 'Deliberate-v2.0', 'scribble', ], [ 'assets/examples/jeep-input.jpg', 'assets/examples/jeep-depth.png', 'depth', False, 512, 768, 2.0, 30, 'SeeCoder', 'Deliberate-v2.0', 'depth', ], [ 'assets/examples/bedroom-input.jpg', 'assets/examples/bedroom-mlsd.png', 'mlsd', False, 512, 512, 2.0, 31, 'SeeCoder', 'Deliberate-v2.0', 'mlsd', ], [ 'assets/examples/nightstreet-input.jpg', 'assets/examples/nightstreet-canny.png', 'canny', False, 768, 512, 2.3, 20, 'SeeCoder', 'Deliberate-v2.0', 'canny', ], [ 'assets/examples/woodcar-input.jpg', 'assets/examples/woodcar-depth.png', 'depth', False, 768, 512, 2.0, 20, 'SeeCoder', 'Deliberate-v2.0', 'depth', ], [ 'assets/examples-anime/miku.jpg', 'assets/examples-anime/miku-canny.png', 'canny', False, 768, 576, 1.5, 22, 'SeeCoder-Anime', 'Anything-v4', 'canny', ], [ 'assets/examples-anime/random0.jpg', 'assets/examples-anime/pose.png', 'openpose', False, 768, 1536, 2.0, 41, 'SeeCoder-Anime', 'Oam-v2', 'openpose_v11p', ], [ 'assets/examples-anime/random1.jpg', 'assets/examples-anime/pose.png', 'openpose', False, 768, 1536, 2.5, 28, 'SeeCoder-Anime', 'Oam-v2', 'openpose_v11p', ], [ 'assets/examples-anime/camping.jpg', 'assets/examples-anime/pose.png', 'openpose', False, 768, 1536, 2.0, 35, 'SeeCoder-Anime', 'Anything-v4', 'openpose_v11p', ], [ 'assets/examples-anime/hanfu_girl.jpg', 'assets/examples-anime/pose.png', 'openpose', False, 768, 1536, 2.0, 20, 'SeeCoder-Anime', 'Anything-v4', 'openpose_v11p', ], ] return case def interface(): with gr.Row(): with gr.Column(): img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox') with gr.Row(): out_width = gr.Slider(label="Width" , minimum=512, maximum=1536, value=512, step=64, visible=True) out_height = gr.Slider(label="Height", minimum=512, maximum=1536, value=512, step=64, visible=True) with gr.Row(): scl_lvl = gr.Slider(label="CFGScale", minimum=0, maximum=10, value=2, step=0.01, visible=True) seed = gr.Number(20, label="Seed", precision=0) with gr.Row(): tag_ctx = gr.Dropdown(label='Context Encoder', choices=[pi for pi in ctxencoder_path.keys()], value='SeeCoder') tag_diffuser = gr.Dropdown(label='Diffuser', choices=[pi for pi in diffuser_path.keys()], value='Deliberate-v2.0') button = gr.Button("Run") with gr.Column(): ctl_input = gr.Image(label='Control Input', type='pil', elem_id='customized_imbox') do_preprocess = gr.Checkbox(label='Preprocess', value=False) with gr.Row(): ctl_method = gr.Dropdown(label='Preprocess Type', choices=preprocess_method, value='canny') tag_ctl = gr.Dropdown(label='ControlNet', choices=[pi for pi in controlnet_path.keys()], value='canny') with gr.Column(): img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image+1) tag_ctl.change( pfd_inference.action_autoset_method, inputs = [tag_ctl], outputs = [ctl_method],) ctl_input.change( pfd_inference.action_autoset_hw, inputs = [ctl_input], outputs = [out_height, out_width],) # tag_ctx.change( # pfd_inference.action_load_ctx, # inputs = [tag_ctx], # outputs = [tag_ctx],) # tag_diffuser.change( # pfd_inference.action_load_diffuser, # inputs = [tag_diffuser], # outputs = [tag_diffuser],) # tag_ctl.change( # pfd_inference.action_load_ctl, # inputs = [tag_ctl], # outputs = [tag_ctl],) button.click( pfd_inference.action_inference, inputs=[img_input, ctl_input, ctl_method, do_preprocess, out_height, out_width, scl_lvl, seed, tag_ctx, tag_diffuser, tag_ctl, ], outputs=[img_output]) gr.Examples( label='Examples', examples=get_example(), fn=pfd_inference.action_inference, inputs=[img_input, ctl_input, ctl_method, do_preprocess, out_height, out_width, scl_lvl, seed, tag_ctx, tag_diffuser, tag_ctl, ], outputs=[img_output], cache_examples=cache_examples,) ############# # Interface # ############# css = """ #customized_imbox { min-height: 450px; } #customized_imbox>div[data-testid="image"] { min-height: 450px; } #customized_imbox>div[data-testid="image"]>div { min-height: 450px; } #customized_imbox>div[data-testid="image"]>iframe { min-height: 450px; } #customized_imbox>div.unpadded_box { min-height: 450px; } #myinst { font-size: 0.8rem; margin: 0rem; color: #6B7280; } #maskinst { text-align: justify; min-width: 1200px; } #maskinst>img { min-width:399px; max-width:450px; vertical-align: top; display: inline-block; } #maskinst:after { content: ""; width: 100%; display: inline-block; } """ if True: with gr.Blocks(css=css) as demo: gr.HTML( """

Prompt-Free Diffusion

""") interface() # gr.HTML( # """ #
#

# Version: {} #

#
# """.format(' '+str(pfd_inference.pretrained))) # demo.launch(server_name="0.0.0.0", server_port=7992) # demo.launch() demo.launch(debug=True)