Skip to content

Large differences between T5 weight initialization in TF and torch #16749

@jorgemcgomes

Description

@jorgemcgomes
  • transformers version: 4.18.0, master branch

Who can help

@patrickvonplaten

I found some significant differences in weight init between the PT and TF implementations of T5.

The embeddings (model.shared):

  • In PT, according to T5PreTrainedModel._init_weights, they are initialized with random normal with std=1.0:
    module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)

  • In TF (TFT5Model), the embeddings are initialized as such:
    self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
    Since initializer_range is not being provided, it is using the default, which is hidden_size**-0.5 (see TFSharedEmbeddings).

This means that in the base model (d=768), the weights in PT are being initialized with stdev=1.0, and in TF they are being initialized with stdev=0.036.

The LM head (model.lm_head):

  • In PT, the initializer is not specified, meaning it is being initialized with a uniform distribution in [-sqrt(1/d_model), sqrt(1/d_model)] (https://pytorch.org/docs/stable/generated/torch.nn.Linear.html). The weights don't seem to be initialized in _init_weights either.
    lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

  • In TF, the initializer is explicitly provided (TFT5ForConditionalGeneration):
    lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor)

So, in the base model, the weights in PT are initialized with a uniform distribution of [-0.036, 0.036], and in TF they are initialized with a random normal with stdev=1.0.

I'm not entirely sure about the actual implications of this in model training. But at least the lm_head weights will have a huge impact in loss values initially.

Based on other transformer models I've seen, the "correct" answer seems to be that both weights should be initialised with stdev=1.0. But none of the implementations actually does this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions