import onnx import argparse def modify_batch_size(input_model_path, output_model_path, batch_size): # Load the ONNX model model = onnx.load(input_model_path) # Modify the batch size of the first input tensor input_tensor = model.graph.input[0] input_tensor.type.tensor_type.shape.dim[0].dim_value = batch_size # Save the modified model onnx.save(model, output_model_path) print(f"Modified model saved to {output_model_path}") if __name__ == "__main__": # Parse command line arguments parser = argparse.ArgumentParser(description="Modify the batch size of the first input tensor of an ONNX model.") parser.add_argument("input_model_path", type=str, help="Path to the input ONNX model file.") parser.add_argument("output_model_path", type=str, help="Path to save the modified ONNX model file.") parser.add_argument("--batch_size", type=int, help="Desired batch size for the first input tensor.") args = parser.parse_args() # Ensure that the batch_size argument has been provided if args.batch_size is None: raise ValueError("Please specify the new batch size for the input tensor using --batch_size.") # Call the function to modify the model's batch size modify_batch_size(args.input_model_path, args.output_model_path, args.batch_size)