Spaces:
Runtime error
Runtime error
anonymous
commited on
Commit
•
37119a2
1
Parent(s):
328dfe7
add consistency decoder
Browse files- app.py +15 -8
- examples/remove_censorship.yaml +3 -0
app.py
CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
|
|
8 |
import torch
|
9 |
import torchvision
|
10 |
import safetensors
|
11 |
-
from diffusers import AutoencoderKL
|
12 |
from peft import get_peft_model, LoraConfig, set_peft_model_state_dict
|
13 |
from huggingface_hub import snapshot_download
|
14 |
|
@@ -43,7 +43,10 @@ def prepare_model():
|
|
43 |
set_peft_model_state_dict(vae, state_dict)
|
44 |
|
45 |
print('Done.')
|
46 |
-
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
@spaces.GPU
|
@@ -102,15 +105,18 @@ def add_censorship(input_image, mode, pixelation_block_size, blur_kernel_size, s
|
|
102 |
|
103 |
@spaces.GPU
|
104 |
@torch.no_grad()
|
105 |
-
def remove_censorship(input_image, x1, y1, x2, y2):
|
106 |
background, layers, _ = input_image.values()
|
107 |
images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
|
108 |
mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
|
109 |
images = images * (1 - mask)
|
110 |
images = images[..., y1:y2, x1:x2]
|
111 |
latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean
|
112 |
-
|
113 |
-
images =
|
|
|
|
|
|
|
114 |
# denormalize
|
115 |
images = images / 2 + 0.5
|
116 |
images *= 255
|
@@ -119,7 +125,7 @@ def remove_censorship(input_image, x1, y1, x2, y2):
|
|
119 |
|
120 |
# @@@@@@@ Start of the program @@@@@@@@
|
121 |
|
122 |
-
vae = prepare_model()
|
123 |
|
124 |
css = '''
|
125 |
.my-disabled {
|
@@ -177,6 +183,7 @@ with gr.Blocks(css=css) as demo:
|
|
177 |
with gr.Row():
|
178 |
with gr.Column():
|
179 |
input_image = gr.ImageEditor()
|
|
|
180 |
with gr.Accordion('Manual cropping', open=False):
|
181 |
with gr.Row():
|
182 |
with gr.Row():
|
@@ -197,13 +204,13 @@ with gr.Blocks(css=css) as demo:
|
|
197 |
|
198 |
submit_btn.click(
|
199 |
fn=remove_censorship,
|
200 |
-
inputs=[input_image, x1, y1, x2, y2],
|
201 |
outputs=output_image
|
202 |
)
|
203 |
gr.Examples(
|
204 |
examples=remove_censor_examples,
|
205 |
fn=remove_censorship,
|
206 |
-
inputs=[input_image, x1, y1, x2, y2],
|
207 |
outputs=output_image,
|
208 |
cache_examples=False,
|
209 |
)
|
|
|
8 |
import torch
|
9 |
import torchvision
|
10 |
import safetensors
|
11 |
+
from diffusers import AutoencoderKL, ConsistencyDecoderVAE
|
12 |
from peft import get_peft_model, LoraConfig, set_peft_model_state_dict
|
13 |
from huggingface_hub import snapshot_download
|
14 |
|
|
|
43 |
set_peft_model_state_dict(vae, state_dict)
|
44 |
|
45 |
print('Done.')
|
46 |
+
cd_vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
47 |
+
vae = vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
48 |
+
cd_vae = cd_vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
49 |
+
return vae, cd_vae
|
50 |
|
51 |
|
52 |
@spaces.GPU
|
|
|
105 |
|
106 |
@spaces.GPU
|
107 |
@torch.no_grad()
|
108 |
+
def remove_censorship(input_image, use_cd, x1, y1, x2, y2):
|
109 |
background, layers, _ = input_image.values()
|
110 |
images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
|
111 |
mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
|
112 |
images = images * (1 - mask)
|
113 |
images = images[..., y1:y2, x1:x2]
|
114 |
latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean
|
115 |
+
if use_cd:
|
116 |
+
images = cd_vae.decode(latents.to(cd_vae.dtype), return_dict=False)[0]
|
117 |
+
else:
|
118 |
+
with vae.disable_adapter():
|
119 |
+
images = vae.decode(latents, return_dict=False)[0]
|
120 |
# denormalize
|
121 |
images = images / 2 + 0.5
|
122 |
images *= 255
|
|
|
125 |
|
126 |
# @@@@@@@ Start of the program @@@@@@@@
|
127 |
|
128 |
+
vae, cd_vae = prepare_model()
|
129 |
|
130 |
css = '''
|
131 |
.my-disabled {
|
|
|
183 |
with gr.Row():
|
184 |
with gr.Column():
|
185 |
input_image = gr.ImageEditor()
|
186 |
+
use_cd = gr.Checkbox(label='Use Consistency Decoder (slower)')
|
187 |
with gr.Accordion('Manual cropping', open=False):
|
188 |
with gr.Row():
|
189 |
with gr.Row():
|
|
|
204 |
|
205 |
submit_btn.click(
|
206 |
fn=remove_censorship,
|
207 |
+
inputs=[input_image, use_cd, x1, y1, x2, y2],
|
208 |
outputs=output_image
|
209 |
)
|
210 |
gr.Examples(
|
211 |
examples=remove_censor_examples,
|
212 |
fn=remove_censorship,
|
213 |
+
inputs=[input_image, use_cd, x1, y1, x2, y2],
|
214 |
outputs=output_image,
|
215 |
cache_examples=False,
|
216 |
)
|
examples/remove_censorship.yaml
CHANGED
@@ -1,14 +1,17 @@
|
|
1 |
- - examples/images/processed/car.png
|
|
|
2 |
- 0
|
3 |
- 0
|
4 |
- 10000
|
5 |
- 10000
|
6 |
- - examples/images/processed/obama.png
|
|
|
7 |
- 0
|
8 |
- 0
|
9 |
- 10000
|
10 |
- 10000
|
11 |
- - examples/images/processed/steam-clock.png
|
|
|
12 |
- 0
|
13 |
- 0
|
14 |
- 10000
|
|
|
1 |
- - examples/images/processed/car.png
|
2 |
+
- false
|
3 |
- 0
|
4 |
- 0
|
5 |
- 10000
|
6 |
- 10000
|
7 |
- - examples/images/processed/obama.png
|
8 |
+
- false
|
9 |
- 0
|
10 |
- 0
|
11 |
- 10000
|
12 |
- 10000
|
13 |
- - examples/images/processed/steam-clock.png
|
14 |
+
- false
|
15 |
- 0
|
16 |
- 0
|
17 |
- 10000
|