File size: 3,091 Bytes
8b79d57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from typing import Optional, List
from .mobilenetv3 import MobileNetV3LargeEncoder
from .resnet import ResNet50Encoder
from .lraspp import LRASPP
from .decoder import RecurrentDecoder, Projection
from .fast_guided_filter import FastGuidedFilterRefiner
from .deep_guided_filter import DeepGuidedFilterRefiner
class MattingNetwork(nn.Module):
def __init__(self,
variant: str = 'mobilenetv3',
refiner: str = 'deep_guided_filter',
pretrained_backbone: bool = False):
super().__init__()
assert variant in ['mobilenetv3', 'resnet50']
assert refiner in ['fast_guided_filter', 'deep_guided_filter']
if variant == 'mobilenetv3':
self.backbone = MobileNetV3LargeEncoder(pretrained_backbone)
self.aspp = LRASPP(960, 128)
self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16])
else:
self.backbone = ResNet50Encoder(pretrained_backbone)
self.aspp = LRASPP(2048, 256)
self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16])
self.project_mat = Projection(16, 4)
self.project_seg = Projection(16, 1)
if refiner == 'deep_guided_filter':
self.refiner = DeepGuidedFilterRefiner()
else:
self.refiner = FastGuidedFilterRefiner()
def forward(self,
src: Tensor,
r1: Optional[Tensor] = None,
r2: Optional[Tensor] = None,
r3: Optional[Tensor] = None,
r4: Optional[Tensor] = None,
downsample_ratio: float = 1,
segmentation_pass: bool = False):
if downsample_ratio != 1:
src_sm = self._interpolate(src, scale_factor=downsample_ratio)
else:
src_sm = src
f1, f2, f3, f4 = self.backbone(src_sm)
f4 = self.aspp(f4)
hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
if not segmentation_pass:
fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
if downsample_ratio != 1:
fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
fgr = fgr_residual + src
fgr = fgr.clamp(0., 1.)
pha = pha.clamp(0., 1.)
return [fgr, pha, *rec]
else:
seg = self.project_seg(hid)
return [seg, *rec]
def _interpolate(self, x: Tensor, scale_factor: float):
if x.ndim == 5:
B, T = x.shape[:2]
x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor,
mode='bilinear', align_corners=False, recompute_scale_factor=False)
x = x.unflatten(0, (B, T))
else:
x = F.interpolate(x, scale_factor=scale_factor,
mode='bilinear', align_corners=False, recompute_scale_factor=False)
return x
|