Spaces:
Running
Running
import os | |
from torch.multiprocessing import Process, Manager, set_start_method, Pool | |
import functools | |
import argparse | |
import yaml | |
import numpy as np | |
import sys | |
import cv2 | |
from tqdm import trange | |
set_start_method("spawn", force=True) | |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
sys.path.insert(0, ROOT_DIR) | |
from components import load_component | |
from utils import evaluation_utils, metrics | |
parser = argparse.ArgumentParser(description="dump eval data.") | |
parser.add_argument( | |
"--config_path", type=str, default="configs/eval/scannet_eval_sgm.yaml" | |
) | |
parser.add_argument("--num_process_match", type=int, default=4) | |
parser.add_argument("--num_process_eval", type=int, default=4) | |
parser.add_argument("--vis_folder", type=str, default=None) | |
args = parser.parse_args() | |
def feed_match(info, matcher): | |
x1, x2, desc1, desc2, size1, size2 = ( | |
info["x1"], | |
info["x2"], | |
info["desc1"], | |
info["desc2"], | |
info["img1"].shape[:2], | |
info["img2"].shape[:2], | |
) | |
test_data = { | |
"x1": x1, | |
"x2": x2, | |
"desc1": desc1, | |
"desc2": desc2, | |
"size1": np.flip(np.asarray(size1)), | |
"size2": np.flip(np.asarray(size2)), | |
} | |
corr1, corr2 = matcher.run(test_data) | |
return [corr1, corr2] | |
def reader_handler(config, read_que): | |
reader = load_component("reader", config["name"], config) | |
for index in range(len(reader)): | |
index += 0 | |
info = reader.run(index) | |
read_que.put(info) | |
read_que.put("over") | |
def match_handler(config, read_que, match_que): | |
matcher = load_component("matcher", config["name"], config) | |
match_func = functools.partial(feed_match, matcher=matcher) | |
pool = Pool(args.num_process_match) | |
cache = [] | |
while True: | |
item = read_que.get() | |
# clear cache | |
if item == "over": | |
if len(cache) != 0: | |
results = pool.map(match_func, cache) | |
for cur_item, cur_result in zip(cache, results): | |
cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1] | |
match_que.put(cur_item) | |
match_que.put("over") | |
break | |
cache.append(item) | |
# print(len(cache)) | |
if len(cache) == args.num_process_match: | |
# matching in parallel | |
results = pool.map(match_func, cache) | |
for cur_item, cur_result in zip(cache, results): | |
cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1] | |
match_que.put(cur_item) | |
cache = [] | |
pool.close() | |
pool.join() | |
def evaluate_handler(config, match_que): | |
evaluator = load_component("evaluator", config["name"], config) | |
pool = Pool(args.num_process_eval) | |
cache = [] | |
for _ in trange(config["num_pair"]): | |
item = match_que.get() | |
if item == "over": | |
if len(cache) != 0: | |
results = pool.map(evaluator.run, cache) | |
for cur_res in results: | |
evaluator.res_inqueue(cur_res) | |
break | |
cache.append(item) | |
if len(cache) == args.num_process_eval: | |
results = pool.map(evaluator.run, cache) | |
for cur_res in results: | |
evaluator.res_inqueue(cur_res) | |
cache = [] | |
if args.vis_folder is not None: | |
# dump visualization | |
corr1_norm, corr2_norm = evaluation_utils.normalize_intrinsic( | |
item["corr1"], item["K1"] | |
), evaluation_utils.normalize_intrinsic(item["corr2"], item["K2"]) | |
inlier_mask = metrics.compute_epi_inlier( | |
corr1_norm, corr2_norm, item["e"], config["inlier_th"] | |
) | |
display = evaluation_utils.draw_match( | |
item["img1"], item["img2"], item["corr1"], item["corr2"], inlier_mask | |
) | |
cv2.imwrite( | |
os.path.join(args.vis_folder, str(item["index"]) + ".png"), display | |
) | |
evaluator.parse() | |
if __name__ == "__main__": | |
with open(args.config_path, "r") as f: | |
config = yaml.load(f) | |
if args.vis_folder is not None and not os.path.exists(args.vis_folder): | |
os.mkdir(args.vis_folder) | |
read_que, match_que, estimate_que = ( | |
Manager().Queue(maxsize=100), | |
Manager().Queue(maxsize=100), | |
Manager().Queue(maxsize=100), | |
) | |
read_process = Process(target=reader_handler, args=(config["reader"], read_que)) | |
match_process = Process( | |
target=match_handler, args=(config["matcher"], read_que, match_que) | |
) | |
evaluate_process = Process( | |
target=evaluate_handler, args=(config["evaluator"], match_que) | |
) | |
read_process.start() | |
match_process.start() | |
evaluate_process.start() | |
read_process.join() | |
match_process.join() | |
evaluate_process.join() | |