Skip to content
View causalfusion's full-sized avatar

Block or report causalfusion

Block user

Prevent this user from interacting with your repositories and sending you notifications. Learn more about blocking users.

You must be logged in to block users.

Please don't include any personal information such as legal names or email addresses. Maximum 100 characters, markdown supported. This note will be visible to only you.
Report abuse

Contact GitHub support about this user’s behavior. Learn more about reporting abuse.

Report abuse
causalfusion/README.md

Causal Diffusion Transformer for Generative Modeling

samples samples

This repo provides the official implementation for our paper

Causal Diffusion Transformers for Generative Modeling
Chaorui Deng, Deyao Zhu, Kunchang Li, Shi Guang, Haoqi Fan
Bytedance Research

Setup

Install the dependencies:

git clone https://github.com/causalfusion/causalfusion.git
pip install -U torch==2.5.1 torchvision==0.20.1 transformers==4.46.2

Download pretrained VAE from MAR.

Training

Training CausalFusion-XL on 8 GPUs with a batch size of 256:

torchrun --nnodes=1 --nproc_per_node=8 train.py --data-path=$PATH_TO_IMAGENET_TRAIN_DIR --tokenizer-path=$PATH_TO_MAR_VAE --results-dir=$PATH_TO_RESULTS_DIR --model=CausalFusion-XL --global-batch-size=256 --ckpt-every=50000 --lr=1e-4 --distributed --grad-checkpoint

Sampling

Download pretrained CausalFusion-XL.

Sampling 10,000 images on 8 GPUs with CFG scale of 4.0:

torchrun --nnodes=1 --nproc_per_node=8 sample.py --distributed --model=CausalFusion-L --tokenizer-path=$PATH_TO_MAR_VAE --num-fid-samples=10000 --ckpt=$PATH_TO_PRETRAINED_CKPT --sample-dir=$PATH_TO_SAMPLE_DIR --cfg-scale=4.0

Evaluation

See the instrutions in ADM for evaluation.

BibTex

@article{deng2024causalfusion,
  title={Causal Diffusion Transformers for Generative Modeling},
  author={Chaorui Deng, Deyao Zhu, Kunchang Li, Shi Guang, Haoqi Fan},
  year={2024},
  journal={arXiv preprint arXiv:2412.12095},
}

Acknowledgments

This codebase borrows from DiT, MAR, and ADM, thanks for their great works!

Pinned Loading

  1. ByteDance-Seed/Bagel ByteDance-Seed/Bagel Public

    Open-source unified multimodal model

    Python 4.9k 434

  2. causalfusion causalfusion Public

    Python 179 4