Skip to content

[BUG] Vectorized Environment Autoreset Incompatible with openai/baselines' API #194

@vwxyzjn

Description

@vwxyzjn

Describe the bug

Related to #33.

When an environment is "done", the autoreset feature in openai/gym' API will reset this environment and return the initial observation from the next episode. Here is a simple demonstration of how it works with gym==0.23.1:

import gym

class TestEnv(gym.Env):
    def __init__(self):
        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Discrete(10)
        
    def reset(self):
        self.obs = 0
        return self.obs

    def step(self, action):
        self.obs += 1
        return self.obs, 0, False, {}

def thunk():
    env = TestEnv()
    env = gym.wrappers.TimeLimit(env, max_episode_steps=4)
    return env

env = gym.vector.SyncVectorEnv([thunk])
env.reset()
print(env.step([0]))
print(env.step([0]))
print(env.step([0]))
print(env.step([0]))
(array([1]), array([0.]), array([False]), [{}])
(array([2]), array([0.]), array([False]), [{}])
(array([3]), array([0.]), array([False]), [{}])
(array([0]), array([0.]), array([ True]), [{'TimeLimit.truncated': True, 'terminal_observation': 4}])

Note that done=True and obs=0 is returned in this example, and the truncated observation is put to the info dict.

However, envpool does not implement this behavior and will only return the initial observation of the next episode after an additional step. See reproduction below.

To Reproduce

import envpool
import numpy as np
import matplotlib.pyplot as plt

# make gym env
env = envpool.make(
    "Breakout-v5",
    env_type="gym",
    num_envs=1,
    max_episode_steps=4
)
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(5.5, 3.5), dpi=200)

obs = env.reset()
plt.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
plt.savefig("static/reset.png")
for i, ax in zip(range(8), axs.flatten()):
    act = np.zeros(1, dtype=int) + 2
    obs, rew, done, info = env.step(act)
    ax.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.set_title(f"step {i}, d={int(done[0])}")
    print(i, done, info['TimeLimit.truncated'])
plt.savefig(f"static/envpool.png")

With stable_baselines3==1.2.0.

import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3.common.atari_wrappers import (  # isort:skip
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
import gym

def thunk():
    env = gym.make("BreakoutNoFrameskip-v4")
    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    # env = EpisodicLifeEnv(env) # have to comment this out due to how timelimit works
    if "FIRE" in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = ClipRewardEnv(env)
    env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.GrayScaleObservation(env)
    env = gym.wrappers.FrameStack(env, 4)
    env = gym.wrappers.TimeLimit(env, max_episode_steps=4)
    return env
env = gym.vector.SyncVectorEnv([thunk])
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(5.5, 3.5), dpi=200)

obs = env.reset()
plt.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
plt.savefig("static/gym-reset.png")
for i, ax in zip(range(8), axs.flatten()):
    act = np.zeros(1, dtype=int) + 2
    obs, rew, done, info = env.step(act)
    ax.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.set_title(f"step {i}, d={int(done[0])}")
    print(i, done, info)

plt.savefig(f"static/gym.png")

image

It can be observed from the picture that envpool returns the initial observation of the new episode at step 4, whereas gym's vec env returns it at step 3, the same time when done=True happens.

Expected behavior

In the screenshot above, envpool should return the initial observation of the new episode at step 3.

This is highly relevant to return calculation as it causes an off by 1 error. Consider the following return calculation:

import numpy as np

# assume the game is terminated and resulted in the terminated observation of obs3
rewards = np.array([1, 0.1, 0.01, 2, 0.1, 0.01, 0.001, 0.0001]).reshape(-1, 1)
dones = np.array([0, 0, 0, 1, 0, 0, 0, 0]).reshape(-1, 1)
gamma = 1.0
num_steps = 8
next_done = 0
next_value = 0.0005 # value of obs8
returns = np.zeros_like(rewards)
for t in reversed(range(num_steps)): 
    if t == num_steps - 1: 
        nextnonterminal = 1.0 - next_done 
        next_return = next_value 
    else: 
        nextnonterminal = 1.0 - dones[t + 1] 
        next_return = returns[t + 1] 
    returns[t] = rewards[t] + gamma * nextnonterminal * next_return
print(list(returns))


# assume the game is truncated and resulted in the truncated observation of obs3
rewards = np.array([1, 0.1, 0.01, 2, 0.1, 0.01, 0.001, 0.0001]).reshape(-1, 1)
dones = np.array([0, 0, 0, 1, 0, 0, 0, 0]).reshape(-1, 1)
v_obs3 = 0.008
rewards[2] += v_obs3
next_done = 0
next_value = 0.0005 # value of obs8
returns = np.zeros_like(rewards)
for t in reversed(range(num_steps)): 
    if t == num_steps - 1: 
        nextnonterminal = 1.0 - next_done 
        next_return = next_value 
    else: 
        nextnonterminal = 1.0 - dones[t + 1] 
        next_return = returns[t + 1] 
    returns[t] = rewards[t] + gamma * nextnonterminal * next_return
print(list(returns))
[array([1.11]), array([0.11]), array([0.01]), array([2.1116]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]
[array([1.118]), array([0.118]), array([0.018]), array([2.1116]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]

which calculates the returns for two trajectories correctly.

If the dones are off by 1 like np.array([0, 0, 0, 0, 1, 0, 0, 0]).reshape(-1, 1), the results will be quite different.

[array([3.11]), array([2.11]), array([2.01]), array([2.]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]
[array([3.118]), array([2.118]), array([2.018]), array([2.]), array([0.1116]), array([0.0116]), array([0.0016]), array([0.0006])]

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)
0.6.4 1.21.6 3.7.8 (default, Mar 30 2022, 09:38:46) 
[GCC 11.2.0] linux

Additional context

There are ways to manually trigger reset by doing env.reset(done_env_ids) as follows, but this is not supported in both XLA and async API.

import envpool
import numpy as np
import matplotlib.pyplot as plt

# make gym env
env = envpool.make(
    "Breakout-v5",
    env_type="gym",
    num_envs=1,
    max_episode_steps=4
)
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(5.5, 3.5), dpi=200)

obs = env.reset()
plt.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
plt.savefig("static/reset.png")
for i, ax in zip(range(8), axs.flatten()):
    act = np.zeros(1, dtype=int) + 2
    obs, rew, done, info = env.step(act)

    # proper auto-reset
    auto_reset_ids = np.where((info['TimeLimit.truncated'] or info['terminated']) == 1)[0]
    obs[auto_reset_ids] = env.reset(auto_reset_ids)

    ax.imshow(obs[0,3].reshape(84, 84, 1), interpolation='nearest')
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.set_title(f"step {i}, d={int(done[0])}")
    print(i, done, info['TimeLimit.truncated'], info['terminated'])
plt.savefig(f"static/envpool_mannual.png")

Reason and Possible fixes

A possible solution is to add a last_observation key like in the gym's API.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions