File size: 1,758 Bytes
b334e29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11b57f1
 
b334e29
 
 
a929430
 
b334e29
 
a929430
 
 
 
 
 
b334e29
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
'''
 * Copyright (c) 2023 Salesforce, Inc.
 * All rights reserved.
 * SPDX-License-Identifier: Apache License 2.0
 * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
 * By Can Qin
 * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
 * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
 * Modified from UniFormer repo: From https://github.com/Sense-X/UniFormer
 * Apache-2.0 license
'''



import os
from annotator.uniformer_base.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot
from annotator.uniformer_base.mmseg.core.evaluation import get_palette
from annotator.util import annotator_ckpts_path

import pdb
#checkpoint_file = "https://huggingface.co/Salesforce/UniControl/blob/main/annotator/ckpts/upernet_global_base.pth"
checkpoint_file =  "https://storage.googleapis.com/sfr-unicontrol-data-research/annotator/ckpts/upernet_global_base.pth"
class UniformerDetector:
    def __init__(self):
        #modelpath = os.path.join(annotator_ckpts_path, "upernet_global_base.pth")
        #if not os.path.exists(modelpath):
        #    from basicsr.utils.download_util import load_file_from_url
        #    load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path)
        #    raise ValueError("wrong ckpt path")
        modelpath = checkpoint_file
        config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer_base", "exp", "upernet_global_base", "config.py")
        self.model = init_segmentor(config_file, modelpath).cuda()

    def __call__(self, img):
        result = inference_segmentor(self.model, img)
        res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1)
        return res_img