Spaces:
Running
Running
File size: 7,018 Bytes
9ef8038 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
from unimatch.unimatch import UniMatch
from utils.flow_viz import flow_to_image
from dataloader.stereo import transforms
from utils.visualization import vis_disparity
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
@torch.no_grad()
def inference(image1, image2, task='flow'):
"""Inference on an image pair for optical flow or stereo disparity prediction"""
model = UniMatch(feature_channels=128,
num_scales=2,
upsample_factor=4,
ffn_dim_expansion=4,
num_transformer_layers=6,
reg_refine=True,
task=task)
model.eval()
if task == 'flow':
checkpoint_path = 'pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth'
else:
checkpoint_path = 'pretrained/gmstereo-scale2-regrefine3-resumeflowthings-mixdata-train320x640-ft640x960-e4e291fd.pth'
checkpoint_flow = torch.load(checkpoint_path)
model.load_state_dict(checkpoint_flow['model'], strict=True)
padding_factor = 32
attn_type = 'swin' if task == 'flow' else 'self_swin2d_cross_swin1d'
attn_splits_list = [2, 8]
corr_radius_list = [-1, 4]
prop_radius_list = [-1, 1]
num_reg_refine = 6 if task == 'flow' else 3
# smaller inference size for faster speed
max_inference_size = [384, 768] if task == 'flow' else [640, 960]
transpose_img = False
image1 = np.array(image1).astype(np.float32)
image2 = np.array(image2).astype(np.float32)
if len(image1.shape) == 2: # gray image
image1 = np.tile(image1[..., None], (1, 1, 3))
image2 = np.tile(image2[..., None], (1, 1, 3))
else:
image1 = image1[..., :3]
image2 = image2[..., :3]
if task == 'flow':
image1 = torch.from_numpy(image1).permute(2, 0, 1).float().unsqueeze(0)
image2 = torch.from_numpy(image2).permute(2, 0, 1).float().unsqueeze(0)
else:
val_transform_list = [transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
]
val_transform = transforms.Compose(val_transform_list)
sample = {'left': image1, 'right': image2}
sample = val_transform(sample)
image1 = sample['left'].unsqueeze(0) # [1, 3, H, W]
image2 = sample['right'].unsqueeze(0) # [1, 3, H, W]
# the model is trained with size: width > height
if task == 'flow' and image1.size(-2) > image1.size(-1):
image1 = torch.transpose(image1, -2, -1)
image2 = torch.transpose(image2, -2, -1)
transpose_img = True
nearest_size = [int(np.ceil(image1.size(-2) / padding_factor)) * padding_factor,
int(np.ceil(image1.size(-1) / padding_factor)) * padding_factor]
inference_size = [min(max_inference_size[0], nearest_size[0]), min(max_inference_size[1], nearest_size[1])]
assert isinstance(inference_size, list) or isinstance(inference_size, tuple)
ori_size = image1.shape[-2:]
# resize before inference
if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
image1 = F.interpolate(image1, size=inference_size, mode='bilinear',
align_corners=True)
image2 = F.interpolate(image2, size=inference_size, mode='bilinear',
align_corners=True)
results_dict = model(image1, image2,
attn_type=attn_type,
attn_splits_list=attn_splits_list,
corr_radius_list=corr_radius_list,
prop_radius_list=prop_radius_list,
num_reg_refine=num_reg_refine,
task=task,
)
flow_pr = results_dict['flow_preds'][-1] # [1, 2, H, W] or [1, H, W]
# resize back
if task == 'flow':
if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
flow_pr = F.interpolate(flow_pr, size=ori_size, mode='bilinear',
align_corners=True)
flow_pr[:, 0] = flow_pr[:, 0] * ori_size[-1] / inference_size[-1]
flow_pr[:, 1] = flow_pr[:, 1] * ori_size[-2] / inference_size[-2]
else:
if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
pred_disp = F.interpolate(flow_pr.unsqueeze(1), size=ori_size,
mode='bilinear',
align_corners=True).squeeze(1) # [1, H, W]
pred_disp = pred_disp * ori_size[-1] / float(inference_size[-1])
if task == 'flow':
if transpose_img:
flow_pr = torch.transpose(flow_pr, -2, -1)
flow = flow_pr[0].permute(1, 2, 0).cpu().numpy() # [H, W, 2]
output = flow_to_image(flow) # [H, W, 3]
else:
disp = pred_disp[0].cpu().numpy()
output = vis_disparity(disp, return_rgb=True)
return Image.fromarray(output)
title = "UniMatch"
description = "<p style='text-align: center'>Optical flow and stereo matching demo for <a href='https://haofeixu.github.io/unimatch/' target='_blank'>Unifying Flow, Stereo and Depth Estimation</a> | <a href='https://arxiv.org/abs/2211.05783' target='_blank'>Paper</a> | <a href='https://github.com/autonomousvision/unimatch' target='_blank'>Code</a> | <a href='https://colab.research.google.com/drive/1r5m-xVy3Kw60U-m5VB-aQ98oqqg_6cab?usp=sharing' target='_blank'>Colab</a><br>Simply upload your images or click one of the provided examples.<br>The <strong>first three</strong> examples are video frames for <strong>flow</strong> task, and the <strong>last three</strong> are stereo pairs for <strong>stereo</strong> task.<br><strong>Select the task type according to your input images</strong>.</p>"
examples = [
['demo/flow_kitti_test_000197_10.png', 'demo/flow_kitti_test_000197_11.png'],
['demo/flow_sintel_cave_3_frame_0049.png', 'demo/flow_sintel_cave_3_frame_0050.png'],
['demo/flow_davis_skate-jump_00059.jpg', 'demo/flow_davis_skate-jump_00060.jpg'],
['demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_left.jpg',
'demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_right.jpg'],
['demo/stereo_middlebury_plants_im0.png', 'demo/stereo_middlebury_plants_im1.png'],
['demo/stereo_holopix_left.png', 'demo/stereo_holopix_right.png']
]
gr.Interface(
inference,
[gr.Image(type="pil", label="Image1"), gr.Image(type="pil", label="Image2"), gr.Radio(choices=['flow', 'stereo'], value='flow', label='Task')],
gr.Image(type="pil", label="Flow/Disparity"),
title=title,
description=description,
examples=examples,
thumbnail="https://haofeixu.github.io/unimatch/resources/teaser.svg",
allow_flagging="auto",
).launch(debug=True)
|