File size: 3,785 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F


class SoftErosion(torch.nn.Module):
    def __init__(self, kernel_size: int = 15, threshold: float = 0.6, iterations: int = 1):
        super(SoftErosion, self).__init__()
        r = kernel_size // 2
        self.padding = r
        self.iterations = iterations
        self.threshold = threshold

        # Create kernel
        y_indices, x_indices = torch.meshgrid(torch.arange(0.0, kernel_size), torch.arange(0.0, kernel_size))
        dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
        kernel = dist.max() - dist
        kernel /= kernel.sum()
        kernel = kernel.view(1, 1, *kernel.shape)
        self.register_buffer("weight", kernel)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        for i in range(self.iterations - 1):
            x = torch.min(
                x,
                F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding),
            )
        x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)

        mask = x >= self.threshold

        x[mask] = 1.0
        # add small epsilon to avoid Nans
        x[~mask] /= x[~mask].max() + 1e-7

        return x, mask


def encode_segmentation_rgb(segmentation: np.ndarray, no_neck: bool = True) -> np.ndarray:
    parse = segmentation
    # https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py
    face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
    mouth_id = 11
    # hair_id = 17
    face_map = np.zeros([parse.shape[0], parse.shape[1]])
    mouth_map = np.zeros([parse.shape[0], parse.shape[1]])
    # hair_map = np.zeros([parse.shape[0], parse.shape[1]])

    for valid_id in face_part_ids:
        valid_index = np.where(parse == valid_id)
        face_map[valid_index] = 255
    valid_index = np.where(parse == mouth_id)
    mouth_map[valid_index] = 255
    # valid_index = np.where(parse==hair_id)
    # hair_map[valid_index] = 255
    # return np.stack([face_map, mouth_map,hair_map], axis=2)
    return np.stack([face_map, mouth_map], axis=2)


def encode_segmentation_rgb_batch(segmentation: torch.Tensor, no_neck: bool = True) -> torch.Tensor:
    # https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py
    face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
    mouth_id = 11
    # hair_id = 17
    segmentation = segmentation.int()
    face_map = torch.zeros_like(segmentation)
    mouth_map = torch.zeros_like(segmentation)
    # hair_map = np.zeros([parse.shape[0], parse.shape[1]])

    white_tensor = face_map + 255
    for valid_id in face_part_ids:
        face_map = torch.where(segmentation == valid_id, white_tensor, face_map)
    mouth_map = torch.where(segmentation == mouth_id, white_tensor, mouth_map)

    return torch.cat([face_map, mouth_map], dim=1)


def postprocess(
    swapped_face: np.ndarray,
    target: np.ndarray,
    target_mask: np.ndarray,
    smooth_mask: torch.nn.Module,
) -> np.ndarray:
    # target_mask = cv2.resize(target_mask, (self.size,  self.size))

    mask_tensor = torch.from_numpy(target_mask.copy().transpose((2, 0, 1))).float().mul_(1 / 255.0).cuda()
    face_mask_tensor = mask_tensor[0] + mask_tensor[1]

    soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
    soft_face_mask_tensor.squeeze_()

    soft_face_mask = soft_face_mask_tensor.cpu().numpy()
    soft_face_mask = soft_face_mask[:, :, np.newaxis]

    result = swapped_face * soft_face_mask + target * (1 - soft_face_mask)
    result = result[:, :, ::-1]  # .astype(np.uint8)
    return result