-
Notifications
You must be signed in to change notification settings - Fork 825
Closed
Description
Problem Description
A lot of the formatting changes are suggested by @Howuhh
1. Refactor on next_done
The current code to handle done
looks like this
next_obs, reward, done, info = envs.step(action.cpu().numpy())
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
which is fine, but when I tried to adapt isaacgym it became an issue. Specifically, I thought the to(device)
code is no longer needed so just did
next_obs, reward, done, info = envs.step(action)
but this is wrong because I should have done next_done = done
. The current next_done = torch.Tensor(done).to(device)
just does not make a lot of sense.
We should refactor it to
next_obs, reward, next_done, info = envs.step(action.cpu().numpy())
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
2. make_env
refactor
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
to
if capture_video and idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
3. flatten batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
to
b_obs = obs.flatten(0, 1)
b_actions = actions.flatten(0, 1)
b_logprobs = logprobs.reshape(-1)
b_returns = returns.reshape(-1)
b_advantages = advantages.reshape(-1)
b_values = values.reshape(-1)
4.
if args.target_kl is not None:
if approx_kl > args.target_kl:
break
to
if args.target_kl is not None and approx_kl > args.target_kl:
break
5.
Line 209 in 9a74142
global_step += 1 * args.num_envs |
to
global_step += args.num_envs
6.
move
Line 183 in 9a74142
num_updates = args.total_timesteps // args.batch_size |
to the argparse.
Metadata
Metadata
Assignees
Labels
No labels