mambazjp's picture
Upload 58 files
355b5d6
raw
history blame
2.18 kB
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self,
filter_channels,
merge_layer=0,
res_layers=[],
norm='group',
last_op=None):
super(MLP, self).__init__()
self.filters = nn.ModuleList()
self.norms = nn.ModuleList()
self.merge_layer = merge_layer if merge_layer > 0 else len(filter_channels) // 2
self.res_layers = res_layers
self.norm = norm
self.last_op = last_op
for l in range(0, len(filter_channels)-1):
if l in self.res_layers:
self.filters.append(nn.Conv1d(
filter_channels[l] + filter_channels[0],
filter_channels[l+1],
1))
else:
self.filters.append(nn.Conv1d(
filter_channels[l],
filter_channels[l+1],
1))
if l != len(filter_channels)-2:
if norm == 'group':
self.norms.append(nn.GroupNorm(32, filter_channels[l+1]))
elif norm == 'batch':
self.norms.append(nn.BatchNorm1d(filter_channels[l+1]))
def forward(self, feature):
'''
feature may include multiple view inputs
args:
feature: [B, C_in, N]
return:
[B, C_out, N] prediction
'''
y = feature
tmpy = feature
phi = None
for i, f in enumerate(self.filters):
y = f(
y if i not in self.res_layers
else torch.cat([y, tmpy], 1)
)
if i != len(self.filters)-1:
if self.norm not in ['batch', 'group']:
y = F.leaky_relu(y)
else:
y = F.leaky_relu(self.norms[i](y))
if i == self.merge_layer:
phi = y.clone()
if self.last_op is not None:
y = self.last_op(y)
return y, phi