Skip to content

Prototype multi-gpu support with PPO #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed

Prototype multi-gpu support with PPO #162

wants to merge 18 commits into from

Conversation

vwxyzjn
Copy link
Owner

@vwxyzjn vwxyzjn commented Apr 16, 2022

Description

This PR contains some prototypes that bring multi-GPU support to PPO. There are many ways to do it so this PR tries to compare different approaches.

I don't really have multi-GPUs to test this out, so I launch two processes accessing the same GPU, mainly to test out if they would result in the same performance. In theory, multi-GPU support should not harm the performance. However, I plan to real multi-GPU performance on the sample efficiency is not affected.

Option 1: ppo_atari_multigpu.py

My first try is ppo_atari_multigpu.py, which uses pytorch's low-level distributed API as shown in this link or this example:

""" Gradient averaging. """
def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
        param.grad.data /= size

Option 2 ppo_atari_multigpu_batch_reduce.py

In this file, I adopted entity-neural-network/incubator#220 by batch reducing the gradient.

Option 3 ppo_atari_ddp.py

In this file, I adopted the high-level API here https://pytorch.org/docs/stable/notes/ddp.html.

Option 4 ppo_atari_elastic.py

This file adopts https://pytorch.org/docs/stable/elastic/run.html. We can run the training script via torchrun --standalone --nnodes=1 --nproc_per_node=2 ppo_atari_elastic.py

image

The sample efficiency seems to suffer, as shown above, however the wall-time performance is pretty good.

My suspicion for the sample efficiency regression is that policy gradient averaging is more tricky compared to the value gradient averaging: see #162 (comment)

Options 3 and 4 are pretty impressive: they reduce the wall-time by half while using only a single GPU: maybe by using multi-GPU the speed up can be even more?

Types of changes

  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted @vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format with width=500 and height=300).
    • I have added links to the tracked experiments.
  • I have updated the tests accordingly (if applicable).

@vercel
Copy link

vercel bot commented Apr 16, 2022

This pull request is being automatically deployed with Vercel (learn more).
To see the status of your deployment, click below or on the icon next to each commit.

🔍 Inspect: https://vercel.com/vwxyzjn/cleanrl/AuJfHiX83yegoMeXn5bVMydxxZKP
✅ Preview: https://cleanrl-git-ppo-multi-gpu-vwxyzjn.vercel.app

@gitpod-io
Copy link

gitpod-io bot commented Apr 16, 2022

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Apr 17, 2022

Here is a script for understanding how gradient accumulation works and data parallelism works:

import torch.distributed as dist
import torch.multiprocessing as mp
import os
import torch

def init_process(rank, size, fn, backend="gloo"):
    """Initialize the distributed environment."""
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)

def train(rank: int, size: int):
    prediction = torch.tensor(
        [[1.,2.,3.,4.], [4.,7.,5.,8.]],
    requires_grad=True)
    label = torch.tensor([[0.,0.,0.,0.], [0.,0.,0.,0.]])
    loss = (prediction[rank] - label[rank]) ** 2
    loss.mean().backward()
    dist.all_reduce(prediction.grad.data, op=dist.ReduceOp.SUM)
    prediction.grad.data /= size
    print("gradient with data parallelism (multi-gpu) \n", prediction.grad)
    

if __name__ == "__main__":
    prediction = torch.tensor(
        [[1.,2.,3.,4.], [4.,7.,5.,8.]],
    requires_grad=True)
    label = torch.tensor([[0.,0.,0.,0.], [0.,0.,0.,0.]])

    loss = (prediction - label) ** 2
    loss.mean().backward()
    print("gradient with the whole batch\n", prediction.grad)

    prediction = torch.tensor(
        [[1.,2.,3.,4.], [4.,7.,5.,8.]],
    requires_grad=True)
    label = torch.tensor([[0.,0.,0.,0.], [0.,0.,0.,0.]])
    # do backward pass in two minibatches
    loss = (prediction[0] - label[0]) ** 2
    loss.mean().backward()
    loss = (prediction[1] - label[1]) ** 2
    loss.mean().backward()
    # divide the gradient by the size
    print("gradient accumulation \n", prediction.grad / 2)
    
    size = 2
    processes = []
    mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=init_process, args=(rank, size, train))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

the script yields the following output:

gradient with the whole batch
 tensor([[0.2500, 0.5000, 0.7500, 1.0000],
        [1.0000, 1.7500, 1.2500, 2.0000]])
gradient accumulation 
 tensor([[0.2500, 0.5000, 0.7500, 1.0000],
        [1.0000, 1.7500, 1.2500, 2.0000]])
gradient with data parallelism (multi-gpu) 
 tensor([[0.2500, 0.5000, 0.7500, 1.0000],
        [1.0000, 1.7500, 1.2500, 2.0000]])
gradient with data parallelism (multi-gpu) 
 tensor([[0.2500, 0.5000, 0.7500, 1.0000],
        [1.0000, 1.7500, 1.2500, 2.0000]])

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Apr 17, 2022

Attempt 1

image

The sample efficiency seems to suffer, as shown above, however the wall-time performance is pretty good.

My suspicion for the sample efficiency regression is that policy gradient averaging is more tricky compared to the value gradient averaging: see #162 (comment)

Options 3 and 4 are pretty impressive: they reduce the wall-time by half while using only a single GPU: maybe by using multi-GPU the speed up can be even more?

Some notes

ppo_atari.py does a single forward and backward pass on a minibatch of size 256, whereas ppo_atari_multigpu.py does two forward and backward passes in two separate processes on two minibatches of size 128, call dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) to sum the gradient then divide the gradient by the number of processes 2. This is basically gradient accumulation, right?

The following script (see here for full script) demonstrates that such a practice results in the same gradient for value function, but not the same gradient for the policy function.

optimizer.zero_grad()
start = 0
end = start + args.minibatch_size
mb_inds = b_inds[start:end]
fit_vloss(mb_inds)
print()
print(f"CASE 1: value function: forward and backward pass of the minibatch (size 256: i.e., data[0:256])")
print("agent.critic.weight.grad.sum() =", agent.critic.weight.grad.sum())

optimizer.zero_grad()
args.minibatch_size = 128
for start in [0, 128]:
    end = start + args.minibatch_size
    mb_inds = b_inds[start:end]
    fit_vloss(mb_inds)
print()
print(f"CASE 2: value function: forward and backward pass of 2 minibatches (size 128: i.e., data[0:128] and data[128:256])")
print("agent.critic.weight.grad.sum() / 2 =", agent.critic.weight.grad.sum() / 2)

optimizer.zero_grad()
start = 0
end = start + args.minibatch_size
mb_inds = b_inds[start:end]
fit_pgloss(mb_inds)
print()
print(f"CASE 3: policy function: forward and backward pass of the minibatch (size 256: i.e., data[0:256])")
print("agent.actor.weight.grad.sum() =", agent.actor.weight.grad.sum())

optimizer.zero_grad()
args.minibatch_size = 128
for start in [0, 128]:
    end = start + args.minibatch_size
    mb_inds = b_inds[start:end]
    fit_pgloss(mb_inds)
print()
print(f"CASE 4: policy function: forward and backward pass of 2 minibatches (size 128: i.e., data[0:128] and data[128:256])")
print("agent.actor.weight.grad.sum() / 2 =", agent.actor.weight.grad.sum() / 2)
CASE 1: value function: forward and backward pass of the minibatch (size 256: i.e., data[0:256])
agent.critic.weight.grad.sum() = tensor(-3.1651, device='cuda:0')

CASE 2: value function: forward and backward pass of 2 minibatches (size 128: i.e., data[0:128] and data[128:256])
agent.critic.weight.grad.sum() / 2 = tensor(-3.1651, device='cuda:0')

as shown their gradients are the same, this is essentially gradient accumulation

CASE 3: policy function: forward and backward pass of the minibatch (size 256: i.e., data[0:256])
agent.actor.weight.grad.sum() = tensor(-1.1659e-07, device='cuda:0')

CASE 4: policy function: forward and backward pass of 2 minibatches (size 128: i.e., data[0:128] and data[128:256])
agent.actor.weight.grad.sum() / 2 = tensor(-3.1557e-07, device='cuda:0')

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Apr 17, 2022

Here is an even simpler demo. The issue is Categorical.log_prob does not work with gradient accumulation. See pytorch/pytorch#75948

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical

class Agent(nn.Module):
    def __init__(self, action_n):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
        )
        self.actor = nn.Linear(512, action_n)
        self.critic = nn.Linear(512, 1)

    def get_value(self, x):
        return self.critic(self.network(x / 255.0))

    def get_action_and_value(self, x, action=None):
        hidden = self.network(x / 255.0)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)

# setup
agent = Agent(4)
optimizer = optim.Adam(agent.parameters())
next_obs = torch.rand(8, 4, 84, 84)
action, newlogprob, entropy, newvalue = agent.get_action_and_value(next_obs)

optimizer.zero_grad()
_, newlogprob, _, newvalue = agent.get_action_and_value(next_obs, action)
newvalue.mean().backward()
print(f"`agent.critic.weight.grad.sum() = {agent.critic.weight.grad.sum()}` after fitting value loss using data[0:8]")

optimizer.zero_grad()
_, newlogprob, _, newvalue = agent.get_action_and_value(next_obs[0:4], action[0:4])
newvalue.mean().backward()
_, newlogprob, _, newvalue = agent.get_action_and_value(next_obs[4:8], action[4:8])
newvalue.mean().backward()
print(f"`agent.critic.weight.grad.sum() / 2 = {agent.critic.weight.grad.sum() / 2}` after fitting value loss using data[0:4] and data[4:8] respectively")

optimizer.zero_grad()
_, newlogprob, _, newvalue = agent.get_action_and_value(next_obs, action)
newlogprob.mean().backward()
print(f"`agent.actor.weight.grad.sum() = {agent.actor.weight.grad.sum()}` after fitting value loss using data[0:8]")

optimizer.zero_grad()
_, newlogprob, _, newvalue = agent.get_action_and_value(next_obs[0:4], action[0:4])
newlogprob.mean().backward()
_, newlogprob, _, newvalue = agent.get_action_and_value(next_obs[4:8], action[4:8])
newlogprob.mean().backward()
print(f"`agent.actor.weight.grad.sum() / 2 = {agent.actor.weight.grad.sum() / 2}` after fitting value loss using data[0:4] and data[4:8] respectively")

vwxyzjn added 2 commits April 17, 2022 16:20
otherwise agent.module.get_action_value won't trigger proper gradient sync
@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Apr 17, 2022

Attempt 2: much more successful

image

Fixed a couple of issues:

  • In ppo_atari_multigpu.py and ppo_atari_multigpu_batch_reduce.py, I needed to set different environment random seeds used in separate processes. Previously I didn't and thus the script is essentially only learning from half of the experience.
  • In ppo_atari_ddp.py and ppo_atari_elastic.py, for some reason agent.module.get_action_value does not work with multiprocesses. I needed to rename get_action_value to forward to make things work.

args.num_envs = int(args.num_envs / size)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
dist.init_process_group("gloo", rank=rank, world_size=size)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to PyTorch doc, you should be using NCCL for multi-gpu training, Gloo is recommended for multi-cpu training: https://pytorch.org/docs/stable/distributed.html

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t actually have a multi GPU stuff to try it out and NCCL would break :) but this is something I should try find a solution for.

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Apr 24, 2022

48C88F0D-C3A9-4EC7-853D-2E216DB9DBEB

https://wandb.ai/costa-huang/cleanRL/reports/Data-Parallelism-Experiment--VmlldzoxODI1OTY0

Ok did more testing and it looks like ppo_atari_multigpu_batch_reduce.py has the highest performance. I do need to actually find a multi-GPU machine to test this out though. There are a couple of settings:

  • w/wo actual multi-GPUs
  • gloo vs nccl backend.

Worth testing it out with ppo_atari_multigpu_batch_reduce.py and ppo_atari_elastic.py

@vwxyzjn vwxyzjn mentioned this pull request May 4, 2022
16 tasks
@vwxyzjn
Copy link
Owner Author

vwxyzjn commented May 5, 2022

Closing in favor of #178

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants