This project is an implementation of the T5 model using JAX. It takes a functional approach, leveraging the capabilities of JAX to achieve its goals. The primary objective of this project is twofold: to offer a versatile codebase for researching Transformer-based LLM architectures, and to show how Transformer-based language models can be implemented using JAX and trained on Google Cloud TPUs.
This project is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).
This project is inspired by ayaka/bart-base-jax, while the code for this project is entirely written by myself.
- Setup Instructions
- Usage examples
- Discoveries
- Analysis
- 1. JAX precision
- 2. Layer normalisation
- 3. Dropout
- 4. Scaling QK matrices
- 5. Relative Attention Bias / Position embeddings
- 6. Layer norm in T5 does not subtract mean
- 7. T5 employs a final layer norm on the output of the encoder and decoder
- 8. T5 uses tied word embeddings
- 9. T5 also rescales the decoder output for tied word embedding in the language model head
- T5 Jax Implementation Results
- Fine-tuning
-
Install jax
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-
Then install requirements.txt:
pip install -r requirements.txt
-
Tokenize inputs
from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration model = FlaxT5ForConditionalGeneration.from_pretrained("allenai/unifiedqa-t5-base") tokenizer = AutoTokenizer.from_pretrained("t5-base") inputs = tokenizer( ["summarize: My friends are cool but they eat too many carbs."], return_tensors="np" ) input_ids = inputs["input_ids"]
-
Initialize model parameters
from utils.params_utils import init_params_pretrained params = init_params_pretrained()
-
Encoder
from model.transformer_encoder import fwd_transformer_encoder encoder_output = fwd_transformer_encoder( encoder_params=params["encoder"], embedding_params=params["shared"], input_ids=input_ids, )
-
Decoder
from model.transformer_decoder import fwd_transformer_decoder decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = ( jnp.ones((encoder_input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id ) decoder_output = fwd_transformer_decoder( decoder_params=params["decoder"], embedding_params=params["shared"], decoder_input_ids=decoder_input_ids, encoder_output=encoder_output, )
-
Generate
from model.t5_generate import fwd_t5_generate from config import config sequences = fwd_t5_generate( params, encoder_input_ids=input_ids, eos_token_id=config.EOS_TOKEN_ID, decoder_start_token_id=config.DECODER_START_TOKEN_ID, ) output = tokenizer.batch_decode(sequences, skip_special_tokens=True)
I discovered an issue in the Hugging Face transformers FlaxT5. Their hidden states output were not consistent with my outputs.
I observed that my encoder and decoder block 11
hidden state
does not align with their block 11
hidden_state
even though my hidden states
from block 0
to 10
aligns with the their hidden states
from block 0
to 10
. Additionally, my final hidden state
(after applying the layer norm) also aligns with their final hidden state
after the layer norm.
I then raised an issue and made a PR to fix this issue.
- On TPU, JAX defaults to using
bfloat16
for matrix multiplication even when the data type is specified asfloat32
. While this may speed up training, some precision is lost. - When utilizing GPU, the Hugging Face transformers model exhibits distinct precision compared to JAX.
T5 performs pre-layer norm instead of post-layer norm.
Attention:
(layer norm -> self attention -> dropout -> add)
instead of(self-attention -> dropout -> add -> layer norm)
Feed foward:
(layer norm -> densereludense -> dropout -> add)
instead of(densereludense -> dropout -> add -> layernorm)
- drop out performed once at the end in ff
(linear -> linear -> dropout)
instead of twice after each linear layer(linear -> dropout -> linear -> dropout)
Hugging Face T5 does not scale the QK matrices
The T5 paper did not mention the exclusion of QK matrix scaling.
instead of
T5's position embeddings (relative attention bias) is different from Self-Attention with Relative Position Representations. (Hugging Face's implementation)
- Uses binned relative attention bias to reduce time complexity for long sequences
- Only applies the bias before
$\text{softmax}$
It is not mentioned in the T5 paper that they only apply the bias before the
$\text{softmax}$
instead of
Where:
-
$Q$ is the query matrix -
$K$ is the key matrix -
$V$ is the value matrix -
$d_k$ is the dimension of the keys
In the case of multi-head attention, the above process is performed multiple times with different learned linear transformations of the original (Q), (K), and (V). If we have (h) heads, then we have:
where each head is defined as:
$$ \text{head}i = \text{Attention}(QW{Qi}, KW_{Ki}, VW_{Vi}) $$
Layer Norm Definition
Given an input
Where:
-
$\mu$ is the mean of the input$x$ :$\mu = \frac{1}{d} \sum_{i=1}^{d} x_i$ -
$\sigma^2$ is the variance of the input$x$ :$\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2$ -
$\gamma$ and$\beta$ are learnable parameters (the weight and bias), which have the same dimension as$x$ . -
$\epsilon$ is a small number for numerical stability, typically on the order of$10^{-5}$ to$10^{-8}$ .
T5 Layer Norm
T5's layer norm does not subtract the mean (
The T5 paper did not mention that they used Root Mean Square Layer Normalization
Root mean Square Layer Normalization Formula:
Where:
-
$g_i$ is the gain (weight) parameter -
$a_i$ is the inputs -
$\bar{a_i}$ is the scaled values of the inputs -
$RMS(a)$ is the root mean square$a$ .
In the original transformer model proposed by Vaswani et al., 2017, there is no final layer normalization on the outputs of the encoder and decoder. The outputs of these components are fed directly into subsequent operations.
In the T5 model, there is a final layer normalization step after the output from both the encoder and decoder.
T5 uses tied word embeddings
, which is layered upon the output of the final decoder. This differs from the conventional Transformer architecture, which uses a linear layer for the language model head (lm_head
).
However, for T5 during training, the lm_head
is the transpose of the word embedding. This reduces the number of trainable parameters in the model by sharing the same embeddings for the input and output layers. This not only decreases the computational load, but also helps in regularizing the model, leading to an improved generalization ability and potentially better performance on unseen data.
The output of the final decoder block is fed into a dense layer with a softmax output, whose weights are shared with the input embedding matrix.
The rescaling of decoder output before passing it into the lm_head is not mentioned in the T5 paper
However, their T5 implementation scales the decoder output.
Where:
-
$x$ is the decoder output. -
$y$ is the logits. -
$d_{\text{model}}$ is the dimensionality of the model. -
$W_e$ is the input embeddings used for tie word embeddings. -
$\mathrm{lm_head}$ is the input embeddings used for tie word embeddings.
Input | Hugging Face Output | My Output |
---|---|---|
translate English to German: That is good. | 'Das ist gut so.' | 'Das ist gut so.' |
cola sentence: The course is jumping well. | acceptable | acceptable |
stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field. | 4.0 | 4.0 |
summarize: In recent times, rapid advancements in technology have revolutionized various industries, enhancing efficiency, connectivity, and convenience for individuals and businesses alike. | rapid advancements in technology have revolutionized various industries | rapid advancements in technology have revolutionized various industries |
The inputs above are fed into the Hugging Face transformers model and my own model. Generation was repeated 100 times and here is the total time taken:
Device | Hugging Face | Mine | Speed Improvement |
---|---|---|---|
GPU | 190.63s | 64.36s | 66.24% faster |
TPU | 466.59s | 42.31s | 90.93% faster |
In a direct comparison, my implementation achieves comparable results to Hugging Face's implementation, while also demonstrating superior performance in terms of speed. Both implementations produced identical translations, acceptability scores, and summarization outputs in the provided examples. However, my implementation outperforms Hugging Face's implementation, completing the tasks approximately 90.93% faster on TPU and 66.24% faster on GPU.
Upon reading the original T5 paper, I discovered that it primarily focused on translating English to German, French, and Romanian. This sparked my curiosity about whether the model could also handle translating from French to English. To test this, I utilized the pre-trained model and applied a task prefix of "translate French to English: ". Unfortunately, the model proved incapable of performing the desired translation. Determined to overcome this limitation, I embarked on the journey of fine-tuning my own model specifically tailored for the task of French to English translation.
For more in-depth information regarding my fine-tuning process, you can visit the GitHub branch or explore the WandB runs. These resources provide additional insights into the details of my fine-tuning procedure.
To finetune my model, I utilized the wmt-14 fr-en dataset, which consists of approximately 40 million data entries for the training set, and around 3,000 rows for the test and validation sets.
Coming soon...