from PIL import Image | |
import numpy as np | |
def segment_image(image, segmentation_mask): | |
image_array = np.array(image) | |
segmented_image_array = np.zeros_like(image_array) | |
segmented_image_array[segmentation_mask] = image_array[segmentation_mask] | |
segmented_image = Image.fromarray(segmented_image_array) | |
black_image = Image.new("RGB", image.size, (0, 0, 0)) | |
transparency_mask = np.zeros_like(segmentation_mask, dtype=np.uint8) | |
transparency_mask[segmentation_mask] = 255 | |
transparency_mask_image = Image.fromarray(transparency_mask, mode='L') | |
black_image.paste(segmented_image, mask=transparency_mask_image) | |
return black_image | |
def convert_box_xywh_to_xyxy(box): | |
x1 = box[0] | |
y1 = box[1] | |
x2 = box[0] + box[2] | |
y2 = box[1] + box[3] | |
return [x1, y1, x2, y2] | |
def get_indices_of_values_above_threshold(values, threshold): | |
return [i for i, v in enumerate(values) if v > threshold] |