|
import torch |
|
from torch import Tensor |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from typing import Optional, List |
|
|
|
from .mobilenetv3 import MobileNetV3LargeEncoder |
|
from .resnet import ResNet50Encoder |
|
from .lraspp import LRASPP |
|
from .decoder import RecurrentDecoder, Projection |
|
from .fast_guided_filter import FastGuidedFilterRefiner |
|
from .deep_guided_filter import DeepGuidedFilterRefiner |
|
|
|
class MattingNetwork(nn.Module): |
|
def __init__(self, |
|
variant: str = 'mobilenetv3', |
|
refiner: str = 'deep_guided_filter', |
|
pretrained_backbone: bool = False): |
|
super().__init__() |
|
assert variant in ['mobilenetv3', 'resnet50'] |
|
assert refiner in ['fast_guided_filter', 'deep_guided_filter'] |
|
|
|
if variant == 'mobilenetv3': |
|
self.backbone = MobileNetV3LargeEncoder(pretrained_backbone) |
|
self.aspp = LRASPP(960, 128) |
|
self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16]) |
|
else: |
|
self.backbone = ResNet50Encoder(pretrained_backbone) |
|
self.aspp = LRASPP(2048, 256) |
|
self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16]) |
|
|
|
self.project_mat = Projection(16, 4) |
|
self.project_seg = Projection(16, 1) |
|
|
|
if refiner == 'deep_guided_filter': |
|
self.refiner = DeepGuidedFilterRefiner() |
|
else: |
|
self.refiner = FastGuidedFilterRefiner() |
|
|
|
def forward(self, |
|
src: Tensor, |
|
r1: Optional[Tensor] = None, |
|
r2: Optional[Tensor] = None, |
|
r3: Optional[Tensor] = None, |
|
r4: Optional[Tensor] = None, |
|
downsample_ratio: float = 1, |
|
segmentation_pass: bool = False): |
|
|
|
if downsample_ratio != 1: |
|
src_sm = self._interpolate(src, scale_factor=downsample_ratio) |
|
else: |
|
src_sm = src |
|
|
|
f1, f2, f3, f4 = self.backbone(src_sm) |
|
f4 = self.aspp(f4) |
|
hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) |
|
|
|
if not segmentation_pass: |
|
fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3) |
|
if downsample_ratio != 1: |
|
fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) |
|
fgr = fgr_residual + src |
|
fgr = fgr.clamp(0., 1.) |
|
pha = pha.clamp(0., 1.) |
|
return [fgr, pha, *rec] |
|
else: |
|
seg = self.project_seg(hid) |
|
return [seg, *rec] |
|
|
|
def _interpolate(self, x: Tensor, scale_factor: float): |
|
if x.ndim == 5: |
|
B, T = x.shape[:2] |
|
x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor, |
|
mode='bilinear', align_corners=False, recompute_scale_factor=False) |
|
x = x.unflatten(0, (B, T)) |
|
else: |
|
x = F.interpolate(x, scale_factor=scale_factor, |
|
mode='bilinear', align_corners=False, recompute_scale_factor=False) |
|
return x |
|
|