File size: 1,246 Bytes
7ad7e4d f9a6075 7ad7e4d f9a6075 7ad7e4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import argparse
from onnxmltools.utils.float16_converter import convert_float_to_float16
from onnxmltools.utils import load_model, save_model
node_block_list = ['Sin_689', 'MatMul_694', 'MatMul_698', 'Clip_699', 'Clip_700', 'Sub_702', 'Sub_704']
def main():
# Set up an argument parser
parser = argparse.ArgumentParser(description='Convert ONNX model from Float32 to Float16.')
parser.add_argument('--input_model', type=str, required=True, help='Path to the input ONNX model file.')
parser.add_argument('--output_model', type=str, required=True, help='Path for saving the converted ONNX model file.')
# Parse arguments
args = parser.parse_args()
# Load the model
print(f"Loading model from {args.input_model}")
onnx_model = load_model(args.input_model)
# Convert model from Float32 to Float16
print("Converting model...")
new_onnx_model = convert_float_to_float16(onnx_model, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=True, node_block_list=node_block_list)
# Save the converted model
print(f"Saving converted model to {args.output_model}")
save_model(new_onnx_model, args.output_model)
print("Conversion complete.")
if __name__ == "__main__":
main()
|