Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use data acquired by users #768

Closed
wants to merge 13 commits into from
Closed

Conversation

jas-ho
Copy link
Contributor

@jas-ho jas-ho commented Aug 10, 2023

Description

This PR introduces a new abstract base policy InteractivePolicy to support interactive data collection from human experts.

There are two implementations for demo purposes:

  • TextInteractivePolicy which supports simple Toy Text environments like FrozenLake
  • AtariInteractivePolicy which is meant to support Atari games like Pong. This one is work in progress with many missing pieces. The idea is to rely on some code from retro-gym which is added in this PR with very slight modifications at src/imitation/policies/retro_gym_.py

Limitations

  • The current interface design will not be sufficient to support continuous time applications like the Pong game.
  • It's unclear whether we want / can support retro-gym environments. Maybe there are other options for demonstrating continuous-time data gathering?
  • due to limited time there are still a lot of unsolved todo comments
  • ...and only very basic testing

Testing

To explore the behavior of TextInteractivePolicy you can run

  • python src/imitation/policies/interactive_text.py to observe a single episode rollout of TextInteractivePolicy on FrozenLake
  • python examples/train_dagger_with_human_demos.py to run a script collecting human demonstration data on FrozenLake and using it to train DAgger
  • play with the very rudimentary tests at tests/policies/test_interactive.py

To test that the interactive code taken from retro-gym is working you can pip install gym-retro and then run

  • python src/imitation/policies/retro_gym_.py

There are no demo scripts or tests for our wrapper AtariInteractivePolicy yet since that is still WIP.

@jas-ho jas-ho linked an issue Aug 10, 2023 that may be closed by this pull request
@jas-ho
Copy link
Contributor Author

jas-ho commented Aug 10, 2023

PS: also a lot of clean-up still needed to pass linters..

@@ -6,7 +6,8 @@
import gym
import numpy as np
import torch as th
from stable_baselines3.common import policies, torch_layers
from stable_baselines3.common import policies, torch_layers, type_aliases
from stable_baselines3.common.vec_env import VecEnv
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common import vec_env

Style guide https://google.github.io/styleguide/pyguide.html#22-imports

np_obs = obs.detach().cpu().numpy()
np_actions = []

for np_ob in np_obs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we iterate if we require num_envs == 1?

return th_actions

@abc.abstractmethod
def _render(self, obs: np.ndarray) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concrete implementations of this don't use obs -- do we need it?

"""Render the environment, optionally based on observation obs."""

@abc.abstractmethod
def _query_action(self, obs: np.ndarray) -> np.ndarray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concrete implementations of this don't use obs -- do we need it?

def _query_action(self, obs: np.ndarray) -> np.ndarray:
"""Query human for an action, optionally based on observation obs."""

def forward(self, *args):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could inherit from HardCodedPolicy which already does this (code seems to be copied from there). Debatable whether it's "hard coded" (given it takes human input), but it's not a learned policy at least.

user_input = input(self.action_prompt).strip().lower()
if user_input in self.action_map:
return np.array([self.action_map[user_input]])
print("Invalid input. Try again.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("Invalid input. Try again.")
print(f"Invalid input '{input}. Try again.")

print("Invalid input. Try again.")


if __name__ == "__main__":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider moving example code into separate file

@@ -0,0 +1,285 @@
# type: ignore
# flake8: noqa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll want to fix these linting errors not just suppress them prior to merge.

self._episode_returns - self._prev_episode_returns
)
self._prev_episode_returns = self._episode_returns
mess = "steps={self._steps} episode_steps={self._episode_steps} episode_returns_delta={episode_returns_delta} episode_returns={self._episode_returns}".format(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use f-strings in modern Python instead of .format(**locals())

from pyglet.window import key as keycodes


class Interactive(abc.ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to refactor some of this Retro code before merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use data acquired by users
2 participants