|
import random |
|
|
|
import cv2 |
|
import einops |
|
import numpy as np |
|
import torch |
|
from pytorch_lightning import seed_everything |
|
|
|
from utils.data import HWC3, apply_color, resize_image |
|
from utils.ddim import DDIMSampler |
|
from utils.model import create_model, load_state_dict |
|
|
|
model = create_model('./models/cldm_v21.yaml').cpu() |
|
model.load_state_dict(load_state_dict( |
|
'lightning_logs/version_6/checkpoints/colorizenet-sd21.ckpt', location='cuda')) |
|
model = model.cuda() |
|
ddim_sampler = DDIMSampler(model) |
|
|
|
|
|
input_image = cv2.imread("sample_data/sample1_bw.jpg") |
|
input_image = HWC3(input_image) |
|
img = resize_image(input_image, resolution=512) |
|
H, W, C = img.shape |
|
|
|
num_samples = 1 |
|
control = torch.from_numpy(img.copy()).float().cuda() / 255.0 |
|
control = torch.stack([control for _ in range(num_samples)], dim=0) |
|
control = einops.rearrange(control, 'b h w c -> b c h w').clone() |
|
|
|
|
|
|
|
seed = 1294574436 |
|
seed_everything(seed) |
|
prompt = "Colorize this image" |
|
n_prompt = "" |
|
guess_mode = False |
|
strength = 1.0 |
|
eta = 0.0 |
|
ddim_steps = 20 |
|
scale = 9.0 |
|
|
|
cond = {"c_concat": [control], "c_crossattn": [ |
|
model.get_learned_conditioning([prompt] * num_samples)]} |
|
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [ |
|
model.get_learned_conditioning([n_prompt] * num_samples)]} |
|
shape = (4, H // 8, W // 8) |
|
|
|
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( |
|
[strength] * 13) |
|
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, |
|
shape, cond, verbose=False, eta=eta, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=un_cond) |
|
|
|
x_samples = model.decode_first_stage(samples) |
|
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') |
|
* 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
|
|
|
results = [x_samples[i] for i in range(num_samples)] |
|
colored_results = [apply_color(img, result) for result in results] |
|
[cv2.imwrite(f"colorized_{i}.jpg", cv2.cvtColor(result, cv2.COLOR_RGB2BGR)) for i, result in enumerate(colored_results)] |