File size: 6,128 Bytes
7e02c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a534ec
7e02c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a534ec
7e02c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a534ec
7e02c9f
 
21bd41f
7e02c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a534ec
 
 
7e02c9f
 
21bd41f
7e02c9f
 
21bd41f
 
 
 
 
7e02c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import itertools
import math
import os.path as osp

import numpy as np
import requests
import streamlit as st
from mmengine.dataset import Compose, default_collate
from mmengine.fileio import list_from_file
from mmengine.registry import init_default_scope
from PIL import Image
import mmengine
import logging
from mmengine.logging.logger import MMFormatter
from mmcls import list_models as list_models_
from mmcls.apis.model import ModelHub, init_model
import os


@st.cache()
def prepare_data():
    import subprocess
    subprocess.run(['unzip', '-n', 'imagenet-val.zip'])


@st.cache()
def load_demo_image():
    response = requests.get(
        'https://github.com/open-mmlab/mmclassification/blob/master/demo/bird.JPEG?raw=true',  # noqa
        stream=True).raw
    img = Image.open(response).convert('RGB')
    return img


@st.cache()
def list_models(*args, **kwargs):
    return sorted(list_models_(*args, **kwargs))


DATA_ROOT = '.'
ANNO_FILE = 'meta/val.txt'
LOG_FILE = 'demo.log'
CACHED_PATH = 'cache'


def get_model(model_name, pretrained=True):

    metainfo = ModelHub.get(model_name)

    if pretrained:
        if metainfo.weights is None:
            raise ValueError(
                f"The model {model_name} doesn't have pretrained weights.")
        ckpt = metainfo.weights
    else:
        ckpt = None

    cfg = metainfo.config
    cfg.model.backbone.init_cfg = dict(
        type='Pretrained', checkpoint=ckpt, prefix='backbone')
    new_model_cfg = dict()
    new_model_cfg['type'] = 'ImageToImageRetriever'
    if hasattr(cfg.model, 'neck') and cfg.model.neck is not None:
        new_model_cfg['image_encoder'] = [cfg.model.backbone, cfg.model.neck]
    else:
        new_model_cfg['image_encoder'] = cfg.model.backbone
    cfg.model = new_model_cfg

    # prepare prototype
    cached_path = f'{CACHED_PATH}/{model_name}_prototype.pt'  # noqa
    cfg.model.prototype = cached_path

    model = init_model(metainfo.config, None, device='cpu')
    with st.spinner(f'Loading model {model_name} on the server...This is '
                    'slow at the first time.'):
        model.init_weights()
    st.success('Model loaded!')

    with st.spinner('Preparing prototype for all image...This is '
                    'slow at the first time.'):
        model.prepare_prototype()

    return model


def get_pred(name, img):

    logger = mmengine.logging.MMLogger.get_current_instance()
    file_handler = logging.FileHandler(LOG_FILE, 'w')
    # `StreamHandler` record year, month, day hour, minute,
    # and second timestamp. file_handler will only record logs
    # without color to avoid garbled code saved in files.
    file_handler.setFormatter(
        MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
    file_handler.setLevel('INFO')
    logger.handlers.append(file_handler)

    init_default_scope('mmcls')

    model = get_model(name)

    cfg = model.cfg
    # build the data pipeline
    test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
    if isinstance(img, str):
        if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
            test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
        data = dict(img_path=img)
    elif isinstance(img, np.ndarray):
        if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
            test_pipeline_cfg.pop(0)
        data = dict(img=img)
    elif isinstance(img, Image.Image):
        if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
            test_pipeline_cfg[0] = dict(type='ToNumpy', keys=['img'])
        data = dict(img=img)

    test_pipeline = Compose(test_pipeline_cfg)
    data = test_pipeline(data)
    data = default_collate([data])

    labels = model.val_step(data)[0].pred_label.label
    scores = model.val_step(data)[0].pred_label.score[labels]

    image_list = list_from_file(osp.join(DATA_ROOT, ANNO_FILE))
    data_root = osp.join(DATA_ROOT, 'val')
    result_list = [(osp.join(data_root, image_list[idx].rsplit()[0]), score)
                   for idx, score in zip(labels, scores)]
    return result_list


def app():
    prepare_data()

    model_name = st.sidebar.selectbox(
        "Model:",
        [m.split('_prototype.pt')[0] for m in os.listdir(CACHED_PATH)])

    st.markdown(
        "<h1>Image To Image Retrieval</h1>",
        unsafe_allow_html=True,
    )
    st.write(
        'This is a demo for image to image retrieval in around 3k images from '
        'ImageNet tiny val set using mmclassification apis. You can try more '
        'features on [mmclassification]'
        '(https://github.com/open-mmlab/mmclassification).')

    file = st.file_uploader(
        'Please upload your own image or use the provided:')

    container1 = st.container()
    if file:
        raw_img = Image.open(file).convert('RGB')
    else:
        raw_img = load_demo_image()

    container1.header('Image')

    w, h = raw_img.size
    scaling_factor = 360 / w
    resized_image = raw_img.resize(
        (int(w * scaling_factor), int(h * scaling_factor)))

    container1.image(resized_image, use_column_width='auto')
    button = container1.button('Search')

    st.header('Results')

    topk = st.sidebar.number_input('Topk(1-50)', min_value=1, max_value=50)

    # search on both selection of topk and button
    if button or topk > 1:

        result_list = get_pred(model_name, raw_img)
        # auto adjust number of images in a row but 5 at most.
        col = min(int(math.sqrt(topk)), 5)
        row = math.ceil(topk / col)

        grid = []
        for i in range(row):
            with st.container():
                grid.append(st.columns(col))

        grid = list(itertools.chain.from_iterable(grid))[:topk]

        for cell, (image_path, score) in zip(grid, result_list[:topk]):
            image = Image.open(image_path).convert('RGB')

            w, h = raw_img.size
            scaling_factor = 360 / w
            resized_image = raw_img.resize(
                (int(w * scaling_factor), int(h * scaling_factor)))

            cell.caption('Score: {:.4f}'.format(float(score)))
            cell.image(image)


if __name__ == '__main__':
    app()