pesi
/

rtmo / fix_input_batch_size.py
Luigi's picture
Make ONNX models compatible with ONNXruntime's TensorrtExecutionProvider
b75f05d
raw
history blame
1.32 kB
import onnx
import argparse
def modify_batch_size(input_model_path, output_model_path, batch_size):
# Load the ONNX model
model = onnx.load(input_model_path)
# Modify the batch size of the first input tensor
input_tensor = model.graph.input[0]
input_tensor.type.tensor_type.shape.dim[0].dim_value = batch_size
# Save the modified model
onnx.save(model, output_model_path)
print(f"Modified model saved to {output_model_path}")
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(description="Modify the batch size of the first input tensor of an ONNX model.")
parser.add_argument("input_model_path", type=str, help="Path to the input ONNX model file.")
parser.add_argument("output_model_path", type=str, help="Path to save the modified ONNX model file.")
parser.add_argument("--batch_size", type=int, help="Desired batch size for the first input tensor.")
args = parser.parse_args()
# Ensure that the batch_size argument has been provided
if args.batch_size is None:
raise ValueError("Please specify the new batch size for the input tensor using --batch_size.")
# Call the function to modify the model's batch size
modify_batch_size(args.input_model_path, args.output_model_path, args.batch_size)