File size: 1,655 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54a5078
 
 
 
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import time
from dataclasses import dataclass

from configs.mode import FaceSwapMode
from configs.singleton import Singleton


@Singleton
@dataclass
class TrainConfig:
    mode = FaceSwapMode.MANY_TO_MANY
    source_name: str = ""

    dataset_index: str = "/data/dataset/faceswap/full.pkl"
    dataset_root: str = "/data/dataset/faceswap"

    batch_size: int = 8
    num_threads: int = 8
    same_rate: float = 0.5
    lr: float = 5e-5
    grad_clip: float = 1000.0

    use_ddp: bool = True

    mouth_mask: bool = True
    eye_hm_loss: bool = False
    mouth_hm_loss: bool = False

    load_checkpoint = None  # ("/data/checkpoints/hififace/rebuilt_discriminator_SFF_c256_1683367464544", 400000)

    identity_extractor_config = {
        "f_3d_checkpoint_path": "/checkpoints/Deep3DFaceRecon/epoch_20_new.pth",
        "f_id_checkpoint_path": "/checkpoints/arcface/ms1mv3_arcface_r100_fp16_backbone.pth",
        "bfm_folder": "/checkpoints/useful_ckpt/BFM",
        "hrnet_path": "/checkpoints/useful_ckpt/face_98lmks/HR18-WFLW.pth",
    }

    visualize_interval: int = 100
    plot_interval: int = 100
    max_iters: int = 1000000
    checkpoint_interval: int = 40000

    exp_name: str = "exp_base"
    log_basedir: str = "/data/logs/hififace/"
    checkpoint_basedir = "/data/checkpoints/hififace"

    def __post_init__(self):
        time_stamp = int(time.time() * 1000)
        self.log_dir = os.path.join(self.log_basedir, f"{self.exp_name}_{time_stamp}")
        self.checkpoint_dir = os.path.join(self.checkpoint_basedir, f"{self.exp_name}_{time_stamp}")


if __name__ == "__main__":
    tc = TrainConfig()
    print(tc.log_dir)