import numpy as np |
import torchvision |
import cv2 |
import tqdm |
import torchvision.transforms.functional as F |
from PIL import Image |
from torchvision.utils import save_image |
import time |
import os |
import pathlib |
from torch.utils.data import DataLoader |
from kornia.filters.median import MedianBlur |
median_filter = MedianBlur(kernel_size=(15,15)) |
from moments_dataset import MomentsDataset |
try: |
from processing_utils import aggregate_frames |
import processing_utils |
except Exception as e: |
print(e) |
print('process failed') |
exit() |
import torch |
def load_image(img_path, resize_size=None,crop_size=None): |
img1_pil = Image.open(img_path) |
img1_frames = torchvision.transforms.functional.pil_to_tensor(img1_pil) |
if resize_size: |
img1_frames = torchvision.transforms.functional.resize(img1_frames, resize_size) |
if crop_size: |
img1_frames = torchvision.transforms.functional.center_crop(img1_frames, crop_size) |
img1_batch = torch.unsqueeze(img1_frames, dim=0) |
return img1_batch |
def get_grid(size): |
y = np.repeat(np.arange(size)[None, ...], size) |
y = y.reshape(size, size) |
x = y.transpose() |
out = np.stack([y,x], -1) |
return out |
def collage_from_frames(frames_t): |
if np.random.randint(0, 2) == 0: |
frames_t = frames_t.flip(0) |
tgt_idx_guess = np.random.randint(1, min(len(frames_t), 20)) |
tgt_idx = 1 |
pairwise_flows = [] |
flow = None |
init_time = time.time() |
unsmoothed_agg = None |
for cur_idx in range(1, tgt_idx_guess+1): |
cur_flow, pairwise_flows = aggregate_frames(frames_t[:cur_idx+1] , pairwise_flows, unsmoothed_agg) |
unsmoothed_agg = cur_flow.clone() |
agg_cur_flow = median_filter(cur_flow) |
flow_norm = torch.norm(agg_cur_flow.squeeze(), dim=0).flatten() |
flow_90 = np.percentile(flow_norm.cpu().numpy(), 90) |
flow_90 = np.percentile(flow_norm.cpu().numpy(), 90) |
flow_95 = np.percentile(flow_norm.cpu().numpy(), 95) |
if cur_idx == 5: |
if flow_95 < 20.0: |
print('flow is tiny :(') |
return None |
if cur_idx == tgt_idx_guess-1: |
if flow_95 < 50.0: |
print('flow is tiny :(') |
return None |
if flow is None: |
if flow_90 < 1.0: |
return None |
flow = agg_cur_flow |
if flow_90 <= 300: |
tgt_idx = cur_idx |
flow = agg_cur_flow |
else: |
break |
final_time = time.time() |
print('time guessing idx', final_time - init_time) |
_, flow_warping_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=None, alpha_mask=None) |
flow_warping_mask = flow_warping_mask.squeeze().numpy() > 0.5 |
if np.mean(flow_warping_mask) < 0.6: |
return |
src_array = frames_t[0].moveaxis(0, -1).cpu().numpy() * 1.0 |
init_time = time.time() |
depth = get_depth_from_array(frames_t[0]) |
finish_time = time.time() |
print('time getting depth', finish_time - init_time) |
src_array_uint = src_array * 255.0 |
src_array_uint = src_array_uint.astype(np.uint8) |
segments = processing_utils.mask_generator.generate(src_array_uint) |
size = src_array.shape[1] |
grid_np = get_grid(size).astype(np.float16) / size |
grid_t = torch.tensor(grid_np).moveaxis(-1, 0) |
collage, canvas_alpha, lost_alpha = collage_warp(src_array, flow.squeeze(), depth, segments, grid_array=grid_np) |
lost_alpha_t = torch.tensor(lost_alpha).squeeze().unsqueeze(0) |
warping_alpha = (lost_alpha_t < 0.5).float() |
rgb_grid_splatted, actual_warped_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=grid_t, alpha_mask=warping_alpha) |
warped_src = (rgb_grid_splatted * actual_warped_mask).moveaxis(0, -1).cpu().numpy() |
canvas_alpha_mask = canvas_alpha == 0.0 |
collage_mask = canvas_alpha.squeeze() + actual_warped_mask.squeeze().cpu().numpy() |
collage_mask = collage_mask > 0.5 |
composite_grid = warped_src * canvas_alpha_mask + collage |
rgb_grid_splatted_np = rgb_grid_splatted.moveaxis(0, -1).cpu().numpy() |
return frames_t[0], frames_t[tgt_idx], rgb_grid_splatted_np, composite_grid, flow_warping_mask, collage_mask |
def collage_warp(rgb_array, flow, depth, segments, grid_array): |
avg_depths = [] |
avg_flows = [] |
src_array = np.concatenate([rgb_array, grid_array], axis=-1) |
canvas = np.zeros_like(src_array) |
canvas_alpha = np.zeros_like(canvas[...,-1:]).astype(float) |
lost_regions = np.zeros_like(canvas[...,-1:]).astype(float) |
z_buffer = np.ones_like(depth)[..., None] * -1.0 |
unsqueezed_depth = depth[..., None] |
affine_transforms = [] |
filtered_segments = [] |
for segment in segments: |
if segment['area'] > 300: |
filtered_segments.append(segment) |
for segment in filtered_segments: |
seg_mask = segment['segmentation'] |
avg_flow = torch.mean(flow[:, seg_mask],dim=1) |
avg_flows.append(avg_flow) |
avg_depth = torch.median(1.0 / (depth[seg_mask] + 1e-6)) |
avg_depths.append(avg_depth) |
all_y, all_x = np.nonzero(segment['segmentation']) |
rand_indices = np.random.randint(0, len(all_y), size=50) |
rand_x = [all_x[i] for i in rand_indices] |
rand_y = [all_y[i] for i in rand_indices] |
src_pairs = [(x, y) for x, y in zip(rand_x, rand_y)] |
tgt_pairs = [] |
for i in range(len(src_pairs)): |
x, y = src_pairs[i] |
dx, dy = flow[:, y, x] |
tgt_pairs.append((x+dx, y+dy)) |
affine_trans, inliers = cv2.estimateAffinePartial2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32)) |
affine_transforms.append(affine_trans) |
depth_sorted_indices = np.arange(len(avg_depths)) |
depth_sorted_indices = sorted(depth_sorted_indices, key=lambda x: avg_depths[x]) |
for idx in depth_sorted_indices: |
alpha_mask = filtered_segments[idx]['segmentation'][..., None] * (lost_regions < 0.5).astype(float) |
src_rgba = np.concatenate([src_array, alpha_mask, unsqueezed_depth], axis=-1) |
warp_dst = cv2.warpAffine(src_rgba, affine_transforms[idx], (src_array.shape[1], src_array.shape[0])) |
warped_mask = warp_dst[..., -2:-1] |
warped_depth = warp_dst[..., -1:] |
warped_rgb = warp_dst[...,:-2] |
good_z_region = warped_depth > z_buffer |
warped_mask = np.logical_and(warped_mask > 0.5, good_z_region).astype(float) |
kernel = np.ones((3,3), float) |
canvas_alpha += cv2.erode(warped_mask,kernel)[..., None] |
lost_regions += alpha_mask |
canvas = canvas * (1.0 - warped_mask) + warped_mask * warped_rgb |
z_buffer = z_buffer * (1.0 - warped_mask) + warped_mask * warped_depth |
return canvas, canvas_alpha, lost_regions |
def get_depth_from_array(img_t): |
img_arr = img_t.moveaxis(0, -1).cpu().numpy() * 1.0 |
img_arr *= 255.0 |
img_arr = img_arr.astype(np.uint8) |
input_batch = processing_utils.depth_transform(img_arr).cuda() |
with torch.no_grad(): |
prediction = processing_utils.midas(input_batch) |
prediction = torch.nn.functional.interpolate( |
prediction.unsqueeze(1), |
size=img_arr.shape[:2], |
mode="bicubic", |
align_corners=False, |
).squeeze() |
output = prediction.cpu() |
return output |
def main(): |
print('starting main') |
video_folder = './example_videos' |
save_dir = pathlib.Path('./processed_data') |
process_video_folder(video_folder, save_dir) |
def process_video_folder(video_folder, save_dir): |
all_counter = 0 |
success_counter = 0 |
os.makedirs(save_dir, exist_ok=True) |
dataset = MomentsDataset(videos_folder=video_folder, num_frames=20, samples_per_video=5) |
batch_size = 4 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
with torch.no_grad(): |
for i, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataset)//batch_size): |
frames_to_visualize = batch["frames"] |
bs = frames_to_visualize.shape[0] |
for j in range(bs): |
frames = frames_to_visualize[j] |
caption = batch["caption"][j] |
collage_init_time = time.time() |
out = collage_from_frames(frames) |
collage_finish_time = time.time() |
print('collage processing time', collage_finish_time - collage_init_time) |
all_counter += 1 |
if out is not None: |
src_image, tgt_image, splatted, collage, flow_mask, collage_mask = out |
splatted_rgb = splatted[...,:3] |
splatted_grid = splatted[...,3:].astype(np.float16) |
collage_rgb = collage[...,:3] |
collage_grid = collage[...,3:].astype(np.float16) |
success_counter += 1 |
else: |
continue |
id_str = f'{success_counter:08d}' |
src_path = str(save_dir / f'src_{id_str}.png') |
tgt_path = str(save_dir / f'tgt_{id_str}.png') |
flow_warped_path = str(save_dir / f'flow_warped_{id_str}.png') |
composite_path = str(save_dir / f'composite_{id_str}.png') |
flow_mask_path = str(save_dir / f'flow_mask_{id_str}.png') |
composite_mask_path = str(save_dir / f'composite_mask_{id_str}.png') |
flow_grid_path = str(save_dir / f'flow_warped_grid_{id_str}.npy') |
composite_grid_path = str(save_dir / f'composite_grid_{id_str}.npy') |
save_image(src_image, src_path) |
save_image(tgt_image, tgt_path) |
collage_pil = Image.fromarray((collage_rgb * 255).astype(np.uint8)) |
collage_pil.save(composite_path) |
splatted_pil = Image.fromarray((splatted_rgb * 255).astype(np.uint8)) |
splatted_pil.save(flow_warped_path) |
flow_mask_pil = Image.fromarray((flow_mask.astype(float) * 255).astype(np.uint8)) |
flow_mask_pil.save(flow_mask_path) |
composite_mask_pil = Image.fromarray((collage_mask.astype(float) * 255).astype(np.uint8)) |
composite_mask_pil.save(composite_mask_path) |
splatted_grid_t = torch.tensor(splatted_grid).moveaxis(-1, 0) |
splatted_grid_resized = torchvision.transforms.functional.resize(splatted_grid_t, (64,64)) |
collage_grid_t = torch.tensor(collage_grid).moveaxis(-1, 0) |
collage_grid_resized = torchvision.transforms.functional.resize(collage_grid_t, (64,64)) |
np.save(flow_grid_path, splatted_grid_resized.cpu().numpy()) |
np.save(composite_grid_path, collage_grid_resized.cpu().numpy()) |
del out |
del splatted_grid |
del collage_grid |
del frames |
del frames_to_visualize |
if __name__ == '__main__': |
try: |
main() |
except Exception as e: |
print(e) |
print('process failed') |