-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use data acquired by users #768
Conversation
PS: also a lot of clean-up still needed to pass linters.. |
@@ -6,7 +6,8 @@ | |||
import gym | |||
import numpy as np | |||
import torch as th | |||
from stable_baselines3.common import policies, torch_layers | |||
from stable_baselines3.common import policies, torch_layers, type_aliases | |||
from stable_baselines3.common.vec_env import VecEnv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from stable_baselines3.common.vec_env import VecEnv | |
from stable_baselines3.common import vec_env |
Style guide https://google.github.io/styleguide/pyguide.html#22-imports
np_obs = obs.detach().cpu().numpy() | ||
np_actions = [] | ||
|
||
for np_ob in np_obs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we iterate if we require num_envs == 1
?
return th_actions | ||
|
||
@abc.abstractmethod | ||
def _render(self, obs: np.ndarray) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Concrete implementations of this don't use obs
-- do we need it?
"""Render the environment, optionally based on observation obs.""" | ||
|
||
@abc.abstractmethod | ||
def _query_action(self, obs: np.ndarray) -> np.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Concrete implementations of this don't use obs
-- do we need it?
def _query_action(self, obs: np.ndarray) -> np.ndarray: | ||
"""Query human for an action, optionally based on observation obs.""" | ||
|
||
def forward(self, *args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could inherit from HardCodedPolicy
which already does this (code seems to be copied from there). Debatable whether it's "hard coded" (given it takes human input), but it's not a learned policy at least.
user_input = input(self.action_prompt).strip().lower() | ||
if user_input in self.action_map: | ||
return np.array([self.action_map[user_input]]) | ||
print("Invalid input. Try again.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print("Invalid input. Try again.") | |
print(f"Invalid input '{input}. Try again.") |
print("Invalid input. Try again.") | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider moving example code into separate file
@@ -0,0 +1,285 @@ | |||
# type: ignore | |||
# flake8: noqa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll want to fix these linting errors not just suppress them prior to merge.
self._episode_returns - self._prev_episode_returns | ||
) | ||
self._prev_episode_returns = self._episode_returns | ||
mess = "steps={self._steps} episode_steps={self._episode_steps} episode_returns_delta={episode_returns_delta} episode_returns={self._episode_returns}".format( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can use f-strings in modern Python instead of .format(**locals())
from pyglet.window import key as keycodes | ||
|
||
|
||
class Interactive(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably want to refactor some of this Retro code before merging.
Description
This PR introduces a new abstract base policy
InteractivePolicy
to support interactive data collection from human experts.There are two implementations for demo purposes:
TextInteractivePolicy
which supports simple Toy Text environments like FrozenLakeAtariInteractivePolicy
which is meant to support Atari games like Pong. This one is work in progress with many missing pieces. The idea is to rely on some code from retro-gym which is added in this PR with very slight modifications at src/imitation/policies/retro_gym_.pyLimitations
retro-gym
environments. Maybe there are other options for demonstrating continuous-time data gathering?Testing
To explore the behavior of
TextInteractivePolicy
you can runpython src/imitation/policies/interactive_text.py
to observe a single episode rollout ofTextInteractivePolicy
on FrozenLakepython examples/train_dagger_with_human_demos.py
to run a script collecting human demonstration data on FrozenLake and using it to train DAggerTo test that the interactive code taken from retro-gym is working you can
pip install gym-retro
and then runpython src/imitation/policies/retro_gym_.py
There are no demo scripts or tests for our wrapper
AtariInteractivePolicy
yet since that is still WIP.