-
Notifications
You must be signed in to change notification settings - Fork 30.2k
Description
System Info
I'm implementing my own T5 model in JAX and using the FlaxT5ForConditionalGeneration
module to evaluate the results of my work.
During my testing phase, I ran into an issue. Using the provided code, I noticed that the hidden states from block 0
to block 10
in my implementation are consistent with the corresponding hidden states in the transformer model (i.e., output_flax['hidden_states'][0]
to output_flax['hidden_states'][10]
).
However, the issue arises in the final block, where the hidden state of my model doesn't match with the transformer model's corresponding hidden state (output_flax['hidden_states'][11]
). This is strange because after I apply the RMS layer normalization on my final block hidden state to get the final_hidden_state
, it aligns with the final_hidden_state
of the transformer model (output_flax['final_hidden_state']
).
According to my understanding, the encoder block is replicated 12 times without any special processing in the final block. Hence, I am unclear about what could be causing this inconsistency in the final block's hidden states for both the encoder and decoder.
In summary, here's what I've observed:
- My hidden state aligns with
output_flax['hidden_states'][0]
tooutput_flax['hidden_states'][10]
. - My hidden state doesn't match
output_flax['hidden_states'][11]
(before applying the final layer norm). - My final hidden state (after applying the layer norm) aligns with
output_flax['final_hidden_state']
.
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Here is the Python code I used for testing:
from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = FlaxT5ForConditionalGeneration.from_pretrained("allenai/unifiedqa-t5-base")
inputs = tokenizer(
["summarize: My friends are cool but they eat too many carbs."], return_tensors="np"
)
input_ids = inputs["input_ids"]
output_flax = model.encode(
input_ids, output_hidden_states=True, return_dict=True, output_attentions=True
)
Expected behavior
I expect that each block's hidden state in my implementation of the encoder/decoder should align with the corresponding block's hidden state in the transformer model.