Skip to content

ztjhz/t5-jax

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX Implementation of T5

This project is an implementation of the T5 model using JAX. It takes a functional approach, leveraging the capabilities of JAX to achieve its goals. The primary objective of this project is twofold: to offer a versatile codebase for researching Transformer-based LLM architectures, and to show how Transformer-based language models can be implemented using JAX and trained on Google Cloud TPUs.

This project is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).

This project is inspired by ayaka/bart-base-jax, while the code for this project is entirely written by myself.

Setup Instructions

  1. Install jax

    pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
  2. Then install requirements.txt:

    pip install -r requirements.txt

Usage examples

  1. Tokenize inputs

    from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration
    
    model = FlaxT5ForConditionalGeneration.from_pretrained("allenai/unifiedqa-t5-base")
    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    inputs = tokenizer(
       ["summarize: My friends are cool but they eat too many carbs."], return_tensors="np"
    )
    input_ids = inputs["input_ids"]
  2. Initialize model parameters

    from utils.params_utils import init_params_pretrained
    
    params = init_params_pretrained()
  3. Encoder

    from model.transformer_encoder import fwd_transformer_encoder
    
    encoder_output = fwd_transformer_encoder(
       encoder_params=params["encoder"],
       embedding_params=params["shared"],
       input_ids=input_ids,
    )
  4. Decoder

    from model.transformer_decoder import fwd_transformer_decoder
    
    decoder_start_token_id = model.config.decoder_start_token_id
    decoder_input_ids = (
       jnp.ones((encoder_input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
    )
    
    decoder_output = fwd_transformer_decoder(
       decoder_params=params["decoder"],
       embedding_params=params["shared"],
       decoder_input_ids=decoder_input_ids,
       encoder_output=encoder_output,
    )
  5. Generate

    from model.t5_generate import fwd_t5_generate
    from config import config
    
    sequences = fwd_t5_generate(
       params,
       encoder_input_ids=input_ids,
       eos_token_id=config.EOS_TOKEN_ID,
       decoder_start_token_id=config.DECODER_START_TOKEN_ID,
    )
    output = tokenizer.batch_decode(sequences, skip_special_tokens=True)

Discoveries

I discovered an issue in the Hugging Face transformers FlaxT5. Their hidden states output were not consistent with my outputs.

I observed that my encoder and decoder block 11 hidden state does not align with their block 11 hidden_state even though my hidden states from block 0 to 10 aligns with the their hidden states from block 0 to 10. Additionally, my final hidden state (after applying the layer norm) also aligns with their final hidden state after the layer norm.

I then raised an issue and made a PR to fix this issue.

Analysis

1. JAX precision

  1. On TPU, JAX defaults to using bfloat16 for matrix multiplication even when the data type is specified as float32. While this may speed up training, some precision is lost.
  2. When utilizing GPU, the Hugging Face transformers model exhibits distinct precision compared to JAX.

2. Layer normalisation

T5 performs pre-layer norm instead of post-layer norm.

Attention:

  • (layer norm -> self attention -> dropout -> add) instead of
  • (self-attention -> dropout -> add -> layer norm)

Feed foward:

  • (layer norm -> densereludense -> dropout -> add) instead of
  • (densereludense -> dropout -> add -> layernorm)

3. Dropout

  • drop out performed once at the end in ff (linear -> linear -> dropout) instead of twice after each linear layer (linear -> dropout -> linear -> dropout)

4. Scaling QK matrices

Hugging Face T5 does not scale the QK matrices

The T5 paper did not mention the exclusion of QK matrix scaling.

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(QK^T\right)V $$

instead of

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

5. Relative Attention Bias / Position embeddings

T5's position embeddings (relative attention bias) is different from Self-Attention with Relative Position Representations. (Hugging Face's implementation)

  1. Uses binned relative attention bias to reduce time complexity for long sequences
  2. Only applies the bias before $\text{softmax}$

It is not mentioned in the T5 paper that they only apply the bias before the $\text{softmax}$

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(QK^T + X\right)V $$

instead of

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T + X}{\sqrt{d_k}}\right)V + Y $$

Where:

  • $Q$ is the query matrix
  • $K$ is the key matrix
  • $V$ is the value matrix
  • $d_k$ is the dimension of the keys

In the case of multi-head attention, the above process is performed multiple times with different learned linear transformations of the original (Q), (K), and (V). If we have (h) heads, then we have:

$$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W_O $$

where each head is defined as:

$$ \text{head}i = \text{Attention}(QW{Qi}, KW_{Ki}, VW_{Vi}) $$

6. Layer norm in T5 does not subtract mean

Layer Norm Definition

Given an input $x \in \mathbb{R}^{d}$, the output of layer normalization $y \in \mathbb{R}^{d}$ is calculated as:

$$ y = \gamma \left( \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \right) + \beta $$

Where:

  • $\mu$ is the mean of the input $x$: $\mu = \frac{1}{d} \sum_{i=1}^{d} x_i$
  • $\sigma^2$ is the variance of the input $x$: $\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2$
  • $\gamma$ and $\beta$ are learnable parameters (the weight and bias), which have the same dimension as $x$.
  • $\epsilon$ is a small number for numerical stability, typically on the order of $10^{-5}$ to $10^{-8}$.

T5 Layer Norm

T5's layer norm does not subtract the mean ($\mu$) and does not have a bias ($\beta$). They utilise Root Mean Square Layer Normalization. (HuggingFace's implementation)

The T5 paper did not mention that they used Root Mean Square Layer Normalization

Root mean Square Layer Normalization Formula:

$$ \bar{a_i}=\frac{a_i}{RMS(a)}g_i\textrm{, where }RMS(a)=\sqrt{\frac{1}{n}\sum_{i=1}^{n}a_i^2} $$

Where:

  • $g_i$ is the gain (weight) parameter
  • $a_i$ is the inputs
  • $\bar{a_i}$ is the scaled values of the inputs
  • $RMS(a)$ is the root mean square $a$.

7. T5 employs a final layer norm on the output of the encoder and decoder

In the original transformer model proposed by Vaswani et al., 2017, there is no final layer normalization on the outputs of the encoder and decoder. The outputs of these components are fed directly into subsequent operations.

In the T5 model, there is a final layer normalization step after the output from both the encoder and decoder.

8. T5 uses tied word embeddings

T5 uses tied word embeddings, which is layered upon the output of the final decoder. This differs from the conventional Transformer architecture, which uses a linear layer for the language model head (lm_head).

However, for T5 during training, the lm_head is the transpose of the word embedding. This reduces the number of trainable parameters in the model by sharing the same embeddings for the input and output layers. This not only decreases the computational load, but also helps in regularizing the model, leading to an improved generalization ability and potentially better performance on unseen data.

The output of the final decoder block is fed into a dense layer with a softmax output, whose weights are shared with the input embedding matrix.

9. T5 also rescales the decoder output for tied word embedding in the language model head

The rescaling of decoder output before passing it into the lm_head is not mentioned in the T5 paper

However, their T5 implementation scales the decoder output.

$$ \mathrm{lm_head}(x) = \frac{x}{\sqrt{d_{\text{model}}}}W_e \textrm{\quad instead of\quad} \mathrm{lm_head}(x) = xW_e $$

$$ y = \text{Softmax}(\mathrm{lm_head}(x)) $$

Where:

  • $x$ is the decoder output.
  • $y$ is the logits.
  • $d_{\text{model}}$ is the dimensionality of the model.
  • $W_e$ is the input embeddings used for tie word embeddings.
  • $\mathrm{lm_head}$ is the input embeddings used for tie word embeddings.

T5 Jax Implementation Results

Input and Output

Input Hugging Face Output My Output
translate English to German: That is good. 'Das ist gut so.' 'Das ist gut so.'
cola sentence: The course is jumping well. acceptable acceptable
stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field. 4.0 4.0
summarize: In recent times, rapid advancements in technology have revolutionized various industries, enhancing efficiency, connectivity, and convenience for individuals and businesses alike. rapid advancements in technology have revolutionized various industries rapid advancements in technology have revolutionized various industries

Time taken

The inputs above are fed into the Hugging Face transformers model and my own model. Generation was repeated 100 times and here is the total time taken:

Device Hugging Face Mine Speed Improvement
GPU 190.63s 64.36s 66.24% faster
TPU 466.59s 42.31s 90.93% faster

Conclusion

In a direct comparison, my implementation achieves comparable results to Hugging Face's implementation, while also demonstrating superior performance in terms of speed. Both implementations produced identical translations, acceptability scores, and summarization outputs in the provided examples. However, my implementation outperforms Hugging Face's implementation, completing the tasks approximately 90.93% faster on TPU and 66.24% faster on GPU.

Fine-tuning

Upon reading the original T5 paper, I discovered that it primarily focused on translating English to German, French, and Romanian. This sparked my curiosity about whether the model could also handle translating from French to English. To test this, I utilized the pre-trained model and applied a task prefix of "translate French to English: ". Unfortunately, the model proved incapable of performing the desired translation. Determined to overcome this limitation, I embarked on the journey of fine-tuning my own model specifically tailored for the task of French to English translation.

For more in-depth information regarding my fine-tuning process, you can visit the GitHub branch or explore the WandB runs. These resources provide additional insights into the details of my fine-tuning procedure.

Dataset

To finetune my model, I utilized the wmt-14 fr-en dataset, which consists of approximately 40 million data entries for the training set, and around 3,000 rows for the test and validation sets.

Results

Coming soon...

About

JAX implementation of the T5 model: Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages