File size: 3,358 Bytes
8b79d57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch
from torch.nn import functional as F
# --------------------------------------------------------------------------------- Train Loss
def matting_loss(pred_fgr, pred_pha, true_fgr, true_pha):
"""
Args:
pred_fgr: Shape(B, T, 3, H, W)
pred_pha: Shape(B, T, 1, H, W)
true_fgr: Shape(B, T, 3, H, W)
true_pha: Shape(B, T, 1, H, W)
"""
loss = dict()
# Alpha losses
loss['pha_l1'] = F.l1_loss(pred_pha, true_pha)
loss['pha_laplacian'] = laplacian_loss(pred_pha.flatten(0, 1), true_pha.flatten(0, 1))
loss['pha_coherence'] = F.mse_loss(pred_pha[:, 1:] - pred_pha[:, :-1],
true_pha[:, 1:] - true_pha[:, :-1]) * 5
# Foreground losses
true_msk = true_pha.gt(0)
pred_fgr = pred_fgr * true_msk
true_fgr = true_fgr * true_msk
loss['fgr_l1'] = F.l1_loss(pred_fgr, true_fgr)
loss['fgr_coherence'] = F.mse_loss(pred_fgr[:, 1:] - pred_fgr[:, :-1],
true_fgr[:, 1:] - true_fgr[:, :-1]) * 5
# Total
loss['total'] = loss['pha_l1'] + loss['pha_coherence'] + loss['pha_laplacian'] \
+ loss['fgr_l1'] + loss['fgr_coherence']
return loss
def segmentation_loss(pred_seg, true_seg):
"""
Args:
pred_seg: Shape(B, T, 1, H, W)
true_seg: Shape(B, T, 1, H, W)
"""
return F.binary_cross_entropy_with_logits(pred_seg, true_seg)
# ----------------------------------------------------------------------------- Laplacian Loss
def laplacian_loss(pred, true, max_levels=5):
kernel = gauss_kernel(device=pred.device, dtype=pred.dtype)
pred_pyramid = laplacian_pyramid(pred, kernel, max_levels)
true_pyramid = laplacian_pyramid(true, kernel, max_levels)
loss = 0
for level in range(max_levels):
loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level])
return loss / max_levels
def laplacian_pyramid(img, kernel, max_levels):
current = img
pyramid = []
for _ in range(max_levels):
current = crop_to_even_size(current)
down = downsample(current, kernel)
up = upsample(down, kernel)
diff = current - up
pyramid.append(diff)
current = down
return pyramid
def gauss_kernel(device='cpu', dtype=torch.float32):
kernel = torch.tensor([[1, 4, 6, 4, 1],
[4, 16, 24, 16, 4],
[6, 24, 36, 24, 6],
[4, 16, 24, 16, 4],
[1, 4, 6, 4, 1]], device=device, dtype=dtype)
kernel /= 256
kernel = kernel[None, None, :, :]
return kernel
def gauss_convolution(img, kernel):
B, C, H, W = img.shape
img = img.reshape(B * C, 1, H, W)
img = F.pad(img, (2, 2, 2, 2), mode='reflect')
img = F.conv2d(img, kernel)
img = img.reshape(B, C, H, W)
return img
def downsample(img, kernel):
img = gauss_convolution(img, kernel)
img = img[:, :, ::2, ::2]
return img
def upsample(img, kernel):
B, C, H, W = img.shape
out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype)
out[:, :, ::2, ::2] = img * 4
out = gauss_convolution(out, kernel)
return out
def crop_to_even_size(img):
H, W = img.shape[2:]
H = H - H % 2
W = W - W % 2
return img[:, :, :H, :W]
|