File size: 4,202 Bytes
483de47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import data
import cv2
import torch
import numpy as np
from PIL import Image, ImageDraw
from tqdm import tqdm
from models import imagebind_model
from models.imagebind_model import ModalityType
from segment_anything import build_sam, SamAutomaticMaskGenerator
from utils import (
segment_image,
convert_box_xywh_to_xyxy,
get_indices_of_values_above_threshold,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
"""
Step 1: Instantiate model
"""
# Segment Anything
mask_generator = SamAutomaticMaskGenerator(
build_sam(checkpoint=".checkpoints/sam_vit_h_4b8939.pth").to(device),
points_per_side=16,
)
# ImageBind
bind_model = imagebind_model.imagebind_huge(pretrained=True)
bind_model.eval()
bind_model.to(device)
"""
Step 2: Generate auto masks with SAM
"""
image_path = ".assets/car_image.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
"""
Step 3: Get cropped images based on mask and box
"""
cropped_boxes = []
image = Image.open(image_path)
for mask in tqdm(masks):
cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))
"""
Step 4: Run ImageBind model to get similarity between cropped image and different modalities
"""
def retriev_vision_and_text(elements, text_list):
inputs = {
ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device),
ModalityType.TEXT: data.load_and_transform_text(text_list, device),
}
with torch.no_grad():
embeddings = bind_model(inputs)
vision_audio = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=0),
return vision_audio # [113, 1]
def retriev_vision_and_audio(elements, audio_list):
inputs = {
ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device),
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_list, device),
}
with torch.no_grad():
embeddings = bind_model(inputs)
vision_audio = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=0),
return vision_audio
vision_audio_result = retriev_vision_and_audio(cropped_boxes, [".assets/car_audio.wav"])
vision_text_result = retriev_vision_and_text(cropped_boxes, ["A car"] )
"""
Step 5: Merge the top similarity masks to get the final mask and save the merged mask
This is the audio retrival result
"""
# get highest similar mask with threshold
# result[0] shape: [113, 1]
threshold = 0.025
index = get_indices_of_values_above_threshold(vision_audio_result[0], threshold)
segmentation_masks = []
for seg_idx in index:
segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
segmentation_masks.append(segmentation_mask_image)
original_image = Image.open(image_path)
overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255))
overlay_color = (255, 255, 255, 0)
draw = ImageDraw.Draw(overlay_image)
for segmentation_mask_image in segmentation_masks:
draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color)
# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
mask_image = overlay_image.convert("RGB")
mask_image.save("./audio_sam_merged_mask.jpg")
"""
Image / Text mask
"""
# get highest similar mask with threshold
# result[0] shape: [113, 1]
threshold = 0.05
index = get_indices_of_values_above_threshold(vision_text_result[0], threshold)
segmentation_masks = []
for seg_idx in index:
segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
segmentation_masks.append(segmentation_mask_image)
original_image = Image.open(image_path)
overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255))
overlay_color = (255, 255, 255, 0)
draw = ImageDraw.Draw(overlay_image)
for segmentation_mask_image in segmentation_masks:
draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color)
# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
mask_image = overlay_image.convert("RGB")
mask_image.save("./text_sam_merged_mask.jpg")
|