Skip to content

Conversation

AndSonder
Copy link
Contributor

@AndSonder AndSonder commented Apr 9, 2024

PR Category

Auto Parallel

PR Types

Improvements

Description

当下 use_reentrant == True 时会使用 PyLayer 来实现。但 PyLayer 目前不支持以 dict 形式传入 Tensor 类型参数(因为以 dict 形式传入的 Tensor 不会创建反向节点、反向边)

为了提升分布式训练的易用性,本 PR支持当 use_reentrant == True 时 recompute 使用 dict 形式传入 Tensor 类型参数。主要思路为 将 position-args + keyword-args 重排成 position-args

性能测试数据如下:

测试环境:4 卡 3090,Llama2 模型 num_hidden_layer hack 为 4

收集第30个step的性能数据:

Case interval_runtime interval_samples_per_second interval_steps_per_second Loss(step30)
Llama2(不使用 kwargs) 10.3995 1.5394 0.0962 7.09293509
Llama2(使用 kwargs) 10.4043 1.5378 0.0961 7.09293509
GPT3(不使用kwargs) 2.4434 1.6371 0.4093 10.3778286
GPT3(使用kwargs) 2.4371 1.6413 0.4103 10.3778286

Llama2 测试脚本如下:

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# just for debug

set -x
unset CUDA_VISIBLE_DEVICES

task_name="llama_auto_dp2mp2pp2"
rm -rf output/$task_name/
rm -rf "output/$task_name""_log"

export SOT_LOG_LEVEL=4
export PYTHONPATH=../../../:$PYTHONPATH
# ulimit -c unlimited
# export GLOG_v=4

# export FLAGS_call_stack_level=3
# export FLAGS_use_cuda_managed_memory=true

# export FLAGS_embedding_deterministic=1        
# export FLAGS_cudnn_deterministic=1
# export NVIDIA_TF32_OVERRIDE=0

to_static=0  # 是否开启动转静训练

python -u  -m paddle.distributed.launch \
    --gpus "0,1,2,3" \
    --log_dir "auto_3d" \
    run_pretrain_auto.py \
    --model_type "llama" \
    --model_name_or_path "facebook/llama-7b" \
    --tokenizer_name_or_path "facebook/llama-7b" \
    --input_dir "../data" \
    --output_dir "output/$task_name" \
    --split 949,50,1 \
    --max_seq_length 2048 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --use_flash_attention 0 \
    --use_fused_rms_norm 0 \
    --fp16 0 \
    --fp16_opt_level "O2"  \
    --scale_loss 1024 \
    --pipeline_parallel_degree 4 \
    --tensor_parallel_degree 1 \
    --sharding_parallel_degree 1 \
    --learning_rate 0.0001 \
    --min_learning_rate 0.00001 \
    --max_steps 30 \
    --save_steps 5000000 \
    --weight_decay 0.01 \
    --warmup_ratio 0.01 \
    --logging_steps 1\
    --dataloader_num_workers 1 \
    --sharding_parallel_degree 1 \
    --sharding "stage1" \
    --eval_steps 1000000 \
    --disable_tqdm true \
    --continue_training 0 \
    --recompute 1 \
    --recompute_granularity full \
    --do_train \
    --do_eval \
    --device "gpu" \
    --data_impl "mmap" \
    --enable_auto_parallel 1 \
    --max_grad_norm 1.0 \
    --to_static $to_static \

修改方式 paddlenlp/transformers/llama/modeling_auto.py 中所有启用 recompute 的地方(一共3处)

image

image

image

GPT 运行脚本如下:

export PYTHONPATH="../../../":$PYTHONPATH
export FLAGS_cudnn_deterministic=1
export FLAGS_embedding_deterministic=1 
export NVIDIA_TF32_OVERRIDE=0
export FLAGS_call_stack_level=3

to_static=0
# export TRANSLATOR_DISABLE_NEW_ERROR=0
# export TRANSLATOR_CODE_LEVEL=100

task_name="gpt3_auto_dp2mp2pp2_${to_static}"
log_dir="output/$task_name""_log"
output_dir="output/$task_name"
rm -rf $log_dir
rm -rf $output_dir

python -u -m paddle.distributed.launch \
    --gpus "0,1,2,3" \
    --log_dir ${log_dir} \
    run_pretrain_auto.py \
    --model_name_or_path gpt2-medium-en \
    --tokenizer_name_or_path gpt2-medium-en \
    --input_dir "../data" \
    --output_dir ${output_dir}  \
    --split 949,50,1 \
    --max_seq_length 1024 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --sharding "" \
    --tensor_parallel_degree 2 \
    --pipeline_parallel_degree 2 \
    --sequence_parallel 0 \
    --fuse_attention_qkv 0 \
    --use_flash_attention 0 \
    --scale_loss 1024 \
    --learning_rate 0.00001 \
    --min_learning_rate 0.000005 \
    --max_steps 30 \
    --save_steps 50000 \
    --weight_decay 0.01 \
    --warmup_ratio 0.01 \
    --max_grad_norm 1.0 \
    --logging_steps 1\
    --continue_training 0\
    --dataloader_num_workers 1 \
    --eval_steps 100000 \
    --report_to "visualdl" \
    --disable_tqdm true \
    --recompute 0 \
    --gradient_accumulation_steps 4 \
    --do_train \
    --do_eval \
    --device "gpu" \
    --model_type "gpt" \
    --enable_auto_parallel 1 \
    --to_static ${to_static} \
    --fp16 0 \
    --fp16_opt_level "O2"  \

paddlenlp/transformers/gpt/modeling_auto.py 修改如下:

image

Copy link

paddle-bot bot commented Apr 9, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 9, 2024
@AndSonder AndSonder marked this pull request as ready for review April 9, 2024 01:55
@AndSonder
Copy link
Contributor Author

@ForFishes @MarioLulab CI 都问题了,麻烦研发老师 review 一下 ~

@ForFishes
Copy link
Member

您好,这个pr涉及到一些问题,内部需要进一步讨论这个问题。

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit 64cad15 into PaddlePaddle:develop Apr 16, 2024
Asthestarsfalll pushed a commit to Asthestarsfalll/Paddle that referenced this pull request Apr 17, 2024
…== True (PaddlePaddle#63337)

* support kwargs for recompute when open use_reentrant

* update test

* fix

* Update recompute.py

* fix

* fix
Comment on lines +533 to +543
input_args = args
# rearrange `position-args + keyword-args` into `position-args`
if isinstance(function, paddle.nn.Layer):
dyfunc_sig = inspect.signature(function.forward)
else:
dyfunc_sig = inspect.signature(function)

bound_args = dyfunc_sig.bind(*args, **kwargs)
bound_args.apply_defaults()
input_args = list(bound_args.arguments.values())
return RecomputeFunction.apply(function, preserve, *input_args)
Copy link
Member

@SigureMo SigureMo Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于如下情况

# 摘自 PaddleMIX:https://github.com/PaddlePaddle/PaddleMIX/blob/8b896d533811a3500af3064c5f1952b77003d4c8/ppdiffusers/ppdiffusers/models/unet_2d_blocks.py#L1149-L1155
def custom_forward(*inputs):
    ...

使用 bound_args.arguments 是错误的,无论传入多少个值,bound_args.arguments 只有一个值,就是打包后的 inputs

需要考虑所有 Parameter kind

import inspect

def custom_forward(*inputs, **kwargs):
    return inputs

def convert_inputs_to_positional_args(fn, *args, **kwargs):
    positional_args = []
    sig = inspect.signature(fn)
    bound_args = sig.bind(*args, **kwargs)
    bound_args.apply_defaults()

    for arg, param in zip(bound_args.arguments.values(), sig.parameters.values()):
        if param.kind == param.VAR_POSITIONAL:
            positional_args.extend(arg)
        elif param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
            positional_args.append(arg)
        elif param.kind == param.VAR_KEYWORD:
            positional_args.extend(arg.values())
        elif param.kind == param.KEYWORD_ONLY:
            raise ValueError("Currently, keyword-only arguments are not supported.")
        else:
            raise ValueError("Unknown parameter kind.")
    return positional_args

convert_inputs_to_positional_args(custom_forward, 1, 2, y=2, x=1)

主要思路为 将 position-args + keyword-args 重排成 position-args

注意该方案天生不支持 keyword-only 的函数,如果需要支持那么这个方案是不可行的

另外,本 PR 已经影响了高优监控模型 Stable Diffusion,我先提一个 PR 尝试 revert(#63637),可以同时看看怎么修复

@AndSonder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,收到

另外有一个问题想确认一下,对 Stable Diffusion 的影响是上述的 case 发生报错,还是其他问题呢

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上述 case

@luotao1 luotao1 changed the title 【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True 【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True -part Apr 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants