# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import torch import torch.nn as nn import torch.nn.functional as F class MLP(nn.Module): def __init__(self, filter_channels, merge_layer=0, res_layers=[], norm='group', last_op=None): super(MLP, self).__init__() self.filters = nn.ModuleList() self.norms = nn.ModuleList() self.merge_layer = merge_layer if merge_layer > 0 else len(filter_channels) // 2 self.res_layers = res_layers self.norm = norm self.last_op = last_op for l in range(0, len(filter_channels)-1): if l in self.res_layers: self.filters.append(nn.Conv1d( filter_channels[l] + filter_channels[0], filter_channels[l+1], 1)) else: self.filters.append(nn.Conv1d( filter_channels[l], filter_channels[l+1], 1)) if l != len(filter_channels)-2: if norm == 'group': self.norms.append(nn.GroupNorm(32, filter_channels[l+1])) elif norm == 'batch': self.norms.append(nn.BatchNorm1d(filter_channels[l+1])) def forward(self, feature): ''' feature may include multiple view inputs args: feature: [B, C_in, N] return: [B, C_out, N] prediction ''' y = feature tmpy = feature phi = None for i, f in enumerate(self.filters): y = f( y if i not in self.res_layers else torch.cat([y, tmpy], 1) ) if i != len(self.filters)-1: if self.norm not in ['batch', 'group']: y = F.leaky_relu(y) else: y = F.leaky_relu(self.norms[i](y)) if i == self.merge_layer: phi = y.clone() if self.last_op is not None: y = self.last_op(y) return y, phi