|
"""This file contains methods for inference and image generation.""" |
|
import logging |
|
from typing import List, Tuple, Dict |
|
|
|
import streamlit as st |
|
import torch |
|
import gc |
|
import time |
|
import numpy as np |
|
from PIL import Image |
|
from time import perf_counter |
|
from contextlib import contextmanager |
|
from scipy.signal import fftconvolve |
|
from PIL import ImageFilter |
|
|
|
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation |
|
from diffusers import ControlNetModel, UniPCMultistepScheduler |
|
from diffusers import StableDiffusionInpaintPipeline |
|
from compel import Compel |
|
|
|
from config import WIDTH, HEIGHT |
|
from palette import ade_palette |
|
from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline |
|
|
|
LOGGING = logging.getLogger(__name__) |
|
|
|
def flush(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
class ControlNetPipeline: |
|
def __init__(self): |
|
self.in_use = False |
|
self.controlnet = ControlNetModel.from_pretrained( |
|
"BertChristiaens/controlnet-seg-room", torch_dtype=torch.float32) |
|
|
|
self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-inpainting", |
|
controlnet=self.controlnet, |
|
safety_checker=None, |
|
torch_dtype=torch.float32 |
|
) |
|
|
|
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
self.pipe.enable_attention_slicing("max") |
|
self.pipe = self.pipe.to("cuda") |
|
|
|
self.waiting_queue = [] |
|
self.count = 0 |
|
|
|
def __call__(self, **kwargs): |
|
self.count += 1 |
|
number = self.count |
|
|
|
self.waiting_queue.append(number) |
|
|
|
|
|
while self.waiting_queue[0] != number: |
|
print(f"Wait for your turn {number} in queue {self.waiting_queue}") |
|
time.sleep(0.5) |
|
pass |
|
|
|
|
|
|
|
print("It's the turn of", self.count) |
|
results = self.pipe(**kwargs) |
|
self.waiting_queue.pop(0) |
|
flush() |
|
return results |
|
|
|
|
|
@contextmanager |
|
def catchtime(message: str) -> float: |
|
"""Context manager to measure time |
|
Args: |
|
message (str): message to log |
|
Returns: |
|
float: time in seconds |
|
Yields: |
|
Iterator[float]: time in seconds |
|
""" |
|
start = perf_counter() |
|
yield lambda: perf_counter() - start |
|
LOGGING.info('%s: %.3f seconds', message, perf_counter() - start) |
|
|
|
|
|
def convolution(mask: Image.Image, size=9) -> Image: |
|
"""Method to blur the mask |
|
Args: |
|
mask (Image): masking image |
|
size (int, optional): size of the blur. Defaults to 9. |
|
Returns: |
|
Image: blurred mask |
|
""" |
|
mask = np.array(mask.convert("L")) |
|
conv = np.ones((size, size)) / size**2 |
|
mask_blended = fftconvolve(mask, conv, 'same') |
|
mask_blended = mask_blended.astype(np.uint8).copy() |
|
|
|
border = size |
|
|
|
|
|
mask_blended[:border, :] = mask[:border, :] |
|
mask_blended[-border:, :] = mask[-border:, :] |
|
mask_blended[:, :border] = mask[:, :border] |
|
mask_blended[:, -border:] = mask[:, -border:] |
|
|
|
return Image.fromarray(mask_blended).convert("L") |
|
|
|
|
|
def postprocess_image_masking(inpainted: Image, image: Image, mask: Image) -> Image: |
|
"""Method to postprocess the inpainted image |
|
Args: |
|
inpainted (Image): inpainted image |
|
image (Image): original image |
|
mask (Image): mask |
|
Returns: |
|
Image: inpainted image |
|
""" |
|
final_inpainted = Image.composite(inpainted.convert("RGBA"), image.convert("RGBA"), mask) |
|
return final_inpainted.convert("RGB") |
|
|
|
|
|
@st.experimental_singleton(max_entries=5) |
|
def get_controlnet() -> ControlNetModel: |
|
"""Method to load the controlnet model |
|
Returns: |
|
ControlNetModel: controlnet model |
|
""" |
|
pipe = ControlNetPipeline() |
|
return pipe |
|
|
|
|
|
@st.experimental_singleton(max_entries=5) |
|
def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: |
|
"""Method to load the segmentation pipeline |
|
Returns: |
|
Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline |
|
""" |
|
image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") |
|
image_segmentor = UperNetForSemanticSegmentation.from_pretrained( |
|
"openmmlab/upernet-convnext-small") |
|
return image_processor, image_segmentor |
|
|
|
|
|
@st.experimental_singleton(max_entries=5) |
|
def get_inpainting_pipeline() -> StableDiffusionInpaintPipeline: |
|
"""Method to load the inpainting pipeline |
|
Returns: |
|
StableDiffusionInpaintPipeline: inpainting pipeline |
|
""" |
|
pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting", |
|
torch_dtype=torch.float16, |
|
safety_checker=None, |
|
) |
|
|
|
pipe.enable_xformers_memory_efficient_attention() |
|
pipe = pipe.to("cuda") |
|
|
|
return pipe |
|
|
|
|
|
def make_grid_parameters(grid_search: Dict, params: Dict) -> List[Dict]: |
|
"""Method to make grid parameters |
|
Args: |
|
grid_search (Dict): grid search parameters |
|
params (Dict): fixed parameters |
|
Returns: |
|
List[Dict]: grid parameters |
|
""" |
|
options = [] |
|
|
|
for k in range(len(grid_search['generator'])): |
|
for i in range(len(grid_search['strength'])): |
|
for j in range(len(grid_search['guidance_scale'])): |
|
options.append({'strength': grid_search['strength'][i], |
|
'guidance_scale': grid_search['guidance_scale'][j], |
|
'generator': grid_search['generator'][k], |
|
**params |
|
}) |
|
return options |
|
|
|
|
|
def make_captions(options: List[Dict]) -> List[str]: |
|
"""Method to make captions |
|
Args: |
|
options (List[Dict]): grid parameters |
|
Returns: |
|
List[str]: captions |
|
""" |
|
captions = [] |
|
for option in options: |
|
captions.append( |
|
f"strength {option['strength']}, guidance {option['guidance_scale']}, steps {option['num_inference_steps']}") |
|
return captions |
|
|
|
|
|
@torch.inference_mode() |
|
def make_image_controlnet(image: np.ndarray, |
|
mask_image: np.ndarray, |
|
controlnet_conditioning_image: np.ndarray, |
|
positive_prompt: str, negative_prompt: str, |
|
seed: int = 2356132) -> List[Image.Image]: |
|
"""Method to make image using controlnet |
|
Args: |
|
image (np.ndarray): input image |
|
mask_image (np.ndarray): mask image |
|
controlnet_conditioning_image (np.ndarray): conditioning image |
|
positive_prompt (str): positive prompt string |
|
negative_prompt (str): negative prompt string |
|
seed (int, optional): seed. Defaults to 2356132. |
|
Returns: |
|
List[Image.Image]: list of generated images |
|
""" |
|
|
|
with catchtime("get controlnet"): |
|
pipe = get_controlnet() |
|
|
|
torch.cuda.empty_cache() |
|
images = [] |
|
|
|
common_parameters = {'prompt': positive_prompt, |
|
'negative_prompt': negative_prompt, |
|
'num_inference_steps': 30, |
|
'controlnet_conditioning_scale': 1.1, |
|
'controlnet_conditioning_scale_decay': 0.96, |
|
'controlnet_steps': 28, |
|
} |
|
|
|
grid_search = {'strength': [1.00, ], |
|
'guidance_scale': [7.0], |
|
'generator': [[torch.Generator(device="cuda").manual_seed(seed+i)] for i in range(1)], |
|
} |
|
|
|
prompt_settings = make_grid_parameters(grid_search, common_parameters) |
|
|
|
|
|
mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB") |
|
image = Image.fromarray(image).convert("RGB") |
|
controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB").filter(ImageFilter.GaussianBlur(radius = 9)) |
|
|
|
mask_image_postproc = convolution(mask_image) |
|
|
|
with catchtime("Controlnet generation total"): |
|
for _, setting in enumerate(prompt_settings): |
|
with catchtime("Controlnet generation"): |
|
generated_image = pipe( |
|
**setting, |
|
image=image, |
|
mask_image=mask_image, |
|
controlnet_conditioning_image=controlnet_conditioning_image, |
|
).images[0] |
|
generated_image = postprocess_image_masking( |
|
generated_image, image, mask_image_postproc) |
|
images.append(generated_image) |
|
|
|
return images |
|
|
|
|
|
@torch.inference_mode() |
|
def make_inpainting(positive_prompt: str, |
|
image: Image, |
|
mask_image: np.ndarray, |
|
negative_prompt: str = "") -> List[Image.Image]: |
|
"""Method to make inpainting |
|
Args: |
|
positive_prompt (str): positive prompt string |
|
image (Image): input image |
|
mask_image (np.ndarray): mask image |
|
negative_prompt (str, optional): negative prompt string. Defaults to "". |
|
Returns: |
|
List[Image.Image]: list of generated images |
|
""" |
|
|
|
with catchtime("Get inpainting pipeline"): |
|
pipe = get_inpainting_pipeline() |
|
|
|
common_parameters = {'prompt': positive_prompt, |
|
'negative_prompt': negative_prompt, |
|
'num_inference_steps': 20, |
|
} |
|
|
|
torch.cuda.empty_cache() |
|
images = [] |
|
for _ in range(1): |
|
with catchtime("Inpainting generation"): |
|
image_ = pipe(image=image, |
|
mask_image=Image.fromarray((mask_image * 255).astype(np.uint8)), |
|
height=HEIGHT, |
|
width=WIDTH, |
|
**common_parameters |
|
).images[0] |
|
images.append(image_) |
|
return images |
|
|
|
|
|
@torch.inference_mode() |
|
@torch.autocast('cuda') |
|
def segment_image(image: Image) -> Image: |
|
"""Method to segment image |
|
Args: |
|
image (Image): input image |
|
Returns: |
|
Image: segmented image |
|
""" |
|
image_processor, image_segmentor = get_segmentation_pipeline() |
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values |
|
with torch.no_grad(): |
|
outputs = image_segmentor(pixel_values) |
|
|
|
seg = image_processor.post_process_semantic_segmentation( |
|
outputs, target_sizes=[image.size[::-1]]) |
|
seg = seg[0] |
|
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) |
|
palette = np.array(ade_palette()) |
|
for label, color in enumerate(palette): |
|
color_seg[seg == label, :] = color |
|
color_seg = color_seg.astype(np.uint8) |
|
seg_image = Image.fromarray(color_seg).convert('RGB') |
|
return seg_image |