Spaces:
Runtime error
Runtime error
anonymous
commited on
Commit
•
19d207d
1
Parent(s):
ce9d640
update
Browse files- app.py +197 -0
- requirements.txt +12 -0
app.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
from glob import glob
|
3 |
+
import yaml
|
4 |
+
import os
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import torch
|
8 |
+
import torchvision
|
9 |
+
import safetensors
|
10 |
+
from diffusers import AutoencoderKL
|
11 |
+
from peft import get_peft_model, LoraConfig, set_peft_model_state_dict
|
12 |
+
from huggingface_hub import snapshot_download
|
13 |
+
|
14 |
+
pretrained_model_path = snapshot_download(repo_id="revp2024/revp-censorship")
|
15 |
+
with open(glob(os.path.join(pretrained_model_path, 'hparams.yml'), recursive=True)[0]) as f:
|
16 |
+
args = Namespace(**yaml.safe_load(f))
|
17 |
+
|
18 |
+
def prepare_model():
|
19 |
+
print('Loading model ...')
|
20 |
+
vae_lora_config = LoraConfig(
|
21 |
+
r=args.rank,
|
22 |
+
lora_alpha=args.rank,
|
23 |
+
init_lora_weights="gaussian",
|
24 |
+
target_modules=["conv", "conv1", "conv2",
|
25 |
+
"to_q", "to_k", "to_v", "to_out.0"],
|
26 |
+
)
|
27 |
+
vae = AutoencoderKL.from_pretrained(
|
28 |
+
args.pretrained_model_name_or_path, subfolder="vae"
|
29 |
+
)
|
30 |
+
vae = get_peft_model(vae, vae_lora_config)
|
31 |
+
lora_weights_path = os.path.join(pretrained_model_path, f"pytorch_lora_weights.safetensors")
|
32 |
+
state_dict = {}
|
33 |
+
with safetensors.torch.safe_open(lora_weights_path, framework="pt", device="cpu") as f:
|
34 |
+
for key in f.keys():
|
35 |
+
state_dict[key] = f.get_tensor(key)
|
36 |
+
|
37 |
+
set_peft_model_state_dict(vae, state_dict)
|
38 |
+
|
39 |
+
print('Done.')
|
40 |
+
return vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
41 |
+
|
42 |
+
|
43 |
+
@torch.no_grad()
|
44 |
+
def add_censorship(input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size):
|
45 |
+
background, layers, _ = input_image.values()
|
46 |
+
input_images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
|
47 |
+
mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
|
48 |
+
|
49 |
+
H, W = input_images.shape[-2:]
|
50 |
+
if H > 1024 or W > 1024:
|
51 |
+
H_t, W_t = H, W
|
52 |
+
if H > W:
|
53 |
+
H, W = 1024, int(1024 * W_t / H_t)
|
54 |
+
else:
|
55 |
+
H, W = int(1024 * H_t / W_t), 1024
|
56 |
+
H_q8 = (H // 8) * 8
|
57 |
+
W_q8 = (W // 8) * 8
|
58 |
+
input_images = torch.nn.functional.interpolate(input_images, (H_q8, W_q8), mode='bilinear')
|
59 |
+
mask = torch.nn.functional.interpolate(mask, (H_q8, W_q8))
|
60 |
+
if soft_edges:
|
61 |
+
mask = torchvision.transforms.functional.gaussian_blur(mask, soft_edge_kernel_size)[0][0]
|
62 |
+
|
63 |
+
input_images = input_images.to(vae.device)
|
64 |
+
|
65 |
+
if mode == 'Pixelation':
|
66 |
+
censored = torch.nn.functional.avg_pool2d(
|
67 |
+
input_images, pixelation_block_size)
|
68 |
+
censored = torch.nn.functional.interpolate(censored, input_images.shape[-2:])
|
69 |
+
elif mode == 'Gaussian blur':
|
70 |
+
censored = torchvision.transforms.functional.gaussian_blur(
|
71 |
+
input_images, blur_kernel_size)
|
72 |
+
elif mode == 'Black':
|
73 |
+
censored = torch.zeros_like(input_images)
|
74 |
+
else:
|
75 |
+
raise ValueError("censor_mode has to be either `pixelation' or `gaussian_blur'")
|
76 |
+
|
77 |
+
mask = mask.to(input_images.device)
|
78 |
+
censored_images = input_images * (1 - mask) + censored * mask
|
79 |
+
censored_images *= 255
|
80 |
+
|
81 |
+
input_images = input_images * 2 - 1
|
82 |
+
with vae.disable_adapter():
|
83 |
+
latents = vae.encode(input_images).latent_dist.mean
|
84 |
+
images = vae.decode(latents, return_dict=False)[0]
|
85 |
+
|
86 |
+
# denormalize
|
87 |
+
images = images / 2 + 0.5
|
88 |
+
images *= 255
|
89 |
+
|
90 |
+
residuals = (images - censored_images).clamp(-args.budget, args.budget)
|
91 |
+
images = (censored_images + residuals).clamp(0, 255).to(torch.uint8)
|
92 |
+
|
93 |
+
gr.Info("Try to donwload/copy the censored image to the `Remove censorsip' tab")
|
94 |
+
return images[0].permute(1, 2, 0).cpu().numpy()
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def remove_censorship(input_image, x1, y1, x2, y2):
|
98 |
+
background, layers, _ = input_image.values()
|
99 |
+
images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
|
100 |
+
mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
|
101 |
+
images = images * (1 - mask)
|
102 |
+
images = images[..., y1:y2, x1:x2]
|
103 |
+
latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean
|
104 |
+
with vae.disable_adapter():
|
105 |
+
images = vae.decode(latents, return_dict=False)[0]
|
106 |
+
# denormalize
|
107 |
+
images = images / 2 + 0.5
|
108 |
+
images *= 255
|
109 |
+
images = images.clamp(0, 255).to(torch.uint8)
|
110 |
+
return images[0].permute(1, 2, 0).cpu().numpy()
|
111 |
+
|
112 |
+
# @@@@@@@ Start of the program @@@@@@@@
|
113 |
+
|
114 |
+
vae = prepare_model()
|
115 |
+
|
116 |
+
css = '''
|
117 |
+
.my-disabled {
|
118 |
+
background-color: #eee;
|
119 |
+
}
|
120 |
+
.my-disabled input {
|
121 |
+
background-color: #eee;
|
122 |
+
}
|
123 |
+
'''
|
124 |
+
with gr.Blocks(css=css) as demo:
|
125 |
+
gr.Markdown('# ReVP: Reversible Visual Processing with Latent Models')
|
126 |
+
with gr.Tab('Add censorship'):
|
127 |
+
with gr.Row():
|
128 |
+
with gr.Column():
|
129 |
+
input_image = gr.ImageEditor(brush=gr.Brush(default_size=100))
|
130 |
+
with gr.Accordion('Options', open=False) as options_accord:
|
131 |
+
mode = gr.Radio(label='Mode', choices=['Pixelation', 'Gaussian blur', 'Black'],
|
132 |
+
value='Pixelation', interactive=True)
|
133 |
+
pixelation_block_size = gr.Slider(label='Block size', minimum=10, maximum=40, value=25, step=1, interactive=True)
|
134 |
+
blur_kernel_size = gr.Slider(label='Blur kernel size', minimum=21, maximum=151, value=85, step=2, interactive=True, visible=False)
|
135 |
+
def change_mode(mode):
|
136 |
+
if mode == 'Gaussian blur':
|
137 |
+
return gr.Slider(visible=False), gr.Slider(visible=True), gr.Accordion(open=True)
|
138 |
+
elif mode == 'Pixelation':
|
139 |
+
return gr.Slider(visible=True), gr.Slider(visible=False), gr.Accordion(open=True)
|
140 |
+
elif mode == 'Black':
|
141 |
+
return gr.Slider(visible=False), gr.Slider(visible=False), gr.Accordion(open=True)
|
142 |
+
else:
|
143 |
+
raise NotImplementedError
|
144 |
+
mode.select(change_mode, mode, [pixelation_block_size, blur_kernel_size, options_accord])
|
145 |
+
with gr.Row(variant='panel'):
|
146 |
+
soft_edges = gr.Checkbox(label='Soft edges', value=True, interactive=True, scale=1)
|
147 |
+
soft_edge_kernel_size = gr.Slider(label='Soft edge kernel size', minimum=21, maximum=49, value=35, step=2, interactive=True, visible=True, scale=2)
|
148 |
+
def change_soft_edges(soft_edges):
|
149 |
+
return gr.Slider(visible=True if soft_edges else False), gr.Accordion(open=True)
|
150 |
+
soft_edges.change(change_soft_edges, soft_edges, [soft_edge_kernel_size, options_accord])
|
151 |
+
submit_btn = gr.Button('Submit')
|
152 |
+
output_image = gr.Image(label='Censored', show_download_button=True)
|
153 |
+
|
154 |
+
submit_btn.click(
|
155 |
+
fn=add_censorship,
|
156 |
+
inputs=[input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size],
|
157 |
+
outputs=output_image
|
158 |
+
)
|
159 |
+
|
160 |
+
with gr.Tab('Remove censorship'):
|
161 |
+
with gr.Row():
|
162 |
+
with gr.Column():
|
163 |
+
input_image = gr.ImageEditor()
|
164 |
+
with gr.Accordion('Manual cropping', open=False):
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Row():
|
167 |
+
x1 = gr.Number(value=0, label='x1')
|
168 |
+
y1 = gr.Number(value=0, label='y1')
|
169 |
+
with gr.Row():
|
170 |
+
x2_ = gr.Number(value=10000, label='x2', interactive=False, elem_classes='my-disabled')
|
171 |
+
y1_ = gr.Number(value=0, label='y1', interactive=False, elem_classes='my-disabled')
|
172 |
+
with gr.Row():
|
173 |
+
with gr.Row():
|
174 |
+
x1_ =gr.Number(value=0, label='x1', elem_classes='my-disabled')
|
175 |
+
y2_ = gr.Number(value=10000, label='y2', elem_classes='my-disabled')
|
176 |
+
with gr.Row():
|
177 |
+
x2 = gr.Number(value=10000, label='x2')
|
178 |
+
y2 = gr.Number(value=10000, label='y2')
|
179 |
+
submit_btn = gr.Button('Submit')
|
180 |
+
output_image = gr.Image(label='Uncensored')
|
181 |
+
|
182 |
+
submit_btn.click(
|
183 |
+
fn=remove_censorship,
|
184 |
+
inputs=[input_image, x1, y1, x2, y2],
|
185 |
+
outputs=output_image
|
186 |
+
)
|
187 |
+
|
188 |
+
# sync coordinate on changed
|
189 |
+
x1.change(lambda x : x, x1, x1_)
|
190 |
+
x2.change(lambda x : x, x2, x2_)
|
191 |
+
y1.change(lambda x : x, y1, y1_)
|
192 |
+
y2.change(lambda x : x, y2, y2_)
|
193 |
+
|
194 |
+
if __name__ == '__main__':
|
195 |
+
demo.queue(4)
|
196 |
+
demo.launch()
|
197 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
huggingface_hub
|
4 |
+
accelerate
|
5 |
+
transformers
|
6 |
+
datasets
|
7 |
+
diffusers
|
8 |
+
peft
|
9 |
+
safetensors
|
10 |
+
gradio==4.26.0
|
11 |
+
pyyaml
|
12 |
+
|