Spaces:
Running
on
L40S
Running
on
L40S
""" | |
input: json file with video, audio, motion paths | |
output: igraph object with nodes containing video, audio, motion, position, velocity, axis_angle, previous, next, frame, fps | |
preprocess: | |
1. assume you have a video for one speaker in folder, listed in | |
-- video_a.mp4 | |
-- video_b.mp4 | |
run process_video.py to extract frames and audio | |
""" | |
import os | |
import smplx | |
import torch | |
import numpy as np | |
import cv2 | |
import librosa | |
import igraph | |
import json | |
import utils.rotation_conversions as rc | |
from moviepy.editor import VideoClip, AudioFileClip | |
from tqdm import tqdm | |
import imageio | |
import tempfile | |
import argparse | |
def get_motion_reps_tensor(motion_tensor, smplx_model, pose_fps=30, device='cuda'): | |
bs, n, _ = motion_tensor.shape | |
motion_tensor = motion_tensor.float().to(device) | |
motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165) | |
output = smplx_model( | |
betas=torch.zeros(bs * n, 300, device=device), | |
transl=torch.zeros(bs * n, 3, device=device), | |
expression=torch.zeros(bs * n, 100, device=device), | |
jaw_pose=torch.zeros(bs * n, 3, device=device), | |
global_orient=torch.zeros(bs * n, 3, device=device), | |
body_pose=motion_tensor_reshaped[:, 3:21 * 3 + 3], | |
left_hand_pose=motion_tensor_reshaped[:, 25 * 3:40 * 3], | |
right_hand_pose=motion_tensor_reshaped[:, 40 * 3:55 * 3], | |
return_joints=True, | |
leye_pose=torch.zeros(bs * n, 3, device=device), | |
reye_pose=torch.zeros(bs * n, 3, device=device), | |
) | |
joints = output['joints'].reshape(bs, n, 127, 3)[:, :, :55, :] | |
dt = 1 / pose_fps | |
init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt | |
middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt) | |
final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt | |
vel = torch.cat([init_vel, middle_vel, final_vel], dim=1) | |
position = joints | |
rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3)) | |
rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6) | |
init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt | |
middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt) | |
final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt | |
angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3) | |
rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15) | |
return { | |
"position": position, | |
"velocity": vel, | |
"rotation": rot6d, | |
"axis_angle": motion_tensor, | |
"angular_velocity": angular_velocity, | |
"rep15d": rep15d, | |
} | |
def get_motion_reps(motion, smplx_model, pose_fps=30): | |
gt_motion_tensor = motion["poses"] | |
n = gt_motion_tensor.shape[0] | |
bs = 1 | |
gt_motion_tensor = torch.from_numpy(gt_motion_tensor).float().to(device).unsqueeze(0) | |
gt_motion_tensor_reshaped = gt_motion_tensor.reshape(bs * n, 165) | |
output = smplx_model( | |
betas=torch.zeros(bs * n, 300).to(device), | |
transl=torch.zeros(bs * n, 3).to(device), | |
expression=torch.zeros(bs * n, 100).to(device), | |
jaw_pose=torch.zeros(bs * n, 3).to(device), | |
global_orient=torch.zeros(bs * n, 3).to(device), | |
body_pose=gt_motion_tensor_reshaped[:, 3:21 * 3 + 3], | |
left_hand_pose=gt_motion_tensor_reshaped[:, 25 * 3:40 * 3], | |
right_hand_pose=gt_motion_tensor_reshaped[:, 40 * 3:55 * 3], | |
return_joints=True, | |
leye_pose=torch.zeros(bs * n, 3).to(device), | |
reye_pose=torch.zeros(bs * n, 3).to(device), | |
) | |
joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :] | |
dt = 1 / pose_fps | |
init_vel = (joints[1:2] - joints[0:1]) / dt | |
middle_vel = (joints[2:] - joints[:-2]) / (2 * dt) | |
final_vel = (joints[-1:] - joints[-2:-1]) / dt | |
vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0) | |
position = joints | |
rot_matrices = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3))[0] | |
rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy() | |
init_vel = (motion["poses"][1:2] - motion["poses"][0:1]) / dt | |
middle_vel = (motion["poses"][2:] - motion["poses"][:-2]) / (2 * dt) | |
final_vel = (motion["poses"][-1:] - motion["poses"][-2:-1]) / dt | |
angular_velocity = np.concatenate([init_vel, middle_vel, final_vel], axis=0).reshape(n, 55, 3) | |
rep15d = np.concatenate([ | |
position, | |
vel, | |
rot6d, | |
angular_velocity], | |
axis=2 | |
).reshape(n, 55*15) | |
return { | |
"position": position, | |
"velocity": vel, | |
"rotation": rot6d, | |
"axis_angle": motion["poses"], | |
"angular_velocity": angular_velocity, | |
"rep15d": rep15d, | |
"trans": motion["trans"] | |
} | |
def create_graph(json_path, smplx_model): | |
fps = 30 | |
data_meta = json.load(open(json_path, "r")) | |
graph = igraph.Graph(directed=True) | |
global_i = 0 | |
for data_item in data_meta: | |
video_path = os.path.join(data_item['video_path'], data_item['video_id'] + ".mp4") | |
# audio_path = os.path.join(data_item['audio_path'], data_item['video_id'] + ".wav") | |
motion_path = os.path.join(data_item['motion_path'], data_item['video_id'] + ".npz") | |
video_id = data_item.get("video_id", "") | |
motion = np.load(motion_path, allow_pickle=True) | |
motion_reps = get_motion_reps(motion, smplx_model) | |
position = motion_reps['position'] | |
velocity = motion_reps['velocity'] | |
trans = motion_reps['trans'] | |
axis_angle = motion_reps['axis_angle'] | |
# audio, sr = librosa.load(audio_path, sr=None) | |
# audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) | |
all_frames = [] | |
reader = imageio.get_reader(video_path) | |
all_frames = [] | |
for frame in reader: | |
all_frames.append(frame) | |
video_frames = np.array(all_frames) | |
min_frames = min(len(video_frames), position.shape[0]) | |
position = position[:min_frames] | |
velocity = velocity[:min_frames] | |
video_frames = video_frames[:min_frames] | |
# print(min_frames) | |
for i in tqdm(range(min_frames)): | |
if i == 0: | |
previous = -1 | |
next_node = global_i + 1 | |
elif i == min_frames - 1: | |
previous = global_i - 1 | |
next_node = -1 | |
else: | |
previous = global_i - 1 | |
next_node = global_i + 1 | |
graph.add_vertex( | |
idx=global_i, | |
name=video_id, | |
motion=motion_reps, | |
position=position[i], | |
velocity=velocity[i], | |
axis_angle=axis_angle[i], | |
trans=trans[i], | |
# audio=audio[], | |
video=video_frames[i], | |
previous=previous, | |
next=next_node, | |
frame=i, | |
fps=fps, | |
) | |
global_i += 1 | |
return graph | |
def create_edges(graph): | |
adaptive_length = [-4, -3, -2, -1, 1, 2, 3, 4] | |
# print() | |
for i, node in enumerate(graph.vs): | |
current_position = node['position'] | |
current_velocity = node['velocity'] | |
current_trans = node['trans'] | |
# print(current_position.shape, current_velocity.shape) | |
avg_position = np.zeros(current_position.shape[0]) | |
avg_velocity = np.zeros(current_position.shape[0]) | |
avg_trans = 0 | |
count = 0 | |
for node_offset in adaptive_length: | |
idx = i + node_offset | |
if idx < 0 or idx >= len(graph.vs): | |
continue | |
if node_offset < 0: | |
if graph.vs[idx]['next'] == -1:continue | |
else: | |
if graph.vs[idx]['previous'] == -1:continue | |
# add check | |
other_node = graph.vs[idx] | |
other_position = other_node['position'] | |
other_velocity = other_node['velocity'] | |
other_trans = other_node['trans'] | |
# print(other_position.shape, other_velocity.shape) | |
avg_position += np.linalg.norm(current_position - other_position, axis=1) | |
avg_velocity += np.linalg.norm(current_velocity - other_velocity, axis=1) | |
avg_trans += np.linalg.norm(current_trans - other_trans, axis=0) | |
count += 1 | |
if count == 0: | |
continue | |
threshold_position = avg_position / count | |
threshold_velocity = avg_velocity / count | |
threshold_trans = avg_trans / count | |
# print(threshold_position, threshold_velocity, threshold_trans) | |
for j, other_node in enumerate(graph.vs): | |
if i == j: | |
continue | |
if j == node['previous'] or j == node['next']: | |
graph.add_edge(i, j, is_continue=1) | |
continue | |
other_position = other_node['position'] | |
other_velocity = other_node['velocity'] | |
other_trans = other_node['trans'] | |
position_similarity = np.linalg.norm(current_position - other_position, axis=1) | |
velocity_similarity = np.linalg.norm(current_velocity - other_velocity, axis=1) | |
trans_similarity = np.linalg.norm(current_trans - other_trans, axis=0) | |
if trans_similarity < threshold_trans: | |
if np.sum(position_similarity < threshold_position) >= 45 and np.sum(velocity_similarity < threshold_velocity) >= 45: | |
graph.add_edge(i, j, is_continue=0) | |
print(f"nodes: {len(graph.vs)}, edges: {len(graph.es)}") | |
in_degrees = graph.indegree() | |
out_degrees = graph.outdegree() | |
avg_in_degree = sum(in_degrees) / len(in_degrees) | |
avg_out_degree = sum(out_degrees) / len(out_degrees) | |
print(f"Average In-degree: {avg_in_degree}") | |
print(f"Average Out-degree: {avg_out_degree}") | |
print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}") | |
print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}") | |
# igraph.plot(graph, target="/content/test.png", bbox=(1000, 1000), vertex_size=10) | |
return graph | |
def random_walk(graph, walk_length, start_node=None): | |
if start_node is None: | |
start_node = np.random.choice(graph.vs) | |
walk = [start_node] | |
is_continue = [1] | |
for _ in range(walk_length): | |
current_node = walk[-1] | |
neighbor_indices = graph.neighbors(current_node.index, mode='OUT') | |
if not neighbor_indices: | |
break | |
next_idx = np.random.choice(neighbor_indices) | |
edge_id = graph.get_eid(current_node.index, next_idx) | |
is_cont = graph.es[edge_id]['is_continue'] | |
walk.append(graph.vs[next_idx]) | |
is_continue.append(is_cont) | |
return walk, is_continue | |
def path_visualization(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False): | |
all_frames = [node['video'] for node in path] | |
average_dis_continue = 1 - sum(is_continue) / len(is_continue) | |
if verbose_continue: | |
print("average_dis_continue:", average_dis_continue) | |
duration = len(all_frames) / graph.vs[0]['fps'] | |
def make_frame(t): | |
idx = min(int(t * graph.vs[0]['fps']), len(all_frames) - 1) | |
return all_frames[idx] | |
video_clip = VideoClip(make_frame, duration=duration) | |
if audio_path is not None: | |
audio_clip = AudioFileClip(audio_path) | |
video_clip = video_clip.set_audio(audio_clip) | |
video_clip.write_videofile(save_path, codec='libx264', fps=graph.vs[0]['fps'], audio_codec='aac') | |
if return_motion: | |
all_motion = [node['axis_angle'] for node in path] | |
all_motion = np.stack(all_motion, 0) | |
return all_motion | |
def generate_transition_video(frame_start_path, frame_end_path, output_video_path): | |
import subprocess | |
import os | |
# Define the path to your model and inference script | |
model_path = "./frame-interpolation-pytorch/film_net_fp32.pt" | |
inference_script = "./frame-interpolation-pytorch/inference.py" | |
# Build the command to run the inference script | |
command = [ | |
"python", | |
inference_script, | |
model_path, | |
frame_start_path, | |
frame_end_path, | |
"--save_path", output_video_path, | |
"--gpu", | |
"--frames", "3", | |
"--fps", "30" | |
] | |
# Run the command | |
try: | |
subprocess.run(command, check=True) | |
print(f"Generated transition video saved at {output_video_path}") | |
except subprocess.CalledProcessError as e: | |
print(f"Error occurred while generating transition video: {e}") | |
def path_visualization_v2(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False): | |
''' | |
this is for hugging face demo for fast interpolation. our paper use a diffusion based interpolation method | |
''' | |
all_frames = [node['video'] for node in path] | |
average_dis_continue = 1 - sum(is_continue) / len(is_continue) | |
if verbose_continue: | |
print("average_dis_continue:", average_dis_continue) | |
duration = len(all_frames) / graph.vs[0]['fps'] | |
# First loop: Confirm where blending is needed | |
discontinuity_indices = [] | |
for i, cont in enumerate(is_continue): | |
if cont == 0: | |
discontinuity_indices.append(i) | |
# Identify blending positions without overlapping | |
blend_positions = [] | |
processed_frames = set() | |
for i in discontinuity_indices: | |
# Define the frames for blending: i-2 to i+2 | |
start_idx = i - 2 | |
end_idx = i + 2 | |
# Check index boundaries | |
if start_idx < 0 or end_idx >= len(all_frames): | |
continue # Skip if indices are out of bounds | |
# Check for overlapping frames | |
overlap = any(idx in processed_frames for idx in range(i - 1, i + 2)) | |
if overlap: | |
continue # Skip if frames have been processed | |
# Mark frames as processed | |
processed_frames.update(range(i - 1, i + 2)) | |
blend_positions.append(i) | |
# Second loop: Perform blending | |
temp_dir = tempfile.mkdtemp(prefix='blending_frames_') | |
for i in tqdm(blend_positions): | |
start_frame_idx = i - 2 | |
end_frame_idx = i + 2 | |
frame_start = all_frames[start_frame_idx] | |
frame_end = all_frames[end_frame_idx] | |
frame_start_path = os.path.join(temp_dir, f'frame_{start_frame_idx}.png') | |
frame_end_path = os.path.join(temp_dir, f'frame_{end_frame_idx}.png') | |
# Save the start and end frames as images | |
imageio.imwrite(frame_start_path, frame_start) | |
imageio.imwrite(frame_end_path, frame_end) | |
# Call FiLM API to generate video | |
generated_video_path = os.path.join(temp_dir, f'generated_{start_frame_idx}_{end_frame_idx}.mp4') | |
generate_transition_video(frame_start_path, frame_end_path, generated_video_path) | |
# Read the generated video frames | |
reader = imageio.get_reader(generated_video_path) | |
generated_frames = [frame for frame in reader] | |
reader.close() | |
# Replace the middle three frames (i-1, i, i+1) in all_frames | |
total_generated_frames = len(generated_frames) | |
if total_generated_frames < 5: | |
print(f"Generated video has insufficient frames ({total_generated_frames}). Skipping blending at position {i}.") | |
continue | |
middle_start = 1 # Start index for middle 3 frames | |
middle_frames = generated_frames[middle_start:middle_start+3] | |
for idx, frame_idx in enumerate(range(i - 1, i + 2)): | |
all_frames[frame_idx] = middle_frames[idx] | |
# Create the video clip | |
def make_frame(t): | |
idx = min(int(t * graph.vs[0]['fps']), len(all_frames) - 1) | |
return all_frames[idx] | |
video_clip = VideoClip(make_frame, duration=duration) | |
if audio_path is not None: | |
audio_clip = AudioFileClip(audio_path) | |
video_clip = video_clip.set_audio(audio_clip) | |
video_clip.write_videofile(save_path, codec='libx264', fps=graph.vs[0]['fps'], audio_codec='aac') | |
if return_motion: | |
all_motion = [node['axis_angle'] for node in path] | |
all_motion = np.stack(all_motion, 0) | |
return all_motion | |
def graph_pruning(graph): | |
ascc = graph.clusters(mode="STRONG") | |
lascc = ascc.giant() | |
print(f"before nodes: {len(graph.vs)}, edges: {len(graph.es)}") | |
print(f"after nodes: {len(lascc.vs)}, edges: {len(lascc.es)}") | |
in_degrees = lascc.indegree() | |
out_degrees = lascc.outdegree() | |
avg_in_degree = sum(in_degrees) / len(in_degrees) | |
avg_out_degree = sum(out_degrees) / len(out_degrees) | |
print(f"Average In-degree: {avg_in_degree}") | |
print(f"Average Out-degree: {avg_out_degree}") | |
print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}") | |
print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}") | |
return lascc | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--json_save_path", type=str, default="") | |
parser.add_argument("--graph_save_path", type=str, default="") | |
args = parser.parse_args() | |
json_path = args.json_save_path | |
graph_path = args.graph_save_path | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
smplx_model = smplx.create( | |
"./emage/smplx_models/", | |
model_type='smplx', | |
gender='NEUTRAL_2020', | |
use_face_contour=False, | |
num_betas=300, | |
num_expression_coeffs=100, | |
ext='npz', | |
use_pca=False, | |
).to(device).eval() | |
# single_test | |
# graph = create_graph('/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/Abortion_Laws_-_Last_Week_Tonight_with_John_Oliver_HBO-DRauXXz6t0Y.webm.json') | |
graph = create_graph(json_path, smplx_model) | |
graph = create_edges(graph) | |
# pool_path = "/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/show-oliver-test.pkl" | |
# graph = igraph.Graph.Read_Pickle(fname=pool_path) | |
# graph = igraph.Graph.Read_Pickle(fname="/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/test.pkl") | |
walk, is_continue = random_walk(graph, 100) | |
motion = path_visualization(graph, walk, is_continue, "./test.mp4", audio_path=None, verbose_continue=True, return_motion=True) | |
# print(motion.shape) | |
save_graph = graph.write_pickle(fname=graph_path) | |
graph = graph_pruning(graph) | |
# show-oliver | |
# json_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/" | |
# pre_node_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/cached_graph/show_oliver_test/" | |
# for json_file in tqdm(os.listdir(json_path)): | |
# graph = create_graph(os.path.join(json_path, json_file)) | |
# graph = create_edges(graph) | |
# if not len(graph.vs) >= 1500: | |
# print(f"skip: {len(graph.vs)}", json_file) | |
# graph.write_pickle(fname=os.path.join(pre_node_path, json_file.split(".")[0] + ".pkl")) | |
# print(f"Graph saved at {json_file.split('.')[0]}.pkl") |