|
|
|
|
|
from argparse import ArgumentParser |
|
import sys |
|
import os |
|
|
|
sys.path.append('..') |
|
sys.path.append('.') |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.utils.data import DataLoader, Dataset |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
from vit.vision_transformer import VisionTransformer as ViT |
|
from vit.vit_triplane import ViTTriplane |
|
from guided_diffusion import dist_util, logger |
|
|
|
import click |
|
import dnnlib |
|
|
|
SEED = 42 |
|
BATCH_SIZE = 8 |
|
NUM_EPOCHS = 1 |
|
|
|
|
|
class YourDataset(Dataset): |
|
def __init__(self): |
|
pass |
|
|
|
|
|
@click.command() |
|
@click.option('--cfg', help='Base configuration', type=str, default='ffhq') |
|
@click.option('--sr-module', |
|
help='Superresolution module override', |
|
metavar='STR', |
|
required=False, |
|
default=None) |
|
@click.option('--density_reg', |
|
help='Density regularization strength.', |
|
metavar='FLOAT', |
|
type=click.FloatRange(min=0), |
|
default=0.25, |
|
required=False, |
|
show_default=True) |
|
@click.option('--density_reg_every', |
|
help='lazy density reg', |
|
metavar='int', |
|
type=click.FloatRange(min=1), |
|
default=4, |
|
required=False, |
|
show_default=True) |
|
@click.option('--density_reg_p_dist', |
|
help='density regularization strength.', |
|
metavar='FLOAT', |
|
type=click.FloatRange(min=0), |
|
default=0.004, |
|
required=False, |
|
show_default=True) |
|
@click.option('--reg_type', |
|
help='Type of regularization', |
|
metavar='STR', |
|
type=click.Choice([ |
|
'l1', 'l1-alt', 'monotonic-detach', 'monotonic-fixed', |
|
'total-variation' |
|
]), |
|
required=False, |
|
default='l1') |
|
@click.option('--decoder_lr_mul', |
|
help='decoder learning rate multiplier.', |
|
metavar='FLOAT', |
|
type=click.FloatRange(min=0), |
|
default=1, |
|
required=False, |
|
show_default=True) |
|
@click.option('--c_scale', |
|
help='Scale factor for generator pose conditioning.', |
|
metavar='FLOAT', |
|
type=click.FloatRange(min=0), |
|
required=False, |
|
default=1) |
|
def main(**kwargs): |
|
|
|
|
|
|
|
|
|
opts = dnnlib.EasyDict(kwargs) |
|
c = dnnlib.EasyDict() |
|
|
|
rendering_options = { |
|
|
|
'image_resolution': 256, |
|
'disparity_space_sampling': False, |
|
'clamp_mode': 'softplus', |
|
|
|
|
|
|
|
|
|
'c_scale': |
|
opts.c_scale, |
|
|
|
|
|
'density_reg': opts.density_reg, |
|
'density_reg_p_dist': opts. |
|
density_reg_p_dist, |
|
'reg_type': opts. |
|
reg_type, |
|
'decoder_lr_mul': |
|
opts.decoder_lr_mul, |
|
'sr_antialias': True, |
|
'return_triplane_features': True, |
|
'return_sampling_details_flag': True, |
|
} |
|
|
|
if opts.cfg == 'ffhq': |
|
rendering_options.update({ |
|
'focal': 2985.29 / 700, |
|
'depth_resolution': |
|
|
|
36, |
|
'depth_resolution_importance': |
|
|
|
36, |
|
'ray_start': |
|
2.25, |
|
'ray_end': |
|
3.3, |
|
'box_warp': |
|
1, |
|
'avg_camera_radius': |
|
2.7, |
|
'avg_camera_pivot': [ |
|
0, 0, 0.2 |
|
], |
|
}) |
|
elif opts.cfg == 'afhq': |
|
rendering_options.update({ |
|
'focal': 4.2647, |
|
'depth_resolution': 48, |
|
'depth_resolution_importance': 48, |
|
'ray_start': 2.25, |
|
'ray_end': 3.3, |
|
'box_warp': 1, |
|
'avg_camera_radius': 2.7, |
|
'avg_camera_pivot': [0, 0, -0.06], |
|
}) |
|
elif opts.cfg == 'shapenet': |
|
rendering_options.update({ |
|
'depth_resolution': 64, |
|
'depth_resolution_importance': 64, |
|
|
|
|
|
'ray_start': 0.1, |
|
'ray_end': 3.3, |
|
'box_warp': 1.6, |
|
'white_back': True, |
|
'avg_camera_radius': 1.7, |
|
'avg_camera_pivot': [0, 0, 0], |
|
}) |
|
else: |
|
assert False, "Need to specify config" |
|
|
|
c.rendering_kwargs = rendering_options |
|
|
|
args = opts |
|
|
|
|
|
args.local_rank = int(os.environ["LOCAL_RANK"]) |
|
args.is_master = args.local_rank == 0 |
|
|
|
|
|
|
|
device = torch.device(f"cuda:{args.local_rank}") |
|
|
|
|
|
dist.init_process_group(backend='nccl', |
|
init_method='env://', |
|
rank=args.local_rank, |
|
world_size=torch.cuda.device_count()) |
|
print(f"{args.local_rank=} init complete") |
|
torch.cuda.set_device(args.local_rank) |
|
|
|
|
|
torch.cuda.manual_seed_all(SEED) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = ViTTriplane( |
|
img_size=[224], |
|
patch_size=16, |
|
in_chans=384, |
|
num_classes=0, |
|
embed_dim=384, |
|
depth=2, |
|
num_heads=16, |
|
mlp_ratio=4., |
|
qkv_bias=False, |
|
qk_scale=None, |
|
drop_rate=0.1, |
|
attn_drop_rate=0., |
|
drop_path_rate=0., |
|
norm_layer=nn.LayerNorm, |
|
out_chans=96, |
|
c_dim=25, |
|
img_resolution=128, |
|
img_channels=3, |
|
cls_token=False, |
|
|
|
rendering_kwargs=c.rendering_kwargs, |
|
) |
|
|
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
model = DDP(model, |
|
device_ids=[args.local_rank], |
|
output_device=args.local_rank) |
|
|
|
dist_util.sync_params(model.named_parameters()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(NUM_EPOCHS): |
|
|
|
model.train() |
|
|
|
|
|
dist.barrier() |
|
|
|
noise = torch.randn(1, 14 * 14, 384).to(device) |
|
img = model(noise, torch.zeros(1, 25).to(device)) |
|
print(img['image'].shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|