stable_edit / model.py
itberrios's picture
update
261c2c3
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from diffusers import StableDiffusionInpaintPipeline # , DiffusionPipeline
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Image segmentation
seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
seg_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
def segment_image(image):
inputs = seg_processor(image, return_tensors="pt")
with torch.no_grad():
seg_outputs = seg_model(**inputs)
# get prediction dict
seg_prediction = seg_processor.post_process_panoptic_segmentation(seg_outputs, target_sizes=[image.size[::-1]])[0]
# get segment labels dict
segment_labels = {}
for segment in seg_prediction['segments_info']:
segment_id = segment['id']
segment_label_id = segment['label_id']
segment_label = seg_model.config.id2label[segment_label_id]
segment_labels.update({segment_id : segment_label})
return seg_prediction, segment_labels
# Image inpainting pipeline
# get Stable Diffusion model for image inpainting
if device == 'cuda':
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
torch_dtype=torch.float16,
).to(device)
else:
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
torch_dtype=torch.bfloat16,
device_map="auto"
)
# pipe = StableDiffusionInpaintPipeline.from_pretrained( # DiffusionPipeline.from_pretrained(
# "runwayml/stable-diffusion-inpainting",
# revision="fp16",
# torch_dtype=torch.bfloat16,
# # device_map="auto" # use for Hugging face spaces
# )
# pipe.to(device) # use for local environment
def inpaint(image, mask, W, H, prompt="", seed=0, guidance_scale=17.5, num_samples=3):
""" Uses Stable Diffusion model to inpaint image
Inputs:
image - input image (PIL or torch tensor)
mask - mask for inpainting same size as image (PIL or troch tensor)
W - size of image
H - size of mask
prompt - prompt for inpainting
seed - random seed
Outputs:
images - output images
"""
generator = torch.Generator(device=device).manual_seed(seed)
images = pipe(
prompt=prompt,
image=image,
mask_image=mask, # ensure mask is same type as image
height=H,
width=W,
guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=num_samples,
).images
return images