[mcore] option to use dist checkpoint #1030
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
mcore dist checkpointing is a parallel-invariant weight format, you can save and load in arbitrary parallel settings. e.g. save in tp2pp2 and load in tp4pp1.
This PR introduce an option to use dist checkpoint with mcore backend. It is disabled by default for backward compatibility. But future support for mcore MoE models and VLM models will work only when dist ckpt is enabled for a easier implementation.
Before this PR, when initing actor and critic workers, each GPU would load the entire huggingface weights and then re-shard to correct mcore model state dict, making the procedure slow and complicated.
With this PR, we convert hf weight to dist ckpt by offline scripts, and each GPU will only load its parts from dist ckpt. The speed is faster and no more online resharding needed.
When loading
Qwen2-7B-Instruct
for critic worker, the loading time reduced from 109s to 25s, speedup by 4.36xThe
converter_hf_to_mcore.py
in this version use existing online resharding function to convert weights. And it should be refactored for better efficiency and MoE/VLM models.Thanks to #998 for the optimization of loading hf weight only at GPU 0.
Future TODO:
megatron_checkpoint_manager.py
with dist ckptmodel_merger.py