-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
The TextPairRegressor
fails to initialize from a state dict as created by _get_state_dict
because of a mismatch in the keys 'embeddings' and 'document_embeddings'
It appears as though the logic is copied from TextPairClassifier
which inherits from DefaultClassifier
To Reproduce
from flair.models import TextPairRegressor
from flair.embeddings import TransformerDocumentEmbeddings
from flair.data import Sentence
import torch
document_embeddings = TransformerDocumentEmbeddings('bert-base-uncased')
model = TextPairRegressor(document_embeddings, label_type='test')
model._init_model_with_state_dict(model._get_state_dict())
Expected behavior
This should be expected to work and successfully load the model
Logs and Stack traces
/pyzr/active_venv/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[17], line 10
6 document_embeddings = TransformerDocumentEmbeddings('bert-base-uncased')
8 model = TextPairRegressor(document_embeddings, label_type='test')
---> 10 model._init_model_with_state_dict(model._get_state_dict())
File /pyzr/active_venv/lib/python3.12/site-packages/flair/models/pairwise_regression_model.py:221, in TextPairRegressor._init_model_with_state_dict(cls, state, **kwargs)
218 if arg not in kwargs and arg in state:
219 kwargs[arg] = state[arg]
--> 221 return super()._init_model_with_state_dict(state, **kwargs)
File /pyzr/active_venv/lib/python3.12/site-packages/flair/nn/model.py:103, in Model._init_model_with_state_dict(cls, state, **kwargs)
100 embeddings = load_embeddings(embeddings)
101 kwargs["embeddings"] = embeddings
--> 103 model = cls(**kwargs)
105 model.load_state_dict(state["state_dict"])
107 return model
TypeError: TextPairRegressor.__init__() got an unexpected keyword argument 'document_embeddings'
Screenshots
No response
Additional Context
No response
Environment
Versions:
Flair
0.14.0
Pytorch
2.4.0+cu121
Transformers
4.44.1
GPU
False
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working