Skip to content

Commit

Permalink
Add namedconfigs for all seals envs and set the number of timesteps i…
Browse files Browse the repository at this point in the history
…n namedconfig for SQIL.
  • Loading branch information
ernestum committed Feb 27, 2024
1 parent d3860a3 commit 80edc62
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
30 changes: 29 additions & 1 deletion src/imitation/scripts/config/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def seals_mountain_car():
environment = dict(gym_id="seals/MountainCar-v0")
bc = dict(l2_weight=0.0)
dagger = dict(total_timesteps=20000)
sqil = dict(total_timesteps=1e5)


@train_imitation_ex.named_config
def seals_ant():
environment = dict(gym_id="seals/Ant-v1")
sqil = dict(total_timesteps=2e6)


@train_imitation_ex.named_config
Expand All @@ -57,11 +64,13 @@ def cartpole():
def seals_cartpole():
environment = dict(gym_id="seals/CartPole-v0")
dagger = dict(total_timesteps=20000)
sqil = dict(total_timesteps=1e5)


@train_imitation_ex.named_config
def pendulum():
environment = dict(gym_id="Pendulum-v1")
sqil = dict(total_timesteps=1e5)


@train_imitation_ex.named_config
Expand All @@ -76,14 +85,33 @@ def half_cheetah():
dagger = dict(total_timesteps=60000)


@train_imitation_ex.named_config
def seals_half_cheetah():
environment = dict(gym_id="seals/HalfCheetah-v1")
sqil = dict(total_timesteps=2e6)


@train_imitation_ex.named_config
def seals_hopper():
environment = dict(gym_id="seals/Hopper-v1")
sqil = dict(total_timesteps=2e6)


@train_imitation_ex.named_config
def seals_walker():
environment = dict(gym_id="seals/Walker2d-v1")
sqil = dict(total_timesteps=2e6)


@train_imitation_ex.named_config
def humanoid():
environment = dict(gym_id="Humanoid-v2")


@train_imitation_ex.named_config
def seals_humanoid():
environment = dict(gym_id="seals/Humanoid-v0")
environment = dict(gym_id="seals/Humanoid-v1")
sqil = dict(total_timesteps=2e6)


@train_imitation_ex.named_config
Expand Down
3 changes: 0 additions & 3 deletions tuning/hp_search_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,6 @@ def __call__(
"n_expert_demos": 100,
"source": "generated",
},
"sqil": {
"total_timesteps": 1e6,
},
"rl": {
"rl_kwargs": {
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-2, log=True),
Expand Down

0 comments on commit 80edc62

Please sign in to comment.