Spaces:
Running
on
Zero
Running
on
Zero
from transformers import PretrainedConfig | |
from PIL import Image | |
import torch | |
import numpy as np | |
import PIL | |
import os | |
from tqdm.auto import tqdm | |
from diffusers.models.attention_processor import ( | |
AttnProcessor2_0, | |
LoRAAttnProcessor2_0, | |
LoRAXFormersAttnProcessor, | |
XFormersAttnProcessor, | |
) | |
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
def myroll2d(a, delta_x, delta_y): | |
h, w = a.shape[0], a.shape[1] | |
delta_x = -delta_x | |
delta_y = -delta_y | |
if isinstance(a, np.ndarray): | |
b = np.zeros ([h,w]).astype(np.uint8) | |
elif isinstance(a, torch.Tensor): | |
b = torch.zeros([h,w]).to(torch.uint8) | |
if delta_x > 0: | |
left_a = delta_x | |
right_a = w | |
left_b = 0 | |
right_b = w - delta_x | |
else: | |
left_a = 0 | |
right_a = w + delta_x | |
left_b = -delta_x | |
right_b = w | |
if delta_y > 0: | |
top_a = delta_y | |
bot_a = h | |
top_b = 0 | |
bot_b = h-delta_y | |
else: | |
top_a = 0 | |
bot_a = h + delta_y | |
top_b = -delta_y | |
bot_b = h | |
b[left_b: right_b, top_b: bot_b] = a[left_a: right_a, top_a: bot_a] | |
return b | |
def import_model_class_from_model_name_or_path( | |
pretrained_model_name_or_path: str, revision = None, subfolder: str = "text_encoder" | |
): | |
text_encoder_config = PretrainedConfig.from_pretrained( | |
pretrained_model_name_or_path, subfolder=subfolder, revision=revision | |
) | |
model_class = text_encoder_config.architectures[0] | |
if model_class == "CLIPTextModel": | |
from transformers import CLIPTextModel | |
return CLIPTextModel | |
elif model_class == "CLIPTextModelWithProjection": | |
from transformers import CLIPTextModelWithProjection | |
return CLIPTextModelWithProjection | |
else: | |
raise ValueError(f"{model_class} is not supported.") | |
def image2latent(image, vae = None, dtype=None): | |
with torch.no_grad(): | |
if type(image) is Image or type(image) is PIL.PngImagePlugin.PngImageFile or type(image) is PIL.JpegImagePlugin.JpegImageFile: | |
image = np.array(image) | |
if type(image) is torch.Tensor and image.dim() == 4: | |
latents = image | |
else: | |
image = torch.from_numpy(image).float() / 127.5 - 1 | |
image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype= dtype) | |
latents = vae.encode(image).latent_dist.sample() | |
latents = latents * vae.config.scaling_factor | |
return latents | |
def latent2image(latents, return_type = 'np', vae = None): | |
# needs_upcasting = vae.dtype == torch.float16 and vae.config.force_upcast | |
needs_upcasting = True | |
if needs_upcasting: | |
upcast_vae(vae) | |
latents = latents.to(next(iter(vae.post_quant_conv.parameters())).dtype) | |
image = vae.decode(latents /vae.config.scaling_factor, return_dict=False)[0] | |
if return_type == 'np': | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).numpy()#[0] | |
image = (image * 255).astype(np.uint8) | |
if needs_upcasting: | |
vae.to(dtype=torch.float16) | |
return image | |
def upcast_vae(vae): | |
dtype = vae.dtype | |
vae.to(dtype=torch.float32) | |
use_torch_2_0_or_xformers = isinstance( | |
vae.decoder.mid_block.attentions[0].processor, | |
( | |
AttnProcessor2_0, | |
XFormersAttnProcessor, | |
LoRAXFormersAttnProcessor, | |
LoRAAttnProcessor2_0, | |
), | |
) | |
# if xformers or torch_2_0 is used attention block does not need | |
# to be in float32 which can save lots of memory | |
if use_torch_2_0_or_xformers: | |
vae.post_quant_conv.to(dtype) | |
vae.decoder.conv_in.to(dtype) | |
vae.decoder.mid_block.to(dtype) | |
def prompt_to_emb_length_sdxl(prompt, tokenizer, text_encoder, length = None): | |
text_input = tokenizer( | |
[prompt], | |
padding="max_length", | |
max_length=length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
prompt_embeds = text_encoder(text_input.input_ids.to(device),output_hidden_states=True) | |
pooled_prompt_embeds = prompt_embeds[0] | |
prompt_embeds = prompt_embeds.hidden_states[-2] | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) | |
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) | |
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds} | |
def prompt_to_emb_length_sd(prompt, tokenizer, text_encoder, length = None): | |
text_input = tokenizer( | |
[prompt], | |
padding="max_length", | |
max_length=length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
emb = text_encoder(text_input.input_ids.to(device))[0] | |
return emb | |
def sdxl_prepare_input_decom( | |
set_string_list, | |
tokenizer, | |
tokenizer_2, | |
text_encoder_1, | |
text_encoder_2, | |
length = 20, | |
bsz = 1, | |
weight_dtype = torch.float32, | |
resolution = 1024, | |
normal_token_id_list = [] | |
): | |
encoder_hidden_states_list = [] | |
pooled_prompt_embeds = 0 | |
for m_idx in range(len(set_string_list)): | |
prompt_embeds_list = [] | |
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ### | |
out = prompt_to_emb_length_sdxl( | |
set_string_list[m_idx], tokenizer, text_encoder_1, length = length | |
) | |
else: | |
out = prompt_to_emb_length_sdxl( | |
set_string_list[m_idx], tokenizer, text_encoder_1, length = 77 | |
) | |
print(m_idx, set_string_list[m_idx]) | |
prompt_embeds, _ = out["prompt_embeds"].to(dtype=weight_dtype), out["pooled_prompt_embeds"].to(dtype=weight_dtype) | |
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) | |
prompt_embeds_list.append(prompt_embeds) | |
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list: | |
out = prompt_to_emb_length_sdxl( | |
set_string_list[m_idx], tokenizer_2, text_encoder_2, length = length | |
) | |
else: | |
out = prompt_to_emb_length_sdxl( | |
set_string_list[m_idx], tokenizer_2, text_encoder_2, length = 77 | |
) | |
print(m_idx, set_string_list[m_idx]) | |
prompt_embeds = out["prompt_embeds"].to(dtype=weight_dtype) | |
pooled_prompt_embeds += out["pooled_prompt_embeds"].to(dtype=weight_dtype) | |
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) | |
prompt_embeds_list.append(prompt_embeds) | |
encoder_hidden_states_list.append(torch.concat(prompt_embeds_list, dim=-1)) | |
add_text_embeds = pooled_prompt_embeds /len(set_string_list) | |
target_size, original_size,crops_coords_top_left = (resolution,resolution),(resolution,resolution),(0,0) | |
add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype,device = pooled_prompt_embeds.device) #[B,6] | |
return encoder_hidden_states_list, add_text_embeds, add_time_ids | |
def sd_prepare_input_decom( | |
set_string_list, | |
tokenizer, | |
text_encoder_1, | |
length = 20, | |
bsz = 1, | |
weight_dtype = torch.float32, | |
normal_token_id_list = [] | |
): | |
encoder_hidden_states_list = [] | |
for m_idx in range(len(set_string_list)): | |
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ### | |
encoder_hidden_states = prompt_to_emb_length_sd( | |
set_string_list[m_idx], tokenizer, text_encoder_1, length = length | |
) | |
else: | |
encoder_hidden_states = prompt_to_emb_length_sd( | |
set_string_list[m_idx], tokenizer, text_encoder_1, length = 77 | |
) | |
print(m_idx, set_string_list[m_idx]) | |
encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) | |
encoder_hidden_states_list.append(encoder_hidden_states.to(dtype=weight_dtype)) | |
return encoder_hidden_states_list | |
def load_mask (input_folder): | |
np_mask_dtype = 'uint8' | |
mask_np_list = [] | |
mask_label_list = [] | |
files = [ | |
file_name for file_name in os.listdir(input_folder) \ | |
if "mask" in file_name and ".npy" in file_name \ | |
and "_" in file_name and "Edited" not in file_name | |
] | |
files = sorted(files, key = lambda x: int(x.split("_")[0][4:])) | |
for idx, file_name in enumerate(files): | |
if "mask" in file_name and ".npy" in file_name and "_" in file_name \ | |
and "Edited" not in file_name: | |
mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype) | |
mask_np_list.append(mask_np) | |
mask_label = file_name.split("_")[1][:-4] | |
mask_label_list.append(mask_label) | |
mask_list = [] | |
for mask_np in mask_np_list: | |
mask = torch.from_numpy(mask_np) | |
mask_list.append(mask) | |
try: | |
assert torch.all(sum(mask_list)==1) | |
except: | |
print("please check mask") | |
# plt.imsave( "out_mask.png", mask_list_edit[0]) | |
return mask_list, mask_label_list | |
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512): | |
if type(image_path) is str: | |
image = np.array(Image.open(image_path))[:, :, :3] | |
else: | |
image = image_path | |
h, w, c = image.shape | |
left = min(left, w-1) | |
right = min(right, w - left - 1) | |
top = min(top, h - left - 1) | |
bottom = min(bottom, h - top - 1) | |
image = image[top:h-bottom, left:w-right] | |
h, w, c = image.shape | |
if h < w: | |
offset = (w - h) // 2 | |
image = image[:, offset:offset + h] | |
elif w < h: | |
offset = (h - w) // 2 | |
image = image[offset:offset + w] | |
image = np.array(Image.fromarray(image).resize((size, size))) | |
return image | |
def mask_union_torch(*masks): | |
masks = [m.to(torch.float) for m in masks] | |
res = sum(masks)>0 | |
return res | |
def load_mask_edit(input_folder): | |
np_mask_dtype = 'uint8' | |
mask_np_list = [] | |
mask_label_list = [] | |
files = [file_name for file_name in os.listdir(input_folder) if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name] | |
files = sorted(files, key = lambda x: int(x.split("_")[0][10:])) | |
for idx, file_name in enumerate(files): | |
if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name: | |
mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype) | |
mask_np_list.append(mask_np) | |
mask_label = file_name.split("_")[1][:-4] | |
# mask_label = mask_label.split("-")[0] | |
mask_label_list.append(mask_label) | |
mask_list = [] | |
for mask_np in mask_np_list: | |
mask = torch.from_numpy(mask_np) | |
mask_list.append(mask) | |
try: | |
assert torch.all(sum(mask_list)==1) | |
except: | |
print("Make sure maskEdited is in the folder, if not, generate using the UI") | |
import pdb; pdb.set_trace() | |
return mask_list, mask_label_list | |
def save_images(images,filename, num_rows=1, offset_ratio=0.02): | |
if type(images) is list: | |
num_empty = len(images) % num_rows | |
elif images.ndim == 4: | |
num_empty = images.shape[0] % num_rows | |
else: | |
images = [images] | |
num_empty = 0 | |
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
num_items = len(images) | |
folder = os.path.dirname(filename) | |
for i, image in enumerate(images): | |
pil_img = Image.fromarray(image) | |
name = filename.split("/")[-1] | |
name = name.split(".")[-2]+"_{}".format(i) +"."+filename.split(".")[-1] | |
pil_img.save(os.path.join(folder, name)) | |
print("saved to ", os.path.join(folder, name)) |