import itertools import math import os.path as osp import numpy as np import requests import streamlit as st from mmengine.dataset import Compose, default_collate from mmengine.fileio import list_from_file from mmengine.registry import init_default_scope from PIL import Image import mmengine import logging from mmengine.logging.logger import MMFormatter from mmcls import list_models as list_models_ from mmcls.apis.model import ModelHub, init_model import os @st.cache() def prepare_data(): import subprocess subprocess.run(['unzip', '-n', 'imagenet-val.zip']) @st.cache() def load_demo_image(): response = requests.get( 'https://github.com/open-mmlab/mmclassification/blob/master/demo/bird.JPEG?raw=true', # noqa stream=True).raw img = Image.open(response).convert('RGB') return img @st.cache() def list_models(*args, **kwargs): return sorted(list_models_(*args, **kwargs)) DATA_ROOT = '.' ANNO_FILE = 'meta/val.txt' LOG_FILE = 'demo.log' CACHED_PATH = 'cache' def get_model(model_name, pretrained=True): metainfo = ModelHub.get(model_name) if pretrained: if metainfo.weights is None: raise ValueError( f"The model {model_name} doesn't have pretrained weights.") ckpt = metainfo.weights else: ckpt = None cfg = metainfo.config cfg.model.backbone.init_cfg = dict( type='Pretrained', checkpoint=ckpt, prefix='backbone') new_model_cfg = dict() new_model_cfg['type'] = 'ImageToImageRetriever' if hasattr(cfg.model, 'neck') and cfg.model.neck is not None: new_model_cfg['image_encoder'] = [cfg.model.backbone, cfg.model.neck] else: new_model_cfg['image_encoder'] = cfg.model.backbone cfg.model = new_model_cfg # prepare prototype cached_path = f'{CACHED_PATH}/{model_name}_prototype.pt' # noqa cfg.model.prototype = cached_path model = init_model(metainfo.config, None, device='cpu') with st.spinner(f'Loading model {model_name} on the server...This is ' 'slow at the first time.'): model.init_weights() st.success('Model loaded!') with st.spinner('Preparing prototype for all image...This is ' 'slow at the first time.'): model.prepare_prototype() return model def get_pred(name, img): logger = mmengine.logging.MMLogger.get_current_instance() file_handler = logging.FileHandler(LOG_FILE, 'w') # `StreamHandler` record year, month, day hour, minute, # and second timestamp. file_handler will only record logs # without color to avoid garbled code saved in files. file_handler.setFormatter( MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S')) file_handler.setLevel('INFO') logger.handlers.append(file_handler) init_default_scope('mmcls') model = get_model(name) cfg = model.cfg # build the data pipeline test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline if isinstance(img, str): if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile': test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile')) data = dict(img_path=img) elif isinstance(img, np.ndarray): if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile': test_pipeline_cfg.pop(0) data = dict(img=img) elif isinstance(img, Image.Image): if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile': test_pipeline_cfg[0] = dict(type='ToNumpy', keys=['img']) data = dict(img=img) test_pipeline = Compose(test_pipeline_cfg) data = test_pipeline(data) data = default_collate([data]) labels = model.val_step(data)[0].pred_label.label scores = model.val_step(data)[0].pred_label.score[labels] image_list = list_from_file(osp.join(DATA_ROOT, ANNO_FILE)) data_root = osp.join(DATA_ROOT, 'val') result_list = [(osp.join(data_root, image_list[idx].rsplit()[0]), score) for idx, score in zip(labels, scores)] return result_list def app(): prepare_data() model_name = st.sidebar.selectbox( "Model:", [m.split('_prototype.pt')[0] for m in os.listdir(CACHED_PATH)]) st.markdown( "