Luigi commited on
Commit
d0c0220
1 Parent(s): cc2ef74

Add calibration for int8 quantization

Browse files
Files changed (1) hide show
  1. yolo_nas_pose_to_onnx.py +68 -2
yolo_nas_pose_to_onnx.py CHANGED
@@ -12,6 +12,12 @@ import onnxruntime
12
  import os
13
  from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization
14
  import matplotlib.pyplot as plt
 
 
 
 
 
 
15
 
16
  os.environ['CRASH_HANDLER']='0'
17
 
@@ -19,7 +25,7 @@ os.environ['CRASH_HANDLER']='0'
19
 
20
  CONVERSION = True
21
  input_image_shape = [640, 640]
22
- quantization_modes = [None, ExportQuantizationMode.INT8, ExportQuantizationMode.FP16]
23
  output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT
24
 
25
  # NMS-related Setting
@@ -37,6 +43,61 @@ image_name = "https://deci-pretrained-models.s3.amazonaws.com/sample_images/beat
37
  # Check
38
  SHAPE_CHECK=True
39
  VISUAL_CHECK=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def iterate_over_flat_predictions(predictions, batch_size):
42
  [flat_predictions] = predictions
@@ -65,6 +126,11 @@ image = load_image(image_name)
65
  image = cv2.resize(image, (input_image_shape[1], input_image_shape[0]))
66
  image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))
67
 
 
 
 
 
 
68
  for model_name in [Models.YOLO_NAS_POSE_L, Models.YOLO_NAS_POSE_M, Models.YOLO_NAS_POSE_N, Models.YOLO_NAS_POSE_S ]:
69
  for q in quantization_modes:
70
 
@@ -94,7 +160,7 @@ for model_name in [Models.YOLO_NAS_POSE_L, Models.YOLO_NAS_POSE_M, Models.YOLO
94
  engine=ExportTargetBackend.ONNXRUNTIME,
95
  quantization_mode=q,
96
  #selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa
97
- #calibration_loader: Optional[DataLoader] = None,
98
  #calibration_method: str = "percentile",
99
  #calibration_batches: int = 16,
100
  #calibration_percentile: float = 99.99,
 
12
  import os
13
  from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization
14
  import matplotlib.pyplot as plt
15
+ from datasets import load_dataset
16
+ from torchvision import transforms
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from torchvision import transforms
19
+ import matplotlib.pyplot as plt
20
+
21
 
22
  os.environ['CRASH_HANDLER']='0'
23
 
 
25
 
26
  CONVERSION = True
27
  input_image_shape = [640, 640]
28
+ quantization_modes = [ExportQuantizationMode.INT8, ExportQuantizationMode.FP16, None]
29
  output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT
30
 
31
  # NMS-related Setting
 
43
  # Check
44
  SHAPE_CHECK=True
45
  VISUAL_CHECK=True
46
+ CALIBRATION_DATASET_CHECK=False
47
+
48
+ # Function to convert tensor to image for visualization
49
+ def tensor_to_image(tensor):
50
+ # Convert the tensor to a numpy array
51
+ numpy_image = tensor.numpy()
52
+
53
+ # The output of ToTensor() is in C x H x W format, convert to H x W x C
54
+ numpy_image = numpy_image.transpose(1, 2, 0)
55
+
56
+ # Undo the normalization (if any)
57
+ # numpy_image = numpy_image * std + mean # Adjust based on your normalization
58
+
59
+ return numpy_image
60
+
61
+ class HFDatasetWrapper(Dataset):
62
+ def __init__(self, hf_dataset, transform=None):
63
+ self.hf_dataset = hf_dataset
64
+ self.transform = transform
65
+
66
+ def __len__(self):
67
+ return len(self.hf_dataset)
68
+
69
+ def __getitem__(self, idx):
70
+ item = self.hf_dataset[idx]
71
+ if self.transform:
72
+ item = self.transform(item)
73
+ return item['image']
74
+
75
+ def preprocess(data):
76
+ # Convert byte data to PIL Image
77
+ image = data['image']
78
+
79
+ # Convert to RGB if not already
80
+ if image.mode != 'RGB':
81
+ image = image.convert('RGB')
82
+
83
+ # Define your transformations
84
+ transform = transforms.Compose([
85
+ transforms.Resize((640, 640)), # Resize (example size)
86
+ transforms.ToTensor(), # Convert to tensor
87
+ # Add normalization or other transformations if needed
88
+ ])
89
+
90
+ # Process Image
91
+ transformed = transform(image)
92
+
93
+ if CALIBRATION_DATASET_CHECK:
94
+ # Display the Processed Image
95
+ plt_image = tensor_to_image(transformed)
96
+ plt.imshow(plt_image)
97
+ plt.axis('off') # Turn off axis numbers
98
+ plt.show()
99
+
100
+ return {'image': transformed}
101
 
102
  def iterate_over_flat_predictions(predictions, batch_size):
103
  [flat_predictions] = predictions
 
126
  image = cv2.resize(image, (input_image_shape[1], input_image_shape[0]))
127
  image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))
128
 
129
+ # Prepare Calibration Dataset for INT8 Quantization
130
+ dataset = load_dataset("cppe-5", split="test")
131
+ hf_dataset_wrapper = HFDatasetWrapper(dataset, transform=preprocess)
132
+ calibration_loader = DataLoader(hf_dataset_wrapper, batch_size=8)
133
+
134
  for model_name in [Models.YOLO_NAS_POSE_L, Models.YOLO_NAS_POSE_M, Models.YOLO_NAS_POSE_N, Models.YOLO_NAS_POSE_S ]:
135
  for q in quantization_modes:
136
 
 
160
  engine=ExportTargetBackend.ONNXRUNTIME,
161
  quantization_mode=q,
162
  #selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa
163
+ calibration_loader = calibration_loader,
164
  #calibration_method: str = "percentile",
165
  #calibration_batches: int = 16,
166
  #calibration_percentile: float = 99.99,