File size: 2,377 Bytes
2cd560a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# !pip install diffusers transformers

import requests
import torch
import numpy as np
from PIL import Image
from io import BytesIO
from diffusers import DiffusionPipeline

from segment_anything import sam_model_registry, SamPredictor


"""
Step 1: Download and preprocess example demo images
"""
def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")


img_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/input_image.png?raw=true"
# example_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/pomeranian_example.jpg?raw=True"
# example_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/example_image.jpg?raw=true"
example_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/labrador_example.jpg?raw=true"

init_image = download_image(img_url).resize((512, 512))
example_image = download_image(example_url).resize((512, 512))


"""
Step 2: Initialize SAM and PaintByExample models
"""

DEVICE = "cuda:1"

# SAM
SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "/comp_robot/rentianhe/code/Grounded-Segment-Anything/sam_vit_h_4b8939.pth"
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE)
sam_predictor = SamPredictor(sam)
sam_predictor.set_image(np.array(init_image))

# PaintByExample Pipeline
CACHE_DIR = "/comp_robot/rentianhe/weights/diffusers/"
pipe = DiffusionPipeline.from_pretrained(
    "Fantasy-Studio/Paint-by-Example",
    torch_dtype=torch.float16,
    cache_dir=CACHE_DIR,
)
pipe = pipe.to(DEVICE)


"""
Step 3: Get masks with SAM by prompt (box or point) and inpaint the mask region by example image.
"""

input_point = np.array([[350, 256]])
input_label = np.array([1])  # positive label

masks, _, _ = sam_predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=False
)
mask = masks[0]  # [1, 512, 512] to [512, 512] np.ndarray
mask_pil = Image.fromarray(mask)

mask_pil.save("./mask.jpg")

image = pipe(
    image=init_image, 
    mask_image=mask_pil, 
    example_image=example_image, 
    num_inference_steps=500, 
    guidance_scale=9.0
).images[0]

image.save("./paint_by_example_demo.jpg")