Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce interactive policies to gather data from a user #776

Merged
merged 16 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/train_dagger_atari_interactive_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Training DAgger with an interactive policy that queries the user for actions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a matplotlib GUI backend isn't installed, it'll fail with a somewhat cryptic error:

  fig.show()

and indeed no figure displays.

Installing the relevant backend seems out-of-scope for this project. But might want to check if the backend is interactive (I think plt.isinteractive() checks for this) and warn if not with link to relevant docs e.g. https://matplotlib.org/stable/users/explain/backends.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please check again what message you are getting and if this is an error or a warning? In my case, when I set a non-GUI backend like Agg, I get a warning like this (and the execution continues):

UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
  fig.show()

Actually, plt.isinteractive() checks for the interactive mode; GUI backends like MacOsX can be in both modes (and we do not need to turn on interactive for it to work). What we actually would like is to check if the backend is "GUI" or "non-GUI" but from a simple search, it does not seem like there is a nice way to do it (rather than check with some white-list of backends). Given that, and the fact that the message I listed above is not that bad, I'd keep this as-is for now. Alternatively, we could opt for throwing an error/assert instead of warning, but again this would require a white-list of backends. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For additional context: on my laptop, the example runs nicely although the mode is not interactive by default.


Note that this is a toy example that does not lead to training a reasonable policy.
"""

import tempfile
michalzajac-ml marked this conversation as resolved.
Show resolved Hide resolved

import gym
import numpy as np
from stable_baselines3.common import vec_env

from imitation.algorithms import bc, dagger
from imitation.policies import interactive

if __name__ == "__main__":
rng = np.random.default_rng(0)

env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make("Pong-v4"), 10)])
env.seed(0)

expert = interactive.AtariInteractivePolicy(env)

bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
rng=rng,
)

with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir:
dagger_trainer = dagger.SimpleDAggerTrainer(
venv=env,
scratch_dir=tmpdir,
expert_policy=expert,
bc_trainer=bc_trainer,
rng=rng,
)
dagger_trainer.train(
total_timesteps=20,
rollout_round_min_episodes=1,
rollout_round_min_timesteps=10,
)
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,14 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"torch>=1.4.0",
"tqdm",
"scikit-learn>=0.21.2",
"seals>=0.1.5",
"seals~=0.1.5",
STABLE_BASELINES3,
"sacred>=0.8.4",
"tensorboard>=1.14",
"huggingface_sb3>=2.2.1",
# TODO: remove once https:/huggingface/huggingface_sb3/issues/37 is
# fixed
"huggingface_sb3==2.2.5",
"optuna>=3.0.1",
"datasets>=2.8.0",
],
tests_require=TESTS_REQUIRE,
Expand Down
10 changes: 5 additions & 5 deletions src/imitation/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from imitation.util import networks


class HardCodedPolicy(policies.BasePolicy, abc.ABC):
"""Abstract class for hard-coded (non-trainable) policies."""
class NonTrainablePolicy(policies.BasePolicy, abc.ABC):
"""Abstract class for non-trainable (e.g. hard-coded or interactive) policies."""

def __init__(self, observation_space: gym.Space, action_space: gym.Space):
"""Builds HardcodedPolicy with specified observation and action space."""
"""Builds NonTrainablePolicy with specified observation and action space."""
super().__init__(
observation_space=observation_space,
action_space=action_space,
Expand All @@ -43,14 +43,14 @@ def forward(self, *args):
raise NotImplementedError # pragma: no cover


class RandomPolicy(HardCodedPolicy):
class RandomPolicy(NonTrainablePolicy):
"""Returns random actions."""

def _choose_action(self, obs: np.ndarray) -> np.ndarray:
return self.action_space.sample()


class ZeroPolicy(HardCodedPolicy):
class ZeroPolicy(NonTrainablePolicy):
"""Returns constant zero action."""

def _choose_action(self, obs: np.ndarray) -> np.ndarray:
Expand Down
152 changes: 152 additions & 0 deletions src/imitation/policies/interactive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Interactive policies that query the user for actions."""

import abc
import collections
import typing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our style guide allows importing types directly from typing (i.e. from typing import Optional is permissible) although it's not obligatory -- fine to use this style if you prefer. https://google.github.io/styleguide/pyguide.html#2241-exemptions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good to know!


import gym
import matplotlib.pyplot as plt
import numpy as np
from stable_baselines3.common import vec_env

import imitation.policies.base as base_policies
from imitation.util import util


class DiscreteInteractivePolicy(base_policies.NonTrainablePolicy, abc.ABC):
"""Abstract class for interactive policies with discrete actions.

For each query, the observation is rendered and then the action is provided
as a keyboard input.
"""

def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
action_keys_names: collections.OrderedDict,
AdamGleave marked this conversation as resolved.
Show resolved Hide resolved
clear_screen_on_query: bool = True,
):
"""Builds DiscreteInteractivePolicy.

Args:
observation_space: Observation space.
action_space: Action space.
action_keys_names: `OrderedDict` containing pairs (key, name) for every
action, where key will be used in the console interface, and name
is a semantic action name.
michalzajac-ml marked this conversation as resolved.
Show resolved Hide resolved
clear_screen_on_query: If `True`, console will be cleared on every query.
"""
super().__init__(
observation_space=observation_space,
action_space=action_space,
)

assert isinstance(action_space, gym.spaces.Discrete)
assert (
len(action_keys_names)
== len(set(action_keys_names.values()))
== action_space.n
)

self.action_keys_names = action_keys_names
self.action_key_to_index = {
k: i for i, k in enumerate(action_keys_names.keys())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterating over a dict gives you the keys by default (you can leave as-is if you want to be explicit about it)

Suggested change
k: i for i, k in enumerate(action_keys_names.keys())
k: i for i, k in enumerate(action_keys_names)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in this case I slightly prefer to keep it.

}
self.clear_screen_on_query = clear_screen_on_query

def _choose_action(self, obs: np.ndarray) -> np.ndarray:
if self.clear_screen_on_query:
util.clear_screen()

context = self._render(obs)
key = self._get_input_key()
self._clean_up(context)

return np.array([self.action_key_to_index[key]])

def _get_input_key(self) -> str:
"""Obtains input key for action selection."""
print(
"Please select an action. Possible choices in [ACTION_NAME:KEY] format:",
", ".join([f"{n}:{k}" for k, n in self.action_keys_names.items()]),
)

key = input("Your choice (enter key):")
while key not in self.action_keys_names.keys():
michalzajac-ml marked this conversation as resolved.
Show resolved Hide resolved
key = input("Invalid key, please try again! Your choice (enter key):")

return key

@abc.abstractmethod
def _render(self, obs: np.ndarray) -> typing.Optional[object]:
"""Renders an observation, optionally returns a context for later cleanup."""

def _clean_up(self, context: object) -> None:
"""Cleans up after the input has been captured, e.g. stops showing the image."""
pass
michalzajac-ml marked this conversation as resolved.
Show resolved Hide resolved


class ImageObsDiscreteInteractivePolicy(DiscreteInteractivePolicy):
"""DiscreteInteractivePolicy that renders image observations."""

def _render(self, obs: np.ndarray) -> plt.Figure:
img = self._prepare_obs_image(obs)

fig, ax = plt.subplots()
ax.imshow(img, cmap="gray", vmin=0, vmax=255) # cmap is ignored for RGB images.
ax.axis("off")
fig.show()

return fig

def _clean_up(self, context: plt.Figure) -> None:
plt.close(context)

def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray:
"""Applies any required observation processing to get an image to show."""
return obs


ATARI_ACTION_NAMES_TO_KEYS = {
"NOOP": "1",
"FIRE": "2",
"UP": "w",
"RIGHT": "d",
"LEFT": "a",
"DOWN": "x",
"UPRIGHT": "e",
"UPLEFT": "q",
"DOWNRIGHT": "c",
"DOWNLEFT": "z",
"UPFIRE": "t",
"RIGHTFIRE": "h",
"LEFTFIRE": "f",
"DOWNFIRE": "b",
"UPRIGHTFIRE": "y",
"UPLEFTFIRE": "r",
"DOWNRIGHTFIRE": "n",
"DOWNLEFTFIRE": "v",
}


class AtariInteractivePolicy(ImageObsDiscreteInteractivePolicy):
"""Interactive policy for Atari environments."""

def __init__(self, env: typing.Union[gym.Env, vec_env.VecEnv], *args, **kwargs):
"""Builds AtariInteractivePolicy."""
action_names = (
env.get_action_meanings()
if isinstance(env, gym.Env)
else env.env_method("get_action_meanings", indices=[0])[0]
)
action_keys_names = collections.OrderedDict(
[(ATARI_ACTION_NAMES_TO_KEYS[name], name) for name in action_names],
)
super().__init__(
env.observation_space,
env.action_space,
action_keys_names,
*args,
**kwargs,
)
8 changes: 8 additions & 0 deletions src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,11 @@ def split_in_half(x: int) -> Tuple[int, int]:
"""
half = x // 2
return half, x - half


def clear_screen() -> None:
"""Clears the console screen."""
if os.name == "nt": # Windows
os.system("cls")
else:
os.system("clear")
2 changes: 2 additions & 0 deletions tests/algorithms/test_mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def test_infinite_horizon_error(random_mdp, rng):
def test_policy_om_random_mdp(discount: float):
"""Test that optimal policy occupancy measure ("om") for a random MDP is sane."""
mdp = gym.make("seals/Random-v0")
mdp.seed(0)

V, Q, pi = mce_partition_fh(mdp, discount=discount)
assert np.all(np.isfinite(V))
assert np.all(np.isfinite(Q))
Expand Down
117 changes: 117 additions & 0 deletions tests/policies/test_interactive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests interactive policies."""
michalzajac-ml marked this conversation as resolved.
Show resolved Hide resolved

import collections
from unittest import mock

import gym
import numpy as np
import pytest
from stable_baselines3.common import vec_env

from imitation.policies import interactive

ENVS = [
"CartPole-v0",
michalzajac-ml marked this conversation as resolved.
Show resolved Hide resolved
]


class NoRenderingDiscreteInteractivePolicy(interactive.DiscreteInteractivePolicy):
"""DiscreteInteractivePolicy with no rendering."""

def _render(self, obs: np.ndarray) -> None:
pass


def _get_interactive_policy(env: vec_env.VecEnv):
num_actions = env.action_space.n
action_keys_names = collections.OrderedDict(
[(f"k{i}", f"n{i}") for i in range(num_actions)],
)
interactive_policy = NoRenderingDiscreteInteractivePolicy(
env.observation_space,
env.action_space,
action_keys_names,
)
return interactive_policy


@pytest.mark.parametrize("env_name", ENVS)
def test_interactive_policy(env_name: str):
michalzajac-ml marked this conversation as resolved.
Show resolved Hide resolved
"""Test if correct actions are selected, as specified by input keys."""
env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make(env_name), 10)])
env.seed(0)

interactive_policy = _get_interactive_policy(env)
action_keys = list(interactive_policy.action_keys_names.keys())

obs = env.reset()
done = np.array([False])

class mock_input:
def __init__(self):
self.index = 0

def __call__(self, _):
# Sometimes insert incorrect keys, which should get ignored by the policy.
if np.random.uniform() < 0.5:
return "invalid"
key = action_keys[self.index]
self.index = (self.index + 1) % len(action_keys)
return key

with mock.patch("builtins.input", mock_input()):
requested_action = 0
while not done.all():
action, _ = interactive_policy.predict(obs)
assert isinstance(action, np.ndarray)
assert all(env.action_space.contains(a) for a in action)
assert action[0] == requested_action

obs, reward, done, info = env.step(action)
assert isinstance(obs, np.ndarray)
assert all(env.observation_space.contains(o) for o in obs)
assert isinstance(reward, np.ndarray)
assert isinstance(done, np.ndarray)

requested_action = (requested_action + 1) % len(action_keys)


@pytest.mark.parametrize("env_name", ENVS)
def test_interactive_policy_input_validity(capsys, env_name: str):
"""Test if appropriate feedback is given on the validity of the input."""
env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make(env_name), 10)])
env.seed(0)

interactive_policy = _get_interactive_policy(env)
action_keys = list(interactive_policy.action_keys_names.keys())

# Valid input key case
obs = env.reset()

def mock_input_valid(prompt):
print(prompt)
return action_keys[0]

with mock.patch("builtins.input", mock_input_valid):
interactive_policy.predict(obs)
stdout = capsys.readouterr().out
assert "Your choice" in stdout and "Invalid" not in stdout

# First invalid input key, then valid
obs = env.reset()

class mock_input_invalid_then_valid:
def __init__(self):
self.return_valid = False

def __call__(self, prompt):
print(prompt)
if self.return_valid:
return action_keys[0]
self.return_valid = True
return "invalid"

with mock.patch("builtins.input", mock_input_invalid_then_valid()):
interactive_policy.predict(obs)
stdout = capsys.readouterr().out
assert "Your choice" in stdout and "Invalid" in stdout
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for DiscreteInteractivePolicy looks great. We're not testing AtariInteractivePolicy at all though. It's pretty simple granted but might still be worth testing, even if just a simple smoke test (it runs, if we feed in a key corresponding to "FIRE" we get the correct action back, etc).

4 changes: 2 additions & 2 deletions tests/policies/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
SIMPLE_DISCRETE_ENV = "CartPole-v0" # Discrete(2) action space
SIMPLE_CONTINUOUS_ENV = "MountainCarContinuous-v0" # Box(1) action space
SIMPLE_ENVS = [SIMPLE_DISCRETE_ENV, SIMPLE_CONTINUOUS_ENV]
HARDCODED_TYPES = ["random", "zero"]
NONTRAINABLE_TYPES = ["random", "zero"]

assert_equal = functools.partial(th.testing.assert_close, rtol=0, atol=0)


@pytest.mark.parametrize("env_name", SIMPLE_ENVS)
@pytest.mark.parametrize("policy_type", HARDCODED_TYPES)
@pytest.mark.parametrize("policy_type", NONTRAINABLE_TYPES)
def test_actions_valid(env_name, policy_type, rng):
"""Test output actions of our custom policies always lie in action space."""
venv = util.make_vec_env(
Expand Down
Loading
Loading