Skip to content

Question answering pipeline: error for long text sequences when max_seq_len is specified #17241

@ATroxler

Description

@ATroxler

System Info

- `transformers` version: 4.17.0
- Platform: Linux-5.4.188+-x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.7.13
- PyTorch version (GPU?): 1.11.0+cu113 (False)
- Tensorflow version (GPU?): 2.8.0 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: no

Who can help?

@Narsil

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

code:

#!pip install transformers==4.16.0
!pip install transformers==4.17.0
from transformers import pipeline
context = 100 * "The quick brown fox jumps over the lazy dog. "
qa_pipeline = pipeline("question-answering", max_seq_len=2000)
qa_pipeline(question="what does the fox do?", context=context)

exception traceback:

No model was supplied, defaulted to distilbert-base-cased-distilled-squad (https://huggingface.co/distilbert-base-cased-distilled-squad)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-4-d1e4a860038f>](https://localhost:8080/#) in <module>()
      1 qa_pipeline = pipeline("question-answering", max_seq_len=2000)
----> 2 qa_pipeline(question="what does the fox do?", context=context)

10 frames
[/usr/local/lib/python3.7/dist-packages/transformers/pipelines/question_answering.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    249         examples = self._args_parser(*args, **kwargs)
    250         if len(examples) == 1:
--> 251             return super().__call__(examples[0], **kwargs)
    252         return super().__call__(examples, **kwargs)
    253 

[/usr/local/lib/python3.7/dist-packages/transformers/pipelines/base.py](https://localhost:8080/#) in __call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1025             return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
   1026         else:
-> 1027             return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
   1028 
   1029     def run_multi(self, inputs, preprocess_params, forward_params, postprocess_params):

[/usr/local/lib/python3.7/dist-packages/transformers/pipelines/base.py](https://localhost:8080/#) in run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
   1047         all_outputs = []
   1048         for model_inputs in self.preprocess(inputs, **preprocess_params):
-> 1049             model_outputs = self.forward(model_inputs, **forward_params)
   1050             all_outputs.append(model_outputs)
   1051         outputs = self.postprocess(all_outputs, **postprocess_params)

[/usr/local/lib/python3.7/dist-packages/transformers/pipelines/base.py](https://localhost:8080/#) in forward(self, model_inputs, **forward_params)
    942                 with inference_context():
    943                     model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
--> 944                     model_outputs = self._forward(model_inputs, **forward_params)
    945                     model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
    946             else:

[/usr/local/lib/python3.7/dist-packages/transformers/pipelines/question_answering.py](https://localhost:8080/#) in _forward(self, inputs)
    369         example = inputs["example"]
    370         model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names}
--> 371         start, end = self.model(**model_inputs)[:2]
    372         return {"start": start, "end": end, "example": example, **inputs}
    373 

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/transformers/models/distilbert/modeling_distilbert.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, head_mask, inputs_embeds, start_positions, end_positions, output_attentions, output_hidden_states, return_dict)
    853             output_attentions=output_attentions,
    854             output_hidden_states=output_hidden_states,
--> 855             return_dict=return_dict,
    856         )
    857         hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/transformers/models/distilbert/modeling_distilbert.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
    546 
    547         if inputs_embeds is None:
--> 548             inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)
    549         return self.transformer(
    550             x=inputs_embeds,

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/transformers/models/distilbert/modeling_distilbert.py](https://localhost:8080/#) in forward(self, input_ids)
    131         position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)
    132 
--> 133         embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
    134         embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
    135         embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)

RuntimeError: The size of tensor a (1009) must match the size of tensor b (512) at non-singleton dimension 1

Expected behavior

Run through and produce a result similar to the following, like with transformers 4.16.0

{'answer': 'The quick brown fox jumps over the lazy dog',
 'end': 3418,
 'score': 0.017251048237085342,
 'start': 3375}

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions