import gradio as gr from examples.story_examples import get_examples import spaces import numpy as np import torch import random import os import torch.nn.functional as F from diffusers import StableDiffusionXLPipeline, DDIMScheduler import copy from huggingface_hub import hf_hub_download from diffusers.utils import load_image from storyDiffusion.utils.gradio_utils import AttnProcessor2_0 as AttnProcessor, cal_attn_mask_xl from storyDiffusion.utils import PhotoMakerStableDiffusionXLPipeline from storyDiffusion.utils.utils import get_comic from storyDiffusion.utils.style_template import styles # Constants image_encoder_path = "./data/models/ip_adapter/sdxl_models/image_encoder" ip_ckpt = "./data/models/ip_adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin" os.environ["no_proxy"] = "localhost,,::1" STYLE_NAMES = list(styles.keys()) DEFAULT_STYLE_NAME = "Japanese Anime" MAX_SEED = np.iinfo(np.int32).max # Global variables global models_dict, use_va, photomaker_path, pipe2, pipe4, attn_count, total_count, id_length, total_length, cur_step, cur_model_type, write, sa32, sa64, height, width, attn_procs, unet, num_steps models_dict = { "RealVision": "SG161222/RealVisXL_V4.0", "Unstable": "stablediffusionapi/sdxl-unstable-diffusers-y" } use_va = True photomaker_path = hf_hub_download( repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model") device = "cuda" # Functions def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def set_text_unfinished(): return gr.update(visible=True, value="

(Not Finished) Generating ··· The intermediate results will be shown.

") def set_text_finished(): return gr.update(visible=True, value="

Generation Finished

") class SpatialAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. text_context_len (`int`, defaults to 77): The context length of the text features. scale (`float`, defaults to 1.0): the weight scale of image prompt. """ def __init__(self, hidden_size=None, cross_attention_dim=None, id_length=4, device="cuda", dtype=torch.float16): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.device = device self.dtype = dtype self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.total_length = id_length + 1 self.id_length = id_length self.id_bank = {} def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): # un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2) # un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb) # 生成一个0到1之间的随机数 global total_count, attn_count, cur_step, mask1024, mask4096 global sa32, sa64 global write global height, width global num_steps if write: # print(f"white:{cur_step}") self.id_bank[cur_step] = [ hidden_states[:self.id_length], hidden_states[self.id_length:]] else: encoder_hidden_states =[cur_step][0].to( self.device), hidden_states[:1], self.id_bank[cur_step][1].to(self.device), hidden_states[1:])) # 判断随机数是否大于0.5 if cur_step <= 1: hidden_states = self.__call2__( attn, hidden_states, None, attention_mask, temb) else: # 256 1024 4096 random_number = random.random() if cur_step < 0.4 * num_steps: rand_num = 0.3 else: rand_num = 0.1 # print(f"hidden state shape {hidden_states.shape[1]}") if random_number > rand_num: # print("mask shape",mask1024.shape,mask4096.shape) if not write: if hidden_states.shape[1] == (height//32) * (width//32): attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:] else: attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:] else: # print(self.total_length,self.id_length,hidden_states.shape,(height//32) * (width//32)) if hidden_states.shape[1] == (height//32) * (width//32): attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length, :mask1024.shape[0] // self.total_length * self.id_length] else: attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length, :mask4096.shape[0] // self.total_length * self.id_length] # print(attention_mask.shape) # print("before attention",hidden_states.shape,attention_mask.shape,encoder_hidden_states.shape if encoder_hidden_states is not None else "None") hidden_states = self.__call1__( attn, hidden_states, encoder_hidden_states, attention_mask, temb) else: hidden_states = self.__call2__( attn, hidden_states, None, attention_mask, temb) attn_count += 1 if attn_count == total_count: attn_count = 0 cur_step += 1 mask1024, mask4096 = cal_attn_mask_xl( self.total_length, self.id_length, sa32, sa64, height, width, device=self.device, dtype=self.dtype) return hidden_states def __call1__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): # print("hidden state shape",hidden_states.shape,self.id_length) residual = hidden_states # if encoder_hidden_states is not None: # raise Exception("not implement") if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: total_batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( total_batch_size, channel, height * width).transpose(1, 2) total_batch_size, nums_token, channel = hidden_states.shape img_nums = total_batch_size//2 hidden_states = hidden_states.view(-1, img_nums, nums_token, channel).reshape(-1, img_nums * nums_token, channel) batch_size, sequence_length, _ = hidden_states.shape if attn.group_norm is not None: hidden_states = attn.group_norm( hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states # B, N, C else: encoder_hidden_states = encoder_hidden_states.view( -1, self.id_length+1, nums_token, channel).reshape(-1, (self.id_length+1) * nums_token, channel) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # print(key.shape,value.shape,query.shape,attention_mask.shape) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 # print(query.shape,key.shape,value.shape,attention_mask.shape) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape( total_batch_size, -1, attn.heads * head_dim) hidden_states = # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) # if input_ndim == 4: # tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) # if attn.residual_connection: # tile_hidden_states = tile_hidden_states + residual if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( total_batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor # print(hidden_states.shape) return hidden_states def __call2__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, channel = ( hidden_states.shape ) # print(hidden_states.shape) if attention_mask is not None: attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view( batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm( hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states # B, N, C else: encoder_hidden_states = encoder_hidden_states.view( -1, self.id_length+1, sequence_length, channel).reshape(-1, (self.id_length+1) * sequence_length, channel) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim) hidden_states = # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose( -1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def set_attention_processor(unet, id_length, is_ipadapter=False): global total_count total_count = 0 attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith( "attn1.processor") else unet.config.cross_attention_dim if cross_attention_dim is None: if name.startswith("up_blocks"): attn_procs[name] = SpatialAttnProcessor2_0(id_length=id_length) total_count += 1 else: attn_procs[name] = AttnProcessor() else: attn_procs[name] = AttnProcessor() unet.set_attn_processor(copy.deepcopy(attn_procs)) print("Successfully loaded paired self-attention") print(f"Number of processors: {total_count}") attn_count = 0 total_count = 0 cur_step = 0 id_length = 4 total_length = 5 cur_model_type = "" device = "cuda" attn_procs = {} write = False sa32 = 0.5 sa64 = 0.5 height = 768 width = 768 def swap_to_gallery(images): return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) def upload_example_to_gallery(images, prompt, style, negative_prompt): return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) def remove_back_to_files(): return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) def remove_tips(): return gr.update(visible=False) def apply_style_positive(style_name: str, positive: str): p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) return p.replace("{prompt}", positive) def apply_style(style_name: str, positives: list, negative: str = ""): p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) return [p.replace("{prompt}", positive) for positive in positives], n + ' ' + negative def change_visiale_by_model_type(_model_type): if _model_type == "Only Using Textual Description": return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) elif _model_type == "Using Ref Images": return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) else: raise ValueError("Invalid model type", _model_type) @spaces.GPU(duration=120) def process_generation(_sd_type, _model_type, _upload_images, _num_steps, style_name, _Ip_Adapter_Strength, _style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array, G_height, G_width, _comic_type): global sa32, sa64, id_length, total_length, attn_procs, unet, cur_model_type, device, num_steps, write, cur_step, attn_count, height, width, pipe2, pipe4, sd_model_path, models_dict _model_type = "Photomaker" if _model_type == "Using Ref Images" else "original" if _model_type == "Photomaker" and "img" not in general_prompt: raise gr.Error( "Please add the trigger word 'img' behind the class word you want to customize, such as: man img or woman img") if _upload_images is None and _model_type != "original": raise gr.Error("Cannot find any input face image!") if len(prompt_array.splitlines()) > 10: raise gr.Error( f"No more than 10 prompts in Hugging Face demo for speed! But found {len(prompt_array.splitlines())} prompts!") height = G_height width = G_width sd_model_path = models_dict[_sd_type] num_steps = _num_steps if style_name == "(No style)": sd_model_path = models_dict["RealVision"] if _model_type == "original": pipe = StableDiffusionXLPipeline.from_pretrained( sd_model_path, torch_dtype=torch.float16) pipe = pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) elif _model_type == "Photomaker": if _sd_type != "RealVision" and style_name != "(No style)": pipe = set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) else: pipe = set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) else: raise NotImplementedError( "You should choose between original and Photomaker!", f"But you chose {_model_type}") pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) cur_model_type = _sd_type + "-" + _model_type + str(id_length_) if _model_type != "original": input_id_images = [load_image(img) for img in _upload_images] prompts = prompt_array.splitlines() start_merge_step = int(float(_style_strength_ratio) / 100 * _num_steps) if start_merge_step > 30: start_merge_step = 30 print(f"start_merge_step: {start_merge_step}") generator = torch.Generator(device="cuda").manual_seed(seed_) sa32, sa64 = sa32_, sa64_ id_length = id_length_ clipped_prompts = prompts[:] prompts = [general_prompt + "," + prompt if "[NC]" not in prompt else prompt.replace( "[NC]", "") for prompt in clipped_prompts] prompts = [prompt.rpartition( '#')[0] if "#" in prompt else prompt for prompt in prompts] print(prompts) id_prompts = prompts[:id_length] real_prompts = prompts[id_length:] torch.cuda.empty_cache() write = True cur_step = 0 attn_count = 0 id_prompts, negative_prompt = apply_style( style_name, id_prompts, negative_prompt) setup_seed(seed_) total_results = [] if _model_type == "original": id_images = pipe(id_prompts, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height=height, width=width, negative_prompt=negative_prompt, generator=generator).images elif _model_type == "Photomaker": id_images = pipe(id_prompts, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step=start_merge_step, height=height, width=width, negative_prompt=negative_prompt, generator=generator).images else: raise NotImplementedError( "You should choose between original and Photomaker!", f"But you chose {_model_type}") total_results = id_images + total_results yield total_results real_images = [] write = False for real_prompt in real_prompts: setup_seed(seed_) cur_step = 0 real_prompt = apply_style_positive(style_name, real_prompt) if _model_type == "original": real_images.append(pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height=height, width=width, negative_prompt=negative_prompt, generator=generator).images[0]) elif _model_type == "Photomaker": real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step=start_merge_step, height=height, width=width, negative_prompt=negative_prompt, generator=generator).images[0]) else: raise NotImplementedError( "You should choose between original and Photomaker!", f"But you chose {_model_type}") total_results = [real_images[-1]] + total_results yield total_results if _comic_type != "No typesetting (default)": from PIL import ImageFont captions = prompt_array.splitlines() captions = [caption.replace("[NC]", "") for caption in captions] captions = [caption.split( '#')[-1] if "#" in caption else caption for caption in captions] total_results = get_comic(id_images + real_images, _comic_type, captions=captions, font=ImageFont.truetype("./storyDiffusion/fonts/Inkfree.ttf", int(45))) + total_results if _model_type == "Photomaker": pipe ="cpu")"cpu") set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) yield total_results # Initialize pipelines pipe2 = PhotoMakerStableDiffusionXLPipeline.from_pretrained( models_dict["Unstable"], torch_dtype=torch.float16, use_safetensors=False) pipe2 ="cpu") pipe2.load_photomaker_adapter( os.path.dirname(photomaker_path), subfolder="", weight_name=os.path.basename(photomaker_path), trigger_word="img" ) pipe2 ="cpu") pipe2.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) pipe2.fuse_lora() pipe4 = PhotoMakerStableDiffusionXLPipeline.from_pretrained( models_dict["RealVision"], torch_dtype=torch.float16, use_safetensors=True) pipe4 ="cpu") pipe4.load_photomaker_adapter( os.path.dirname(photomaker_path), subfolder="", weight_name=os.path.basename(photomaker_path), trigger_word="img" ) pipe4 ="cpu") pipe4.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) pipe4.fuse_lora() def story_generation_ui(): with gr.Row(): with gr.Group(elem_id="main-image"): prompts = [] colors = [] with gr.Column(visible=True) as gen_prompt_vis: sd_type = gr.Dropdown(choices=list(models_dict.keys( )), value="Unstable", label="sd_type", info="Select pretrained model") model_type = gr.Radio(["Only Using Textual Description", "Using Ref Images"], label="model_type", value="Only Using Textual Description", info="Control type of the Character") with gr.Group(visible=False) as control_image_input: files = gr.Files( label="Drag (Select) 1 or more photos of your face", file_types=["image"], ) uploaded_files = gr.Gallery( label="Your images", visible=False, columns=5, rows=1, height=200) with gr.Column(visible=False) as clear_button: remove_and_reupload = gr.ClearButton( value="Remove and upload new ones", components=files, size="sm") general_prompt = gr.Textbox( value='', label="(1) Textual Description for Character", interactive=True) negative_prompt = gr.Textbox( value='', label="(2) Negative_prompt", interactive=True) style = gr.Dropdown( label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME) prompt_array = gr.Textbox( lines=3, value='', label="(3) Comic Description (each line corresponds to a frame).", interactive=True) with gr.Accordion("(4) Tune the hyperparameters", open=False): sa32_ = gr.Slider(label="(The degree of Paired Attention at 32 x 32 self-attention layers)", minimum=0, maximum=1., value=0.7, step=0.1) sa64_ = gr.Slider(label="(The degree of Paired Attention at 64 x 64 self-attention layers)", minimum=0, maximum=1., value=0.7, step=0.1) id_length_ = gr.Slider( label="Number of id images in total images", minimum=2, maximum=4, value=3, step=1) seed_ = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, value=0, step=1) num_steps = gr.Slider( label="Number of sample steps", minimum=25, maximum=50, step=1, value=50, ) G_height = gr.Slider( label="height", minimum=256, maximum=1024, step=32, value=1024, ) G_width = gr.Slider( label="width", minimum=256, maximum=1024, step=32, value=1024, ) comic_type = gr.Radio(["No typesetting (default)", "Four Pannel", "Classic Comic Style"], value="Classic Comic Style", label="Typesetting Style", info="Select the typesetting style ") guidance_scale = gr.Slider( label="Guidance scale", minimum=0.1, maximum=10.0, step=0.1, value=5, ) style_strength_ratio = gr.Slider( label="Style strength of Ref Image (%)", minimum=15, maximum=50, step=1, value=20, visible=False ) Ip_Adapter_Strength = gr.Slider( label="Ip_Adapter_Strength", minimum=0, maximum=1, step=0.1, value=0.5, visible=False ) final_run_btn = gr.Button("Generate ! 😺") with gr.Column(): out_image = gr.Gallery(label="Result", columns=2, height='auto') generated_information = gr.Markdown( label="Generation Details", value="", visible=False) model_type.change(fn=change_visiale_by_model_type, inputs=model_type, outputs=[ control_image_input, style_strength_ratio, Ip_Adapter_Strength]) files.upload(fn=swap_to_gallery, inputs=files, outputs=[ uploaded_files, clear_button, files]), outputs=[ uploaded_files, clear_button, files]), outputs=generated_information ).then(process_generation, inputs=[sd_type, model_type, files, num_steps, style, Ip_Adapter_Strength, style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array, G_height, G_width, comic_type], outputs=out_image ).then(fn=set_text_finished, outputs=generated_information) gr.Examples( examples=get_examples(), inputs=[seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array, style, model_type, files, G_height, G_width], label='😺 Examples 😺', )