Spaces:
Runtime error
Runtime error
from transformers.tools.base import Tool, get_default_device | |
from transformers.utils import is_accelerate_available | |
import torch | |
from diffusers import StableDiffusionInpaintPipeline | |
INPAINTING_DESCRIPTION = ( | |
"This is a tool that inpaints some parts of an image StableDiffusionInpaintPipeline according to a prompt." | |
" It takes three inputs: `image`, which should be the original image which will be inpainted," | |
" `mask_image`, which should be used to determine which parts of the original image" | |
" (stored in the `image` variable) should be inpainted," | |
" and `prompt`, which should be the prompt to use to guide the inpainting process. It returns the" | |
" inpainted image." | |
) | |
class InpaintingTool(Tool): | |
default_checkpoint = "stabilityai/stable-diffusion-2-inpainting" | |
description = INPAINTING_DESCRIPTION | |
name = "image_inpainter" | |
inputs = ['image', 'image', 'text'] | |
outputs = ['image'] | |
def __init__(self, device=None, **hub_kwargs) -> None: | |
if not is_accelerate_available(): | |
raise ImportError("Accelerate should be installed in order to use tools.") | |
super().__init__() | |
self.device = device | |
self.pipeline = None | |
self.hub_kwargs = hub_kwargs | |
def setup(self): | |
if self.device is None: | |
self.device = get_default_device() | |
self.pipeline = StableDiffusionInpaintPipeline.from_pretrained(self.default_checkpoint) | |
self.pipeline.to(self.device) | |
if self.device.type == "cuda": | |
self.pipeline.to(torch_dtype=torch.float16) | |
self.is_initialized = True | |
def __call__(self, image, mask_image, prompt): | |
if not self.is_initialized: | |
self.setup() | |
resized_image = image.resize((512, 512)) | |
mask_image = mask_image.resize((512, 512)) | |
# negative_prompt = "low quality, bad quality, deformed, low resolution" | |
# added_prompt = " , highest quality, highly realistic, very high resolution" | |
inpainted_image = self.pipeline( | |
# prompt=prompt + added_prompt, | |
# negative_prompt=negative_prompt, | |
prompt=prompt, | |
image=resized_image, | |
mask_image=mask_image | |
).images[0] | |
return inpainted_image.resize((image.size[0], image.size[1])) |