|
import os |
|
import argparse |
|
|
|
import onnxruntime as ort |
|
from utils import * |
|
|
|
|
|
CFG = { |
|
"name": "mobilenet0.25", |
|
"min_sizes": [[16, 32], [64, 128], [256, 512]], |
|
"steps": [8, 16, 32], |
|
"variance": [0.1, 0.2], |
|
"clip": False, |
|
} |
|
INPUT_SIZE = [608, 640] |
|
DEVICE = torch.device("cpu") |
|
|
|
|
|
def vis(img_raw, dets, vis_thres): |
|
"""Visualization original image |
|
Args: |
|
img_raw: origin image |
|
dets: detections |
|
vis_thres: visualization threshold |
|
Returns: |
|
visualization results |
|
""" |
|
for b in dets: |
|
if b[4] < vis_thres: |
|
continue |
|
text = "{:.4f}".format(b[4]) |
|
b = list(map(int, b)) |
|
cv2.rectangle(img_raw, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) |
|
cx = b[0] |
|
cy = b[1] + 12 |
|
cv2.putText(img_raw, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255),) |
|
|
|
|
|
cv2.circle(img_raw, (b[5], b[6]), 1, (0, 0, 255), 4) |
|
cv2.circle(img_raw, (b[7], b[8]), 1, (0, 255, 255), 4) |
|
cv2.circle(img_raw, (b[9], b[10]), 1, (255, 0, 255), 4) |
|
cv2.circle(img_raw, (b[11], b[12]), 1, (0, 255, 0), 4) |
|
cv2.circle(img_raw, (b[13], b[14]), 1, (255, 0, 0), 4) |
|
|
|
if not os.path.exists("./results/"): |
|
os.makedirs("./results/") |
|
name = "./results/" + 'result' + ".jpg" |
|
cv2.imwrite(name, img_raw) |
|
|
|
|
|
def Retinaface_inference(run_ort, args): |
|
"""Infer an image with onnx seession |
|
Args: |
|
run_ort: Onnx session |
|
args: including image path and hyperparameters |
|
Returns: boxes_list, confidence_list, landm_list |
|
boxes_list = [[left, top, right, bottom]...] |
|
confidence_list = [[confidence]...] |
|
landm_list = [[landms(dim=10)]...] |
|
""" |
|
img_raw = cv2.imread(args.image_path, cv2.IMREAD_COLOR) |
|
|
|
img, scale, resize = preprocess(img_raw, INPUT_SIZE, DEVICE) |
|
|
|
img = np.transpose(img, (0, 2, 3, 1)) |
|
|
|
outputs = run_ort.run(None, {run_ort.get_inputs()[0].name: img}) |
|
|
|
dets = postprocess(CFG, img, outputs, scale, resize, args.confidence_threshold, args.nms_threshold, DEVICE) |
|
|
|
|
|
boxes = dets[:, :4] |
|
confidences = dets[:, 4:5] |
|
landms = dets[:, 5:] |
|
boxes_list = [box.tolist() for box in boxes] |
|
confidence_list = [confidence.tolist() for confidence in confidences] |
|
landm_list = [landm.tolist() for landm in landms] |
|
|
|
|
|
if args.save_image: |
|
vis(img_raw, dets, args.vis_thres) |
|
|
|
return boxes_list, confidence_list, landm_list |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description="Retinaface") |
|
parser.add_argument( |
|
"-m", |
|
"--trained_model", |
|
default="./weights/RetinaFace_int.onnx", |
|
type=str, |
|
help="Trained state_dict file path to open", |
|
) |
|
parser.add_argument( |
|
"--image_path", |
|
default="./data/widerface/val/images/18--Concerts/18_Concerts_Concerts_18_38.jpg", |
|
type=str, |
|
help="image path", |
|
) |
|
parser.add_argument( |
|
"--confidence_threshold", |
|
default=0.4, |
|
type=float, |
|
help="confidence_threshold" |
|
) |
|
parser.add_argument( |
|
"--nms_threshold", |
|
default=0.4, |
|
type=float, |
|
help="nms_threshold" |
|
) |
|
parser.add_argument( |
|
"-s", |
|
"--save_image", |
|
action="store_true", |
|
default=False, |
|
help="show detection results", |
|
) |
|
parser.add_argument( |
|
"--vis_thres", |
|
default=0.5, |
|
type=float, |
|
help="visualization_threshold" |
|
) |
|
parser.add_argument( |
|
"--ipu", |
|
action="store_true", |
|
help="Use IPU for inference.", |
|
) |
|
parser.add_argument( |
|
"--provider_config", |
|
type=str, |
|
default="vaip_config.json", |
|
help="Path of the config file for seting provider_options.", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
if args.ipu: |
|
providers = ["VitisAIExecutionProvider"] |
|
provider_options = [{"config_file": args.provider_config}] |
|
else: |
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
provider_options = None |
|
|
|
print("Loading pretrained model from {}".format(args.trained_model)) |
|
run_ort = ort.InferenceSession(args.trained_model, providers=providers, provider_options=provider_options) |
|
|
|
boxes_list, confidence_list, landm_list = Retinaface_inference(run_ort, args) |
|
print('inference done!') |
|
print(boxes_list, confidence_list, landm_list) |
|
|