Skip to content

FlaxT5ForConditionalGeneration: Inconsistency in Final Block Hidden State of Encoder/Decoder #23960

@ztjhz

Description

@ztjhz

System Info

I'm implementing my own T5 model in JAX and using the FlaxT5ForConditionalGeneration module to evaluate the results of my work.

During my testing phase, I ran into an issue. Using the provided code, I noticed that the hidden states from block 0 to block 10 in my implementation are consistent with the corresponding hidden states in the transformer model (i.e., output_flax['hidden_states'][0] to output_flax['hidden_states'][10]).

However, the issue arises in the final block, where the hidden state of my model doesn't match with the transformer model's corresponding hidden state (output_flax['hidden_states'][11]). This is strange because after I apply the RMS layer normalization on my final block hidden state to get the final_hidden_state, it aligns with the final_hidden_state of the transformer model (output_flax['final_hidden_state']).

According to my understanding, the encoder block is replicated 12 times without any special processing in the final block. Hence, I am unclear about what could be causing this inconsistency in the final block's hidden states for both the encoder and decoder.

In summary, here's what I've observed:

  • My hidden state aligns with output_flax['hidden_states'][0] to output_flax['hidden_states'][10].
  • My hidden state doesn't match output_flax['hidden_states'][11] (before applying the final layer norm).
  • My final hidden state (after applying the layer norm) aligns with output_flax['final_hidden_state'].

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Here is the Python code I used for testing:

from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = FlaxT5ForConditionalGeneration.from_pretrained("allenai/unifiedqa-t5-base")

inputs = tokenizer(
    ["summarize: My friends are cool but they eat too many carbs."], return_tensors="np"
)
input_ids = inputs["input_ids"]

output_flax = model.encode(
    input_ids, output_hidden_states=True, return_dict=True, output_attentions=True
)

Expected behavior

I expect that each block's hidden state in my implementation of the encoder/decoder should align with the corresponding block's hidden state in the transformer model.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions