yixionghuo commited on
Commit
361519c
1 Parent(s): 32865f3

Update onnx_inference.py

Browse files
Files changed (1) hide show
  1. onnx_inference.py +4 -1
onnx_inference.py CHANGED
@@ -127,8 +127,11 @@ if __name__ == '__main__':
127
 
128
  img0 = cv2.imread(path)
129
  img = pre_process(img0)
130
- onnx_input = {onnx_model.get_inputs()[0].name: img}
 
131
  onnx_output = onnx_model.run(None, onnx_input)
 
 
132
  pred = post_process(onnx_output, conf_thres,
133
  iou_thres, multi_label=False,
134
  classes=classes, agnostic=agnostic_nms)
 
127
 
128
  img0 = cv2.imread(path)
129
  img = pre_process(img0)
130
+ # onnx_input = {onnx_model.get_inputs()[0].name: img}
131
+ onnx_input = {onnx_model.get_inputs()[0].name: np.transpose(img, (0, 2 ,3, 1))}
132
  onnx_output = onnx_model.run(None, onnx_input)
133
+ onnx_output = [np.transpose(out, (0, 3, 1, 2)) for out in onnx_output]
134
+
135
  pred = post_process(onnx_output, conf_thres,
136
  iou_thres, multi_label=False,
137
  classes=classes, agnostic=agnostic_nms)