Skip to content

Conversation

FeixLiu
Copy link
Contributor

@FeixLiu FeixLiu commented Sep 14, 2021

PR types

Bug fixes

PR changes

Others

Describe

mp下,有些var并不是分布式的,譬如scale、bias等等。这些var如果在GradientClipByGlobalNorm时候在各路mp均进行累加,那么最后的global norm会略大一些。
修改后,对于is_distributed为False的grad,只在mp_rank为0的节点进行累加计算Norm,后续通过c_allreduce_sum来计算GlobalNorm
修改前mp_rank=1 program:
image

image

修改后mp_rank=1 program:
image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@wangxicoding wangxicoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加上开MP和不开MP时,混合并行精度的对比

@FeixLiu FeixLiu force-pushed the fix_mp_multi_gradient_clip_prob branch from 5e1fe33 to 7ebecff Compare September 15, 2021 01:32
@FeixLiu
Copy link
Contributor Author

FeixLiu commented Sep 15, 2021

Screen Shot 2021-09-15 at 2 23 36 PM

gradient clip 数值改为1e-6后,dev分支,pr分支在w/ mp与w/o mp下loss基本对齐

# Therefore, we prune those duplicated vars for grad clip.
if mp_rank > 0 and (not (hasattr(input_var, 'is_distributed')
and input_var.is_distributed)):
removed_op_idx.append(idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实当前这种删法是不严谨的,按照之前global_clip的square、sum op实现会出问题。不过目前改成了squarel2norm实现,不会出问题就是了。

@FeixLiu FeixLiu force-pushed the fix_mp_multi_gradient_clip_prob branch from a9763af to ee1eff2 Compare September 16, 2021 01:25
@FeixLiu FeixLiu force-pushed the fix_mp_multi_gradient_clip_prob branch from 70baf73 to d9b6d50 Compare September 16, 2021 06:58
Copy link
Contributor

@wangxicoding wangxicoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
'update_loss_scaling', 'fill_constant', 'c_allreduce_sum',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👻

OP_ROLE_KEY: OpRole.Optimize,
})
return
for idx, op in list(enumerate(block.ops)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实反向遍历更好点点

@wangxicoding wangxicoding changed the title Fix mp multi gradient clip prob [hybrid] Fix mp multi gradient clip prob Sep 16, 2021
Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wangxicoding wangxicoding merged commit a4eadd1 into PaddlePaddle:develop Sep 16, 2021
@FeixLiu FeixLiu deleted the fix_mp_multi_gradient_clip_prob branch September 16, 2021 23:58
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 29, 2021
@FeixLiu FeixLiu changed the title [hybrid] Fix mp multi gradient clip prob [hybrid bug fix] Fix mp multi gradient clip prob Oct 11, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants