Skip to content

Conversation

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125336

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit a4d47b1 with merge base 746da87 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
curr_obj = getattr(curr_obj, curr_obj_name)
if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
if i != len(obj_names) - 1:
raise RuntimeError("Expect `_extra_state` to be the last obj name")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
raise RuntimeError("Expect `_extra_state` to be the last obj name")
raise ValueError("Expect `_extra_state` to be the last obj name")

Copy link
Contributor Author

@fegin fegin May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the API that receives an incorrect argument but the traverse logic is wrong. So ValueError doesn't seem to be more close to the meaning. AssertionError may be more close. I'll change in another PR.

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@fegin
Copy link
Contributor Author

fegin commented May 7, 2024

@pytorchbot merge -f "The failing tests are not related."

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 7, 2024
…125337)

Summary:
Fixes #122792

state_dict includes only persistent buffers, while named_buffers() would
include non_persistent buffers.

Pull Request resolved: #125337
Approved by: https://github.com/awgu
ghstack dependencies: #125333, #125501, #125334, #125335, #125336
mvpatel2000 pushed a commit to mvpatel2000/pytorch that referenced this pull request May 17, 2024
Summary:
distributed_state_dict should not try to use `getattr` to get `_extra_state` as this is not well-defined.

Pull Request resolved: pytorch#125336
Approved by: https://github.com/LucasLLC
ghstack dependencies: pytorch#125333, pytorch#125501, pytorch#125334, pytorch#125335
atalman added a commit that referenced this pull request May 27, 2024
* [DSD] Correctly handle _extra_state (#125336)

Summary:
distributed_state_dict should not try to use `getattr` to get `_extra_state` as this is not well-defined.

Pull Request resolved: #125336
Approved by: https://github.com/LucasLLC
ghstack dependencies: #125333, #125501, #125334, #125335

* lint

* lint

---------

Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Andrey Talman <atalman@fb.com>
antoinebrl pushed a commit to antoinebrl/pytorch that referenced this pull request May 27, 2024
…ytorch#125337)

Summary:
Fixes pytorch#122792

state_dict includes only persistent buffers, while named_buffers() would
include non_persistent buffers.

Pull Request resolved: pytorch#125337
Approved by: https://github.com/awgu
ghstack dependencies: pytorch#125333, pytorch#125501, pytorch#125334, pytorch#125335, pytorch#125336
huydhn pushed a commit that referenced this pull request May 27, 2024
…125337) (#127219)

* [DSD] Fix to remove non_persistent buffer in distributed state dict (#125337)

Summary:
Fixes #122792

state_dict includes only persistent buffers, while named_buffers() would
include non_persistent buffers.

Pull Request resolved: #125337
Approved by: https://github.com/awgu
ghstack dependencies: #125333, #125501, #125334, #125335, #125336

* lintrunner

* lint

---------

Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Andrey Talman <atalman@fb.com>
@j316chuck
Copy link

j316chuck commented Jun 5, 2024

Hey @fegin! For this PR, we noticed there was a regression for loading state dict modules of models with tied weight embeddings. Here is the link to the broken test.

Error traceback:

_________________ test_algorithm_resumption[1-SeqLengthWarmup] _________________

tmp_path = PosixPath('/tmp/pytest-of-root/pytest-0/test_algorithm_resumption_1_Se1')
alg_cls = <class 'composer.algorithms.seq_length_warmup.seq_length_warmup.SeqLengthWarmup'>
world_size = 1

    @pytest.mark.gpu
    @pytest.mark.parametrize('alg_cls', get_algs_with_marks())
    @pytest.mark.filterwarnings(
        'ignore:Detected call of `lr_scheduler.step()',
    )  # optimizer.step() sometimes skipped when NaN/inf on low batch size
    @pytest.mark.filterwarnings(r'ignore:.*Plan failed with a cudnnException.*:UserWarning')  # Torch 2.3 regression
    @world_size(1, 2)
    def test_algorithm_resumption(
        tmp_path: pathlib.Path,
        alg_cls: Type[Algorithm],
        world_size,
    ):
        folder1 = os.path.join(tmp_path, 'folder1')
        folder2 = os.path.join(tmp_path, 'folder2')
        os.makedirs(folder1, exist_ok=True)
        os.makedirs(folder2, exist_ok=True)
    
        model = get_alg_model(alg_cls)
        alg_kwargs = get_alg_kwargs(alg_cls)
    
        copied_model = copy.deepcopy(model)  # copy the model so the params will start from the same point
    
        if alg_cls is LayerFreezing:
            pytest.xfail('Known issues')
    
        if alg_cls in (SAM, StochasticDepth):
            pytest.xfail('Mismatch in weights when resuming from a checkpoint.')
    
        if alg_cls is GyroDropout:
            pytest.xfail('GyroDropoutLayer is not implemented in a way that allows correct resumption.')
    
        if alg_cls is SWA and world_size > 1:
            pytest.xfail('SWA is not implemented in a way that is compatible correct resumption on multiple devices.')
    
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)
    
        shared_config = {
            'max_duration': '2ep',
            'save_filename': 'ep{epoch}-rank{rank}',
            'save_interval': '1ep',
            'train_subset_num_batches': 2,
            'precision': 'amp_bf16',
        }
        train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True)
        # train model once, saving checkpoints every epoch
        trainer1 = Trainer(
            model=model,
            train_dataloader=train_dataloader,
            optimizers=optimizer,
            schedulers=scheduler,
            save_folder=folder1,
            algorithms=alg_cls(**alg_kwargs),
            **shared_config,
        )
        trainer1.fit()
    
        # create second trainer, load an intermediate checkpoint
        # and continue training
    
        optimizer = torch.optim.Adam(copied_model.parameters(), lr=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)
    
        alg = alg_cls(**alg_kwargs)
        # SeqLengthWarmup has a call to ._activate_model() that happens on the first call to the algorithm
        # in order to get complete matching of the rng state, we have to cause that extra call to be skipped
        # when reloading.
        if alg_cls is SeqLengthWarmup:
            alg._activated = True  # type: ignore
    
        train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True)
>       trainer2 = Trainer(
            model=copied_model,
            train_dataloader=train_dataloader,
            load_path=os.path.join(folder1, 'ep1-rank{rank}'),
            load_weights_only=False,
            load_strict_model_weights=False,
            optimizers=optimizer,
            schedulers=scheduler,
            save_folder=folder2,
            algorithms=alg,
            **shared_config,
        )

/composer/tests/algorithms/test_algorithm_resumption.py:91: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/composer/composer/trainer/trainer.py:1715: in __init__
    self._rng_state = checkpoint.load_checkpoint(
/composer/composer/utils/checkpoint.py:529: in load_checkpoint
    rng_state_dicts = _restore_checkpoint(
/composer/composer/utils/checkpoint.py:1006: in _restore_checkpoint
    state.load_state_dict(
/composer/composer/core/state.py:1425: in load_state_dict
    self.load_optim_state(state)
/composer/composer/core/state.py:1338: in load_optim_state
    set_optimizer_state_dict(
/composer/composer/trainer/mosaic_fsdp_utils.py:965: in set_optimizer_state_dict
    _load_optim_state_dict(model, optimizers, optim_state_dict, info)
/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict.py:589: in _load_optim_state_dict
    optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

model = HuggingFaceModel(
  (model): BertForMaskedLM(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word...se_affine=True)
        )
        (decoder): Linear(in_features=128, out_features=30522, bias=True)
      )
    )
  )
)
optim = Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.01
    lr: 0.01
    maximize: False
    weight_decay: 0
)
optim_state_dict = {'param_groups': [{'amsgrad': False, 'betas': (0.9, 0.999), 'capturable': False, 'differentiable': False, ...}], 'stat...000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00]]), 'step': tensor(2.)}, ...}}
info = _StateDictInfo(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, str..._prefixes=set(), handle_model=False, handle_optim=True, fsdp_context=<class 'contextlib.nullcontext'>, fsdp_modules=[])

    def _split_optim_state_dict(
        model: nn.Module,
        optim: torch.optim.Optimizer,
        optim_state_dict: OptimizerStateType,
        info: _StateDictInfo,
    ) -> OptimizerStateType:
        """
        Extract the corresponding optim state_dict from ``optim_state_dict`` for
        ``optim`` and return the result optim state_dict.
    
        Args:
            model (nn.Module): the root model.
            optim (torch.optim.Optimizer): the optimizer.
            optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that
                contains the optim state_dict of ``optim``.
            info (_StateDictInfo): state dict information.
    
        Returns:
            The optim state_dict of ``optim``.
        """
    
        state: DictValueType = {}
        pg_state: ListDictValueType = []
        return_osd: OptimizerStateType = {STATE: state, PG: pg_state}
        pg_mapping: Dict[int, int] = {}
    
        for param_group in optim.param_groups:
            pg_state.append({PARAMS: []})
            for param in param_group[PARAMS]:
                for fqn in info.fqn_param_mapping[param]:
                    params = pg_state[-1][PARAMS]
                    assert isinstance(params, list)
                    params.append(fqn)
                    if param.requires_grad:
>                       state[fqn] = cast(DictValueType, optim_state_dict[STATE])[fqn]
E                       KeyError: 'model.cls.predictions.decoder.weight'

/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict.py:559: KeyError
------------------------------ Captured log setup ------------------------------

We believe adding this fix in our monkey patch of the _verify_option function fixes things.

        fqn_param_mapping: Dict[
            Union[str, torch.Tensor], Union[Set[str], torch.Tensor],
        ] = {}
        for name, param in chain(model.named_parameters(), model.named_buffers()):
            fqns = _get_fqns(model, name)
            fqn_param_mapping[param] = fqns
            for fqn in fqns:
                fqn_param_mapping[fqn] = param

        all_fqns = set()
        for name, _ in _iterate_valid_model_state(model):
            fqns = _get_fqns(model, name)
            for fqn in fqns:
                all_fqns.add(fqn)

Curious if you can help us upstream this composer bug fix into pytorch as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants