This repo provides the official implementation for our paper
Causal Diffusion Transformers for Generative Modeling
Chaorui Deng, Deyao Zhu, Kunchang Li, Shi Guang, Haoqi Fan
Bytedance Research
Install the dependencies:
git clone https://github.com/causalfusion/causalfusion.git
pip install -U torch==2.5.1 torchvision==0.20.1 transformers==4.46.2
Download pretrained VAE from MAR.
Training CausalFusion-XL on 8 GPUs with a batch size of 256:
torchrun --nnodes=1 --nproc_per_node=8 train.py --data-path=$PATH_TO_IMAGENET_TRAIN_DIR --tokenizer-path=$PATH_TO_MAR_VAE --results-dir=$PATH_TO_RESULTS_DIR --model=CausalFusion-XL --global-batch-size=256 --ckpt-every=50000 --lr=1e-4 --distributed --grad-checkpoint
Download pretrained CausalFusion-XL.
Sampling 10,000 images on 8 GPUs with CFG scale of 4.0:
torchrun --nnodes=1 --nproc_per_node=8 sample.py --distributed --model=CausalFusion-L --tokenizer-path=$PATH_TO_MAR_VAE --num-fid-samples=10000 --ckpt=$PATH_TO_PRETRAINED_CKPT --sample-dir=$PATH_TO_SAMPLE_DIR --cfg-scale=4.0
See the instrutions in ADM for evaluation.
@article{deng2024causalfusion,
title={Causal Diffusion Transformers for Generative Modeling},
author={Chaorui Deng, Deyao Zhu, Kunchang Li, Shi Guang, Haoqi Fan},
year={2024},
journal={arXiv preprint arXiv:2412.12095},
}
This codebase borrows from DiT, MAR, and ADM, thanks for their great works!