MultiMAE / app.py
roman-bachmann's picture
Update app.py
265aea4
import sys, os
import torch
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
# Install detectron2 that matches the above pytorch version
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
os.system(f'pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/{CUDA_VERSION}/torch{TORCH_VERSION}/index.html')
os.system("pip install jinja2")
os.system("pip install git+https://github.com/cocodataset/panopticapi.git")
# Imports
import gradio as gr
import detectron2
from detectron2.utils.logger import setup_logger
import numpy as np
import cv2
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision import datasets, transforms
from einops import rearrange
from PIL import Image
import imutils
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from tqdm import tqdm
import random
from functools import partial
import time
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog
from detectron2.projects.deeplab import add_deeplab_config
coco_metadata = MetadataCatalog.get("coco_2017_val_panoptic")
# Import Mask2Former
from mask2former import add_maskformer2_config
# DPT dependencies for depth pseudo labeling
from dpt.models import DPTDepthModel
from multimae.input_adapters import PatchedInputAdapter, SemSegInputAdapter
from multimae.output_adapters import SpatialOutputAdapter
from multimae.multimae import pretrain_multimae_base
from utils.data_constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
torch.set_grad_enabled(False)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')
# Initialize COCO Mask2Former
cfg = get_cfg()
cfg.MODEL.DEVICE='cpu'
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
cfg.merge_from_file("mask2former/configs/coco/panoptic-segmentation/swin/maskformer2_swin_small_bs16_50ep.yaml")
cfg.MODEL.WEIGHTS = 'https://dl.fbaipublicfiles.com/maskformer/mask2former/coco/panoptic/maskformer2_swin_small_bs16_50ep/model_final_a407fd.pkl'
cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True
cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = True
cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = True
semseg_model = DefaultPredictor(cfg)
def predict_semseg(img):
return semseg_model(255*img.permute(1,2,0).numpy())['sem_seg'].argmax(0)
def plot_semseg(img, semseg, ax):
v = Visualizer(img.permute(1,2,0), coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
semantic_result = v.draw_sem_seg(semseg.cpu()).get_image()
ax.imshow(semantic_result)
# Initialize Omnidata depth model
os.system("wget https://datasets.epfl.ch/vilab/iccv21/weights/omnidata_rgb2depth_dpt_hybrid.pth -P pretrained_models")
omnidata_ckpt = torch.load('./pretrained_models/omnidata_rgb2depth_dpt_hybrid.pth', map_location='cpu')
depth_model = DPTDepthModel()
depth_model.load_state_dict(omnidata_ckpt)
depth_model = depth_model.to(device).eval()
def predict_depth(img):
depth_model_input = (img.unsqueeze(0) - 0.5) / 0.5
return depth_model(depth_model_input.to(device))
# MultiMAE model setup
DOMAIN_CONF = {
'rgb': {
'input_adapter': partial(PatchedInputAdapter, num_channels=3, stride_level=1),
'output_adapter': partial(SpatialOutputAdapter, num_channels=3, stride_level=1),
},
'depth': {
'input_adapter': partial(PatchedInputAdapter, num_channels=1, stride_level=1),
'output_adapter': partial(SpatialOutputAdapter, num_channels=1, stride_level=1),
},
'semseg': {
'input_adapter': partial(SemSegInputAdapter, num_classes=133,
dim_class_emb=64, interpolate_class_emb=False, stride_level=4),
'output_adapter': partial(SpatialOutputAdapter, num_channels=133, stride_level=4),
},
}
DOMAINS = ['rgb', 'depth', 'semseg']
input_adapters = {
domain: dinfo['input_adapter'](
patch_size_full=16,
)
for domain, dinfo in DOMAIN_CONF.items()
}
output_adapters = {
domain: dinfo['output_adapter'](
patch_size_full=16,
dim_tokens=256,
use_task_queries=True,
depth=2,
context_tasks=DOMAINS,
task=domain
)
for domain, dinfo in DOMAIN_CONF.items()
}
multimae = pretrain_multimae_base(
input_adapters=input_adapters,
output_adapters=output_adapters,
)
CKPT_URL = 'https://github.com/EPFL-VILAB/MultiMAE/releases/download/pretrained-weights/multimae-b_98_rgb+-depth-semseg_1600e_multivit-afff3f8c.pth'
ckpt = torch.hub.load_state_dict_from_url(CKPT_URL, map_location='cpu')
multimae.load_state_dict(ckpt['model'], strict=False)
multimae = multimae.to(device).eval()
# Plotting
def get_masked_image(img, mask, image_size=224, patch_size=16, mask_value=0.0):
img_token = rearrange(
img.detach().cpu(),
'b c (nh ph) (nw pw) -> b (nh nw) (c ph pw)',
ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size
)
img_token[mask.detach().cpu()!=0] = mask_value
img = rearrange(
img_token,
'b (nh nw) (c ph pw) -> b c (nh ph) (nw pw)',
ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size
)
return img
def denormalize(img, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
return TF.normalize(
img.clone(),
mean= [-m/s for m, s in zip(mean, std)],
std= [1/s for s in std]
)
def plot_semseg_gt(input_dict, ax=None, image_size=224):
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
instance_mode = ColorMode.IMAGE
img_viz = 255 * denormalize(input_dict['rgb'].detach().cpu())[0].permute(1,2,0)
semseg = F.interpolate(
input_dict['semseg'].unsqueeze(0).cpu().float(), size=image_size, mode='nearest'
).long()[0,0]
visualizer = Visualizer(img_viz, metadata, instance_mode=instance_mode, scale=1)
visualizer.draw_sem_seg(semseg)
if ax is not None:
ax.imshow(visualizer.get_output().get_image())
else:
return visualizer.get_output().get_image()
def plot_semseg_gt_masked(input_dict, mask, ax=None, mask_value=1.0, image_size=224):
img = plot_semseg_gt(input_dict, image_size=image_size)
img = torch.LongTensor(img).permute(2,0,1).unsqueeze(0)
masked_img = get_masked_image(img.float()/255.0, mask, image_size=image_size, patch_size=16, mask_value=mask_value)
masked_img = masked_img[0].permute(1,2,0)
if ax is not None:
ax.imshow(masked_img)
else:
return masked_img
def get_pred_with_input(gt, pred, mask, image_size=224, patch_size=16):
gt_token = rearrange(
gt.detach().cpu(),
'b c (nh ph) (nw pw) -> b (nh nw) (c ph pw)',
ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size
)
pred_token = rearrange(
pred.detach().cpu(),
'b c (nh ph) (nw pw) -> b (nh nw) (c ph pw)',
ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size
)
pred_token[mask.detach().cpu()==0] = gt_token[mask.detach().cpu()==0]
img = rearrange(
pred_token,
'b (nh nw) (c ph pw) -> b c (nh ph) (nw pw)',
ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size
)
return img
def plot_semseg_pred_masked(rgb, semseg_preds, semseg_gt, mask, ax=None, image_size=224):
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
instance_mode = ColorMode.IMAGE
img_viz = 255 * denormalize(rgb.detach().cpu())[0].permute(1,2,0)
semseg = get_pred_with_input(
semseg_gt.unsqueeze(1),
semseg_preds.argmax(1).unsqueeze(1),
mask,
image_size=image_size//4,
patch_size=4
)
semseg = F.interpolate(semseg.float(), size=image_size, mode='nearest')[0,0].long()
visualizer = Visualizer(img_viz, metadata, instance_mode=instance_mode, scale=1)
visualizer.draw_sem_seg(semseg)
if ax is not None:
ax.imshow(visualizer.get_output().get_image())
else:
return visualizer.get_output().get_image()
def plot_predictions(input_dict, preds, masks, image_size=224):
masked_rgb = get_masked_image(
denormalize(input_dict['rgb']),
masks['rgb'],
image_size=image_size,
mask_value=1.0
)[0].permute(1,2,0).detach().cpu()
masked_depth = get_masked_image(
input_dict['depth'],
masks['depth'],
image_size=image_size,
mask_value=np.nan
)[0,0].detach().cpu()
pred_rgb = denormalize(preds['rgb'])[0].permute(1,2,0).clamp(0,1)
pred_depth = preds['depth'][0,0].detach().cpu()
pred_rgb2 = get_pred_with_input(
denormalize(input_dict['rgb']),
denormalize(preds['rgb']).clamp(0,1),
masks['rgb'],
image_size=image_size
)[0].permute(1,2,0).detach().cpu()
pred_depth2 = get_pred_with_input(
input_dict['depth'],
preds['depth'],
masks['depth'],
image_size=image_size
)[0,0].detach().cpu()
fig = plt.figure(figsize=(10, 10))
grid = ImageGrid(fig, 111, nrows_ncols=(3, 3), axes_pad=0)
grid[0].imshow(masked_rgb)
grid[1].imshow(pred_rgb2)
grid[2].imshow(denormalize(input_dict['rgb'])[0].permute(1,2,0).detach().cpu())
grid[3].imshow(masked_depth)
grid[4].imshow(pred_depth2)
grid[5].imshow(input_dict['depth'][0,0].detach().cpu())
plot_semseg_gt_masked(input_dict, masks['semseg'], grid[6], mask_value=1.0, image_size=image_size)
plot_semseg_pred_masked(input_dict['rgb'], preds['semseg'], input_dict['semseg'], masks['semseg'], grid[7], image_size=image_size)
plot_semseg_gt(input_dict, grid[8], image_size=image_size)
for ax in grid:
ax.set_xticks([])
ax.set_yticks([])
fontsize = 16
grid[0].set_title('Masked inputs', fontsize=fontsize)
grid[1].set_title('MultiMAE predictions', fontsize=fontsize)
grid[2].set_title('Original Reference', fontsize=fontsize)
grid[0].set_ylabel('RGB', fontsize=fontsize)
grid[3].set_ylabel('Depth', fontsize=fontsize)
grid[6].set_ylabel('Semantic', fontsize=fontsize)
plt.savefig('./output.png', dpi=300, bbox_inches='tight')
plt.close()
def inference(img, num_tokens, manual_mode, num_rgb, num_depth, num_semseg, seed):
num_tokens = int(588 * num_tokens / 100.0)
num_rgb = int(196 * num_rgb / 100.0)
num_depth = int(196 * num_depth / 100.0)
num_semseg = int(196 * num_semseg / 100.0)
im = Image.open(img)
# Center crop and resize RGB
image_size = 224 # Train resolution
img = TF.center_crop(TF.to_tensor(im), min(im.size))
img = TF.resize(img, image_size, interpolation=TF.InterpolationMode.BICUBIC)
# Predict depth and semseg
depth = predict_depth(img)
semseg = predict_semseg(img)
# Pre-process RGB, depth and semseg to the MultiMAE input format
input_dict = {}
# Normalize RGB
input_dict['rgb'] = TF.normalize(img, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD).unsqueeze(0)
# Normalize depth robustly
trunc_depth = torch.sort(depth.flatten())[0]
trunc_depth = trunc_depth[int(0.1 * trunc_depth.shape[0]): int(0.9 * trunc_depth.shape[0])]
depth = (depth - trunc_depth.mean()[None,None,None]) / torch.sqrt(trunc_depth.var()[None,None,None] + 1e-6)
input_dict['depth'] = depth.unsqueeze(0)
# Downsample semantic segmentation
stride = 4
semseg = TF.resize(semseg.unsqueeze(0), (semseg.shape[0] // stride, semseg.shape[1] // stride), interpolation=TF.InterpolationMode.NEAREST)
input_dict['semseg'] = semseg
# To GPU
input_dict = {k: v.to(device) for k,v in input_dict.items()}
if not manual_mode:
# Randomly sample masks
torch.manual_seed(int(time.time())) # Random mode is random
preds, masks = multimae.forward(
input_dict,
mask_inputs=True, # True if forward pass should sample random masks
num_encoded_tokens=num_tokens,
alphas=1.0
)
else:
# Randomly sample masks using the specified number of tokens per modality
torch.manual_seed(int(seed)) # change seed to resample new mask
task_masks = {domain: torch.ones(1,196).long().to(device) for domain in DOMAINS}
selected_rgb_idxs = torch.randperm(196)[:num_rgb]
selected_depth_idxs = torch.randperm(196)[:num_depth]
selected_semseg_idxs = torch.randperm(196)[:num_semseg]
task_masks['rgb'][:,selected_rgb_idxs] = 0
task_masks['depth'][:,selected_depth_idxs] = 0
task_masks['semseg'][:,selected_semseg_idxs] = 0
preds, masks = multimae.forward(
input_dict,
mask_inputs=True,
task_masks=task_masks
)
preds = {domain: pred.detach().cpu() for domain, pred in preds.items()}
masks = {domain: mask.detach().cpu() for domain, mask in masks.items()}
plot_predictions(input_dict, preds, masks)
return 'output.png'
title = "MultiMAE"
description = "Gradio demo for MultiMAE: Multi-modal Multi-task Masked Autoencoders. \
Upload your own images or try one of the examples below to explore the multi-modal masked reconstruction of a pre-trained MultiMAE model. \
Uploaded images are pseudo labeled using a DPT trained on Omnidata depth, and a Mask2Former trained on COCO. \
Choose the percentage of visible tokens using the sliders below and see how MultiMAE reconstructs the modalities!"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.01678' \
target='_blank'>MultiMAE: Multi-modal Multi-task Masked Autoencoders</a> | \
<a href='https://github.com/EPFL-VILAB/MultiMAE' target='_blank'>Github Repo</a></p>"
css = '.output-image{height: 713px !important}'
# Example images
#os.system("wget https://i.imgur.com/c9ObJdK.jpg")
#os.system("wget https://i.imgur.com/KTKgYKi.jpg")
#os.system("wget https://i.imgur.com/lWYuRI7.jpg")
examples = [
['c9ObJdK.jpg', 15, False, 15, 15, 15, 0],
['KTKgYKi.jpg', 15, False, 15, 15, 15, 0],
['lWYuRI7.jpg', 15, False, 15, 15, 15, 0],
]
gr.Interface(
fn=inference,
inputs=[
gr.inputs.Image(label='RGB input image', type='filepath'),
gr.inputs.Slider(label='Percentage of input tokens', default=15, step=0.1, minimum=0, maximum=100),
gr.inputs.Checkbox(label='Manual mode: Check this to manually set the number of input tokens per modality using the sliders below', default=False),
gr.inputs.Slider(label='Percentage of RGB input tokens (for manual mode only)', default=15, step=0.1, minimum=0, maximum=100),
gr.inputs.Slider(label='Percentage of depth input tokens (for manual mode only)', default=15, step=0.1, minimum=0, maximum=100),
gr.inputs.Slider(label='Percentage of semantic input tokens (for manual mode only)', default=15, step=0.1, minimum=0, maximum=100),
gr.inputs.Number(label='Random seed: Change this to sample different masks (for manual mode only)', default=0),
],
outputs=[
gr.outputs.Image(label='MultiMAE predictions', type='filepath')
],
css=css,
title=title,
description=description,
article=article,
examples=examples
).launch(enable_queue=True)