Skip to content

no attribute 'policy' when pushing PPOTrainer to hub using example ppo.py script #3301

@AMindToThink

Description

@AMindToThink

Reproduction

I'm working to add support for keeping the value model (critic model) that is created by PPOTrainer. I'm running into an odd issue where running the example script with the push_to_hub argument results in an error, but then the model is pushed anyways.

This might not seem like a problem, but when I try changing save_model so that both models can be saved, I get the error and then neither model is pushed.

I supsect this has something to do with code running in parallel interfering with each other.

I would much appreciate some help!

(Additionally, the progress bars don't indicate when the push finishes! This confused me greatly)

Standard Command Run (straight from examples/scripts/ppo/ppo.py)

(trl) cs29824@sting-vm-1:~/matthew/trl$ python -i examples/scripts/ppo/ppo.py     --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style     --dataset_train_split descriptiveness     --learning_rate 3e-6     --output_dir models/minimal/ppo_push_main_12     --per_device_train_batch_size 64     --gradient_accumulation_steps 1     --total_episodes 1     --model_name_or_path EleutherAI/pythia-1b-deduped     --missing_eos_penalty 1.0 --push_to_hub

Resulting error of standard command:

Traceback (most recent call last):
  File "/home/cs29824/matthew/trl/examples/scripts/ppo/ppo.py", line 166, in <module>
    trainer.save_model(training_args.output_dir)
  File "/home/cs29824/matthew/trl/trl/trainer/ppo_trainer.py", line 338, in save_model
    super().save_model(output_dir, _internal_call)
  File "/home/cs29824/matthew/trl/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3906, in save_model
    self.push_to_hub(commit_message="Model save")
  File "/home/cs29824/matthew/trl/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 4816, in push_to_hub
    self.save_model(_internal_call=True)
  File "/home/cs29824/matthew/trl/trl/trainer/ppo_trainer.py", line 332, in save_model
    self.model = self.model.policy  # save only the policy
  File "/home/cs29824/matthew/trl/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1928, in __getattr__
    raise AttributeError(
AttributeError: 'GPTNeoXForCausalLM' object has no attribute 'policy'
training_args.bin: 100%|█| 6.20k/6.20k [00:00<00:00, 91.4kB
model.safetensors: 100%|█| 649M/649M [00:15<00:00, 40.7MB/s
Upload 2 LFS files: 100%|████| 2/2 [00:16<00:00,  8.10s/it]
Upload 2 LFS files:  50%|██  | 1/2 [00:16<00:16, 16.19s/it]
>>> 

Change I want to make

I also modified the PPOConfig to include a save_value_model boolean.

    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
            
        backup_model = self.model
        self.model = self.model.policy  # save only the policy
        if self.is_deepspeed_enabled:
            backup_deepspeed = self.deepspeed
            self.deepspeed = self.model
        policy_output_dir = output_dir if not self.args.save_value_model else os.path.join(output_dir, "policy_model")
        super().save_model(policy_output_dir, _internal_call)

        self.model = backup_model

        if self.is_deepspeed_enabled:
            self.deepspeed = backup_deepspeed
        
        if self.args.save_value_model:
            backup_model = self.model
            self.model = self.model.value_model

            if self.is_deepspeed_enabled:
                backup_deepspeed = self.deepspeed
                self.deepspeed = self.model
            value_output_dir = output_dir if not self.args.save_value_model else os.path.join(output_dir, "value_model")
            super().save_model(value_output_dir, _internal_call)
            self.model = backup_model

            if self.is_deepspeed_enabled:
                self.deepspeed = backup_deepspeed

System Info

  • Platform: Linux-5.15.0-1073-kvm-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • TRL version: 0.17.0.dev0+df737f9
  • PyTorch version: 2.6.0
  • CUDA device(s): Quadro RTX 8000, Quadro RTX 8000
  • Transformers version: 4.51.3
  • Accelerate version: 1.3.0
  • Accelerate config:
    • compute_environment: LOCAL_MACHINE
    • distributed_type: NO
    • mixed_precision: fp16
    • use_cpu: False
    • debug: False
    • num_processes: 1
    • machine_rank: 0
    • num_machines: 1
    • gpu_ids: 0
    • rdzv_backend: static
    • same_network: True
    • main_training_function: main
    • enable_cpu_affinity: False
    • downcast_bf16: no
    • tpu_use_cluster: False
    • tpu_use_sudo: False
    • tpu_env: []
  • Datasets version: 3.5.0
  • HF Hub version: 0.30.2
  • bitsandbytes version: 0.45.5
  • DeepSpeed version: 0.16.5
  • Diffusers version: 0.32.2
  • Liger-Kernel version: 0.5.8
  • LLM-Blender version: 0.0.2
  • OpenAI version: 1.74.0
  • PEFT version: 0.15.2
  • vLLM version: 0.8.4

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 PPORelated to PPO🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions