PSG / app.py
liangyuch's picture
Update app.py
e2d5627
raw
history blame
No virus
6.58 kB
from __future__ import annotations
import argparse
import os
import pathlib
import subprocess
if os.getenv('SYSTEM') == 'spaces':
import mim
mim.uninstall('mmcv-full', confirm_yes=True)
mim.install('mmcv-full==1.4.3', is_yes=True)
subprocess.call('pip uninstall -y opencv-python'.split())
subprocess.call('pip uninstall -y opencv-python-headless'.split())
subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
subprocess.call('pip install pycocotools'.split())
subprocess.call("pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html".split())
# subprocess.call("pip install git+https://github.com/c-liangyu/OpenPSG.git@dev_apis".split())
subprocess.call("pip install git+https://github.com/Jingkang50/OpenPSG.git@hugging_face_demo".split())
subprocess.call("pip install git+https://github.com/cocodataset/panopticapi.git".split())
import cv2
import gradio as gr
import numpy as np
from mmdet.apis import init_detector, inference_detector
from utils import make_gif, show_result
from mmcv import Config
import openpsg
DESCRIPTION = '''# ECCV'22 | Panoptic Scene Graph Generation
πŸš€ πŸš€ πŸš€ This is an official demo for our ECCV'22 paper: [Panoptic Scene Graph Generation](https://psgdataset.org/). Please star our [codebase](https://github.com/Jingkang50/OpenPSG) if you find it useful / interesting.
πŸ“’ πŸ“’ πŸ“’ **News:** The PSG Challenge (prize pool πŸ€‘ **US$150K** πŸ€‘) is now available on [International Algorithm Case Competition](https://www.cvmart.net/race/10349/base?organic_url=https%3A%2F%2Fhf.space%2F) and [ECCV'22 SenseHuman Workshop](https://sense-human.github.io/)!
πŸ” πŸ” πŸ” Check out the [news section](https://github.com/Jingkang50/OpenPSG#updates) in our [GitHub repo](https://github.com/Jingkang50/OpenPSG) for more details. Everyone around the world is welcome to participant and explore the comprehensive scene understanding!
🎯 🎯 🎯 The PSG Development Team is currently focusing on **(1) πŸ§™β€β™‚οΈ Next-Generation PSG Models**, **(2) πŸ•΅οΈβ€β™€οΈ Relation-Aware Visual Reasoning from PSG Models**, and **(3) 🎨 Relation-Aware Image Generation from Scene Graph and Caption**. If you are also interested in the related researches, please reach out and contact us!
<div class="row">
<div class="column">
<img id="logo" src="https://camo.githubusercontent.com/880346b66831a8212074787ba9a2301b4d700bd8f765ca11e4845ac0ab34c230/68747470733a2f2f6c6976652e737461746963666c69636b722e636f6d2f36353533352f35323139333837393637375f373531613465306237395f6b2e6a7067" alt="logo" style="width:60%">
</div>
<div class="column">
<img id="visualzation" src="https://github.com/Jingkang50/OpenPSG/blob/main/assets/psgtr_long.gif?raw=true" alt="visualzation" style="width:60%">
</div>
</div>
Inference takes 10-30 seconds per image. The model is PSGTR (60 epochs).
'''
FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=c-liangyu.openpsg" alt="visitor badge" />'
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--theme', type=str)
parser.add_argument('--share', action='store_true')
parser.add_argument('--port', type=int)
parser.add_argument('--disable-queue',
dest='enable_queue',
action='store_false')
return parser.parse_args()
def update_input_image(image: np.ndarray) -> dict:
if image is None:
return gr.Image.update(value=None)
scale = 800 / max(image.shape[:2])
if scale < 1:
image = cv2.resize(image, None, fx=scale, fy=scale)
return gr.Image.update(value=image)
def set_example_image(example: list) -> dict:
return gr.Image.update(value=example[0])
class Model:
def __init__(self, model_name, device='cpu'):
model_ckt ='OpenPSG/checkpoints/epoch_60.pth'
cfg = Config.fromfile('OpenPSG/configs/psgtr/psgtr_r50_psg_inference.py')
self.model = init_detector(cfg, model_ckt, device=device)
def infer(self, input_image, num_rel):
result = inference_detector(self.model, input_image)
displays = show_result(input_image,
result,
is_one_stage=True,
num_rel=num_rel,
show=True
)
gif = make_gif(displays[:10] if len(displays) > 10 else displays)
return gif, displays
def main():
args = parse_args()
with gr.Blocks(theme=args.theme, css='style.css') as demo:
model = Model('psgtr', device=args.device)
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(label='Input Image', type='numpy')
with gr.Group():
with gr.Row():
num_rel = gr.Slider(
5,
100,
step=5,
value=20,
label='Number of Relations')
with gr.Row():
run_button = gr.Button(value='Run')
with gr.Column():
with gr.Row():
gif = gr.Image(label='Top Relations')
with gr.Row():
displays = gr.Gallery(label='PSGTR Result', type='numpy')
with gr.Row():
paths = sorted(pathlib.Path('images').rglob('*.jpg'))
example_images = gr.Dataset(components=[input_image],
samples=[[path.as_posix()]
for path in paths])
gr.Markdown(FOOTER)
input_image.change(fn=update_input_image,
inputs=input_image,
outputs=input_image)
run_button.click(fn=model.infer,
inputs=[
input_image, num_rel
],
outputs=[gif, displays])
example_images.click(fn=set_example_image,
inputs=example_images,
outputs=input_image)
demo.launch(
enable_queue=args.enable_queue,
server_port=args.port,
share=args.share,
)
if __name__ == '__main__':
main()