"""This file contains methods for inference and image generation.""" import logging from typing import List, Tuple, Dict import streamlit as st import torch import time import numpy as np from PIL import Image from time import perf_counter from contextlib import contextmanager from scipy.signal import fftconvolve from PIL import ImageFilter from transformers import AutoImageProcessor, UperNetForSemanticSegmentation from diffusers import ControlNetModel, UniPCMultistepScheduler from diffusers import StableDiffusionInpaintPipeline from compel import Compel from config import WIDTH, HEIGHT from palette import ade_palette from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline LOGGING = logging.getLogger(__name__) class ControlNetPipeline: def __init__(self): self.in_use = False self.controlnet = ControlNetModel.from_pretrained( "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16) self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", controlnet=self.controlnet, safety_checker=None, torch_dtype=torch.float16 ) self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) self.pipe.enable_xformers_memory_efficient_attention() self.pipe = self.pipe.to("cuda") self.waiting_queue = [] self.count = 0 def __call__(self, **kwargs): self.count += 1 number = self.count self.waiting_queue.append(number) # wait until the next number in the queue is the current number while self.waiting_queue[0] != number: print(f"Wait for your turn {number} in queue {self.waiting_queue}") time.sleep(0.5) pass # it's your turn, so remove the number from the queue # and call the function print("It's the turn of", self.count) return self.pipe(**kwargs) self.waiting_queue.pop(0) @contextmanager def catchtime(message: str) -> float: """Context manager to measure time Args: message (str): message to log Returns: float: time in seconds Yields: Iterator[float]: time in seconds """ start = perf_counter() yield lambda: perf_counter() - start LOGGING.info('%s: %.3f seconds', message, perf_counter() - start) def convolution(mask: Image.Image, size=9) -> Image: """Method to blur the mask Args: mask (Image): masking image size (int, optional): size of the blur. Defaults to 9. Returns: Image: blurred mask """ mask = np.array(mask.convert("L")) conv = np.ones((size, size)) / size**2 mask_blended = fftconvolve(mask, conv, 'same') mask_blended = mask_blended.astype(np.uint8).copy() border = size # replace borders with original values mask_blended[:border, :] = mask[:border, :] mask_blended[-border:, :] = mask[-border:, :] mask_blended[:, :border] = mask[:, :border] mask_blended[:, -border:] = mask[:, -border:] return Image.fromarray(mask_blended).convert("L") def postprocess_image_masking(inpainted: Image, image: Image, mask: Image) -> Image: """Method to postprocess the inpainted image Args: inpainted (Image): inpainted image image (Image): original image mask (Image): mask Returns: Image: inpainted image """ final_inpainted = Image.composite(inpainted.convert("RGBA"), image.convert("RGBA"), mask) return final_inpainted.convert("RGB") @st.experimental_singleton(max_entries=5) def get_controlnet() -> ControlNetModel: """Method to load the controlnet model Returns: ControlNetModel: controlnet model """ pipe = ControlNetPipeline() return pipe @st.experimental_singleton(max_entries=5) def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: """Method to load the segmentation pipeline Returns: Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline """ image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") image_segmentor = UperNetForSemanticSegmentation.from_pretrained( "openmmlab/upernet-convnext-small") return image_processor, image_segmentor @st.experimental_singleton(max_entries=5) def get_inpainting_pipeline() -> StableDiffusionInpaintPipeline: """Method to load the inpainting pipeline Returns: StableDiffusionInpaintPipeline: inpainting pipeline """ pipe = StableDiffusionInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16, safety_checker=None, ) pipe.enable_xformers_memory_efficient_attention() pipe = pipe.to("cuda") return pipe def make_grid_parameters(grid_search: Dict, params: Dict) -> List[Dict]: """Method to make grid parameters Args: grid_search (Dict): grid search parameters params (Dict): fixed parameters Returns: List[Dict]: grid parameters """ options = [] for k in range(len(grid_search['generator'])): for i in range(len(grid_search['strength'])): for j in range(len(grid_search['guidance_scale'])): options.append({'strength': grid_search['strength'][i], 'guidance_scale': grid_search['guidance_scale'][j], 'generator': grid_search['generator'][k], **params }) return options def make_captions(options: List[Dict]) -> List[str]: """Method to make captions Args: options (List[Dict]): grid parameters Returns: List[str]: captions """ captions = [] for option in options: captions.append( f"strength {option['strength']}, guidance {option['guidance_scale']}, steps {option['num_inference_steps']}") return captions @torch.inference_mode() def make_image_controlnet(image: np.ndarray, mask_image: np.ndarray, controlnet_conditioning_image: np.ndarray, positive_prompt: str, negative_prompt: str, seed: int = 2356132) -> List[Image.Image]: """Method to make image using controlnet Args: image (np.ndarray): input image mask_image (np.ndarray): mask image controlnet_conditioning_image (np.ndarray): conditioning image positive_prompt (str): positive prompt string negative_prompt (str): negative prompt string seed (int, optional): seed. Defaults to 2356132. Returns: List[Image.Image]: list of generated images """ with catchtime("get controlnet"): pipe = get_controlnet() torch.cuda.empty_cache() images = [] common_parameters = {'prompt': positive_prompt, 'negative_prompt': negative_prompt, 'num_inference_steps': 30, 'controlnet_conditioning_scale': 1.1, 'controlnet_conditioning_scale_decay': 0.96, 'controlnet_steps': 28, } grid_search = {'strength': [1.00, ], 'guidance_scale': [7.0], 'generator': [[torch.Generator(device="cuda").manual_seed(seed+i)] for i in range(1)], } prompt_settings = make_grid_parameters(grid_search, common_parameters) mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB") image = Image.fromarray(image).convert("RGB") controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB").filter(ImageFilter.GaussianBlur(radius = 9)) mask_image_postproc = convolution(mask_image) with catchtime("Controlnet generation total"): for _, setting in enumerate(prompt_settings): with catchtime("Controlnet generation"): generated_image = pipe( **setting, image=image, mask_image=mask_image, controlnet_conditioning_image=controlnet_conditioning_image, ).images[0] generated_image = postprocess_image_masking( generated_image, image, mask_image_postproc) images.append(generated_image) return images @torch.inference_mode() def make_inpainting(positive_prompt: str, image: Image, mask_image: np.ndarray, negative_prompt: str = "") -> List[Image.Image]: """Method to make inpainting Args: positive_prompt (str): positive prompt string image (Image): input image mask_image (np.ndarray): mask image negative_prompt (str, optional): negative prompt string. Defaults to "". Returns: List[Image.Image]: list of generated images """ with catchtime("Get inpainting pipeline"): pipe = get_inpainting_pipeline() common_parameters = {'prompt': positive_prompt, 'negative_prompt': negative_prompt, 'num_inference_steps': 20, } torch.cuda.empty_cache() images = [] for _ in range(1): with catchtime("Inpainting generation"): image_ = pipe(image=image, mask_image=Image.fromarray((mask_image * 255).astype(np.uint8)), height=HEIGHT, width=WIDTH, **common_parameters ).images[0] images.append(image_) return images @torch.inference_mode() @torch.autocast('cuda') def segment_image(image: Image) -> Image: """Method to segment image Args: image (Image): input image Returns: Image: segmented image """ image_processor, image_segmentor = get_segmentation_pipeline() pixel_values = image_processor(image, return_tensors="pt").pixel_values with torch.no_grad(): outputs = image_segmentor(pixel_values) seg = image_processor.post_process_semantic_segmentation( outputs, target_sizes=[image.size[::-1]]) seg = seg[0] color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 palette = np.array(ade_palette()) for label, color in enumerate(palette): color_seg[seg == label, :] = color color_seg = color_seg.astype(np.uint8) seg_image = Image.fromarray(color_seg).convert('RGB') return seg_image