Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation | |
from diffusers import StableDiffusionInpaintPipeline # , DiffusionPipeline | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# Image segmentation | |
seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") | |
seg_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") | |
def segment_image(image): | |
inputs = seg_processor(image, return_tensors="pt") | |
with torch.no_grad(): | |
seg_outputs = seg_model(**inputs) | |
# get prediction dict | |
seg_prediction = seg_processor.post_process_panoptic_segmentation(seg_outputs, target_sizes=[image.size[::-1]])[0] | |
# get segment labels dict | |
segment_labels = {} | |
for segment in seg_prediction['segments_info']: | |
segment_id = segment['id'] | |
segment_label_id = segment['label_id'] | |
segment_label = seg_model.config.id2label[segment_label_id] | |
segment_labels.update({segment_id : segment_label}) | |
return seg_prediction, segment_labels | |
# Image inpainting pipeline | |
# get Stable Diffusion model for image inpainting | |
if device == 'cuda': | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
torch_dtype=torch.float16, | |
).to(device) | |
else: | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
# pipe = StableDiffusionInpaintPipeline.from_pretrained( # DiffusionPipeline.from_pretrained( | |
# "runwayml/stable-diffusion-inpainting", | |
# revision="fp16", | |
# torch_dtype=torch.bfloat16, | |
# # device_map="auto" # use for Hugging face spaces | |
# ) | |
# pipe.to(device) # use for local environment | |
def inpaint(image, mask, W, H, prompt="", seed=0, guidance_scale=17.5, num_samples=3): | |
""" Uses Stable Diffusion model to inpaint image | |
Inputs: | |
image - input image (PIL or torch tensor) | |
mask - mask for inpainting same size as image (PIL or troch tensor) | |
W - size of image | |
H - size of mask | |
prompt - prompt for inpainting | |
seed - random seed | |
Outputs: | |
images - output images | |
""" | |
generator = torch.Generator(device=device).manual_seed(seed) | |
images = pipe( | |
prompt=prompt, | |
image=image, | |
mask_image=mask, # ensure mask is same type as image | |
height=H, | |
width=W, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
num_images_per_prompt=num_samples, | |
).images | |
return images | |