Skip to content

[GRPO] Gradient mode issue when num_iterations > 1 #2953

@willccbb

Description

@willccbb

When setting num_iterations > 1 I get a runtime error at the first backward pass due to inference-only tensors being used. Runs fine when num_iterations = 1.

Error:

[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/workspace/projects/finetuning/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/projects/finetuning/.venv/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 256, in forward
[rank0]:     hidden_states = self.input_layernorm(hidden_states)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/projects/finetuning/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/projects/finetuning/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/workspace/projects/finetuning/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/projects/finetuning/.venv/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 222, in forward
[rank0]:     return self.weight * hidden_states.to(input_dtype)
[rank0]:            ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[rank0]: RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.

Args:

...
model_name = "Qwen/Qwen2.5-7B-Instruct"

model_kwargs = dict(
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    use_cache=False
)
training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=2e-6,
    beta=0.02,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    num_generations=2,
    num_iterations=2,
    max_prompt_length=1024,
    max_completion_length=512,
    num_train_epochs=1,
    save_strategy='epoch',
    save_only_model=True,
    max_grad_norm=1.0,
    report_to="wandb",
    use_vllm=True,
    vllm_gpu_memory_utilization=0.7,
    vllm_device="cuda:7",
    log_on_each_node=False,
)

#model = AutoLigerKernelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    **model_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        format_reward_func,
        int_reward_func,
        close_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

Launch command (8xH100, 7 train + 1 vLLM):

accelerate launch --num-processes 7 --config-file configs/accelerate/zero3.yaml grpo_demo.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 GRPORelated to GRPO🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions