Skip to content

[Bug]: TextPairRegressor has state dict key mismatch for embedding #3536

@MattGPT-ai

Description

@MattGPT-ai

Describe the bug

The TextPairRegressor fails to initialize from a state dict as created by _get_state_dict because of a mismatch in the keys 'embeddings' and 'document_embeddings'

It appears as though the logic is copied from TextPairClassifier which inherits from DefaultClassifier

To Reproduce

from flair.models import TextPairRegressor
from flair.embeddings import TransformerDocumentEmbeddings
from flair.data import Sentence
import torch

document_embeddings = TransformerDocumentEmbeddings('bert-base-uncased')

model = TextPairRegressor(document_embeddings, label_type='test')

model._init_model_with_state_dict(model._get_state_dict())

Expected behavior

This should be expected to work and successfully load the model

Logs and Stack traces

/pyzr/active_venv/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[17], line 10
      6 document_embeddings = TransformerDocumentEmbeddings('bert-base-uncased')
      8 model = TextPairRegressor(document_embeddings, label_type='test')
---> 10 model._init_model_with_state_dict(model._get_state_dict())

File /pyzr/active_venv/lib/python3.12/site-packages/flair/models/pairwise_regression_model.py:221, in TextPairRegressor._init_model_with_state_dict(cls, state, **kwargs)
    218     if arg not in kwargs and arg in state:
    219         kwargs[arg] = state[arg]
--> 221 return super()._init_model_with_state_dict(state, **kwargs)

File /pyzr/active_venv/lib/python3.12/site-packages/flair/nn/model.py:103, in Model._init_model_with_state_dict(cls, state, **kwargs)
    100         embeddings = load_embeddings(embeddings)
    101     kwargs["embeddings"] = embeddings
--> 103 model = cls(**kwargs)
    105 model.load_state_dict(state["state_dict"])
    107 return model

TypeError: TextPairRegressor.__init__() got an unexpected keyword argument 'document_embeddings'

Screenshots

No response

Additional Context

No response

Environment

Versions:

Flair

0.14.0

Pytorch

2.4.0+cu121

Transformers

4.44.1

GPU

False

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions