-
Notifications
You must be signed in to change notification settings - Fork 114
Description
Describe the bug
Related to #33.
When an environment is "done", the autoreset feature in openai/gym' API will reset this environment and return the initial observation from the next episode. Here is a simple demonstration of how it works with gym==0.23.1
:
import gym
class TestEnv(gym.Env):
def __init__(self):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Discrete(10)
def reset(self):
self.obs = 0
return self.obs
def step(self, action):
self.obs += 1
return self.obs, 0, False, {}
def thunk():
env = TestEnv()
env = gym.wrappers.TimeLimit(env, max_episode_steps=4)
return env
env = gym.vector.SyncVectorEnv([thunk])
env.reset()
print(env.step([0]))
print(env.step([0]))
print(env.step([0]))
print(env.step([0]))
(array([1]), array([0.]), array([False]), [{}])
(array([2]), array([0.]), array([False]), [{}])
(array([3]), array([0.]), array([False]), [{}])
(array([0]), array([0.]), array([ True]), [{'TimeLimit.truncated': True, 'terminal_observation': 4}])
Note that done=True
and obs=0
is returned in this example, and the truncated observation is put to the info
dict.
However, envpool does not implement this behavior and will only return the initial observation of the next episode after an additional step. See reproduction below.
To Reproduce
import envpool
import numpy as np
import matplotlib.pyplot as plt
# make gym env
env = envpool.make(
"Breakout-v5",
env_type="gym",
num_envs=1,
max_episode_steps=4
)
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(5.5, 3.5), dpi=200)
obs = env.reset()
plt.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
plt.savefig("static/reset.png")
for i, ax in zip(range(8), axs.flatten()):
act = np.zeros(1, dtype=int) + 2
obs, rew, done, info = env.step(act)
ax.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
ax.yaxis.set_visible(False)
ax.xaxis.set_visible(False)
ax.set_title(f"step {i}, d={int(done[0])}")
print(i, done, info['TimeLimit.truncated'])
plt.savefig(f"static/envpool.png")
With stable_baselines3==1.2.0
.
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3.common.atari_wrappers import ( # isort:skip
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv,
)
import gym
def thunk():
env = gym.make("BreakoutNoFrameskip-v4")
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
# env = EpisodicLifeEnv(env) # have to comment this out due to how timelimit works
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env = gym.wrappers.TimeLimit(env, max_episode_steps=4)
return env
env = gym.vector.SyncVectorEnv([thunk])
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(5.5, 3.5), dpi=200)
obs = env.reset()
plt.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
plt.savefig("static/gym-reset.png")
for i, ax in zip(range(8), axs.flatten()):
act = np.zeros(1, dtype=int) + 2
obs, rew, done, info = env.step(act)
ax.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
ax.yaxis.set_visible(False)
ax.xaxis.set_visible(False)
ax.set_title(f"step {i}, d={int(done[0])}")
print(i, done, info)
plt.savefig(f"static/gym.png")
It can be observed from the picture that envpool returns the initial observation of the new episode at step 4, whereas gym's vec env returns it at step 3, the same time when done=True
happens.
Expected behavior
In the screenshot above, envpool should return the initial observation of the new episode at step 3.
This is highly relevant to return calculation as it causes an off by 1 error. Consider the following return calculation:
import numpy as np
# assume the game is terminated and resulted in the terminated observation of obs3
rewards = np.array([1, 0.1, 0.01, 2, 0.1, 0.01, 0.001, 0.0001]).reshape(-1, 1)
dones = np.array([0, 0, 0, 1, 0, 0, 0, 0]).reshape(-1, 1)
gamma = 1.0
num_steps = 8
next_done = 0
next_value = 0.0005 # value of obs8
returns = np.zeros_like(rewards)
for t in reversed(range(num_steps)):
if t == num_steps - 1:
nextnonterminal = 1.0 - next_done
next_return = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
next_return = returns[t + 1]
returns[t] = rewards[t] + gamma * nextnonterminal * next_return
print(list(returns))
# assume the game is truncated and resulted in the truncated observation of obs3
rewards = np.array([1, 0.1, 0.01, 2, 0.1, 0.01, 0.001, 0.0001]).reshape(-1, 1)
dones = np.array([0, 0, 0, 1, 0, 0, 0, 0]).reshape(-1, 1)
v_obs3 = 0.008
rewards[2] += v_obs3
next_done = 0
next_value = 0.0005 # value of obs8
returns = np.zeros_like(rewards)
for t in reversed(range(num_steps)):
if t == num_steps - 1:
nextnonterminal = 1.0 - next_done
next_return = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
next_return = returns[t + 1]
returns[t] = rewards[t] + gamma * nextnonterminal * next_return
print(list(returns))
[array([1.11]), array([0.11]), array([0.01]), array([2.1116]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]
[array([1.118]), array([0.118]), array([0.018]), array([2.1116]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]
which calculates the returns for two trajectories correctly.
If the dones
are off by 1 like np.array([0, 0, 0, 0, 1, 0, 0, 0]).reshape(-1, 1)
, the results will be quite different.
[array([3.11]), array([2.11]), array([2.01]), array([2.]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]
[array([3.118]), array([2.118]), array([2.018]), array([2.]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
- Python version
- Versions of any other relevant libraries
import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)
0.6.4 1.21.6 3.7.8 (default, Mar 30 2022, 09:38:46)
[GCC 11.2.0] linux
Additional context
There are ways to manually trigger reset by doing env.reset(done_env_ids)
as follows, but this is not supported in both XLA and async API.
import envpool
import numpy as np
import matplotlib.pyplot as plt
# make gym env
env = envpool.make(
"Breakout-v5",
env_type="gym",
num_envs=1,
max_episode_steps=4
)
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(5.5, 3.5), dpi=200)
obs = env.reset()
plt.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
plt.savefig("static/reset.png")
for i, ax in zip(range(8), axs.flatten()):
act = np.zeros(1, dtype=int) + 2
obs, rew, done, info = env.step(act)
# proper auto-reset
auto_reset_ids = np.where((info['TimeLimit.truncated'] or info['terminated']) == 1)[0]
obs[auto_reset_ids] = env.reset(auto_reset_ids)
ax.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
ax.yaxis.set_visible(False)
ax.xaxis.set_visible(False)
ax.set_title(f"step {i}, d={int(done[0])}")
print(i, done, info['TimeLimit.truncated'], info['terminated'])
plt.savefig(f"static/envpool_mannual.png")
Reason and Possible fixes
A possible solution is to add a last_observation
key like in the gym's API.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)