pesi
/

Luigi commited on
Commit
8072759
1 Parent(s): 43a489d

Add new class 'RTMO_GPU_Batch' that can perform inference on batch of images

Browse files
Files changed (1) hide show
  1. rtmo_gpu.py +84 -0
rtmo_gpu.py CHANGED
@@ -378,3 +378,87 @@ class RTMO_GPU(object):
378
  self.std = std
379
  self.device = device
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  self.std = std
379
  self.device = device
380
 
381
+ class RTMO_GPU_Batch(RTMO_GPU):
382
+ def preprocess_batch(self, imgs: List[np.ndarray]) -> Tuple[np.ndarray, List[float]]:
383
+ """Process a batch of images for RTMPose model inference.
384
+
385
+ Args:
386
+ imgs (List[np.ndarray]): List of input images.
387
+
388
+ Returns:
389
+ tuple:
390
+ - batch_img (np.ndarray): Batch of preprocessed images.
391
+ - ratios (List[float]): Ratios used for preprocessing each image.
392
+ """
393
+ batch_img = []
394
+ ratios = []
395
+
396
+ for img in imgs:
397
+ preprocessed_img, ratio = super().preprocess(img)
398
+ batch_img.append(preprocessed_img)
399
+ ratios.append(ratio)
400
+
401
+ # Stack along the first dimension to create a batch
402
+ batch_img = np.stack(batch_img, axis=0)
403
+
404
+ return batch_img, ratios
405
+
406
+ def inference(self, batch_img: np.ndarray):
407
+ """Override to handle batch inference.
408
+
409
+ Args:
410
+ batch_img (np.ndarray): Batch of preprocessed images.
411
+
412
+ Returns:
413
+ outputs (List[np.ndarray]): Outputs of RTMPose model for each image.
414
+ """
415
+ batch_img = batch_img.transpose(0, 3, 1, 2) # NCHW format
416
+ batch_img = np.ascontiguousarray(batch_img, dtype=np.float32)
417
+
418
+ input = batch_img
419
+
420
+ # Create an IO Binding object
421
+ io_binding = self.session.io_binding()
422
+
423
+ # Bind the model inputs and outputs to the IO Binding object
424
+ io_binding.bind_input(name='input', device_type='cpu', device_id=0, element_type=np.float32, shape=input.shape, buffer_ptr=input.ctypes.data)
425
+ io_binding.bind_output(name='dets')
426
+ io_binding.bind_output(name='keypoints')
427
+
428
+ # Run inference with IO Binding
429
+ self.session.run_with_iobinding(io_binding)
430
+
431
+ # Retrieve the outputs from the IO Binding object
432
+ outputs = [output.numpy() for output in io_binding.get_outputs()]
433
+
434
+ return outputs
435
+
436
+ def postprocess_batch(
437
+ self,
438
+ outputs: List[np.ndarray],
439
+ ratios: List[float]
440
+ ) -> List[Tuple[np.ndarray, np.ndarray]]:
441
+ """Process outputs for a batch of images.
442
+
443
+ Args:
444
+ outputs (List[np.ndarray]): Outputs from the model for each image.
445
+ ratios (List[float]): Ratios used for preprocessing each image.
446
+
447
+ Returns:
448
+ List[Tuple[np.ndarray, np.ndarray]]: keypoints and scores for each image.
449
+ """
450
+ batch_keypoints = []
451
+ batch_scores = []
452
+
453
+ for i, ratio in enumerate(ratios):
454
+ keypoints, scores = super().postprocess(outputs, ratio)
455
+ batch_keypoints.append(keypoints)
456
+ batch_scores.append(scores)
457
+
458
+ return batch_keypoints, batch_scores
459
+
460
+ def __call__(self, images: List[np.ndarray]):
461
+ batch_img, ratios = self.preprocess_batch(images)
462
+ outputs = self.inference(batch_img)
463
+ keypoints, scores = self.postprocess_batch(outputs, ratios)
464
+ return keypoints, scores