diff --git a/examples/train_dagger_atari_interactive_policy.py b/examples/train_dagger_atari_interactive_policy.py new file mode 100644 index 000000000..7c53e728b --- /dev/null +++ b/examples/train_dagger_atari_interactive_policy.py @@ -0,0 +1,41 @@ +"""Training DAgger with an interactive policy that queries the user for actions. + +Note that this is a toy example that does not lead to training a reasonable policy. +""" + +import tempfile + +import gym +import numpy as np +from stable_baselines3.common import vec_env + +from imitation.algorithms import bc, dagger +from imitation.policies import interactive + +if __name__ == "__main__": + rng = np.random.default_rng(0) + + env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make("Pong-v4"), 10)]) + env.seed(0) + + expert = interactive.AtariInteractivePolicy(env) + + bc_trainer = bc.BC( + observation_space=env.observation_space, + action_space=env.action_space, + rng=rng, + ) + + with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir: + dagger_trainer = dagger.SimpleDAggerTrainer( + venv=env, + scratch_dir=tmpdir, + expert_policy=expert, + bc_trainer=bc_trainer, + rng=rng, + ) + dagger_trainer.train( + total_timesteps=20, + rollout_round_min_episodes=1, + rollout_round_min_timesteps=10, + ) diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 3101cf2c7..60db89f50 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -13,11 +13,11 @@ from imitation.util import networks -class HardCodedPolicy(policies.BasePolicy, abc.ABC): - """Abstract class for hard-coded (non-trainable) policies.""" +class NonTrainablePolicy(policies.BasePolicy, abc.ABC): + """Abstract class for non-trainable (e.g. hard-coded or interactive) policies.""" def __init__(self, observation_space: gym.Space, action_space: gym.Space): - """Builds HardcodedPolicy with specified observation and action space.""" + """Builds NonTrainablePolicy with specified observation and action space.""" super().__init__( observation_space=observation_space, action_space=action_space, @@ -43,14 +43,14 @@ def forward(self, *args): raise NotImplementedError # pragma: no cover -class RandomPolicy(HardCodedPolicy): +class RandomPolicy(NonTrainablePolicy): """Returns random actions.""" def _choose_action(self, obs: np.ndarray) -> np.ndarray: return self.action_space.sample() -class ZeroPolicy(HardCodedPolicy): +class ZeroPolicy(NonTrainablePolicy): """Returns constant zero action.""" def _choose_action(self, obs: np.ndarray) -> np.ndarray: diff --git a/src/imitation/policies/interactive.py b/src/imitation/policies/interactive.py new file mode 100644 index 000000000..64be29b0f --- /dev/null +++ b/src/imitation/policies/interactive.py @@ -0,0 +1,152 @@ +"""Interactive policies that query the user for actions.""" + +import abc +import collections +from typing import Optional, Union + +import gym +import matplotlib.pyplot as plt +import numpy as np +from stable_baselines3.common import vec_env + +import imitation.policies.base as base_policies +from imitation.util import util + + +class DiscreteInteractivePolicy(base_policies.NonTrainablePolicy, abc.ABC): + """Abstract class for interactive policies with discrete actions. + + For each query, the observation is rendered and then the action is provided + as a keyboard input. + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + action_keys_names: collections.OrderedDict, + clear_screen_on_query: bool = True, + ): + """Builds DiscreteInteractivePolicy. + + Args: + observation_space: Observation space. + action_space: Action space. + action_keys_names: `OrderedDict` containing pairs (key, name) for every + action, where key will be used in the console interface, and name + is a semantic action name. The index of the pair in the dictionary + will be used as the discrete, integer action. + clear_screen_on_query: If `True`, console will be cleared on every query. + """ + super().__init__( + observation_space=observation_space, + action_space=action_space, + ) + + assert isinstance(action_space, gym.spaces.Discrete) + assert ( + len(action_keys_names) + == len(set(action_keys_names.values())) + == action_space.n + ) + + self.action_keys_names = action_keys_names + self.action_key_to_index = { + k: i for i, k in enumerate(action_keys_names.keys()) + } + self.clear_screen_on_query = clear_screen_on_query + + def _choose_action(self, obs: np.ndarray) -> np.ndarray: + if self.clear_screen_on_query: + util.clear_screen() + + context = self._render(obs) + key = self._get_input_key() + self._clean_up(context) + + return np.array([self.action_key_to_index[key]]) + + def _get_input_key(self) -> str: + """Obtains input key for action selection.""" + print( + "Please select an action. Possible choices in [ACTION_NAME:KEY] format:", + ", ".join([f"{n}:{k}" for k, n in self.action_keys_names.items()]), + ) + + key = input("Your choice (enter key):") + while key not in self.action_keys_names.keys(): + key = input("Invalid key, please try again! Your choice (enter key):") + + return key + + @abc.abstractmethod + def _render(self, obs: np.ndarray) -> Optional[object]: + """Renders an observation, optionally returns a context for later cleanup.""" + + def _clean_up(self, context: object) -> None: + """Cleans up after the input has been captured, e.g. stops showing the image.""" + + +class ImageObsDiscreteInteractivePolicy(DiscreteInteractivePolicy): + """DiscreteInteractivePolicy that renders image observations.""" + + def _render(self, obs: np.ndarray) -> plt.Figure: + img = self._prepare_obs_image(obs) + + fig, ax = plt.subplots() + ax.imshow(img, cmap="gray", vmin=0, vmax=255) # cmap is ignored for RGB images. + ax.axis("off") + fig.show() + + return fig + + def _clean_up(self, context: plt.Figure) -> None: + plt.close(context) + + def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray: + """Applies any required observation processing to get an image to show.""" + return obs + + +ATARI_ACTION_NAMES_TO_KEYS = { + "NOOP": "1", + "FIRE": "2", + "UP": "w", + "RIGHT": "d", + "LEFT": "a", + "DOWN": "x", + "UPRIGHT": "e", + "UPLEFT": "q", + "DOWNRIGHT": "c", + "DOWNLEFT": "z", + "UPFIRE": "t", + "RIGHTFIRE": "h", + "LEFTFIRE": "f", + "DOWNFIRE": "b", + "UPRIGHTFIRE": "y", + "UPLEFTFIRE": "r", + "DOWNRIGHTFIRE": "n", + "DOWNLEFTFIRE": "v", +} + + +class AtariInteractivePolicy(ImageObsDiscreteInteractivePolicy): + """Interactive policy for Atari environments.""" + + def __init__(self, env: Union[gym.Env, vec_env.VecEnv], *args, **kwargs): + """Builds AtariInteractivePolicy.""" + action_names = ( + env.get_action_meanings() + if isinstance(env, gym.Env) + else env.env_method("get_action_meanings", indices=[0])[0] + ) + action_keys_names = collections.OrderedDict( + [(ATARI_ACTION_NAMES_TO_KEYS[name], name) for name in action_names], + ) + super().__init__( + env.observation_space, + env.action_space, + action_keys_names, + *args, + **kwargs, + ) diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index 83696028d..2abae1605 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -460,3 +460,11 @@ def split_in_half(x: int) -> Tuple[int, int]: """ half = x // 2 return half, x - half + + +def clear_screen() -> None: + """Clears the console screen.""" + if os.name == "nt": # Windows + os.system("cls") + else: + os.system("clear") diff --git a/tests/policies/test_interactive.py b/tests/policies/test_interactive.py new file mode 100644 index 000000000..dece73fed --- /dev/null +++ b/tests/policies/test_interactive.py @@ -0,0 +1,154 @@ +"""Tests interactive policies.""" + +import collections +from unittest import mock + +import gym +import numpy as np +import pytest +from stable_baselines3.common import vec_env + +from imitation.policies import interactive + +SIMPLE_ENVS = [ + "seals/CartPole-v0", +] +ATARI_ENVS = [ + "Pong-v4", +] + + +class NoRenderingDiscreteInteractivePolicy(interactive.DiscreteInteractivePolicy): + """DiscreteInteractivePolicy with no rendering.""" + + def _render(self, obs: np.ndarray) -> None: + pass + + +def _get_simple_interactive_policy(env: vec_env.VecEnv): + num_actions = env.action_space.n + action_keys_names = collections.OrderedDict( + [(f"k{i}", f"n{i}") for i in range(num_actions)], + ) + interactive_policy = NoRenderingDiscreteInteractivePolicy( + env.observation_space, + env.action_space, + action_keys_names, + ) + return interactive_policy + + +@pytest.mark.parametrize("env_name", SIMPLE_ENVS + ATARI_ENVS) +def test_interactive_policy(env_name: str): + """Test if correct actions are selected, as specified by input keys.""" + env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make(env_name), 50)]) + env.seed(0) + + if env_name in ATARI_ENVS: + interactive_policy = interactive.AtariInteractivePolicy(env) + else: + interactive_policy = _get_simple_interactive_policy(env) + action_keys = list(interactive_policy.action_keys_names.keys()) + + obs = env.reset() + done = np.array([False]) + + class mock_input: + def __init__(self): + self.index = 0 + + def __call__(self, _): + # Sometimes insert incorrect keys, which should get ignored by the policy. + if np.random.uniform() < 0.5: + return "invalid" + key = action_keys[self.index] + self.index = (self.index + 1) % len(action_keys) + return key + + with mock.patch("builtins.input", mock_input()): + requested_action = 0 + while not done.all(): + action, _ = interactive_policy.predict(np.array(obs)) + assert isinstance(action, np.ndarray) + assert all(env.action_space.contains(a) for a in action) + assert action[0] == requested_action + + obs, reward, done, info = env.step(action) + assert isinstance(obs, np.ndarray) + assert all(env.observation_space.contains(o) for o in obs) + assert isinstance(reward, np.ndarray) + assert isinstance(done, np.ndarray) + + requested_action = (requested_action + 1) % len(action_keys) + + +@pytest.mark.parametrize("env_name", SIMPLE_ENVS) +def test_interactive_policy_input_validity(capsys, env_name: str): + """Test if appropriate feedback is given on the validity of the input.""" + env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make(env_name), 10)]) + env.seed(0) + + interactive_policy = _get_simple_interactive_policy(env) + action_keys = list(interactive_policy.action_keys_names.keys()) + + # Valid input key case + obs = env.reset() + + def mock_input_valid(prompt): + print(prompt) + return action_keys[0] + + with mock.patch("builtins.input", mock_input_valid): + interactive_policy.predict(obs) + stdout = capsys.readouterr().out + assert "Your choice" in stdout and "Invalid" not in stdout + + # First invalid input key, then valid + obs = env.reset() + + class mock_input_invalid_then_valid: + def __init__(self): + self.return_valid = False + + def __call__(self, prompt): + print(prompt) + if self.return_valid: + return action_keys[0] + self.return_valid = True + return "invalid" + + with mock.patch("builtins.input", mock_input_invalid_then_valid()): + interactive_policy.predict(obs) + stdout = capsys.readouterr().out + assert "Your choice" in stdout and "Invalid" in stdout + + +@pytest.mark.parametrize("env_name", ATARI_ENVS) +def test_atari_action_mappings(env_name: str): + """Test if correct actions are selected, as specified by input keys.""" + env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make(env_name), 50)]) + env.seed(0) + action_meanings = env.env_method("get_action_meanings", indices=[0])[0] + + interactive_policy = interactive.AtariInteractivePolicy(env) + + obs = env.reset() + + provided_keys = ["2", "a", "d"] + expected_action_meanings = ["FIRE", "LEFT", "RIGHT"] + + class mock_input: + def __init__(self): + self.index = 0 + + def __call__(self, _): + key = provided_keys[self.index] + self.index += 1 + return key + + with mock.patch("builtins.input", mock_input()): + for expected_action_meaning in expected_action_meanings: + action, _ = interactive_policy.predict(np.array(obs)) + obs, reward, done, info = env.step(action) + + assert action_meanings[action[0]] == expected_action_meaning diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index a79b134f1..e957eaf52 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -17,13 +17,13 @@ SIMPLE_DISCRETE_ENV = "CartPole-v0" # Discrete(2) action space SIMPLE_CONTINUOUS_ENV = "MountainCarContinuous-v0" # Box(1) action space SIMPLE_ENVS = [SIMPLE_DISCRETE_ENV, SIMPLE_CONTINUOUS_ENV] -HARDCODED_TYPES = ["random", "zero"] +NONTRAINABLE_TYPES = ["random", "zero"] assert_equal = functools.partial(th.testing.assert_close, rtol=0, atol=0) @pytest.mark.parametrize("env_name", SIMPLE_ENVS) -@pytest.mark.parametrize("policy_type", HARDCODED_TYPES) +@pytest.mark.parametrize("policy_type", NONTRAINABLE_TYPES) def test_actions_valid(env_name, policy_type, rng): """Test output actions of our custom policies always lie in action space.""" venv = util.make_vec_env( diff --git a/tests/test_examples.py b/tests/test_examples.py index 214fa0415..e7acc204b 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -30,9 +30,12 @@ def _paths_to_strs(x: Iterable[pathlib.Path]) -> Sequence[str]: EXAMPLES_DIR = THIS_DIR / ".." / "examples" TUTORIALS_DIR = THIS_DIR / ".." / "docs" / "tutorials" -SH_PATHS = _paths_to_strs(EXAMPLES_DIR.glob("*.sh")) +EXCLUDED_EXAMPLE_FILES = ["train_dagger_atari_interactive_policy.py"] +EXCLUDED_EXAMPLE_PATHS = [EXAMPLES_DIR / f for f in EXCLUDED_EXAMPLE_FILES] + +SH_PATHS = _paths_to_strs(set(EXAMPLES_DIR.glob("*.sh")) - set(EXCLUDED_EXAMPLE_PATHS)) TUTORIAL_PATHS = _paths_to_strs(TUTORIALS_DIR.glob("*.ipynb")) -PY_PATHS = _paths_to_strs(EXAMPLES_DIR.glob("*.py")) +PY_PATHS = _paths_to_strs(set(EXAMPLES_DIR.glob("*.py")) - set(EXCLUDED_EXAMPLE_PATHS)) # Note: This is excluded from coverage since is computed on linux. However, it is