Spaces:
Sleeping
Sleeping
File size: 6,128 Bytes
7e02c9f 2a534ec 7e02c9f 2a534ec 7e02c9f 2a534ec 7e02c9f 21bd41f 7e02c9f 2a534ec 7e02c9f 21bd41f 7e02c9f 21bd41f 7e02c9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import itertools
import math
import os.path as osp
import numpy as np
import requests
import streamlit as st
from mmengine.dataset import Compose, default_collate
from mmengine.fileio import list_from_file
from mmengine.registry import init_default_scope
from PIL import Image
import mmengine
import logging
from mmengine.logging.logger import MMFormatter
from mmcls import list_models as list_models_
from mmcls.apis.model import ModelHub, init_model
import os
@st.cache()
def prepare_data():
import subprocess
subprocess.run(['unzip', '-n', 'imagenet-val.zip'])
@st.cache()
def load_demo_image():
response = requests.get(
'https://github.com/open-mmlab/mmclassification/blob/master/demo/bird.JPEG?raw=true', # noqa
stream=True).raw
img = Image.open(response).convert('RGB')
return img
@st.cache()
def list_models(*args, **kwargs):
return sorted(list_models_(*args, **kwargs))
DATA_ROOT = '.'
ANNO_FILE = 'meta/val.txt'
LOG_FILE = 'demo.log'
CACHED_PATH = 'cache'
def get_model(model_name, pretrained=True):
metainfo = ModelHub.get(model_name)
if pretrained:
if metainfo.weights is None:
raise ValueError(
f"The model {model_name} doesn't have pretrained weights.")
ckpt = metainfo.weights
else:
ckpt = None
cfg = metainfo.config
cfg.model.backbone.init_cfg = dict(
type='Pretrained', checkpoint=ckpt, prefix='backbone')
new_model_cfg = dict()
new_model_cfg['type'] = 'ImageToImageRetriever'
if hasattr(cfg.model, 'neck') and cfg.model.neck is not None:
new_model_cfg['image_encoder'] = [cfg.model.backbone, cfg.model.neck]
else:
new_model_cfg['image_encoder'] = cfg.model.backbone
cfg.model = new_model_cfg
# prepare prototype
cached_path = f'{CACHED_PATH}/{model_name}_prototype.pt' # noqa
cfg.model.prototype = cached_path
model = init_model(metainfo.config, None, device='cpu')
with st.spinner(f'Loading model {model_name} on the server...This is '
'slow at the first time.'):
model.init_weights()
st.success('Model loaded!')
with st.spinner('Preparing prototype for all image...This is '
'slow at the first time.'):
model.prepare_prototype()
return model
def get_pred(name, img):
logger = mmengine.logging.MMLogger.get_current_instance()
file_handler = logging.FileHandler(LOG_FILE, 'w')
# `StreamHandler` record year, month, day hour, minute,
# and second timestamp. file_handler will only record logs
# without color to avoid garbled code saved in files.
file_handler.setFormatter(
MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
file_handler.setLevel('INFO')
logger.handlers.append(file_handler)
init_default_scope('mmcls')
model = get_model(name)
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if isinstance(img, str):
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
data = dict(img_path=img)
elif isinstance(img, np.ndarray):
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg.pop(0)
data = dict(img=img)
elif isinstance(img, Image.Image):
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg[0] = dict(type='ToNumpy', keys=['img'])
data = dict(img=img)
test_pipeline = Compose(test_pipeline_cfg)
data = test_pipeline(data)
data = default_collate([data])
labels = model.val_step(data)[0].pred_label.label
scores = model.val_step(data)[0].pred_label.score[labels]
image_list = list_from_file(osp.join(DATA_ROOT, ANNO_FILE))
data_root = osp.join(DATA_ROOT, 'val')
result_list = [(osp.join(data_root, image_list[idx].rsplit()[0]), score)
for idx, score in zip(labels, scores)]
return result_list
def app():
prepare_data()
model_name = st.sidebar.selectbox(
"Model:",
[m.split('_prototype.pt')[0] for m in os.listdir(CACHED_PATH)])
st.markdown(
"<h1>Image To Image Retrieval</h1>",
unsafe_allow_html=True,
)
st.write(
'This is a demo for image to image retrieval in around 3k images from '
'ImageNet tiny val set using mmclassification apis. You can try more '
'features on [mmclassification]'
'(https://github.com/open-mmlab/mmclassification).')
file = st.file_uploader(
'Please upload your own image or use the provided:')
container1 = st.container()
if file:
raw_img = Image.open(file).convert('RGB')
else:
raw_img = load_demo_image()
container1.header('Image')
w, h = raw_img.size
scaling_factor = 360 / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor)))
container1.image(resized_image, use_column_width='auto')
button = container1.button('Search')
st.header('Results')
topk = st.sidebar.number_input('Topk(1-50)', min_value=1, max_value=50)
# search on both selection of topk and button
if button or topk > 1:
result_list = get_pred(model_name, raw_img)
# auto adjust number of images in a row but 5 at most.
col = min(int(math.sqrt(topk)), 5)
row = math.ceil(topk / col)
grid = []
for i in range(row):
with st.container():
grid.append(st.columns(col))
grid = list(itertools.chain.from_iterable(grid))[:topk]
for cell, (image_path, score) in zip(grid, result_list[:topk]):
image = Image.open(image_path).convert('RGB')
w, h = raw_img.size
scaling_factor = 360 / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor)))
cell.caption('Score: {:.4f}'.format(float(score)))
cell.image(image)
if __name__ == '__main__':
app()
|