import itertools import math import os.path as osp import numpy as np import requests import streamlit as st from mmengine.dataset import Compose, default_collate from mmengine.fileio import list_from_file from mmengine.registry import init_default_scope from PIL import Image import mmengine import logging from mmengine.logging.logger import MMFormatter from mmcls import list_models as list_models_ from mmcls.apis.model import ModelHub, init_model import os @st.cache() def prepare_data(): import subprocess subprocess.run(['unzip', '-n', 'imagenet-val.zip']) @st.cache() def load_demo_image(): response = requests.get( 'https://github.com/open-mmlab/mmclassification/blob/master/demo/bird.JPEG?raw=true', # noqa stream=True).raw img = Image.open(response).convert('RGB') return img @st.cache() def list_models(*args, **kwargs): return sorted(list_models_(*args, **kwargs)) DATA_ROOT = '.' ANNO_FILE = 'meta/val.txt' LOG_FILE = 'demo.log' CACHED_PATH = 'cache' def get_model(model_name, pretrained=True): metainfo = ModelHub.get(model_name) if pretrained: if metainfo.weights is None: raise ValueError( f"The model {model_name} doesn't have pretrained weights.") ckpt = metainfo.weights else: ckpt = None cfg = metainfo.config cfg.model.backbone.init_cfg = dict( type='Pretrained', checkpoint=ckpt, prefix='backbone') new_model_cfg = dict() new_model_cfg['type'] = 'ImageToImageRetriever' if hasattr(cfg.model, 'neck') and cfg.model.neck is not None: new_model_cfg['image_encoder'] = [cfg.model.backbone, cfg.model.neck] else: new_model_cfg['image_encoder'] = cfg.model.backbone cfg.model = new_model_cfg # prepare prototype cached_path = f'{CACHED_PATH}/{model_name}_prototype.pt' # noqa cfg.model.prototype = cached_path model = init_model(metainfo.config, None, device='cpu') with st.spinner(f'Loading model {model_name} on the server...This is ' 'slow at the first time.'): model.init_weights() st.success('Model loaded!') with st.spinner('Preparing prototype for all image...This is ' 'slow at the first time.'): model.prepare_prototype() return model def get_pred(name, img): logger = mmengine.logging.MMLogger.get_current_instance() file_handler = logging.FileHandler(LOG_FILE, 'w') # `StreamHandler` record year, month, day hour, minute, # and second timestamp. file_handler will only record logs # without color to avoid garbled code saved in files. file_handler.setFormatter( MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S')) file_handler.setLevel('INFO') logger.handlers.append(file_handler) init_default_scope('mmcls') model = get_model(name) cfg = model.cfg # build the data pipeline test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline if isinstance(img, str): if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile': test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile')) data = dict(img_path=img) elif isinstance(img, np.ndarray): if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile': test_pipeline_cfg.pop(0) data = dict(img=img) elif isinstance(img, Image.Image): if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile': test_pipeline_cfg[0] = dict(type='ToNumpy', keys=['img']) data = dict(img=img) test_pipeline = Compose(test_pipeline_cfg) data = test_pipeline(data) data = default_collate([data]) labels = model.val_step(data)[0].pred_label.label scores = model.val_step(data)[0].pred_label.score[labels] image_list = list_from_file(osp.join(DATA_ROOT, ANNO_FILE)) data_root = osp.join(DATA_ROOT, 'val') result_list = [(osp.join(data_root, image_list[idx].rsplit()[0]), score) for idx, score in zip(labels, scores)] return result_list def app(): prepare_data() model_name = st.sidebar.selectbox( "Model:", [m.split('_prototype.pt')[0] for m in os.listdir(CACHED_PATH)]) st.markdown( "

Image To Image Retrieval

", unsafe_allow_html=True, ) st.write( 'This is a demo for image to image retrieval in around 3k images from ' 'ImageNet tiny val set using mmclassification apis. You can try more ' 'features on [mmclassification]' '(https://github.com/open-mmlab/mmclassification).') file = st.file_uploader( 'Please upload your own image or use the provided:') container1 = st.container() if file: raw_img = Image.open(file).convert('RGB') else: raw_img = load_demo_image() container1.header('Image') w, h = raw_img.size scaling_factor = 360 / w resized_image = raw_img.resize( (int(w * scaling_factor), int(h * scaling_factor))) container1.image(resized_image, use_column_width='auto') button = container1.button('Search') st.header('Results') topk = st.sidebar.number_input('Topk(1-50)', min_value=1, max_value=50) # search on both selection of topk and button if button or topk > 1: result_list = get_pred(model_name, raw_img) # auto adjust number of images in a row but 5 at most. col = min(int(math.sqrt(topk)), 5) row = math.ceil(topk / col) grid = [] for i in range(row): with st.container(): grid.append(st.columns(col)) grid = list(itertools.chain.from_iterable(grid))[:topk] for cell, (image_path, score) in zip(grid, result_list[:topk]): image = Image.open(image_path).convert('RGB') w, h = raw_img.size scaling_factor = 360 / w resized_image = raw_img.resize( (int(w * scaling_factor), int(h * scaling_factor))) cell.caption('Score: {:.4f}'.format(float(score))) cell.image(image) if __name__ == '__main__': app()