Spaces:
Running
on
L40S
Running
on
L40S
import torch.nn as nn | |
import torch | |
import numpy as np | |
from .skeleton import ResidualBlock, SkeletonResidual, residual_ratio, SkeletonConv, SkeletonPool, find_neighbor, build_edge_topology | |
class LocalEncoder(nn.Module): | |
def __init__(self, args, topology): | |
super(LocalEncoder, self).__init__() | |
args.channel_base = 6 | |
args.activation = "tanh" | |
args.use_residual_blocks=True | |
args.z_dim=1024 | |
args.temporal_scale=8 | |
args.kernel_size=4 | |
args.num_layers=args.vae_layer | |
args.skeleton_dist=2 | |
args.extra_conv=0 | |
# check how to reflect in 1d | |
args.padding_mode="constant" | |
args.skeleton_pool="mean" | |
args.upsampling="linear" | |
self.topologies = [topology] | |
self.channel_base = [args.channel_base] | |
self.channel_list = [] | |
self.edge_num = [len(topology)] | |
self.pooling_list = [] | |
self.layers = nn.ModuleList() | |
self.args = args | |
# self.convs = [] | |
kernel_size = args.kernel_size | |
kernel_even = False if kernel_size % 2 else True | |
padding = (kernel_size - 1) // 2 | |
bias = True | |
self.grow = args.vae_grow | |
for i in range(args.num_layers): | |
self.channel_base.append(self.channel_base[-1]*self.grow[i]) | |
for i in range(args.num_layers): | |
seq = [] | |
neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist) | |
in_channels = self.channel_base[i] * self.edge_num[i] | |
out_channels = self.channel_base[i + 1] * self.edge_num[i] | |
if i == 0: | |
self.channel_list.append(in_channels) | |
self.channel_list.append(out_channels) | |
last_pool = True if i == args.num_layers - 1 else False | |
# (T, J, D) => (T, J', D) | |
pool = SkeletonPool(edges=self.topologies[i], pooling_mode=args.skeleton_pool, | |
channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool) | |
if args.use_residual_blocks: | |
# (T, J, D) => (T/2, J', 2D) | |
seq.append(SkeletonResidual(self.topologies[i], neighbour_list, joint_num=self.edge_num[i], in_channels=in_channels, out_channels=out_channels, | |
kernel_size=kernel_size, stride=2, padding=padding, padding_mode=args.padding_mode, bias=bias, | |
extra_conv=args.extra_conv, pooling_mode=args.skeleton_pool, activation=args.activation, last_pool=last_pool)) | |
else: | |
for _ in range(args.extra_conv): | |
# (T, J, D) => (T, J, D) | |
seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, | |
joint_num=self.edge_num[i], kernel_size=kernel_size - 1 if kernel_even else kernel_size, | |
stride=1, | |
padding=padding, padding_mode=args.padding_mode, bias=bias)) | |
seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh()) | |
# (T, J, D) => (T/2, J, 2D) | |
seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, | |
joint_num=self.edge_num[i], kernel_size=kernel_size, stride=2, | |
padding=padding, padding_mode=args.padding_mode, bias=bias, add_offset=False, | |
in_offset_channel=3 * self.channel_base[i] // self.channel_base[0])) | |
# self.convs.append(seq[-1]) | |
seq.append(pool) | |
seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh()) | |
self.layers.append(nn.Sequential(*seq)) | |
self.topologies.append(pool.new_edges) | |
self.pooling_list.append(pool.pooling_list) | |
self.edge_num.append(len(self.topologies[-1])) | |
# in_features = self.channel_base[-1] * len(self.pooling_list[-1]) | |
# in_features *= int(args.temporal_scale / 2) | |
# self.reduce = nn.Linear(in_features, args.z_dim) | |
# self.mu = nn.Linear(in_features, args.z_dim) | |
# self.logvar = nn.Linear(in_features, args.z_dim) | |
def forward(self, input): | |
#bs, n, c = input.shape[0], input.shape[1], input.shape[2] | |
output = input.permute(0, 2, 1)#input.reshape(bs, n, -1, 6) | |
for layer in self.layers: | |
output = layer(output) | |
#output = output.view(output.shape[0], -1) | |
output = output.permute(0, 2, 1) | |
return output | |
class ResBlock(nn.Module): | |
def __init__(self, channel): | |
super(ResBlock, self).__init__() | |
self.model = nn.Sequential( | |
nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1), | |
) | |
def forward(self, x): | |
residual = x | |
out = self.model(x) | |
out += residual | |
return out | |
class VQDecoderV3(nn.Module): | |
def __init__(self, args): | |
super(VQDecoderV3, self).__init__() | |
n_up = args.vae_layer | |
channels = [] | |
for i in range(n_up-1): | |
channels.append(args.vae_length) | |
channels.append(args.vae_length) | |
channels.append(args.vae_test_dim) | |
input_size = args.vae_length | |
n_resblk = 2 | |
assert len(channels) == n_up + 1 | |
if input_size == channels[0]: | |
layers = [] | |
else: | |
layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] | |
for i in range(n_resblk): | |
layers += [ResBlock(channels[0])] | |
# channels = channels | |
for i in range(n_up): | |
layers += [ | |
nn.Upsample(scale_factor=2, mode='nearest'), | |
nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(0.2, inplace=True) | |
] | |
layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] | |
self.main = nn.Sequential(*layers) | |
# self.main.apply(init_weight) | |
def forward(self, inputs): | |
inputs = inputs.permute(0, 2, 1) | |
outputs = self.main(inputs).permute(0, 2, 1) | |
return outputs | |
def reparameterize(mu, logvar): | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
return mu + eps * std | |
class VAEConv(nn.Module): | |
def __init__(self, args): | |
super(VAEConv, self).__init__() | |
# self.encoder = VQEncoderV3(args) | |
# self.decoder = VQDecoderV3(args) | |
self.fc_mu = nn.Linear(args.vae_length, args.vae_length) | |
self.fc_logvar = nn.Linear(args.vae_length, args.vae_length) | |
self.variational = args.variational | |
def forward(self, inputs): | |
pre_latent = self.encoder(inputs) | |
mu, logvar = None, None | |
if self.variational: | |
mu = self.fc_mu(pre_latent) | |
logvar = self.fc_logvar(pre_latent) | |
pre_latent = reparameterize(mu, logvar) | |
rec_pose = self.decoder(pre_latent) | |
return { | |
"poses_feat":pre_latent, | |
"rec_pose": rec_pose, | |
"pose_mu": mu, | |
"pose_logvar": logvar, | |
} | |
def map2latent(self, inputs): | |
pre_latent = self.encoder(inputs) | |
if self.variational: | |
mu = self.fc_mu(pre_latent) | |
logvar = self.fc_logvar(pre_latent) | |
pre_latent = reparameterize(mu, logvar) | |
return pre_latent | |
def decode(self, pre_latent): | |
rec_pose = self.decoder(pre_latent) | |
return rec_pose | |
class VAESKConv(VAEConv): | |
def __init__(self, args, model_save_path="./emage/"): | |
# args = args() | |
super(VAESKConv, self).__init__(args) | |
smpl_fname = model_save_path +'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz' | |
smpl_data = np.load(smpl_fname, encoding='latin1') | |
parents = smpl_data['kintree_table'][0].astype(np.int32) | |
edges = build_edge_topology(parents) | |
self.encoder = LocalEncoder(args, edges) | |
self.decoder = VQDecoderV3(args) |