anonymous commited on
Commit
19d207d
1 Parent(s): ce9d640
Files changed (2) hide show
  1. app.py +197 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ from glob import glob
3
+ import yaml
4
+ import os
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import torchvision
9
+ import safetensors
10
+ from diffusers import AutoencoderKL
11
+ from peft import get_peft_model, LoraConfig, set_peft_model_state_dict
12
+ from huggingface_hub import snapshot_download
13
+
14
+ pretrained_model_path = snapshot_download(repo_id="revp2024/revp-censorship")
15
+ with open(glob(os.path.join(pretrained_model_path, 'hparams.yml'), recursive=True)[0]) as f:
16
+ args = Namespace(**yaml.safe_load(f))
17
+
18
+ def prepare_model():
19
+ print('Loading model ...')
20
+ vae_lora_config = LoraConfig(
21
+ r=args.rank,
22
+ lora_alpha=args.rank,
23
+ init_lora_weights="gaussian",
24
+ target_modules=["conv", "conv1", "conv2",
25
+ "to_q", "to_k", "to_v", "to_out.0"],
26
+ )
27
+ vae = AutoencoderKL.from_pretrained(
28
+ args.pretrained_model_name_or_path, subfolder="vae"
29
+ )
30
+ vae = get_peft_model(vae, vae_lora_config)
31
+ lora_weights_path = os.path.join(pretrained_model_path, f"pytorch_lora_weights.safetensors")
32
+ state_dict = {}
33
+ with safetensors.torch.safe_open(lora_weights_path, framework="pt", device="cpu") as f:
34
+ for key in f.keys():
35
+ state_dict[key] = f.get_tensor(key)
36
+
37
+ set_peft_model_state_dict(vae, state_dict)
38
+
39
+ print('Done.')
40
+ return vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
41
+
42
+
43
+ @torch.no_grad()
44
+ def add_censorship(input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size):
45
+ background, layers, _ = input_image.values()
46
+ input_images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
47
+ mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
48
+
49
+ H, W = input_images.shape[-2:]
50
+ if H > 1024 or W > 1024:
51
+ H_t, W_t = H, W
52
+ if H > W:
53
+ H, W = 1024, int(1024 * W_t / H_t)
54
+ else:
55
+ H, W = int(1024 * H_t / W_t), 1024
56
+ H_q8 = (H // 8) * 8
57
+ W_q8 = (W // 8) * 8
58
+ input_images = torch.nn.functional.interpolate(input_images, (H_q8, W_q8), mode='bilinear')
59
+ mask = torch.nn.functional.interpolate(mask, (H_q8, W_q8))
60
+ if soft_edges:
61
+ mask = torchvision.transforms.functional.gaussian_blur(mask, soft_edge_kernel_size)[0][0]
62
+
63
+ input_images = input_images.to(vae.device)
64
+
65
+ if mode == 'Pixelation':
66
+ censored = torch.nn.functional.avg_pool2d(
67
+ input_images, pixelation_block_size)
68
+ censored = torch.nn.functional.interpolate(censored, input_images.shape[-2:])
69
+ elif mode == 'Gaussian blur':
70
+ censored = torchvision.transforms.functional.gaussian_blur(
71
+ input_images, blur_kernel_size)
72
+ elif mode == 'Black':
73
+ censored = torch.zeros_like(input_images)
74
+ else:
75
+ raise ValueError("censor_mode has to be either `pixelation' or `gaussian_blur'")
76
+
77
+ mask = mask.to(input_images.device)
78
+ censored_images = input_images * (1 - mask) + censored * mask
79
+ censored_images *= 255
80
+
81
+ input_images = input_images * 2 - 1
82
+ with vae.disable_adapter():
83
+ latents = vae.encode(input_images).latent_dist.mean
84
+ images = vae.decode(latents, return_dict=False)[0]
85
+
86
+ # denormalize
87
+ images = images / 2 + 0.5
88
+ images *= 255
89
+
90
+ residuals = (images - censored_images).clamp(-args.budget, args.budget)
91
+ images = (censored_images + residuals).clamp(0, 255).to(torch.uint8)
92
+
93
+ gr.Info("Try to donwload/copy the censored image to the `Remove censorsip' tab")
94
+ return images[0].permute(1, 2, 0).cpu().numpy()
95
+
96
+ @torch.no_grad()
97
+ def remove_censorship(input_image, x1, y1, x2, y2):
98
+ background, layers, _ = input_image.values()
99
+ images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
100
+ mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
101
+ images = images * (1 - mask)
102
+ images = images[..., y1:y2, x1:x2]
103
+ latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean
104
+ with vae.disable_adapter():
105
+ images = vae.decode(latents, return_dict=False)[0]
106
+ # denormalize
107
+ images = images / 2 + 0.5
108
+ images *= 255
109
+ images = images.clamp(0, 255).to(torch.uint8)
110
+ return images[0].permute(1, 2, 0).cpu().numpy()
111
+
112
+ # @@@@@@@ Start of the program @@@@@@@@
113
+
114
+ vae = prepare_model()
115
+
116
+ css = '''
117
+ .my-disabled {
118
+ background-color: #eee;
119
+ }
120
+ .my-disabled input {
121
+ background-color: #eee;
122
+ }
123
+ '''
124
+ with gr.Blocks(css=css) as demo:
125
+ gr.Markdown('# ReVP: Reversible Visual Processing with Latent Models')
126
+ with gr.Tab('Add censorship'):
127
+ with gr.Row():
128
+ with gr.Column():
129
+ input_image = gr.ImageEditor(brush=gr.Brush(default_size=100))
130
+ with gr.Accordion('Options', open=False) as options_accord:
131
+ mode = gr.Radio(label='Mode', choices=['Pixelation', 'Gaussian blur', 'Black'],
132
+ value='Pixelation', interactive=True)
133
+ pixelation_block_size = gr.Slider(label='Block size', minimum=10, maximum=40, value=25, step=1, interactive=True)
134
+ blur_kernel_size = gr.Slider(label='Blur kernel size', minimum=21, maximum=151, value=85, step=2, interactive=True, visible=False)
135
+ def change_mode(mode):
136
+ if mode == 'Gaussian blur':
137
+ return gr.Slider(visible=False), gr.Slider(visible=True), gr.Accordion(open=True)
138
+ elif mode == 'Pixelation':
139
+ return gr.Slider(visible=True), gr.Slider(visible=False), gr.Accordion(open=True)
140
+ elif mode == 'Black':
141
+ return gr.Slider(visible=False), gr.Slider(visible=False), gr.Accordion(open=True)
142
+ else:
143
+ raise NotImplementedError
144
+ mode.select(change_mode, mode, [pixelation_block_size, blur_kernel_size, options_accord])
145
+ with gr.Row(variant='panel'):
146
+ soft_edges = gr.Checkbox(label='Soft edges', value=True, interactive=True, scale=1)
147
+ soft_edge_kernel_size = gr.Slider(label='Soft edge kernel size', minimum=21, maximum=49, value=35, step=2, interactive=True, visible=True, scale=2)
148
+ def change_soft_edges(soft_edges):
149
+ return gr.Slider(visible=True if soft_edges else False), gr.Accordion(open=True)
150
+ soft_edges.change(change_soft_edges, soft_edges, [soft_edge_kernel_size, options_accord])
151
+ submit_btn = gr.Button('Submit')
152
+ output_image = gr.Image(label='Censored', show_download_button=True)
153
+
154
+ submit_btn.click(
155
+ fn=add_censorship,
156
+ inputs=[input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size],
157
+ outputs=output_image
158
+ )
159
+
160
+ with gr.Tab('Remove censorship'):
161
+ with gr.Row():
162
+ with gr.Column():
163
+ input_image = gr.ImageEditor()
164
+ with gr.Accordion('Manual cropping', open=False):
165
+ with gr.Row():
166
+ with gr.Row():
167
+ x1 = gr.Number(value=0, label='x1')
168
+ y1 = gr.Number(value=0, label='y1')
169
+ with gr.Row():
170
+ x2_ = gr.Number(value=10000, label='x2', interactive=False, elem_classes='my-disabled')
171
+ y1_ = gr.Number(value=0, label='y1', interactive=False, elem_classes='my-disabled')
172
+ with gr.Row():
173
+ with gr.Row():
174
+ x1_ =gr.Number(value=0, label='x1', elem_classes='my-disabled')
175
+ y2_ = gr.Number(value=10000, label='y2', elem_classes='my-disabled')
176
+ with gr.Row():
177
+ x2 = gr.Number(value=10000, label='x2')
178
+ y2 = gr.Number(value=10000, label='y2')
179
+ submit_btn = gr.Button('Submit')
180
+ output_image = gr.Image(label='Uncensored')
181
+
182
+ submit_btn.click(
183
+ fn=remove_censorship,
184
+ inputs=[input_image, x1, y1, x2, y2],
185
+ outputs=output_image
186
+ )
187
+
188
+ # sync coordinate on changed
189
+ x1.change(lambda x : x, x1, x1_)
190
+ x2.change(lambda x : x, x2, x2_)
191
+ y1.change(lambda x : x, y1, y1_)
192
+ y2.change(lambda x : x, y2, y2_)
193
+
194
+ if __name__ == '__main__':
195
+ demo.queue(4)
196
+ demo.launch()
197
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ huggingface_hub
4
+ accelerate
5
+ transformers
6
+ datasets
7
+ diffusers
8
+ peft
9
+ safetensors
10
+ gradio==4.26.0
11
+ pyyaml
12
+