-
Notifications
You must be signed in to change notification settings - Fork 824
Closed
Description
Problem description.
The regular advantage calculation in PPO is a special case of the GAE advantage calculation when gae_lambda=1
- we empirically demonstrate this with the debugging output in the bottom. Based on this result, we should remove
Lines 232 to 242 in 94a685d
else: | |
returns = torch.zeros_like(rewards).to(device) | |
for t in reversed(range(args.num_steps)): | |
if t == args.num_steps - 1: | |
nextnonterminal = 1.0 - next_done | |
next_return = next_value | |
else: | |
nextnonterminal = 1.0 - dones[t + 1] | |
next_return = returns[t + 1] | |
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return | |
advantages = returns - values |
Debugging output
(cleanrl-ghSZGHE3-py3.9) ➜ cleanrl git:(explain-non-modular) ✗ ipython -i ppo.py
Python 3.9.5 (default, Jul 19 2021, 13:27:26)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.4.0 -- An enhanced Interactive Python. Type '?' for help.
/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/gym/utils/passive_env_checker.py:97: UserWarning: WARN: We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html
logger.warn(
/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/gym/core.py:200: DeprecationWarning: WARN: Function `env.seed(seed)` is marked as deprecated and will be removed in the future. Please use `env.reset(seed=seed)` instead.
deprecation(
global_step=36, episodic_return=9.0
global_step=52, episodic_return=13.0
global_step=100, episodic_return=25.0
global_step=112, episodic_return=19.0
global_step=128, episodic_return=32.0
global_step=144, episodic_return=11.0
global_step=152, episodic_return=13.0
global_step=176, episodic_return=12.0
global_step=196, episodic_return=11.0
global_step=228, episodic_return=13.0
global_step=260, episodic_return=16.0
global_step=296, episodic_return=46.0
global_step=300, episodic_return=39.0
global_step=312, episodic_return=13.0
global_step=360, episodic_return=15.0
global_step=388, episodic_return=40.0
global_step=400, episodic_return=22.0
global_step=408, episodic_return=28.0
global_step=440, episodic_return=13.0
global_step=460, episodic_return=15.0
global_step=484, episodic_return=31.0
global_step=500, episodic_return=23.0
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File ~/Documents/go/src/github.com/cleanrl/cleanrl/ppo.py:243, in <module>
241 returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
242 advantages = returns - values
--> 243 raise
244 # flatten the batch
245 b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
RuntimeError: No active exception to reraise
In [1]: returns = torch.zeros_like(rewards).to(device)
...: for t in reversed(range(args.num_steps)):
...: if t == args.num_steps - 1:
...: nextnonterminal = 1.0 - next_done
...: next_return = next_value
...: else:
...: nextnonterminal = 1.0 - dones[t + 1]
...: next_return = returns[t + 1]
...: returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
...: advantages = returns - values
In [2]: returns.sum()
Out[2]: tensor(6017.7227, device='cuda:0')
In [3]: advantages.sum()
Out[3]: tensor(6005.0435, device='cuda:0')
In [4]: advantages = torch.zeros_like(rewards).to(device)
...: lastgaelam = 0
...: for t in reversed(range(args.num_steps)):
...: if t == args.num_steps - 1:
...: nextnonterminal = 1.0 - next_done
...: nextvalues = next_value
...: else:
...: nextnonterminal = 1.0 - dones[t + 1]
...: nextvalues = values[t + 1]
...: delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
...: advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
...: returns = advantages + values
In [5]: returns.sum()
Out[5]: tensor(4088.1948, device='cuda:0')
In [6]: advantages.sum()
Out[6]: tensor(4075.5161, device='cuda:0')
In [7]: args.gae_lambda
Out[7]: 0.95
In [8]: args.gae_lambda = 1
In [9]: advantages = torch.zeros_like(rewards).to(device)
...: lastgaelam = 0
...: for t in reversed(range(args.num_steps)):
...: if t == args.num_steps - 1:
...: nextnonterminal = 1.0 - next_done
...: nextvalues = next_value
...: else:
...: nextnonterminal = 1.0 - dones[t + 1]
...: nextvalues = values[t + 1]
...: delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
...: advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
...: returns = advantages + values
In [10]: returns.sum()
Out[10]: tensor(6017.7227, device='cuda:0')
In [11]: advantages.sum()
Out[11]: tensor(6005.0435, device='cuda:0')
Metadata
Metadata
Assignees
Labels
No labels