retinaface / widerface_onnx_evalute.py
zhengrongzhang's picture
Update widerface_onnx_evalute.py (#2)
c2b2584 verified
raw
history blame
No virus
4.16 kB
import os
import argparse
import onnxruntime as ort
from utils import *
CFG = {
"name": "mobilenet0.25",
"min_sizes": [[16, 32], [64, 128], [256, 512]],
"steps": [8, 16, 32],
"variance": [0.1, 0.2],
"clip": False,
}
INPUT_SIZE = [608, 640] #resize scale
DEVICE = torch.device("cpu")
def save_result(img_name, dets, save_folder):
"""Save detection results
Args:
img_name: origin image name
dets: detection results
save_folder: results path
"""
if not os.path.exists(save_folder):
os.makedirs(save_folder)
save_name = save_folder + img_name[:-4] + ".txt"
dirname = os.path.dirname(save_name)
if not os.path.isdir(dirname):
os.makedirs(dirname)
with open(save_name, "w") as fw:
bboxs = dets
file_name = os.path.basename(save_name)[:-4] + "\n"
bboxs_num = str(len(bboxs)) + "\n"
fw.write(file_name)
fw.write(bboxs_num)
for box in bboxs:
x = int(box[0])
y = int(box[1])
w = int(box[2]) - int(box[0])
h = int(box[3]) - int(box[1])
confidence = str(box[4])
line = (str(x) + " " + str(y) + " " + str(w) + " " + str(h) + " " + confidence + " \n")
fw.write(line)
def Retinaface_evalute(run_ort, args):
"""Retinaface_evalute function
Args:
run_ort : run_ort to evaluate.
args : parser parameter.
Returns:
predict result : under "--save_folder" path.
"""
# testing dataset
testset_folder = args.dataset_folder
testset_list = args.dataset_folder[:-7] + "wider_val.txt"
with open(testset_list, "r") as fr:
test_dataset = fr.read().split()
num_images = len(test_dataset)
# testing begin
for i, img_name in enumerate(test_dataset):
image_path = testset_folder + img_name
img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
# preprocess
img, scale, resize = preprocess(img_raw, INPUT_SIZE, DEVICE)
# to NHWC
img = np.transpose(img, (0, 2, 3, 1))
# forward
outputs = run_ort.run(None, {run_ort.get_inputs()[0].name: img})
# postprocess
dets = postprocess(CFG, img, outputs, scale, resize, args.confidence_threshold, args.nms_threshold, DEVICE)
# save predict result
save_result(img_name, dets, args.save_folder)
print("im_detect: {:d}/{:d}".format(i + 1, num_images))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Retinaface")
parser.add_argument(
"-m",
"--trained_model",
default="./weights/RetinaFace_int.onnx",
type=str,
help="Trained state_dict file path to open",
)
parser.add_argument(
"--save_folder",
default="./widerface_evaluate/widerface_txt/",
type=str,
help="Dir to save txt results",
)
parser.add_argument(
"--dataset_folder",
default="./data/widerface/val/images/",
type=str,
help="dataset path",
)
parser.add_argument(
"--confidence_threshold",
default=0.02,
type=float,
help="confidence_threshold",
)
parser.add_argument(
"--nms_threshold",
default=0.4,
type=float,
help="nms_threshold",
)
parser.add_argument(
"--ipu",
action="store_true",
help="Use IPU for inference.",
)
parser.add_argument(
"--provider_config",
type=str,
default="vaip_config.json",
help="Path of the config file for seting provider_options.",
)
args = parser.parse_args()
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
print("Loading pretrained model from {}".format(args.trained_model))
run_ort = ort.InferenceSession(args.trained_model, providers=providers, provider_options=provider_options)
Retinaface_evalute(run_ort, args)