Spaces:
Runtime error
Runtime error
File size: 2,484 Bytes
83d8d3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
from pathlib import Path
import cv2
import torch
from model import BiSeNet
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm
# For BiSeNet and for official_224 SimSwap
class MaskDataset(Dataset):
def __init__(self, img_root, mask_root):
img_dir = Path(img_root)
self.to_tensor_normalize = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
self.img_files = list(img_dir.glob(f"**/*.jpg"))
self.img_files.sort()
self.mask_files = [os.path.join(mask_root, os.path.relpath(img_path, img_root)) for img_path in self.img_files]
def __len__(self):
return len(self.mask_files)
def __getitem__(self, index):
img = Image.open(self.img_files[index]).convert("RGB")
return {"img": self.to_tensor_normalize(img), "mask_path": self.mask_files[index]}
class MaskDataLoader:
def __init__(self):
"""Initialize this class"""
self.dataset = MaskDataset(img_root="/data/dataset/face_1k/alignHQ", mask_root="/data/dataset/face_1k/mask")
self.dataloader = torch.utils.data.DataLoader(
self.dataset, batch_size=8, shuffle=True, num_workers=8, drop_last=False
)
def __len__(self):
"""Return the number of data in the dataset"""
return len(self.dataset) / 8
def __iter__(self):
"""Return a batch of data"""
for data in self.dataloader:
yield data
if __name__ == "__main__":
dataloader = MaskDataLoader()
bisenet_path = "/data/useful_ckpt/face_parsing/parsing_model_79999_iter.pth"
bisenet = BiSeNet(n_classes=19)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
bisenet.to(device)
state_dict = torch.load(bisenet_path, map_location=device)
bisenet.load_state_dict(state_dict)
bisenet.eval()
for data in tqdm(dataloader):
mask, ignore_ids = bisenet.get_mask(data["img"].to(device), 256)
mask = (mask * 255).to(torch.uint8).cpu().numpy().transpose(0, 2, 3, 1).repeat(3, 3)
for i in range(mask.shape[0]):
if ignore_ids[i]:
continue
path = data["mask_path"][i]
dirname = os.path.dirname(path)
os.makedirs(dirname, exist_ok=True)
cv2.imwrite(path, mask[i])
|