import torch from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation from diffusers import StableDiffusionInpaintPipeline # , DiffusionPipeline device = 'cuda' if torch.cuda.is_available() else 'cpu' # Image segmentation seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") seg_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") def segment_image(image): inputs = seg_processor(image, return_tensors="pt") with torch.no_grad(): seg_outputs = seg_model(**inputs) # get prediction dict seg_prediction = seg_processor.post_process_panoptic_segmentation(seg_outputs, target_sizes=[image.size[::-1]])[0] # get segment labels dict segment_labels = {} for segment in seg_prediction['segments_info']: segment_id = segment['id'] segment_label_id = segment['label_id'] segment_label = seg_model.config.id2label[segment_label_id] segment_labels.update({segment_id : segment_label}) return seg_prediction, segment_labels # Image inpainting pipeline # get Stable Diffusion model for image inpainting if device == 'cuda': pipe = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, ).to(device) else: pipe = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", torch_dtype=torch.bfloat16, device_map="auto" ) # pipe = StableDiffusionInpaintPipeline.from_pretrained( # DiffusionPipeline.from_pretrained( # "runwayml/stable-diffusion-inpainting", # revision="fp16", # torch_dtype=torch.bfloat16, # # device_map="auto" # use for Hugging face spaces # ) # pipe.to(device) # use for local environment def inpaint(image, mask, W, H, prompt="", seed=0, guidance_scale=17.5, num_samples=3): """ Uses Stable Diffusion model to inpaint image Inputs: image - input image (PIL or torch tensor) mask - mask for inpainting same size as image (PIL or troch tensor) W - size of image H - size of mask prompt - prompt for inpainting seed - random seed Outputs: images - output images """ generator = torch.Generator(device=device).manual_seed(seed) images = pipe( prompt=prompt, image=image, mask_image=mask, # ensure mask is same type as image height=H, width=W, guidance_scale=guidance_scale, generator=generator, num_images_per_prompt=num_samples, ).images return images