diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 013f4aaf7..f4008fd09 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -101,6 +101,8 @@ def __init__(self, adversarial_trainer, *args, **kwargs): """Builds TrainDiscriminatorCallback. Args: + adversarial_trainer: The AdversarialTrainer instance in which + this callback will be called. *args: Passed through to `callbacks.BaseCallback`. **kwargs: Passed through to `callbacks.BaseCallback`. """ @@ -276,7 +278,7 @@ def __init__( # Would use an identity reward fn here, but RewardFns can't see rewards. self.venv_wrapped = self.venv_buffering self.gen_callback: List[callbacks.BaseCallback] = [ - self.disc_trainer_callback + self.disc_trainer_callback, ] else: self.venv_wrapped = reward_wrapper.RewardVecEnvWrapper( @@ -369,7 +371,7 @@ def update_rewards_of_rollouts(self) -> None: buffer = self.gen_algo.rollout_buffer assert buffer is not None reward_fn_inputs = replay_buffer_wrapper._rollout_buffer_to_reward_fn_input( - self.gen_algo.rollout_buffer + self.gen_algo.rollout_buffer, ) rewards = self._reward_net.predict(**reward_fn_inputs) rewards = rewards.reshape(buffer.rewards.shape) @@ -380,13 +382,14 @@ def update_rewards_of_rollouts(self) -> None: last_dones = last_values == 0.0 self.gen_algo.rollout_buffer.rewards[:] = rewards self.gen_algo.rollout_buffer.compute_returns_and_advantage( - th.tensor(last_values), last_dones + th.tensor(last_values), + last_dones, ) elif isinstance(self.gen_algo, off_policy_algorithm.OffPolicyAlgorithm): buffer = self.gen_algo.replay_buffer assert buffer is not None reward_fn_inputs = replay_buffer_wrapper._replay_buffer_to_reward_fn_input( - buffer + buffer, ) rewards = self._reward_net.predict(**reward_fn_inputs) buffer.rewards[:] = rewards.reshape(buffer.rewards.shape) @@ -465,13 +468,15 @@ def train_disc( return train_stats - def train_gen( + def train_gen_with_disc( self, total_timesteps: Optional[int] = None, learn_kwargs: Optional[Mapping] = None, ) -> None: """Trains the generator to maximize the discriminator loss. + The discriminator is also trained after the rollouts are collected and before + the generator is trained. After the end of training populates the generator replay buffer (used in discriminator training) with `self.disc_batch_size` transitions. @@ -502,7 +507,7 @@ def train( ) -> None: """Alternates between training the generator and discriminator. - Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`, + Every "round" consists of a call to `train_gen_with_disc(self.gen_train_timesteps)`, a call to `train_disc`, and finally a call to `callback(round)`. Training ends once an additional "round" would cause the number of transitions @@ -522,7 +527,7 @@ def train( f"total_timesteps={total_timesteps})!" ) for r in tqdm.tqdm(range(0, n_rounds), desc="round"): - self.train_gen(self.gen_train_timesteps) + self.train_gen_with_disc(self.gen_train_timesteps) if callback: callback(r) self.logger.dump(self._global_step) @@ -610,7 +615,8 @@ def _make_disc_train_batches( if gen_samples is None: if self._gen_replay_buffer.size() == 0: raise RuntimeError( - "No generator samples for training. " "Call `train_gen()` first.", + "No generator samples for training. " + "Call `train_gen_with_disc()` first.", ) gen_samples_dataclass = self._gen_replay_buffer.sample(batch_size) gen_samples = types.dataclass_quick_asdict(gen_samples_dataclass) diff --git a/tests/algorithms/test_adversarial.py b/tests/algorithms/test_adversarial.py index d3609efaa..769b2d52f 100644 --- a/tests/algorithms/test_adversarial.py +++ b/tests/algorithms/test_adversarial.py @@ -231,8 +231,9 @@ def test_train_gen_train_disc_no_crash( trainer_parametrized: common.AdversarialTrainer, n_updates: int = 2, ) -> None: - trainer_parametrized.train_gen(n_updates * trainer_parametrized.gen_train_timesteps) - trainer_parametrized.train_disc() + trainer_parametrized.train_gen_with_disc( + n_updates * trainer_parametrized.gen_train_timesteps + ) @pytest.fixture