mmcls-retriever / app.py
huyingfan
add more models
2a534ec
raw
history blame contribute delete
No virus
6.13 kB
import itertools
import math
import os.path as osp
import numpy as np
import requests
import streamlit as st
from mmengine.dataset import Compose, default_collate
from mmengine.fileio import list_from_file
from mmengine.registry import init_default_scope
from PIL import Image
import mmengine
import logging
from mmengine.logging.logger import MMFormatter
from mmcls import list_models as list_models_
from mmcls.apis.model import ModelHub, init_model
import os
@st.cache()
def prepare_data():
import subprocess
subprocess.run(['unzip', '-n', 'imagenet-val.zip'])
@st.cache()
def load_demo_image():
response = requests.get(
'https://github.com/open-mmlab/mmclassification/blob/master/demo/bird.JPEG?raw=true', # noqa
stream=True).raw
img = Image.open(response).convert('RGB')
return img
@st.cache()
def list_models(*args, **kwargs):
return sorted(list_models_(*args, **kwargs))
DATA_ROOT = '.'
ANNO_FILE = 'meta/val.txt'
LOG_FILE = 'demo.log'
CACHED_PATH = 'cache'
def get_model(model_name, pretrained=True):
metainfo = ModelHub.get(model_name)
if pretrained:
if metainfo.weights is None:
raise ValueError(
f"The model {model_name} doesn't have pretrained weights.")
ckpt = metainfo.weights
else:
ckpt = None
cfg = metainfo.config
cfg.model.backbone.init_cfg = dict(
type='Pretrained', checkpoint=ckpt, prefix='backbone')
new_model_cfg = dict()
new_model_cfg['type'] = 'ImageToImageRetriever'
if hasattr(cfg.model, 'neck') and cfg.model.neck is not None:
new_model_cfg['image_encoder'] = [cfg.model.backbone, cfg.model.neck]
else:
new_model_cfg['image_encoder'] = cfg.model.backbone
cfg.model = new_model_cfg
# prepare prototype
cached_path = f'{CACHED_PATH}/{model_name}_prototype.pt' # noqa
cfg.model.prototype = cached_path
model = init_model(metainfo.config, None, device='cpu')
with st.spinner(f'Loading model {model_name} on the server...This is '
'slow at the first time.'):
model.init_weights()
st.success('Model loaded!')
with st.spinner('Preparing prototype for all image...This is '
'slow at the first time.'):
model.prepare_prototype()
return model
def get_pred(name, img):
logger = mmengine.logging.MMLogger.get_current_instance()
file_handler = logging.FileHandler(LOG_FILE, 'w')
# `StreamHandler` record year, month, day hour, minute,
# and second timestamp. file_handler will only record logs
# without color to avoid garbled code saved in files.
file_handler.setFormatter(
MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
file_handler.setLevel('INFO')
logger.handlers.append(file_handler)
init_default_scope('mmcls')
model = get_model(name)
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if isinstance(img, str):
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
data = dict(img_path=img)
elif isinstance(img, np.ndarray):
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg.pop(0)
data = dict(img=img)
elif isinstance(img, Image.Image):
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg[0] = dict(type='ToNumpy', keys=['img'])
data = dict(img=img)
test_pipeline = Compose(test_pipeline_cfg)
data = test_pipeline(data)
data = default_collate([data])
labels = model.val_step(data)[0].pred_label.label
scores = model.val_step(data)[0].pred_label.score[labels]
image_list = list_from_file(osp.join(DATA_ROOT, ANNO_FILE))
data_root = osp.join(DATA_ROOT, 'val')
result_list = [(osp.join(data_root, image_list[idx].rsplit()[0]), score)
for idx, score in zip(labels, scores)]
return result_list
def app():
prepare_data()
model_name = st.sidebar.selectbox(
"Model:",
[m.split('_prototype.pt')[0] for m in os.listdir(CACHED_PATH)])
st.markdown(
"<h1>Image To Image Retrieval</h1>",
unsafe_allow_html=True,
)
st.write(
'This is a demo for image to image retrieval in around 3k images from '
'ImageNet tiny val set using mmclassification apis. You can try more '
'features on [mmclassification]'
'(https://github.com/open-mmlab/mmclassification).')
file = st.file_uploader(
'Please upload your own image or use the provided:')
container1 = st.container()
if file:
raw_img = Image.open(file).convert('RGB')
else:
raw_img = load_demo_image()
container1.header('Image')
w, h = raw_img.size
scaling_factor = 360 / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor)))
container1.image(resized_image, use_column_width='auto')
button = container1.button('Search')
st.header('Results')
topk = st.sidebar.number_input('Topk(1-50)', min_value=1, max_value=50)
# search on both selection of topk and button
if button or topk > 1:
result_list = get_pred(model_name, raw_img)
# auto adjust number of images in a row but 5 at most.
col = min(int(math.sqrt(topk)), 5)
row = math.ceil(topk / col)
grid = []
for i in range(row):
with st.container():
grid.append(st.columns(col))
grid = list(itertools.chain.from_iterable(grid))[:topk]
for cell, (image_path, score) in zip(grid, result_list[:topk]):
image = Image.open(image_path).convert('RGB')
w, h = raw_img.size
scaling_factor = 360 / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor)))
cell.caption('Score: {:.4f}'.format(float(score)))
cell.image(image)
if __name__ == '__main__':
app()