Skip to content

Commit

Permalink
Fix test errors
Browse files Browse the repository at this point in the history
  • Loading branch information
taufeeque9 committed Aug 8, 2023
1 parent a14c7d2 commit 25d1eef
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
22 changes: 14 additions & 8 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def __init__(self, adversarial_trainer, *args, **kwargs):
"""Builds TrainDiscriminatorCallback.
Args:
adversarial_trainer: The AdversarialTrainer instance in which
this callback will be called.
*args: Passed through to `callbacks.BaseCallback`.
**kwargs: Passed through to `callbacks.BaseCallback`.
"""
Expand Down Expand Up @@ -276,7 +278,7 @@ def __init__(
# Would use an identity reward fn here, but RewardFns can't see rewards.
self.venv_wrapped = self.venv_buffering
self.gen_callback: List[callbacks.BaseCallback] = [
self.disc_trainer_callback
self.disc_trainer_callback,
]
else:
self.venv_wrapped = reward_wrapper.RewardVecEnvWrapper(
Expand Down Expand Up @@ -369,7 +371,7 @@ def update_rewards_of_rollouts(self) -> None:
buffer = self.gen_algo.rollout_buffer
assert buffer is not None
reward_fn_inputs = replay_buffer_wrapper._rollout_buffer_to_reward_fn_input(
self.gen_algo.rollout_buffer
self.gen_algo.rollout_buffer,
)
rewards = self._reward_net.predict(**reward_fn_inputs)
rewards = rewards.reshape(buffer.rewards.shape)
Expand All @@ -380,13 +382,14 @@ def update_rewards_of_rollouts(self) -> None:
last_dones = last_values == 0.0
self.gen_algo.rollout_buffer.rewards[:] = rewards
self.gen_algo.rollout_buffer.compute_returns_and_advantage(
th.tensor(last_values), last_dones
th.tensor(last_values),
last_dones,
)
elif isinstance(self.gen_algo, off_policy_algorithm.OffPolicyAlgorithm):
buffer = self.gen_algo.replay_buffer
assert buffer is not None
reward_fn_inputs = replay_buffer_wrapper._replay_buffer_to_reward_fn_input(
buffer
buffer,
)
rewards = self._reward_net.predict(**reward_fn_inputs)
buffer.rewards[:] = rewards.reshape(buffer.rewards.shape)
Expand Down Expand Up @@ -465,13 +468,15 @@ def train_disc(

return train_stats

def train_gen(
def train_gen_with_disc(
self,
total_timesteps: Optional[int] = None,
learn_kwargs: Optional[Mapping] = None,
) -> None:
"""Trains the generator to maximize the discriminator loss.
The discriminator is also trained after the rollouts are collected and before
the generator is trained.
After the end of training populates the generator replay buffer (used in
discriminator training) with `self.disc_batch_size` transitions.
Expand Down Expand Up @@ -502,7 +507,7 @@ def train(
) -> None:
"""Alternates between training the generator and discriminator.
Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`,
Every "round" consists of a call to `train_gen_with_disc(self.gen_train_timesteps)`,
a call to `train_disc`, and finally a call to `callback(round)`.
Training ends once an additional "round" would cause the number of transitions
Expand All @@ -522,7 +527,7 @@ def train(
f"total_timesteps={total_timesteps})!"
)
for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
self.train_gen(self.gen_train_timesteps)
self.train_gen_with_disc(self.gen_train_timesteps)
if callback:
callback(r)
self.logger.dump(self._global_step)
Expand Down Expand Up @@ -610,7 +615,8 @@ def _make_disc_train_batches(
if gen_samples is None:
if self._gen_replay_buffer.size() == 0:
raise RuntimeError(
"No generator samples for training. " "Call `train_gen()` first.",
"No generator samples for training. "
"Call `train_gen_with_disc()` first.",
)
gen_samples_dataclass = self._gen_replay_buffer.sample(batch_size)
gen_samples = types.dataclass_quick_asdict(gen_samples_dataclass)
Expand Down
5 changes: 3 additions & 2 deletions tests/algorithms/test_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,9 @@ def test_train_gen_train_disc_no_crash(
trainer_parametrized: common.AdversarialTrainer,
n_updates: int = 2,
) -> None:
trainer_parametrized.train_gen(n_updates * trainer_parametrized.gen_train_timesteps)
trainer_parametrized.train_disc()
trainer_parametrized.train_gen_with_disc(
n_updates * trainer_parametrized.gen_train_timesteps
)


@pytest.fixture
Expand Down

0 comments on commit 25d1eef

Please sign in to comment.