|
import argparse |
|
import onnx |
|
from onnx import numpy_helper |
|
import numpy as np |
|
|
|
def fix_batch_dimension(input_model_path, output_model_path, batch_size=1): |
|
|
|
model = onnx.load(input_model_path) |
|
|
|
|
|
for input_tensor in model.graph.input: |
|
|
|
tensor_shape = input_tensor.type.tensor_type.shape |
|
|
|
|
|
if len(tensor_shape.dim) > 0: |
|
tensor_shape.dim[0].dim_value = batch_size |
|
|
|
|
|
onnx.save(model, output_model_path) |
|
print(f"Model saved with updated batch size of {batch_size} to {output_model_path}") |
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(description="Fix batch dimension of an ONNX model.") |
|
parser.add_argument("input_model_path", type=str, help="Path to the input ONNX model.") |
|
parser.add_argument("output_model_path", type=str, help="Path to save the output ONNX model with fixed batch dimension.") |
|
parser.add_argument("--batch_size", type=int, default=1, help="Value of batch size to assign (default is 1).") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
fix_batch_dimension(args.input_model_path, args.output_model_path, args.batch_size) |
|
|
|
|