File size: 2,484 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from pathlib import Path

import cv2
import torch
from model import BiSeNet
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm

# For BiSeNet and for official_224 SimSwap


class MaskDataset(Dataset):
    def __init__(self, img_root, mask_root):
        img_dir = Path(img_root)
        self.to_tensor_normalize = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        self.img_files = list(img_dir.glob(f"**/*.jpg"))
        self.img_files.sort()
        self.mask_files = [os.path.join(mask_root, os.path.relpath(img_path, img_root)) for img_path in self.img_files]

    def __len__(self):
        return len(self.mask_files)

    def __getitem__(self, index):
        img = Image.open(self.img_files[index]).convert("RGB")
        return {"img": self.to_tensor_normalize(img), "mask_path": self.mask_files[index]}


class MaskDataLoader:
    def __init__(self):
        """Initialize this class"""
        self.dataset = MaskDataset(img_root="/data/dataset/face_1k/alignHQ", mask_root="/data/dataset/face_1k/mask")

        self.dataloader = torch.utils.data.DataLoader(
            self.dataset, batch_size=8, shuffle=True, num_workers=8, drop_last=False
        )

    def __len__(self):
        """Return the number of data in the dataset"""
        return len(self.dataset) / 8

    def __iter__(self):
        """Return a batch of data"""
        for data in self.dataloader:
            yield data


if __name__ == "__main__":
    dataloader = MaskDataLoader()
    bisenet_path = "/data/useful_ckpt/face_parsing/parsing_model_79999_iter.pth"
    bisenet = BiSeNet(n_classes=19)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    bisenet.to(device)
    state_dict = torch.load(bisenet_path, map_location=device)
    bisenet.load_state_dict(state_dict)
    bisenet.eval()

    for data in tqdm(dataloader):
        mask, ignore_ids = bisenet.get_mask(data["img"].to(device), 256)
        mask = (mask * 255).to(torch.uint8).cpu().numpy().transpose(0, 2, 3, 1).repeat(3, 3)

        for i in range(mask.shape[0]):
            if ignore_ids[i]:
                continue
            path = data["mask_path"][i]
            dirname = os.path.dirname(path)
            os.makedirs(dirname, exist_ok=True)
            cv2.imwrite(path, mask[i])