from transformers import PretrainedConfig from PIL import Image import torch import numpy as np import PIL import os from tqdm.auto import tqdm from diffusers.models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') def myroll2d(a, delta_x, delta_y): h, w = a.shape[0], a.shape[1] delta_x = -delta_x delta_y = -delta_y if isinstance(a, np.ndarray): b = np.zeros ([h,w]).astype(np.uint8) elif isinstance(a, torch.Tensor): b = torch.zeros([h,w]).to(torch.uint8) if delta_x > 0: left_a = delta_x right_a = w left_b = 0 right_b = w - delta_x else: left_a = 0 right_a = w + delta_x left_b = -delta_x right_b = w if delta_y > 0: top_a = delta_y bot_a = h top_b = 0 bot_b = h-delta_y else: top_a = 0 bot_a = h + delta_y top_b = -delta_y bot_b = h b[left_b: right_b, top_b: bot_b] = a[left_a: right_a, top_a: bot_a] return b def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision = None, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "CLIPTextModelWithProjection": from transformers import CLIPTextModelWithProjection return CLIPTextModelWithProjection else: raise ValueError(f"{model_class} is not supported.") @torch.no_grad() def image2latent(image, vae = None, dtype=None): with torch.no_grad(): if type(image) is Image or type(image) is PIL.PngImagePlugin.PngImageFile or type(image) is PIL.JpegImagePlugin.JpegImageFile: image = np.array(image) if type(image) is torch.Tensor and image.dim() == 4: latents = image else: image = torch.from_numpy(image).float() / 127.5 - 1 image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype= dtype) latents = vae.encode(image).latent_dist.sample() latents = latents * vae.config.scaling_factor return latents @torch.no_grad() def latent2image(latents, return_type = 'np', vae = None): # needs_upcasting = vae.dtype == torch.float16 and vae.config.force_upcast needs_upcasting = True if needs_upcasting: upcast_vae(vae) latents = latents.to(next(iter(vae.post_quant_conv.parameters())).dtype) image = vae.decode(latents /vae.config.scaling_factor, return_dict=False)[0] if return_type == 'np': image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy()#[0] image = (image * 255).astype(np.uint8) if needs_upcasting: vae.to(dtype=torch.float16) return image def upcast_vae(vae): dtype = vae.dtype vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: vae.post_quant_conv.to(dtype) vae.decoder.conv_in.to(dtype) vae.decoder.mid_block.to(dtype) def prompt_to_emb_length_sdxl(prompt, tokenizer, text_encoder, length = None): text_input = tokenizer( [prompt], padding="max_length", max_length=length, truncation=True, return_tensors="pt", ) prompt_embeds = text_encoder(text_input.input_ids.to(device),output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds} def prompt_to_emb_length_sd(prompt, tokenizer, text_encoder, length = None): text_input = tokenizer( [prompt], padding="max_length", max_length=length, truncation=True, return_tensors="pt", ) emb = text_encoder(text_input.input_ids.to(device))[0] return emb def sdxl_prepare_input_decom( set_string_list, tokenizer, tokenizer_2, text_encoder_1, text_encoder_2, length = 20, bsz = 1, weight_dtype = torch.float32, resolution = 1024, normal_token_id_list = [] ): encoder_hidden_states_list = [] pooled_prompt_embeds = 0 for m_idx in range(len(set_string_list)): prompt_embeds_list = [] if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ### out = prompt_to_emb_length_sdxl( set_string_list[m_idx], tokenizer, text_encoder_1, length = length ) else: out = prompt_to_emb_length_sdxl( set_string_list[m_idx], tokenizer, text_encoder_1, length = 77 ) print(m_idx, set_string_list[m_idx]) prompt_embeds, _ = out["prompt_embeds"].to(dtype=weight_dtype), out["pooled_prompt_embeds"].to(dtype=weight_dtype) prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) prompt_embeds_list.append(prompt_embeds) if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list: out = prompt_to_emb_length_sdxl( set_string_list[m_idx], tokenizer_2, text_encoder_2, length = length ) else: out = prompt_to_emb_length_sdxl( set_string_list[m_idx], tokenizer_2, text_encoder_2, length = 77 ) print(m_idx, set_string_list[m_idx]) prompt_embeds = out["prompt_embeds"].to(dtype=weight_dtype) pooled_prompt_embeds += out["pooled_prompt_embeds"].to(dtype=weight_dtype) prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) prompt_embeds_list.append(prompt_embeds) encoder_hidden_states_list.append(torch.concat(prompt_embeds_list, dim=-1)) add_text_embeds = pooled_prompt_embeds /len(set_string_list) target_size, original_size,crops_coords_top_left = (resolution,resolution),(resolution,resolution),(0,0) add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype,device = pooled_prompt_embeds.device) #[B,6] return encoder_hidden_states_list, add_text_embeds, add_time_ids def sd_prepare_input_decom( set_string_list, tokenizer, text_encoder_1, length = 20, bsz = 1, weight_dtype = torch.float32, normal_token_id_list = [] ): encoder_hidden_states_list = [] for m_idx in range(len(set_string_list)): if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : ### encoder_hidden_states = prompt_to_emb_length_sd( set_string_list[m_idx], tokenizer, text_encoder_1, length = length ) else: encoder_hidden_states = prompt_to_emb_length_sd( set_string_list[m_idx], tokenizer, text_encoder_1, length = 77 ) print(m_idx, set_string_list[m_idx]) encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) encoder_hidden_states_list.append(encoder_hidden_states.to(dtype=weight_dtype)) return encoder_hidden_states_list def load_mask (input_folder): np_mask_dtype = 'uint8' mask_np_list = [] mask_label_list = [] files = [ file_name for file_name in os.listdir(input_folder) \ if "mask" in file_name and ".npy" in file_name \ and "_" in file_name and "Edited" not in file_name ] files = sorted(files, key = lambda x: int(x.split("_")[0][4:])) for idx, file_name in enumerate(files): if "mask" in file_name and ".npy" in file_name and "_" in file_name \ and "Edited" not in file_name: mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype) mask_np_list.append(mask_np) mask_label = file_name.split("_")[1][:-4] mask_label_list.append(mask_label) mask_list = [] for mask_np in mask_np_list: mask = torch.from_numpy(mask_np) mask_list.append(mask) try: assert torch.all(sum(mask_list)==1) except: print("please check mask") # plt.imsave( "out_mask.png", mask_list_edit[0]) return mask_list, mask_label_list def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512): if type(image_path) is str: image = np.array(Image.open(image_path))[:, :, :3] else: image = image_path h, w, c = image.shape left = min(left, w-1) right = min(right, w - left - 1) top = min(top, h - left - 1) bottom = min(bottom, h - top - 1) image = image[top:h-bottom, left:w-right] h, w, c = image.shape if h < w: offset = (w - h) // 2 image = image[:, offset:offset + h] elif w < h: offset = (h - w) // 2 image = image[offset:offset + w] image = np.array(Image.fromarray(image).resize((size, size))) return image def mask_union_torch(*masks): masks = [m.to(torch.float) for m in masks] res = sum(masks)>0 return res def load_mask_edit(input_folder): np_mask_dtype = 'uint8' mask_np_list = [] mask_label_list = [] files = [file_name for file_name in os.listdir(input_folder) if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name] files = sorted(files, key = lambda x: int(x.split("_")[0][10:])) for idx, file_name in enumerate(files): if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name: mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype) mask_np_list.append(mask_np) mask_label = file_name.split("_")[1][:-4] # mask_label = mask_label.split("-")[0] mask_label_list.append(mask_label) mask_list = [] for mask_np in mask_np_list: mask = torch.from_numpy(mask_np) mask_list.append(mask) try: assert torch.all(sum(mask_list)==1) except: print("Make sure maskEdited is in the folder, if not, generate using the UI") import pdb; pdb.set_trace() return mask_list, mask_label_list def save_images(images,filename, num_rows=1, offset_ratio=0.02): if type(images) is list: num_empty = len(images) % num_rows elif images.ndim == 4: num_empty = images.shape[0] % num_rows else: images = [images] num_empty = 0 empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty num_items = len(images) folder = os.path.dirname(filename) for i, image in enumerate(images): pil_img = Image.fromarray(image) name = filename.split("/")[-1] name = name.split(".")[-2]+"_{}".format(i) +"."+filename.split(".")[-1] pil_img.save(os.path.join(folder, name)) print("saved to ", os.path.join(folder, name))