Skip to content

RuntimeError while saving embedding model in ONNX format #2930

@kobiche

Description

@kobiche

Describe the bug
The model cannot be saved in ONNX format as shown in the flair/resources/docs/TUTORIAL_13_TRANSFORMERS_PRODUCTION.md tutorial. The export raises the following error:
RuntimeError: Dynamic shape axis should be no more than the shape dimension for sequ_length

To Reproduce

from flair.data import Sentence
from flair.models import SequenceTagger
from flair.embeddings import TransformerWordEmbeddings, TransformerDocumentEmbeddings
import flair.datasets

if __name__ == '__main__':
    model = SequenceTagger.load("ner-large")
    assert isinstance(model.embeddings, (TransformerWordEmbeddings, TransformerDocumentEmbeddings))

    # Serialize the model (use sample sentences)
    sentences = list(flair.datasets.CONLL_03_SPANISH().test)[:5]
    model.embeddings = model.embeddings.export_onnx("flert-embeddings.onnx", sentences, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

Expected behavior
Model saved in ONNX format

Environment (please complete the following information):

  • OS: Linux
  • Flair version: master (last commit: ee7e619)
  • ONNX runtime version: 1.12.0

Additional context
I found out that this problem arises when the torch.onnx.export is called. While debugging, I found out that this comes from the order of example_tensors. The model expects the following order ['input_ids', 'lengths', 'attention_mask', 'overflow_to_sample_mapping', 'word_ids'] so I suggest that the example_tensors is reordered accordingly (quick fix):

desired_order = ['input_ids', 'lengths', 'attention_mask', 'overflow_to_sample_mapping', 'word_ids']
example_tensors = {k: example_tensors[k] for k in desired_order}

And while we are at it, I would suggest to fix the expected type of the args argument of torch.onnx.export (should be a tuple)

        torch.onnx.export(
            embedding,
            tuple(example_tensors.values()),
            ...
        )

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingwontfixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions