Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce interactive policies to gather data from a user #776

Merged
merged 16 commits into from
Sep 8, 2023

Conversation

michalzajac-ml
Copy link
Contributor

@michalzajac-ml michalzajac-ml commented Sep 6, 2023

Description

This PR introduces interactive policies that query the user for actions, as requested in #701.
Such policies can be used e.g. in Behavioral cloning or DAgger.
An example showing the use for Atari is included.
Acknowledgement: tests were heavily based on the ones from #768 by @jas-ho.

Testing

pytest tests/policies/test_interactive.py to run unit tests.
python examples/train_dagger_atari_interactive_policy.py to run the interactive demo.

src/imitation/policies/interactive.py Outdated Show resolved Hide resolved
src/imitation/policies/interactive.py Outdated Show resolved Hide resolved
assert isinstance(action_space, gym.spaces.Discrete)
assert len(action_names) == len(action_keys) == action_space.n
# Names and keys should be unique.
assert len(set(action_names)) == len(set(action_keys)) == action_space.n
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have these both as sequences rather than a dictionary mapping from action key to action name (or vice-versa)? This would enforce the length the same, and uniqueness on the keys, so one would only need to check that len(the_dict) == action_space.n and that the values (action names) are unique.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we would need an ordered dictionary so perhaps that's a reason against, although all dictionaries are ordered in Python since 3.6.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I did this because of ordering, and had doubts about enforcing OrderedDict type, but now I think I agree it's more elegant so changing into OrderedDict.

src/imitation/policies/interactive.py Outdated Show resolved Hide resolved
tests/policies/test_interactive.py Show resolved Hide resolved
tests/policies/test_interactive.py Outdated Show resolved Hide resolved
env.seed(0)

action_names = env.envs[0].get_action_meanings()
names_to_keys = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only a small finite number of legal actions in Atari, so we could define a more comprehensive version of these in a constant somewhere (or even subclass ImageObsDiscreteInteractivePolicy to handle this directly) rather than it having to live in an example.

Copy link
Contributor Author

@michalzajac-ml michalzajac-ml Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for suggestion!

@michalzajac-ml michalzajac-ml changed the title WIP: Introduce interactive policies to gather data from a user Introduce interactive policies to gather data from a user Sep 7, 2023
@@ -30,9 +30,12 @@ def _paths_to_strs(x: Iterable[pathlib.Path]) -> Sequence[str]:
EXAMPLES_DIR = THIS_DIR / ".." / "examples"
TUTORIALS_DIR = THIS_DIR / ".." / "docs" / "tutorials"

SH_PATHS = _paths_to_strs(EXAMPLES_DIR.glob("*.sh"))
EXCLUDED_EXAMPLE_FILES = ["train_dagger_atari_interactive_policy.py"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably one could think about an alternative where we mock parts of the example script etc. However, it does not seem to be super useful, since we have unit tests that check analogous things that this mocked version would check.

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! Looks nearly ready -- only major sugestion is to add some more tests (covering AtariInteractivePolicy), others are pretty minor suggestions.

tests/test_examples.py Outdated Show resolved Hide resolved
@@ -0,0 +1,41 @@
"""Training DAgger with an interactive policy that queries the user for actions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a matplotlib GUI backend isn't installed, it'll fail with a somewhat cryptic error:

  fig.show()

and indeed no figure displays.

Installing the relevant backend seems out-of-scope for this project. But might want to check if the backend is interactive (I think plt.isinteractive() checks for this) and warn if not with link to relevant docs e.g. https://matplotlib.org/stable/users/explain/backends.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please check again what message you are getting and if this is an error or a warning? In my case, when I set a non-GUI backend like Agg, I get a warning like this (and the execution continues):

UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
  fig.show()

Actually, plt.isinteractive() checks for the interactive mode; GUI backends like MacOsX can be in both modes (and we do not need to turn on interactive for it to work). What we actually would like is to check if the backend is "GUI" or "non-GUI" but from a simple search, it does not seem like there is a nice way to do it (rather than check with some white-list of backends). Given that, and the fact that the message I listed above is not that bad, I'd keep this as-is for now. Alternatively, we could opt for throwing an error/assert instead of warning, but again this would require a white-list of backends. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For additional context: on my laptop, the example runs nicely although the mode is not interactive by default.

src/imitation/policies/interactive.py Show resolved Hide resolved
src/imitation/policies/interactive.py Outdated Show resolved Hide resolved
self.action_key_to_index = {k: i for i, k in enumerate(action_keys)}
self.action_keys_names = action_keys_names
self.action_key_to_index = {
k: i for i, k in enumerate(action_keys_names.keys())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterating over a dict gives you the keys by default (you can leave as-is if you want to be explicit about it)

Suggested change
k: i for i, k in enumerate(action_keys_names.keys())
k: i for i, k in enumerate(action_keys_names)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in this case I slightly prefer to keep it.

src/imitation/policies/interactive.py Show resolved Hide resolved
import abc
from typing import Optional, List
import collections
import typing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our style guide allows importing types directly from typing (i.e. from typing import Optional is permissible) although it's not obligatory -- fine to use this style if you prefer. https://google.github.io/styleguide/pyguide.html#2241-exemptions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good to know!

src/imitation/policies/interactive.py Outdated Show resolved Hide resolved
tests/policies/test_interactive.py Outdated Show resolved Hide resolved
with mock.patch("builtins.input", mock_input_invalid_then_valid()):
interactive_policy.predict(obs)
stdout = capsys.readouterr().out
assert "Your choice" in stdout and "Invalid" in stdout
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for DiscreteInteractivePolicy looks great. We're not testing AtariInteractivePolicy at all though. It's pretty simple granted but might still be worth testing, even if just a simple smoke test (it runs, if we feed in a key corresponding to "FIRE" we get the correct action back, etc).

Base automatically changed from dependency_fixes to master September 7, 2023 22:56
Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@AdamGleave AdamGleave merged commit f6a4888 into master Sep 8, 2023
7 of 9 checks passed
@AdamGleave AdamGleave deleted the 701-interactive-data branch September 8, 2023 16:20
lukasberglund pushed a commit to lukasberglund/imitation that referenced this pull request Sep 12, 2023
…tibleAI#776)

* Pin huggingface_sb3 version.

* Properly specify the compatible seals version so it does not auto-upgrade to 0.2.

* Make random_mdp test deterministic by seeding the environment.

* WIP: Introduce interactive policies to gather data from a user

* Addressing remarks from review

* fixes

* fix types

* formatting

* Dummy commit to acknowledge co-authorship.

Co-authored-by: Jason Hoelscher-Obermaier <[email protected]>

* Exclude interactive example from running during tests

* formatting

* Apply suggestions from code review

Co-authored-by: Adam Gleave <[email protected]>

* Adressing further suggestions from review

* formatting

* formatting

---------

Co-authored-by: Maximilian Ernestus <[email protected]>
Co-authored-by: Jason Hoelscher-Obermaier <[email protected]>
Co-authored-by: Adam Gleave <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants