Skip to content

Various minor PPO refactors #167

@vwxyzjn

Description

@vwxyzjn

Problem Description

A lot of the formatting changes are suggested by @Howuhh

1. Refactor on next_done

The current code to handle done looks like this

            next_obs, reward, done, info = envs.step(action.cpu().numpy())
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

which is fine, but when I tried to adapt isaacgym it became an issue. Specifically, I thought the to(device) code is no longer needed so just did

            next_obs, reward, done, info = envs.step(action)

but this is wrong because I should have done next_done = done. The current next_done = torch.Tensor(done).to(device) just does not make a lot of sense.

We should refactor it to

            next_obs, reward, next_done, info = envs.step(action.cpu().numpy())
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

2. make_env refactor

if capture_video:
    if idx == 0:
        env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")

to

if capture_video and idx == 0:
    env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")

3. flatten batch

        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

to

        b_obs = obs.flatten(0, 1)
        b_actions = actions.flatten(0, 1)
        b_logprobs = logprobs.reshape(-1)
        b_returns = returns.reshape(-1)
        b_advantages = advantages.reshape(-1)
        b_values = values.reshape(-1)

4.


            if args.target_kl is not None:
                if approx_kl > args.target_kl:
                    break

to

            if args.target_kl is not None and approx_kl > args.target_kl:
                break

5.

global_step += 1 * args.num_envs

to

global_step += args.num_envs

6.

move

num_updates = args.total_timesteps // args.batch_size

to the argparse.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions