|
import argparse |
|
import logging |
|
import os |
|
import pdb |
|
|
|
from peft import LoraConfig, get_peft_model |
|
import torch |
|
from safetensors.torch import load_model, save_model |
|
from marigold.marigold_inpaint_pipeline import MarigoldInpaintPipeline |
|
from marigold.duplicate_unet import DoubleUNet2DConditionModel |
|
import json |
|
from depth_anything_v2.dpt import DepthAnythingV2 |
|
from torchvision.transforms.functional import pil_to_tensor |
|
from PIL import Image |
|
import random |
|
import numpy as np |
|
from pycocotools import mask as coco_mask |
|
from diffusers.schedulers import DDIMScheduler, PNDMScheduler |
|
from torchvision.transforms import InterpolationMode, Resize, CenterCrop |
|
import torchvision.transforms as transforms |
|
|
|
model = MarigoldInpaintPipeline.from_pretrained('stabilityai/stable-diffusion-2') |
|
unet_config_path = '/home/aiops/wangzh/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2/snapshots/1e128c8891e52218b74cde8f26dbfc701cb99d79/unet/config.json' |
|
|
|
model.unet = DoubleUNet2DConditionModel(**json.load(open(unet_config_path))) |
|
|
|
|
|
model.unet.config["in_channels"] = 13 |
|
model.unet.duplicate_model() |
|
model.unet.inpaint_rgb_conv_in() |
|
model.unet.inpaint_depth_conv_in() |
|
|
|
unet_lora_config = LoraConfig( |
|
r=128, |
|
lora_alpha=128, |
|
init_lora_weights="gaussian", |
|
target_modules=['to_k','to_q','to_v','to_out.0'], |
|
) |
|
model.unet = get_peft_model(model.unet, unet_lora_config) |
|
|
|
sd2inpaint_ckpt = torch.load('/home/aiops/wangzh/marigold/output/512-inpaint-0.5-128-vitl-partition/checkpoint/latest/pytorch_model.bin', map_location='cpu') |
|
model.unet.load_state_dict(sd2inpaint_ckpt) |
|
model.to('cuda') |
|
|
|
model_configs = { |
|
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, |
|
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, |
|
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, |
|
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} |
|
} |
|
|
|
model.rgb_scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler") |
|
model.depth_scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler") |
|
|
|
depth_model = DepthAnythingV2(**model_configs['vitl']) |
|
depth_model.load_state_dict( |
|
torch.load(f'/home/aiops/wangzh/Depth-Anything-V2/checkpoints/depth_anything_v2_vitl.pth', map_location='cpu')) |
|
depth_model = depth_model.to('cuda').eval() |
|
|
|
image_path = ['/dataset/~sa-1b/data/sa_001000/sa_10000335.jpg', |
|
'/dataset/~sa-1b/data/sa_000357/sa_3572319.jpg', |
|
'/dataset/~sa-1b/data/sa_000045/sa_457934.jpg'] |
|
|
|
prompt = ['A white car is parked in front of the factory', |
|
'church with cemetery next to it', |
|
'A house with a red brick roof'] |
|
|
|
imgs = [pil_to_tensor(Image.open(p)) for p in image_path] |
|
depth_imgs = [depth_model(img.unsqueeze(0).cpu().numpy()) for img in imgs] |
|
|
|
masks = [] |
|
for rgb_path in image_path: |
|
anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations'] |
|
random.shuffle(anno) |
|
object_num = random.randint(5, 10) |
|
mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8) |
|
for single_anno in (anno[0:object_num] if len(anno)>object_num else anno): |
|
mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8) |
|
mask = mask |
|
mask = torch.stack([torch.tensor(mask) * 3], dim=0) |
|
masks.append(mask) |
|
|
|
|
|
|
|
|
|
|
|
resize_transform = Resize(size=[512, 512], interpolation=InterpolationMode.NEAREST_EXACT) |
|
imgs = [resize_transform(img) for img in imgs] |
|
depth_imgs = [resize_transform(depth_img.unsqueeze(0)) for depth_img in depth_imgs] |
|
masks = [resize_transform(mask.unsqueeze(0)) for mask in masks] |
|
|
|
|
|
for i in range(len(imgs)): |
|
output_image = model._rgbd_inpaint(imgs[i], depth_imgs[i].unsqueeze(0), masks[i], [prompt[i]], processing_res=512, |
|
guidance_scale=3, mode='joint_inpaint' |
|
) |
|
output_image.save(f'./joint-{i}.jpg') |