import gradio as gr
from examples.story_examples import get_examples
import spaces
import numpy as np
import torch
import random
import os
import torch.nn.functional as F
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import copy
from huggingface_hub import hf_hub_download
from diffusers.utils import load_image
from storyDiffusion.utils.gradio_utils import AttnProcessor2_0 as AttnProcessor, cal_attn_mask_xl
from storyDiffusion.utils import PhotoMakerStableDiffusionXLPipeline
from storyDiffusion.utils.utils import get_comic
from storyDiffusion.utils.style_template import styles
# Constants
image_encoder_path = "./data/models/ip_adapter/sdxl_models/image_encoder"
ip_ckpt = "./data/models/ip_adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin"
os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Japanese Anime"
MAX_SEED = np.iinfo(np.int32).max
# Global variables
global models_dict, use_va, photomaker_path, pipe2, pipe4, attn_count, total_count, id_length, total_length, cur_step, cur_model_type, write, sa32, sa64, height, width, attn_procs, unet, num_steps
models_dict = {
"RealVision": "SG161222/RealVisXL_V4.0",
"Unstable": "stablediffusionapi/sdxl-unstable-diffusers-y"
}
use_va = True
photomaker_path = hf_hub_download(
repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model")
device = "cuda"
# Functions
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def set_text_unfinished():
return gr.update(visible=True, value="
(Not Finished) Generating ··· The intermediate results will be shown.
")
def set_text_finished():
return gr.update(visible=True, value="Generation Finished
")
class SpatialAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
text_context_len (`int`, defaults to 77):
The context length of the text features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size=None, cross_attention_dim=None, id_length=4, device="cuda", dtype=torch.float16):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.device = device
self.dtype = dtype
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.total_length = id_length + 1
self.id_length = id_length
self.id_bank = {}
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None):
# un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2)
# un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb)
# 生成一个0到1之间的随机数
global total_count, attn_count, cur_step, mask1024, mask4096
global sa32, sa64
global write
global height, width
global num_steps
if write:
# print(f"white:{cur_step}")
self.id_bank[cur_step] = [
hidden_states[:self.id_length], hidden_states[self.id_length:]]
else:
encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(
self.device), hidden_states[:1], self.id_bank[cur_step][1].to(self.device), hidden_states[1:]))
# 判断随机数是否大于0.5
if cur_step <= 1:
hidden_states = self.__call2__(
attn, hidden_states, None, attention_mask, temb)
else: # 256 1024 4096
random_number = random.random()
if cur_step < 0.4 * num_steps:
rand_num = 0.3
else:
rand_num = 0.1
# print(f"hidden state shape {hidden_states.shape[1]}")
if random_number > rand_num:
# print("mask shape",mask1024.shape,mask4096.shape)
if not write:
if hidden_states.shape[1] == (height//32) * (width//32):
attention_mask = mask1024[mask1024.shape[0] //
self.total_length * self.id_length:]
else:
attention_mask = mask4096[mask4096.shape[0] //
self.total_length * self.id_length:]
else:
# print(self.total_length,self.id_length,hidden_states.shape,(height//32) * (width//32))
if hidden_states.shape[1] == (height//32) * (width//32):
attention_mask = mask1024[:mask1024.shape[0] // self.total_length *
self.id_length, :mask1024.shape[0] // self.total_length * self.id_length]
else:
attention_mask = mask4096[:mask4096.shape[0] // self.total_length *
self.id_length, :mask4096.shape[0] // self.total_length * self.id_length]
# print(attention_mask.shape)
# print("before attention",hidden_states.shape,attention_mask.shape,encoder_hidden_states.shape if encoder_hidden_states is not None else "None")
hidden_states = self.__call1__(
attn, hidden_states, encoder_hidden_states, attention_mask, temb)
else:
hidden_states = self.__call2__(
attn, hidden_states, None, attention_mask, temb)
attn_count += 1
if attn_count == total_count:
attn_count = 0
cur_step += 1
mask1024, mask4096 = cal_attn_mask_xl(
self.total_length, self.id_length, sa32, sa64, height, width, device=self.device, dtype=self.dtype)
return hidden_states
def __call1__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
# print("hidden state shape",hidden_states.shape,self.id_length)
residual = hidden_states
# if encoder_hidden_states is not None:
# raise Exception("not implement")
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
total_batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
total_batch_size, channel, height * width).transpose(1, 2)
total_batch_size, nums_token, channel = hidden_states.shape
img_nums = total_batch_size//2
hidden_states = hidden_states.view(-1, img_nums, nums_token,
channel).reshape(-1, img_nums * nums_token, channel)
batch_size, sequence_length, _ = hidden_states.shape
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states # B, N, C
else:
encoder_hidden_states = encoder_hidden_states.view(
-1, self.id_length+1, nums_token, channel).reshape(-1, (self.id_length+1) * nums_token, channel)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
# print(key.shape,value.shape,query.shape,attention_mask.shape)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
# print(query.shape,key.shape,value.shape,attention_mask.shape)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
total_batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
# if input_ndim == 4:
# tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# if attn.residual_connection:
# tile_hidden_states = tile_hidden_states + residual
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
total_batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# print(hidden_states.shape)
return hidden_states
def __call2__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, channel = (
hidden_states.shape
)
# print(hidden_states.shape)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states # B, N, C
else:
encoder_hidden_states = encoder_hidden_states.view(
-1, self.id_length+1, sequence_length, channel).reshape(-1, (self.id_length+1) * sequence_length, channel)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def set_attention_processor(unet, id_length, is_ipadapter=False):
global total_count
total_count = 0
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith(
"attn1.processor") else unet.config.cross_attention_dim
if cross_attention_dim is None:
if name.startswith("up_blocks"):
attn_procs[name] = SpatialAttnProcessor2_0(id_length=id_length)
total_count += 1
else:
attn_procs[name] = AttnProcessor()
else:
attn_procs[name] = AttnProcessor()
unet.set_attn_processor(copy.deepcopy(attn_procs))
print("Successfully loaded paired self-attention")
print(f"Number of processors: {total_count}")
attn_count = 0
total_count = 0
cur_step = 0
id_length = 4
total_length = 5
cur_model_type = ""
device = "cuda"
attn_procs = {}
write = False
sa32 = 0.5
sa64 = 0.5
height = 768
width = 768
def swap_to_gallery(images):
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
def upload_example_to_gallery(images, prompt, style, negative_prompt):
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
def remove_back_to_files():
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
def remove_tips():
return gr.update(visible=False)
def apply_style_positive(style_name: str, positive: str):
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive)
def apply_style(style_name: str, positives: list, negative: str = ""):
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return [p.replace("{prompt}", positive) for positive in positives], n + ' ' + negative
def change_visiale_by_model_type(_model_type):
if _model_type == "Only Using Textual Description":
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
elif _model_type == "Using Ref Images":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
else:
raise ValueError("Invalid model type", _model_type)
@spaces.GPU(duration=120)
def process_generation(_sd_type, _model_type, _upload_images, _num_steps, style_name, _Ip_Adapter_Strength, _style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array, G_height, G_width, _comic_type):
global sa32, sa64, id_length, total_length, attn_procs, unet, cur_model_type, device, num_steps, write, cur_step, attn_count, height, width, pipe2, pipe4, sd_model_path, models_dict
_model_type = "Photomaker" if _model_type == "Using Ref Images" else "original"
if _model_type == "Photomaker" and "img" not in general_prompt:
raise gr.Error(
"Please add the trigger word 'img' behind the class word you want to customize, such as: man img or woman img")
if _upload_images is None and _model_type != "original":
raise gr.Error("Cannot find any input face image!")
if len(prompt_array.splitlines()) > 10:
raise gr.Error(
f"No more than 10 prompts in Hugging Face demo for speed! But found {len(prompt_array.splitlines())} prompts!")
height = G_height
width = G_width
sd_model_path = models_dict[_sd_type]
num_steps = _num_steps
if style_name == "(No style)":
sd_model_path = models_dict["RealVision"]
if _model_type == "original":
pipe = StableDiffusionXLPipeline.from_pretrained(
sd_model_path, torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
set_attention_processor(pipe.unet, id_length_, is_ipadapter=False)
elif _model_type == "Photomaker":
if _sd_type != "RealVision" and style_name != "(No style)":
pipe = pipe2.to(device)
pipe.id_encoder.to(device)
set_attention_processor(pipe.unet, id_length_, is_ipadapter=False)
else:
pipe = pipe4.to(device)
pipe.id_encoder.to(device)
set_attention_processor(pipe.unet, id_length_, is_ipadapter=False)
else:
raise NotImplementedError(
"You should choose between original and Photomaker!", f"But you chose {_model_type}")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
cur_model_type = _sd_type + "-" + _model_type + str(id_length_)
if _model_type != "original":
input_id_images = [load_image(img) for img in _upload_images]
prompts = prompt_array.splitlines()
start_merge_step = int(float(_style_strength_ratio) / 100 * _num_steps)
if start_merge_step > 30:
start_merge_step = 30
print(f"start_merge_step: {start_merge_step}")
generator = torch.Generator(device="cuda").manual_seed(seed_)
sa32, sa64 = sa32_, sa64_
id_length = id_length_
clipped_prompts = prompts[:]
prompts = [general_prompt + "," + prompt if "[NC]" not in prompt else prompt.replace(
"[NC]", "") for prompt in clipped_prompts]
prompts = [prompt.rpartition(
'#')[0] if "#" in prompt else prompt for prompt in prompts]
print(prompts)
id_prompts = prompts[:id_length]
real_prompts = prompts[id_length:]
torch.cuda.empty_cache()
write = True
cur_step = 0
attn_count = 0
id_prompts, negative_prompt = apply_style(
style_name, id_prompts, negative_prompt)
setup_seed(seed_)
total_results = []
if _model_type == "original":
id_images = pipe(id_prompts, num_inference_steps=_num_steps, guidance_scale=guidance_scale,
height=height, width=width, negative_prompt=negative_prompt, generator=generator).images
elif _model_type == "Photomaker":
id_images = pipe(id_prompts, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale,
start_merge_step=start_merge_step, height=height, width=width, negative_prompt=negative_prompt, generator=generator).images
else:
raise NotImplementedError(
"You should choose between original and Photomaker!", f"But you chose {_model_type}")
total_results = id_images + total_results
yield total_results
real_images = []
write = False
for real_prompt in real_prompts:
setup_seed(seed_)
cur_step = 0
real_prompt = apply_style_positive(style_name, real_prompt)
if _model_type == "original":
real_images.append(pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale,
height=height, width=width, negative_prompt=negative_prompt, generator=generator).images[0])
elif _model_type == "Photomaker":
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale,
start_merge_step=start_merge_step, height=height, width=width, negative_prompt=negative_prompt, generator=generator).images[0])
else:
raise NotImplementedError(
"You should choose between original and Photomaker!", f"But you chose {_model_type}")
total_results = [real_images[-1]] + total_results
yield total_results
if _comic_type != "No typesetting (default)":
from PIL import ImageFont
captions = prompt_array.splitlines()
captions = [caption.replace("[NC]", "") for caption in captions]
captions = [caption.split(
'#')[-1] if "#" in caption else caption for caption in captions]
total_results = get_comic(id_images + real_images, _comic_type, captions=captions,
font=ImageFont.truetype("./storyDiffusion/fonts/Inkfree.ttf", int(45))) + total_results
if _model_type == "Photomaker":
pipe = pipe2.to("cpu")
pipe.id_encoder.to("cpu")
set_attention_processor(pipe.unet, id_length_, is_ipadapter=False)
yield total_results
# Initialize pipelines
pipe2 = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
models_dict["Unstable"], torch_dtype=torch.float16, use_safetensors=False)
pipe2 = pipe2.to("cpu")
pipe2.load_photomaker_adapter(
os.path.dirname(photomaker_path),
subfolder="",
weight_name=os.path.basename(photomaker_path),
trigger_word="img"
)
pipe2 = pipe2.to("cpu")
pipe2.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
pipe2.fuse_lora()
pipe4 = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
models_dict["RealVision"], torch_dtype=torch.float16, use_safetensors=True)
pipe4 = pipe4.to("cpu")
pipe4.load_photomaker_adapter(
os.path.dirname(photomaker_path),
subfolder="",
weight_name=os.path.basename(photomaker_path),
trigger_word="img"
)
pipe4 = pipe4.to("cpu")
pipe4.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
pipe4.fuse_lora()
def story_generation_ui():
with gr.Row():
with gr.Group(elem_id="main-image"):
prompts = []
colors = []
with gr.Column(visible=True) as gen_prompt_vis:
sd_type = gr.Dropdown(choices=list(models_dict.keys(
)), value="Unstable", label="sd_type", info="Select pretrained model")
model_type = gr.Radio(["Only Using Textual Description", "Using Ref Images"], label="model_type",
value="Only Using Textual Description", info="Control type of the Character")
with gr.Group(visible=False) as control_image_input:
files = gr.Files(
label="Drag (Select) 1 or more photos of your face",
file_types=["image"],
)
uploaded_files = gr.Gallery(
label="Your images", visible=False, columns=5, rows=1, height=200)
with gr.Column(visible=False) as clear_button:
remove_and_reupload = gr.ClearButton(
value="Remove and upload new ones", components=files, size="sm")
general_prompt = gr.Textbox(
value='', label="(1) Textual Description for Character", interactive=True)
negative_prompt = gr.Textbox(
value='', label="(2) Negative_prompt", interactive=True)
style = gr.Dropdown(
label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
prompt_array = gr.Textbox(
lines=3, value='', label="(3) Comic Description (each line corresponds to a frame).", interactive=True)
with gr.Accordion("(4) Tune the hyperparameters", open=False):
sa32_ = gr.Slider(label="(The degree of Paired Attention at 32 x 32 self-attention layers)",
minimum=0, maximum=1., value=0.7, step=0.1)
sa64_ = gr.Slider(label="(The degree of Paired Attention at 64 x 64 self-attention layers)",
minimum=0, maximum=1., value=0.7, step=0.1)
id_length_ = gr.Slider(
label="Number of id images in total images", minimum=2, maximum=4, value=3, step=1)
seed_ = gr.Slider(label="Seed", minimum=-1,
maximum=MAX_SEED, value=0, step=1)
num_steps = gr.Slider(
label="Number of sample steps",
minimum=25,
maximum=50,
step=1,
value=50,
)
G_height = gr.Slider(
label="height",
minimum=256,
maximum=1024,
step=32,
value=1024,
)
G_width = gr.Slider(
label="width",
minimum=256,
maximum=1024,
step=32,
value=1024,
)
comic_type = gr.Radio(["No typesetting (default)", "Four Pannel", "Classic Comic Style"],
value="Classic Comic Style", label="Typesetting Style", info="Select the typesetting style ")
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.1,
maximum=10.0,
step=0.1,
value=5,
)
style_strength_ratio = gr.Slider(
label="Style strength of Ref Image (%)",
minimum=15,
maximum=50,
step=1,
value=20,
visible=False
)
Ip_Adapter_Strength = gr.Slider(
label="Ip_Adapter_Strength",
minimum=0,
maximum=1,
step=0.1,
value=0.5,
visible=False
)
final_run_btn = gr.Button("Generate ! 😺")
with gr.Column():
out_image = gr.Gallery(label="Result", columns=2, height='auto')
generated_information = gr.Markdown(
label="Generation Details", value="", visible=False)
model_type.change(fn=change_visiale_by_model_type, inputs=model_type, outputs=[
control_image_input, style_strength_ratio, Ip_Adapter_Strength])
files.upload(fn=swap_to_gallery, inputs=files, outputs=[
uploaded_files, clear_button, files])
remove_and_reupload.click(fn=remove_back_to_files, outputs=[
uploaded_files, clear_button, files])
final_run_btn.click(fn=set_text_unfinished, outputs=generated_information
).then(process_generation, inputs=[sd_type, model_type, files, num_steps, style, Ip_Adapter_Strength, style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array, G_height, G_width, comic_type], outputs=out_image
).then(fn=set_text_finished, outputs=generated_information)
gr.Examples(
examples=get_examples(),
inputs=[seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt,
prompt_array, style, model_type, files, G_height, G_width],
label='😺 Examples 😺',
)