diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000000000000000000000000000000000000..0110818d6b1e70612d770dd55726953ec8002c6f --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +img/ +logfile/ +__pycache__/ +*/__pycache__/ +models/ +plt/ +docs/ +exp/ +examples/mask/ +examples/mask_box/ \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000000000000000000000000000000000..224b59528850f7106228202b8d5901771b41c614 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,17 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Demo", + "type": "debugpy", + "request": "launch", + "program": "/home/jyr/demo/DesignEdit/app.py", + "console": "integratedTerminal", + "python": "/home/jyr/.conda/envs/new_design/bin/python", + + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 9504551be6c0d250f32f617faec2350c5c855aa1..ddb40e73375e68172b0c30cf147112ae756624ec 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ --- title: DesignEdit -emoji: 🏆 -colorFrom: pink -colorTo: yellow +emoji: 🌿 +colorFrom: yellow +colorTo: green sdk: gradio -sdk_version: 4.25.0 +sdk_version: 4.24.0 app_file: app.py pinned: false --- diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..891e638d4689ebee66ef358de780a849ea5121fa --- /dev/null +++ b/app.py @@ -0,0 +1,61 @@ +import gradio as gr +import spaces +import torch + +import os +import subprocess +import shlex +from src.demo.model import DesignEdit + +os.makedirs('models', exist_ok=True) +subprocess.run(shlex.split('wget https://huggingface.co/Adapter/DragonDiffusion/resolve/main/model/efficient_sam_vits.pt -O models/efficient_sam_vits.pt')) + +from src.demo.demo import * +import shlex +import cv2 + +pretrained_model_path = "stabilityai/stable-diffusion-xl-base-1.0" +model = DesignEdit(pretrained_model_path=pretrained_model_path) +DESCRIPTION_1 = """
+ + 🌿D + e + s + i + g + n + E + d + i + t🌿 + +
+ """ +DESCRIPTION_2 = """

Multi-Layered Latent Decomposition and Fusion for Unified & Accurate Image Editing

""" +DESCRIPTION_3 = """ +
+

Gradio demo for DesignEdit

+
+""" + + +with gr.Blocks(css='style.css') as demo: + gr.HTML(DESCRIPTION_1) + gr.HTML(DESCRIPTION_2) + gr.HTML(DESCRIPTION_3) + with gr.Tabs(): + with gr.TabItem('1️⃣ Object Removal'): + create_demo_remove(model.run_remove) + with gr.TabItem('2️⃣ Zooming Out'): + create_demo_zooming(model.run_zooming) + with gr.TabItem('3️⃣ Camera Panning'): + create_demo_panning(model.run_panning) + with gr.TabItem('4️⃣ Object Moving, Resizing and Flipping'): + create_demo_moving(model.run_moving) + with gr.TabItem('5️⃣ 🚩 Multi-Layered Editing 🚩'): + create_demo_layer(model.run_layer) + with gr.TabItem('🔧 Mask Preparation: Draw or Sketch'): + create_demo_mask_box(model.run_mask) +demo.queue(max_size=20) +demo.launch(max_threads=3, server_name="0.0.0.0") + diff --git a/examples/layer/01_horse/00.jpg b/examples/layer/01_horse/00.jpg new file mode 100755 index 0000000000000000000000000000000000000000..502be3f0bbada559c91718606b6b03c86021572b Binary files /dev/null and b/examples/layer/01_horse/00.jpg differ diff --git a/examples/layer/01_horse/mask0.jpg b/examples/layer/01_horse/mask0.jpg new file mode 100755 index 0000000000000000000000000000000000000000..39377e481c21d9980df45c94c233f562809efbfa Binary files /dev/null and b/examples/layer/01_horse/mask0.jpg differ diff --git a/examples/layer/02_baby/00.jpg b/examples/layer/02_baby/00.jpg new file mode 100755 index 0000000000000000000000000000000000000000..f4fbb272ec041a9fd0793f10cc6830c42f740047 Binary files /dev/null and b/examples/layer/02_baby/00.jpg differ diff --git a/examples/layer/02_baby/mask0.jpg b/examples/layer/02_baby/mask0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f94b32a63688d6b2852e7f743dc4dffb742bdd53 Binary files /dev/null and b/examples/layer/02_baby/mask0.jpg differ diff --git a/examples/layer/02_baby/mask1.jpg b/examples/layer/02_baby/mask1.jpg new file mode 100755 index 0000000000000000000000000000000000000000..bbed5f61ca38f384ab3f2f87c206efd74174e1da Binary files /dev/null and b/examples/layer/02_baby/mask1.jpg differ diff --git a/examples/layer/02_baby/mask2.jpg b/examples/layer/02_baby/mask2.jpg new file mode 100755 index 0000000000000000000000000000000000000000..c77e82d7eeea470bd3a787d018815eda5ee7f588 Binary files /dev/null and b/examples/layer/02_baby/mask2.jpg differ diff --git a/examples/layer/03_text/00.jpg b/examples/layer/03_text/00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2ae11e62b15bbabcba72e46fd373ad0c718a8fb3 Binary files /dev/null and b/examples/layer/03_text/00.jpg differ diff --git a/examples/layer/03_text/01.jpg b/examples/layer/03_text/01.jpg new file mode 100644 index 0000000000000000000000000000000000000000..654004437236e476431cb81d38434dec1d65e1d1 Binary files /dev/null and b/examples/layer/03_text/01.jpg differ diff --git a/examples/layer/03_text/mask0.jpg b/examples/layer/03_text/mask0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bf0ce7ffc1799bd21967250f02ecd5f5159ae384 Binary files /dev/null and b/examples/layer/03_text/mask0.jpg differ diff --git a/examples/layer/03_text/mask1.jpg b/examples/layer/03_text/mask1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a1d805e541613f2c53f767b0e775ab053707cb34 Binary files /dev/null and b/examples/layer/03_text/mask1.jpg differ diff --git a/examples/layer/04_cross/0.jpg b/examples/layer/04_cross/0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4edde7252b283d38f8e123da726c00e8ecf418fb Binary files /dev/null and b/examples/layer/04_cross/0.jpg differ diff --git a/examples/layer/04_cross/1.jpg b/examples/layer/04_cross/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a71e2c9477b3ae1911be802d6020df06c55554da Binary files /dev/null and b/examples/layer/04_cross/1.jpg differ diff --git a/examples/layer/04_cross/2.jpg b/examples/layer/04_cross/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f40b0440819237e7e18eb8e8479cd370da837e2c Binary files /dev/null and b/examples/layer/04_cross/2.jpg differ diff --git a/examples/layer/04_cross/3.jpg b/examples/layer/04_cross/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a2ded809ac249e22dd2808dd99ca386c56632adc Binary files /dev/null and b/examples/layer/04_cross/3.jpg differ diff --git a/examples/layer/04_cross/mask0.jpg b/examples/layer/04_cross/mask0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e07c313622a380fe141c9b799ab362ab54703345 Binary files /dev/null and b/examples/layer/04_cross/mask0.jpg differ diff --git a/examples/layer/04_cross/mask1.jpg b/examples/layer/04_cross/mask1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3c90e633a455b762be1fc39a0ece2c62ea55bb5f Binary files /dev/null and b/examples/layer/04_cross/mask1.jpg differ diff --git a/examples/layer/04_cross/mask2.jpg b/examples/layer/04_cross/mask2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1d09eecc9bf6bf60321def094bc3ecb726963fee Binary files /dev/null and b/examples/layer/04_cross/mask2.jpg differ diff --git a/examples/layer/04_cross/mask3.jpg b/examples/layer/04_cross/mask3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..19b58983cf1463edfd7263fa40296ba723f4d03a Binary files /dev/null and b/examples/layer/04_cross/mask3.jpg differ diff --git a/examples/moving/01_ball/0.jpg b/examples/moving/01_ball/0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f11738232d02b9ffcb60fe4878764a1081d7fb7a Binary files /dev/null and b/examples/moving/01_ball/0.jpg differ diff --git a/examples/moving/01_ball/mask0.jpg b/examples/moving/01_ball/mask0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..49e57f3cdb98cd1ee80e70318bcb081fb0fbd930 Binary files /dev/null and b/examples/moving/01_ball/mask0.jpg differ diff --git a/examples/moving/02_bell/0.jpg b/examples/moving/02_bell/0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1a3e1fba25f66f21fe42265f867fa8daaa3f9032 Binary files /dev/null and b/examples/moving/02_bell/0.jpg differ diff --git a/examples/moving/02_bell/mask0.jpg b/examples/moving/02_bell/mask0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4f89e4293acea880cd8a04bb0f7e2d69f7cbdb2 Binary files /dev/null and b/examples/moving/02_bell/mask0.jpg differ diff --git a/examples/pan/01.jpg b/examples/pan/01.jpg new file mode 100755 index 0000000000000000000000000000000000000000..502be3f0bbada559c91718606b6b03c86021572b Binary files /dev/null and b/examples/pan/01.jpg differ diff --git a/examples/pan/02.jpg b/examples/pan/02.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dd1d5d4450d43f0e2812a4740f92b2eabd5d7d11 Binary files /dev/null and b/examples/pan/02.jpg differ diff --git a/examples/pan/03.jpg b/examples/pan/03.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d3a84ffebd40b0748deb1b2f1f8734f88b2d1ebc Binary files /dev/null and b/examples/pan/03.jpg differ diff --git a/examples/pan/04.jpg b/examples/pan/04.jpg new file mode 100644 index 0000000000000000000000000000000000000000..63c305e6fd77128f5b1664d89bf70bc0fad21a40 Binary files /dev/null and b/examples/pan/04.jpg differ diff --git a/examples/pan/05.jpg b/examples/pan/05.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5617fd79f1a039673e5de8e3ee865351f96ff1cf Binary files /dev/null and b/examples/pan/05.jpg differ diff --git a/examples/pan/06.jpg b/examples/pan/06.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d5e6670f5491df5388f4d5181c60afb36cdcdfed Binary files /dev/null and b/examples/pan/06.jpg differ diff --git a/examples/remove/01_moto/0.jpg b/examples/remove/01_moto/0.jpg new file mode 100755 index 0000000000000000000000000000000000000000..2164e166b5fb3e420a0c892b891ba3e3e68c6712 Binary files /dev/null and b/examples/remove/01_moto/0.jpg differ diff --git a/examples/remove/01_moto/mask0.jpg b/examples/remove/01_moto/mask0.jpg new file mode 100755 index 0000000000000000000000000000000000000000..3c602058c95f0165eb3e4b4b84dc6205de06cf66 Binary files /dev/null and b/examples/remove/01_moto/mask0.jpg differ diff --git a/examples/remove/01_moto/mask1.jpg b/examples/remove/01_moto/mask1.jpg new file mode 100755 index 0000000000000000000000000000000000000000..d773289fee31961900990eb745f7cbd8ac85735e Binary files /dev/null and b/examples/remove/01_moto/mask1.jpg differ diff --git a/examples/remove/02_ring/0.jpg b/examples/remove/02_ring/0.jpg new file mode 100755 index 0000000000000000000000000000000000000000..897e98afa4ebe74083b154fee58f77332c072973 Binary files /dev/null and b/examples/remove/02_ring/0.jpg differ diff --git a/examples/remove/02_ring/mask0.jpg b/examples/remove/02_ring/mask0.jpg new file mode 100755 index 0000000000000000000000000000000000000000..437703be8b840ac9f7d84587f7871872d60b9b1c Binary files /dev/null and b/examples/remove/02_ring/mask0.jpg differ diff --git a/examples/remove/02_ring/mask1.jpg b/examples/remove/02_ring/mask1.jpg new file mode 100755 index 0000000000000000000000000000000000000000..61001eff2dd2cb528a110d63897e7d48945f2cf5 Binary files /dev/null and b/examples/remove/02_ring/mask1.jpg differ diff --git a/examples/remove/02_ring/mask2.jpg b/examples/remove/02_ring/mask2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..faabcd04ce296a29eecbf23a9fbb5fcd93a82dd9 Binary files /dev/null and b/examples/remove/02_ring/mask2.jpg differ diff --git a/examples/remove/03_ball/0.jpg b/examples/remove/03_ball/0.jpg new file mode 100755 index 0000000000000000000000000000000000000000..a32933213dd03dedcd1483fa5bee7582c0c51236 Binary files /dev/null and b/examples/remove/03_ball/0.jpg differ diff --git a/examples/remove/03_ball/mask0.jpg b/examples/remove/03_ball/mask0.jpg new file mode 100755 index 0000000000000000000000000000000000000000..41d2fb7ffc5b8605d8f6622bb646ca3afb8093a1 Binary files /dev/null and b/examples/remove/03_ball/mask0.jpg differ diff --git a/examples/remove/03_ball/mask1.jpg b/examples/remove/03_ball/mask1.jpg new file mode 100755 index 0000000000000000000000000000000000000000..dbe79573f0b773655b22ea70003eb38e747d5394 Binary files /dev/null and b/examples/remove/03_ball/mask1.jpg differ diff --git a/examples/remove/04_pikachu/0.jpg b/examples/remove/04_pikachu/0.jpg new file mode 100755 index 0000000000000000000000000000000000000000..1204812f145e29311d15fd20433e7649954f8af2 Binary files /dev/null and b/examples/remove/04_pikachu/0.jpg differ diff --git a/examples/remove/04_pikachu/mask0.jpg b/examples/remove/04_pikachu/mask0.jpg new file mode 100755 index 0000000000000000000000000000000000000000..a197882e1b3f9c728096885ee6265d6875521905 Binary files /dev/null and b/examples/remove/04_pikachu/mask0.jpg differ diff --git a/examples/remove/04_pikachu/mask1.jpg b/examples/remove/04_pikachu/mask1.jpg new file mode 100755 index 0000000000000000000000000000000000000000..80777884fa0fb4ee59deedc1b99477e2f95594fc Binary files /dev/null and b/examples/remove/04_pikachu/mask1.jpg differ diff --git a/examples/remove/04_pikachu/mask2.jpg b/examples/remove/04_pikachu/mask2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2eb6c21723b01a5c74b50563a096cda0e21b87c0 Binary files /dev/null and b/examples/remove/04_pikachu/mask2.jpg differ diff --git a/examples/remove/05_betty/0.jpg b/examples/remove/05_betty/0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..39eceadc2330703743090990bdbff8aca97afb13 Binary files /dev/null and b/examples/remove/05_betty/0.jpg differ diff --git a/examples/remove/05_betty/mask0.jpg b/examples/remove/05_betty/mask0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87c0784b1c402471436c89166f8114c36535c672 Binary files /dev/null and b/examples/remove/05_betty/mask0.jpg differ diff --git a/examples/zoom/01.jpg b/examples/zoom/01.jpg new file mode 100755 index 0000000000000000000000000000000000000000..502be3f0bbada559c91718606b6b03c86021572b Binary files /dev/null and b/examples/zoom/01.jpg differ diff --git a/examples/zoom/02.jpg b/examples/zoom/02.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f5730a436e7a08e47a7f2e29e28617271d9d1c44 Binary files /dev/null and b/examples/zoom/02.jpg differ diff --git a/examples/zoom/03.jpg b/examples/zoom/03.jpg new file mode 100755 index 0000000000000000000000000000000000000000..f7a0cc7d115a9e8a59959c910d66774ab0d44af8 Binary files /dev/null and b/examples/zoom/03.jpg differ diff --git a/examples/zoom/04.jpg b/examples/zoom/04.jpg new file mode 100755 index 0000000000000000000000000000000000000000..e4341114794f8be5fe80c9d442693630db1c8ff6 Binary files /dev/null and b/examples/zoom/04.jpg differ diff --git a/examples/zoom/05.jpg b/examples/zoom/05.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f4a2f82d540c7f6731b8e4dff5dcdb74e18c4ab6 Binary files /dev/null and b/examples/zoom/05.jpg differ diff --git a/examples/zoom/06.jpg b/examples/zoom/06.jpg new file mode 100644 index 0000000000000000000000000000000000000000..63c305e6fd77128f5b1664d89bf70bc0fad21a40 Binary files /dev/null and b/examples/zoom/06.jpg differ diff --git a/examples/zoom/07.jpg b/examples/zoom/07.jpg new file mode 100644 index 0000000000000000000000000000000000000000..aaffb0d35dff979c342e5bc383e0c8d999607134 Binary files /dev/null and b/examples/zoom/07.jpg differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d434bd9eaabf87af87056786d9479729abc8a266 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +diffusers==0.18.2 +torch==2.0.1 +torchvision==0.15.2 +matplotlib==3.7.2 +numpy==1.25.1 +opencv_python==4.8.0.74 +opencv_python_headless==4.8.0.74 +Pillow==10.1.0 +Pillow==10.1.0 +transformers==4.35.0 +gradio==4.0.0 +basicsr==1.4.2 +accelerate==0.21.0 +invisible-watermark \ No newline at end of file diff --git a/sam/efficient_sam/__init__.py b/sam/efficient_sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22a2d29cee5d2a2df01944c90b6e01f879301f3f --- /dev/null +++ b/sam/efficient_sam/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +from .build_efficient_sam import ( + build_efficient_sam_vitt, + build_efficient_sam_vits, +) diff --git a/sam/efficient_sam/build_efficient_sam.py b/sam/efficient_sam/build_efficient_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..1d138e7335d10c8cbf43aa9ceafef12eda92a66e --- /dev/null +++ b/sam/efficient_sam/build_efficient_sam.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .efficient_sam import build_efficient_sam + +def build_efficient_sam_vitt(): + return build_efficient_sam( + encoder_patch_embed_dim=192, + encoder_num_heads=3, + checkpoint="models/efficient_sam_vitt.pt", + ).eval() + + +def build_efficient_sam_vits(): + return build_efficient_sam( + encoder_patch_embed_dim=384, + encoder_num_heads=6, + checkpoint="models/efficient_sam_vits.pt", + ).eval() diff --git a/sam/efficient_sam/efficient_sam.py b/sam/efficient_sam/efficient_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3ba4c328e0716a8e3166b7d267353757aa76d7 --- /dev/null +++ b/sam/efficient_sam/efficient_sam.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, List, Tuple, Type + +import torch +import torch.nn.functional as F + +from torch import nn, Tensor + +from .efficient_sam_decoder import MaskDecoder, PromptEncoder +from .efficient_sam_encoder import ImageEncoderViT +from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer + +class EfficientSam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + decoder_max_num_input_points: int, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [0.485, 0.456, 0.406], + pixel_std: List[float] = [0.229, 0.224, 0.225], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.decoder_max_num_input_points = decoder_max_num_input_points + self.mask_decoder = mask_decoder + self.register_buffer( + "pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False + ) + self.register_buffer( + "pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False + ) + + @torch.jit.export + def predict_masks( + self, + image_embeddings: torch.Tensor, + batched_points: torch.Tensor, + batched_point_labels: torch.Tensor, + multimask_output: bool, + input_h: int, + input_w: int, + output_h: int = -1, + output_w: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predicts masks given image embeddings and prompts. This only runs the decoder. + + Arguments: + image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W] + batched_points: A tensor of shape [B, max_num_queries, num_pts, 2] + batched_point_labels: A tensor of shape [B, max_num_queries, num_pts] + Returns: + A tuple of two tensors: + low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks + iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores + """ + + batch_size, max_num_queries, num_pts, _ = batched_points.shape + num_pts = batched_points.shape[2] + rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w) + + if num_pts > self.decoder_max_num_input_points: + rescaled_batched_points = rescaled_batched_points[ + :, :, : self.decoder_max_num_input_points, : + ] + batched_point_labels = batched_point_labels[ + :, :, : self.decoder_max_num_input_points + ] + elif num_pts < self.decoder_max_num_input_points: + rescaled_batched_points = F.pad( + rescaled_batched_points, + (0, 0, 0, self.decoder_max_num_input_points - num_pts), + value=-1.0, + ) + batched_point_labels = F.pad( + batched_point_labels, + (0, self.decoder_max_num_input_points - num_pts), + value=-1.0, + ) + + sparse_embeddings = self.prompt_encoder( + rescaled_batched_points.reshape( + batch_size * max_num_queries, self.decoder_max_num_input_points, 2 + ), + batched_point_labels.reshape( + batch_size * max_num_queries, self.decoder_max_num_input_points + ), + ) + sparse_embeddings = sparse_embeddings.view( + batch_size, + max_num_queries, + sparse_embeddings.shape[1], + sparse_embeddings.shape[2], + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings, + self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + multimask_output=multimask_output, + ) + _, num_predictions, low_res_size, _ = low_res_masks.shape + + if output_w > 0 and output_h > 0: + output_masks = F.interpolate( + low_res_masks, (output_h, output_w), mode="bicubic" + ) + output_masks = torch.reshape( + output_masks, + (batch_size, max_num_queries, num_predictions, output_h, output_w), + ) + else: + output_masks = torch.reshape( + low_res_masks, + ( + batch_size, + max_num_queries, + num_predictions, + low_res_size, + low_res_size, + ), + ) + iou_predictions = torch.reshape( + iou_predictions, (batch_size, max_num_queries, num_predictions) + ) + sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True) + iou_predictions = torch.take_along_dim(iou_predictions, sorted_ids, dim=2) + output_masks = torch.take_along_dim( + output_masks, sorted_ids[..., None, None], dim=2 + ) + return output_masks, iou_predictions + + def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int): + return torch.stack( + [ + torch.where( + batched_points[..., 0] >= 0, + batched_points[..., 0] * self.image_encoder.img_size / input_w, + -1.0, + ), + torch.where( + batched_points[..., 1] >= 0, + batched_points[..., 1] * self.image_encoder.img_size / input_h, + -1.0, + ), + ], + dim=-1, + ) + + @torch.jit.export + def get_image_embeddings(self, batched_images) -> torch.Tensor: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_images: A tensor of shape [B, 3, H, W] + Returns: + List of image embeddings each of of shape [B, C(i), H(i), W(i)]. + The last embedding corresponds to the final layer. + """ + batched_images = self.preprocess(batched_images) + return self.image_encoder(batched_images) + + def forward( + self, + batched_images: torch.Tensor, + batched_points: torch.Tensor, + batched_point_labels: torch.Tensor, + scale_to_original_image_size: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_images: A tensor of shape [B, 3, H, W] + batched_points: A tensor of shape [B, num_queries, max_num_pts, 2] + batched_point_labels: A tensor of shape [B, num_queries, max_num_pts] + + Returns: + A list tuples of two tensors where the ith element is by considering the first i+1 points. + low_res_mask: A tensor of shape [B, 256, 256] of predicted masks + iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores + """ + batch_size, _, input_h, input_w = batched_images.shape + image_embeddings = self.get_image_embeddings(batched_images) + return self.predict_masks( + image_embeddings, + batched_points, + batched_point_labels, + multimask_output=True, + input_h=input_h, + input_w=input_w, + output_h=input_h if scale_to_original_image_size else -1, + output_w=input_w if scale_to_original_image_size else -1, + ) + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + if ( + x.shape[2] != self.image_encoder.img_size + or x.shape[3] != self.image_encoder.img_size + ): + x = F.interpolate( + x, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + ) + return (x - self.pixel_mean) / self.pixel_std + + +def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None): + img_size = 1024 + encoder_patch_size = 16 + encoder_depth = 12 + encoder_mlp_ratio = 4.0 + encoder_neck_dims = [256, 256] + decoder_max_num_input_points = 6 + decoder_transformer_depth = 2 + decoder_transformer_mlp_dim = 2048 + decoder_num_heads = 8 + decoder_upscaling_layer_dims = [64, 32] + num_multimask_outputs = 3 + iou_head_depth = 3 + iou_head_hidden_dim = 256 + activation = "gelu" + normalization_type = "layer_norm" + normalize_before_activation = False + + assert activation == "relu" or activation == "gelu" + if activation == "relu": + activation_fn = nn.ReLU + else: + activation_fn = nn.GELU + + image_encoder = ImageEncoderViT( + img_size=img_size, + patch_size=encoder_patch_size, + in_chans=3, + patch_embed_dim=encoder_patch_embed_dim, + normalization_type=normalization_type, + depth=encoder_depth, + num_heads=encoder_num_heads, + mlp_ratio=encoder_mlp_ratio, + neck_dims=encoder_neck_dims, + act_layer=activation_fn, + ) + + image_embedding_size = image_encoder.image_embedding_size + encoder_transformer_output_dim = image_encoder.transformer_output_dim + + sam = EfficientSam( + image_encoder=image_encoder, + prompt_encoder=PromptEncoder( + embed_dim=encoder_transformer_output_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(img_size, img_size), + ), + decoder_max_num_input_points=decoder_max_num_input_points, + mask_decoder=MaskDecoder( + transformer_dim=encoder_transformer_output_dim, + transformer=TwoWayTransformer( + depth=decoder_transformer_depth, + embedding_dim=encoder_transformer_output_dim, + num_heads=decoder_num_heads, + mlp_dim=decoder_transformer_mlp_dim, + activation=activation_fn, + normalize_before_activation=normalize_before_activation, + ), + num_multimask_outputs=num_multimask_outputs, + activation=activation_fn, + normalization_type=normalization_type, + normalize_before_activation=normalize_before_activation, + iou_head_depth=iou_head_depth - 1, + iou_head_hidden_dim=iou_head_hidden_dim, + upscaling_layer_dims=decoder_upscaling_layer_dims, + ), + pixel_mean=[0.485, 0.456, 0.406], + pixel_std=[0.229, 0.224, 0.225], + ) + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f, map_location="cpu") + sam.load_state_dict(state_dict["model"]) + return sam diff --git a/sam/efficient_sam/efficient_sam_decoder.py b/sam/efficient_sam/efficient_sam_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..380f41c1650f1ffb824d8d911f810eabedc66ddd --- /dev/null +++ b/sam/efficient_sam/efficient_sam_decoder.py @@ -0,0 +1,315 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple, Type + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .mlp import MLPBlock + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + self.invalid_points = nn.Embedding(1, embed_dim) + self.point_embeddings = nn.Embedding(1, embed_dim) + self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim) + self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + invalid_label_ids = torch.eq(labels, -1)[:,:,None] + point_label_ids = torch.eq(labels, 1)[:,:,None] + topleft_label_ids = torch.eq(labels, 2)[:,:,None] + bottomright_label_ids = torch.eq(labels, 3)[:,:,None] + point_embedding = point_embedding + self.invalid_points.weight[:,None,:] * invalid_label_ids + point_embedding = point_embedding + self.point_embeddings.weight[:,None,:] * point_label_ids + point_embedding = point_embedding + self.bbox_top_left_embeddings.weight[:,None,:] * topleft_label_ids + point_embedding = point_embedding + self.bbox_bottom_right_embeddings.weight[:,None,:] * bottomright_label_ids + return point_embedding + + def forward( + self, + coords, + labels, + ) -> torch.Tensor: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points: A tensor of shape [B, 2] + labels: An integer tensor of shape [B] where each element is 1,2 or 3. + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + """ + return self._embed_points(coords, labels) + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int) -> None: + super().__init__() + self.register_buffer( + "positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats)) + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device = self.positional_encoding_gaussian_matrix.device + grid = torch.ones([h, w], device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int, + activation: Type[nn.Module], + normalization_type: str, + normalize_before_activation: bool, + iou_head_depth: int, + iou_head_hidden_dim: int, + upscaling_layer_dims: List[int], + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + if num_multimask_outputs > 1: + self.num_mask_tokens = num_multimask_outputs + 1 + else: + self.num_mask_tokens = 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + output_dim_after_upscaling = transformer_dim + + self.final_output_upscaling_layers = nn.ModuleList([]) + for idx, layer_dims in enumerate(upscaling_layer_dims): + self.final_output_upscaling_layers.append( + nn.Sequential( + nn.ConvTranspose2d( + output_dim_after_upscaling, + layer_dims, + kernel_size=2, + stride=2, + ), + nn.GroupNorm(1, layer_dims) + if idx < len(upscaling_layer_dims) - 1 + else nn.Identity(), + activation(), + ) + ) + output_dim_after_upscaling = layer_dims + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLPBlock( + input_dim=transformer_dim, + hidden_dim=transformer_dim, + output_dim=output_dim_after_upscaling, + num_layers=2, + act=activation, + ) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLPBlock( + input_dim=transformer_dim, + hidden_dim=iou_head_hidden_dim, + output_dim=self.num_mask_tokens, + num_layers=iou_head_depth, + act=activation, + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W] + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable). + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + + ( + batch_size, + max_num_queries, + sparse_embed_dim_1, + sparse_embed_dim_2, + ) = sparse_prompt_embeddings.shape + + ( + _, + image_embed_dim_c, + image_embed_dim_h, + image_embed_dim_w, + ) = image_embeddings.shape + + # Tile the image embedding for all queries. + image_embeddings_tiled = torch.tile( + image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1] + ).view( + batch_size * max_num_queries, + image_embed_dim_c, + image_embed_dim_h, + image_embed_dim_w, + ) + sparse_prompt_embeddings = sparse_prompt_embeddings.reshape( + batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2 + ) + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings_tiled, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + ) + if multimask_output and self.num_multimask_outputs > 1: + return masks[:, 1:, :], iou_pred[:, 1:] + else: + return masks[:, :1, :], iou_pred[:, :1] + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + # Expand per-image data in batch direction to be per-mask + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = image_embeddings.shape + hs, src = self.transformer(image_embeddings, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + upscaled_embedding = src.transpose(1, 2).view(b, c, h, w) + + for upscaling_layer in self.final_output_upscaling_layers: + upscaled_embedding = upscaling_layer(upscaled_embedding) + hyper_in_list: List[torch.Tensor] = [] + for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps): + hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + return masks, iou_pred diff --git a/sam/efficient_sam/efficient_sam_encoder.py b/sam/efficient_sam/efficient_sam_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..73fd7ac470b42df738e5e6bcbbcb60b4f30fb46e --- /dev/null +++ b/sam/efficient_sam/efficient_sam_encoder.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import List, Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size, + patch_size, + in_chans, + embed_dim, + ): + super().__init__() + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + bias=True, + ) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads, + qkv_bias, + qk_scale=None, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + act_layer=nn.GELU, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=1e-6) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + ) + self.norm2 = nn.LayerNorm(dim, eps=1e-6) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +@torch.jit.export +def get_abs_pos( + abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int] +) -> torch.Tensor: + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + h = hw[0] + w = hw[1] + if has_cls_token: + abs_pos = abs_pos[:, 1:] + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ) + return new_abs_pos.permute(0, 2, 3, 1) + else: + return abs_pos.reshape(1, h, w, -1) + + +# Image encoder for efficient SAM. +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + patch_embed_dim: int, + normalization_type: str, + depth: int, + num_heads: int, + mlp_ratio: float, + neck_dims: List[int], + act_layer: Type[nn.Module], + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + patch_embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + act_layer (nn.Module): Activation layer. + """ + super().__init__() + + self.img_size = img_size + self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1)) + self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1] + self.pretrain_use_cls_token = True + pretrain_img_size = 224 + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim) + # Initialize absolute positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_size) * ( + pretrain_img_size // patch_size + ) + num_positions = num_patches + 1 + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim)) + self.blocks = nn.ModuleList() + for i in range(depth): + vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True) + self.blocks.append(vit_block) + self.neck = nn.Sequential( + nn.Conv2d( + patch_embed_dim, + neck_dims[0], + kernel_size=1, + bias=False, + ), + LayerNorm2d(neck_dims[0]), + nn.Conv2d( + neck_dims[0], + neck_dims[0], + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(neck_dims[0]), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert ( + x.shape[2] == self.img_size and x.shape[3] == self.img_size + ), "input image size must match self.img_size" + x = self.patch_embed(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + x = x + get_abs_pos( + self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]] + ) + num_patches = x.shape[1] + assert x.shape[2] == num_patches + x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3]) + for blk in self.blocks: + x = blk(x) + x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2]) + x = self.neck(x.permute(0, 3, 1, 2)) + return x diff --git a/sam/efficient_sam/mlp.py b/sam/efficient_sam/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..b3be8db49cbf6990ce55a467e7e62f60daf62c9d --- /dev/null +++ b/sam/efficient_sam/mlp.py @@ -0,0 +1,29 @@ +from typing import Type + +from torch import nn + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLPBlock(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + act: Type[nn.Module], + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Sequential(nn.Linear(n, k), act()) + for n, k in zip([input_dim] + h, [hidden_dim] * num_layers) + ) + self.fc = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return self.fc(x) diff --git a/sam/efficient_sam/two_way_transformer.py b/sam/efficient_sam/two_way_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..881e76fd7efd07eeeef12999931fc2b74db406a9 --- /dev/null +++ b/sam/efficient_sam/two_way_transformer.py @@ -0,0 +1,264 @@ +import math +from typing import Tuple, Type +import torch +from torch import nn, Tensor +from .mlp import MLPBlock + + + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module], + normalize_before_activation: bool, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + curr_layer = TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + normalize_before_activation=normalize_before_activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + self.layers.append(curr_layer) + + self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for idx, layer in enumerate(self.layers): + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module], + normalize_before_activation: bool, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock( + embedding_dim, + mlp_dim, + embedding_dim, + 1, + activation, + ) + + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if not self.skip_first_layer_pe: + queries = queries + query_pe + attn_out = self.self_attn(q=queries, k=queries, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class AttentionForTwoWayAttentionBlock(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + self._reset_parameters() + + def _reset_parameters(self) -> None: + # The fan_out is incorrect, but matches pytorch's initialization + # for which qkv is a single 3*embedding_dim x embedding_dim matrix + fan_in = self.embedding_dim + fan_out = 3 * self.internal_dim + # Xavier uniform with our custom fan_out + bnd = math.sqrt(6 / (fan_in + fan_out)) + nn.init.uniform_(self.q_proj.weight, -bnd, bnd) + nn.init.uniform_(self.k_proj.weight, -bnd, bnd) + nn.init.uniform_(self.v_proj.weight, -bnd, bnd) + # out_proj.weight is left with default initialization, like pytorch attention + nn.init.zeros_(self.q_proj.bias) + nn.init.zeros_(self.k_proj.bias) + nn.init.zeros_(self.v_proj.bias) + nn.init.zeros_(self.out_proj.bias) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + return out diff --git a/src/demo/demo.py b/src/demo/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..166b2b03f28b71b0b3ea8e81e808ffaa9bc6f5da --- /dev/null +++ b/src/demo/demo.py @@ -0,0 +1,738 @@ +import gradio as gr +import numpy as np +from src.demo.utils import get_point, store_img, get_point_move, store_img_move, clear_points, upload_image_move, segment_with_points, segment_with_points_paste, fun_clear, paste_with_mask_and_offset +import spaces + +examples_remove = [ + [ + "examples/remove/02_ring/0.jpg", # original image 1 + "examples/remove/02_ring/mask0.jpg", # mask 1 + "examples/remove/02_ring/0.jpg", # original image 2 + "examples/remove/02_ring/mask1.jpg", #mask 2 + "examples/remove/02_ring/0.jpg", #Original image 3 + "examples/remove/02_ring/mask2.jpg", #mask 3 + None, #Original image 4 + None, # refine mask + ], # 02 + [ + "examples/remove/01_moto/0.jpg", # original image 1 + "examples/remove/01_moto/mask0.jpg", # mask 1 + "examples/remove/01_moto/0.jpg", # original image 2 + None, #mask 2 + "examples/remove/01_moto/0.jpg", #Original image 3 + None, #mask 3 + "examples/remove/01_moto/0.jpg", #Original image 4 + "examples/remove/01_moto/mask1.jpg", # refine mask + ], # 01 + [ + "examples/remove/03_ball/0.jpg", # original image 1 + "examples/remove/03_ball/mask0.jpg", # mask 1 + "examples/remove/03_ball/0.jpg", # original image 2 + "examples/remove/03_ball/mask1.jpg", #mask 2 + "examples/remove/03_ball/0.jpg", #Original image 3 + None, #mask 3 + None, #Original image 4 + None, # refine mask + ], # 03 + [ + "examples/remove/04_pikachu/0.jpg", # original image 1 + "examples/remove/04_pikachu/mask0.jpg", # mask 1 + "examples/remove/04_pikachu/0.jpg", # original image 2 + "examples/remove/04_pikachu/mask1.jpg", #mask 2 + "examples/remove/04_pikachu/0.jpg", #Original image 3 + "examples/remove/04_pikachu/mask2.jpg", #mask 3 + None, #Original image 4 + None, # refine mask + ], # 04 + [ + "examples/remove/05_betty/0.jpg", # original image 1 + "examples/remove/05_betty/mask0.jpg", # mask 1 + None, # original image 2 + None, #mask 2 + None, #Original image 3 + None, #mask 3 + None, #Original image 4 + None, # refine mask + ], # 05 +] +examples_zoom = [ + ["examples/zoom/01.jpg"], + ["examples/zoom/02.jpg"], + ["examples/zoom/03.jpg"], + ["examples/zoom/04.jpg"], + ["examples/zoom/05.jpg"], + ["examples/zoom/06.jpg"], + ["examples/zoom/07.jpg"], +] +examples_pan = [ + ["examples/pan/01.jpg"], + ["examples/pan/02.jpg"], + ["examples/pan/03.jpg"], + ["examples/pan/04.jpg"], + ["examples/pan/05.jpg"], + ["examples/pan/06.jpg"], +] + +examples_moving = [ + [ + "examples/layer/01_horse/00.jpg", #bg + "examples/layer/01_horse/mask0.jpg", #bg_mask + 0, 0, 1.2, "None", "left/right", #l1_dx, l1_dy, l1_resize + ], + [ + "examples/moving/01_ball/0.jpg", #bg + "examples/moving/01_ball/mask0.jpg", #bg_mask + -0.2, -0.1, 0.8, "None", "None", #l1_dx, l1_dy, l1_resize + ], + [ + "examples/moving/02_bell/0.jpg", #bg + "examples/moving/02_bell/mask0.jpg", #bg_mask + 0, 0, 0.75, "None", "None", #l1_dx, l1_dy, l1_resize + ], +] +examples_layer = [ + [ + "examples/layer/01_horse/00.jpg", #bg + "examples/layer/01_horse/mask0.jpg", #bg_mask + + "examples/layer/01_horse/00.jpg", #l1 + "examples/layer/01_horse/mask0.jpg", #l1_mask + -0.2, 0, 1, "None", "None", #l1_dx, l1_dy, l1_resize + + "examples/layer/01_horse/00.jpg", #l2 + "examples/layer/01_horse/mask0.jpg", #l2_mask + 0.2, 0, 1, "None", "None", #l2_dx, l2_dy, l2_resize + + None, #l3 + None, #l3_mask + 0, 0, 1, "None", "None", #l3_dx, l3_dy, l3_resize + + "examples/layer/01_horse/00.jpg", #bg_ori + "examples/layer/01_horse/00.jpg", #l1_ori + "examples/layer/01_horse/00.jpg", #l2_ori + None, "None", "None", #l3_ori + ], + + [ + "examples/layer/02_baby/00.jpg", #bg + "examples/layer/02_baby/mask0.jpg", #bg_mask + + "examples/layer/02_baby/00.jpg", #l1 + "examples/layer/02_baby/mask1.jpg", #l1_mask + -0.35, 0, 1,"left/right", "None", #l1_dx, l1_dy, l1_resize + + "examples/layer/02_baby/00.jpg", #l2 + "examples/layer/02_baby/mask2.jpg", #l2_mask + 0.35, 0, 1, "left/right", "None", #l2_dx, l2_dy, l2_resize + + None, #l3 + None, #l3_mask + 0, 0, 1,"None", "None", #l3_dx, l3_dy, l3_resize + ], + + [ + "examples/layer/03_text/00.jpg", #bg + "examples/layer/03_text/mask0.jpg", #bg_mask + + "examples/layer/03_text/01.jpg", #l1 + "examples/layer/03_text/mask1.jpg", #l1_mask + 0.1, -0.1, 0.5, "None", "None",#l1_dx, l1_dy, l1_resize + + None, #l2 + None, #l2_mask + 0, 0, 1, "None", "None",#l2_dx, l2_dy, l2_resize + + None, #l3 + None, #l3_mask + 0, 0, 1,"None", "None", #l3_dx, l3_dy, l3_resize + ], + [ + "examples/layer/04_cross/0.jpg", #bg + "examples/layer/04_cross/mask0.jpg", #bg_mask + + "examples/layer/04_cross/2.jpg", #l1 + "examples/layer/04_cross/mask2.jpg", #l1_mask + -0.1, -0.25, 0.5, "None", "None",#l1_dx, l1_dy, l1_resize + + "examples/layer/04_cross/1.jpg", #l2 + "examples/layer/04_cross/mask1.jpg", #l2_mask + -0.1, -0.15, 0.7, "None", "None",#l2_dx, l2_dy, l2_resize + + "examples/layer/04_cross/3.jpg", #l3 + "examples/layer/04_cross/mask3.jpg", #l3_mask + -0.1, -0.55, 0.5, "None", "None",#l3_dx, l3_dy, l3_resize + ], +] +examples_mask_box = [ + [ + "examples/mask_box/image1.jpg", # original image 1 + "examples/mask_box/image2.jpg", # original image 1 + "examples/mask_box/mask01.jpg", # original image 1 + "examples/mask_box/mask02.jpg", # original image 1 + "examples/mask_box/mask00.jpg", # original image 1 + ] +] + +# 01 +def create_demo_remove(runner=None): + DESCRIPTION = """ + # Object Removal + + ## Usage: + + - Upload a sources image, and then draw a box to generate the mask corresponding to the selecting object. + - You can choose to mask more than one object by using Mask2 and Mask3. + - If you encounter artifacts, try to sketch the regions that caused the artifacts. + - You can refer to the first motorcycle example to understand the usage of the Refined Mask. + - Please clear the output before running a new example! + - For more irregular composition masks, refer to the last page: Mask Preparation. +""" + + with gr.Blocks() as demo: + original_image = gr.State(value=None) + img_with_mask = gr.State(value=None) + + selected_points = gr.State([]) + global_points = gr.State([]) + global_point_label = gr.State([]) + + gr.Markdown(DESCRIPTION) + + with gr.Row(): + with gr.Column(): + with gr.Group(): + gr.Markdown("# INPUT") + # mask 0 + gr.Markdown("## Select two points for Mask 1:") + gr.Markdown("the top left and the bottom right") + original_image_1 = gr.Image(sources='upload', label="Original image (Mask 1)", interactive=True, type="numpy") + # mask 1 + gr.Markdown("## Option: Select two points for Mask 2") + gr.Markdown("the top left and the bottom right") + original_image_2 = gr.Image(sources='upload', label="Original (Mask 2)", interactive=True, type="numpy") + # mask 2 + gr.Markdown("## Option: Select two points for Mask 3") + gr.Markdown("the top left and the bottom right") + original_image_3 = gr.Image(label="Original image (Mask 3)", interactive=True, type="numpy") + + gr.Markdown("## Option: Mask regions caused artifacts") + gr.Markdown("the top left and the bottom right") + original_image_4 = gr.Image(label="Original image (Refine Mask)", interactive=True, type="numpy") + with gr.Row(): + run_button = gr.Button("Edit") + clear_button = gr.Button("Clear") + + + with gr.Column(): + with gr.Group(): + gr.Markdown("# Mask") + + gr.Markdown("## Removal Mask 1") + mask_1 = gr.Image(sources='upload', label="Removal Mask 1", interactive=True, type="numpy") + gr.Markdown("## Option: Removal Mask 2") + mask_2 = gr.Image(sources='upload', label="Removal Mask 2", interactive=True, type="numpy") + gr.Markdown("## Option: Removal Mask 3") + mask_3 = gr.Image(sources='upload', label="Removal Mask 3", interactive=True, type="numpy") + + gr.Markdown("## Option: Refine Mask to avoid artifacts") + refine_mask = gr.Image(sources='upload', label="Refine Mask", interactive=True, type="numpy") + + with gr.Column(): + with gr.Group(): + gr.Markdown("# OUTPUT") + gr.Markdown("## Results") + output = gr.Gallery(columns=1, height='auto') + + + original_image_1.select( + segment_with_points, + inputs=[original_image_1, original_image, global_points, global_point_label], + outputs=[original_image_1, original_image, mask_1, global_points, global_point_label] + ) + original_image_2.select( + segment_with_points, + inputs=[original_image_2, original_image, global_points, global_point_label], + outputs=[original_image_2, original_image, mask_2, global_points, global_point_label] + ) + original_image_3.select( + segment_with_points, + inputs=[original_image_3, original_image, global_points, global_point_label], + outputs=[original_image_3, original_image, mask_3, global_points, global_point_label] + ) + original_image_4.select( + segment_with_points, + inputs=[original_image_4, original_image, global_points, global_point_label], + outputs=[original_image_4, original_image, refine_mask, global_points, global_point_label] + ) + + with gr.Column(): + gr.Markdown("Try some of the examples below ⬇️") + gr.Examples( + examples=examples_remove, + inputs=[ + original_image_1, mask_1, + original_image_2, mask_2, + original_image_3, mask_3, + original_image_4, refine_mask] + ) + run_button.click(fn=runner, inputs=[original_image, mask_1, mask_2, mask_3, refine_mask, + original_image_1, original_image_2, original_image_3], outputs=[output]) + clear_button.click( + fn=fun_clear, + inputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, original_image_1, original_image_2, original_image_3, original_image_4, mask_1, mask_2, mask_3, refine_mask], + outputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, original_image_1, original_image_2, original_image_3, original_image_4, mask_1, mask_2, mask_3, refine_mask] + ) + return demo + + +# 02: +def create_demo_zooming(runner=None): + DESCRIPTION = """ + # Zooming Out + + ## Usage: + + - Upload a sources image and choose the width and height zooming scale to zoom out. + - The illustration of image adjustment and mask preparation is shown in the second column. + - We recommend setting the zooming scale between 0.75 and 1 for optimal results. + - Please clear the output before running a new example! + """ + + with gr.Blocks() as demo: + + gr.Markdown(DESCRIPTION) + + with gr.Row(): + with gr.Column(): + with gr.Group(): + gr.Markdown("# INPUT") + # mask 0 + gr.Markdown("## Original Image") + original_image = gr.Image(sources='upload', interactive=True, type="numpy") + + + gr.Markdown("## Scale:") + width_scale= gr.Slider( + label="Width scale", + minimum=0, + maximum=1, + step=0.05, + value=0.9, + interactive=True) + height_scale= gr.Slider( + label="Height scale", + minimum=0, + maximum=1, + step=0.05, + value=0.9, + interactive=True) + with gr.Row(): + run_button = gr.Button("Edit") + clear_button = gr.Button("Clear") + + with gr.Column(): + with gr.Group(): + gr.Markdown("# Preprocess") + gr.Markdown("## Image Adjustment:") + new_image = gr.Gallery(columns=1, height='auto') + gr.Markdown("## Mask Adjustment:") + new_mask = gr.Gallery(columns=1, height='auto') + + with gr.Column(): + with gr.Group(): + gr.Markdown("# OUTPUT") + gr.Markdown("## Results") + output = gr.Gallery(columns=1, height='auto') + + with gr.Column(): + gr.Markdown("Try some of the examples below ⬇️") + gr.Examples( + examples=examples_zoom, + inputs=[original_image] + ) + run_button.click(fn=runner, inputs=[original_image, width_scale, height_scale], outputs=[output, new_image, new_mask]) + clear_button.click(fn=fun_clear, inputs=[original_image, width_scale, height_scale, output, new_image, new_mask], + outputs=[original_image, width_scale, height_scale, output, new_image, new_mask]) + return demo +# 03 + +def create_demo_panning(runner=None): + DESCRIPTION = """ + # Camera Panning + + ## Usage: + + - Upload a sources image and choose the width and height panning scale. + - The illustration of image adjustment and mask preparation is shown in the second column. + - We recommend setting the panning scale between 0 and 0.25 for optimal results. + - Please clear the output before running a new example! + """ + + with gr.Blocks() as demo: + gr.Markdown(DESCRIPTION) + + with gr.Row(): + with gr.Column(): + with gr.Group(): + gr.Markdown("# INPUT") + # mask 0 + gr.Markdown("## Original Image") + original_image = gr.Image(sources='upload', interactive=True, type="numpy") + w_direction = gr.Radio(["left", "right"], value="left", label="Width Direction") + w_scale = gr.Slider( + label="Width scale", + minimum=0, + maximum=1, + step=0.05, + value=0, + interactive=True) + + h_direction = gr.Radio(["up", "down"], value="up", label="Height Direction") + h_scale = gr.Slider( + label="Height scale", + minimum=0, + maximum=1, + step=0.05, + value=0, + interactive=True) + with gr.Row(): + run_button = gr.Button("Edit") + clear_button = gr.Button("Clear") + + with gr.Column(): + with gr.Group(): + gr.Markdown("# Preprocess") + gr.Markdown("## Image Adjustment:") + new_image = gr.Gallery(columns=1, height='auto') + gr.Markdown("## Mask Adjustment:") + new_mask = gr.Gallery(columns=1, height='auto') + + with gr.Column(): + with gr.Group(): + gr.Markdown("# OUTPUT") + gr.Markdown("## Results") + output = gr.Gallery(columns=1, height='auto') + + with gr.Column(): + gr.Markdown("Try some of the examples below ⬇️") + gr.Examples( + examples=examples_pan, + inputs=[original_image] + ) + run_button.click(fn=runner, inputs=[original_image, w_direction, w_scale, h_direction, h_scale], outputs=[output, new_image, new_mask]) + clear_button.click(fn=fun_clear, inputs=[original_image, w_direction, w_scale, h_direction, h_scale, new_image, new_mask, output], + outputs=[original_image, w_direction, w_scale, h_direction, h_scale, new_image, new_mask, output]) + return demo +# 04: +def create_position_size(label=None): + image = gr.Image(sources='upload', label=label, interactive=True, type="numpy") + with gr.Row(): + dx = gr.Slider( + label="Left-Right", + minimum=-1, + maximum=1, + step=0.05, + value=0, + interactive=True + ) + dy = gr.Slider( + label="Down-Up", + minimum=-1, + maximum=1, + step=0.05, + value=0, + interactive=True + ) + resize_scale = gr.Slider( + label="Resize", + minimum=0, + maximum=2, + step=0.05, + value=1, + interactive=True + ) + with gr.Row(): + w_flip = gr.Radio(["left/right","None"], value="None", label="Horizontal Flip") + h_flip = gr.Radio(["down/up", "None"], value="None", label="Vertical Flip") + return image, dx, dy, resize_scale, w_flip, h_flip +# 05: +def create_demo_layer(runner=None): + DESCRIPTION = """ + # 🚩 Multi-Layered selecting 🚩 + + ## Usage: + + - Notice that all operations can be achieved using the multi-layered selecting mode. + - In particular, you can accomplish multi-object selecting such as adding objects and cross-image composition on this page. + - Try some interesting examples given below to understand the usage. + - Please clear the output before running a new example! + - We strongly recommend you to read the [original paper](https://arxiv.org/abs/2403.14487) to further explore more uses of multi-layered selecting. + """ + global_points = gr.State([]) + global_point_label = gr.State([]) + bg_ori = gr.State(value=None) + l1_ori = gr.State(value=None) + l2_ori = gr.State(value=None) + l3_ori = gr.State(value=None) + with gr.Blocks() as demo: + gr.Markdown(DESCRIPTION) + with gr.Row(): + with gr.Column(): + with gr.Group(): + gr.Markdown("# INPUT") + gr.Markdown("## Background Image") + bg_img = gr.Image(sources='upload', label="Background", interactive=True, type="numpy") + gr.Markdown("## Layer-1") + l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip = create_position_size(label="Layer-1") + gr.Markdown("## Layer-2") + l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip = create_position_size(label="Layer-2") + gr.Markdown("## Layer-3") + l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip = create_position_size(label="Layer-3") + with gr.Row(): + run_button = gr.Button("Edit") + clear_button = gr.Button("Clear") + + with gr.Column(): + with gr.Group(): + gr.Markdown("# Mask") + gr.Markdown("## Background Mask for Removal:") + bg_mask = gr.Image(sources='upload', label="BG Mask", interactive=True, type="numpy") + gr.Markdown("## Layer-1 Mask:") + l1_mask = gr.Image(sources='upload', label="L1 Mask", interactive=True, type="numpy") + gr.Markdown("## Layer-2 Mask:") + l2_mask = gr.Image(sources='upload', label="L2 Mask", interactive=True, type="numpy") + gr.Markdown("## Layer-3 Mask:") + l3_mask = gr.Image(sources='upload', label="L3 Mask", interactive=True, type="numpy") + + with gr.Column(): + with gr.Group(): + gr.Markdown("# OUTPUT") + gr.Markdown("## Results") + output = gr.Gallery(columns=1, height='auto') + + with gr.Column(): + gr.Markdown("Try some of the examples below ⬇️") + gr.Examples( + examples=examples_layer, + inputs=[ + bg_img, bg_mask, + l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, + l2_img, l2_mask, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip, + l3_img, l3_mask, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip, + ] + ) + bg_img.select( + segment_with_points, + inputs=[bg_img, bg_ori, global_points, global_point_label], + outputs=[bg_img, bg_ori, bg_mask, global_points, global_point_label] + ) + l1_img.select( + segment_with_points, + inputs=[l1_img, l1_ori, global_points, global_point_label], + outputs=[l1_img, l1_ori, l1_mask, global_points, global_point_label] + ) + l2_img.select( + segment_with_points, + inputs=[l2_img, l2_ori, global_points, global_point_label], + outputs=[l2_img, l2_ori, l2_mask, global_points, global_point_label] + ) + l3_img.select( + segment_with_points, + inputs=[l3_img, l3_ori, global_points, global_point_label], + outputs=[l3_img, l3_ori, l3_mask, global_points, global_point_label] + ) + + run_button.click(fn=runner, inputs=[ + bg_img, + l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, + l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip, + l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip, + bg_mask, l1_mask, l2_mask, l3_mask, + bg_ori, l1_ori, l2_ori, l3_ori + ], outputs=[output]) + + clear_button.click(fn=fun_clear, + inputs=[bg_img, bg_ori, + l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, + l2_img, l2_ori, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip, + l3_img, l3_ori, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip, + bg_mask, l1_mask, l2_mask, l3_mask, + global_points, global_point_label, output], + outputs=[bg_img, bg_ori, + l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, + l2_img, l2_ori, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip, + l3_img, l3_ori, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip, + bg_mask, l1_mask, l2_mask, l3_mask, + global_points, global_point_label, output], + ) + return demo + +# 06: +def create_demo_mask_box(runner=None): + DESCRIPTION = """ + # 🔧 Mask Preparation + ## Usage: + - This page is a tool for you to combine more than one mask. + - You can draw a box to mask an object to obtain Masks 1-4. + - The merged mask is the union of Masks 1-4. + - Please clear the output before running a new example! + """ + + with gr.Blocks() as demo: + original_image = gr.State(value=None) + img_with_mask = gr.State(value=None) + selected_points = gr.State([]) + global_points = gr.State([]) + global_point_label = gr.State([]) + gr.Markdown(DESCRIPTION) + with gr.Row(): + with gr.Column(): + with gr.Group(): + gr.Markdown("# INPUT") + gr.Markdown("## 1. Select two points for Mask 1") + gr.Markdown("the top left and the bottom right") + img_draw_box_1 = gr.Image(sources='upload', label="Original Image", interactive=True, type="numpy") + + gr.Markdown("## 2. Select two points for Mask 2") + gr.Markdown("the top left and the bottom right") + img_draw_box_2 = gr.Image(sources='upload', label="Original Image", interactive=True, type="numpy") + + gr.Markdown("## 3. Select two points for Mask 3") + gr.Markdown("the top left and the bottom right") + img_draw_box_3 = gr.Image(sources='upload', label="Original Image", interactive=True, type="numpy") + + gr.Markdown("## 4. Select two points for Mask 4") + gr.Markdown("the top left and the bottom right") + img_draw_box_4 = gr.Image(label="Original Image", interactive=True, type="numpy") + + with gr.Row(): + run_button = gr.Button("Edit") + clear_button = gr.Button("Clear") + + with gr.Column(): + with gr.Group(): + gr.Markdown("# Mask") + gr.Markdown("## Mask 1") + mask_1 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy") + gr.Markdown("## Mask 2") + mask_2 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy") + gr.Markdown("## Mask 3") + mask_3 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy") + gr.Markdown("## Mask 4") + mask_4 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy") + + with gr.Column(): + with gr.Group(): + gr.Markdown("# Merged Mask") + merged_mask = gr.Image(sources='upload', label="Mask of object", interactive=True, type="numpy") + + with gr.Column(): + gr.Markdown("Please see the example below. ⬇️") + gr.Examples( + examples=examples_mask_box, + inputs=[ + img_draw_box_1, img_draw_box_2, mask_1, mask_2, merged_mask + ] + ) + img_draw_box_1.select( + segment_with_points, + inputs=[img_draw_box_1, original_image, global_points, global_point_label], + outputs=[img_draw_box_1, original_image, mask_1, global_points, global_point_label] + ) + img_draw_box_2.select( + segment_with_points, + inputs=[img_draw_box_2, original_image, global_points, global_point_label], + outputs=[img_draw_box_2, original_image, mask_2, global_points, global_point_label] + ) + img_draw_box_3.select( + segment_with_points, + inputs=[img_draw_box_3, original_image, global_points, global_point_label], + outputs=[img_draw_box_3, original_image, mask_3, global_points, global_point_label] + ) + img_draw_box_4.select( + segment_with_points, + inputs=[img_draw_box_4, original_image, global_points, global_point_label], + outputs=[img_draw_box_4, original_image, mask_4, global_points, global_point_label] + ) + + run_button.click(fn=runner, inputs=[mask_1, mask_2, mask_3, mask_4], outputs=[merged_mask]) + clear_button.click( + fn=fun_clear, + inputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, img_draw_box_1, img_draw_box_2, img_draw_box_3, img_draw_box_4, mask_1, mask_2, mask_3, mask_4], + outputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, img_draw_box_1, img_draw_box_2, img_draw_box_3, img_draw_box_4, mask_1, mask_2, mask_3, mask_4, merged_mask] + ) + return demo + +def create_demo_moving(runner=None): + DESCRIPTION = """ + # Object Moving, Resizing, and Flipping + + ## Usage: + - Upload an image and draw a box around the object to manipulate. + - Move the object vertically or horizontally using sliders or by drawing an arrow. + - You can select options for moving and flipping the object from a menu. + - Please clear the output before running a new example! + """ + + selected_points = gr.State([]) + global_points = gr.State([]) + global_point_label = gr.State([]) + bg_ori = gr.State(value=None) + l1_ori = gr.State(value=None) + with gr.Blocks() as demo: + gr.Markdown(DESCRIPTION) + with gr.Row(): + with gr.Column(): + with gr.Group(): + gr.Markdown("# INPUT") + gr.Markdown("## Draw box to mask target object") + bg_img = gr.Image(sources='upload', label="Background", interactive=True, type="numpy") + gr.Markdown("## Draw arrow to describe the movement") + l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip = create_position_size(label="Layer-1") + with gr.Row(): + run_button = gr.Button("Edit") + clear_button = gr.Button("Clear") + + with gr.Column(): + with gr.Group(): + gr.Markdown("# Mask") + gr.Markdown("## Background Mask for Removal:") + bg_mask = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy") + + with gr.Column(): + with gr.Group(): + gr.Markdown("# OUTPUT") + gr.Markdown("## Results") + output = gr.Gallery(columns=1, height='auto') + + with gr.Column(): + gr.Markdown("Try some of the examples below ⬇️") + gr.Examples( + examples=examples_moving, + inputs=[ + bg_img, bg_mask, l1_dx, l1_dy, l1_resize, l1_h_flip, l1_w_flip + ] + ) + bg_img.select( + segment_with_points, + inputs=[bg_img, bg_ori, global_points, global_point_label], + outputs=[bg_img, bg_ori, bg_mask, global_points, global_point_label] + ) + l1_img.select( + get_point_move, + [bg_ori, l1_img, selected_points], + [l1_img, bg_ori, selected_points, l1_dx, l1_dy], + ) + + run_button.click(fn=runner, inputs=[ + bg_img, bg_ori,bg_mask, + l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, selected_points + ], outputs=[output]) + + clear_button.click(fn=fun_clear, + inputs=[bg_img, bg_ori, bg_mask, l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, + global_points, global_point_label, selected_points, output], + outputs=[bg_img, bg_ori, bg_mask, l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, + global_points, global_point_label, selected_points, output], + ) + return demo diff --git a/src/demo/model.py b/src/demo/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a92e3c92296a283680eb055c030b0f241efe4a2d --- /dev/null +++ b/src/demo/model.py @@ -0,0 +1,517 @@ +import numpy as np +import torch +from diffusers import DDIMScheduler +import cv2 +from utils.sdxl import sdxl +from utils.inversion import Inversion +import math +import torch.nn.functional as F +import utils.utils as utils +import os +import matplotlib.pyplot as plt +from PIL import Image, ImageDraw, ImageFont +import spaces + +MAX_NUM_WORDS = 77 + + +class LayerFusion: + def get_mask(self, maps, alpha, use_pool,x_t): + k = 1 + maps = (maps * alpha).sum(-1).mean(1) + if use_pool: + maps = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) + mask = F.interpolate(maps, size=(x_t.shape[2:])) #[2, 1, 128, 128] + mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] + mask=(mask - mask.min ()) / (mask.max () - mask.min ()) + mask = mask.gt(self.mask_threshold) + self.mask=mask + mask = mask[:1] + mask + return mask + + def get_one_mask(self, maps, use_pool, x_t, idx_lst, i=None, sav_img=False): + k=1 + if sav_img is False: + mask_tot = 0 + for obj in idx_lst: + mask = maps[0, :, :, :, obj].mean(0).reshape(1, 1, 32, 32) + if use_pool: + mask = F.max_pool2d(mask, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) + mask = F.interpolate(mask, size=(x_t.shape[2:])) + mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] + mask=(mask - mask.min ()) / (mask.max () - mask.min ()) + mask = mask.gt(self.mask_threshold[int(self.counter/10)]) + mask_tot |= mask + mask = mask_tot + return mask + else: + for obj in idx_lst: + mask = maps[0, :, :, :, obj].mean(0).reshape(1, 1, 32, 32) + if use_pool: + mask = F.max_pool2d(mask, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) + mask = F.interpolate(mask, size=(1024, 1024))#[1, 1, 1024, 1024] + mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] + mask=(mask - mask.min ()) / (mask.max () - mask.min ()) + mask = mask.gt(0.6) + mask = np.array(mask[0][0].clone().cpu()).astype(np.uint8)*255 + cv2.imwrite(f'./img/sam_mask/{self.blend_list[i][0]}_{self.counter}.jpg', mask) + return mask + + def mv_op(self, mp, op, scale=0.2, ones=False, flip=None): + _, b, H, W = mp.shape + if ones == False: + new_mp = torch.zeros_like(mp) + else: + new_mp = torch.ones_like(mp) + K = int(scale*W) + if op == 'right': + new_mp[:, :, :, K:] = mp[:, :, :, 0:W-K] + elif op == 'left': + new_mp[:, :, :, 0:W-K] = mp[:, :, :, K:] + elif op == 'down': + new_mp[:, :, K:, :] = mp[:, :, 0:W-K, :] + elif op == 'up': + new_mp[:, :, 0:W-K, :] = mp[:, :, K:, :] + if flip is not None: + new_mp = torch.flip(new_mp, dims=flip) + + return new_mp + + def mv_layer(self, x_t, bg_id, fg_id, op_id): + bg_img = x_t[bg_id:(bg_id+1)].clone() + fg_img = x_t[fg_id:(fg_id+1)].clone() + fg_mask = self.fg_mask_list[fg_id-3] + op_list = self.op_list[fg_id-3] + + for item in op_list: + op, scale = item[0], item[1] + if scale != 0: + fg_img = self.mv_op(fg_img, op=op, scale=scale) + fg_mask = self.mv_op(fg_mask, op=op, scale=scale) + x_t[op_id:(op_id+1)] = bg_img*(1-fg_mask) + fg_img*fg_mask + + def __call__(self, x_t): + self.counter += 1 + # inpainting + if self.blend_time[0] <= self.counter <= self.blend_time[1]: + x_t[1:2] = x_t[1:2]*self.remove_mask + x_t[0:1]*(1-self.remove_mask) + + if self.counter == self.blend_time[1] + 1 and self.mode != "removal": + b = x_t.shape[0] + bg_id = 1 #bg_layer + op_id = 2 #canvas + for fg_id in range(3, b): #fg_layer + self.mv_layer(x_t, bg_id=bg_id, fg_id=fg_id, op_id=op_id) + bg_id = op_id + + return x_t + + def __init__(self, remove_mask, fg_mask_list, refine_mask=None, + blend_time=[0, 40], + mode="removal", op_list=None): + self.counter = 0 + self.mode = mode + self.op_list = op_list + self.blend_time = blend_time + + self.remove_mask = remove_mask + self.refine_mask = refine_mask + if self.refine_mask is not None: + self.new_mask = self.remove_mask + self.refine_mask + self.new_mask[self.new_mask>0] = 1 + else: + self.new_mask = None + self.fg_mask_list = fg_mask_list + + +class Control(): + def step_callback(self, x_t): + if self.layer_fusion is not None: + x_t = self.layer_fusion(x_t) + return x_t + def __init__(self, layer_fusion): + self.layer_fusion = layer_fusion + +def register_attention_control(model, controller, mask_time=[0, 40], refine_time=[0, 25]): + def ca_forward(self, place_in_unet): + to_out = self.to_out + if type(to_out) is torch.nn.modules.container.ModuleList: + to_out = self.to_out[0] + else: + to_out = self.to_out + self.counter = 0 #time + def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): #self_attention + x = hidden_states.clone() + context = encoder_hidden_states + is_cross = context is not None + if is_cross is False: + if controller.layer_fusion is not None and (mask_time[0] < self.counter < mask_time[1]): + b, i, j = x.shape + H = W = int(math.sqrt(i)) + x_old = x.clone() + x = x.reshape(b, H, W, j) + new_mask = controller.layer_fusion.remove_mask + if new_mask is not None: + new_mask[new_mask>0] = 1 + new_mask = F.interpolate(new_mask.to(dtype=torch.float32).clone(), size=(H, W), mode='bilinear').cuda() + new_mask = (1 - new_mask).reshape(1, H, W).unsqueeze(-1) + if (refine_time[0] < self.counter <= refine_time[1]) and controller.layer_fusion.refine_mask is not None: + new_mask = controller.layer_fusion.new_mask + new_mask = F.interpolate(new_mask.to(dtype=torch.float32).clone(), size=(H, W), mode='bilinear').cuda() + new_mask = (1 - new_mask).reshape(1, H, W).unsqueeze(-1) + idx = 1 #inpaiint_idx:bg + x[int(b/2)+idx, :, :] = (x[int(b/2)+idx, :, :]*new_mask[0]) + x = x.reshape(b, i, j) + if is_cross: + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + else: + context = x + q = self.to_q(hidden_states) + k = self.to_k(x) + v = self.to_v(hidden_states) + q = self.head_to_batch_dim(q) + k = self.head_to_batch_dim(k) + v = self.head_to_batch_dim(v) + + if hasattr(controller, 'count_layers'): + controller.count_layers(place_in_unet,is_cross) + sim = torch.einsum("b i d, b j d -> b i j", q.clone(), k.clone()) * self.scale + + attn = sim.softmax(dim=-1) + out = torch.einsum("b i j, b j d -> b i d", attn, v) + out = self.batch_to_head_dim(out) + global global_cnt + self.counter += 1 + return to_out(out) + + return forward + + def register_recr(net_, count, place_in_unet): + if net_.__class__.__name__ == 'Attention': + net_.forward = ca_forward(net_, place_in_unet) + return count + 1 + elif hasattr(net_, 'children'): + for net__ in net_.children(): + count = register_recr(net__, count, place_in_unet) + return count + + cross_att_count = 0 + sub_nets = model.unet.named_children() + for net in sub_nets: + if "down" in net[0]: + cross_att_count += register_recr(net[1], 0, "down") + elif "up" in net[0]: + cross_att_count += register_recr(net[1], 0, "up") + elif "mid" in net[0]: + cross_att_count += register_recr(net[1], 0, "mid") + + controller.num_att_layers = cross_att_count + +class DesignEdit(): + def __init__(self, pretrained_model_path="/home/jyr/model/stable-diffusion-xl-base-1.0"): + self.model_dtype = "fp16" + self.pretrained_model_path=pretrained_model_path + self.num_ddim_steps = 50 + self.mask_time = [0, 40] + self.op_list = {} + self.attend_scale = {} + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) + if self.model_dtype == "fp16": + torch_dtype = torch.float16 + elif self.model_dtype == "fp32": + torch_dtype = torch.float32 + self.pipe = sdxl.from_pretrained(self.pretrained_model_path, torch_dtype=torch_dtype, use_safetensors=True, variant=self.model_dtype,scheduler=scheduler) + + @spaces.GPU + def init_model(self, num_ddim_steps=50): + device = torch.device('cuda:0') + self.pipe.to(device) + inversion = Inversion(self.pipe,num_ddim_steps) + return self.pipe, inversion + + @spaces.GPU(duration=120, enable_queue=True) + def run_remove(self, original_image=None, mask_1=None, mask_2=None, mask_3=None, refine_mask=None, + ori_1=None, ori_2=None, ori_3=None, + prompt="", save_dir="./tmp", mode='removal',): + # 01-1: + self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) + if original_image is None: + original_image = ori_1 if ori_1 is not None else ori_2 if ori_2 is not None else ori_3 + op_list = None + attend_scale = 20 + sample_ref_match={0 : 0, 1 : 0} + ori_shape = original_image.shape + + # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask + image_gt = Image.fromarray(original_image).resize((1024, 1024)) + image_gt = np.stack([np.array(image_gt)]) + mask_list = [mask_1, mask_2, mask_3] + remove_mask = utils.attend_mask(utils.add_masks_resized(mask_list), attend_scale=attend_scale) # numpy to tensor + fg_mask_list = None + refine_mask = utils.attend_mask(utils.convert_and_resize_mask(refine_mask)) if refine_mask is not None else None + + # 01-3: prepare: prompts, blend_time, refine_time + prompts = len(sample_ref_match)*[prompt] # 2 + blend_time = [0, 41] + refine_time = [0, 25] + + # 02: invert + _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1) + + # 03: init layer_fusion and controller + lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, refine_mask=refine_mask, + blend_time=blend_time, mode=mode, op_list=op_list) + controller = Control(layer_fusion=lb) + register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time) + + # 04: generate images + images = self.ldm_model(controller=controller, prompt=prompts, + latents=x_t, x_stars=x_stars, + negative_prompt_embeds=prompt_embeds, + negative_pooled_prompt_embeds=pooled_prompt_embeds, + sample_ref_match=sample_ref_match) + folder = None + utils.view_images(images, folder=folder) + return [cv2.resize(images[1], (ori_shape[1], ori_shape[0]))] + + @spaces.GPU(duration=120, enable_queue=True) + def run_zooming(self, original_image, width_scale=1, height_scale=1, prompt="", save_dir="./tmp", mode='removal'): + self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) + # 01-1: + op_list = {0: ['zooming', [height_scale, width_scale]]} + ori_shape = original_image.shape + attend_scale = 30 + sample_ref_match = {0 : 0, 1 : 0} + + # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask + img_new, mask = utils.zooming(original_image, [height_scale, width_scale]) + img_new_copy = img_new.copy() + mask_copy = mask.copy() + + image_gt = Image.fromarray(img_new).resize((1024, 1024)) + image_gt = np.stack([np.array(image_gt)]) + + remove_mask = utils.attend_mask(utils.convert_and_resize_mask(mask), attend_scale=attend_scale) # numpy to tensor + fg_mask_list = None + refine_mask = None + + # 01-3: prepare: prompts, blend_time, refine_time + prompts = len(sample_ref_match)*[prompt] # 2 + blend_time = [0, 41] + refine_time = [0, 25] + + # 02: invert + _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1) + + # 03: init layer_fusion and controller + lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, + mode=mode, op_list=op_list) + controller = Control(layer_fusion=lb) + register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time) + + # 04: generate images + images = self.ldm_model(controller=controller, prompt=prompts, + latents=x_t, x_stars=x_stars, + negative_prompt_embeds=prompt_embeds, + negative_pooled_prompt_embeds=pooled_prompt_embeds, + sample_ref_match=sample_ref_match) + folder = None + utils.view_images(images, folder=folder) + resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0])) + return [resized_img], [img_new_copy], [mask_copy] + + @spaces.GPU(duration=120, enable_queue=True) + def run_panning(self, original_image, w_direction, w_scale, h_direction, h_scale, prompt="", save_dir="./tmp", mode='removal'): + # 01-1: prepare: op_list, attend_scale, sample_ref_match + self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) + ori_shape = original_image.shape + attend_scale = 30 + sample_ref_match = {0 : 0, 1 : 0} + + # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask + op_list = [[w_direction, w_scale], [h_direction, h_scale]] + img_new, mask = utils.panning(original_image, op_list=op_list) + img_new_copy = img_new.copy() + mask_copy = mask.copy() + + image_gt = Image.fromarray(img_new).resize((1024, 1024)) + image_gt = np.stack([np.array(image_gt)]) + remove_mask = utils.attend_mask(utils.convert_and_resize_mask(mask), attend_scale=attend_scale) # numpy to tensor + + fg_mask_list = None + refine_mask = None + + # 01-3: prepare: prompts, blend_time, refine_time + prompts = len(sample_ref_match)*[prompt] # 2 + blend_time = [0, 41] + refine_time = [0, 25] + + # 02: invert + _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1) + # 03: init layer_fusion and controller + lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, + mode=mode, op_list=op_list) + controller = Control(layer_fusion=lb) + register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time) + + # 04: generate images + + images = self.ldm_model(controller=controller, prompt=prompts, + latents=x_t, x_stars=x_stars, + negative_prompt_embeds=prompt_embeds, + negative_pooled_prompt_embeds=pooled_prompt_embeds, + sample_ref_match=sample_ref_match) + folder = None + utils.view_images(images, folder=folder) + resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0])) + return [resized_img], [img_new_copy], [mask_copy] + + # layer-wise multi-object editing + def process_layer_states(self, layer_states): + self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) + image_paths = [] + mask_paths = [] + op_list = [] + + for state in layer_states: + img, mask, dx, dy, resize, w_flip, h_flip = state + if img is not None: + img = cv2.resize(img, (1024, 1024)) + mask = utils.convert_and_resize_mask(mask) + dx_command = ['right', dx] if dx > 0 else ['left', -dx] + dy_command = ['up', dy] if dy > 0 else ['down', -dy] + flip_code = None + if w_flip == "left/right" and h_flip == "down/up": + flip_code = -1 + elif w_flip == "left/right": + flip_code = 1 # 或者其他默认值,根据您的需要设置 + elif h_flip == "down/up": + flip_code = 0 + op_list.append([dx_command, dy_command]) + img, mask, _ = utils.resize_image_with_mask(img, mask, resize) + img, mask, _ = utils.flip_image_with_mask(img, mask, flip_code=flip_code) + image_paths.append(img) + mask_paths.append(utils.attend_mask(mask)) + sample_ref_match = {0: 0, 1: 0, 2: 0, 3: 1, 4: 2, 5: 3} + required_length = len(image_paths) + 3 + truncated_sample_ref_match = {k: sample_ref_match[k] for k in sorted(sample_ref_match.keys())[:required_length]} + return image_paths, mask_paths, op_list, truncated_sample_ref_match + + @spaces.GPU(duration=200) + def run_layer(self, bg_img, l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, + l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip, + l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip, + bg_mask, l1_mask, l2_mask, l3_mask, + bg_ori=None, l1_ori=None, l2_ori=None, l3_ori=None, + prompt="", save_dir="./tmp", mode='layerwise'): + self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) + # 00: prepare: layer-wise states + bg_img = bg_ori if bg_ori is not None else bg_img + l1_img = l1_ori if l1_ori is not None else l1_img + l2_img = l2_ori if l2_ori is not None else l2_img + l3_img = l3_ori if l3_ori is not None else l3_img + for mask in [bg_mask, l1_mask, l2_mask, l3_mask]: + if mask is None: + mask = np.zeros((1024, 1024), dtype=np.uint8) + else: + mask = utils.convert_and_resize_mask(mask) + l1_state = [l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip] + l2_state = [l2_img, l2_mask, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip] + l3_state = [l3_img, l3_mask, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip] + ori_shape = bg_img.shape + + image_paths, fg_mask_list, op_list, sample_ref_match = self.process_layer_states([l1_state, l2_state, l3_state]) + if image_paths == []: + mode = "removal" + # 01-1: prepare: image_gt, remove_mask, fg_mask_list, refine_mask + attend_scale = 20 + image_gt = [bg_img] + image_paths + image_gt = [Image.fromarray(img).resize((1024, 1024)) for img in image_gt] + image_gt = np.stack(image_gt) + remove_mask = utils.attend_mask(bg_mask, attend_scale=attend_scale) + refine_mask = None + + # 01-2: prepare: promptrun_masks, blend_time, refine_time + prompts = len(sample_ref_match)*[prompt] # 2 + blend_time = [0, 41] + refine_time = [0, 25] + attend_scale = [] + + # 02: invert + _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=len(image_gt)) + # 03: init layer_fusion and controller + lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, refine_mask=refine_mask, + mode=mode, op_list=op_list) + controller = Control(layer_fusion=lb) + register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time) + # 04: generate images + images = self.ldm_model(controller=controller, prompt=prompts, + latents=x_t, x_stars=x_stars, + negative_prompt_embeds=prompt_embeds, + negative_pooled_prompt_embeds=pooled_prompt_embeds, + sample_ref_match=sample_ref_match) + folder = None + utils.view_images(images, folder=folder) + if mode == 'removal': + resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0])) + else: + resized_img = cv2.resize(images[2], (ori_shape[1], ori_shape[0])) + return [resized_img] + + @spaces.GPU(duration=120, enable_queue=True) + def run_moving(self, bg_img, bg_ori, bg_mask, l1_dx, l1_dy, l1_resize, + l1_w_flip=None, l1_h_flip=None, selected_points=None, + prompt="", save_dir="./tmp", mode='layerwise'): + self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) + # 00: prepare: layer-wise states + bg_img = bg_ori if bg_ori is not None else bg_img + l1_img = bg_img + if bg_mask is None: + bg_mask = np.zeros((1024, 1024), dtype=np.uint8) + else: + bg_mask = utils.convert_and_resize_mask(bg_mask) + l1_mask = bg_mask + l1_state = [l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip] + ori_shape = bg_img.shape + + image_paths, fg_mask_list, op_list, sample_ref_match = self.process_layer_states([l1_state]) + + # 01-1: prepare: image_gt, remove_mask, fg_mask_list, refine_mask + attend_scale = 20 + image_gt = [bg_img] + image_paths + image_gt = [Image.fromarray(img).resize((1024, 1024)) for img in image_gt] + image_gt = np.stack(image_gt) + remove_mask = utils.attend_mask(bg_mask, attend_scale=attend_scale) + refine_mask = None + + # 01-2: prepare: promptrun_masks, blend_time, refine_time + prompts = len(sample_ref_match)*[prompt] # 2 + blend_time = [0, 41] + refine_time = [0, 25] + attend_scale = [] + + # 02: invert + _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=len(image_gt)) + # 03: init layer_fusion and controller + lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, refine_mask=refine_mask, + mode=mode, op_list=op_list) + controller = Control(layer_fusion=lb) + register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time) + # 04: generate images + images = self.ldm_model(controller=controller, prompt=prompts, + latents=x_t, x_stars=x_stars, + negative_prompt_embeds=prompt_embeds, + negative_pooled_prompt_embeds=pooled_prompt_embeds, + sample_ref_match=sample_ref_match) + folder = None + utils.view_images(images, folder=folder) + resized_img = cv2.resize(images[2], (ori_shape[1], ori_shape[0])) + return [resized_img] + + # turn mask to 1024x1024 unit-8 + def run_mask(self, mask_1, mask_2, mask_3, mask_4): + mask_list = [mask_1, mask_2, mask_3, mask_4] + final_mask = utils.add_masks_resized(mask_list) + return final_mask \ No newline at end of file diff --git a/src/demo/utils.py b/src/demo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d0c789365ca94a89a9e71874999b73398f682f9 --- /dev/null +++ b/src/demo/utils.py @@ -0,0 +1,319 @@ +import numpy as np +import gradio as gr +import cv2 +from copy import deepcopy +import torch +from torchvision import transforms +from PIL import Image, ImageDraw, ImageFont + +from sam.efficient_sam.build_efficient_sam import build_efficient_sam_vits +from src.utils.utils import resize_numpy_image + +sam = build_efficient_sam_vits() + +def show_point_or_box(image, global_points): + # for point + if len(global_points) == 1: + image = cv2.circle(image, global_points[0], 10, (0, 0, 255), -1) + # for box + if len(global_points) == 2: + p1 = global_points[0] + p2 = global_points[1] + image = cv2.rectangle(image,(int(p1[0]),int(p1[1])),(int(p2[0]),int(p2[1])),(0,0,255),2) + + return image + +def segment_with_points( + image, + original_image, + global_points, + global_point_label, + evt: gr.SelectData, + img_direction, + save_dir = "./tmp" +): + if original_image is None: + original_image = image + else: + image = original_image + if img_direction is None: + img_direction = original_image + x, y = evt.index[0], evt.index[1] + image_path = None + mask_path = None + if len(global_points) == 0: + global_points.append([x, y]) + global_point_label.append(2) + image_with_point= show_point_or_box(image.copy(), global_points) + return image_with_point, original_image, None, global_points, global_point_label + elif len(global_points) == 1: + global_points.append([x, y]) + global_point_label.append(3) + x1, y1 = global_points[0] + x2, y2 = global_points[1] + if x1 < x2 and y1 >= y2: + global_points[0][0] = x1 + global_points[0][1] = y2 + global_points[1][0] = x2 + global_points[1][1] = y1 + elif x1 >= x2 and y1 < y2: + global_points[0][0] = x2 + global_points[0][1] = y1 + global_points[1][0] = x1 + global_points[1][1] = y2 + elif x1 >= x2 and y1 >= y2: + global_points[0][0] = x2 + global_points[0][1] = y2 + global_points[1][0] = x1 + global_points[1][1] = y1 + image_with_point = show_point_or_box(image.copy(), global_points) + # data process + input_point = np.array(global_points) + input_label = np.array(global_point_label) + pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2]) + pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1]) + img_tensor = transforms.ToTensor()(image) + # sam + predicted_logits, predicted_iou = sam( + img_tensor[None, ...], + pts_sampled, + pts_labels, + ) + mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy() + mask_image = (mask*255.).astype(np.uint8) + return image_with_point, original_image, mask_image, global_points, global_point_label + else: + global_points=[[x, y]] + global_point_label=[2] + image_with_point= show_point_or_box(image.copy(), global_points) + return image_with_point, original_image, None, global_points, global_point_label + + +def segment_with_points_paste( + image, + original_image, + global_points, + global_point_label, + image_b, + evt: gr.SelectData, + dx, + dy, + resize_scale + +): + if original_image is None: + original_image = image + else: + image = original_image + x, y = evt.index[0], evt.index[1] + if len(global_points) == 0: + global_points.append([x, y]) + global_point_label.append(2) + image_with_point= show_point_or_box(image.copy(), global_points) + return image_with_point, original_image, None, global_points, global_point_label, None + elif len(global_points) == 1: + global_points.append([x, y]) + global_point_label.append(3) + x1, y1 = global_points[0] + x2, y2 = global_points[1] + if x1 < x2 and y1 >= y2: + global_points[0][0] = x1 + global_points[0][1] = y2 + global_points[1][0] = x2 + global_points[1][1] = y1 + elif x1 >= x2 and y1 < y2: + global_points[0][0] = x2 + global_points[0][1] = y1 + global_points[1][0] = x1 + global_points[1][1] = y2 + elif x1 >= x2 and y1 >= y2: + global_points[0][0] = x2 + global_points[0][1] = y2 + global_points[1][0] = x1 + global_points[1][1] = y1 + image_with_point = show_point_or_box(image.copy(), global_points) + # data process + input_point = np.array(global_points) + input_label = np.array(global_point_label) + pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2]) + pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1]) + img_tensor = transforms.ToTensor()(image) + # sam + predicted_logits, predicted_iou = sam( + img_tensor[None, ...], + pts_sampled, + pts_labels, + ) + mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy() + mask_uint8 = (mask*255.).astype(np.uint8) + + return image_with_point, original_image, paste_with_mask_and_offset(image, image_b, mask_uint8, dx, dy, resize_scale), global_points, global_point_label, mask_uint8 + else: + global_points=[[x, y]] + global_point_label=[2] + image_with_point= show_point_or_box(image.copy(), global_points) + return image_with_point, original_image, None, global_points, global_point_label, None + +def paste_with_mask_and_offset(image_a, image_b, mask, x_offset=0, y_offset=0, delta=1): + try: + numpy_mask = np.array(mask) + y_coords, x_coords = np.nonzero(numpy_mask) + x_min = x_coords.min() + x_max = x_coords.max() + y_min = y_coords.min() + y_max = y_coords.max() + target_center_x = int((x_min + x_max) / 2) + target_center_y = int((y_min + y_max) / 2) + + image_a = Image.fromarray(image_a) + image_b = Image.fromarray(image_b) + mask = Image.fromarray(mask) + + if image_a.size != mask.size: + mask = mask.resize(image_a.size) + + cropped_image = Image.composite(image_a, Image.new('RGBA', image_a.size, (0, 0, 0, 0)), mask) + x_b = int(target_center_x * (image_b.width / cropped_image.width)) + y_b = int(target_center_y * (image_b.height / cropped_image.height)) + x_offset = x_offset - int((delta - 1) * x_b) + y_offset = y_offset - int((delta - 1) * y_b) + cropped_image = cropped_image.resize(image_b.size) + new_size = (int(cropped_image.width * delta), int(cropped_image.height * delta)) + cropped_image = cropped_image.resize(new_size) + image_b.putalpha(128) + result_image = Image.new('RGBA', image_b.size, (0, 0, 0, 0)) + result_image.paste(image_b, (0, 0)) + result_image.paste(cropped_image, (x_offset, y_offset), mask=cropped_image) + + return result_image + except: + return None + +def upload_image_move(img, original_image): + if original_image is not None: + return original_image + else: + return img + +def fun_clear(*args): + result = [] + for arg in args: + if isinstance(arg, list): + result.append([]) + else: + result.append(None) + return tuple(result) + +def clear_points(img): + image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = image.copy() + + return [], masked_img + +def get_point(img, sel_pix, evt: gr.SelectData): + sel_pix.append(evt.index) + points = [] + for idx, point in enumerate(sel_pix): + if idx % 2 == 0: + cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) + else: + cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) + points.append(tuple(point)) + if len(points) == 2: + cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) + points = [] + return img if isinstance(img, np.ndarray) else np.array(img) + +def calculate_translation_percentage(ori_shape, selected_points): + dx = selected_points[1][0] - selected_points[0][0] + dy = selected_points[1][1] - selected_points[0][1] + dx_percentage = dx / ori_shape[1] + dy_percentage = dy / ori_shape[0] + + return dx_percentage, dy_percentage + +def get_point_move(original_image, img, sel_pix, evt: gr.SelectData): + if original_image is not None: + img = original_image.copy() + else: + original_image = img.copy() + if len(sel_pix)<2: + sel_pix.append(evt.index) + else: + sel_pix = [evt.index] + points = [] + dx, dy = 0, 0 + for idx, point in enumerate(sel_pix): + if idx % 2 == 0: + cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) + else: + cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) + points.append(tuple(point)) + if len(points) == 2: + cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) + ori_shape = original_image.shape + dx, dy = calculate_translation_percentage(original_image.shape, sel_pix) + points = [] + img = np.array(img) + + return img, original_image, sel_pix, dx, dy + +def store_img(img): + image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = image.copy() + + return image, masked_img, mask +# im["background"], im["layers"][0] +def store_img_move(img, mask=None): + if mask is not None: + image = img["background"] + return image, None, mask + image, mask = img["background"], np.float32(["layers"][0][:, :, 0]) / 255. + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = image.copy() + + return image, masked_img, (mask*255.).astype(np.uint8) + +def store_img_move_old(img, mask=None): + if mask is not None: + image = img["image"] + return image, None, mask + image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = image.copy() + + return image, masked_img, (mask*255.).astype(np.uint8) + +def mask_image(image, mask, color=[255,0,0], alpha=0.5, max_resolution=None): + """ Overlay mask on image for visualization purpose. + Args: + image (H, W, 3) or (H, W): input image + mask (H, W): mask to be overlaid + color: the color of overlaid mask + alpha: the transparency of the mask + """ + if max_resolution is not None: + image, _ = resize_numpy_image(image, max_resolution*max_resolution) + mask = cv2.resize(mask, (image.shape[1], image.shape[0]),interpolation=cv2.INTER_NEAREST) + + out = deepcopy(image) + img = deepcopy(image) + img[mask == 1] = color + out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) + contours = cv2.findContours(np.uint8(deepcopy(mask)), cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE)[-2:] + return out \ No newline at end of file diff --git a/src/utils/utils.py b/src/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b127a81c0f3f87a58bbab638cd032da25cdb5d5c --- /dev/null +++ b/src/utils/utils.py @@ -0,0 +1,240 @@ +import numpy as np +import cv2 +from basicsr.utils import img2tensor +import torch +import torch.nn.functional as F + +def resize_numpy_image(image, max_resolution=768 * 768, resize_short_edge=None): + h, w = image.shape[:2] + w_org = image.shape[1] + if resize_short_edge is not None: + k = resize_short_edge / min(h, w) + else: + k = max_resolution / (h * w) + k = k**0.5 + h = int(np.round(h * k / 64)) * 64 + w = int(np.round(w * k / 64)) * 64 + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) + scale = w/w_org + return image, scale + +def split_ldm(ldm): + x = [] + y = [] + for p in ldm: + x.append(p[0]) + y.append(p[1]) + return x,y + +def process_move(path_mask, h, w, dx, dy, scale, input_scale, resize_scale, up_scale, up_ft_index, w_edit, w_content, w_contrast, w_inpaint, precision, path_mask_ref=None): + dx, dy = dx*input_scale, dy*input_scale + if isinstance(path_mask, str): + mask_x0 = cv2.imread(path_mask) + else: + mask_x0 = path_mask + mask_x0 = cv2.resize(mask_x0, (h, w)) + if path_mask_ref is not None: + if isinstance(path_mask_ref, str): + mask_x0_ref = cv2.imread(path_mask_ref) + else: + mask_x0_ref = path_mask_ref + mask_x0_ref = cv2.resize(mask_x0_ref, (h, w)) + else: + mask_x0_ref=None + + mask_x0 = img2tensor(mask_x0)[0] + mask_x0 = (mask_x0>0.5).float().to('cuda', dtype=precision) + if mask_x0_ref is not None: + mask_x0_ref = img2tensor(mask_x0_ref)[0] + mask_x0_ref = (mask_x0_ref>0.5).float().to('cuda', dtype=precision) + mask_org = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)))>0.5 + + mask_tar = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale*resize_scale), int(mask_x0.shape[-1]//scale*resize_scale)))>0.5 + mask_cur = torch.roll(mask_tar, (int(dy//scale*resize_scale), int(dx//scale*resize_scale)), (-2,-1)) + + pad_size_x = abs(mask_tar.shape[-1]-mask_org.shape[-1])//2 + pad_size_y = abs(mask_tar.shape[-2]-mask_org.shape[-2])//2 + if resize_scale>1: + sum_before = torch.sum(mask_cur) + mask_cur = mask_cur[:,:,pad_size_y:pad_size_y+mask_org.shape[-2],pad_size_x:pad_size_x+mask_org.shape[-1]] + sum_after = torch.sum(mask_cur) + if sum_after != sum_before: + raise ValueError('Resize out of bounds, exiting.') + else: + temp = torch.zeros(1,1,mask_org.shape[-2], mask_org.shape[-1]).to(mask_org.device) + temp[:,:,pad_size_y:pad_size_y+mask_cur.shape[-2],pad_size_x:pad_size_x+mask_cur.shape[-1]]=mask_cur + mask_cur =temp>0.5 + + mask_other = (1-((mask_cur+mask_org)>0.5).float())>0.5 + mask_overlap = ((mask_cur.float()+mask_org.float())>1.5).float() + mask_non_overlap = (mask_org.float()-mask_overlap)>0.5 + + return { + "mask_x0":mask_x0, + "mask_x0_ref":mask_x0_ref, + "mask_tar":mask_tar, + "mask_cur":mask_cur, + "mask_other":mask_other, + "mask_overlap":mask_overlap, + "mask_non_overlap":mask_non_overlap, + "up_scale":up_scale, + "up_ft_index":up_ft_index, + "resize_scale":resize_scale, + "w_edit":w_edit, + "w_content":w_content, + "w_contrast":w_contrast, + "w_inpaint":w_inpaint, + } + +def process_drag_face(h, w, x, y, x_cur, y_cur, scale, input_scale, up_scale, up_ft_index, w_edit, w_inpaint, precision): + for i in range(len(x)): + x[i] = int(x[i]*input_scale) + y[i] = int(y[i]*input_scale) + x_cur[i] = int(x_cur[i]*input_scale) + y_cur[i] = int(y_cur[i]*input_scale) + + mask_tar = [] + for p_idx in range(len(x)): + mask_i = torch.zeros(int(h//scale), int(w//scale)).cuda() + y_clip = int(np.clip(y[p_idx]//scale, 1, mask_i.shape[0]-2)) + x_clip = int(np.clip(x[p_idx]//scale, 1, mask_i.shape[1]-2)) + mask_i[y_clip-1:y_clip+2,x_clip-1:x_clip+2]=1 + mask_i = mask_i>0.5 + mask_tar.append(mask_i) + mask_cur = [] + for p_idx in range(len(x_cur)): + mask_i = torch.zeros(int(h//scale), int(w//scale)).cuda() + y_clip = int(np.clip(y_cur[p_idx]//scale, 1, mask_i.shape[0]-2)) + x_clip = int(np.clip(x_cur[p_idx]//scale, 1, mask_i.shape[1]-2)) + mask_i[y_clip-1:y_clip+2,x_clip-1:x_clip+2]=1 + mask_i=mask_i>0.5 + mask_cur.append(mask_i) + + return { + "mask_tar":mask_tar, + "mask_cur":mask_cur, + "up_scale":up_scale, + "up_ft_index":up_ft_index, + "w_edit": w_edit, + "w_inpaint": w_inpaint, + } + +def process_drag(path_mask, h, w, x, y, x_cur, y_cur, scale, input_scale, up_scale, up_ft_index, w_edit, w_inpaint, w_content, precision, latent_in): + if isinstance(path_mask, str): + mask_x0 = cv2.imread(path_mask) + else: + mask_x0 = path_mask + mask_x0 = cv2.resize(mask_x0, (h, w)) + mask_x0 = img2tensor(mask_x0)[0] + dict_mask = {} + dict_mask['base'] = mask_x0 + mask_x0 = (mask_x0>0.5).float().to('cuda', dtype=precision) + + mask_other = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)))<0.5 + mask_tar = [] + mask_cur = [] + for p_idx in range(len(x)): + mask_tar_i = torch.zeros(int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)).to('cuda', dtype=precision) + mask_cur_i = torch.zeros(int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)).to('cuda', dtype=precision) + y_tar_clip = int(np.clip(y[p_idx]//scale, 1, mask_tar_i.shape[0]-2)) + x_tar_clip = int(np.clip(x[p_idx]//scale, 1, mask_tar_i.shape[0]-2)) + y_cur_clip = int(np.clip(y_cur[p_idx]//scale, 1, mask_cur_i.shape[0]-2)) + x_cur_clip = int(np.clip(x_cur[p_idx]//scale, 1, mask_cur_i.shape[0]-2)) + mask_tar_i[y_tar_clip-1:y_tar_clip+2,x_tar_clip-1:x_tar_clip+2]=1 + mask_cur_i[y_cur_clip-1:y_cur_clip+2,x_cur_clip-1:x_cur_clip+2]=1 + mask_tar_i = mask_tar_i>0.5 + mask_cur_i=mask_cur_i>0.5 + mask_tar.append(mask_tar_i) + mask_cur.append(mask_cur_i) + latent_in[:,:,y_cur_clip//up_scale-1:y_cur_clip//up_scale+2, x_cur_clip//up_scale-1:x_cur_clip//up_scale+2] = latent_in[:,:, y_tar_clip//up_scale-1:y_tar_clip//up_scale+2, x_tar_clip//up_scale-1:x_tar_clip//up_scale+2] + + + return { + "dict_mask":dict_mask, + "mask_x0":mask_x0, + "mask_tar":mask_tar, + "mask_cur":mask_cur, + "mask_other":mask_other, + "up_scale":up_scale, + "up_ft_index":up_ft_index, + "w_edit": w_edit, + "w_inpaint": w_inpaint, + "w_content": w_content, + "latent_in":latent_in, + } + +def process_appearance(path_mask, path_mask_replace, h, w, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision): + if isinstance(path_mask, str): + mask_base = cv2.imread(path_mask) + else: + mask_base = path_mask + mask_base = cv2.resize(mask_base, (h, w)) + if isinstance(path_mask_replace, str): + mask_replace = cv2.imread(path_mask_replace) + else: + mask_replace = path_mask_replace + mask_replace = cv2.resize(mask_replace, (h, w)) + + dict_mask = {} + mask_base = img2tensor(mask_base)[0] + dict_mask['base'] = mask_base + mask_base = (mask_base>0.5).to('cuda', dtype=precision) + mask_replace = img2tensor(mask_replace)[0] + dict_mask['replace'] = mask_replace + mask_replace = (mask_replace>0.5).to('cuda', dtype=precision) + + mask_base_cur = F.interpolate(mask_base[None,None], (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)))>0.5 + mask_replace_cur = F.interpolate(mask_replace[None,None], (int(mask_replace.shape[-2]//scale), int(mask_replace.shape[-1]//scale)))>0.5 + + return { + "dict_mask":dict_mask, + "mask_base_cur":mask_base_cur, + "mask_replace_cur":mask_replace_cur, + "up_scale":up_scale, + "up_ft_index":up_ft_index, + "w_edit":w_edit, + "w_content":w_content, + } + +def process_paste(path_mask, h, w, dx, dy, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision, resize_scale=None): + dx, dy = dx*input_scale, dy*input_scale + if isinstance(path_mask, str): + mask_base = cv2.imread(path_mask) + else: + mask_base = path_mask + mask_base = cv2.resize(mask_base, (h, w)) + + dict_mask = {} + mask_base = img2tensor(mask_base)[0][None, None] + mask_base = (mask_base>0.5).to('cuda', dtype=precision) + if resize_scale is not None and resize_scale!=1: + hi, wi = mask_base.shape[-2], mask_base.shape[-1] + mask_base = F.interpolate(mask_base, (int(hi*resize_scale), int(wi*resize_scale))) + pad_size_x = np.abs(mask_base.shape[-1]-wi)//2 + pad_size_y = np.abs(mask_base.shape[-2]-hi)//2 + if resize_scale>1: + mask_base = mask_base[:,:,pad_size_y:pad_size_y+hi,pad_size_x:pad_size_x+wi] + else: + temp = torch.zeros(1,1,hi, wi).to(mask_base.device) + temp[:,:,pad_size_y:pad_size_y+mask_base.shape[-2],pad_size_x:pad_size_x+mask_base.shape[-1]]=mask_base + mask_base = temp + mask_replace = mask_base.clone() + mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2,-1)) + dict_mask['base'] = mask_base[0,0] + dict_mask['replace'] = mask_replace[0,0] + mask_replace = (mask_replace>0.5).to('cuda', dtype=precision) + + mask_base_cur = F.interpolate(mask_base, (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)))>0.5 + mask_replace_cur = torch.roll(mask_base_cur, (-int(dy/scale), -int(dx/scale)), (-2,-1)) + + return { + "dict_mask":dict_mask, + "mask_base_cur":mask_base_cur, + "mask_replace_cur":mask_replace_cur, + "up_scale":up_scale, + "up_ft_index":up_ft_index, + "w_edit":w_edit, + "w_content":w_content, + "w_edit":w_edit, + "w_content":w_content, + } \ No newline at end of file diff --git a/utils/inversion.py b/utils/inversion.py new file mode 100755 index 0000000000000000000000000000000000000000..16ce1796d3a905a1596300b39cdfa940d49d0e15 --- /dev/null +++ b/utils/inversion.py @@ -0,0 +1,265 @@ +import torch +import numpy as np +from PIL import Image +from typing import Optional, Union, Tuple, List +from tqdm import tqdm +import os +from diffusers import DDIMInverseScheduler,DPMSolverMultistepInverseScheduler +import spaces + +class Inversion: + + def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + sample: Union[torch.FloatTensor, np.ndarray]): + timestep, next_timestep = min( + timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod + alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + @torch.no_grad() + def get_noise_pred_single(self, latents, t, context,cond=True,both=False): + added_cond_id=1 if cond else 0 + do_classifier_free_guidance=False + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if both is False: + added_cond_kwargs = {"text_embeds": self.add_text_embeds[added_cond_id].unsqueeze(0).repeat(self.inv_batch_size,1), "time_ids": self.add_time_ids[added_cond_id].unsqueeze(0).repeat(self.inv_batch_size,1)} + else: + added_cond_kwargs = {"text_embeds": self.add_text_embeds, "time_ids": self.add_time_ids} + noise_pred = self.model.unet( + latent_model_input, + t, + encoder_hidden_states=context, + cross_attention_kwargs=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + return noise_pred + + @torch.no_grad() + def latent2image(self, latents, return_type='np'): + latents = 1 / self.model.vae.config.scaling_factor * latents.detach() + self.model.vae.to(dtype=torch.float32) + image = self.model.vae.decode(latents)['sample'] + if return_type == 'np': + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image * 255).astype(np.uint8) + return image + + @torch.no_grad() + @spaces.GPU + def image2latent(self, image): + with torch.no_grad(): + if type(image) is Image: + image = np.array(image) + else: + if image.ndim==3: + image=np.expand_dims(image,0) + image = torch.from_numpy(image).float() / 127.5 - 1 + image = image.permute(0, 3, 1, 2).to(self.device) + print(f"Running on device: {self.device}") + latents=[] + for i,_ in enumerate(image): + latent=self.model.vae.encode(image[i:i+1])['latent_dist'].mean + latents.append(latent) + latents=torch.stack(latents).squeeze(1) + latents = latents * self.model.vae.config.scaling_factor + return latents + + @torch.no_grad() + def init_prompt( + self, + prompt: Union[str, List[str]], + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + original_size = original_size or (1024, 1024) + target_size = target_size or (1024, 1024) + # 3. Encode input prompt + do_classifier_free_guidance=True + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.model.encode_prompt_not_zero_uncond( + prompt, + self.model.device, + 1, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + lora_scale=None, + ) + prompt_embeds=prompt_embeds[:self.inv_batch_size] + negative_prompt_embeds=negative_prompt_embeds[:self.inv_batch_size] + pooled_prompt_embeds=pooled_prompt_embeds[:self.inv_batch_size] + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[:self.inv_batch_size] + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self.model._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(self.device) + self.add_text_embeds = add_text_embeds.to(self.device) + self.add_time_ids = add_time_ids.to(self.device).repeat(self.inv_batch_size * 1, 1) + + self.prompt_embeds=prompt_embeds + self.negative_prompt_embeds=negative_prompt_embeds + self.pooled_prompt_embeds=pooled_prompt_embeds + self.negative_pooled_prompt_embeds=negative_pooled_prompt_embeds + self.prompt = prompt + self.context=prompt_embeds + + @torch.no_grad() + @spaces.GPU + def ddim_loop(self, latent): + uncond_embeddings, cond_embeddings = self.context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + extra_step_kwargs = self.model.prepare_extra_step_kwargs(self.generator, self.eta) + if isinstance(self.inverse_scheduler,DDIMInverseScheduler): + extra_step_kwargs.pop("generator") + for i in tqdm(range(self.num_ddim_steps)): + use_inv_sc=False + if use_inv_sc: + t = self.inverse_scheduler.timesteps[i] + noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings,cond=True) + latent = self.inverse_scheduler.step(noise_pred, t, latent, **extra_step_kwargs, return_dict=False)[0] + else: + t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] + noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings,cond=True) + latent = self.next_step(noise_pred, t, latent) + all_latent.append(latent) + return all_latent + + @property + def scheduler(self): + return self.model.scheduler + + @torch.no_grad() + @spaces.GPU + def ddim_inversion(self, image): + latent = self.image2latent(image) + image_rec = self.latent2image(latent) + ddim_latents = self.ddim_loop(latent.to(self.model.unet.dtype)) + return image_rec, ddim_latents + + from typing import Union, List, Dict + import numpy as np + + @spaces.GPU + def invert(self, image_gt, prompt: Union[str, List[str]], + verbose=True, inv_output_pos=None, inv_batch_size=1): + + self.inv_batch_size = inv_batch_size + self.init_prompt(prompt) + out_put_pos = 0 if inv_output_pos is None else inv_output_pos + self.out_put_pos = out_put_pos + if verbose: + print("DDIM inversion...") + image_rec, ddim_latents = self.ddim_inversion(image_gt) + if verbose: + print("Done.") + return (image_gt, image_rec), ddim_latents[-1], ddim_latents, self.prompt_embeds[self.prompt_embeds.shape[0]//2:], self.pooled_prompt_embeds + + def __init__(self, model,num_ddim_steps,generator=None,scheduler_type="DDIM"): + self.model = model + self.tokenizer = self.model.tokenizer + self.num_ddim_steps=num_ddim_steps + if scheduler_type == "DDIM": + self.inverse_scheduler=DDIMInverseScheduler.from_config(self.model.scheduler.config) + self.inverse_scheduler.set_timesteps(num_ddim_steps) + elif scheduler_type=="DPMSolver": + self.inverse_scheduler=DPMSolverMultistepInverseScheduler.from_config(self.model.scheduler.config) + self.inverse_scheduler.set_timesteps(num_ddim_steps) + self.model.scheduler.set_timesteps(num_ddim_steps) + self.model.vae.to(dtype=torch.float32) + self.prompt = None + self.context = None + # self.device=self.model.unet.device + self.device = torch.device("cuda:0") + self.generator=generator + self.eta=0.0 + +def load_512(image_path, left=0, right=0, top=0, bottom=0): + if type(image_path) is str: + image = np.array(Image.open(image_path))[:, :, :3] + else: + image = image_path + h, w, c = image.shape + left = min(left, w - 1) + right = min(right, w - left - 1) + top = min(top, h - left - 1) + bottom = min(bottom, h - top - 1) + image = image[top:h - bottom, left:w - right] + h, w, c = image.shape + if h < w: + offset = (w - h) // 2 + image = image[:, offset:offset + h] + elif w < h: + offset = (h - w) // 2 + image = image[offset:offset + w] + image = np.array(Image.fromarray(image).resize((512, 512))) + return image + +def load_1024_mask(image_path, left=0, right=0, top=0, bottom=0,target_H=128,target_W=128): + if type(image_path) is str: + image = np.array(Image.open(image_path))[:, :, np.newaxis] + else: + image = image_path + if len(image.shape) == 4: + image = image[:, :, :, 0] + h, w, c = image.shape + left = min(left, w - 1) + right = min(right, w - left - 1) + top = min(top, h - left - 1) + bottom = min(bottom, h - top - 1) + image = image[top:h - bottom, left:w - right] + h, w, c = image.shape + if h < w: + offset = (w - h) // 2 + image = image[:, offset:offset + h] + elif w < h: + offset = (h - w) // 2 + image = image[offset:offset + w] + image=image.squeeze() + image = np.array(Image.fromarray(image).resize((target_H, target_W))) + return image + +def load_1024(image_path, left=0, right=0, top=0, bottom=0): + if type(image_path) is str: + image = np.array(Image.open(image_path).resize((1024, 1024)))[:, :, :3] + else: + image = image_path + h, w, c = image.shape + left = min(left, w - 1) + right = min(right, w - left - 1) + top = min(top, h - left - 1) + bottom = min(bottom, h - top - 1) + image = image[top:h - bottom, left:w - right] + h, w, c = image.shape + if h < w: + offset = (w - h) // 2 + image = image[:, offset:offset + h] + elif w < h: + offset = (h - w) // 2 + image = image[offset:offset + w] + image = np.array(Image.fromarray(image).resize((1024, 1024))) + return image \ No newline at end of file diff --git a/utils/sdxl.py b/utils/sdxl.py new file mode 100755 index 0000000000000000000000000000000000000000..30e3cee65b303fa704785adb6394a39eebb523a4 --- /dev/null +++ b/utils/sdxl.py @@ -0,0 +1,986 @@ + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +# import seaborn as sns +import matplotlib.pyplot as plt +import torch +from diffusers import StableDiffusionXLPipeline +from typing import Optional, Union, Tuple, List, Callable, Dict +import numpy as np +import copy +import torch.nn.functional as F +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from diffusers.utils import ( logging, randn_tensor, replace_example_docstring, ) +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +import os +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class sdxl(StableDiffusionXLPipeline): + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + controller=None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + same_init=False, + x_stars=None, + prox_guidance=True, + masa_control=False, + masa_mask=False, + masa_start_step=40, + masa_start_layer=55, + mask_file=None, + query_mask_time=[0, 10], + **kwargs + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + inv_batch_size = len(latents) if latents is not None else 1 + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + sample_ref_match=kwargs['sample_ref_match'] if 'sample_ref_match' in kwargs else None, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + same_init=same_init, #ADD + sample_ref_match=kwargs['sample_ref_match'] if 'sample_ref_match' in kwargs else None, + ) + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + # CHANGE START + score_delta,mask_edit=self.prox_regularization( + noise_pred_uncond, + noise_pred_text, + i, + t, + prox_guidance=prox_guidance, + ) + if mask_edit is not None: + a = 1 + noise_pred = noise_pred_uncond + guidance_scale * score_delta + # CHANGE END + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # ADD START + latents = self.proximal_guidance( + i, + t, + latents, + mask_edit, + prox_guidance=prox_guidance, + dtype=self.unet.dtype, + x_stars=x_stars, + controller=controller, + sample_ref_match=kwargs['sample_ref_match'] if 'sample_ref_match' in kwargs else None, + inv_batch_size=inv_batch_size, + only_inversion_align=kwargs['only_inversion_align'] if 'only_inversion_align' in kwargs else False, + ) + # ADD END + if controller is not None: + latents = controller.step_callback(latents) + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None,same_init=False,sample_ref_match=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if sample_ref_match is not None: + new_latents=randn_tensor((batch_size,*shape[1:]), generator=generator, device=device, dtype=dtype) + for key,value in sample_ref_match.items(): + new_latents[key]=latents[value].clone() + latents=new_latents + else: + if same_init is True: + if latents is None: + latents = randn_tensor((1,*shape[1:]), generator=generator, device=device, dtype=dtype).expand(shape).to(device) + else: + if batch_size>1 and latents.shape[0]==1: + latents=latents.expand(shape).to(device) + else: + latents = latents.to(device) + else: + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + sample_ref_match=None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + # ADD START + if sample_ref_match is not None: + new_negative_prompt_embeds=torch.zeros_like(prompt_embeds) + new_negative_pooled_prompt_embeds=torch.zeros_like(pooled_prompt_embeds) + for key,value in sample_ref_match.items(): + new_negative_prompt_embeds[key]=negative_prompt_embeds[value].clone() + new_negative_pooled_prompt_embeds[key]=negative_pooled_prompt_embeds[value].clone() + negative_prompt_embeds=new_negative_prompt_embeds + negative_pooled_prompt_embeds=new_negative_pooled_prompt_embeds + else: + if negative_pooled_prompt_embeds.shape[0]==1 and bs_embed!=1: + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.repeat(bs_embed,1) + if negative_prompt_embeds.shape[0]==1 and bs_embed!=1: + negative_prompt_embeds=negative_prompt_embeds.repeat(bs_embed,1,1) + # ADD END + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def encode_prompt_not_zero_uncond( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device),output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and isinstance(prompt,List) and negative_prompt == "": + negative_prompt = ["" for i in range(len(prompt))] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def prox_regularization( + self, + noise_pred_uncond, + noise_pred_text, + i, + t, + prox_guidance=False, + prox=None, + quantile=0.75, + recon_t=400, + dilate_radius=2, + ): + if prox_guidance is True: + mask_edit = None + if prox == 'l1': + score_delta = (noise_pred_text - noise_pred_uncond).float() + if quantile > 0: + threshold = score_delta.abs().quantile(quantile) + else: + threshold = -quantile # if quantile is negative, use it as a fixed threshold + score_delta -= score_delta.clamp(-threshold, threshold) + score_delta = torch.where(score_delta > 0, score_delta-threshold, score_delta) + score_delta = torch.where(score_delta < 0, score_delta+threshold, score_delta) + if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t): + mask_edit = (score_delta.abs() > threshold).float() + if dilate_radius > 0: + radius = int(dilate_radius) + mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius) + elif prox == 'l0': + score_delta = (noise_pred_text - noise_pred_uncond).float() + if quantile > 0: + threshold = score_delta.abs().quantile(quantile) + else: + threshold = -quantile # if quantile is negative, use it as a fixed threshold + score_delta -= score_delta.clamp(-threshold, threshold) + if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t): + mask_edit = (score_delta.abs() > threshold).float() + if dilate_radius > 0: + radius = int(dilate_radius) + mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius) + elif prox==None: + score_delta = (noise_pred_text - noise_pred_uncond).float() + if quantile > 0: + threshold = score_delta.abs().quantile(quantile) + else: + threshold = -quantile # if quantile is negative, use it as a fixed threshold + if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t): + mask_edit = (score_delta.abs() > threshold).float() + if dilate_radius > 0: + radius = int(dilate_radius) + mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius) + else: + raise NotImplementedError + return score_delta,mask_edit + else: + return noise_pred_text - noise_pred_uncond,None + + def proximal_guidance( + self, + i, + t, + latents, + mask_edit, + dtype, + prox_guidance=False, + recon_t=400, + recon_end=0, + recon_lr=0.1, + x_stars=None, + controller=None, + sample_ref_match=None, + inv_batch_size=1, + only_inversion_align=False, + ): + if mask_edit is not None and prox_guidance and (recon_t > recon_end and t < recon_t) or (recon_t < -recon_end and t > -recon_t): + if controller.layer_fusion.remove_mask is not None: + fix_mask = copy.deepcopy(controller.layer_fusion.remove_mask) + mask_edit[1] = (mask_edit[1]+fix_mask).clamp(0,1) + if mask_edit.shape[0] > 2: + mask_edit[2].fill_(1) + recon_mask = 1 - mask_edit + target_latents=x_stars[len(x_stars)-i-2] + new_target_latents=torch.zeros_like(latents) + for key,value in sample_ref_match.items(): + new_target_latents[key]=target_latents[value].clone() + latents = latents - recon_lr * (latents - new_target_latents) * recon_mask + return latents.to(dtype) + +def slerp(val, low, high): + """ taken from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4 + """ + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res + + +def slerp_tensor(val, low, high): + shape = low.shape + res = slerp(val, low.flatten(1), high.flatten(1)) + return res.reshape(shape) + + +def dilate(image, kernel_size, stride=1, padding=0): + """ + Perform dilation on a binary image using a square kernel. + """ + # Ensure the image is binary + assert image.max() <= 1 and image.min() >= 0 + + # Get the maximum value in each neighborhood + dilated_image = F.max_pool2d(image, kernel_size, stride, padding) + + return dilated_image + +def exec_classifier_free_guidance(model,latents,controller,t,guidance_scale, + do_classifier_free_guidance,noise_pred,guidance_rescale, + prox=None, quantile=0.75,image_enc=None, recon_lr=0.1, recon_t=400,recon_end_t=0, + inversion_guidance=False, reconstruction_guidance=False,x_stars=None, i=0, + use_localblend_mask=False, + save_heatmap=False,**kwargs): + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + #noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if prox is None and inversion_guidance is True: + prox = 'l1' + step_kwargs = { + 'ref_image': None, + 'recon_lr': 0, + 'recon_mask': None, + } + mask_edit = None + if prox is not None: + if prox == 'l1': + score_delta = (noise_pred_text - noise_pred_uncond).float() + if quantile > 0: + threshold = score_delta.abs().quantile(quantile) + else: + threshold = -quantile # if quantile is negative, use it as a fixed threshold + score_delta -= score_delta.clamp(-threshold, threshold) + score_delta = torch.where(score_delta > 0, score_delta-threshold, score_delta) + score_delta = torch.where(score_delta < 0, score_delta+threshold, score_delta) + if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t): + step_kwargs['ref_image'] = image_enc + step_kwargs['recon_lr'] = recon_lr + score_delta_norm=score_delta.abs() + score_delta_norm=(score_delta_norm - score_delta_norm.min ()) / (score_delta_norm.max () - score_delta_norm.min ()) + mask_edit = (score_delta.abs() > threshold).float() + if save_heatmap and i%10==0: + for kk in range(4): + sns.heatmap(mask_edit[1][kk].clone().cpu(), cmap='coolwarm') + plt.savefig(f'./vis/prox_inv/heatmap1_mask_{i}_{kk}.png') + plt.clf() + if kwargs.get('dilate_mask', 2) > 0: + radius = int(kwargs.get('dilate_mask', 2)) + mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius) + if save_heatmap and i%10==0: + for kk in range(4): + sns.heatmap(mask_edit[1][kk].clone().cpu(), cmap='coolwarm') + plt.savefig(f'./vis/prox_inv/heatmap1_mask_dilate_{i}_{kk}.png') + plt.clf() + step_kwargs['recon_mask'] = 1 - mask_edit + elif prox == 'l0': + score_delta = (noise_pred_text - noise_pred_uncond).float() + if quantile > 0: + threshold = score_delta.abs().quantile(quantile) + else: + threshold = -quantile # if quantile is negative, use it as a fixed threshold + score_delta -= score_delta.clamp(-threshold, threshold) + if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t): + step_kwargs['ref_image'] = image_enc + step_kwargs['recon_lr'] = recon_lr + mask_edit = (score_delta.abs() > threshold).float() + if kwargs.get('dilate_mask', 2) > 0: + radius = int(kwargs.get('dilate_mask', 2)) + mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius) + step_kwargs['recon_mask'] = 1 - mask_edit + else: + raise NotImplementedError + noise_pred = (noise_pred_uncond + guidance_scale * score_delta).to(model.unet.dtype) + else: + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + if reconstruction_guidance: + kwargs.update(step_kwargs) + latents = model.scheduler.step(noise_pred, t, latents, **kwargs, return_dict=False)[0] + if mask_edit is not None and inversion_guidance and (recon_t > recon_end_t and t < recon_t) or (recon_t < recon_end_t and t > -recon_t): + if use_localblend_mask: + assert hasattr(controller,"layer_fusion") + if save_heatmap and i%10==0: + sns.heatmap(controller.layer_fusion.mask[0][0].clone().cpu(), cmap='coolwarm') + plt.savefig(f'./vis/prox_inv/heatmap0_localblendmask_{i}.png') + plt.clf() + sns.heatmap(controller.layer_fusion.mask[1][0].clone().cpu(), cmap='coolwarm') + plt.savefig(f'./vis/prox_inv/heatmap1_localblendmask_{i}.png') + plt.clf() + layer_fusion_mask=controller.layer_fusion.mask.float() + layer_fusion_mask[0]=layer_fusion_mask[1] + recon_mask=1-layer_fusion_mask.expand_as(latents) + else: + recon_mask = 1 - mask_edit + target_latents=x_stars[len(x_stars)-i-2].expand_as(latents) + # if target_latents有四维 + if len(target_latents.shape)==4: + target_latents=target_latents[0] + latents = latents - recon_lr * (latents - target_latents) * recon_mask + # controller + if controller is not None: + latents = controller.step_callback(latents) + return latents.to(model.unet.dtype) diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cd71b6a15f9b9376b0ef161d40d00a78df790726 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,363 @@ +import cv2 +from matplotlib import pyplot as plt +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from datetime import datetime +import os +from typing import List, Dict + +def convert_and_resize_mask(mask): + if mask.ndim == 3: + mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) + resized_mask = cv2.resize(mask, (1024, 1024)) + return resized_mask + +def add_masks_resized(masks): + final_mask = np.zeros((1024, 1024), dtype=np.uint8) + for mask in masks: + if mask is not None: + resized_mask = convert_and_resize_mask(mask) + resized_mask = resized_mask.astype(np.uint8) + final_mask = cv2.add(final_mask, resized_mask) + return final_mask + +def attend_mask(mask_file, attend_scale=10, save=False): + if isinstance(mask_file, str): + if mask_file == '': + return torch.zeros([1, 1, 128, 128], dtype=torch.float32).cuda() + else: + image_with_mask = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) + elif len(mask_file.shape) == 3: # convert RGB to gray + image_with_mask = cv2.cvtColor(mask_file, cv2.COLOR_BGR2GRAY) + + else: + image_with_mask = mask_file + + if attend_scale != 0: + kernel = np.ones((abs(attend_scale), abs(attend_scale)), np.uint8) + if attend_scale > 0: + image_with_mask = cv2.dilate(image_with_mask, kernel, iterations=1) + else: + image_with_mask = cv2.erode(image_with_mask, kernel, iterations=1) + + if save and isinstance(mask_file, str): + new_mask_file_name = mask_file[:-4]+'_'+str(attend_scale)+'.jpg' + cv2.imwrite(new_mask_file_name, image_with_mask) + print("new_mask is saved in ", new_mask_file_name) + + dilated_image= cv2.resize(image_with_mask, (128, 128), interpolation=cv2.INTER_NEAREST) + dilated_image = torch.from_numpy(dilated_image).to(torch.float32).unsqueeze(0).unsqueeze(0).cuda() / 255 + return dilated_image + + +def panning(img_path=None, op_list=[['left', 0.2]], save=False, save_dir=None): + if isinstance(img_path, str): + img = cv2.imread(img_path) + else: + img = img_path + img_new = img.copy() + img_height, img_width, _ = img.shape + w_mask = 255 * np.ones((img_height, img_width), dtype=np.uint8) + h_mask = 255 * np.ones((img_height, img_width), dtype=np.uint8) + + for op in op_list: + scale = op[1] + if op[0] in ['right', 'left']: + K = int(scale*img_width) + elif op[0] in ['up', 'down']: + K = int(scale*img_height) + + if op[0] == 'right': + img_new[:, K:, :] = img[:, 0:img_width-K, :] + w_mask[:, K:] = 0 + elif op[0] == 'left': + img_new[:, 0:img_width-K, :] = img[:, K:, :] + w_mask[:, 0:img_width-K] = 0 + elif op[0] == 'down': + img_new[K:, :, :] = img[0:img_height-K, :, :] + h_mask[K:, :] = 0 + elif op[0] == 'up': + img_new[0:img_height-K, :, :] = img[K:, :, :] + h_mask[0:img_height-K, :] = 0 + img = img_new + + mask = w_mask + h_mask + mask[mask>0] = 255 + + if save: + if save_dir is None: + base_dir = os.path.dirname(img_path) + save_dir = os.path.join(base_dir, 'preprocess') + elif not os.path.exists(save_dir): + os.makedirs(save_dir) + resized_img_name = f"{save_dir}/resized_image.png" + resized_mask_name = f"{save_dir}/resized_mask.png" + cv2.imwrite(resized_img_name, img_new) + cv2.imwrite(resized_mask_name, mask) + return resized_img_name, resized_mask_name + else: + return img_new, mask + +def zooming(img_path=None, scale=[0.8, 0.8], save=False, save_dir=None): + if isinstance(img_path, str): + img = cv2.imread(img_path) + else: + img = img_path + img_new = img.copy() + img_height, img_width, _ = img.shape + mask = 255 * np.ones((img_height, img_width), dtype=np.uint8) + + new_height = int(img_height*scale[0]) + new_width = int(img_width*scale[1]) + resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA) + x_offset = (img_width - new_width) // 2 + y_offset = (img_height - new_height) // 2 + + img_new[y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized_img + mask[y_offset:y_offset + new_height, x_offset:x_offset + new_width] = 0 + + if save: + if save_dir is None: + base_dir = os.path.dirname(img_path) + save_dir = os.path.join(base_dir, 'preprocess') + elif not os.path.exists(save_dir): + os.makedirs(save_dir) + + resized_img_name = f"{save_dir}/resized_image.png" + resized_mask_name = f"{save_dir}/resized_mask.png" + cv2.imwrite(resized_img_name, img_new) + cv2.imwrite(resized_mask_name, mask) + return resized_img_name, resized_mask_name + else: + return img_new, mask + +def get_box(mask, bias = 2): + nonzero_indices = torch.nonzero(mask) + H, W = mask.shape[-2:] + min_x = max(min(nonzero_indices[:, 1]) - bias, 0) + min_y = max(min(nonzero_indices[:, 0]) - bias, 0) + max_x = min(max(nonzero_indices[:, 1]) + bias, W) + max_y = min(max(nonzero_indices[:, 0]) + bias, H) + return (min_x, min_y, max_x, max_y) + + +def draw_axis(img,grid_dict,x_len,y_len): + if grid_dict is not None and grid_dict is not False: + assert isinstance(grid_dict,Dict) + assert "x_title" in grid_dict + assert "y_title" in grid_dict + assert "x_text_list" in grid_dict + assert "y_text_list" in grid_dict + x_title=grid_dict["x_title"] + y_title=grid_dict["y_title"] + x_text_list=grid_dict['x_text_list'] + y_text_list=grid_dict['y_text_list'] + assert len(y_text_list)==y_len + assert len(x_text_list)==x_len + assert "font_size" in grid_dict + font_size=grid_dict["font_size"] + if "x_color" in grid_dict: + color_x=grid_dict['x_color'] + else: + color_x="black" + if "y_color" in grid_dict: + color_y=grid_dict['y_color'] + else: + color_y="black" + if "num_decimals" in grid_dict: + num_decimals=grid_dict['num_decimals'] + else: + num_decimals=2 + if "shift_x" in grid_dict: + shift_x_x,shift_x_y=grid_dict['shift_x'] + else: + shift_x_x=shift_x_y=0 + if "shift_y" in grid_dict: + shift_y_x,shift_y_y=grid_dict['shift_y'] + else: + shift_y_x=shift_y_y=0 + if "title" in grid_dict: + title=grid_dict['title'] + if isinstance(title,List): + all_title="" + for s in title: + all_title=all_title+s+"\n" + title=all_title + else: + title='' + width, height = img.size + num_x=x_len + num_y=y_len + + new_img = Image.new("RGB", (width + width // num_x+width // (num_x*2), height + height // num_y+height // (num_y*2)), color=(255, 255, 255)) + width,height=(width + width // num_x, height + height // num_y) + num_x=num_x+1 + num_y=num_y+1 + new_img.paste(img, (width // num_x, height // num_y)) + + draw = ImageDraw.Draw(new_img) + + font = ImageFont.truetype("DejaVuSansMono.ttf", font_size) + for i in range(2, num_x+1): + x = (i - 1) * width // num_x + width // (num_x * 2)-width *0.2// num_x+shift_x_x + y = height // (num_y * 2)+shift_x_y + k=i-1 + if isinstance(x_text_list[i-2],str): + draw.text((x, y), x_text_list[i-2], font=font,fill=color_x,align="center") + else: + draw.text((x, y), "{:.{}f}".format(x_text_list[i-2],num_decimals), font=font,fill=color_x,align="center") + + for i in range(2, num_y+1): + x = width // (num_x * 2)-width *0.1// num_x+shift_y_x + y = (i - 1) * height // num_y + height // (num_y * 2)-height*0.1//num_y+shift_y_y + k = i - 1 + if isinstance(y_text_list[i-2],str): + draw.text((x, y), y_text_list[i-2], font=font,fill=color_y,align="center") + else: + draw.text((x, y), "{:.{}f}".format(y_text_list[i-2],num_decimals), font=font,fill=color_y,align="center") + i=1 + x = (i - 1) * width // num_x + width // (num_x * 2)-height*0.1//num_y+shift_y_x + y = height // (num_y * 2)+width *0.2// num_x+shift_y_y + draw.text((x, y), y_title, font=font, fill=color_y,align="center") + x = width // (num_x * 2)+width *0.2// num_x+shift_x_x + y = (i - 1) * height // num_y + height // (num_y * 2)+shift_x_y + draw.text((x, y), x_title, font=font, fill=color_x,align="left") + x = width // 4 + y = (i - 1) * height // num_y + height // (num_y * 10) + draw.text((x, y), title, font=font, fill='blue',align="left") + else: + + new_img=img + return new_img + +def view_images(images, num_rows=1, offset_ratio=0.02,text="",folder=None,Notimestamp=False, +grid_dict=None,subfolder=None,verbose=True,output_dir=None,timestamp=None,**kwargs): + if type(images) is list: + num_empty = len(images) % num_rows + elif images.ndim == 4: + num_empty = images.shape[0] % num_rows + else: + images = [images] + num_empty = 0 + origin_size=kwargs.get("origin_size",None) + images_copy=images.copy() + for i, per_image in enumerate(images_copy): + if isinstance(per_image, Image.Image) and origin_size is not None: + images[i] = np.array(per_image.resize((origin_size[1],origin_size[0]))) + else: + images[i] = np.array(per_image) + + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty + num_items = len(images) + + h, w, c = images[0].shape + offset = int(h * offset_ratio) + num_cols = num_items // num_rows + image_ = np.ones((h * num_rows + offset * (num_rows - 1), + w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 + for i in range(num_rows): + for j in range(num_cols): + image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ + i * num_cols + j] + + pil_img = Image.fromarray(image_) + + pil_img_=draw_axis(pil_img,grid_dict,num_cols,num_rows) + if pil_img_.size[0]==pil_img_.size[1]: + pil_img_.resize((2048,2048)) + else: + longer_side = max(pil_img.size) + ratio = 2048/longer_side + new_size = tuple([int(x*ratio) for x in pil_img.size]) + pil_img = pil_img.resize(new_size) + + if verbose is False: + return pil_img + now = datetime.now() + if timestamp is None: + if Notimestamp is False: + timestamp = now.strftime("%Y-%m-%d_%H-%M-%S") + else: + timestamp="" + if output_dir is None: + if timestamp != "": + date, time = timestamp.split('_') + else: + date, time = "","" + if folder is not None: + dirname="./"+folder + filename = text+f"img_{timestamp}.jpg" + else: + if subfolder is not None: + dirname=os.path.join("./img", subfolder,date) + dirname=os.path.join(dirname,time) + filename =text+f"img_{timestamp}.jpg" + else: + dirname=os.path.join("./img",date) + dirname=os.path.join(dirname,time) + filename =text+f"img_{timestamp}.jpg" + else: + dirname=output_dir + filename =text+f"img_{timestamp}.jpg" + if not os.path.exists(dirname): + os.makedirs(dirname) + if verbose is True: + for i, img in enumerate(images): + im = Image.fromarray(img) + im.save(os.path.join(dirname,f"{i}.jpg")) + print(f"Output dir: {dirname}") + pil_img.save(os.path.join(dirname, filename)) + if grid_dict is not None and grid_dict is not False: + if not os.path.exists(dirname): + os.makedirs(dirname) + pil_img_.save(os.path.join(dirname, filename[:-4]+"_2048x.jpg")) + +def resize_image_with_mask(img, mask, scale): + if scale == 1: + return img, mask, None + img_blackboard = img.copy() # canvas + mask_blackboard = np.zeros_like(mask) + + M = cv2.moments(mask) + cx = int(M["m10"] / M["m00"]) + cy = int(M["m01"] / M["m00"]) + + scale_factor = [scale, scale] + resized_img = cv2.resize(img, None, fx=scale_factor[0], fy=scale_factor[1], interpolation=cv2.INTER_AREA) + resized_mask = cv2.resize(mask, None, fx=scale_factor[0], fy=scale_factor[1], interpolation=cv2.INTER_AREA) + new_cx, new_cy = cx * scale_factor[0], cy * scale_factor[1] + + for y in range(resized_mask.shape[0]): + for x in range(resized_mask.shape[1]): + if 0 <= cy - (new_cy - y) < img.shape[0] and 0 <= cx - (new_cx - x) < img.shape[1]: + mask_blackboard[int(cy - (new_cy - y)), int(cx - (new_cx - x))] = resized_mask[y, x] + img_blackboard[int(cy - (new_cy - y)), int(cx - (new_cx - x))] = resized_img[y, x] + return img_blackboard, mask_blackboard, (cx, cy) + +def flip_image_with_mask(img, mask, flip_code=None): + if flip_code is None: + return img, mask, None + M = cv2.moments(mask) + if M["m00"] == 0: + return img, mask + cx = int(M["m10"] / M["m00"]) + cy = int(M["m01"] / M["m00"]) + + h, w = img.shape[:2] + img_center = (w // 2, h // 2) + + tx = img_center[0] - cx + ty = img_center[1] - cy + + M_translate = np.float32([[1, 0, tx], [0, 1, ty]]) + img_translated = cv2.warpAffine(img, M_translate, (w, h)) + mask_translated = cv2.warpAffine(mask, M_translate, (w, h)) + flipped_img = cv2.flip(img_translated, flip_code) + flipped_mask = cv2.flip(mask_translated, flip_code) + M_translate_back = np.float32([[1, 0, -tx], [0, 1, -ty]]) + flipped_img_back = cv2.warpAffine(flipped_img, M_translate_back, (w, h)) + flipped_mask_back = cv2.warpAffine(flipped_mask, M_translate_back, (w, h)) + + return flipped_img_back, flipped_mask_back, (cx, cy)