# Copyright 2024 Adobe. All rights reserved. #%% import numpy as np import torchvision import cv2 import tqdm import torchvision.transforms.functional as F from PIL import Image from torchvision.utils import save_image import time import os import pathlib from torch.utils.data import DataLoader # %matplotlib inline from kornia.filters.median import MedianBlur median_filter = MedianBlur(kernel_size=(15,15)) from moments_dataset import MomentsDataset try: from processing_utils import aggregate_frames import processing_utils except Exception as e: print(e) print('process failed') exit() import torch # %% def load_image(img_path, resize_size=None,crop_size=None): img1_pil = Image.open(img_path) img1_frames = torchvision.transforms.functional.pil_to_tensor(img1_pil) if resize_size: img1_frames = torchvision.transforms.functional.resize(img1_frames, resize_size) if crop_size: img1_frames = torchvision.transforms.functional.center_crop(img1_frames, crop_size) img1_batch = torch.unsqueeze(img1_frames, dim=0) return img1_batch def get_grid(size): y = np.repeat(np.arange(size)[None, ...], size) y = y.reshape(size, size) x = y.transpose() out = np.stack([y,x], -1) return out def collage_from_frames(frames_t): # decide forward or backward if np.random.randint(0, 2) == 0: # flip frames_t = frames_t.flip(0) # decide how deep you would go tgt_idx_guess = np.random.randint(1, min(len(frames_t), 20)) tgt_idx = 1 pairwise_flows = [] flow = None init_time = time.time() unsmoothed_agg = None for cur_idx in range(1, tgt_idx_guess+1): # cur_idx = i+1 cur_flow, pairwise_flows = aggregate_frames(frames_t[:cur_idx+1] , pairwise_flows, unsmoothed_agg) # passing pairwise flows for efficiency unsmoothed_agg = cur_flow.clone() agg_cur_flow = median_filter(cur_flow) flow_norm = torch.norm(agg_cur_flow.squeeze(), dim=0).flatten() # flow_10 = np.percentile(flow_norm.cpu().numpy(), 10) flow_90 = np.percentile(flow_norm.cpu().numpy(), 90) # flow_10 = np.percentile(flow_norm.cpu().numpy(), 10) flow_90 = np.percentile(flow_norm.cpu().numpy(), 90) flow_95 = np.percentile(flow_norm.cpu().numpy(), 95) if cur_idx == 5: # if still small flow then drop if flow_95 < 20.0: # no motion in the frame. skip print('flow is tiny :(') return None if cur_idx == tgt_idx_guess-1: # if still small flow then drop if flow_95 < 50.0: # no motion in the frame. skip print('flow is tiny :(') return None if flow is None: # means first iter if flow_90 < 1.0: # no motion in the frame. skip return None flow = agg_cur_flow if flow_90 <= 300: # maybe should increase this part # update idx tgt_idx = cur_idx flow = agg_cur_flow else: break final_time = time.time() print('time guessing idx', final_time - init_time) _, flow_warping_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=None, alpha_mask=None) flow_warping_mask = flow_warping_mask.squeeze().numpy() > 0.5 if np.mean(flow_warping_mask) < 0.6: return src_array = frames_t[0].moveaxis(0, -1).cpu().numpy() * 1.0 init_time = time.time() depth = get_depth_from_array(frames_t[0]) finish_time = time.time() print('time getting depth', finish_time - init_time) # flow, pairwise_flows = aggregate_frames(frames_t) # agg_flow = median_filter(flow) src_array_uint = src_array * 255.0 src_array_uint = src_array_uint.astype(np.uint8) segments = processing_utils.mask_generator.generate(src_array_uint) size = src_array.shape[1] grid_np = get_grid(size).astype(np.float16) / size # 512 x 512 x 2get grid_t = torch.tensor(grid_np).moveaxis(-1, 0) # 512 x 512 x 2 collage, canvas_alpha, lost_alpha = collage_warp(src_array, flow.squeeze(), depth, segments, grid_array=grid_np) lost_alpha_t = torch.tensor(lost_alpha).squeeze().unsqueeze(0) warping_alpha = (lost_alpha_t < 0.5).float() rgb_grid_splatted, actual_warped_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=grid_t, alpha_mask=warping_alpha) # basic blending now # print('rgb grid splatted', rgb_grid_splatted.shape) warped_src = (rgb_grid_splatted * actual_warped_mask).moveaxis(0, -1).cpu().numpy() canvas_alpha_mask = canvas_alpha == 0.0 collage_mask = canvas_alpha.squeeze() + actual_warped_mask.squeeze().cpu().numpy() collage_mask = collage_mask > 0.5 composite_grid = warped_src * canvas_alpha_mask + collage rgb_grid_splatted_np = rgb_grid_splatted.moveaxis(0, -1).cpu().numpy() return frames_t[0], frames_t[tgt_idx], rgb_grid_splatted_np, composite_grid, flow_warping_mask, collage_mask def collage_warp(rgb_array, flow, depth, segments, grid_array): avg_depths = [] avg_flows = [] # src_array = src_array.moveaxis(-1, 0).cpu().numpy() #np.array(Image.open(src_path).convert('RGB')) / 255.0 src_array = np.concatenate([rgb_array, grid_array], axis=-1) canvas = np.zeros_like(src_array) canvas_alpha = np.zeros_like(canvas[...,-1:]).astype(float) lost_regions = np.zeros_like(canvas[...,-1:]).astype(float) z_buffer = np.ones_like(depth)[..., None] * -1.0 unsqueezed_depth = depth[..., None] affine_transforms = [] filtered_segments = [] for segment in segments: if segment['area'] > 300: filtered_segments.append(segment) for segment in filtered_segments: seg_mask = segment['segmentation'] avg_flow = torch.mean(flow[:, seg_mask],dim=1) avg_flows.append(avg_flow) # median depth (conversion from disparity) avg_depth = torch.median(1.0 / (depth[seg_mask] + 1e-6)) avg_depths.append(avg_depth) all_y, all_x = np.nonzero(segment['segmentation']) rand_indices = np.random.randint(0, len(all_y), size=50) rand_x = [all_x[i] for i in rand_indices] rand_y = [all_y[i] for i in rand_indices] src_pairs = [(x, y) for x, y in zip(rand_x, rand_y)] # tgt_pairs = [(x + w, y) for x, y in src_pairs] tgt_pairs = [] # print('estimating affine') # TODO this can be faster for i in range(len(src_pairs)): x, y = src_pairs[i] dx, dy = flow[:, y, x] tgt_pairs.append((x+dx, y+dy)) # affine_trans, inliers = cv2.estimateAffine2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32)) affine_trans, inliers = cv2.estimateAffinePartial2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32)) # print('num inliers', np.sum(inliers)) # # print('num inliers', np.sum(inliers)) affine_transforms.append(affine_trans) depth_sorted_indices = np.arange(len(avg_depths)) depth_sorted_indices = sorted(depth_sorted_indices, key=lambda x: avg_depths[x]) # sorted_masks = [] # print('warping stuff') for idx in depth_sorted_indices: # sorted_masks.append(mask[idx]) alpha_mask = filtered_segments[idx]['segmentation'][..., None] * (lost_regions < 0.5).astype(float) src_rgba = np.concatenate([src_array, alpha_mask, unsqueezed_depth], axis=-1) warp_dst = cv2.warpAffine(src_rgba, affine_transforms[idx], (src_array.shape[1], src_array.shape[0])) warped_mask = warp_dst[..., -2:-1] # this is warped alpha warped_depth = warp_dst[..., -1:] warped_rgb = warp_dst[...,:-2] good_z_region = warped_depth > z_buffer warped_mask = np.logical_and(warped_mask > 0.5, good_z_region).astype(float) kernel = np.ones((3,3), float) # print('og masked shape', warped_mask.shape) # warped_mask = cv2.erode(warped_mask,(5,5))[..., None] # print('eroded masked shape', warped_mask.shape) canvas_alpha += cv2.erode(warped_mask,kernel)[..., None] lost_regions += alpha_mask canvas = canvas * (1.0 - warped_mask) + warped_mask * warped_rgb # TODO check if need to dialate here z_buffer = z_buffer * (1.0 - warped_mask) + warped_mask * warped_depth # TODO check if need to dialate here # print('max lost region', np.max(lost_regions)) return canvas, canvas_alpha, lost_regions def get_depth_from_array(img_t): img_arr = img_t.moveaxis(0, -1).cpu().numpy() * 1.0 # print(img_arr.shape) img_arr *= 255.0 img_arr = img_arr.astype(np.uint8) input_batch = processing_utils.depth_transform(img_arr).cuda() with torch.no_grad(): prediction = processing_utils.midas(input_batch) prediction = torch.nn.functional.interpolate( prediction.unsqueeze(1), size=img_arr.shape[:2], mode="bicubic", align_corners=False, ).squeeze() output = prediction.cpu() return output # %% def main(): print('starting main') video_folder = './example_videos' save_dir = pathlib.Path('./processed_data') process_video_folder(video_folder, save_dir) def process_video_folder(video_folder, save_dir): all_counter = 0 success_counter = 0 # save_folder = pathlib.Path('/dev/shm/processed') # save_dir = save_folder / foldername #pathlib.Path('/sensei-fs/users/halzayer/collage2photo/testing_partitioning_dilate_extreme') os.makedirs(save_dir, exist_ok=True) dataset = MomentsDataset(videos_folder=video_folder, num_frames=20, samples_per_video=5) batch_size = 4 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) with torch.no_grad(): for i, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataset)//batch_size): frames_to_visualize = batch["frames"] bs = frames_to_visualize.shape[0] for j in range(bs): frames = frames_to_visualize[j] caption = batch["caption"][j] collage_init_time = time.time() out = collage_from_frames(frames) collage_finish_time = time.time() print('collage processing time', collage_finish_time - collage_init_time) all_counter += 1 if out is not None: src_image, tgt_image, splatted, collage, flow_mask, collage_mask = out splatted_rgb = splatted[...,:3] splatted_grid = splatted[...,3:].astype(np.float16) collage_rgb = collage[...,:3] collage_grid = collage[...,3:].astype(np.float16) success_counter += 1 else: continue id_str = f'{success_counter:08d}' src_path = str(save_dir / f'src_{id_str}.png') tgt_path = str(save_dir / f'tgt_{id_str}.png') flow_warped_path = str(save_dir / f'flow_warped_{id_str}.png') composite_path = str(save_dir / f'composite_{id_str}.png') flow_mask_path = str(save_dir / f'flow_mask_{id_str}.png') composite_mask_path = str(save_dir / f'composite_mask_{id_str}.png') flow_grid_path = str(save_dir / f'flow_warped_grid_{id_str}.npy') composite_grid_path = str(save_dir / f'composite_grid_{id_str}.npy') save_image(src_image, src_path) save_image(tgt_image, tgt_path) collage_pil = Image.fromarray((collage_rgb * 255).astype(np.uint8)) collage_pil.save(composite_path) splatted_pil = Image.fromarray((splatted_rgb * 255).astype(np.uint8)) splatted_pil.save(flow_warped_path) flow_mask_pil = Image.fromarray((flow_mask.astype(float) * 255).astype(np.uint8)) flow_mask_pil.save(flow_mask_path) composite_mask_pil = Image.fromarray((collage_mask.astype(float) * 255).astype(np.uint8)) composite_mask_pil.save(composite_mask_path) splatted_grid_t = torch.tensor(splatted_grid).moveaxis(-1, 0) splatted_grid_resized = torchvision.transforms.functional.resize(splatted_grid_t, (64,64)) collage_grid_t = torch.tensor(collage_grid).moveaxis(-1, 0) collage_grid_resized = torchvision.transforms.functional.resize(collage_grid_t, (64,64)) np.save(flow_grid_path, splatted_grid_resized.cpu().numpy()) np.save(composite_grid_path, collage_grid_resized.cpu().numpy()) del out del splatted_grid del collage_grid del frames del frames_to_visualize #%% if __name__ == '__main__': try: main() except Exception as e: print(e) print('process failed')