import onnx | |
import argparse | |
def modify_batch_size(input_model_path, output_model_path, batch_size): | |
# Load the ONNX model | |
model = onnx.load(input_model_path) | |
# Modify the batch size of the first input tensor | |
input_tensor = model.graph.input[0] | |
input_tensor.type.tensor_type.shape.dim[0].dim_value = batch_size | |
# Save the modified model | |
onnx.save(model, output_model_path) | |
print(f"Modified model saved to {output_model_path}") | |
if __name__ == "__main__": | |
# Parse command line arguments | |
parser = argparse.ArgumentParser(description="Modify the batch size of the first input tensor of an ONNX model.") | |
parser.add_argument("input_model_path", type=str, help="Path to the input ONNX model file.") | |
parser.add_argument("output_model_path", type=str, help="Path to save the modified ONNX model file.") | |
parser.add_argument("--batch_size", type=int, help="Desired batch size for the first input tensor.") | |
args = parser.parse_args() | |
# Ensure that the batch_size argument has been provided | |
if args.batch_size is None: | |
raise ValueError("Please specify the new batch size for the input tensor using --batch_size.") | |
# Call the function to modify the model's batch size | |
modify_batch_size(args.input_model_path, args.output_model_path, args.batch_size) | |