from argparse import Namespace from glob import glob import yaml import os import spaces import gradio as gr import torch import torchvision import safetensors from diffusers import AutoencoderKL from peft import get_peft_model, LoraConfig, set_peft_model_state_dict from huggingface_hub import snapshot_download pretrained_model_path = snapshot_download(repo_id="revp2024/revp-censorship") with open(glob(os.path.join(pretrained_model_path, 'hparams.yml'), recursive=True)[0]) as f: args = Namespace(**yaml.safe_load(f)) def prepare_model(): print('Loading model ...') vae_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["conv", "conv1", "conv2", "to_q", "to_k", "to_v", "to_out.0"], ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae" ) vae = get_peft_model(vae, vae_lora_config) lora_weights_path = os.path.join(pretrained_model_path, f"pytorch_lora_weights.safetensors") state_dict = {} with safetensors.torch.safe_open(lora_weights_path, framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) set_peft_model_state_dict(vae, state_dict) print('Done.') return vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') @spaces.GPU @torch.no_grad() def add_censorship(input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size): background, layers, _ = input_image.values() input_images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255 mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255 H, W = input_images.shape[-2:] if H > 1024 or W > 1024: H_t, W_t = H, W if H > W: H, W = 1024, int(1024 * W_t / H_t) else: H, W = int(1024 * H_t / W_t), 1024 H_q8 = (H // 8) * 8 W_q8 = (W // 8) * 8 input_images = torch.nn.functional.interpolate(input_images, (H_q8, W_q8), mode='bilinear') mask = torch.nn.functional.interpolate(mask, (H_q8, W_q8)) if soft_edges: mask = torchvision.transforms.functional.gaussian_blur(mask, soft_edge_kernel_size)[0][0] input_images = input_images.to(vae.device) if mode == 'Pixelation': censored = torch.nn.functional.avg_pool2d( input_images, pixelation_block_size) censored = torch.nn.functional.interpolate(censored, input_images.shape[-2:]) elif mode == 'Gaussian blur': censored = torchvision.transforms.functional.gaussian_blur( input_images, blur_kernel_size) elif mode == 'Black': censored = torch.zeros_like(input_images) else: raise ValueError("censor_mode has to be either `pixelation' or `gaussian_blur'") mask = mask.to(input_images.device) censored_images = input_images * (1 - mask) + censored * mask censored_images *= 255 input_images = input_images * 2 - 1 with vae.disable_adapter(): latents = vae.encode(input_images).latent_dist.mean images = vae.decode(latents, return_dict=False)[0] # denormalize images = images / 2 + 0.5 images *= 255 residuals = (images - censored_images).clamp(-args.budget, args.budget) images = (censored_images + residuals).clamp(0, 255).to(torch.uint8) gr.Info("Try to donwload/copy the censored image to the `Remove censorsip' tab") return images[0].permute(1, 2, 0).cpu().numpy() @spaces.GPU @torch.no_grad() def remove_censorship(input_image, x1, y1, x2, y2): background, layers, _ = input_image.values() images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255 mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255 images = images * (1 - mask) images = images[..., y1:y2, x1:x2] latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean with vae.disable_adapter(): images = vae.decode(latents, return_dict=False)[0] # denormalize images = images / 2 + 0.5 images *= 255 images = images.clamp(0, 255).to(torch.uint8) return images[0].permute(1, 2, 0).cpu().numpy() # @@@@@@@ Start of the program @@@@@@@@ vae = prepare_model() css = ''' .my-disabled { background-color: #eee; } .my-disabled input { background-color: #eee; } ''' with gr.Blocks(css=css) as demo: gr.Markdown('# ReVP: Reversible Visual Processing with Latent Models') with gr.Tab('Add censorship'): with gr.Row(): with gr.Column(): input_image = gr.ImageEditor(brush=gr.Brush(default_size=100)) with gr.Accordion('Options', open=False) as options_accord: mode = gr.Radio(label='Mode', choices=['Pixelation', 'Gaussian blur', 'Black'], value='Pixelation', interactive=True) pixelation_block_size = gr.Slider(label='Block size', minimum=10, maximum=40, value=25, step=1, interactive=True) blur_kernel_size = gr.Slider(label='Blur kernel size', minimum=21, maximum=151, value=85, step=2, interactive=True, visible=False) def change_mode(mode): if mode == 'Gaussian blur': return gr.Slider(visible=False), gr.Slider(visible=True), gr.Accordion(open=True) elif mode == 'Pixelation': return gr.Slider(visible=True), gr.Slider(visible=False), gr.Accordion(open=True) elif mode == 'Black': return gr.Slider(visible=False), gr.Slider(visible=False), gr.Accordion(open=True) else: raise NotImplementedError mode.select(change_mode, mode, [pixelation_block_size, blur_kernel_size, options_accord]) with gr.Row(variant='panel'): soft_edges = gr.Checkbox(label='Soft edges', value=True, interactive=True, scale=1) soft_edge_kernel_size = gr.Slider(label='Soft edge kernel size', minimum=21, maximum=49, value=35, step=2, interactive=True, visible=True, scale=2) def change_soft_edges(soft_edges): return gr.Slider(visible=True if soft_edges else False), gr.Accordion(open=True) soft_edges.change(change_soft_edges, soft_edges, [soft_edge_kernel_size, options_accord]) submit_btn = gr.Button('Submit') output_image = gr.Image(label='Censored', show_download_button=True) submit_btn.click( fn=add_censorship, inputs=[input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size], outputs=output_image ) with gr.Tab('Remove censorship'): with gr.Row(): with gr.Column(): input_image = gr.ImageEditor() with gr.Accordion('Manual cropping', open=False): with gr.Row(): with gr.Row(): x1 = gr.Number(value=0, label='x1') y1 = gr.Number(value=0, label='y1') with gr.Row(): x2_ = gr.Number(value=10000, label='x2', interactive=False, elem_classes='my-disabled') y1_ = gr.Number(value=0, label='y1', interactive=False, elem_classes='my-disabled') with gr.Row(): with gr.Row(): x1_ =gr.Number(value=0, label='x1', elem_classes='my-disabled') y2_ = gr.Number(value=10000, label='y2', elem_classes='my-disabled') with gr.Row(): x2 = gr.Number(value=10000, label='x2') y2 = gr.Number(value=10000, label='y2') submit_btn = gr.Button('Submit') output_image = gr.Image(label='Uncensored') submit_btn.click( fn=remove_censorship, inputs=[input_image, x1, y1, x2, y2], outputs=output_image ) # sync coordinate on changed x1.change(lambda x : x, x1, x1_) x2.change(lambda x : x, x2, x2_) y1.change(lambda x : x, y1, y1_) y2.change(lambda x : x, y2, y2_) if __name__ == '__main__': demo.queue(4) demo.launch()