Spaces:
Runtime error
Runtime error
''' | |
* Copyright (c) 2023 Salesforce, Inc. | |
* All rights reserved. | |
* SPDX-License-Identifier: Apache License 2.0 | |
* For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ | |
* By Can Qin | |
* Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet | |
* Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala | |
* Modified from MMCV repo: From https://github.com/open-mmlab/mmcv | |
* Copyright (c) OpenMMLab. All rights reserved. | |
''' | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Encoding(nn.Module): | |
"""Encoding Layer: a learnable residual encoder. | |
Input is of shape (batch_size, channels, height, width). | |
Output is of shape (batch_size, num_codes, channels). | |
Args: | |
channels: dimension of the features or feature channels | |
num_codes: number of code words | |
""" | |
def __init__(self, channels, num_codes): | |
super(Encoding, self).__init__() | |
# init codewords and smoothing factor | |
self.channels, self.num_codes = channels, num_codes | |
std = 1. / ((num_codes * channels)**0.5) | |
# [num_codes, channels] | |
self.codewords = nn.Parameter( | |
torch.empty(num_codes, channels, | |
dtype=torch.float).uniform_(-std, std), | |
requires_grad=True) | |
# [num_codes] | |
self.scale = nn.Parameter( | |
torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), | |
requires_grad=True) | |
def scaled_l2(x, codewords, scale): | |
num_codes, channels = codewords.size() | |
batch_size = x.size(0) | |
reshaped_scale = scale.view((1, 1, num_codes)) | |
expanded_x = x.unsqueeze(2).expand( | |
(batch_size, x.size(1), num_codes, channels)) | |
reshaped_codewords = codewords.view((1, 1, num_codes, channels)) | |
scaled_l2_norm = reshaped_scale * ( | |
expanded_x - reshaped_codewords).pow(2).sum(dim=3) | |
return scaled_l2_norm | |
def aggregate(assignment_weights, x, codewords): | |
num_codes, channels = codewords.size() | |
reshaped_codewords = codewords.view((1, 1, num_codes, channels)) | |
batch_size = x.size(0) | |
expanded_x = x.unsqueeze(2).expand( | |
(batch_size, x.size(1), num_codes, channels)) | |
encoded_feat = (assignment_weights.unsqueeze(3) * | |
(expanded_x - reshaped_codewords)).sum(dim=1) | |
return encoded_feat | |
def forward(self, x): | |
assert x.dim() == 4 and x.size(1) == self.channels | |
# [batch_size, channels, height, width] | |
batch_size = x.size(0) | |
# [batch_size, height x width, channels] | |
x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() | |
# assignment_weights: [batch_size, channels, num_codes] | |
assignment_weights = F.softmax( | |
self.scaled_l2(x, self.codewords, self.scale), dim=2) | |
# aggregate | |
encoded_feat = self.aggregate(assignment_weights, x, self.codewords) | |
return encoded_feat | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ | |
f'x{self.channels})' | |
return repr_str | |