import os import spaces from pip._internal import main # os.system('python model/segment_anything_2/setup.py build_ext --inplace') main(['install', 'timm==1.0.8']) main(['install', 'samv2']) main(['install', 'torch==2.1.2']) main(['install', 'numpy==1.21.6']) import timm print("installed", timm.__version__) import gradio as gr from inference import sam_preprocess, beit3_preprocess from model.evf_sam2 import EvfSam2Model from model.evf_sam2_video import EvfSam2Model as EvfSam2VideoModel from transformers import AutoTokenizer import torch import cv2 import numpy as np import sys import tqdm version = "YxZhang/evf-sam2" model_type = "sam2" tokenizer = AutoTokenizer.from_pretrained( version, padding_side="right", use_fast=False, ) kwargs = { "torch_dtype": torch.half, } image_model = EvfSam2Model.from_pretrained(version, low_cpu_mem_usage=True, **kwargs) del image_model.visual_model.memory_encoder del image_model.visual_model.memory_attention image_model = image_model.eval() image_model.to('cuda') video_model = EvfSam2VideoModel.from_pretrained(version, low_cpu_mem_usage=True, **kwargs) video_model = video_model.eval() fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_model.to('cuda') @spaces.GPU @torch.no_grad() def inference_image(image_np, prompt): original_size_list = [image_np.shape[:2]] image_beit = beit3_preprocess(image_np, 224).to(dtype=image_model.dtype, device=image_model.device) image_sam, resize_shape = sam_preprocess(image_np, model_type=model_type) image_sam = image_sam.to(dtype=image_model.dtype, device=image_model.device) input_ids = tokenizer( prompt, return_tensors="pt")["input_ids"].to(device=image_model.device) # infer pred_mask = image_model.inference( image_sam.unsqueeze(0), image_beit.unsqueeze(0), input_ids, resize_list=[resize_shape], original_size_list=original_size_list, ) pred_mask = pred_mask.detach().cpu().numpy()[0] pred_mask = pred_mask > 0 visualization = image_np.copy() visualization[pred_mask] = (image_np * 0.5 + pred_mask[:, :, None].astype(np.uint8) * np.array([50, 120, 220]) * 0.5)[pred_mask] return visualization / 255.0 @spaces.GPU @torch.no_grad() @torch.autocast(device_type="cuda", dtype=torch.float16) def inference_video(video_path, prompt): os.system("rm -rf demo_temp") os.makedirs("demo_temp/input_frames", exist_ok=True) os.system( "ffmpeg -i {} -q:v 2 -start_number 0 demo_temp/input_frames/'%05d.jpg'" .format(video_path)) input_frames = sorted(os.listdir("demo_temp/input_frames")) image_np = cv2.imread("demo_temp/input_frames/00000.jpg") image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) height, width, channels = image_np.shape image_beit = beit3_preprocess(image_np, 224).to(dtype=video_model.dtype, device=video_model.device) input_ids = tokenizer( prompt, return_tensors="pt")["input_ids"].to(device=video_model.device) # infer output = video_model.inference( "demo_temp/input_frames", image_beit.unsqueeze(0), input_ids, ) # save visualization video_writer = cv2.VideoWriter("demo_temp/out.mp4", fourcc, 30, (width, height)) pbar = tqdm(input_frames) pbar.set_description("generating video: ") for i, file in enumerate(pbar): img = cv2.imread(os.path.join("demo_temp/input_frames", file)) vis = img + np.array([0, 0, 128]) * output[i][1].transpose(1, 2, 0) vis = np.clip(vis, 0, 255) vis = np.uint8(vis) video_writer.write(vis) video_writer.release() return "demo_temp/out.mp4" desc = """

EVF-SAM: Early Vision-Language Fusion for Text-Prompted Segment Anything Model

EVF-SAM extends SAM's capabilities with text-prompted segmentation, achieving high accuracy in Referring Expression Segmentation.

""" # desc_title_str = '

Early Vision-Language Fusion for Text-Prompted Segment Anything Model

' # desc_link_str = '[![arxiv paper](https://img.shields.io/badge/arXiv-Paper-red)](https://arxiv.org/abs/2406.20076)' with gr.Blocks() as demo: gr.Markdown(desc) with gr.Tab(label="EVF-SAM-2-Image"): with gr.Row(): input_image = gr.Image(type='numpy', label='Input Image', image_mode='RGB') output_image = gr.Image(type='numpy', label='Output Image') with gr.Row(): image_prompt = gr.Textbox( label="Prompt", info= "Use a phrase or sentence to describe the object you want to segment. Currently we only support English" ) submit_image = gr.Button(value='Submit', scale=1, variant='primary') with gr.Tab(label="EVF-SAM-2-Video"): with gr.Row(): input_video = gr.Video(label='Input Video') output_video = gr.Video(label='Output Video') with gr.Row(): video_prompt = gr.Textbox( label="Prompt", info= "Use a phrase or sentence to describe the object you want to segment. Currently we only support English" ) submit_video = gr.Button(value='Submit', scale=1, variant='primary') submit_image.click(fn=inference_image, inputs=[input_image, image_prompt], outputs=output_image) submit_video.click(fn=inference_video, inputs=[input_video, video_prompt], outputs=output_video) demo.launch(show_error=True)