Spaces:
Runtime error
Runtime error
File size: 9,864 Bytes
19d207d c9f0162 19d207d 37119a2 19d207d 385c778 19d207d 37119a2 19d207d c9f0162 19d207d c9f0162 19d207d 37119a2 19d207d 37119a2 19d207d 37119a2 19d207d 385c778 19d207d 385c778 328dfe7 385c778 19d207d 37119a2 19d207d 37119a2 19d207d 385c778 37119a2 328dfe7 385c778 19d207d 385c778 19d207d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
from argparse import Namespace
from glob import glob
import yaml
import os
import spaces
import gradio as gr
import torch
import torchvision
import safetensors
from diffusers import AutoencoderKL, ConsistencyDecoderVAE
from peft import get_peft_model, LoraConfig, set_peft_model_state_dict
from huggingface_hub import snapshot_download
pretrained_model_path = snapshot_download(repo_id="revp2024/revp-censorship")
with open(glob(os.path.join(pretrained_model_path, 'hparams.yml'), recursive=True)[0]) as f:
args = Namespace(**yaml.safe_load(f))
with open('examples/add_censorship.yaml') as f:
add_censor_examples = yaml.safe_load(f)
with open('examples/remove_censorship.yaml') as f:
remove_censor_examples = yaml.safe_load(f)
def prepare_model():
print('Loading model ...')
vae_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["conv", "conv1", "conv2",
"to_q", "to_k", "to_v", "to_out.0"],
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae"
)
vae = get_peft_model(vae, vae_lora_config)
lora_weights_path = os.path.join(pretrained_model_path, f"pytorch_lora_weights.safetensors")
state_dict = {}
with safetensors.torch.safe_open(lora_weights_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
set_peft_model_state_dict(vae, state_dict)
print('Done.')
cd_vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
vae = vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
cd_vae = cd_vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
return vae, cd_vae
@spaces.GPU
@torch.no_grad()
def add_censorship(input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size):
background, layers, _ = input_image.values()
input_images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
H, W = input_images.shape[-2:]
if H > 1024 or W > 1024:
H_t, W_t = H, W
if H > W:
H, W = 1024, int(1024 * W_t / H_t)
else:
H, W = int(1024 * H_t / W_t), 1024
H_q8 = (H // 8) * 8
W_q8 = (W // 8) * 8
input_images = torch.nn.functional.interpolate(input_images, (H_q8, W_q8), mode='bilinear')
mask = torch.nn.functional.interpolate(mask, (H_q8, W_q8))
if soft_edges:
mask = torchvision.transforms.functional.gaussian_blur(mask, soft_edge_kernel_size)[0][0]
input_images = input_images.to(vae.device)
if mode == 'Pixelation':
censored = torch.nn.functional.avg_pool2d(
input_images, pixelation_block_size)
censored = torch.nn.functional.interpolate(censored, input_images.shape[-2:])
elif mode == 'Gaussian blur':
censored = torchvision.transforms.functional.gaussian_blur(
input_images, blur_kernel_size)
elif mode == 'Black':
censored = torch.zeros_like(input_images)
else:
raise ValueError("censor_mode has to be either `pixelation' or `gaussian_blur'")
mask = mask.to(input_images.device)
censored_images = input_images * (1 - mask) + censored * mask
censored_images *= 255
input_images = input_images * 2 - 1
with vae.disable_adapter():
latents = vae.encode(input_images).latent_dist.mean
images = vae.decode(latents, return_dict=False)[0]
# denormalize
images = images / 2 + 0.5
images *= 255
residuals = (images - censored_images).clamp(-args.budget, args.budget)
images = (censored_images + residuals).clamp(0, 255).to(torch.uint8)
gr.Info("Try to donwload/copy the censored image to the `Remove censorsip' tab")
return images[0].permute(1, 2, 0).cpu().numpy()
@spaces.GPU
@torch.no_grad()
def remove_censorship(input_image, use_cd, x1, y1, x2, y2):
background, layers, _ = input_image.values()
images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255
mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255
images = images * (1 - mask)
images = images[..., y1:y2, x1:x2]
latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean
if use_cd:
images = cd_vae.decode(latents.to(cd_vae.dtype), return_dict=False)[0]
else:
with vae.disable_adapter():
images = vae.decode(latents, return_dict=False)[0]
# denormalize
images = images / 2 + 0.5
images *= 255
images = images.clamp(0, 255).to(torch.uint8)
return images[0].permute(1, 2, 0).cpu().numpy()
# @@@@@@@ Start of the program @@@@@@@@
vae, cd_vae = prepare_model()
css = '''
.my-disabled {
background-color: #eee;
}
.my-disabled input {
background-color: #eee;
}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown('# ReVP: Reversible Visual Processing with Latent Models')
gr.Markdown('### Check out our project page for more info: https://revp2024.github.io')
with gr.Tab('Add censorship'):
with gr.Row():
with gr.Column():
input_image = gr.ImageEditor(brush=gr.Brush(default_size=100))
with gr.Accordion('Options', open=False) as options_accord:
mode = gr.Radio(label='Mode', choices=['Pixelation', 'Gaussian blur', 'Black'],
value='Pixelation', interactive=True)
pixelation_block_size = gr.Slider(label='Block size', minimum=10, maximum=40, value=25, step=1, interactive=True)
blur_kernel_size = gr.Slider(label='Blur kernel size', minimum=21, maximum=151, value=85, step=2, interactive=True, visible=False)
def change_mode(mode):
if mode == 'Gaussian blur':
return gr.Slider(visible=False), gr.Slider(visible=True), gr.Accordion(open=True)
elif mode == 'Pixelation':
return gr.Slider(visible=True), gr.Slider(visible=False), gr.Accordion(open=True)
elif mode == 'Black':
return gr.Slider(visible=False), gr.Slider(visible=False), gr.Accordion(open=True)
else:
raise NotImplementedError
mode.select(change_mode, mode, [pixelation_block_size, blur_kernel_size, options_accord])
with gr.Row(variant='panel'):
soft_edges = gr.Checkbox(label='Soft edges', value=True, interactive=True, scale=1)
soft_edge_kernel_size = gr.Slider(label='Soft edge kernel size', minimum=21, maximum=49, value=35, step=2, interactive=True, visible=True, scale=2)
def change_soft_edges(soft_edges):
return gr.Slider(visible=True if soft_edges else False), gr.Accordion(open=True)
soft_edges.change(change_soft_edges, soft_edges, [soft_edge_kernel_size, options_accord])
submit_btn = gr.Button('Submit')
output_image = gr.Image(label='Censored', show_download_button=True)
submit_btn.click(
fn=add_censorship,
inputs=[input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size],
outputs=output_image
)
gr.Examples(
examples=add_censor_examples,
fn=add_censorship,
inputs=[input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size],
outputs=output_image,
cache_examples=False,
)
with gr.Tab('Remove censorship'):
with gr.Row():
with gr.Column():
input_image = gr.ImageEditor()
use_cd = gr.Checkbox(label='Use Consistency Decoder (slower)')
with gr.Accordion('Manual cropping', open=False):
with gr.Row():
with gr.Row():
x1 = gr.Number(value=0, label='x1')
y1 = gr.Number(value=0, label='y1')
with gr.Row():
x2_ = gr.Number(value=10000, label='x2', interactive=False, elem_classes='my-disabled')
y1_ = gr.Number(value=0, label='y1', interactive=False, elem_classes='my-disabled')
with gr.Row():
with gr.Row():
x1_ =gr.Number(value=0, label='x1', elem_classes='my-disabled')
y2_ = gr.Number(value=10000, label='y2', elem_classes='my-disabled')
with gr.Row():
x2 = gr.Number(value=10000, label='x2')
y2 = gr.Number(value=10000, label='y2')
submit_btn = gr.Button('Submit')
output_image = gr.Image(label='Uncensored')
submit_btn.click(
fn=remove_censorship,
inputs=[input_image, use_cd, x1, y1, x2, y2],
outputs=output_image
)
gr.Examples(
examples=remove_censor_examples,
fn=remove_censorship,
inputs=[input_image, use_cd, x1, y1, x2, y2],
outputs=output_image,
cache_examples=False,
)
# sync coordinate on changed
x1.change(lambda x : x, x1, x1_)
x2.change(lambda x : x, x2, x2_)
y1.change(lambda x : x, y1, y1_)
y2.change(lambda x : x, y2, y2_)
if __name__ == '__main__':
demo.queue(4)
demo.launch()
|