Add new class 'RTMO_GPU_Batch' that can perform inference on batch of images
Browse files- rtmo_gpu.py +84 -0
rtmo_gpu.py
CHANGED
@@ -378,3 +378,87 @@ class RTMO_GPU(object):
|
|
378 |
self.std = std
|
379 |
self.device = device
|
380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
self.std = std
|
379 |
self.device = device
|
380 |
|
381 |
+
class RTMO_GPU_Batch(RTMO_GPU):
|
382 |
+
def preprocess_batch(self, imgs: List[np.ndarray]) -> Tuple[np.ndarray, List[float]]:
|
383 |
+
"""Process a batch of images for RTMPose model inference.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
imgs (List[np.ndarray]): List of input images.
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
tuple:
|
390 |
+
- batch_img (np.ndarray): Batch of preprocessed images.
|
391 |
+
- ratios (List[float]): Ratios used for preprocessing each image.
|
392 |
+
"""
|
393 |
+
batch_img = []
|
394 |
+
ratios = []
|
395 |
+
|
396 |
+
for img in imgs:
|
397 |
+
preprocessed_img, ratio = super().preprocess(img)
|
398 |
+
batch_img.append(preprocessed_img)
|
399 |
+
ratios.append(ratio)
|
400 |
+
|
401 |
+
# Stack along the first dimension to create a batch
|
402 |
+
batch_img = np.stack(batch_img, axis=0)
|
403 |
+
|
404 |
+
return batch_img, ratios
|
405 |
+
|
406 |
+
def inference(self, batch_img: np.ndarray):
|
407 |
+
"""Override to handle batch inference.
|
408 |
+
|
409 |
+
Args:
|
410 |
+
batch_img (np.ndarray): Batch of preprocessed images.
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
outputs (List[np.ndarray]): Outputs of RTMPose model for each image.
|
414 |
+
"""
|
415 |
+
batch_img = batch_img.transpose(0, 3, 1, 2) # NCHW format
|
416 |
+
batch_img = np.ascontiguousarray(batch_img, dtype=np.float32)
|
417 |
+
|
418 |
+
input = batch_img
|
419 |
+
|
420 |
+
# Create an IO Binding object
|
421 |
+
io_binding = self.session.io_binding()
|
422 |
+
|
423 |
+
# Bind the model inputs and outputs to the IO Binding object
|
424 |
+
io_binding.bind_input(name='input', device_type='cpu', device_id=0, element_type=np.float32, shape=input.shape, buffer_ptr=input.ctypes.data)
|
425 |
+
io_binding.bind_output(name='dets')
|
426 |
+
io_binding.bind_output(name='keypoints')
|
427 |
+
|
428 |
+
# Run inference with IO Binding
|
429 |
+
self.session.run_with_iobinding(io_binding)
|
430 |
+
|
431 |
+
# Retrieve the outputs from the IO Binding object
|
432 |
+
outputs = [output.numpy() for output in io_binding.get_outputs()]
|
433 |
+
|
434 |
+
return outputs
|
435 |
+
|
436 |
+
def postprocess_batch(
|
437 |
+
self,
|
438 |
+
outputs: List[np.ndarray],
|
439 |
+
ratios: List[float]
|
440 |
+
) -> List[Tuple[np.ndarray, np.ndarray]]:
|
441 |
+
"""Process outputs for a batch of images.
|
442 |
+
|
443 |
+
Args:
|
444 |
+
outputs (List[np.ndarray]): Outputs from the model for each image.
|
445 |
+
ratios (List[float]): Ratios used for preprocessing each image.
|
446 |
+
|
447 |
+
Returns:
|
448 |
+
List[Tuple[np.ndarray, np.ndarray]]: keypoints and scores for each image.
|
449 |
+
"""
|
450 |
+
batch_keypoints = []
|
451 |
+
batch_scores = []
|
452 |
+
|
453 |
+
for i, ratio in enumerate(ratios):
|
454 |
+
keypoints, scores = super().postprocess(outputs, ratio)
|
455 |
+
batch_keypoints.append(keypoints)
|
456 |
+
batch_scores.append(scores)
|
457 |
+
|
458 |
+
return batch_keypoints, batch_scores
|
459 |
+
|
460 |
+
def __call__(self, images: List[np.ndarray]):
|
461 |
+
batch_img, ratios = self.preprocess_batch(images)
|
462 |
+
outputs = self.inference(batch_img)
|
463 |
+
keypoints, scores = self.postprocess_batch(outputs, ratios)
|
464 |
+
return keypoints, scores
|