diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000000000000000000000000000000000000..0110818d6b1e70612d770dd55726953ec8002c6f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,10 @@
+img/
+logfile/
+__pycache__/
+*/__pycache__/
+models/
+plt/
+docs/
+exp/
+examples/mask/
+examples/mask_box/
\ No newline at end of file
diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000000000000000000000000000000000000..224b59528850f7106228202b8d5901771b41c614
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,17 @@
+{
+ // Use IntelliSense to learn about possible attributes.
+ // Hover to view descriptions of existing attributes.
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "Demo",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "/home/jyr/demo/DesignEdit/app.py",
+ "console": "integratedTerminal",
+ "python": "/home/jyr/.conda/envs/new_design/bin/python",
+
+ }
+ ]
+}
\ No newline at end of file
diff --git a/README.md b/README.md
index 9504551be6c0d250f32f617faec2350c5c855aa1..ddb40e73375e68172b0c30cf147112ae756624ec 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@
---
title: DesignEdit
-emoji: 🏆
-colorFrom: pink
-colorTo: yellow
+emoji: 🌿
+colorFrom: yellow
+colorTo: green
sdk: gradio
-sdk_version: 4.25.0
+sdk_version: 4.24.0
app_file: app.py
pinned: false
---
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..891e638d4689ebee66ef358de780a849ea5121fa
--- /dev/null
+++ b/app.py
@@ -0,0 +1,61 @@
+import gradio as gr
+import spaces
+import torch
+
+import os
+import subprocess
+import shlex
+from src.demo.model import DesignEdit
+
+os.makedirs('models', exist_ok=True)
+subprocess.run(shlex.split('wget https://huggingface.co/Adapter/DragonDiffusion/resolve/main/model/efficient_sam_vits.pt -O models/efficient_sam_vits.pt'))
+
+from src.demo.demo import *
+import shlex
+import cv2
+
+pretrained_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
+model = DesignEdit(pretrained_model_path=pretrained_model_path)
+DESCRIPTION_1 = """
+
+ 🌿D
+ e
+ s
+ i
+ g
+ n
+ E
+ d
+ i
+ t🌿
+
+
+ """
+DESCRIPTION_2 = """
Multi-Layered Latent Decomposition and Fusion for Unified & Accurate Image Editing
"""
+DESCRIPTION_3 = """
+
+"""
+
+
+with gr.Blocks(css='style.css') as demo:
+ gr.HTML(DESCRIPTION_1)
+ gr.HTML(DESCRIPTION_2)
+ gr.HTML(DESCRIPTION_3)
+ with gr.Tabs():
+ with gr.TabItem('1️⃣ Object Removal'):
+ create_demo_remove(model.run_remove)
+ with gr.TabItem('2️⃣ Zooming Out'):
+ create_demo_zooming(model.run_zooming)
+ with gr.TabItem('3️⃣ Camera Panning'):
+ create_demo_panning(model.run_panning)
+ with gr.TabItem('4️⃣ Object Moving, Resizing and Flipping'):
+ create_demo_moving(model.run_moving)
+ with gr.TabItem('5️⃣ 🚩 Multi-Layered Editing 🚩'):
+ create_demo_layer(model.run_layer)
+ with gr.TabItem('🔧 Mask Preparation: Draw or Sketch'):
+ create_demo_mask_box(model.run_mask)
+demo.queue(max_size=20)
+demo.launch(max_threads=3, server_name="0.0.0.0")
+
diff --git a/examples/layer/01_horse/00.jpg b/examples/layer/01_horse/00.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..502be3f0bbada559c91718606b6b03c86021572b
Binary files /dev/null and b/examples/layer/01_horse/00.jpg differ
diff --git a/examples/layer/01_horse/mask0.jpg b/examples/layer/01_horse/mask0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..39377e481c21d9980df45c94c233f562809efbfa
Binary files /dev/null and b/examples/layer/01_horse/mask0.jpg differ
diff --git a/examples/layer/02_baby/00.jpg b/examples/layer/02_baby/00.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..f4fbb272ec041a9fd0793f10cc6830c42f740047
Binary files /dev/null and b/examples/layer/02_baby/00.jpg differ
diff --git a/examples/layer/02_baby/mask0.jpg b/examples/layer/02_baby/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f94b32a63688d6b2852e7f743dc4dffb742bdd53
Binary files /dev/null and b/examples/layer/02_baby/mask0.jpg differ
diff --git a/examples/layer/02_baby/mask1.jpg b/examples/layer/02_baby/mask1.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..bbed5f61ca38f384ab3f2f87c206efd74174e1da
Binary files /dev/null and b/examples/layer/02_baby/mask1.jpg differ
diff --git a/examples/layer/02_baby/mask2.jpg b/examples/layer/02_baby/mask2.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..c77e82d7eeea470bd3a787d018815eda5ee7f588
Binary files /dev/null and b/examples/layer/02_baby/mask2.jpg differ
diff --git a/examples/layer/03_text/00.jpg b/examples/layer/03_text/00.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2ae11e62b15bbabcba72e46fd373ad0c718a8fb3
Binary files /dev/null and b/examples/layer/03_text/00.jpg differ
diff --git a/examples/layer/03_text/01.jpg b/examples/layer/03_text/01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..654004437236e476431cb81d38434dec1d65e1d1
Binary files /dev/null and b/examples/layer/03_text/01.jpg differ
diff --git a/examples/layer/03_text/mask0.jpg b/examples/layer/03_text/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bf0ce7ffc1799bd21967250f02ecd5f5159ae384
Binary files /dev/null and b/examples/layer/03_text/mask0.jpg differ
diff --git a/examples/layer/03_text/mask1.jpg b/examples/layer/03_text/mask1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a1d805e541613f2c53f767b0e775ab053707cb34
Binary files /dev/null and b/examples/layer/03_text/mask1.jpg differ
diff --git a/examples/layer/04_cross/0.jpg b/examples/layer/04_cross/0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4edde7252b283d38f8e123da726c00e8ecf418fb
Binary files /dev/null and b/examples/layer/04_cross/0.jpg differ
diff --git a/examples/layer/04_cross/1.jpg b/examples/layer/04_cross/1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a71e2c9477b3ae1911be802d6020df06c55554da
Binary files /dev/null and b/examples/layer/04_cross/1.jpg differ
diff --git a/examples/layer/04_cross/2.jpg b/examples/layer/04_cross/2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f40b0440819237e7e18eb8e8479cd370da837e2c
Binary files /dev/null and b/examples/layer/04_cross/2.jpg differ
diff --git a/examples/layer/04_cross/3.jpg b/examples/layer/04_cross/3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a2ded809ac249e22dd2808dd99ca386c56632adc
Binary files /dev/null and b/examples/layer/04_cross/3.jpg differ
diff --git a/examples/layer/04_cross/mask0.jpg b/examples/layer/04_cross/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e07c313622a380fe141c9b799ab362ab54703345
Binary files /dev/null and b/examples/layer/04_cross/mask0.jpg differ
diff --git a/examples/layer/04_cross/mask1.jpg b/examples/layer/04_cross/mask1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3c90e633a455b762be1fc39a0ece2c62ea55bb5f
Binary files /dev/null and b/examples/layer/04_cross/mask1.jpg differ
diff --git a/examples/layer/04_cross/mask2.jpg b/examples/layer/04_cross/mask2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1d09eecc9bf6bf60321def094bc3ecb726963fee
Binary files /dev/null and b/examples/layer/04_cross/mask2.jpg differ
diff --git a/examples/layer/04_cross/mask3.jpg b/examples/layer/04_cross/mask3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..19b58983cf1463edfd7263fa40296ba723f4d03a
Binary files /dev/null and b/examples/layer/04_cross/mask3.jpg differ
diff --git a/examples/moving/01_ball/0.jpg b/examples/moving/01_ball/0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f11738232d02b9ffcb60fe4878764a1081d7fb7a
Binary files /dev/null and b/examples/moving/01_ball/0.jpg differ
diff --git a/examples/moving/01_ball/mask0.jpg b/examples/moving/01_ball/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..49e57f3cdb98cd1ee80e70318bcb081fb0fbd930
Binary files /dev/null and b/examples/moving/01_ball/mask0.jpg differ
diff --git a/examples/moving/02_bell/0.jpg b/examples/moving/02_bell/0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1a3e1fba25f66f21fe42265f867fa8daaa3f9032
Binary files /dev/null and b/examples/moving/02_bell/0.jpg differ
diff --git a/examples/moving/02_bell/mask0.jpg b/examples/moving/02_bell/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e4f89e4293acea880cd8a04bb0f7e2d69f7cbdb2
Binary files /dev/null and b/examples/moving/02_bell/mask0.jpg differ
diff --git a/examples/pan/01.jpg b/examples/pan/01.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..502be3f0bbada559c91718606b6b03c86021572b
Binary files /dev/null and b/examples/pan/01.jpg differ
diff --git a/examples/pan/02.jpg b/examples/pan/02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dd1d5d4450d43f0e2812a4740f92b2eabd5d7d11
Binary files /dev/null and b/examples/pan/02.jpg differ
diff --git a/examples/pan/03.jpg b/examples/pan/03.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d3a84ffebd40b0748deb1b2f1f8734f88b2d1ebc
Binary files /dev/null and b/examples/pan/03.jpg differ
diff --git a/examples/pan/04.jpg b/examples/pan/04.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..63c305e6fd77128f5b1664d89bf70bc0fad21a40
Binary files /dev/null and b/examples/pan/04.jpg differ
diff --git a/examples/pan/05.jpg b/examples/pan/05.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5617fd79f1a039673e5de8e3ee865351f96ff1cf
Binary files /dev/null and b/examples/pan/05.jpg differ
diff --git a/examples/pan/06.jpg b/examples/pan/06.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d5e6670f5491df5388f4d5181c60afb36cdcdfed
Binary files /dev/null and b/examples/pan/06.jpg differ
diff --git a/examples/remove/01_moto/0.jpg b/examples/remove/01_moto/0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..2164e166b5fb3e420a0c892b891ba3e3e68c6712
Binary files /dev/null and b/examples/remove/01_moto/0.jpg differ
diff --git a/examples/remove/01_moto/mask0.jpg b/examples/remove/01_moto/mask0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..3c602058c95f0165eb3e4b4b84dc6205de06cf66
Binary files /dev/null and b/examples/remove/01_moto/mask0.jpg differ
diff --git a/examples/remove/01_moto/mask1.jpg b/examples/remove/01_moto/mask1.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..d773289fee31961900990eb745f7cbd8ac85735e
Binary files /dev/null and b/examples/remove/01_moto/mask1.jpg differ
diff --git a/examples/remove/02_ring/0.jpg b/examples/remove/02_ring/0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..897e98afa4ebe74083b154fee58f77332c072973
Binary files /dev/null and b/examples/remove/02_ring/0.jpg differ
diff --git a/examples/remove/02_ring/mask0.jpg b/examples/remove/02_ring/mask0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..437703be8b840ac9f7d84587f7871872d60b9b1c
Binary files /dev/null and b/examples/remove/02_ring/mask0.jpg differ
diff --git a/examples/remove/02_ring/mask1.jpg b/examples/remove/02_ring/mask1.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..61001eff2dd2cb528a110d63897e7d48945f2cf5
Binary files /dev/null and b/examples/remove/02_ring/mask1.jpg differ
diff --git a/examples/remove/02_ring/mask2.jpg b/examples/remove/02_ring/mask2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..faabcd04ce296a29eecbf23a9fbb5fcd93a82dd9
Binary files /dev/null and b/examples/remove/02_ring/mask2.jpg differ
diff --git a/examples/remove/03_ball/0.jpg b/examples/remove/03_ball/0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..a32933213dd03dedcd1483fa5bee7582c0c51236
Binary files /dev/null and b/examples/remove/03_ball/0.jpg differ
diff --git a/examples/remove/03_ball/mask0.jpg b/examples/remove/03_ball/mask0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..41d2fb7ffc5b8605d8f6622bb646ca3afb8093a1
Binary files /dev/null and b/examples/remove/03_ball/mask0.jpg differ
diff --git a/examples/remove/03_ball/mask1.jpg b/examples/remove/03_ball/mask1.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..dbe79573f0b773655b22ea70003eb38e747d5394
Binary files /dev/null and b/examples/remove/03_ball/mask1.jpg differ
diff --git a/examples/remove/04_pikachu/0.jpg b/examples/remove/04_pikachu/0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..1204812f145e29311d15fd20433e7649954f8af2
Binary files /dev/null and b/examples/remove/04_pikachu/0.jpg differ
diff --git a/examples/remove/04_pikachu/mask0.jpg b/examples/remove/04_pikachu/mask0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..a197882e1b3f9c728096885ee6265d6875521905
Binary files /dev/null and b/examples/remove/04_pikachu/mask0.jpg differ
diff --git a/examples/remove/04_pikachu/mask1.jpg b/examples/remove/04_pikachu/mask1.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..80777884fa0fb4ee59deedc1b99477e2f95594fc
Binary files /dev/null and b/examples/remove/04_pikachu/mask1.jpg differ
diff --git a/examples/remove/04_pikachu/mask2.jpg b/examples/remove/04_pikachu/mask2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2eb6c21723b01a5c74b50563a096cda0e21b87c0
Binary files /dev/null and b/examples/remove/04_pikachu/mask2.jpg differ
diff --git a/examples/remove/05_betty/0.jpg b/examples/remove/05_betty/0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..39eceadc2330703743090990bdbff8aca97afb13
Binary files /dev/null and b/examples/remove/05_betty/0.jpg differ
diff --git a/examples/remove/05_betty/mask0.jpg b/examples/remove/05_betty/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..87c0784b1c402471436c89166f8114c36535c672
Binary files /dev/null and b/examples/remove/05_betty/mask0.jpg differ
diff --git a/examples/zoom/01.jpg b/examples/zoom/01.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..502be3f0bbada559c91718606b6b03c86021572b
Binary files /dev/null and b/examples/zoom/01.jpg differ
diff --git a/examples/zoom/02.jpg b/examples/zoom/02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f5730a436e7a08e47a7f2e29e28617271d9d1c44
Binary files /dev/null and b/examples/zoom/02.jpg differ
diff --git a/examples/zoom/03.jpg b/examples/zoom/03.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..f7a0cc7d115a9e8a59959c910d66774ab0d44af8
Binary files /dev/null and b/examples/zoom/03.jpg differ
diff --git a/examples/zoom/04.jpg b/examples/zoom/04.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..e4341114794f8be5fe80c9d442693630db1c8ff6
Binary files /dev/null and b/examples/zoom/04.jpg differ
diff --git a/examples/zoom/05.jpg b/examples/zoom/05.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f4a2f82d540c7f6731b8e4dff5dcdb74e18c4ab6
Binary files /dev/null and b/examples/zoom/05.jpg differ
diff --git a/examples/zoom/06.jpg b/examples/zoom/06.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..63c305e6fd77128f5b1664d89bf70bc0fad21a40
Binary files /dev/null and b/examples/zoom/06.jpg differ
diff --git a/examples/zoom/07.jpg b/examples/zoom/07.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..aaffb0d35dff979c342e5bc383e0c8d999607134
Binary files /dev/null and b/examples/zoom/07.jpg differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d434bd9eaabf87af87056786d9479729abc8a266
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,14 @@
+diffusers==0.18.2
+torch==2.0.1
+torchvision==0.15.2
+matplotlib==3.7.2
+numpy==1.25.1
+opencv_python==4.8.0.74
+opencv_python_headless==4.8.0.74
+Pillow==10.1.0
+Pillow==10.1.0
+transformers==4.35.0
+gradio==4.0.0
+basicsr==1.4.2
+accelerate==0.21.0
+invisible-watermark
\ No newline at end of file
diff --git a/sam/efficient_sam/__init__.py b/sam/efficient_sam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..22a2d29cee5d2a2df01944c90b6e01f879301f3f
--- /dev/null
+++ b/sam/efficient_sam/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+from .build_efficient_sam import (
+ build_efficient_sam_vitt,
+ build_efficient_sam_vits,
+)
diff --git a/sam/efficient_sam/build_efficient_sam.py b/sam/efficient_sam/build_efficient_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d138e7335d10c8cbf43aa9ceafef12eda92a66e
--- /dev/null
+++ b/sam/efficient_sam/build_efficient_sam.py
@@ -0,0 +1,22 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .efficient_sam import build_efficient_sam
+
+def build_efficient_sam_vitt():
+ return build_efficient_sam(
+ encoder_patch_embed_dim=192,
+ encoder_num_heads=3,
+ checkpoint="models/efficient_sam_vitt.pt",
+ ).eval()
+
+
+def build_efficient_sam_vits():
+ return build_efficient_sam(
+ encoder_patch_embed_dim=384,
+ encoder_num_heads=6,
+ checkpoint="models/efficient_sam_vits.pt",
+ ).eval()
diff --git a/sam/efficient_sam/efficient_sam.py b/sam/efficient_sam/efficient_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a3ba4c328e0716a8e3166b7d267353757aa76d7
--- /dev/null
+++ b/sam/efficient_sam/efficient_sam.py
@@ -0,0 +1,310 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Any, List, Tuple, Type
+
+import torch
+import torch.nn.functional as F
+
+from torch import nn, Tensor
+
+from .efficient_sam_decoder import MaskDecoder, PromptEncoder
+from .efficient_sam_encoder import ImageEncoderViT
+from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer
+
+class EfficientSam(nn.Module):
+ mask_threshold: float = 0.0
+ image_format: str = "RGB"
+
+ def __init__(
+ self,
+ image_encoder: ImageEncoderViT,
+ prompt_encoder: PromptEncoder,
+ decoder_max_num_input_points: int,
+ mask_decoder: MaskDecoder,
+ pixel_mean: List[float] = [0.485, 0.456, 0.406],
+ pixel_std: List[float] = [0.229, 0.224, 0.225],
+ ) -> None:
+ """
+ SAM predicts object masks from an image and input prompts.
+
+ Arguments:
+ image_encoder (ImageEncoderViT): The backbone used to encode the
+ image into image embeddings that allow for efficient mask prediction.
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+ and encoded prompts.
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
+ """
+ super().__init__()
+ self.image_encoder = image_encoder
+ self.prompt_encoder = prompt_encoder
+ self.decoder_max_num_input_points = decoder_max_num_input_points
+ self.mask_decoder = mask_decoder
+ self.register_buffer(
+ "pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False
+ )
+ self.register_buffer(
+ "pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False
+ )
+
+ @torch.jit.export
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ batched_points: torch.Tensor,
+ batched_point_labels: torch.Tensor,
+ multimask_output: bool,
+ input_h: int,
+ input_w: int,
+ output_h: int = -1,
+ output_w: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predicts masks given image embeddings and prompts. This only runs the decoder.
+
+ Arguments:
+ image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
+ batched_points: A tensor of shape [B, max_num_queries, num_pts, 2]
+ batched_point_labels: A tensor of shape [B, max_num_queries, num_pts]
+ Returns:
+ A tuple of two tensors:
+ low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks
+ iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
+ """
+
+ batch_size, max_num_queries, num_pts, _ = batched_points.shape
+ num_pts = batched_points.shape[2]
+ rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w)
+
+ if num_pts > self.decoder_max_num_input_points:
+ rescaled_batched_points = rescaled_batched_points[
+ :, :, : self.decoder_max_num_input_points, :
+ ]
+ batched_point_labels = batched_point_labels[
+ :, :, : self.decoder_max_num_input_points
+ ]
+ elif num_pts < self.decoder_max_num_input_points:
+ rescaled_batched_points = F.pad(
+ rescaled_batched_points,
+ (0, 0, 0, self.decoder_max_num_input_points - num_pts),
+ value=-1.0,
+ )
+ batched_point_labels = F.pad(
+ batched_point_labels,
+ (0, self.decoder_max_num_input_points - num_pts),
+ value=-1.0,
+ )
+
+ sparse_embeddings = self.prompt_encoder(
+ rescaled_batched_points.reshape(
+ batch_size * max_num_queries, self.decoder_max_num_input_points, 2
+ ),
+ batched_point_labels.reshape(
+ batch_size * max_num_queries, self.decoder_max_num_input_points
+ ),
+ )
+ sparse_embeddings = sparse_embeddings.view(
+ batch_size,
+ max_num_queries,
+ sparse_embeddings.shape[1],
+ sparse_embeddings.shape[2],
+ )
+ low_res_masks, iou_predictions = self.mask_decoder(
+ image_embeddings,
+ self.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ multimask_output=multimask_output,
+ )
+ _, num_predictions, low_res_size, _ = low_res_masks.shape
+
+ if output_w > 0 and output_h > 0:
+ output_masks = F.interpolate(
+ low_res_masks, (output_h, output_w), mode="bicubic"
+ )
+ output_masks = torch.reshape(
+ output_masks,
+ (batch_size, max_num_queries, num_predictions, output_h, output_w),
+ )
+ else:
+ output_masks = torch.reshape(
+ low_res_masks,
+ (
+ batch_size,
+ max_num_queries,
+ num_predictions,
+ low_res_size,
+ low_res_size,
+ ),
+ )
+ iou_predictions = torch.reshape(
+ iou_predictions, (batch_size, max_num_queries, num_predictions)
+ )
+ sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True)
+ iou_predictions = torch.take_along_dim(iou_predictions, sorted_ids, dim=2)
+ output_masks = torch.take_along_dim(
+ output_masks, sorted_ids[..., None, None], dim=2
+ )
+ return output_masks, iou_predictions
+
+ def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int):
+ return torch.stack(
+ [
+ torch.where(
+ batched_points[..., 0] >= 0,
+ batched_points[..., 0] * self.image_encoder.img_size / input_w,
+ -1.0,
+ ),
+ torch.where(
+ batched_points[..., 1] >= 0,
+ batched_points[..., 1] * self.image_encoder.img_size / input_h,
+ -1.0,
+ ),
+ ],
+ dim=-1,
+ )
+
+ @torch.jit.export
+ def get_image_embeddings(self, batched_images) -> torch.Tensor:
+ """
+ Predicts masks end-to-end from provided images and prompts.
+ If prompts are not known in advance, using SamPredictor is
+ recommended over calling the model directly.
+
+ Arguments:
+ batched_images: A tensor of shape [B, 3, H, W]
+ Returns:
+ List of image embeddings each of of shape [B, C(i), H(i), W(i)].
+ The last embedding corresponds to the final layer.
+ """
+ batched_images = self.preprocess(batched_images)
+ return self.image_encoder(batched_images)
+
+ def forward(
+ self,
+ batched_images: torch.Tensor,
+ batched_points: torch.Tensor,
+ batched_point_labels: torch.Tensor,
+ scale_to_original_image_size: bool = True,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predicts masks end-to-end from provided images and prompts.
+ If prompts are not known in advance, using SamPredictor is
+ recommended over calling the model directly.
+
+ Arguments:
+ batched_images: A tensor of shape [B, 3, H, W]
+ batched_points: A tensor of shape [B, num_queries, max_num_pts, 2]
+ batched_point_labels: A tensor of shape [B, num_queries, max_num_pts]
+
+ Returns:
+ A list tuples of two tensors where the ith element is by considering the first i+1 points.
+ low_res_mask: A tensor of shape [B, 256, 256] of predicted masks
+ iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
+ """
+ batch_size, _, input_h, input_w = batched_images.shape
+ image_embeddings = self.get_image_embeddings(batched_images)
+ return self.predict_masks(
+ image_embeddings,
+ batched_points,
+ batched_point_labels,
+ multimask_output=True,
+ input_h=input_h,
+ input_w=input_w,
+ output_h=input_h if scale_to_original_image_size else -1,
+ output_w=input_w if scale_to_original_image_size else -1,
+ )
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ if (
+ x.shape[2] != self.image_encoder.img_size
+ or x.shape[3] != self.image_encoder.img_size
+ ):
+ x = F.interpolate(
+ x,
+ (self.image_encoder.img_size, self.image_encoder.img_size),
+ mode="bilinear",
+ )
+ return (x - self.pixel_mean) / self.pixel_std
+
+
+def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None):
+ img_size = 1024
+ encoder_patch_size = 16
+ encoder_depth = 12
+ encoder_mlp_ratio = 4.0
+ encoder_neck_dims = [256, 256]
+ decoder_max_num_input_points = 6
+ decoder_transformer_depth = 2
+ decoder_transformer_mlp_dim = 2048
+ decoder_num_heads = 8
+ decoder_upscaling_layer_dims = [64, 32]
+ num_multimask_outputs = 3
+ iou_head_depth = 3
+ iou_head_hidden_dim = 256
+ activation = "gelu"
+ normalization_type = "layer_norm"
+ normalize_before_activation = False
+
+ assert activation == "relu" or activation == "gelu"
+ if activation == "relu":
+ activation_fn = nn.ReLU
+ else:
+ activation_fn = nn.GELU
+
+ image_encoder = ImageEncoderViT(
+ img_size=img_size,
+ patch_size=encoder_patch_size,
+ in_chans=3,
+ patch_embed_dim=encoder_patch_embed_dim,
+ normalization_type=normalization_type,
+ depth=encoder_depth,
+ num_heads=encoder_num_heads,
+ mlp_ratio=encoder_mlp_ratio,
+ neck_dims=encoder_neck_dims,
+ act_layer=activation_fn,
+ )
+
+ image_embedding_size = image_encoder.image_embedding_size
+ encoder_transformer_output_dim = image_encoder.transformer_output_dim
+
+ sam = EfficientSam(
+ image_encoder=image_encoder,
+ prompt_encoder=PromptEncoder(
+ embed_dim=encoder_transformer_output_dim,
+ image_embedding_size=(image_embedding_size, image_embedding_size),
+ input_image_size=(img_size, img_size),
+ ),
+ decoder_max_num_input_points=decoder_max_num_input_points,
+ mask_decoder=MaskDecoder(
+ transformer_dim=encoder_transformer_output_dim,
+ transformer=TwoWayTransformer(
+ depth=decoder_transformer_depth,
+ embedding_dim=encoder_transformer_output_dim,
+ num_heads=decoder_num_heads,
+ mlp_dim=decoder_transformer_mlp_dim,
+ activation=activation_fn,
+ normalize_before_activation=normalize_before_activation,
+ ),
+ num_multimask_outputs=num_multimask_outputs,
+ activation=activation_fn,
+ normalization_type=normalization_type,
+ normalize_before_activation=normalize_before_activation,
+ iou_head_depth=iou_head_depth - 1,
+ iou_head_hidden_dim=iou_head_hidden_dim,
+ upscaling_layer_dims=decoder_upscaling_layer_dims,
+ ),
+ pixel_mean=[0.485, 0.456, 0.406],
+ pixel_std=[0.229, 0.224, 0.225],
+ )
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f, map_location="cpu")
+ sam.load_state_dict(state_dict["model"])
+ return sam
diff --git a/sam/efficient_sam/efficient_sam_decoder.py b/sam/efficient_sam/efficient_sam_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..380f41c1650f1ffb824d8d911f810eabedc66ddd
--- /dev/null
+++ b/sam/efficient_sam/efficient_sam_decoder.py
@@ -0,0 +1,315 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Tuple, Type
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .mlp import MLPBlock
+
+
+class PromptEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ image_embedding_size: Tuple[int, int],
+ input_image_size: Tuple[int, int],
+ ) -> None:
+ """
+ Encodes prompts for input to SAM's mask decoder.
+
+ Arguments:
+ embed_dim (int): The prompts' embedding dimension
+ image_embedding_size (tuple(int, int)): The spatial size of the
+ image embedding, as (H, W).
+ input_image_size (int): The padded size of the image as input
+ to the image encoder, as (H, W).
+ """
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.input_image_size = input_image_size
+ self.image_embedding_size = image_embedding_size
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+ self.invalid_points = nn.Embedding(1, embed_dim)
+ self.point_embeddings = nn.Embedding(1, embed_dim)
+ self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim)
+ self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim)
+
+ def get_dense_pe(self) -> torch.Tensor:
+ """
+ Returns the positional encoding used to encode point prompts,
+ applied to a dense set of points the shape of the image encoding.
+
+ Returns:
+ torch.Tensor: Positional encoding with shape
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
+ """
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+ def _embed_points(
+ self,
+ points: torch.Tensor,
+ labels: torch.Tensor,
+ ) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ point_embedding = self.pe_layer.forward_with_coords(
+ points, self.input_image_size
+ )
+ invalid_label_ids = torch.eq(labels, -1)[:,:,None]
+ point_label_ids = torch.eq(labels, 1)[:,:,None]
+ topleft_label_ids = torch.eq(labels, 2)[:,:,None]
+ bottomright_label_ids = torch.eq(labels, 3)[:,:,None]
+ point_embedding = point_embedding + self.invalid_points.weight[:,None,:] * invalid_label_ids
+ point_embedding = point_embedding + self.point_embeddings.weight[:,None,:] * point_label_ids
+ point_embedding = point_embedding + self.bbox_top_left_embeddings.weight[:,None,:] * topleft_label_ids
+ point_embedding = point_embedding + self.bbox_bottom_right_embeddings.weight[:,None,:] * bottomright_label_ids
+ return point_embedding
+
+ def forward(
+ self,
+ coords,
+ labels,
+ ) -> torch.Tensor:
+ """
+ Embeds different types of prompts, returning both sparse and dense
+ embeddings.
+
+ Arguments:
+ points: A tensor of shape [B, 2]
+ labels: An integer tensor of shape [B] where each element is 1,2 or 3.
+
+ Returns:
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
+ BxNx(embed_dim), where N is determined by the number of input points
+ and boxes.
+ """
+ return self._embed_points(coords, labels)
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """
+ Positional encoding using random spatial frequencies.
+ """
+
+ def __init__(self, num_pos_feats: int) -> None:
+ super().__init__()
+ self.register_buffer(
+ "positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats))
+ )
+
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Positionally encode points that are normalized to [0,1]."""
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generate positional encoding for a grid of the specified size."""
+ h, w = size
+ device = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones([h, w], device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+
+ def forward_with_coords(
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+ ) -> torch.Tensor:
+ """Positionally encode points that are not normalized to [0,1]."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
+
+
+class MaskDecoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int,
+ activation: Type[nn.Module],
+ normalization_type: str,
+ normalize_before_activation: bool,
+ iou_head_depth: int,
+ iou_head_hidden_dim: int,
+ upscaling_layer_dims: List[int],
+ ) -> None:
+ """
+ Predicts masks given an image and prompt embeddings, using a
+ transformer architecture.
+
+ Arguments:
+ transformer_dim (int): the channel dimension of the transformer
+ transformer (nn.Module): the transformer used to predict masks
+ num_multimask_outputs (int): the number of masks to predict
+ when disambiguating masks
+ activation (nn.Module): the type of activation to use when
+ upscaling masks
+ iou_head_depth (int): the depth of the MLP used to predict
+ mask quality
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
+ used to predict mask quality
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ if num_multimask_outputs > 1:
+ self.num_mask_tokens = num_multimask_outputs + 1
+ else:
+ self.num_mask_tokens = 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+ output_dim_after_upscaling = transformer_dim
+
+ self.final_output_upscaling_layers = nn.ModuleList([])
+ for idx, layer_dims in enumerate(upscaling_layer_dims):
+ self.final_output_upscaling_layers.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ output_dim_after_upscaling,
+ layer_dims,
+ kernel_size=2,
+ stride=2,
+ ),
+ nn.GroupNorm(1, layer_dims)
+ if idx < len(upscaling_layer_dims) - 1
+ else nn.Identity(),
+ activation(),
+ )
+ )
+ output_dim_after_upscaling = layer_dims
+
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [
+ MLPBlock(
+ input_dim=transformer_dim,
+ hidden_dim=transformer_dim,
+ output_dim=output_dim_after_upscaling,
+ num_layers=2,
+ act=activation,
+ )
+ for i in range(self.num_mask_tokens)
+ ]
+ )
+
+ self.iou_prediction_head = MLPBlock(
+ input_dim=transformer_dim,
+ hidden_dim=iou_head_hidden_dim,
+ output_dim=self.num_mask_tokens,
+ num_layers=iou_head_depth,
+ act=activation,
+ )
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Arguments:
+ image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable).
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+ multimask_output (bool): Whether to return multiple masks or a single
+ mask.
+
+ Returns:
+ torch.Tensor: batched predicted masks
+ torch.Tensor: batched predictions of mask quality
+ """
+
+ (
+ batch_size,
+ max_num_queries,
+ sparse_embed_dim_1,
+ sparse_embed_dim_2,
+ ) = sparse_prompt_embeddings.shape
+
+ (
+ _,
+ image_embed_dim_c,
+ image_embed_dim_h,
+ image_embed_dim_w,
+ ) = image_embeddings.shape
+
+ # Tile the image embedding for all queries.
+ image_embeddings_tiled = torch.tile(
+ image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1]
+ ).view(
+ batch_size * max_num_queries,
+ image_embed_dim_c,
+ image_embed_dim_h,
+ image_embed_dim_w,
+ )
+ sparse_prompt_embeddings = sparse_prompt_embeddings.reshape(
+ batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2
+ )
+ masks, iou_pred = self.predict_masks(
+ image_embeddings=image_embeddings_tiled,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ )
+ if multimask_output and self.num_multimask_outputs > 1:
+ return masks[:, 1:, :], iou_pred[:, 1:]
+ else:
+ return masks[:, :1, :], iou_pred[:, :1]
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts masks. See 'forward' for more details."""
+ # Concatenate output tokens
+ output_tokens = torch.cat(
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
+ )
+ output_tokens = output_tokens.unsqueeze(0).expand(
+ sparse_prompt_embeddings.size(0), -1, -1
+ )
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+ # Expand per-image data in batch direction to be per-mask
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = image_embeddings.shape
+ hs, src = self.transformer(image_embeddings, pos_src, tokens)
+ iou_token_out = hs[:, 0, :]
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ upscaled_embedding = src.transpose(1, 2).view(b, c, h, w)
+
+ for upscaling_layer in self.final_output_upscaling_layers:
+ upscaled_embedding = upscaling_layer(upscaled_embedding)
+ hyper_in_list: List[torch.Tensor] = []
+ for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps):
+ hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :]))
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ return masks, iou_pred
diff --git a/sam/efficient_sam/efficient_sam_encoder.py b/sam/efficient_sam/efficient_sam_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..73fd7ac470b42df738e5e6bcbbcb60b4f30fb46e
--- /dev/null
+++ b/sam/efficient_sam/efficient_sam_encoder.py
@@ -0,0 +1,257 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LayerNorm2d(nn.Module):
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ img_size,
+ patch_size,
+ in_chans,
+ embed_dim,
+ ):
+ super().__init__()
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ bias=True,
+ )
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = self.proj(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ qkv_bias,
+ qk_scale=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ )
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ return x
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ act_layer=nn.GELU,
+ ):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ )
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ )
+
+ def forward(self, x):
+ x = x + self.attn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+@torch.jit.export
+def get_abs_pos(
+ abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int]
+) -> torch.Tensor:
+ """
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
+ dimension for the original embeddings.
+ Args:
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
+ hw (Tuple): size of input image tokens.
+
+ Returns:
+ Absolute positional embeddings after processing with shape (1, H, W, C)
+ """
+ h = hw[0]
+ w = hw[1]
+ if has_cls_token:
+ abs_pos = abs_pos[:, 1:]
+ xy_num = abs_pos.shape[1]
+ size = int(math.sqrt(xy_num))
+ assert size * size == xy_num
+
+ if size != h or size != w:
+ new_abs_pos = F.interpolate(
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
+ size=(h, w),
+ mode="bicubic",
+ align_corners=False,
+ )
+ return new_abs_pos.permute(0, 2, 3, 1)
+ else:
+ return abs_pos.reshape(1, h, w, -1)
+
+
+# Image encoder for efficient SAM.
+class ImageEncoderViT(nn.Module):
+ def __init__(
+ self,
+ img_size: int,
+ patch_size: int,
+ in_chans: int,
+ patch_embed_dim: int,
+ normalization_type: str,
+ depth: int,
+ num_heads: int,
+ mlp_ratio: float,
+ neck_dims: List[int],
+ act_layer: Type[nn.Module],
+ ) -> None:
+ """
+ Args:
+ img_size (int): Input image size.
+ patch_size (int): Patch size.
+ in_chans (int): Number of input image channels.
+ patch_embed_dim (int): Patch embedding dimension.
+ depth (int): Depth of ViT.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ act_layer (nn.Module): Activation layer.
+ """
+ super().__init__()
+
+ self.img_size = img_size
+ self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
+ self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1]
+ self.pretrain_use_cls_token = True
+ pretrain_img_size = 224
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim)
+ # Initialize absolute positional embedding with pretrain image size.
+ num_patches = (pretrain_img_size // patch_size) * (
+ pretrain_img_size // patch_size
+ )
+ num_positions = num_patches + 1
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim))
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True)
+ self.blocks.append(vit_block)
+ self.neck = nn.Sequential(
+ nn.Conv2d(
+ patch_embed_dim,
+ neck_dims[0],
+ kernel_size=1,
+ bias=False,
+ ),
+ LayerNorm2d(neck_dims[0]),
+ nn.Conv2d(
+ neck_dims[0],
+ neck_dims[0],
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ LayerNorm2d(neck_dims[0]),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ assert (
+ x.shape[2] == self.img_size and x.shape[3] == self.img_size
+ ), "input image size must match self.img_size"
+ x = self.patch_embed(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ x = x + get_abs_pos(
+ self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]]
+ )
+ num_patches = x.shape[1]
+ assert x.shape[2] == num_patches
+ x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3])
+ for blk in self.blocks:
+ x = blk(x)
+ x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2])
+ x = self.neck(x.permute(0, 3, 1, 2))
+ return x
diff --git a/sam/efficient_sam/mlp.py b/sam/efficient_sam/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3be8db49cbf6990ce55a467e7e62f60daf62c9d
--- /dev/null
+++ b/sam/efficient_sam/mlp.py
@@ -0,0 +1,29 @@
+from typing import Type
+
+from torch import nn
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLPBlock(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ act: Type[nn.Module],
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Sequential(nn.Linear(n, k), act())
+ for n, k in zip([input_dim] + h, [hidden_dim] * num_layers)
+ )
+ self.fc = nn.Linear(hidden_dim, output_dim)
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = layer(x)
+ return self.fc(x)
diff --git a/sam/efficient_sam/two_way_transformer.py b/sam/efficient_sam/two_way_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..881e76fd7efd07eeeef12999931fc2b74db406a9
--- /dev/null
+++ b/sam/efficient_sam/two_way_transformer.py
@@ -0,0 +1,264 @@
+import math
+from typing import Tuple, Type
+import torch
+from torch import nn, Tensor
+from .mlp import MLPBlock
+
+
+
+
+class TwoWayTransformer(nn.Module):
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module],
+ normalize_before_activation: bool,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ A transformer decoder that attends to an input image using
+ queries whose positional embedding is supplied.
+
+ Args:
+ depth (int): number of layers in the transformer
+ embedding_dim (int): the channel dimension for the input embeddings
+ num_heads (int): the number of heads for multihead attention. Must
+ divide embedding_dim
+ mlp_dim (int): the channel dimension internal to the MLP block
+ activation (nn.Module): the activation to use in the MLP block
+ """
+ super().__init__()
+ self.depth = depth
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.mlp_dim = mlp_dim
+ self.layers = nn.ModuleList()
+
+ for i in range(depth):
+ curr_layer = TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ normalize_before_activation=normalize_before_activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ self.layers.append(curr_layer)
+
+ self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock(
+ embedding_dim,
+ num_heads,
+ downsample_rate=attention_downsample_rate,
+ )
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+ def forward(
+ self,
+ image_embedding: Tensor,
+ image_pe: Tensor,
+ point_embedding: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ image_embedding (torch.Tensor): image to attend to. Should be shape
+ B x embedding_dim x h x w for any h and w.
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
+ have the same shape as image_embedding.
+ point_embedding (torch.Tensor): the embedding to add to the query points.
+ Must have shape B x N_points x embedding_dim for any N_points.
+
+ Returns:
+ torch.Tensor: the processed point_embedding
+ torch.Tensor: the processed image_embedding
+ """
+
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ bs, c, h, w = image_embedding.shape
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+
+ # Apply transformer blocks and final layernorm
+ for idx, layer in enumerate(self.layers):
+ queries, keys = layer(
+ queries=queries,
+ keys=keys,
+ query_pe=point_embedding,
+ key_pe=image_pe,
+ )
+
+ # Apply the final attention layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+ return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module],
+ normalize_before_activation: bool,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ A transformer block with four layers: (1) self-attention of sparse
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
+ inputs.
+
+ Arguments:
+ embedding_dim (int): the channel dimension of the embeddings
+ num_heads (int): the number of heads in the attention layers
+ mlp_dim (int): the hidden dimension of the mlp block
+ activation (nn.Module): the activation of the mlp block
+ skip_first_layer_pe (bool): skip the PE on the first layer
+ """
+ super().__init__()
+ self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock(
+ embedding_dim,
+ num_heads,
+ downsample_rate=attention_downsample_rate,
+ )
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLPBlock(
+ embedding_dim,
+ mlp_dim,
+ embedding_dim,
+ 1,
+ activation,
+ )
+
+ self.norm3 = nn.LayerNorm(embedding_dim)
+
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock(
+ embedding_dim,
+ num_heads,
+ downsample_rate=attention_downsample_rate,
+ )
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ # Self attention block
+ if not self.skip_first_layer_pe:
+ queries = queries + query_pe
+ attn_out = self.self_attn(q=queries, k=queries, v=queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+
+ return queries, keys
+
+
+class AttentionForTwoWayAttentionBlock(nn.Module):
+ """
+ An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ downsample_rate: int = 1,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ assert (
+ self.internal_dim % num_heads == 0
+ ), "num_heads must divide embedding_dim."
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ # The fan_out is incorrect, but matches pytorch's initialization
+ # for which qkv is a single 3*embedding_dim x embedding_dim matrix
+ fan_in = self.embedding_dim
+ fan_out = 3 * self.internal_dim
+ # Xavier uniform with our custom fan_out
+ bnd = math.sqrt(6 / (fan_in + fan_out))
+ nn.init.uniform_(self.q_proj.weight, -bnd, bnd)
+ nn.init.uniform_(self.k_proj.weight, -bnd, bnd)
+ nn.init.uniform_(self.v_proj.weight, -bnd, bnd)
+ # out_proj.weight is left with default initialization, like pytorch attention
+ nn.init.zeros_(self.q_proj.bias)
+ nn.init.zeros_(self.k_proj.bias)
+ nn.init.zeros_(self.v_proj.bias)
+ nn.init.zeros_(self.out_proj.bias)
+
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+
+ def _recombine_heads(self, x: Tensor) -> Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+ # Get output
+ out = attn @ v
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+ return out
diff --git a/src/demo/demo.py b/src/demo/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..166b2b03f28b71b0b3ea8e81e808ffaa9bc6f5da
--- /dev/null
+++ b/src/demo/demo.py
@@ -0,0 +1,738 @@
+import gradio as gr
+import numpy as np
+from src.demo.utils import get_point, store_img, get_point_move, store_img_move, clear_points, upload_image_move, segment_with_points, segment_with_points_paste, fun_clear, paste_with_mask_and_offset
+import spaces
+
+examples_remove = [
+ [
+ "examples/remove/02_ring/0.jpg", # original image 1
+ "examples/remove/02_ring/mask0.jpg", # mask 1
+ "examples/remove/02_ring/0.jpg", # original image 2
+ "examples/remove/02_ring/mask1.jpg", #mask 2
+ "examples/remove/02_ring/0.jpg", #Original image 3
+ "examples/remove/02_ring/mask2.jpg", #mask 3
+ None, #Original image 4
+ None, # refine mask
+ ], # 02
+ [
+ "examples/remove/01_moto/0.jpg", # original image 1
+ "examples/remove/01_moto/mask0.jpg", # mask 1
+ "examples/remove/01_moto/0.jpg", # original image 2
+ None, #mask 2
+ "examples/remove/01_moto/0.jpg", #Original image 3
+ None, #mask 3
+ "examples/remove/01_moto/0.jpg", #Original image 4
+ "examples/remove/01_moto/mask1.jpg", # refine mask
+ ], # 01
+ [
+ "examples/remove/03_ball/0.jpg", # original image 1
+ "examples/remove/03_ball/mask0.jpg", # mask 1
+ "examples/remove/03_ball/0.jpg", # original image 2
+ "examples/remove/03_ball/mask1.jpg", #mask 2
+ "examples/remove/03_ball/0.jpg", #Original image 3
+ None, #mask 3
+ None, #Original image 4
+ None, # refine mask
+ ], # 03
+ [
+ "examples/remove/04_pikachu/0.jpg", # original image 1
+ "examples/remove/04_pikachu/mask0.jpg", # mask 1
+ "examples/remove/04_pikachu/0.jpg", # original image 2
+ "examples/remove/04_pikachu/mask1.jpg", #mask 2
+ "examples/remove/04_pikachu/0.jpg", #Original image 3
+ "examples/remove/04_pikachu/mask2.jpg", #mask 3
+ None, #Original image 4
+ None, # refine mask
+ ], # 04
+ [
+ "examples/remove/05_betty/0.jpg", # original image 1
+ "examples/remove/05_betty/mask0.jpg", # mask 1
+ None, # original image 2
+ None, #mask 2
+ None, #Original image 3
+ None, #mask 3
+ None, #Original image 4
+ None, # refine mask
+ ], # 05
+]
+examples_zoom = [
+ ["examples/zoom/01.jpg"],
+ ["examples/zoom/02.jpg"],
+ ["examples/zoom/03.jpg"],
+ ["examples/zoom/04.jpg"],
+ ["examples/zoom/05.jpg"],
+ ["examples/zoom/06.jpg"],
+ ["examples/zoom/07.jpg"],
+]
+examples_pan = [
+ ["examples/pan/01.jpg"],
+ ["examples/pan/02.jpg"],
+ ["examples/pan/03.jpg"],
+ ["examples/pan/04.jpg"],
+ ["examples/pan/05.jpg"],
+ ["examples/pan/06.jpg"],
+]
+
+examples_moving = [
+ [
+ "examples/layer/01_horse/00.jpg", #bg
+ "examples/layer/01_horse/mask0.jpg", #bg_mask
+ 0, 0, 1.2, "None", "left/right", #l1_dx, l1_dy, l1_resize
+ ],
+ [
+ "examples/moving/01_ball/0.jpg", #bg
+ "examples/moving/01_ball/mask0.jpg", #bg_mask
+ -0.2, -0.1, 0.8, "None", "None", #l1_dx, l1_dy, l1_resize
+ ],
+ [
+ "examples/moving/02_bell/0.jpg", #bg
+ "examples/moving/02_bell/mask0.jpg", #bg_mask
+ 0, 0, 0.75, "None", "None", #l1_dx, l1_dy, l1_resize
+ ],
+]
+examples_layer = [
+ [
+ "examples/layer/01_horse/00.jpg", #bg
+ "examples/layer/01_horse/mask0.jpg", #bg_mask
+
+ "examples/layer/01_horse/00.jpg", #l1
+ "examples/layer/01_horse/mask0.jpg", #l1_mask
+ -0.2, 0, 1, "None", "None", #l1_dx, l1_dy, l1_resize
+
+ "examples/layer/01_horse/00.jpg", #l2
+ "examples/layer/01_horse/mask0.jpg", #l2_mask
+ 0.2, 0, 1, "None", "None", #l2_dx, l2_dy, l2_resize
+
+ None, #l3
+ None, #l3_mask
+ 0, 0, 1, "None", "None", #l3_dx, l3_dy, l3_resize
+
+ "examples/layer/01_horse/00.jpg", #bg_ori
+ "examples/layer/01_horse/00.jpg", #l1_ori
+ "examples/layer/01_horse/00.jpg", #l2_ori
+ None, "None", "None", #l3_ori
+ ],
+
+ [
+ "examples/layer/02_baby/00.jpg", #bg
+ "examples/layer/02_baby/mask0.jpg", #bg_mask
+
+ "examples/layer/02_baby/00.jpg", #l1
+ "examples/layer/02_baby/mask1.jpg", #l1_mask
+ -0.35, 0, 1,"left/right", "None", #l1_dx, l1_dy, l1_resize
+
+ "examples/layer/02_baby/00.jpg", #l2
+ "examples/layer/02_baby/mask2.jpg", #l2_mask
+ 0.35, 0, 1, "left/right", "None", #l2_dx, l2_dy, l2_resize
+
+ None, #l3
+ None, #l3_mask
+ 0, 0, 1,"None", "None", #l3_dx, l3_dy, l3_resize
+ ],
+
+ [
+ "examples/layer/03_text/00.jpg", #bg
+ "examples/layer/03_text/mask0.jpg", #bg_mask
+
+ "examples/layer/03_text/01.jpg", #l1
+ "examples/layer/03_text/mask1.jpg", #l1_mask
+ 0.1, -0.1, 0.5, "None", "None",#l1_dx, l1_dy, l1_resize
+
+ None, #l2
+ None, #l2_mask
+ 0, 0, 1, "None", "None",#l2_dx, l2_dy, l2_resize
+
+ None, #l3
+ None, #l3_mask
+ 0, 0, 1,"None", "None", #l3_dx, l3_dy, l3_resize
+ ],
+ [
+ "examples/layer/04_cross/0.jpg", #bg
+ "examples/layer/04_cross/mask0.jpg", #bg_mask
+
+ "examples/layer/04_cross/2.jpg", #l1
+ "examples/layer/04_cross/mask2.jpg", #l1_mask
+ -0.1, -0.25, 0.5, "None", "None",#l1_dx, l1_dy, l1_resize
+
+ "examples/layer/04_cross/1.jpg", #l2
+ "examples/layer/04_cross/mask1.jpg", #l2_mask
+ -0.1, -0.15, 0.7, "None", "None",#l2_dx, l2_dy, l2_resize
+
+ "examples/layer/04_cross/3.jpg", #l3
+ "examples/layer/04_cross/mask3.jpg", #l3_mask
+ -0.1, -0.55, 0.5, "None", "None",#l3_dx, l3_dy, l3_resize
+ ],
+]
+examples_mask_box = [
+ [
+ "examples/mask_box/image1.jpg", # original image 1
+ "examples/mask_box/image2.jpg", # original image 1
+ "examples/mask_box/mask01.jpg", # original image 1
+ "examples/mask_box/mask02.jpg", # original image 1
+ "examples/mask_box/mask00.jpg", # original image 1
+ ]
+]
+
+# 01
+def create_demo_remove(runner=None):
+ DESCRIPTION = """
+ # Object Removal
+
+ ## Usage:
+
+ - Upload a sources image, and then draw a box to generate the mask corresponding to the selecting object.
+ - You can choose to mask more than one object by using Mask2 and Mask3.
+ - If you encounter artifacts, try to sketch the regions that caused the artifacts.
+ - You can refer to the first motorcycle example to understand the usage of the Refined Mask.
+ - Please clear the output before running a new example!
+ - For more irregular composition masks, refer to the last page: Mask Preparation.
+"""
+
+ with gr.Blocks() as demo:
+ original_image = gr.State(value=None)
+ img_with_mask = gr.State(value=None)
+
+ selected_points = gr.State([])
+ global_points = gr.State([])
+ global_point_label = gr.State([])
+
+ gr.Markdown(DESCRIPTION)
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# INPUT")
+ # mask 0
+ gr.Markdown("## Select two points for Mask 1:")
+ gr.Markdown("the top left and the bottom right")
+ original_image_1 = gr.Image(sources='upload', label="Original image (Mask 1)", interactive=True, type="numpy")
+ # mask 1
+ gr.Markdown("## Option: Select two points for Mask 2")
+ gr.Markdown("the top left and the bottom right")
+ original_image_2 = gr.Image(sources='upload', label="Original (Mask 2)", interactive=True, type="numpy")
+ # mask 2
+ gr.Markdown("## Option: Select two points for Mask 3")
+ gr.Markdown("the top left and the bottom right")
+ original_image_3 = gr.Image(label="Original image (Mask 3)", interactive=True, type="numpy")
+
+ gr.Markdown("## Option: Mask regions caused artifacts")
+ gr.Markdown("the top left and the bottom right")
+ original_image_4 = gr.Image(label="Original image (Refine Mask)", interactive=True, type="numpy")
+ with gr.Row():
+ run_button = gr.Button("Edit")
+ clear_button = gr.Button("Clear")
+
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# Mask")
+
+ gr.Markdown("## Removal Mask 1")
+ mask_1 = gr.Image(sources='upload', label="Removal Mask 1", interactive=True, type="numpy")
+ gr.Markdown("## Option: Removal Mask 2")
+ mask_2 = gr.Image(sources='upload', label="Removal Mask 2", interactive=True, type="numpy")
+ gr.Markdown("## Option: Removal Mask 3")
+ mask_3 = gr.Image(sources='upload', label="Removal Mask 3", interactive=True, type="numpy")
+
+ gr.Markdown("## Option: Refine Mask to avoid artifacts")
+ refine_mask = gr.Image(sources='upload', label="Refine Mask", interactive=True, type="numpy")
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# OUTPUT")
+ gr.Markdown("## Results")
+ output = gr.Gallery(columns=1, height='auto')
+
+
+ original_image_1.select(
+ segment_with_points,
+ inputs=[original_image_1, original_image, global_points, global_point_label],
+ outputs=[original_image_1, original_image, mask_1, global_points, global_point_label]
+ )
+ original_image_2.select(
+ segment_with_points,
+ inputs=[original_image_2, original_image, global_points, global_point_label],
+ outputs=[original_image_2, original_image, mask_2, global_points, global_point_label]
+ )
+ original_image_3.select(
+ segment_with_points,
+ inputs=[original_image_3, original_image, global_points, global_point_label],
+ outputs=[original_image_3, original_image, mask_3, global_points, global_point_label]
+ )
+ original_image_4.select(
+ segment_with_points,
+ inputs=[original_image_4, original_image, global_points, global_point_label],
+ outputs=[original_image_4, original_image, refine_mask, global_points, global_point_label]
+ )
+
+ with gr.Column():
+ gr.Markdown("Try some of the examples below ⬇️")
+ gr.Examples(
+ examples=examples_remove,
+ inputs=[
+ original_image_1, mask_1,
+ original_image_2, mask_2,
+ original_image_3, mask_3,
+ original_image_4, refine_mask]
+ )
+ run_button.click(fn=runner, inputs=[original_image, mask_1, mask_2, mask_3, refine_mask,
+ original_image_1, original_image_2, original_image_3], outputs=[output])
+ clear_button.click(
+ fn=fun_clear,
+ inputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, original_image_1, original_image_2, original_image_3, original_image_4, mask_1, mask_2, mask_3, refine_mask],
+ outputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, original_image_1, original_image_2, original_image_3, original_image_4, mask_1, mask_2, mask_3, refine_mask]
+ )
+ return demo
+
+
+# 02:
+def create_demo_zooming(runner=None):
+ DESCRIPTION = """
+ # Zooming Out
+
+ ## Usage:
+
+ - Upload a sources image and choose the width and height zooming scale to zoom out.
+ - The illustration of image adjustment and mask preparation is shown in the second column.
+ - We recommend setting the zooming scale between 0.75 and 1 for optimal results.
+ - Please clear the output before running a new example!
+ """
+
+ with gr.Blocks() as demo:
+
+ gr.Markdown(DESCRIPTION)
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# INPUT")
+ # mask 0
+ gr.Markdown("## Original Image")
+ original_image = gr.Image(sources='upload', interactive=True, type="numpy")
+
+
+ gr.Markdown("## Scale:")
+ width_scale= gr.Slider(
+ label="Width scale",
+ minimum=0,
+ maximum=1,
+ step=0.05,
+ value=0.9,
+ interactive=True)
+ height_scale= gr.Slider(
+ label="Height scale",
+ minimum=0,
+ maximum=1,
+ step=0.05,
+ value=0.9,
+ interactive=True)
+ with gr.Row():
+ run_button = gr.Button("Edit")
+ clear_button = gr.Button("Clear")
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# Preprocess")
+ gr.Markdown("## Image Adjustment:")
+ new_image = gr.Gallery(columns=1, height='auto')
+ gr.Markdown("## Mask Adjustment:")
+ new_mask = gr.Gallery(columns=1, height='auto')
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# OUTPUT")
+ gr.Markdown("## Results")
+ output = gr.Gallery(columns=1, height='auto')
+
+ with gr.Column():
+ gr.Markdown("Try some of the examples below ⬇️")
+ gr.Examples(
+ examples=examples_zoom,
+ inputs=[original_image]
+ )
+ run_button.click(fn=runner, inputs=[original_image, width_scale, height_scale], outputs=[output, new_image, new_mask])
+ clear_button.click(fn=fun_clear, inputs=[original_image, width_scale, height_scale, output, new_image, new_mask],
+ outputs=[original_image, width_scale, height_scale, output, new_image, new_mask])
+ return demo
+# 03
+
+def create_demo_panning(runner=None):
+ DESCRIPTION = """
+ # Camera Panning
+
+ ## Usage:
+
+ - Upload a sources image and choose the width and height panning scale.
+ - The illustration of image adjustment and mask preparation is shown in the second column.
+ - We recommend setting the panning scale between 0 and 0.25 for optimal results.
+ - Please clear the output before running a new example!
+ """
+
+ with gr.Blocks() as demo:
+ gr.Markdown(DESCRIPTION)
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# INPUT")
+ # mask 0
+ gr.Markdown("## Original Image")
+ original_image = gr.Image(sources='upload', interactive=True, type="numpy")
+ w_direction = gr.Radio(["left", "right"], value="left", label="Width Direction")
+ w_scale = gr.Slider(
+ label="Width scale",
+ minimum=0,
+ maximum=1,
+ step=0.05,
+ value=0,
+ interactive=True)
+
+ h_direction = gr.Radio(["up", "down"], value="up", label="Height Direction")
+ h_scale = gr.Slider(
+ label="Height scale",
+ minimum=0,
+ maximum=1,
+ step=0.05,
+ value=0,
+ interactive=True)
+ with gr.Row():
+ run_button = gr.Button("Edit")
+ clear_button = gr.Button("Clear")
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# Preprocess")
+ gr.Markdown("## Image Adjustment:")
+ new_image = gr.Gallery(columns=1, height='auto')
+ gr.Markdown("## Mask Adjustment:")
+ new_mask = gr.Gallery(columns=1, height='auto')
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# OUTPUT")
+ gr.Markdown("## Results")
+ output = gr.Gallery(columns=1, height='auto')
+
+ with gr.Column():
+ gr.Markdown("Try some of the examples below ⬇️")
+ gr.Examples(
+ examples=examples_pan,
+ inputs=[original_image]
+ )
+ run_button.click(fn=runner, inputs=[original_image, w_direction, w_scale, h_direction, h_scale], outputs=[output, new_image, new_mask])
+ clear_button.click(fn=fun_clear, inputs=[original_image, w_direction, w_scale, h_direction, h_scale, new_image, new_mask, output],
+ outputs=[original_image, w_direction, w_scale, h_direction, h_scale, new_image, new_mask, output])
+ return demo
+# 04:
+def create_position_size(label=None):
+ image = gr.Image(sources='upload', label=label, interactive=True, type="numpy")
+ with gr.Row():
+ dx = gr.Slider(
+ label="Left-Right",
+ minimum=-1,
+ maximum=1,
+ step=0.05,
+ value=0,
+ interactive=True
+ )
+ dy = gr.Slider(
+ label="Down-Up",
+ minimum=-1,
+ maximum=1,
+ step=0.05,
+ value=0,
+ interactive=True
+ )
+ resize_scale = gr.Slider(
+ label="Resize",
+ minimum=0,
+ maximum=2,
+ step=0.05,
+ value=1,
+ interactive=True
+ )
+ with gr.Row():
+ w_flip = gr.Radio(["left/right","None"], value="None", label="Horizontal Flip")
+ h_flip = gr.Radio(["down/up", "None"], value="None", label="Vertical Flip")
+ return image, dx, dy, resize_scale, w_flip, h_flip
+# 05:
+def create_demo_layer(runner=None):
+ DESCRIPTION = """
+ # 🚩 Multi-Layered selecting 🚩
+
+ ## Usage:
+
+ - Notice that all operations can be achieved using the multi-layered selecting mode.
+ - In particular, you can accomplish multi-object selecting such as adding objects and cross-image composition on this page.
+ - Try some interesting examples given below to understand the usage.
+ - Please clear the output before running a new example!
+ - We strongly recommend you to read the [original paper](https://arxiv.org/abs/2403.14487) to further explore more uses of multi-layered selecting.
+ """
+ global_points = gr.State([])
+ global_point_label = gr.State([])
+ bg_ori = gr.State(value=None)
+ l1_ori = gr.State(value=None)
+ l2_ori = gr.State(value=None)
+ l3_ori = gr.State(value=None)
+ with gr.Blocks() as demo:
+ gr.Markdown(DESCRIPTION)
+ with gr.Row():
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# INPUT")
+ gr.Markdown("## Background Image")
+ bg_img = gr.Image(sources='upload', label="Background", interactive=True, type="numpy")
+ gr.Markdown("## Layer-1")
+ l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip = create_position_size(label="Layer-1")
+ gr.Markdown("## Layer-2")
+ l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip = create_position_size(label="Layer-2")
+ gr.Markdown("## Layer-3")
+ l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip = create_position_size(label="Layer-3")
+ with gr.Row():
+ run_button = gr.Button("Edit")
+ clear_button = gr.Button("Clear")
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# Mask")
+ gr.Markdown("## Background Mask for Removal:")
+ bg_mask = gr.Image(sources='upload', label="BG Mask", interactive=True, type="numpy")
+ gr.Markdown("## Layer-1 Mask:")
+ l1_mask = gr.Image(sources='upload', label="L1 Mask", interactive=True, type="numpy")
+ gr.Markdown("## Layer-2 Mask:")
+ l2_mask = gr.Image(sources='upload', label="L2 Mask", interactive=True, type="numpy")
+ gr.Markdown("## Layer-3 Mask:")
+ l3_mask = gr.Image(sources='upload', label="L3 Mask", interactive=True, type="numpy")
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# OUTPUT")
+ gr.Markdown("## Results")
+ output = gr.Gallery(columns=1, height='auto')
+
+ with gr.Column():
+ gr.Markdown("Try some of the examples below ⬇️")
+ gr.Examples(
+ examples=examples_layer,
+ inputs=[
+ bg_img, bg_mask,
+ l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+ l2_img, l2_mask, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip,
+ l3_img, l3_mask, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
+ ]
+ )
+ bg_img.select(
+ segment_with_points,
+ inputs=[bg_img, bg_ori, global_points, global_point_label],
+ outputs=[bg_img, bg_ori, bg_mask, global_points, global_point_label]
+ )
+ l1_img.select(
+ segment_with_points,
+ inputs=[l1_img, l1_ori, global_points, global_point_label],
+ outputs=[l1_img, l1_ori, l1_mask, global_points, global_point_label]
+ )
+ l2_img.select(
+ segment_with_points,
+ inputs=[l2_img, l2_ori, global_points, global_point_label],
+ outputs=[l2_img, l2_ori, l2_mask, global_points, global_point_label]
+ )
+ l3_img.select(
+ segment_with_points,
+ inputs=[l3_img, l3_ori, global_points, global_point_label],
+ outputs=[l3_img, l3_ori, l3_mask, global_points, global_point_label]
+ )
+
+ run_button.click(fn=runner, inputs=[
+ bg_img,
+ l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+ l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip,
+ l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
+ bg_mask, l1_mask, l2_mask, l3_mask,
+ bg_ori, l1_ori, l2_ori, l3_ori
+ ], outputs=[output])
+
+ clear_button.click(fn=fun_clear,
+ inputs=[bg_img, bg_ori,
+ l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+ l2_img, l2_ori, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip,
+ l3_img, l3_ori, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
+ bg_mask, l1_mask, l2_mask, l3_mask,
+ global_points, global_point_label, output],
+ outputs=[bg_img, bg_ori,
+ l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+ l2_img, l2_ori, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip,
+ l3_img, l3_ori, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
+ bg_mask, l1_mask, l2_mask, l3_mask,
+ global_points, global_point_label, output],
+ )
+ return demo
+
+# 06:
+def create_demo_mask_box(runner=None):
+ DESCRIPTION = """
+ # 🔧 Mask Preparation
+ ## Usage:
+ - This page is a tool for you to combine more than one mask.
+ - You can draw a box to mask an object to obtain Masks 1-4.
+ - The merged mask is the union of Masks 1-4.
+ - Please clear the output before running a new example!
+ """
+
+ with gr.Blocks() as demo:
+ original_image = gr.State(value=None)
+ img_with_mask = gr.State(value=None)
+ selected_points = gr.State([])
+ global_points = gr.State([])
+ global_point_label = gr.State([])
+ gr.Markdown(DESCRIPTION)
+ with gr.Row():
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# INPUT")
+ gr.Markdown("## 1. Select two points for Mask 1")
+ gr.Markdown("the top left and the bottom right")
+ img_draw_box_1 = gr.Image(sources='upload', label="Original Image", interactive=True, type="numpy")
+
+ gr.Markdown("## 2. Select two points for Mask 2")
+ gr.Markdown("the top left and the bottom right")
+ img_draw_box_2 = gr.Image(sources='upload', label="Original Image", interactive=True, type="numpy")
+
+ gr.Markdown("## 3. Select two points for Mask 3")
+ gr.Markdown("the top left and the bottom right")
+ img_draw_box_3 = gr.Image(sources='upload', label="Original Image", interactive=True, type="numpy")
+
+ gr.Markdown("## 4. Select two points for Mask 4")
+ gr.Markdown("the top left and the bottom right")
+ img_draw_box_4 = gr.Image(label="Original Image", interactive=True, type="numpy")
+
+ with gr.Row():
+ run_button = gr.Button("Edit")
+ clear_button = gr.Button("Clear")
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# Mask")
+ gr.Markdown("## Mask 1")
+ mask_1 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy")
+ gr.Markdown("## Mask 2")
+ mask_2 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy")
+ gr.Markdown("## Mask 3")
+ mask_3 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy")
+ gr.Markdown("## Mask 4")
+ mask_4 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy")
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# Merged Mask")
+ merged_mask = gr.Image(sources='upload', label="Mask of object", interactive=True, type="numpy")
+
+ with gr.Column():
+ gr.Markdown("Please see the example below. ⬇️")
+ gr.Examples(
+ examples=examples_mask_box,
+ inputs=[
+ img_draw_box_1, img_draw_box_2, mask_1, mask_2, merged_mask
+ ]
+ )
+ img_draw_box_1.select(
+ segment_with_points,
+ inputs=[img_draw_box_1, original_image, global_points, global_point_label],
+ outputs=[img_draw_box_1, original_image, mask_1, global_points, global_point_label]
+ )
+ img_draw_box_2.select(
+ segment_with_points,
+ inputs=[img_draw_box_2, original_image, global_points, global_point_label],
+ outputs=[img_draw_box_2, original_image, mask_2, global_points, global_point_label]
+ )
+ img_draw_box_3.select(
+ segment_with_points,
+ inputs=[img_draw_box_3, original_image, global_points, global_point_label],
+ outputs=[img_draw_box_3, original_image, mask_3, global_points, global_point_label]
+ )
+ img_draw_box_4.select(
+ segment_with_points,
+ inputs=[img_draw_box_4, original_image, global_points, global_point_label],
+ outputs=[img_draw_box_4, original_image, mask_4, global_points, global_point_label]
+ )
+
+ run_button.click(fn=runner, inputs=[mask_1, mask_2, mask_3, mask_4], outputs=[merged_mask])
+ clear_button.click(
+ fn=fun_clear,
+ inputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, img_draw_box_1, img_draw_box_2, img_draw_box_3, img_draw_box_4, mask_1, mask_2, mask_3, mask_4],
+ outputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, img_draw_box_1, img_draw_box_2, img_draw_box_3, img_draw_box_4, mask_1, mask_2, mask_3, mask_4, merged_mask]
+ )
+ return demo
+
+def create_demo_moving(runner=None):
+ DESCRIPTION = """
+ # Object Moving, Resizing, and Flipping
+
+ ## Usage:
+ - Upload an image and draw a box around the object to manipulate.
+ - Move the object vertically or horizontally using sliders or by drawing an arrow.
+ - You can select options for moving and flipping the object from a menu.
+ - Please clear the output before running a new example!
+ """
+
+ selected_points = gr.State([])
+ global_points = gr.State([])
+ global_point_label = gr.State([])
+ bg_ori = gr.State(value=None)
+ l1_ori = gr.State(value=None)
+ with gr.Blocks() as demo:
+ gr.Markdown(DESCRIPTION)
+ with gr.Row():
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# INPUT")
+ gr.Markdown("## Draw box to mask target object")
+ bg_img = gr.Image(sources='upload', label="Background", interactive=True, type="numpy")
+ gr.Markdown("## Draw arrow to describe the movement")
+ l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip = create_position_size(label="Layer-1")
+ with gr.Row():
+ run_button = gr.Button("Edit")
+ clear_button = gr.Button("Clear")
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# Mask")
+ gr.Markdown("## Background Mask for Removal:")
+ bg_mask = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy")
+
+ with gr.Column():
+ with gr.Group():
+ gr.Markdown("# OUTPUT")
+ gr.Markdown("## Results")
+ output = gr.Gallery(columns=1, height='auto')
+
+ with gr.Column():
+ gr.Markdown("Try some of the examples below ⬇️")
+ gr.Examples(
+ examples=examples_moving,
+ inputs=[
+ bg_img, bg_mask, l1_dx, l1_dy, l1_resize, l1_h_flip, l1_w_flip
+ ]
+ )
+ bg_img.select(
+ segment_with_points,
+ inputs=[bg_img, bg_ori, global_points, global_point_label],
+ outputs=[bg_img, bg_ori, bg_mask, global_points, global_point_label]
+ )
+ l1_img.select(
+ get_point_move,
+ [bg_ori, l1_img, selected_points],
+ [l1_img, bg_ori, selected_points, l1_dx, l1_dy],
+ )
+
+ run_button.click(fn=runner, inputs=[
+ bg_img, bg_ori,bg_mask,
+ l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, selected_points
+ ], outputs=[output])
+
+ clear_button.click(fn=fun_clear,
+ inputs=[bg_img, bg_ori, bg_mask, l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+ global_points, global_point_label, selected_points, output],
+ outputs=[bg_img, bg_ori, bg_mask, l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+ global_points, global_point_label, selected_points, output],
+ )
+ return demo
diff --git a/src/demo/model.py b/src/demo/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92e3c92296a283680eb055c030b0f241efe4a2d
--- /dev/null
+++ b/src/demo/model.py
@@ -0,0 +1,517 @@
+import numpy as np
+import torch
+from diffusers import DDIMScheduler
+import cv2
+from utils.sdxl import sdxl
+from utils.inversion import Inversion
+import math
+import torch.nn.functional as F
+import utils.utils as utils
+import os
+import matplotlib.pyplot as plt
+from PIL import Image, ImageDraw, ImageFont
+import spaces
+
+MAX_NUM_WORDS = 77
+
+
+class LayerFusion:
+ def get_mask(self, maps, alpha, use_pool,x_t):
+ k = 1
+ maps = (maps * alpha).sum(-1).mean(1)
+ if use_pool:
+ maps = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+ mask = F.interpolate(maps, size=(x_t.shape[2:])) #[2, 1, 128, 128]
+ mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+ mask=(mask - mask.min ()) / (mask.max () - mask.min ())
+ mask = mask.gt(self.mask_threshold)
+ self.mask=mask
+ mask = mask[:1] + mask
+ return mask
+
+ def get_one_mask(self, maps, use_pool, x_t, idx_lst, i=None, sav_img=False):
+ k=1
+ if sav_img is False:
+ mask_tot = 0
+ for obj in idx_lst:
+ mask = maps[0, :, :, :, obj].mean(0).reshape(1, 1, 32, 32)
+ if use_pool:
+ mask = F.max_pool2d(mask, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+ mask = F.interpolate(mask, size=(x_t.shape[2:]))
+ mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+ mask=(mask - mask.min ()) / (mask.max () - mask.min ())
+ mask = mask.gt(self.mask_threshold[int(self.counter/10)])
+ mask_tot |= mask
+ mask = mask_tot
+ return mask
+ else:
+ for obj in idx_lst:
+ mask = maps[0, :, :, :, obj].mean(0).reshape(1, 1, 32, 32)
+ if use_pool:
+ mask = F.max_pool2d(mask, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+ mask = F.interpolate(mask, size=(1024, 1024))#[1, 1, 1024, 1024]
+ mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+ mask=(mask - mask.min ()) / (mask.max () - mask.min ())
+ mask = mask.gt(0.6)
+ mask = np.array(mask[0][0].clone().cpu()).astype(np.uint8)*255
+ cv2.imwrite(f'./img/sam_mask/{self.blend_list[i][0]}_{self.counter}.jpg', mask)
+ return mask
+
+ def mv_op(self, mp, op, scale=0.2, ones=False, flip=None):
+ _, b, H, W = mp.shape
+ if ones == False:
+ new_mp = torch.zeros_like(mp)
+ else:
+ new_mp = torch.ones_like(mp)
+ K = int(scale*W)
+ if op == 'right':
+ new_mp[:, :, :, K:] = mp[:, :, :, 0:W-K]
+ elif op == 'left':
+ new_mp[:, :, :, 0:W-K] = mp[:, :, :, K:]
+ elif op == 'down':
+ new_mp[:, :, K:, :] = mp[:, :, 0:W-K, :]
+ elif op == 'up':
+ new_mp[:, :, 0:W-K, :] = mp[:, :, K:, :]
+ if flip is not None:
+ new_mp = torch.flip(new_mp, dims=flip)
+
+ return new_mp
+
+ def mv_layer(self, x_t, bg_id, fg_id, op_id):
+ bg_img = x_t[bg_id:(bg_id+1)].clone()
+ fg_img = x_t[fg_id:(fg_id+1)].clone()
+ fg_mask = self.fg_mask_list[fg_id-3]
+ op_list = self.op_list[fg_id-3]
+
+ for item in op_list:
+ op, scale = item[0], item[1]
+ if scale != 0:
+ fg_img = self.mv_op(fg_img, op=op, scale=scale)
+ fg_mask = self.mv_op(fg_mask, op=op, scale=scale)
+ x_t[op_id:(op_id+1)] = bg_img*(1-fg_mask) + fg_img*fg_mask
+
+ def __call__(self, x_t):
+ self.counter += 1
+ # inpainting
+ if self.blend_time[0] <= self.counter <= self.blend_time[1]:
+ x_t[1:2] = x_t[1:2]*self.remove_mask + x_t[0:1]*(1-self.remove_mask)
+
+ if self.counter == self.blend_time[1] + 1 and self.mode != "removal":
+ b = x_t.shape[0]
+ bg_id = 1 #bg_layer
+ op_id = 2 #canvas
+ for fg_id in range(3, b): #fg_layer
+ self.mv_layer(x_t, bg_id=bg_id, fg_id=fg_id, op_id=op_id)
+ bg_id = op_id
+
+ return x_t
+
+ def __init__(self, remove_mask, fg_mask_list, refine_mask=None,
+ blend_time=[0, 40],
+ mode="removal", op_list=None):
+ self.counter = 0
+ self.mode = mode
+ self.op_list = op_list
+ self.blend_time = blend_time
+
+ self.remove_mask = remove_mask
+ self.refine_mask = refine_mask
+ if self.refine_mask is not None:
+ self.new_mask = self.remove_mask + self.refine_mask
+ self.new_mask[self.new_mask>0] = 1
+ else:
+ self.new_mask = None
+ self.fg_mask_list = fg_mask_list
+
+
+class Control():
+ def step_callback(self, x_t):
+ if self.layer_fusion is not None:
+ x_t = self.layer_fusion(x_t)
+ return x_t
+ def __init__(self, layer_fusion):
+ self.layer_fusion = layer_fusion
+
+def register_attention_control(model, controller, mask_time=[0, 40], refine_time=[0, 25]):
+ def ca_forward(self, place_in_unet):
+ to_out = self.to_out
+ if type(to_out) is torch.nn.modules.container.ModuleList:
+ to_out = self.to_out[0]
+ else:
+ to_out = self.to_out
+ self.counter = 0 #time
+ def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): #self_attention
+ x = hidden_states.clone()
+ context = encoder_hidden_states
+ is_cross = context is not None
+ if is_cross is False:
+ if controller.layer_fusion is not None and (mask_time[0] < self.counter < mask_time[1]):
+ b, i, j = x.shape
+ H = W = int(math.sqrt(i))
+ x_old = x.clone()
+ x = x.reshape(b, H, W, j)
+ new_mask = controller.layer_fusion.remove_mask
+ if new_mask is not None:
+ new_mask[new_mask>0] = 1
+ new_mask = F.interpolate(new_mask.to(dtype=torch.float32).clone(), size=(H, W), mode='bilinear').cuda()
+ new_mask = (1 - new_mask).reshape(1, H, W).unsqueeze(-1)
+ if (refine_time[0] < self.counter <= refine_time[1]) and controller.layer_fusion.refine_mask is not None:
+ new_mask = controller.layer_fusion.new_mask
+ new_mask = F.interpolate(new_mask.to(dtype=torch.float32).clone(), size=(H, W), mode='bilinear').cuda()
+ new_mask = (1 - new_mask).reshape(1, H, W).unsqueeze(-1)
+ idx = 1 #inpaiint_idx:bg
+ x[int(b/2)+idx, :, :] = (x[int(b/2)+idx, :, :]*new_mask[0])
+ x = x.reshape(b, i, j)
+ if is_cross:
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ else:
+ context = x
+ q = self.to_q(hidden_states)
+ k = self.to_k(x)
+ v = self.to_v(hidden_states)
+ q = self.head_to_batch_dim(q)
+ k = self.head_to_batch_dim(k)
+ v = self.head_to_batch_dim(v)
+
+ if hasattr(controller, 'count_layers'):
+ controller.count_layers(place_in_unet,is_cross)
+ sim = torch.einsum("b i d, b j d -> b i j", q.clone(), k.clone()) * self.scale
+
+ attn = sim.softmax(dim=-1)
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
+ out = self.batch_to_head_dim(out)
+ global global_cnt
+ self.counter += 1
+ return to_out(out)
+
+ return forward
+
+ def register_recr(net_, count, place_in_unet):
+ if net_.__class__.__name__ == 'Attention':
+ net_.forward = ca_forward(net_, place_in_unet)
+ return count + 1
+ elif hasattr(net_, 'children'):
+ for net__ in net_.children():
+ count = register_recr(net__, count, place_in_unet)
+ return count
+
+ cross_att_count = 0
+ sub_nets = model.unet.named_children()
+ for net in sub_nets:
+ if "down" in net[0]:
+ cross_att_count += register_recr(net[1], 0, "down")
+ elif "up" in net[0]:
+ cross_att_count += register_recr(net[1], 0, "up")
+ elif "mid" in net[0]:
+ cross_att_count += register_recr(net[1], 0, "mid")
+
+ controller.num_att_layers = cross_att_count
+
+class DesignEdit():
+ def __init__(self, pretrained_model_path="/home/jyr/model/stable-diffusion-xl-base-1.0"):
+ self.model_dtype = "fp16"
+ self.pretrained_model_path=pretrained_model_path
+ self.num_ddim_steps = 50
+ self.mask_time = [0, 40]
+ self.op_list = {}
+ self.attend_scale = {}
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+ if self.model_dtype == "fp16":
+ torch_dtype = torch.float16
+ elif self.model_dtype == "fp32":
+ torch_dtype = torch.float32
+ self.pipe = sdxl.from_pretrained(self.pretrained_model_path, torch_dtype=torch_dtype, use_safetensors=True, variant=self.model_dtype,scheduler=scheduler)
+
+ @spaces.GPU
+ def init_model(self, num_ddim_steps=50):
+ device = torch.device('cuda:0')
+ self.pipe.to(device)
+ inversion = Inversion(self.pipe,num_ddim_steps)
+ return self.pipe, inversion
+
+ @spaces.GPU(duration=120, enable_queue=True)
+ def run_remove(self, original_image=None, mask_1=None, mask_2=None, mask_3=None, refine_mask=None,
+ ori_1=None, ori_2=None, ori_3=None,
+ prompt="", save_dir="./tmp", mode='removal',):
+ # 01-1:
+ self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+ if original_image is None:
+ original_image = ori_1 if ori_1 is not None else ori_2 if ori_2 is not None else ori_3
+ op_list = None
+ attend_scale = 20
+ sample_ref_match={0 : 0, 1 : 0}
+ ori_shape = original_image.shape
+
+ # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
+ image_gt = Image.fromarray(original_image).resize((1024, 1024))
+ image_gt = np.stack([np.array(image_gt)])
+ mask_list = [mask_1, mask_2, mask_3]
+ remove_mask = utils.attend_mask(utils.add_masks_resized(mask_list), attend_scale=attend_scale) # numpy to tensor
+ fg_mask_list = None
+ refine_mask = utils.attend_mask(utils.convert_and_resize_mask(refine_mask)) if refine_mask is not None else None
+
+ # 01-3: prepare: prompts, blend_time, refine_time
+ prompts = len(sample_ref_match)*[prompt] # 2
+ blend_time = [0, 41]
+ refine_time = [0, 25]
+
+ # 02: invert
+ _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1)
+
+ # 03: init layer_fusion and controller
+ lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, refine_mask=refine_mask,
+ blend_time=blend_time, mode=mode, op_list=op_list)
+ controller = Control(layer_fusion=lb)
+ register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
+
+ # 04: generate images
+ images = self.ldm_model(controller=controller, prompt=prompts,
+ latents=x_t, x_stars=x_stars,
+ negative_prompt_embeds=prompt_embeds,
+ negative_pooled_prompt_embeds=pooled_prompt_embeds,
+ sample_ref_match=sample_ref_match)
+ folder = None
+ utils.view_images(images, folder=folder)
+ return [cv2.resize(images[1], (ori_shape[1], ori_shape[0]))]
+
+ @spaces.GPU(duration=120, enable_queue=True)
+ def run_zooming(self, original_image, width_scale=1, height_scale=1, prompt="", save_dir="./tmp", mode='removal'):
+ self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+ # 01-1:
+ op_list = {0: ['zooming', [height_scale, width_scale]]}
+ ori_shape = original_image.shape
+ attend_scale = 30
+ sample_ref_match = {0 : 0, 1 : 0}
+
+ # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
+ img_new, mask = utils.zooming(original_image, [height_scale, width_scale])
+ img_new_copy = img_new.copy()
+ mask_copy = mask.copy()
+
+ image_gt = Image.fromarray(img_new).resize((1024, 1024))
+ image_gt = np.stack([np.array(image_gt)])
+
+ remove_mask = utils.attend_mask(utils.convert_and_resize_mask(mask), attend_scale=attend_scale) # numpy to tensor
+ fg_mask_list = None
+ refine_mask = None
+
+ # 01-3: prepare: prompts, blend_time, refine_time
+ prompts = len(sample_ref_match)*[prompt] # 2
+ blend_time = [0, 41]
+ refine_time = [0, 25]
+
+ # 02: invert
+ _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1)
+
+ # 03: init layer_fusion and controller
+ lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time,
+ mode=mode, op_list=op_list)
+ controller = Control(layer_fusion=lb)
+ register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
+
+ # 04: generate images
+ images = self.ldm_model(controller=controller, prompt=prompts,
+ latents=x_t, x_stars=x_stars,
+ negative_prompt_embeds=prompt_embeds,
+ negative_pooled_prompt_embeds=pooled_prompt_embeds,
+ sample_ref_match=sample_ref_match)
+ folder = None
+ utils.view_images(images, folder=folder)
+ resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0]))
+ return [resized_img], [img_new_copy], [mask_copy]
+
+ @spaces.GPU(duration=120, enable_queue=True)
+ def run_panning(self, original_image, w_direction, w_scale, h_direction, h_scale, prompt="", save_dir="./tmp", mode='removal'):
+ # 01-1: prepare: op_list, attend_scale, sample_ref_match
+ self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+ ori_shape = original_image.shape
+ attend_scale = 30
+ sample_ref_match = {0 : 0, 1 : 0}
+
+ # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
+ op_list = [[w_direction, w_scale], [h_direction, h_scale]]
+ img_new, mask = utils.panning(original_image, op_list=op_list)
+ img_new_copy = img_new.copy()
+ mask_copy = mask.copy()
+
+ image_gt = Image.fromarray(img_new).resize((1024, 1024))
+ image_gt = np.stack([np.array(image_gt)])
+ remove_mask = utils.attend_mask(utils.convert_and_resize_mask(mask), attend_scale=attend_scale) # numpy to tensor
+
+ fg_mask_list = None
+ refine_mask = None
+
+ # 01-3: prepare: prompts, blend_time, refine_time
+ prompts = len(sample_ref_match)*[prompt] # 2
+ blend_time = [0, 41]
+ refine_time = [0, 25]
+
+ # 02: invert
+ _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1)
+ # 03: init layer_fusion and controller
+ lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time,
+ mode=mode, op_list=op_list)
+ controller = Control(layer_fusion=lb)
+ register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
+
+ # 04: generate images
+
+ images = self.ldm_model(controller=controller, prompt=prompts,
+ latents=x_t, x_stars=x_stars,
+ negative_prompt_embeds=prompt_embeds,
+ negative_pooled_prompt_embeds=pooled_prompt_embeds,
+ sample_ref_match=sample_ref_match)
+ folder = None
+ utils.view_images(images, folder=folder)
+ resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0]))
+ return [resized_img], [img_new_copy], [mask_copy]
+
+ # layer-wise multi-object editing
+ def process_layer_states(self, layer_states):
+ self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+ image_paths = []
+ mask_paths = []
+ op_list = []
+
+ for state in layer_states:
+ img, mask, dx, dy, resize, w_flip, h_flip = state
+ if img is not None:
+ img = cv2.resize(img, (1024, 1024))
+ mask = utils.convert_and_resize_mask(mask)
+ dx_command = ['right', dx] if dx > 0 else ['left', -dx]
+ dy_command = ['up', dy] if dy > 0 else ['down', -dy]
+ flip_code = None
+ if w_flip == "left/right" and h_flip == "down/up":
+ flip_code = -1
+ elif w_flip == "left/right":
+ flip_code = 1 # 或者其他默认值,根据您的需要设置
+ elif h_flip == "down/up":
+ flip_code = 0
+ op_list.append([dx_command, dy_command])
+ img, mask, _ = utils.resize_image_with_mask(img, mask, resize)
+ img, mask, _ = utils.flip_image_with_mask(img, mask, flip_code=flip_code)
+ image_paths.append(img)
+ mask_paths.append(utils.attend_mask(mask))
+ sample_ref_match = {0: 0, 1: 0, 2: 0, 3: 1, 4: 2, 5: 3}
+ required_length = len(image_paths) + 3
+ truncated_sample_ref_match = {k: sample_ref_match[k] for k in sorted(sample_ref_match.keys())[:required_length]}
+ return image_paths, mask_paths, op_list, truncated_sample_ref_match
+
+ @spaces.GPU(duration=200)
+ def run_layer(self, bg_img, l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+ l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip,
+ l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
+ bg_mask, l1_mask, l2_mask, l3_mask,
+ bg_ori=None, l1_ori=None, l2_ori=None, l3_ori=None,
+ prompt="", save_dir="./tmp", mode='layerwise'):
+ self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+ # 00: prepare: layer-wise states
+ bg_img = bg_ori if bg_ori is not None else bg_img
+ l1_img = l1_ori if l1_ori is not None else l1_img
+ l2_img = l2_ori if l2_ori is not None else l2_img
+ l3_img = l3_ori if l3_ori is not None else l3_img
+ for mask in [bg_mask, l1_mask, l2_mask, l3_mask]:
+ if mask is None:
+ mask = np.zeros((1024, 1024), dtype=np.uint8)
+ else:
+ mask = utils.convert_and_resize_mask(mask)
+ l1_state = [l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip]
+ l2_state = [l2_img, l2_mask, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip]
+ l3_state = [l3_img, l3_mask, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip]
+ ori_shape = bg_img.shape
+
+ image_paths, fg_mask_list, op_list, sample_ref_match = self.process_layer_states([l1_state, l2_state, l3_state])
+ if image_paths == []:
+ mode = "removal"
+ # 01-1: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
+ attend_scale = 20
+ image_gt = [bg_img] + image_paths
+ image_gt = [Image.fromarray(img).resize((1024, 1024)) for img in image_gt]
+ image_gt = np.stack(image_gt)
+ remove_mask = utils.attend_mask(bg_mask, attend_scale=attend_scale)
+ refine_mask = None
+
+ # 01-2: prepare: promptrun_masks, blend_time, refine_time
+ prompts = len(sample_ref_match)*[prompt] # 2
+ blend_time = [0, 41]
+ refine_time = [0, 25]
+ attend_scale = []
+
+ # 02: invert
+ _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=len(image_gt))
+ # 03: init layer_fusion and controller
+ lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, refine_mask=refine_mask,
+ mode=mode, op_list=op_list)
+ controller = Control(layer_fusion=lb)
+ register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
+ # 04: generate images
+ images = self.ldm_model(controller=controller, prompt=prompts,
+ latents=x_t, x_stars=x_stars,
+ negative_prompt_embeds=prompt_embeds,
+ negative_pooled_prompt_embeds=pooled_prompt_embeds,
+ sample_ref_match=sample_ref_match)
+ folder = None
+ utils.view_images(images, folder=folder)
+ if mode == 'removal':
+ resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0]))
+ else:
+ resized_img = cv2.resize(images[2], (ori_shape[1], ori_shape[0]))
+ return [resized_img]
+
+ @spaces.GPU(duration=120, enable_queue=True)
+ def run_moving(self, bg_img, bg_ori, bg_mask, l1_dx, l1_dy, l1_resize,
+ l1_w_flip=None, l1_h_flip=None, selected_points=None,
+ prompt="", save_dir="./tmp", mode='layerwise'):
+ self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+ # 00: prepare: layer-wise states
+ bg_img = bg_ori if bg_ori is not None else bg_img
+ l1_img = bg_img
+ if bg_mask is None:
+ bg_mask = np.zeros((1024, 1024), dtype=np.uint8)
+ else:
+ bg_mask = utils.convert_and_resize_mask(bg_mask)
+ l1_mask = bg_mask
+ l1_state = [l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip]
+ ori_shape = bg_img.shape
+
+ image_paths, fg_mask_list, op_list, sample_ref_match = self.process_layer_states([l1_state])
+
+ # 01-1: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
+ attend_scale = 20
+ image_gt = [bg_img] + image_paths
+ image_gt = [Image.fromarray(img).resize((1024, 1024)) for img in image_gt]
+ image_gt = np.stack(image_gt)
+ remove_mask = utils.attend_mask(bg_mask, attend_scale=attend_scale)
+ refine_mask = None
+
+ # 01-2: prepare: promptrun_masks, blend_time, refine_time
+ prompts = len(sample_ref_match)*[prompt] # 2
+ blend_time = [0, 41]
+ refine_time = [0, 25]
+ attend_scale = []
+
+ # 02: invert
+ _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=len(image_gt))
+ # 03: init layer_fusion and controller
+ lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, refine_mask=refine_mask,
+ mode=mode, op_list=op_list)
+ controller = Control(layer_fusion=lb)
+ register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
+ # 04: generate images
+ images = self.ldm_model(controller=controller, prompt=prompts,
+ latents=x_t, x_stars=x_stars,
+ negative_prompt_embeds=prompt_embeds,
+ negative_pooled_prompt_embeds=pooled_prompt_embeds,
+ sample_ref_match=sample_ref_match)
+ folder = None
+ utils.view_images(images, folder=folder)
+ resized_img = cv2.resize(images[2], (ori_shape[1], ori_shape[0]))
+ return [resized_img]
+
+ # turn mask to 1024x1024 unit-8
+ def run_mask(self, mask_1, mask_2, mask_3, mask_4):
+ mask_list = [mask_1, mask_2, mask_3, mask_4]
+ final_mask = utils.add_masks_resized(mask_list)
+ return final_mask
\ No newline at end of file
diff --git a/src/demo/utils.py b/src/demo/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d0c789365ca94a89a9e71874999b73398f682f9
--- /dev/null
+++ b/src/demo/utils.py
@@ -0,0 +1,319 @@
+import numpy as np
+import gradio as gr
+import cv2
+from copy import deepcopy
+import torch
+from torchvision import transforms
+from PIL import Image, ImageDraw, ImageFont
+
+from sam.efficient_sam.build_efficient_sam import build_efficient_sam_vits
+from src.utils.utils import resize_numpy_image
+
+sam = build_efficient_sam_vits()
+
+def show_point_or_box(image, global_points):
+ # for point
+ if len(global_points) == 1:
+ image = cv2.circle(image, global_points[0], 10, (0, 0, 255), -1)
+ # for box
+ if len(global_points) == 2:
+ p1 = global_points[0]
+ p2 = global_points[1]
+ image = cv2.rectangle(image,(int(p1[0]),int(p1[1])),(int(p2[0]),int(p2[1])),(0,0,255),2)
+
+ return image
+
+def segment_with_points(
+ image,
+ original_image,
+ global_points,
+ global_point_label,
+ evt: gr.SelectData,
+ img_direction,
+ save_dir = "./tmp"
+):
+ if original_image is None:
+ original_image = image
+ else:
+ image = original_image
+ if img_direction is None:
+ img_direction = original_image
+ x, y = evt.index[0], evt.index[1]
+ image_path = None
+ mask_path = None
+ if len(global_points) == 0:
+ global_points.append([x, y])
+ global_point_label.append(2)
+ image_with_point= show_point_or_box(image.copy(), global_points)
+ return image_with_point, original_image, None, global_points, global_point_label
+ elif len(global_points) == 1:
+ global_points.append([x, y])
+ global_point_label.append(3)
+ x1, y1 = global_points[0]
+ x2, y2 = global_points[1]
+ if x1 < x2 and y1 >= y2:
+ global_points[0][0] = x1
+ global_points[0][1] = y2
+ global_points[1][0] = x2
+ global_points[1][1] = y1
+ elif x1 >= x2 and y1 < y2:
+ global_points[0][0] = x2
+ global_points[0][1] = y1
+ global_points[1][0] = x1
+ global_points[1][1] = y2
+ elif x1 >= x2 and y1 >= y2:
+ global_points[0][0] = x2
+ global_points[0][1] = y2
+ global_points[1][0] = x1
+ global_points[1][1] = y1
+ image_with_point = show_point_or_box(image.copy(), global_points)
+ # data process
+ input_point = np.array(global_points)
+ input_label = np.array(global_point_label)
+ pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2])
+ pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1])
+ img_tensor = transforms.ToTensor()(image)
+ # sam
+ predicted_logits, predicted_iou = sam(
+ img_tensor[None, ...],
+ pts_sampled,
+ pts_labels,
+ )
+ mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy()
+ mask_image = (mask*255.).astype(np.uint8)
+ return image_with_point, original_image, mask_image, global_points, global_point_label
+ else:
+ global_points=[[x, y]]
+ global_point_label=[2]
+ image_with_point= show_point_or_box(image.copy(), global_points)
+ return image_with_point, original_image, None, global_points, global_point_label
+
+
+def segment_with_points_paste(
+ image,
+ original_image,
+ global_points,
+ global_point_label,
+ image_b,
+ evt: gr.SelectData,
+ dx,
+ dy,
+ resize_scale
+
+):
+ if original_image is None:
+ original_image = image
+ else:
+ image = original_image
+ x, y = evt.index[0], evt.index[1]
+ if len(global_points) == 0:
+ global_points.append([x, y])
+ global_point_label.append(2)
+ image_with_point= show_point_or_box(image.copy(), global_points)
+ return image_with_point, original_image, None, global_points, global_point_label, None
+ elif len(global_points) == 1:
+ global_points.append([x, y])
+ global_point_label.append(3)
+ x1, y1 = global_points[0]
+ x2, y2 = global_points[1]
+ if x1 < x2 and y1 >= y2:
+ global_points[0][0] = x1
+ global_points[0][1] = y2
+ global_points[1][0] = x2
+ global_points[1][1] = y1
+ elif x1 >= x2 and y1 < y2:
+ global_points[0][0] = x2
+ global_points[0][1] = y1
+ global_points[1][0] = x1
+ global_points[1][1] = y2
+ elif x1 >= x2 and y1 >= y2:
+ global_points[0][0] = x2
+ global_points[0][1] = y2
+ global_points[1][0] = x1
+ global_points[1][1] = y1
+ image_with_point = show_point_or_box(image.copy(), global_points)
+ # data process
+ input_point = np.array(global_points)
+ input_label = np.array(global_point_label)
+ pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2])
+ pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1])
+ img_tensor = transforms.ToTensor()(image)
+ # sam
+ predicted_logits, predicted_iou = sam(
+ img_tensor[None, ...],
+ pts_sampled,
+ pts_labels,
+ )
+ mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy()
+ mask_uint8 = (mask*255.).astype(np.uint8)
+
+ return image_with_point, original_image, paste_with_mask_and_offset(image, image_b, mask_uint8, dx, dy, resize_scale), global_points, global_point_label, mask_uint8
+ else:
+ global_points=[[x, y]]
+ global_point_label=[2]
+ image_with_point= show_point_or_box(image.copy(), global_points)
+ return image_with_point, original_image, None, global_points, global_point_label, None
+
+def paste_with_mask_and_offset(image_a, image_b, mask, x_offset=0, y_offset=0, delta=1):
+ try:
+ numpy_mask = np.array(mask)
+ y_coords, x_coords = np.nonzero(numpy_mask)
+ x_min = x_coords.min()
+ x_max = x_coords.max()
+ y_min = y_coords.min()
+ y_max = y_coords.max()
+ target_center_x = int((x_min + x_max) / 2)
+ target_center_y = int((y_min + y_max) / 2)
+
+ image_a = Image.fromarray(image_a)
+ image_b = Image.fromarray(image_b)
+ mask = Image.fromarray(mask)
+
+ if image_a.size != mask.size:
+ mask = mask.resize(image_a.size)
+
+ cropped_image = Image.composite(image_a, Image.new('RGBA', image_a.size, (0, 0, 0, 0)), mask)
+ x_b = int(target_center_x * (image_b.width / cropped_image.width))
+ y_b = int(target_center_y * (image_b.height / cropped_image.height))
+ x_offset = x_offset - int((delta - 1) * x_b)
+ y_offset = y_offset - int((delta - 1) * y_b)
+ cropped_image = cropped_image.resize(image_b.size)
+ new_size = (int(cropped_image.width * delta), int(cropped_image.height * delta))
+ cropped_image = cropped_image.resize(new_size)
+ image_b.putalpha(128)
+ result_image = Image.new('RGBA', image_b.size, (0, 0, 0, 0))
+ result_image.paste(image_b, (0, 0))
+ result_image.paste(cropped_image, (x_offset, y_offset), mask=cropped_image)
+
+ return result_image
+ except:
+ return None
+
+def upload_image_move(img, original_image):
+ if original_image is not None:
+ return original_image
+ else:
+ return img
+
+def fun_clear(*args):
+ result = []
+ for arg in args:
+ if isinstance(arg, list):
+ result.append([])
+ else:
+ result.append(None)
+ return tuple(result)
+
+def clear_points(img):
+ image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
+ if mask.sum() > 0:
+ mask = np.uint8(mask > 0)
+ masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
+ else:
+ masked_img = image.copy()
+
+ return [], masked_img
+
+def get_point(img, sel_pix, evt: gr.SelectData):
+ sel_pix.append(evt.index)
+ points = []
+ for idx, point in enumerate(sel_pix):
+ if idx % 2 == 0:
+ cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
+ else:
+ cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
+ points.append(tuple(point))
+ if len(points) == 2:
+ cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
+ points = []
+ return img if isinstance(img, np.ndarray) else np.array(img)
+
+def calculate_translation_percentage(ori_shape, selected_points):
+ dx = selected_points[1][0] - selected_points[0][0]
+ dy = selected_points[1][1] - selected_points[0][1]
+ dx_percentage = dx / ori_shape[1]
+ dy_percentage = dy / ori_shape[0]
+
+ return dx_percentage, dy_percentage
+
+def get_point_move(original_image, img, sel_pix, evt: gr.SelectData):
+ if original_image is not None:
+ img = original_image.copy()
+ else:
+ original_image = img.copy()
+ if len(sel_pix)<2:
+ sel_pix.append(evt.index)
+ else:
+ sel_pix = [evt.index]
+ points = []
+ dx, dy = 0, 0
+ for idx, point in enumerate(sel_pix):
+ if idx % 2 == 0:
+ cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
+ else:
+ cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
+ points.append(tuple(point))
+ if len(points) == 2:
+ cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
+ ori_shape = original_image.shape
+ dx, dy = calculate_translation_percentage(original_image.shape, sel_pix)
+ points = []
+ img = np.array(img)
+
+ return img, original_image, sel_pix, dx, dy
+
+def store_img(img):
+ image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
+ if mask.sum() > 0:
+ mask = np.uint8(mask > 0)
+ masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
+ else:
+ masked_img = image.copy()
+
+ return image, masked_img, mask
+# im["background"], im["layers"][0]
+def store_img_move(img, mask=None):
+ if mask is not None:
+ image = img["background"]
+ return image, None, mask
+ image, mask = img["background"], np.float32(["layers"][0][:, :, 0]) / 255.
+ if mask.sum() > 0:
+ mask = np.uint8(mask > 0)
+ masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
+ else:
+ masked_img = image.copy()
+
+ return image, masked_img, (mask*255.).astype(np.uint8)
+
+def store_img_move_old(img, mask=None):
+ if mask is not None:
+ image = img["image"]
+ return image, None, mask
+ image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
+ if mask.sum() > 0:
+ mask = np.uint8(mask > 0)
+ masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
+ else:
+ masked_img = image.copy()
+
+ return image, masked_img, (mask*255.).astype(np.uint8)
+
+def mask_image(image, mask, color=[255,0,0], alpha=0.5, max_resolution=None):
+ """ Overlay mask on image for visualization purpose.
+ Args:
+ image (H, W, 3) or (H, W): input image
+ mask (H, W): mask to be overlaid
+ color: the color of overlaid mask
+ alpha: the transparency of the mask
+ """
+ if max_resolution is not None:
+ image, _ = resize_numpy_image(image, max_resolution*max_resolution)
+ mask = cv2.resize(mask, (image.shape[1], image.shape[0]),interpolation=cv2.INTER_NEAREST)
+
+ out = deepcopy(image)
+ img = deepcopy(image)
+ img[mask == 1] = color
+ out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
+ contours = cv2.findContours(np.uint8(deepcopy(mask)), cv2.RETR_TREE,
+ cv2.CHAIN_APPROX_SIMPLE)[-2:]
+ return out
\ No newline at end of file
diff --git a/src/utils/utils.py b/src/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b127a81c0f3f87a58bbab638cd032da25cdb5d5c
--- /dev/null
+++ b/src/utils/utils.py
@@ -0,0 +1,240 @@
+import numpy as np
+import cv2
+from basicsr.utils import img2tensor
+import torch
+import torch.nn.functional as F
+
+def resize_numpy_image(image, max_resolution=768 * 768, resize_short_edge=None):
+ h, w = image.shape[:2]
+ w_org = image.shape[1]
+ if resize_short_edge is not None:
+ k = resize_short_edge / min(h, w)
+ else:
+ k = max_resolution / (h * w)
+ k = k**0.5
+ h = int(np.round(h * k / 64)) * 64
+ w = int(np.round(w * k / 64)) * 64
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
+ scale = w/w_org
+ return image, scale
+
+def split_ldm(ldm):
+ x = []
+ y = []
+ for p in ldm:
+ x.append(p[0])
+ y.append(p[1])
+ return x,y
+
+def process_move(path_mask, h, w, dx, dy, scale, input_scale, resize_scale, up_scale, up_ft_index, w_edit, w_content, w_contrast, w_inpaint, precision, path_mask_ref=None):
+ dx, dy = dx*input_scale, dy*input_scale
+ if isinstance(path_mask, str):
+ mask_x0 = cv2.imread(path_mask)
+ else:
+ mask_x0 = path_mask
+ mask_x0 = cv2.resize(mask_x0, (h, w))
+ if path_mask_ref is not None:
+ if isinstance(path_mask_ref, str):
+ mask_x0_ref = cv2.imread(path_mask_ref)
+ else:
+ mask_x0_ref = path_mask_ref
+ mask_x0_ref = cv2.resize(mask_x0_ref, (h, w))
+ else:
+ mask_x0_ref=None
+
+ mask_x0 = img2tensor(mask_x0)[0]
+ mask_x0 = (mask_x0>0.5).float().to('cuda', dtype=precision)
+ if mask_x0_ref is not None:
+ mask_x0_ref = img2tensor(mask_x0_ref)[0]
+ mask_x0_ref = (mask_x0_ref>0.5).float().to('cuda', dtype=precision)
+ mask_org = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)))>0.5
+
+ mask_tar = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale*resize_scale), int(mask_x0.shape[-1]//scale*resize_scale)))>0.5
+ mask_cur = torch.roll(mask_tar, (int(dy//scale*resize_scale), int(dx//scale*resize_scale)), (-2,-1))
+
+ pad_size_x = abs(mask_tar.shape[-1]-mask_org.shape[-1])//2
+ pad_size_y = abs(mask_tar.shape[-2]-mask_org.shape[-2])//2
+ if resize_scale>1:
+ sum_before = torch.sum(mask_cur)
+ mask_cur = mask_cur[:,:,pad_size_y:pad_size_y+mask_org.shape[-2],pad_size_x:pad_size_x+mask_org.shape[-1]]
+ sum_after = torch.sum(mask_cur)
+ if sum_after != sum_before:
+ raise ValueError('Resize out of bounds, exiting.')
+ else:
+ temp = torch.zeros(1,1,mask_org.shape[-2], mask_org.shape[-1]).to(mask_org.device)
+ temp[:,:,pad_size_y:pad_size_y+mask_cur.shape[-2],pad_size_x:pad_size_x+mask_cur.shape[-1]]=mask_cur
+ mask_cur =temp>0.5
+
+ mask_other = (1-((mask_cur+mask_org)>0.5).float())>0.5
+ mask_overlap = ((mask_cur.float()+mask_org.float())>1.5).float()
+ mask_non_overlap = (mask_org.float()-mask_overlap)>0.5
+
+ return {
+ "mask_x0":mask_x0,
+ "mask_x0_ref":mask_x0_ref,
+ "mask_tar":mask_tar,
+ "mask_cur":mask_cur,
+ "mask_other":mask_other,
+ "mask_overlap":mask_overlap,
+ "mask_non_overlap":mask_non_overlap,
+ "up_scale":up_scale,
+ "up_ft_index":up_ft_index,
+ "resize_scale":resize_scale,
+ "w_edit":w_edit,
+ "w_content":w_content,
+ "w_contrast":w_contrast,
+ "w_inpaint":w_inpaint,
+ }
+
+def process_drag_face(h, w, x, y, x_cur, y_cur, scale, input_scale, up_scale, up_ft_index, w_edit, w_inpaint, precision):
+ for i in range(len(x)):
+ x[i] = int(x[i]*input_scale)
+ y[i] = int(y[i]*input_scale)
+ x_cur[i] = int(x_cur[i]*input_scale)
+ y_cur[i] = int(y_cur[i]*input_scale)
+
+ mask_tar = []
+ for p_idx in range(len(x)):
+ mask_i = torch.zeros(int(h//scale), int(w//scale)).cuda()
+ y_clip = int(np.clip(y[p_idx]//scale, 1, mask_i.shape[0]-2))
+ x_clip = int(np.clip(x[p_idx]//scale, 1, mask_i.shape[1]-2))
+ mask_i[y_clip-1:y_clip+2,x_clip-1:x_clip+2]=1
+ mask_i = mask_i>0.5
+ mask_tar.append(mask_i)
+ mask_cur = []
+ for p_idx in range(len(x_cur)):
+ mask_i = torch.zeros(int(h//scale), int(w//scale)).cuda()
+ y_clip = int(np.clip(y_cur[p_idx]//scale, 1, mask_i.shape[0]-2))
+ x_clip = int(np.clip(x_cur[p_idx]//scale, 1, mask_i.shape[1]-2))
+ mask_i[y_clip-1:y_clip+2,x_clip-1:x_clip+2]=1
+ mask_i=mask_i>0.5
+ mask_cur.append(mask_i)
+
+ return {
+ "mask_tar":mask_tar,
+ "mask_cur":mask_cur,
+ "up_scale":up_scale,
+ "up_ft_index":up_ft_index,
+ "w_edit": w_edit,
+ "w_inpaint": w_inpaint,
+ }
+
+def process_drag(path_mask, h, w, x, y, x_cur, y_cur, scale, input_scale, up_scale, up_ft_index, w_edit, w_inpaint, w_content, precision, latent_in):
+ if isinstance(path_mask, str):
+ mask_x0 = cv2.imread(path_mask)
+ else:
+ mask_x0 = path_mask
+ mask_x0 = cv2.resize(mask_x0, (h, w))
+ mask_x0 = img2tensor(mask_x0)[0]
+ dict_mask = {}
+ dict_mask['base'] = mask_x0
+ mask_x0 = (mask_x0>0.5).float().to('cuda', dtype=precision)
+
+ mask_other = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)))<0.5
+ mask_tar = []
+ mask_cur = []
+ for p_idx in range(len(x)):
+ mask_tar_i = torch.zeros(int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)).to('cuda', dtype=precision)
+ mask_cur_i = torch.zeros(int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)).to('cuda', dtype=precision)
+ y_tar_clip = int(np.clip(y[p_idx]//scale, 1, mask_tar_i.shape[0]-2))
+ x_tar_clip = int(np.clip(x[p_idx]//scale, 1, mask_tar_i.shape[0]-2))
+ y_cur_clip = int(np.clip(y_cur[p_idx]//scale, 1, mask_cur_i.shape[0]-2))
+ x_cur_clip = int(np.clip(x_cur[p_idx]//scale, 1, mask_cur_i.shape[0]-2))
+ mask_tar_i[y_tar_clip-1:y_tar_clip+2,x_tar_clip-1:x_tar_clip+2]=1
+ mask_cur_i[y_cur_clip-1:y_cur_clip+2,x_cur_clip-1:x_cur_clip+2]=1
+ mask_tar_i = mask_tar_i>0.5
+ mask_cur_i=mask_cur_i>0.5
+ mask_tar.append(mask_tar_i)
+ mask_cur.append(mask_cur_i)
+ latent_in[:,:,y_cur_clip//up_scale-1:y_cur_clip//up_scale+2, x_cur_clip//up_scale-1:x_cur_clip//up_scale+2] = latent_in[:,:, y_tar_clip//up_scale-1:y_tar_clip//up_scale+2, x_tar_clip//up_scale-1:x_tar_clip//up_scale+2]
+
+
+ return {
+ "dict_mask":dict_mask,
+ "mask_x0":mask_x0,
+ "mask_tar":mask_tar,
+ "mask_cur":mask_cur,
+ "mask_other":mask_other,
+ "up_scale":up_scale,
+ "up_ft_index":up_ft_index,
+ "w_edit": w_edit,
+ "w_inpaint": w_inpaint,
+ "w_content": w_content,
+ "latent_in":latent_in,
+ }
+
+def process_appearance(path_mask, path_mask_replace, h, w, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision):
+ if isinstance(path_mask, str):
+ mask_base = cv2.imread(path_mask)
+ else:
+ mask_base = path_mask
+ mask_base = cv2.resize(mask_base, (h, w))
+ if isinstance(path_mask_replace, str):
+ mask_replace = cv2.imread(path_mask_replace)
+ else:
+ mask_replace = path_mask_replace
+ mask_replace = cv2.resize(mask_replace, (h, w))
+
+ dict_mask = {}
+ mask_base = img2tensor(mask_base)[0]
+ dict_mask['base'] = mask_base
+ mask_base = (mask_base>0.5).to('cuda', dtype=precision)
+ mask_replace = img2tensor(mask_replace)[0]
+ dict_mask['replace'] = mask_replace
+ mask_replace = (mask_replace>0.5).to('cuda', dtype=precision)
+
+ mask_base_cur = F.interpolate(mask_base[None,None], (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)))>0.5
+ mask_replace_cur = F.interpolate(mask_replace[None,None], (int(mask_replace.shape[-2]//scale), int(mask_replace.shape[-1]//scale)))>0.5
+
+ return {
+ "dict_mask":dict_mask,
+ "mask_base_cur":mask_base_cur,
+ "mask_replace_cur":mask_replace_cur,
+ "up_scale":up_scale,
+ "up_ft_index":up_ft_index,
+ "w_edit":w_edit,
+ "w_content":w_content,
+ }
+
+def process_paste(path_mask, h, w, dx, dy, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision, resize_scale=None):
+ dx, dy = dx*input_scale, dy*input_scale
+ if isinstance(path_mask, str):
+ mask_base = cv2.imread(path_mask)
+ else:
+ mask_base = path_mask
+ mask_base = cv2.resize(mask_base, (h, w))
+
+ dict_mask = {}
+ mask_base = img2tensor(mask_base)[0][None, None]
+ mask_base = (mask_base>0.5).to('cuda', dtype=precision)
+ if resize_scale is not None and resize_scale!=1:
+ hi, wi = mask_base.shape[-2], mask_base.shape[-1]
+ mask_base = F.interpolate(mask_base, (int(hi*resize_scale), int(wi*resize_scale)))
+ pad_size_x = np.abs(mask_base.shape[-1]-wi)//2
+ pad_size_y = np.abs(mask_base.shape[-2]-hi)//2
+ if resize_scale>1:
+ mask_base = mask_base[:,:,pad_size_y:pad_size_y+hi,pad_size_x:pad_size_x+wi]
+ else:
+ temp = torch.zeros(1,1,hi, wi).to(mask_base.device)
+ temp[:,:,pad_size_y:pad_size_y+mask_base.shape[-2],pad_size_x:pad_size_x+mask_base.shape[-1]]=mask_base
+ mask_base = temp
+ mask_replace = mask_base.clone()
+ mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2,-1))
+ dict_mask['base'] = mask_base[0,0]
+ dict_mask['replace'] = mask_replace[0,0]
+ mask_replace = (mask_replace>0.5).to('cuda', dtype=precision)
+
+ mask_base_cur = F.interpolate(mask_base, (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)))>0.5
+ mask_replace_cur = torch.roll(mask_base_cur, (-int(dy/scale), -int(dx/scale)), (-2,-1))
+
+ return {
+ "dict_mask":dict_mask,
+ "mask_base_cur":mask_base_cur,
+ "mask_replace_cur":mask_replace_cur,
+ "up_scale":up_scale,
+ "up_ft_index":up_ft_index,
+ "w_edit":w_edit,
+ "w_content":w_content,
+ "w_edit":w_edit,
+ "w_content":w_content,
+ }
\ No newline at end of file
diff --git a/utils/inversion.py b/utils/inversion.py
new file mode 100755
index 0000000000000000000000000000000000000000..16ce1796d3a905a1596300b39cdfa940d49d0e15
--- /dev/null
+++ b/utils/inversion.py
@@ -0,0 +1,265 @@
+import torch
+import numpy as np
+from PIL import Image
+from typing import Optional, Union, Tuple, List
+from tqdm import tqdm
+import os
+from diffusers import DDIMInverseScheduler,DPMSolverMultistepInverseScheduler
+import spaces
+
+class Inversion:
+
+ def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray]):
+ timestep, next_timestep = min(
+ timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
+ return next_sample
+
+ @torch.no_grad()
+ def get_noise_pred_single(self, latents, t, context,cond=True,both=False):
+ added_cond_id=1 if cond else 0
+ do_classifier_free_guidance=False
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ if both is False:
+ added_cond_kwargs = {"text_embeds": self.add_text_embeds[added_cond_id].unsqueeze(0).repeat(self.inv_batch_size,1), "time_ids": self.add_time_ids[added_cond_id].unsqueeze(0).repeat(self.inv_batch_size,1)}
+ else:
+ added_cond_kwargs = {"text_embeds": self.add_text_embeds, "time_ids": self.add_time_ids}
+ noise_pred = self.model.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=context,
+ cross_attention_kwargs=None,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+ return noise_pred
+
+ @torch.no_grad()
+ def latent2image(self, latents, return_type='np'):
+ latents = 1 / self.model.vae.config.scaling_factor * latents.detach()
+ self.model.vae.to(dtype=torch.float32)
+ image = self.model.vae.decode(latents)['sample']
+ if return_type == 'np':
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ image = (image * 255).astype(np.uint8)
+ return image
+
+ @torch.no_grad()
+ @spaces.GPU
+ def image2latent(self, image):
+ with torch.no_grad():
+ if type(image) is Image:
+ image = np.array(image)
+ else:
+ if image.ndim==3:
+ image=np.expand_dims(image,0)
+ image = torch.from_numpy(image).float() / 127.5 - 1
+ image = image.permute(0, 3, 1, 2).to(self.device)
+ print(f"Running on device: {self.device}")
+ latents=[]
+ for i,_ in enumerate(image):
+ latent=self.model.vae.encode(image[i:i+1])['latent_dist'].mean
+ latents.append(latent)
+ latents=torch.stack(latents).squeeze(1)
+ latents = latents * self.model.vae.config.scaling_factor
+ return latents
+
+ @torch.no_grad()
+ def init_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ ):
+ original_size = original_size or (1024, 1024)
+ target_size = target_size or (1024, 1024)
+ # 3. Encode input prompt
+ do_classifier_free_guidance=True
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.model.encode_prompt_not_zero_uncond(
+ prompt,
+ self.model.device,
+ 1,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ lora_scale=None,
+ )
+ prompt_embeds=prompt_embeds[:self.inv_batch_size]
+ negative_prompt_embeds=negative_prompt_embeds[:self.inv_batch_size]
+ pooled_prompt_embeds=pooled_prompt_embeds[:self.inv_batch_size]
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[:self.inv_batch_size]
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self.model._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(self.device)
+ self.add_text_embeds = add_text_embeds.to(self.device)
+ self.add_time_ids = add_time_ids.to(self.device).repeat(self.inv_batch_size * 1, 1)
+
+ self.prompt_embeds=prompt_embeds
+ self.negative_prompt_embeds=negative_prompt_embeds
+ self.pooled_prompt_embeds=pooled_prompt_embeds
+ self.negative_pooled_prompt_embeds=negative_pooled_prompt_embeds
+ self.prompt = prompt
+ self.context=prompt_embeds
+
+ @torch.no_grad()
+ @spaces.GPU
+ def ddim_loop(self, latent):
+ uncond_embeddings, cond_embeddings = self.context.chunk(2)
+ all_latent = [latent]
+ latent = latent.clone().detach()
+ extra_step_kwargs = self.model.prepare_extra_step_kwargs(self.generator, self.eta)
+ if isinstance(self.inverse_scheduler,DDIMInverseScheduler):
+ extra_step_kwargs.pop("generator")
+ for i in tqdm(range(self.num_ddim_steps)):
+ use_inv_sc=False
+ if use_inv_sc:
+ t = self.inverse_scheduler.timesteps[i]
+ noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings,cond=True)
+ latent = self.inverse_scheduler.step(noise_pred, t, latent, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
+ noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings,cond=True)
+ latent = self.next_step(noise_pred, t, latent)
+ all_latent.append(latent)
+ return all_latent
+
+ @property
+ def scheduler(self):
+ return self.model.scheduler
+
+ @torch.no_grad()
+ @spaces.GPU
+ def ddim_inversion(self, image):
+ latent = self.image2latent(image)
+ image_rec = self.latent2image(latent)
+ ddim_latents = self.ddim_loop(latent.to(self.model.unet.dtype))
+ return image_rec, ddim_latents
+
+ from typing import Union, List, Dict
+ import numpy as np
+
+ @spaces.GPU
+ def invert(self, image_gt, prompt: Union[str, List[str]],
+ verbose=True, inv_output_pos=None, inv_batch_size=1):
+
+ self.inv_batch_size = inv_batch_size
+ self.init_prompt(prompt)
+ out_put_pos = 0 if inv_output_pos is None else inv_output_pos
+ self.out_put_pos = out_put_pos
+ if verbose:
+ print("DDIM inversion...")
+ image_rec, ddim_latents = self.ddim_inversion(image_gt)
+ if verbose:
+ print("Done.")
+ return (image_gt, image_rec), ddim_latents[-1], ddim_latents, self.prompt_embeds[self.prompt_embeds.shape[0]//2:], self.pooled_prompt_embeds
+
+ def __init__(self, model,num_ddim_steps,generator=None,scheduler_type="DDIM"):
+ self.model = model
+ self.tokenizer = self.model.tokenizer
+ self.num_ddim_steps=num_ddim_steps
+ if scheduler_type == "DDIM":
+ self.inverse_scheduler=DDIMInverseScheduler.from_config(self.model.scheduler.config)
+ self.inverse_scheduler.set_timesteps(num_ddim_steps)
+ elif scheduler_type=="DPMSolver":
+ self.inverse_scheduler=DPMSolverMultistepInverseScheduler.from_config(self.model.scheduler.config)
+ self.inverse_scheduler.set_timesteps(num_ddim_steps)
+ self.model.scheduler.set_timesteps(num_ddim_steps)
+ self.model.vae.to(dtype=torch.float32)
+ self.prompt = None
+ self.context = None
+ # self.device=self.model.unet.device
+ self.device = torch.device("cuda:0")
+ self.generator=generator
+ self.eta=0.0
+
+def load_512(image_path, left=0, right=0, top=0, bottom=0):
+ if type(image_path) is str:
+ image = np.array(Image.open(image_path))[:, :, :3]
+ else:
+ image = image_path
+ h, w, c = image.shape
+ left = min(left, w - 1)
+ right = min(right, w - left - 1)
+ top = min(top, h - left - 1)
+ bottom = min(bottom, h - top - 1)
+ image = image[top:h - bottom, left:w - right]
+ h, w, c = image.shape
+ if h < w:
+ offset = (w - h) // 2
+ image = image[:, offset:offset + h]
+ elif w < h:
+ offset = (h - w) // 2
+ image = image[offset:offset + w]
+ image = np.array(Image.fromarray(image).resize((512, 512)))
+ return image
+
+def load_1024_mask(image_path, left=0, right=0, top=0, bottom=0,target_H=128,target_W=128):
+ if type(image_path) is str:
+ image = np.array(Image.open(image_path))[:, :, np.newaxis]
+ else:
+ image = image_path
+ if len(image.shape) == 4:
+ image = image[:, :, :, 0]
+ h, w, c = image.shape
+ left = min(left, w - 1)
+ right = min(right, w - left - 1)
+ top = min(top, h - left - 1)
+ bottom = min(bottom, h - top - 1)
+ image = image[top:h - bottom, left:w - right]
+ h, w, c = image.shape
+ if h < w:
+ offset = (w - h) // 2
+ image = image[:, offset:offset + h]
+ elif w < h:
+ offset = (h - w) // 2
+ image = image[offset:offset + w]
+ image=image.squeeze()
+ image = np.array(Image.fromarray(image).resize((target_H, target_W)))
+ return image
+
+def load_1024(image_path, left=0, right=0, top=0, bottom=0):
+ if type(image_path) is str:
+ image = np.array(Image.open(image_path).resize((1024, 1024)))[:, :, :3]
+ else:
+ image = image_path
+ h, w, c = image.shape
+ left = min(left, w - 1)
+ right = min(right, w - left - 1)
+ top = min(top, h - left - 1)
+ bottom = min(bottom, h - top - 1)
+ image = image[top:h - bottom, left:w - right]
+ h, w, c = image.shape
+ if h < w:
+ offset = (w - h) // 2
+ image = image[:, offset:offset + h]
+ elif w < h:
+ offset = (h - w) // 2
+ image = image[offset:offset + w]
+ image = np.array(Image.fromarray(image).resize((1024, 1024)))
+ return image
\ No newline at end of file
diff --git a/utils/sdxl.py b/utils/sdxl.py
new file mode 100755
index 0000000000000000000000000000000000000000..30e3cee65b303fa704785adb6394a39eebb523a4
--- /dev/null
+++ b/utils/sdxl.py
@@ -0,0 +1,986 @@
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+# import seaborn as sns
+import matplotlib.pyplot as plt
+import torch
+from diffusers import StableDiffusionXLPipeline
+from typing import Optional, Union, Tuple, List, Callable, Dict
+import numpy as np
+import copy
+import torch.nn.functional as F
+from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, )
+from diffusers.utils import ( logging, randn_tensor, replace_example_docstring, )
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
+import os
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipeline
+
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+class sdxl(StableDiffusionXLPipeline):
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ @torch.no_grad()
+ def __call__(
+ self,
+ controller=None,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ same_init=False,
+ x_stars=None,
+ prox_guidance=True,
+ masa_control=False,
+ masa_mask=False,
+ masa_start_step=40,
+ masa_start_layer=55,
+ mask_file=None,
+ query_mask_time=[0, 10],
+ **kwargs
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ inv_batch_size = len(latents) if latents is not None else 1
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ sample_ref_match=kwargs['sample_ref_match'] if 'sample_ref_match' in kwargs else None,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ same_init=same_init, #ADD
+ sample_ref_match=kwargs['sample_ref_match'] if 'sample_ref_match' in kwargs else None,
+ )
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ # CHANGE START
+ score_delta,mask_edit=self.prox_regularization(
+ noise_pred_uncond,
+ noise_pred_text,
+ i,
+ t,
+ prox_guidance=prox_guidance,
+ )
+ if mask_edit is not None:
+ a = 1
+ noise_pred = noise_pred_uncond + guidance_scale * score_delta
+ # CHANGE END
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # ADD START
+ latents = self.proximal_guidance(
+ i,
+ t,
+ latents,
+ mask_edit,
+ prox_guidance=prox_guidance,
+ dtype=self.unet.dtype,
+ x_stars=x_stars,
+ controller=controller,
+ sample_ref_match=kwargs['sample_ref_match'] if 'sample_ref_match' in kwargs else None,
+ inv_batch_size=inv_batch_size,
+ only_inversion_align=kwargs['only_inversion_align'] if 'only_inversion_align' in kwargs else False,
+ )
+ # ADD END
+ if controller is not None:
+ latents = controller.step_callback(latents)
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ self.vae.to(dtype=torch.float32)
+
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(latents.dtype)
+ self.vae.decoder.conv_in.to(latents.dtype)
+ self.vae.decoder.mid_block.to(latents.dtype)
+ else:
+ latents = latents.float()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ image = self.watermark.apply_watermark(image)
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None,same_init=False,sample_ref_match=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if sample_ref_match is not None:
+ new_latents=randn_tensor((batch_size,*shape[1:]), generator=generator, device=device, dtype=dtype)
+ for key,value in sample_ref_match.items():
+ new_latents[key]=latents[value].clone()
+ latents=new_latents
+ else:
+ if same_init is True:
+ if latents is None:
+ latents = randn_tensor((1,*shape[1:]), generator=generator, device=device, dtype=dtype).expand(shape).to(device)
+ else:
+ if batch_size>1 and latents.shape[0]==1:
+ latents=latents.expand(shape).to(device)
+ else:
+ latents = latents.to(device)
+ else:
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def encode_prompt(
+ self,
+ prompt,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ sample_ref_match=None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+
+ negative_prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ bs_embed = pooled_prompt_embeds.shape[0]
+ # ADD START
+ if sample_ref_match is not None:
+ new_negative_prompt_embeds=torch.zeros_like(prompt_embeds)
+ new_negative_pooled_prompt_embeds=torch.zeros_like(pooled_prompt_embeds)
+ for key,value in sample_ref_match.items():
+ new_negative_prompt_embeds[key]=negative_prompt_embeds[value].clone()
+ new_negative_pooled_prompt_embeds[key]=negative_pooled_prompt_embeds[value].clone()
+ negative_prompt_embeds=new_negative_prompt_embeds
+ negative_pooled_prompt_embeds=new_negative_pooled_prompt_embeds
+ else:
+ if negative_pooled_prompt_embeds.shape[0]==1 and bs_embed!=1:
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.repeat(bs_embed,1)
+ if negative_prompt_embeds.shape[0]==1 and bs_embed!=1:
+ negative_prompt_embeds=negative_prompt_embeds.repeat(bs_embed,1,1)
+ # ADD END
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ def encode_prompt_not_zero_uncond(
+ self,
+ prompt,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device),output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ uncond_tokens: List[str]
+ if prompt is not None and isinstance(prompt,List) and negative_prompt == "":
+ negative_prompt = ["" for i in range(len(prompt))]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ negative_prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ bs_embed = pooled_prompt_embeds.shape[0]
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ def prox_regularization(
+ self,
+ noise_pred_uncond,
+ noise_pred_text,
+ i,
+ t,
+ prox_guidance=False,
+ prox=None,
+ quantile=0.75,
+ recon_t=400,
+ dilate_radius=2,
+ ):
+ if prox_guidance is True:
+ mask_edit = None
+ if prox == 'l1':
+ score_delta = (noise_pred_text - noise_pred_uncond).float()
+ if quantile > 0:
+ threshold = score_delta.abs().quantile(quantile)
+ else:
+ threshold = -quantile # if quantile is negative, use it as a fixed threshold
+ score_delta -= score_delta.clamp(-threshold, threshold)
+ score_delta = torch.where(score_delta > 0, score_delta-threshold, score_delta)
+ score_delta = torch.where(score_delta < 0, score_delta+threshold, score_delta)
+ if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t):
+ mask_edit = (score_delta.abs() > threshold).float()
+ if dilate_radius > 0:
+ radius = int(dilate_radius)
+ mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius)
+ elif prox == 'l0':
+ score_delta = (noise_pred_text - noise_pred_uncond).float()
+ if quantile > 0:
+ threshold = score_delta.abs().quantile(quantile)
+ else:
+ threshold = -quantile # if quantile is negative, use it as a fixed threshold
+ score_delta -= score_delta.clamp(-threshold, threshold)
+ if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t):
+ mask_edit = (score_delta.abs() > threshold).float()
+ if dilate_radius > 0:
+ radius = int(dilate_radius)
+ mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius)
+ elif prox==None:
+ score_delta = (noise_pred_text - noise_pred_uncond).float()
+ if quantile > 0:
+ threshold = score_delta.abs().quantile(quantile)
+ else:
+ threshold = -quantile # if quantile is negative, use it as a fixed threshold
+ if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t):
+ mask_edit = (score_delta.abs() > threshold).float()
+ if dilate_radius > 0:
+ radius = int(dilate_radius)
+ mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius)
+ else:
+ raise NotImplementedError
+ return score_delta,mask_edit
+ else:
+ return noise_pred_text - noise_pred_uncond,None
+
+ def proximal_guidance(
+ self,
+ i,
+ t,
+ latents,
+ mask_edit,
+ dtype,
+ prox_guidance=False,
+ recon_t=400,
+ recon_end=0,
+ recon_lr=0.1,
+ x_stars=None,
+ controller=None,
+ sample_ref_match=None,
+ inv_batch_size=1,
+ only_inversion_align=False,
+ ):
+ if mask_edit is not None and prox_guidance and (recon_t > recon_end and t < recon_t) or (recon_t < -recon_end and t > -recon_t):
+ if controller.layer_fusion.remove_mask is not None:
+ fix_mask = copy.deepcopy(controller.layer_fusion.remove_mask)
+ mask_edit[1] = (mask_edit[1]+fix_mask).clamp(0,1)
+ if mask_edit.shape[0] > 2:
+ mask_edit[2].fill_(1)
+ recon_mask = 1 - mask_edit
+ target_latents=x_stars[len(x_stars)-i-2]
+ new_target_latents=torch.zeros_like(latents)
+ for key,value in sample_ref_match.items():
+ new_target_latents[key]=target_latents[value].clone()
+ latents = latents - recon_lr * (latents - new_target_latents) * recon_mask
+ return latents.to(dtype)
+
+def slerp(val, low, high):
+ """ taken from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4
+ """
+ low_norm = low/torch.norm(low, dim=1, keepdim=True)
+ high_norm = high/torch.norm(high, dim=1, keepdim=True)
+ omega = torch.acos((low_norm*high_norm).sum(1))
+ so = torch.sin(omega)
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+ return res
+
+
+def slerp_tensor(val, low, high):
+ shape = low.shape
+ res = slerp(val, low.flatten(1), high.flatten(1))
+ return res.reshape(shape)
+
+
+def dilate(image, kernel_size, stride=1, padding=0):
+ """
+ Perform dilation on a binary image using a square kernel.
+ """
+ # Ensure the image is binary
+ assert image.max() <= 1 and image.min() >= 0
+
+ # Get the maximum value in each neighborhood
+ dilated_image = F.max_pool2d(image, kernel_size, stride, padding)
+
+ return dilated_image
+
+def exec_classifier_free_guidance(model,latents,controller,t,guidance_scale,
+ do_classifier_free_guidance,noise_pred,guidance_rescale,
+ prox=None, quantile=0.75,image_enc=None, recon_lr=0.1, recon_t=400,recon_end_t=0,
+ inversion_guidance=False, reconstruction_guidance=False,x_stars=None, i=0,
+ use_localblend_mask=False,
+ save_heatmap=False,**kwargs):
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ #noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ if prox is None and inversion_guidance is True:
+ prox = 'l1'
+ step_kwargs = {
+ 'ref_image': None,
+ 'recon_lr': 0,
+ 'recon_mask': None,
+ }
+ mask_edit = None
+ if prox is not None:
+ if prox == 'l1':
+ score_delta = (noise_pred_text - noise_pred_uncond).float()
+ if quantile > 0:
+ threshold = score_delta.abs().quantile(quantile)
+ else:
+ threshold = -quantile # if quantile is negative, use it as a fixed threshold
+ score_delta -= score_delta.clamp(-threshold, threshold)
+ score_delta = torch.where(score_delta > 0, score_delta-threshold, score_delta)
+ score_delta = torch.where(score_delta < 0, score_delta+threshold, score_delta)
+ if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t):
+ step_kwargs['ref_image'] = image_enc
+ step_kwargs['recon_lr'] = recon_lr
+ score_delta_norm=score_delta.abs()
+ score_delta_norm=(score_delta_norm - score_delta_norm.min ()) / (score_delta_norm.max () - score_delta_norm.min ())
+ mask_edit = (score_delta.abs() > threshold).float()
+ if save_heatmap and i%10==0:
+ for kk in range(4):
+ sns.heatmap(mask_edit[1][kk].clone().cpu(), cmap='coolwarm')
+ plt.savefig(f'./vis/prox_inv/heatmap1_mask_{i}_{kk}.png')
+ plt.clf()
+ if kwargs.get('dilate_mask', 2) > 0:
+ radius = int(kwargs.get('dilate_mask', 2))
+ mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius)
+ if save_heatmap and i%10==0:
+ for kk in range(4):
+ sns.heatmap(mask_edit[1][kk].clone().cpu(), cmap='coolwarm')
+ plt.savefig(f'./vis/prox_inv/heatmap1_mask_dilate_{i}_{kk}.png')
+ plt.clf()
+ step_kwargs['recon_mask'] = 1 - mask_edit
+ elif prox == 'l0':
+ score_delta = (noise_pred_text - noise_pred_uncond).float()
+ if quantile > 0:
+ threshold = score_delta.abs().quantile(quantile)
+ else:
+ threshold = -quantile # if quantile is negative, use it as a fixed threshold
+ score_delta -= score_delta.clamp(-threshold, threshold)
+ if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t):
+ step_kwargs['ref_image'] = image_enc
+ step_kwargs['recon_lr'] = recon_lr
+ mask_edit = (score_delta.abs() > threshold).float()
+ if kwargs.get('dilate_mask', 2) > 0:
+ radius = int(kwargs.get('dilate_mask', 2))
+ mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius)
+ step_kwargs['recon_mask'] = 1 - mask_edit
+ else:
+ raise NotImplementedError
+ noise_pred = (noise_pred_uncond + guidance_scale * score_delta).to(model.unet.dtype)
+ else:
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+ if reconstruction_guidance:
+ kwargs.update(step_kwargs)
+ latents = model.scheduler.step(noise_pred, t, latents, **kwargs, return_dict=False)[0]
+ if mask_edit is not None and inversion_guidance and (recon_t > recon_end_t and t < recon_t) or (recon_t < recon_end_t and t > -recon_t):
+ if use_localblend_mask:
+ assert hasattr(controller,"layer_fusion")
+ if save_heatmap and i%10==0:
+ sns.heatmap(controller.layer_fusion.mask[0][0].clone().cpu(), cmap='coolwarm')
+ plt.savefig(f'./vis/prox_inv/heatmap0_localblendmask_{i}.png')
+ plt.clf()
+ sns.heatmap(controller.layer_fusion.mask[1][0].clone().cpu(), cmap='coolwarm')
+ plt.savefig(f'./vis/prox_inv/heatmap1_localblendmask_{i}.png')
+ plt.clf()
+ layer_fusion_mask=controller.layer_fusion.mask.float()
+ layer_fusion_mask[0]=layer_fusion_mask[1]
+ recon_mask=1-layer_fusion_mask.expand_as(latents)
+ else:
+ recon_mask = 1 - mask_edit
+ target_latents=x_stars[len(x_stars)-i-2].expand_as(latents)
+ # if target_latents有四维
+ if len(target_latents.shape)==4:
+ target_latents=target_latents[0]
+ latents = latents - recon_lr * (latents - target_latents) * recon_mask
+ # controller
+ if controller is not None:
+ latents = controller.step_callback(latents)
+ return latents.to(model.unet.dtype)
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd71b6a15f9b9376b0ef161d40d00a78df790726
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,363 @@
+import cv2
+from matplotlib import pyplot as plt
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from datetime import datetime
+import os
+from typing import List, Dict
+
+def convert_and_resize_mask(mask):
+ if mask.ndim == 3:
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
+ resized_mask = cv2.resize(mask, (1024, 1024))
+ return resized_mask
+
+def add_masks_resized(masks):
+ final_mask = np.zeros((1024, 1024), dtype=np.uint8)
+ for mask in masks:
+ if mask is not None:
+ resized_mask = convert_and_resize_mask(mask)
+ resized_mask = resized_mask.astype(np.uint8)
+ final_mask = cv2.add(final_mask, resized_mask)
+ return final_mask
+
+def attend_mask(mask_file, attend_scale=10, save=False):
+ if isinstance(mask_file, str):
+ if mask_file == '':
+ return torch.zeros([1, 1, 128, 128], dtype=torch.float32).cuda()
+ else:
+ image_with_mask = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)
+ elif len(mask_file.shape) == 3: # convert RGB to gray
+ image_with_mask = cv2.cvtColor(mask_file, cv2.COLOR_BGR2GRAY)
+
+ else:
+ image_with_mask = mask_file
+
+ if attend_scale != 0:
+ kernel = np.ones((abs(attend_scale), abs(attend_scale)), np.uint8)
+ if attend_scale > 0:
+ image_with_mask = cv2.dilate(image_with_mask, kernel, iterations=1)
+ else:
+ image_with_mask = cv2.erode(image_with_mask, kernel, iterations=1)
+
+ if save and isinstance(mask_file, str):
+ new_mask_file_name = mask_file[:-4]+'_'+str(attend_scale)+'.jpg'
+ cv2.imwrite(new_mask_file_name, image_with_mask)
+ print("new_mask is saved in ", new_mask_file_name)
+
+ dilated_image= cv2.resize(image_with_mask, (128, 128), interpolation=cv2.INTER_NEAREST)
+ dilated_image = torch.from_numpy(dilated_image).to(torch.float32).unsqueeze(0).unsqueeze(0).cuda() / 255
+ return dilated_image
+
+
+def panning(img_path=None, op_list=[['left', 0.2]], save=False, save_dir=None):
+ if isinstance(img_path, str):
+ img = cv2.imread(img_path)
+ else:
+ img = img_path
+ img_new = img.copy()
+ img_height, img_width, _ = img.shape
+ w_mask = 255 * np.ones((img_height, img_width), dtype=np.uint8)
+ h_mask = 255 * np.ones((img_height, img_width), dtype=np.uint8)
+
+ for op in op_list:
+ scale = op[1]
+ if op[0] in ['right', 'left']:
+ K = int(scale*img_width)
+ elif op[0] in ['up', 'down']:
+ K = int(scale*img_height)
+
+ if op[0] == 'right':
+ img_new[:, K:, :] = img[:, 0:img_width-K, :]
+ w_mask[:, K:] = 0
+ elif op[0] == 'left':
+ img_new[:, 0:img_width-K, :] = img[:, K:, :]
+ w_mask[:, 0:img_width-K] = 0
+ elif op[0] == 'down':
+ img_new[K:, :, :] = img[0:img_height-K, :, :]
+ h_mask[K:, :] = 0
+ elif op[0] == 'up':
+ img_new[0:img_height-K, :, :] = img[K:, :, :]
+ h_mask[0:img_height-K, :] = 0
+ img = img_new
+
+ mask = w_mask + h_mask
+ mask[mask>0] = 255
+
+ if save:
+ if save_dir is None:
+ base_dir = os.path.dirname(img_path)
+ save_dir = os.path.join(base_dir, 'preprocess')
+ elif not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ resized_img_name = f"{save_dir}/resized_image.png"
+ resized_mask_name = f"{save_dir}/resized_mask.png"
+ cv2.imwrite(resized_img_name, img_new)
+ cv2.imwrite(resized_mask_name, mask)
+ return resized_img_name, resized_mask_name
+ else:
+ return img_new, mask
+
+def zooming(img_path=None, scale=[0.8, 0.8], save=False, save_dir=None):
+ if isinstance(img_path, str):
+ img = cv2.imread(img_path)
+ else:
+ img = img_path
+ img_new = img.copy()
+ img_height, img_width, _ = img.shape
+ mask = 255 * np.ones((img_height, img_width), dtype=np.uint8)
+
+ new_height = int(img_height*scale[0])
+ new_width = int(img_width*scale[1])
+ resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
+ x_offset = (img_width - new_width) // 2
+ y_offset = (img_height - new_height) // 2
+
+ img_new[y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized_img
+ mask[y_offset:y_offset + new_height, x_offset:x_offset + new_width] = 0
+
+ if save:
+ if save_dir is None:
+ base_dir = os.path.dirname(img_path)
+ save_dir = os.path.join(base_dir, 'preprocess')
+ elif not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ resized_img_name = f"{save_dir}/resized_image.png"
+ resized_mask_name = f"{save_dir}/resized_mask.png"
+ cv2.imwrite(resized_img_name, img_new)
+ cv2.imwrite(resized_mask_name, mask)
+ return resized_img_name, resized_mask_name
+ else:
+ return img_new, mask
+
+def get_box(mask, bias = 2):
+ nonzero_indices = torch.nonzero(mask)
+ H, W = mask.shape[-2:]
+ min_x = max(min(nonzero_indices[:, 1]) - bias, 0)
+ min_y = max(min(nonzero_indices[:, 0]) - bias, 0)
+ max_x = min(max(nonzero_indices[:, 1]) + bias, W)
+ max_y = min(max(nonzero_indices[:, 0]) + bias, H)
+ return (min_x, min_y, max_x, max_y)
+
+
+def draw_axis(img,grid_dict,x_len,y_len):
+ if grid_dict is not None and grid_dict is not False:
+ assert isinstance(grid_dict,Dict)
+ assert "x_title" in grid_dict
+ assert "y_title" in grid_dict
+ assert "x_text_list" in grid_dict
+ assert "y_text_list" in grid_dict
+ x_title=grid_dict["x_title"]
+ y_title=grid_dict["y_title"]
+ x_text_list=grid_dict['x_text_list']
+ y_text_list=grid_dict['y_text_list']
+ assert len(y_text_list)==y_len
+ assert len(x_text_list)==x_len
+ assert "font_size" in grid_dict
+ font_size=grid_dict["font_size"]
+ if "x_color" in grid_dict:
+ color_x=grid_dict['x_color']
+ else:
+ color_x="black"
+ if "y_color" in grid_dict:
+ color_y=grid_dict['y_color']
+ else:
+ color_y="black"
+ if "num_decimals" in grid_dict:
+ num_decimals=grid_dict['num_decimals']
+ else:
+ num_decimals=2
+ if "shift_x" in grid_dict:
+ shift_x_x,shift_x_y=grid_dict['shift_x']
+ else:
+ shift_x_x=shift_x_y=0
+ if "shift_y" in grid_dict:
+ shift_y_x,shift_y_y=grid_dict['shift_y']
+ else:
+ shift_y_x=shift_y_y=0
+ if "title" in grid_dict:
+ title=grid_dict['title']
+ if isinstance(title,List):
+ all_title=""
+ for s in title:
+ all_title=all_title+s+"\n"
+ title=all_title
+ else:
+ title=''
+ width, height = img.size
+ num_x=x_len
+ num_y=y_len
+
+ new_img = Image.new("RGB", (width + width // num_x+width // (num_x*2), height + height // num_y+height // (num_y*2)), color=(255, 255, 255))
+ width,height=(width + width // num_x, height + height // num_y)
+ num_x=num_x+1
+ num_y=num_y+1
+ new_img.paste(img, (width // num_x, height // num_y))
+
+ draw = ImageDraw.Draw(new_img)
+
+ font = ImageFont.truetype("DejaVuSansMono.ttf", font_size)
+ for i in range(2, num_x+1):
+ x = (i - 1) * width // num_x + width // (num_x * 2)-width *0.2// num_x+shift_x_x
+ y = height // (num_y * 2)+shift_x_y
+ k=i-1
+ if isinstance(x_text_list[i-2],str):
+ draw.text((x, y), x_text_list[i-2], font=font,fill=color_x,align="center")
+ else:
+ draw.text((x, y), "{:.{}f}".format(x_text_list[i-2],num_decimals), font=font,fill=color_x,align="center")
+
+ for i in range(2, num_y+1):
+ x = width // (num_x * 2)-width *0.1// num_x+shift_y_x
+ y = (i - 1) * height // num_y + height // (num_y * 2)-height*0.1//num_y+shift_y_y
+ k = i - 1
+ if isinstance(y_text_list[i-2],str):
+ draw.text((x, y), y_text_list[i-2], font=font,fill=color_y,align="center")
+ else:
+ draw.text((x, y), "{:.{}f}".format(y_text_list[i-2],num_decimals), font=font,fill=color_y,align="center")
+ i=1
+ x = (i - 1) * width // num_x + width // (num_x * 2)-height*0.1//num_y+shift_y_x
+ y = height // (num_y * 2)+width *0.2// num_x+shift_y_y
+ draw.text((x, y), y_title, font=font, fill=color_y,align="center")
+ x = width // (num_x * 2)+width *0.2// num_x+shift_x_x
+ y = (i - 1) * height // num_y + height // (num_y * 2)+shift_x_y
+ draw.text((x, y), x_title, font=font, fill=color_x,align="left")
+ x = width // 4
+ y = (i - 1) * height // num_y + height // (num_y * 10)
+ draw.text((x, y), title, font=font, fill='blue',align="left")
+ else:
+
+ new_img=img
+ return new_img
+
+def view_images(images, num_rows=1, offset_ratio=0.02,text="",folder=None,Notimestamp=False,
+grid_dict=None,subfolder=None,verbose=True,output_dir=None,timestamp=None,**kwargs):
+ if type(images) is list:
+ num_empty = len(images) % num_rows
+ elif images.ndim == 4:
+ num_empty = images.shape[0] % num_rows
+ else:
+ images = [images]
+ num_empty = 0
+ origin_size=kwargs.get("origin_size",None)
+ images_copy=images.copy()
+ for i, per_image in enumerate(images_copy):
+ if isinstance(per_image, Image.Image) and origin_size is not None:
+ images[i] = np.array(per_image.resize((origin_size[1],origin_size[0])))
+ else:
+ images[i] = np.array(per_image)
+
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
+ num_items = len(images)
+
+ h, w, c = images[0].shape
+ offset = int(h * offset_ratio)
+ num_cols = num_items // num_rows
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
+ for i in range(num_rows):
+ for j in range(num_cols):
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
+ i * num_cols + j]
+
+ pil_img = Image.fromarray(image_)
+
+ pil_img_=draw_axis(pil_img,grid_dict,num_cols,num_rows)
+ if pil_img_.size[0]==pil_img_.size[1]:
+ pil_img_.resize((2048,2048))
+ else:
+ longer_side = max(pil_img.size)
+ ratio = 2048/longer_side
+ new_size = tuple([int(x*ratio) for x in pil_img.size])
+ pil_img = pil_img.resize(new_size)
+
+ if verbose is False:
+ return pil_img
+ now = datetime.now()
+ if timestamp is None:
+ if Notimestamp is False:
+ timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")
+ else:
+ timestamp=""
+ if output_dir is None:
+ if timestamp != "":
+ date, time = timestamp.split('_')
+ else:
+ date, time = "",""
+ if folder is not None:
+ dirname="./"+folder
+ filename = text+f"img_{timestamp}.jpg"
+ else:
+ if subfolder is not None:
+ dirname=os.path.join("./img", subfolder,date)
+ dirname=os.path.join(dirname,time)
+ filename =text+f"img_{timestamp}.jpg"
+ else:
+ dirname=os.path.join("./img",date)
+ dirname=os.path.join(dirname,time)
+ filename =text+f"img_{timestamp}.jpg"
+ else:
+ dirname=output_dir
+ filename =text+f"img_{timestamp}.jpg"
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ if verbose is True:
+ for i, img in enumerate(images):
+ im = Image.fromarray(img)
+ im.save(os.path.join(dirname,f"{i}.jpg"))
+ print(f"Output dir: {dirname}")
+ pil_img.save(os.path.join(dirname, filename))
+ if grid_dict is not None and grid_dict is not False:
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ pil_img_.save(os.path.join(dirname, filename[:-4]+"_2048x.jpg"))
+
+def resize_image_with_mask(img, mask, scale):
+ if scale == 1:
+ return img, mask, None
+ img_blackboard = img.copy() # canvas
+ mask_blackboard = np.zeros_like(mask)
+
+ M = cv2.moments(mask)
+ cx = int(M["m10"] / M["m00"])
+ cy = int(M["m01"] / M["m00"])
+
+ scale_factor = [scale, scale]
+ resized_img = cv2.resize(img, None, fx=scale_factor[0], fy=scale_factor[1], interpolation=cv2.INTER_AREA)
+ resized_mask = cv2.resize(mask, None, fx=scale_factor[0], fy=scale_factor[1], interpolation=cv2.INTER_AREA)
+ new_cx, new_cy = cx * scale_factor[0], cy * scale_factor[1]
+
+ for y in range(resized_mask.shape[0]):
+ for x in range(resized_mask.shape[1]):
+ if 0 <= cy - (new_cy - y) < img.shape[0] and 0 <= cx - (new_cx - x) < img.shape[1]:
+ mask_blackboard[int(cy - (new_cy - y)), int(cx - (new_cx - x))] = resized_mask[y, x]
+ img_blackboard[int(cy - (new_cy - y)), int(cx - (new_cx - x))] = resized_img[y, x]
+ return img_blackboard, mask_blackboard, (cx, cy)
+
+def flip_image_with_mask(img, mask, flip_code=None):
+ if flip_code is None:
+ return img, mask, None
+ M = cv2.moments(mask)
+ if M["m00"] == 0:
+ return img, mask
+ cx = int(M["m10"] / M["m00"])
+ cy = int(M["m01"] / M["m00"])
+
+ h, w = img.shape[:2]
+ img_center = (w // 2, h // 2)
+
+ tx = img_center[0] - cx
+ ty = img_center[1] - cy
+
+ M_translate = np.float32([[1, 0, tx], [0, 1, ty]])
+ img_translated = cv2.warpAffine(img, M_translate, (w, h))
+ mask_translated = cv2.warpAffine(mask, M_translate, (w, h))
+ flipped_img = cv2.flip(img_translated, flip_code)
+ flipped_mask = cv2.flip(mask_translated, flip_code)
+ M_translate_back = np.float32([[1, 0, -tx], [0, 1, -ty]])
+ flipped_img_back = cv2.warpAffine(flipped_img, M_translate_back, (w, h))
+ flipped_mask_back = cv2.warpAffine(flipped_mask, M_translate_back, (w, h))
+
+ return flipped_img_back, flipped_mask_back, (cx, cy)