ONNX Export

#5
by frederikschubert - opened

Are there plans to convert the model to onnx for easier handling in e.g. Vespa?

I tried to export the model using optimum but it fails during the graph optimisation.
optimum-cli export onnx --model jinaai/jina-colbert-v2 --trust-remote-code --task feature-extraction --opset 20 jina-colbert-v2

Using framework PyTorch: 2.4.1+cu12
...
Traceback (most recent call last):
  File ".../.venv/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File ".../.venv/lib/python3.11/site-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File ".../.venv/lib/python3.11/site-packages/optimum/commands/export/onnx.py", line 265, in run
    main_export(
  File ".../.venv/lib/python3.11/site-packages/optimum/exporters/onnx/__main__.py", line 365, in main_export
    onnx_export_from_model(
  File ".../.venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 1170, in onnx_export_from_model
    _, onnx_outputs = export_models(
                      ^^^^^^^^^^^^^^
  File ".../.venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 776, in export_models
    export(
  File ".../.venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 881, in export
    export_output = export_pytorch(
                    ^^^^^^^^^^^^^^^
  File ".../.venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 577, in export_pytorch
    onnx_export(
  File ".../.venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 551, in export
    _export(
  File ".../.venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1648, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File ".../.venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1174, in _model_to_graph
    graph = _optimize_graph(
            ^^^^^^^^^^^^^^^^
  File ".../.venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 656, in _optimize_graph
    _C._jit_pass_peephole(graph, True)
IndexError: Argument passed to at() was not in the map.
Jina AI org

Hi @frederikschubert , yes, let me do it today or tomorrow

I'm also interested in using this for Vespa, let me see if I can export then quantize. Thanks for sharing this model btw -- arguably the top retriever model currently available :)

Jina AI org

Hi @frederikschubert , the issue was due to the dynamic nature of the RoPE implementation. I've just uploaded the ONNX weights. Let me know how it works for you.

Great, thanks for your help! Would you mind sharing your conversion script? Either way, this is really helpful for us!

Jina AI org

Sure, here it is:

import torch
from transformers import AutoModel, AutoTokenizer
import torch.onnx


model = AutoModel.from_pretrained('/home/admin/saba/jina-colbert-v2', trust_remote_code=True, use_flash_attn=False)
model.eval()

onnx_path =  '/home/admin/saba/jina-colbert-v2/onnx/model.onnx'

tokenizer = AutoTokenizer.from_pretrained('/home/admin/saba/jina-colbert-v2')
inputs = tokenizer(["jina", 'ai'], return_tensors="pt", padding='longest')
inps = inputs['input_ids']
mask = inputs['attention_mask']

torch.onnx.export(
    model,
    (inps, mask),
    onnx_path,
    export_params=True,
    do_constant_folding=True,
    input_names = ['input_ids', 'attention_mask'],
    output_names = ['text_embeds'],
    opset_version=16,
    dynamic_axes={
        'input_ids' : {0 : 'batch_size', 1: 'sequence_length'},
        'attention_mask' : {0 : 'batch_size', 1: 'sequence_length'},
        'text_embeds' : {0 : 'batch_size'}
    },
)

But the main challenge was to modify the model implementation to make it compatible with ONNX. That's why you were having the error.

Great, thanks again! :)

frederikschubert changed discussion status to closed

Sign up or log in to comment