Spaces:
Runtime error
Runtime error
from argparse import Namespace | |
from glob import glob | |
import yaml | |
import os | |
import spaces | |
import gradio as gr | |
import torch | |
import torchvision | |
import safetensors | |
from diffusers import AutoencoderKL | |
from peft import get_peft_model, LoraConfig, set_peft_model_state_dict | |
from huggingface_hub import snapshot_download | |
pretrained_model_path = snapshot_download(repo_id="revp2024/revp-censorship") | |
with open(glob(os.path.join(pretrained_model_path, 'hparams.yml'), recursive=True)[0]) as f: | |
args = Namespace(**yaml.safe_load(f)) | |
with open('examples/add_censorship.yaml') as f: | |
add_censor_examples = yaml.safe_load(f) | |
with open('examples/remove_censorship.yaml') as f: | |
remove_censor_examples = yaml.safe_load(f) | |
def prepare_model(): | |
print('Loading model ...') | |
vae_lora_config = LoraConfig( | |
r=args.rank, | |
lora_alpha=args.rank, | |
init_lora_weights="gaussian", | |
target_modules=["conv", "conv1", "conv2", | |
"to_q", "to_k", "to_v", "to_out.0"], | |
) | |
vae = AutoencoderKL.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="vae" | |
) | |
vae = get_peft_model(vae, vae_lora_config) | |
lora_weights_path = os.path.join(pretrained_model_path, f"pytorch_lora_weights.safetensors") | |
state_dict = {} | |
with safetensors.torch.safe_open(lora_weights_path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
state_dict[key] = f.get_tensor(key) | |
set_peft_model_state_dict(vae, state_dict) | |
print('Done.') | |
return vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') | |
def add_censorship(input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size): | |
background, layers, _ = input_image.values() | |
input_images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255 | |
mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255 | |
H, W = input_images.shape[-2:] | |
if H > 1024 or W > 1024: | |
H_t, W_t = H, W | |
if H > W: | |
H, W = 1024, int(1024 * W_t / H_t) | |
else: | |
H, W = int(1024 * H_t / W_t), 1024 | |
H_q8 = (H // 8) * 8 | |
W_q8 = (W // 8) * 8 | |
input_images = torch.nn.functional.interpolate(input_images, (H_q8, W_q8), mode='bilinear') | |
mask = torch.nn.functional.interpolate(mask, (H_q8, W_q8)) | |
if soft_edges: | |
mask = torchvision.transforms.functional.gaussian_blur(mask, soft_edge_kernel_size)[0][0] | |
input_images = input_images.to(vae.device) | |
if mode == 'Pixelation': | |
censored = torch.nn.functional.avg_pool2d( | |
input_images, pixelation_block_size) | |
censored = torch.nn.functional.interpolate(censored, input_images.shape[-2:]) | |
elif mode == 'Gaussian blur': | |
censored = torchvision.transforms.functional.gaussian_blur( | |
input_images, blur_kernel_size) | |
elif mode == 'Black': | |
censored = torch.zeros_like(input_images) | |
else: | |
raise ValueError("censor_mode has to be either `pixelation' or `gaussian_blur'") | |
mask = mask.to(input_images.device) | |
censored_images = input_images * (1 - mask) + censored * mask | |
censored_images *= 255 | |
input_images = input_images * 2 - 1 | |
with vae.disable_adapter(): | |
latents = vae.encode(input_images).latent_dist.mean | |
images = vae.decode(latents, return_dict=False)[0] | |
# denormalize | |
images = images / 2 + 0.5 | |
images *= 255 | |
residuals = (images - censored_images).clamp(-args.budget, args.budget) | |
images = (censored_images + residuals).clamp(0, 255).to(torch.uint8) | |
gr.Info("Try to donwload/copy the censored image to the `Remove censorsip' tab") | |
return images[0].permute(1, 2, 0).cpu().numpy() | |
def remove_censorship(input_image, x1, y1, x2, y2): | |
background, layers, _ = input_image.values() | |
images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255 | |
mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255 | |
images = images * (1 - mask) | |
images = images[..., y1:y2, x1:x2] | |
latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean | |
with vae.disable_adapter(): | |
images = vae.decode(latents, return_dict=False)[0] | |
# denormalize | |
images = images / 2 + 0.5 | |
images *= 255 | |
images = images.clamp(0, 255).to(torch.uint8) | |
return images[0].permute(1, 2, 0).cpu().numpy() | |
# @@@@@@@ Start of the program @@@@@@@@ | |
vae = prepare_model() | |
css = ''' | |
.my-disabled { | |
background-color: #eee; | |
} | |
.my-disabled input { | |
background-color: #eee; | |
} | |
''' | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown('# ReVP: Reversible Visual Processing with Latent Models') | |
gr.Markdown('### Check out our project page for more info: https://revp2024.github.io') | |
with gr.Tab('Add censorship'): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.ImageEditor(brush=gr.Brush(default_size=100)) | |
with gr.Accordion('Options', open=False) as options_accord: | |
mode = gr.Radio(label='Mode', choices=['Pixelation', 'Gaussian blur', 'Black'], | |
value='Pixelation', interactive=True) | |
pixelation_block_size = gr.Slider(label='Block size', minimum=10, maximum=40, value=25, step=1, interactive=True) | |
blur_kernel_size = gr.Slider(label='Blur kernel size', minimum=21, maximum=151, value=85, step=2, interactive=True, visible=False) | |
def change_mode(mode): | |
if mode == 'Gaussian blur': | |
return gr.Slider(visible=False), gr.Slider(visible=True), gr.Accordion(open=True) | |
elif mode == 'Pixelation': | |
return gr.Slider(visible=True), gr.Slider(visible=False), gr.Accordion(open=True) | |
elif mode == 'Black': | |
return gr.Slider(visible=False), gr.Slider(visible=False), gr.Accordion(open=True) | |
else: | |
raise NotImplementedError | |
mode.select(change_mode, mode, [pixelation_block_size, blur_kernel_size, options_accord]) | |
with gr.Row(variant='panel'): | |
soft_edges = gr.Checkbox(label='Soft edges', value=True, interactive=True, scale=1) | |
soft_edge_kernel_size = gr.Slider(label='Soft edge kernel size', minimum=21, maximum=49, value=35, step=2, interactive=True, visible=True, scale=2) | |
def change_soft_edges(soft_edges): | |
return gr.Slider(visible=True if soft_edges else False), gr.Accordion(open=True) | |
soft_edges.change(change_soft_edges, soft_edges, [soft_edge_kernel_size, options_accord]) | |
submit_btn = gr.Button('Submit') | |
output_image = gr.Image(label='Censored', show_download_button=True) | |
submit_btn.click( | |
fn=add_censorship, | |
inputs=[input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size], | |
outputs=output_image | |
) | |
gr.Examples( | |
examples=add_censor_examples, | |
fn=add_censorship, | |
inputs=[input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size], | |
outputs=output_image | |
) | |
with gr.Tab('Remove censorship'): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.ImageEditor() | |
with gr.Accordion('Manual cropping', open=False): | |
with gr.Row(): | |
with gr.Row(): | |
x1 = gr.Number(value=0, label='x1') | |
y1 = gr.Number(value=0, label='y1') | |
with gr.Row(): | |
x2_ = gr.Number(value=10000, label='x2', interactive=False, elem_classes='my-disabled') | |
y1_ = gr.Number(value=0, label='y1', interactive=False, elem_classes='my-disabled') | |
with gr.Row(): | |
with gr.Row(): | |
x1_ =gr.Number(value=0, label='x1', elem_classes='my-disabled') | |
y2_ = gr.Number(value=10000, label='y2', elem_classes='my-disabled') | |
with gr.Row(): | |
x2 = gr.Number(value=10000, label='x2') | |
y2 = gr.Number(value=10000, label='y2') | |
submit_btn = gr.Button('Submit') | |
output_image = gr.Image(label='Uncensored') | |
submit_btn.click( | |
fn=remove_censorship, | |
inputs=[input_image, x1, y1, x2, y2], | |
outputs=output_image | |
) | |
gr.Examples( | |
examples=remove_censor_examples, | |
fn=remove_censorship, | |
inputs=[input_image, x1, y1, x2, y2], | |
outputs=output_image | |
) | |
# sync coordinate on changed | |
x1.change(lambda x : x, x1, x1_) | |
x2.change(lambda x : x, x2, x2_) | |
y1.change(lambda x : x, y1, y1_) | |
y2.change(lambda x : x, y2, y2_) | |
if __name__ == '__main__': | |
demo.queue(4) | |
demo.launch() | |