|
|
|
"""pip install torch pillow requests diffusers imageio gradio==3.4 httpx==0.23.2 transformers accelerate""" |
|
|
|
import gradio as gr |
|
import imageio |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from diffusers import StableDiffusionInpaintPipeline |
|
|
|
def perform_inpainting(prompt): |
|
|
|
|
|
|
|
|
|
img_path = "Original Image.png" |
|
mask_path= "Mask Image.png" |
|
device = "cuda" |
|
model_name="runwayml/stable-diffusion-v1-5" |
|
torch_dtype = torch.float16 |
|
|
|
pipeline = create_inpaint_pipeline(model_name) |
|
pipeline = pipeline.to(device) |
|
|
|
|
|
try: |
|
init_image = Image.open(img_path).convert("RGB").resize((512, 512)) |
|
mask_image = Image.open(mask_path).convert("RGB").resize((512, 512)) |
|
except FileNotFoundError: |
|
print(f"Error: Image files '{img_path}' or '{mask_path}' not found.") |
|
return None |
|
|
|
print("Processing the image...") |
|
|
|
|
|
try: |
|
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0] |
|
image.save("Inpainted_img.png") |
|
return image |
|
except Exception as e: |
|
print(f"Error during inpainting: {e}") |
|
return None |
|
|
|
def create_inpaint_pipeline(model_name): |
|
pipeline = StableDiffusionInpaintPipeline.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
) |
|
return pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Mask(img): |
|
""" |
|
Function to process the input image and generate a mask. |
|
|
|
Args: |
|
img (dict): Dictionary containing the base image and the mask image. |
|
|
|
Returns: |
|
tuple: A tuple containing the base image and the mask image. |
|
""" |
|
try: |
|
|
|
imageio.imwrite("Original Image.png",img["image"]) |
|
imageio.imwrite("Mask Image.png", img["mask"]) |
|
|
|
return img["image"], img["mask"] |
|
except KeyError as e: |
|
|
|
return f"Key error: {e}", None |
|
except Exception as e: |
|
|
|
return f"An error occurred: {e}", None |
|
|
|
|
|
def main(): |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
img = gr.Image(tool="sketch", label="Paint Image", show_label=True) |
|
img1 = gr.Image(label="Original Image") |
|
img2 = gr.Image(label="Mask Image", show_label=True) |
|
|
|
btn = gr.Button() |
|
|
|
btn.click(Mask, inputs=img, outputs=[img1, img2]) |
|
|
|
|
|
with gr.Row(): |
|
prompt = gr.Textbox(label="Enter the prompt") |
|
button = gr.Button("Click") |
|
output_image = gr.Image(label="Generated Image") |
|
|
|
|
|
|
|
|
|
button.click(perform_inpainting, inputs=prompt,outputs=output_image) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
if __name__=='__main__': |
|
main() |
|
|