Skip to content

[Bug]: Reduce transformer vocab of XLM-RoBERTa #3368

@anna-shopova

Description

@anna-shopova

Describe the bug

I encounter an IndexError when setting the reduce_transformer_vocab=True parameter in the train() method and using the TransformerWordEmbeddings with model='xlm-roberta-base'.

To Reproduce

import flair
from flair.data import Corpus
from flair.datasets import ColumnCorpus
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SpanClassifier
from flair.trainers import ModelTrainer


corpus: Corpus = ColumnCorpus(data_folder, columns,
                              train_file='train.txt',
                              test_file='test.txt',
                              dev_file='dev.txt',
                              document_separator_token="-DOCSTART-")

label_dictionary = corpus.make_label_dictionary('nel')

corpus.filter_empty_sentences()
corpus.filter_long_sentences(800)

embeddings = TransformerWordEmbeddings(
    model='xlm-roberta-base',
    layers="-1",
    subtoken_pooling="first",
    fine_tune=True,
    use_context=True
)

tagger = SpanClassifier(
    embeddings=embeddings,
    label_dictionary=label_dictionary
)

trainer = ModelTrainer(tagger, corpus)

trainer.train('./results/',
              learning_rate=5.0e-5,
              optimizer=torch.optim.AdamW,
              mini_batch_size=1,
              train_with_dev=False,
              shuffle=False,
              reduce_transformer_vocab=True
              )

Expected behavior

When setting reduce_transformer_vocab=True in the train() method and utilizing TransformerWordEmbeddings with model='xlm-roberta-base', the training process should proceed without any errors and make use of the reduced vocabulary from the transformer model as intended.

Logs and Stack traces

File /flair/trainers/trainer.py in train(self, base_path, anneal_factor, patience, min_learning_rate, initial_extra_patience, anneal_with_restarts, learning_rate, decoder_learning_rate, mini_batch_size, eval_batch_size, mini_batch_chunk_size, max_epochs, optimizer, train_with_dev, train_with_test, reduce_transformer_vocab, main_evaluation_metric, monitor_test, monitor_train_sample, use_final_model_for_eval, gold_label_dictionary_for_eval, exclude_labels, sampler, shuffle, shuffle_first_epoch, embeddings_storage_mode, epoch, save_final_model, save_optimizer_state, save_model_each_k_epochs, create_file_logs, create_loss_file, write_weights, plugins, attach_default_scheduler, **kwargs)
    198         ]:
    199             local_variables.pop(var)
--> 200         return self.train_custom(**local_variables, **kwargs)
    201 
    202     def fine_tune(

File /flair/trainers/trainer.py in train_custom(self, base_path, learning_rate, decoder_learning_rate, mini_batch_size, eval_batch_size, mini_batch_chunk_size, max_epochs, optimizer, train_with_dev, train_with_test, max_grad_norm, reduce_transformer_vocab, main_evaluation_metric, monitor_test, monitor_train_sample, use_final_model_for_eval, gold_label_dictionary_for_eval, exclude_labels, sampler, shuffle, shuffle_first_epoch, embeddings_storage_mode, epoch, save_final_model, save_optimizer_state, save_model_each_k_epochs, create_file_logs, create_loss_file, write_weights, use_amp, plugins, **kwargs)
    596                             # forward pass
    597                             with torch.autocast(device_type=flair.device.type, enabled=use_amp):
--> 598                                 loss, datapoint_count = self.model.forward_loss(batch_step)
    599 
    600                             batch_train_samples += datapoint_count

File /flair/nn/model.py in forward_loss(self, sentences)
    745 
    746         # pass data points through network to get encoded data point tensor
--> 747         data_point_tensor = self._encode_data_points(sentences, data_points)
    748 
    749         # decode

File /flair/nn/model.py in _encode_data_points(self, sentences, data_points)
    713         # embed sentences
    714         if self.should_embed_sentence:
--> 715             self.embeddings.embed(sentences)
    716 
    717         # get a tensor of data points

File /flair/embeddings/base.py in embed(self, data_points)
     48 
     49         if not self._everything_embedded(data_points):
---> 50             self._add_embeddings_internal(data_points)
     51 
     52         return data_points

File /flair/embeddings/transformer.py in _add_embeddings_internal(self, sentences)
    703         gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad()
    704         with gradient_context:
--> 705             embeddings = self._forward_tensors(tensors)
    706 
    707         if self.document_embedding:

File /flair/embeddings/transformer.py in _forward_tensors(self, tensors)
   1422 
   1423     def _forward_tensors(self, tensors) -> Dict[str, torch.Tensor]:
-> 1424         return self.forward(**tensors)
   1425 
   1426     def export_onnx(

File /flair/embeddings/transformer.py in forward(self, input_ids, sub_token_lengths, token_lengths, attention_mask, overflow_to_sample_mapping, word_ids, langs, bbox, pixel_values)
   1322         if pixel_values is not None:
   1323             model_kwargs["pixel_values"] = pixel_values
-> 1324         hidden_states = self.model(input_ids, **model_kwargs)[-1]
   1325         # make the tuple a tensor; makes working with it easier.
   1326         hidden_states = torch.stack(hidden_states)

File /torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

File /torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

File /transformers/models/xlm_roberta/modeling_xlm_roberta.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    837         head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
    838 
--> 839         embedding_output = self.embeddings(
    840             input_ids=input_ids,
    841             position_ids=position_ids,

File /torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

File /torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

File /transformers/models/xlm_roberta/modeling_xlm_roberta.py in forward(self, input_ids, token_type_ids, position_ids, inputs_embeds, past_key_values_length)
    124 
    125         if inputs_embeds is None:
--> 126             inputs_embeds = self.word_embeddings(input_ids)
    127         token_type_embeddings = self.token_type_embeddings(token_type_ids)
    128 

File /torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

File /torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

File /torch/nn/modules/sparse.py in forward(self, input)
    160 
    161     def forward(self, input: Tensor) -> Tensor:
--> 162         return F.embedding(
    163             input, self.weight, self.padding_idx, self.max_norm,
    164             self.norm_type, self.scale_grad_by_freq, self.sparse)

File /torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2231         # remove once script supports set_grad_enabled
   2232         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2233     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2234 
   2235 

IndexError: index out of range in self

Screenshots

No response

Additional Context

No response

Environment

Versions:

Flair

0.13.0

Pytorch

2.1.0

Transformers

4.34.1

GPU

False

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions