''' * Copyright (c) 2023 Salesforce, Inc. * All rights reserved. * SPDX-License-Identifier: Apache License 2.0 * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ * By Can Qin * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala ''' import config import cv2 import einops import gradio as gr import numpy as np import torch import random import os from pytorch_lightning import seed_everything from annotator.util import resize_image, HWC3 from annotator.uniformer_base import UniformerDetector from annotator.hed import HEDdetector from annotator.canny import CannyDetector from annotator.midas import MidasDetector from annotator.outpainting import Outpainter from annotator.openpose import OpenposeDetector from annotator.inpainting import Inpainter from annotator.grayscale import GrayscaleConverter from annotator.blur import Blurrer import cvlib as cv from utils import create_model, load_state_dict from lib.ddim_hacked import DDIMSampler from safetensors.torch import load_file as stload from collections import OrderedDict apply_uniformer = UniformerDetector() apply_midas = MidasDetector() apply_canny = CannyDetector() apply_hed = HEDdetector() model_outpainting = Outpainter() apply_openpose = OpenposeDetector() model_grayscale = GrayscaleConverter() model_blur = Blurrer() model_inpainting = Inpainter() def midas(img, res): img = resize_image(HWC3(img), res) results = apply_midas(img) return results def outpainting(img, res, rand_h, rand_w): img = resize_image(HWC3(img), res) result = model_outpainting(img, rand_h, rand_w) return result def grayscale(img, res): img = resize_image(HWC3(img), res) result = model_grayscale(img) return result def blur(img, res, ksize): img = resize_image(HWC3(img), res) result = model_blur(img, ksize) return result def inpainting(img, res, rand_h, rand_h_1, rand_w, rand_w_1): img = resize_image(HWC3(img), res) result = model_inpainting(img, rand_h, rand_h_1, rand_w, rand_w_1) return result model = create_model('./models/cldm_v15_unicontrol.yaml').cpu() # model_url = 'https://huggingface.co/Robert001/UniControl-Model/resolve/main/unicontrol_v1.1.ckpt' model_url = 'https://huggingface.co/Robert001/UniControl-Model/resolve/main/unicontrol_v1.1.st' ckpts_path='./' # model_path = os.path.join(ckpts_path, "unicontrol_v1.1.ckpt") model_path = os.path.join(ckpts_path, "unicontrol_v1.1.st") if not os.path.exists(model_path): from basicsr.utils.download_util import load_file_from_url load_file_from_url(model_url, model_dir=ckpts_path) model_dict = OrderedDict(stload(model_path, device='cpu')) model.load_state_dict(model_dict, strict=False) # model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False) model = model.cuda() ddim_sampler = DDIMSampler(model) task_to_name = {'hed': 'control_hed', 'canny': 'control_canny', 'seg': 'control_seg', 'segbase': 'control_seg', 'depth': 'control_depth', 'normal': 'control_normal', 'openpose': 'control_openpose', 'bbox': 'control_bbox', 'grayscale': 'control_grayscale', 'outpainting': 'control_outpainting', 'hedsketch': 'control_hedsketch', 'inpainting': 'control_inpainting', 'blur': 'control_blur', 'grayscale': 'control_grayscale'} name_to_instruction = {"control_hed": "hed edge to image", "control_canny": "canny edge to image", "control_seg": "segmentation map to image", "control_depth": "depth map to image", "control_normal": "normal surface map to image", "control_img": "image editing", "control_openpose": "human pose skeleton to image", "control_hedsketch": "sketch to image", "control_bbox": "bounding box to image", "control_outpainting": "image outpainting", "control_grayscale": "gray image to color image", "control_blur": "deblur image to clean image", "control_inpainting": "image inpainting"} def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold, condition_mode): with torch.no_grad(): img = resize_image(HWC3(input_image), image_resolution) H, W, C = img.shape if condition_mode == True: detected_map = apply_canny(img, low_threshold, high_threshold) detected_map = HWC3(detected_map) else: detected_map = 255 - img control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'canny' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [255 - detected_map] + results def process_hed(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: detected_map = apply_hed(resize_image(input_image, detect_resolution)) detected_map = HWC3(detected_map) else: detected_map = img detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'hed' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results def process_depth(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: detected_map, _ = apply_midas(resize_image(input_image, detect_resolution)) detected_map = HWC3(detected_map) else: detected_map = img detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'depth' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results def process_normal(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: _, detected_map = apply_midas(resize_image(input_image, detect_resolution)) detected_map = HWC3(detected_map) else: detected_map = img detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'normal' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution)) detected_map = HWC3(detected_map) else: detected_map = img detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'openpose' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results def process_seg(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: detected_map = apply_uniformer(resize_image(input_image, detect_resolution)) else: detected_map = img detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'seg' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results color_dict = { 'background': (0, 0, 100), 'person': (255, 0, 0), 'bicycle': (0, 255, 0), 'car': (0, 0, 255), 'motorcycle': (255, 255, 0), 'airplane': (255, 0, 255), 'bus': (0, 255, 255), 'train': (128, 128, 0), 'truck': (128, 0, 128), 'boat': (0, 128, 128), 'traffic light': (128, 128, 128), 'fire hydrant': (64, 0, 0), 'stop sign': (0, 64, 0), 'parking meter': (0, 0, 64), 'bench': (64, 64, 0), 'bird': (64, 0, 64), 'cat': (0, 64, 64), 'dog': (192, 192, 192), 'horse': (32, 32, 32), 'sheep': (96, 96, 96), 'cow': (160, 160, 160), 'elephant': (224, 224, 224), 'bear': (32, 0, 0), 'zebra': (0, 32, 0), 'giraffe': (0, 0, 32), 'backpack': (32, 32, 0), 'umbrella': (32, 0, 32), 'handbag': (0, 32, 32), 'tie': (96, 0, 0), 'suitcase': (0, 96, 0), 'frisbee': (0, 0, 96), 'skis': (96, 96, 0), 'snowboard': (96, 0, 96), 'sports ball': (0, 96, 96), 'kite': (160, 0, 0), 'baseball bat': (0, 160, 0), 'baseball glove': (0, 0, 160), 'skateboard': (160, 160, 0), 'surfboard': (160, 0, 160), 'tennis racket': (0, 160, 160), 'bottle': (224, 0, 0), 'wine glass': (0, 224, 0), 'cup': (0, 0, 224), 'fork': (224, 224, 0), 'knife': (224, 0, 224), 'spoon': (0, 224, 224), 'bowl': (64, 64, 64), 'banana': (128, 64, 64), 'apple': (64, 128, 64), 'sandwich': (64, 64, 128), 'orange': (128, 128, 64), 'broccoli': (128, 64, 128), 'carrot': (64, 128, 128), 'hot dog': (192, 64, 64), 'pizza': (64, 192, 64), 'donut': (64, 64, 192), 'cake': (192, 192, 64), 'chair': (192, 64, 192), 'couch': (64, 192, 192), 'potted plant': (96, 32, 32), 'bed': (32, 96, 32), 'dining table': (32, 32, 96), 'toilet': (96, 96, 32), 'tv': (96, 32, 96), 'laptop': (32, 96, 96), 'mouse': (160, 32, 32), 'remote': (32, 160, 32), 'keyboard': (32, 32, 160), 'cell phone': (160, 160, 32), 'microwave': (160, 32, 160), 'oven': (32, 160, 160), 'toaster': (224, 32, 32), 'sink': (32, 224, 32), 'refrigerator': (32, 32, 224), 'book': (224, 224, 32), 'clock': (224, 32, 224), 'vase': (32, 224, 224), 'scissors': (64, 96, 96), 'teddy bear': (96, 64, 96), 'hair drier': (96, 96, 64), 'toothbrush': (160, 96, 96) } def process_bbox(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, confidence, nms_thresh, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: bbox, label, conf = cv.detect_common_objects(input_image, confidence=confidence, nms_thresh=nms_thresh) mask = np.zeros((input_image.shape), np.uint8) if len(bbox) > 0: order_area = np.zeros(len(bbox)) # order_final = np.arange(len(bbox)) area_all = 0 for idx_mask, box in enumerate(bbox): x_1, y_1, x_2, y_2 = box x_1 = 0 if x_1 < 0 else x_1 y_1 = 0 if y_1 < 0 else y_1 x_2 = input_image.shape[1] if x_2 < 0 else x_2 y_2 = input_image.shape[0] if y_2 < 0 else y_2 area = (x_2 - x_1) * (y_2 - y_1) order_area[idx_mask] = area area_all += area ordered_area = np.argsort(-order_area) for idx_mask in ordered_area: box = bbox[idx_mask] x_1, y_1, x_2, y_2 = box x_1 = 0 if x_1 < 0 else x_1 y_1 = 0 if y_1 < 0 else y_1 x_2 = input_image.shape[1] if x_2 < 0 else x_2 y_2 = input_image.shape[0] if y_2 < 0 else y_2 mask[y_1:y_2, x_1:x_2, :] = color_dict[label[idx_mask]] detected_map = mask else: detected_map = img detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'bbox' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results def process_outpainting(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, h_ratio, w_ratio, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: detected_map = outpainting(input_image, image_resolution, h_ratio, w_ratio) else: detected_map = img detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'outpainting' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results def process_sketch(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: detected_map = apply_hed(resize_image(input_image, detect_resolution)) detected_map = HWC3(detected_map) # sketch the hed image retry = 0 cnt = 0 while retry == 0: threshold_value = np.random.randint(110, 160) kernel_size = 3 alpha = 1.5 beta = 50 binary_image = cv2.threshold(detected_map, threshold_value, 255, cv2.THRESH_BINARY)[1] inverted_image = cv2.bitwise_not(binary_image) smoothed_image = cv2.GaussianBlur(inverted_image, (kernel_size, kernel_size), 0) sketch_image = cv2.convertScaleAbs(smoothed_image, alpha=alpha, beta=beta) if np.sum(sketch_image < 5) > 0.005 * sketch_image.shape[0] * sketch_image.shape[1] or cnt == 5: retry = 1 else: cnt += 1 detected_map = sketch_image else: detected_map = img detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'hedsketch' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results def process_colorization(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: detected_map = grayscale(input_image, image_resolution) detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) detected_map = detected_map[:, :, np.newaxis] detected_map = detected_map.repeat(3, axis=2) else: detected_map = img control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'grayscale' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results def process_deblur(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, ksize, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: detected_map = blur(input_image, image_resolution, ksize) else: detected_map = img detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'blur' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results def process_inpainting(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, h_ratio_t, h_ratio_d, w_ratio_l, w_ratio_r, condition_mode): with torch.no_grad(): input_image = HWC3(input_image) img = resize_image(input_image, image_resolution) H, W, C = img.shape if condition_mode == True: detected_map = inpainting(input_image, image_resolution, h_ratio_t, h_ratio_d, w_ratio_l, w_ratio_r) else: detected_map = img detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) if config.save_memory: model.low_vram_shift(is_diffusing=False) task = 'inpainting' task_dic = {} task_dic['name'] = task_to_name[task] task_instruction = name_to_instruction[task_dic['name']] task_dic['feature'] = model.get_learned_conditioning(task_instruction)[:, :1, :] cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], "task": task_dic} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) if config.save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) if config.save_memory: model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype( np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results ############################################################################################################ demo = gr.Blocks() with demo: #gr.Markdown("UniControl Stable Diffusion Demo") gr.HTML( """
Can Qin 1,2, Shu Zhang1, Ning Yu 1, Yihao Feng1, Xinyi Yang1, Yingbo Zhou 1, Huan Wang 1, Juan Carlos Niebles1, Caiming Xiong 1, Silvio Savarese 1, Stefano Ermon 3, Yun Fu 2, Ran Xu 1
1 Salesforce AI 2 Northeastern University 3 Stanford University
Work done when Can Qin was an intern at Salesforce AI Research.
ONE model for ALL the condition-to-image generation! [Github] [Website] [arXiv]