|
import logging |
|
from typing import List, Tuple, Dict |
|
|
|
import streamlit as st |
|
import torch |
|
import gc |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation |
|
|
|
from palette import ade_palette |
|
|
|
LOGGING = logging.getLogger(__name__) |
|
|
|
|
|
def flush(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
@st.experimental_singleton(max_entries=5) |
|
def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: |
|
"""Method to load the segmentation pipeline |
|
Returns: |
|
Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline |
|
""" |
|
image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") |
|
image_segmentor = UperNetForSemanticSegmentation.from_pretrained( |
|
"openmmlab/upernet-convnext-small") |
|
return image_processor, image_segmentor |
|
|
|
|
|
@torch.inference_mode() |
|
@torch.autocast('cuda') |
|
def segment_image(image: Image) -> Image: |
|
"""Method to segment image |
|
Args: |
|
image (Image): input image |
|
Returns: |
|
Image: segmented image |
|
""" |
|
image_processor, image_segmentor = get_segmentation_pipeline() |
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values |
|
with torch.no_grad(): |
|
outputs = image_segmentor(pixel_values) |
|
|
|
seg = image_processor.post_process_semantic_segmentation( |
|
outputs, target_sizes=[image.size[::-1]])[0] |
|
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) |
|
palette = np.array(ade_palette()) |
|
for label, color in enumerate(palette): |
|
color_seg[seg == label, :] = color |
|
color_seg = color_seg.astype(np.uint8) |
|
seg_image = Image.fromarray(color_seg).convert('RGB') |
|
return seg_image |