anonymous commited on
Commit
37119a2
1 Parent(s): 328dfe7

add consistency decoder

Browse files
Files changed (2) hide show
  1. app.py +15 -8
  2. examples/remove_censorship.yaml +3 -0
app.py CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
8
  import torch
9
  import torchvision
10
  import safetensors
11
- from diffusers import AutoencoderKL
12
  from peft import get_peft_model, LoraConfig, set_peft_model_state_dict
13
  from huggingface_hub import snapshot_download
14
 
@@ -43,7 +43,10 @@ def prepare_model():
43
  set_peft_model_state_dict(vae, state_dict)
44
 
45
  print('Done.')
46
- return vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
 
 
 
47
 
48
 
49
  @spaces.GPU
@@ -102,15 +105,18 @@ def add_censorship(input_image, mode, pixelation_block_size, blur_kernel_size, s
102
 
103
  @spaces.GPU
104
  @torch.no_grad()
105
- def remove_censorship(input_image, x1, y1, x2, y2):
106
  background, layers, _ = input_image.values()
107
  images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
108
  mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
109
  images = images * (1 - mask)
110
  images = images[..., y1:y2, x1:x2]
111
  latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean
112
- with vae.disable_adapter():
113
- images = vae.decode(latents, return_dict=False)[0]
 
 
 
114
  # denormalize
115
  images = images / 2 + 0.5
116
  images *= 255
@@ -119,7 +125,7 @@ def remove_censorship(input_image, x1, y1, x2, y2):
119
 
120
  # @@@@@@@ Start of the program @@@@@@@@
121
 
122
- vae = prepare_model()
123
 
124
  css = '''
125
  .my-disabled {
@@ -177,6 +183,7 @@ with gr.Blocks(css=css) as demo:
177
  with gr.Row():
178
  with gr.Column():
179
  input_image = gr.ImageEditor()
 
180
  with gr.Accordion('Manual cropping', open=False):
181
  with gr.Row():
182
  with gr.Row():
@@ -197,13 +204,13 @@ with gr.Blocks(css=css) as demo:
197
 
198
  submit_btn.click(
199
  fn=remove_censorship,
200
- inputs=[input_image, x1, y1, x2, y2],
201
  outputs=output_image
202
  )
203
  gr.Examples(
204
  examples=remove_censor_examples,
205
  fn=remove_censorship,
206
- inputs=[input_image, x1, y1, x2, y2],
207
  outputs=output_image,
208
  cache_examples=False,
209
  )
 
8
  import torch
9
  import torchvision
10
  import safetensors
11
+ from diffusers import AutoencoderKL, ConsistencyDecoderVAE
12
  from peft import get_peft_model, LoraConfig, set_peft_model_state_dict
13
  from huggingface_hub import snapshot_download
14
 
 
43
  set_peft_model_state_dict(vae, state_dict)
44
 
45
  print('Done.')
46
+ cd_vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
47
+ vae = vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
48
+ cd_vae = cd_vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
49
+ return vae, cd_vae
50
 
51
 
52
  @spaces.GPU
 
105
 
106
  @spaces.GPU
107
  @torch.no_grad()
108
+ def remove_censorship(input_image, use_cd, x1, y1, x2, y2):
109
  background, layers, _ = input_image.values()
110
  images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
111
  mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
112
  images = images * (1 - mask)
113
  images = images[..., y1:y2, x1:x2]
114
  latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean
115
+ if use_cd:
116
+ images = cd_vae.decode(latents.to(cd_vae.dtype), return_dict=False)[0]
117
+ else:
118
+ with vae.disable_adapter():
119
+ images = vae.decode(latents, return_dict=False)[0]
120
  # denormalize
121
  images = images / 2 + 0.5
122
  images *= 255
 
125
 
126
  # @@@@@@@ Start of the program @@@@@@@@
127
 
128
+ vae, cd_vae = prepare_model()
129
 
130
  css = '''
131
  .my-disabled {
 
183
  with gr.Row():
184
  with gr.Column():
185
  input_image = gr.ImageEditor()
186
+ use_cd = gr.Checkbox(label='Use Consistency Decoder (slower)')
187
  with gr.Accordion('Manual cropping', open=False):
188
  with gr.Row():
189
  with gr.Row():
 
204
 
205
  submit_btn.click(
206
  fn=remove_censorship,
207
+ inputs=[input_image, use_cd, x1, y1, x2, y2],
208
  outputs=output_image
209
  )
210
  gr.Examples(
211
  examples=remove_censor_examples,
212
  fn=remove_censorship,
213
+ inputs=[input_image, use_cd, x1, y1, x2, y2],
214
  outputs=output_image,
215
  cache_examples=False,
216
  )
examples/remove_censorship.yaml CHANGED
@@ -1,14 +1,17 @@
1
  - - examples/images/processed/car.png
 
2
  - 0
3
  - 0
4
  - 10000
5
  - 10000
6
  - - examples/images/processed/obama.png
 
7
  - 0
8
  - 0
9
  - 10000
10
  - 10000
11
  - - examples/images/processed/steam-clock.png
 
12
  - 0
13
  - 0
14
  - 10000
 
1
  - - examples/images/processed/car.png
2
+ - false
3
  - 0
4
  - 0
5
  - 10000
6
  - 10000
7
  - - examples/images/processed/obama.png
8
+ - false
9
  - 0
10
  - 0
11
  - 10000
12
  - 10000
13
  - - examples/images/processed/steam-clock.png
14
+ - false
15
  - 0
16
  - 0
17
  - 10000