#!/usr/bin/env python from __future__ import annotations import argparse import os import pathlib import subprocess if os.getenv('SYSTEM') == 'spaces': import mim mim.uninstall('mmcv-full', confirm_yes=True) mim.install('mmcv-full==1.5.2', is_yes=True) subprocess.call('pip uninstall -y opencv-python'.split()) subprocess.call('pip uninstall -y opencv-python-headless'.split()) subprocess.call('pip install opencv-python-headless==4.5.5.64'.split()) subprocess.call('pip install pycocotools'.split()) subprocess.call('pip install detectron2==0.5 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.7/index.html'.split()) import cv2 import gradio as gr import numpy as np from mmdet.apis import init_detector, inference_detector from utils import show_result from mmcv import Config DESCRIPTION = '''# OpenPSG This is an official demo for [OpenPSG](https://github.com/Jingkang50/OpenPSG). overview ''' FOOTER = 'visitor badge' def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') return parser.parse_args() def update_input_image(image: np.ndarray) -> dict: if image is None: return gr.Image.update(value=None) scale = 1500 / max(image.shape[:2]) if scale < 1: image = cv2.resize(image, None, fx=scale, fy=scale) return gr.Image.update(value=image) def set_example_image(example: list) -> dict: return gr.Image.update(value=example[0]) class Model: def __init__(self, model_name, device='cpu'): model_ckt ='OpenPSG/checkpoints/epoch_60.pth' cfg = Config.fromfile('OpenPSG/configs/psgtr/psgtr_r50_psg_inference.py') self.model = init_detector(cfg, model_ckt, device=device) def infer(self, input_image, num_rel): result = inference_detector(self.model, input_image) return show_result(input_image, result, is_one_stage=True, num_rel=num_rel, show=True ) def main(): args = parse_args() with gr.Blocks(theme=args.theme, css='style.css') as demo: model = Model('psgtr', device=args.device) gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.Image(label='Input Image', type='numpy') with gr.Group(): with gr.Row(): num_rel = gr.Slider( 5, 100, step=5, value=20, label='Number of Relations') with gr.Row(): run_button = gr.Button(value='Run') with gr.Column(): with gr.Row(): result = gr.Gallery(label='Result', type='numpy') with gr.Row(): paths = sorted(pathlib.Path('images').rglob('*.jpg')) example_images = gr.Dataset(components=[input_image], samples=[[path.as_posix()] for path in paths]) gr.Markdown(FOOTER) input_image.change(fn=update_input_image, inputs=input_image, outputs=input_image) run_button.click(fn=model.infer, inputs=[ input_image, num_rel ], outputs=result) example_images.click(fn=set_example_image, inputs=example_images, outputs=input_image) demo.launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()