Skip to content

Commit

Permalink
Change the SQIL structure to instead subclass the replay buffer, new …
Browse files Browse the repository at this point in the history
…test
  • Loading branch information
RedTachyon committed Jul 7, 2023
1 parent ae43a75 commit 5b23f84
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 156 deletions.
260 changes: 118 additions & 142 deletions src/imitation/algorithms/sqil.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import torch as th
import torch.nn.functional as F
from gym import spaces
from stable_baselines3 import dqn
from stable_baselines3.common import buffers, policies, type_aliases, vec_env
from stable_baselines3.dqn.policies import DQNPolicy
Expand Down Expand Up @@ -107,8 +107,6 @@ def __init__(
"""
self.venv = venv

super().__init__(demonstrations=demonstrations, custom_logger=custom_logger)

self.dqn = dqn.DQN(
policy=policy,
env=venv,
Expand All @@ -120,8 +118,8 @@ def __init__(
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
replay_buffer_class=SQILReplayBuffer,
replay_buffer_kwargs={"demonstrations": demonstrations},
optimize_memory_usage=optimize_memory_usage,
target_update_interval=target_update_interval,
exploration_fraction=exploration_fraction,
Expand All @@ -136,7 +134,91 @@ def __init__(
_init_setup_model=_init_setup_model,
)

super().__init__(demonstrations=demonstrations, custom_logger=custom_logger)

def set_demonstrations(self, demonstrations: algo_base.AnyTransitions) -> None:
assert isinstance(self.dqn.replay_buffer, SQILReplayBuffer)
self.dqn.replay_buffer.set_demonstrations(demonstrations)

def train(self, *, total_timesteps: int):
self.dqn.learn(total_timesteps=total_timesteps)

@property
def policy(self) -> policies.BasePolicy:
assert isinstance(self.dqn.policy, policies.BasePolicy)
return self.dqn.policy


class SQILReplayBuffer(buffers.ReplayBuffer):
"""Replay buffer used in off-policy algorithms like SAC/TD3.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
of the replay buffer which reduces by almost a factor two the memory used,
at a cost of more complexity.
See https:/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
and https:/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
Cannot be used in combination with handle_timeout_termination.
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https:/DLR-RM/stable-baselines3/issues/284
"""

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
demonstrations: algo_base.AnyTransitions,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
):
"""A modification of the SB3 ReplayBuffer.
This buffer is fundamentally the same as ReplayBuffer,
but it includes an expert demonstration internal buffer.
When sampling a batch of data, it will be 50/50 expert and collected data.
Args:
buffer_size: Max number of element in the buffer
observation_space: Observation space
action_space: Action space
demonstrations: Expert demonstrations.
device: PyTorch device.
n_envs: Number of parallel environments. Defaults to 1.
optimize_memory_usage: Enable a memory efficient variant
of the replay buffer which reduces by almost a factor two
the memory used, at a cost of more complexity.
"""
super().__init__(
buffer_size,
observation_space,
action_space,
device,
n_envs,
optimize_memory_usage,
handle_timeout_termination=False,
)

self.expert_buffer = self.set_demonstrations(demonstrations)

def set_demonstrations(
self,
demonstrations: algo_base.AnyTransitions,
) -> buffers.ReplayBuffer:
"""Set the demonstrations to be used in the buffer.
Args:
demonstrations (algo_base.AnyTransitions): Expert demonstrations.
Returns:
buffers.ReplayBuffer: The buffer with demonstrations added
"""
# If demonstrations is a list of trajectories,
# flatten it into a list of transitions
if isinstance(demonstrations, Iterable):
Expand All @@ -152,15 +234,15 @@ def set_demonstrations(self, demonstrations: algo_base.AnyTransitions) -> None:
)

n_samples = len(demonstrations) # type: ignore[arg-type]
self.expert_buffer = buffers.ReplayBuffer(
expert_buffer = buffers.ReplayBuffer(
n_samples,
self.venv.observation_space,
self.venv.action_space,
self.observation_space,
self.action_space,
handle_timeout_termination=False,
)

for transition in demonstrations:
self.expert_buffer.add(
expert_buffer.add(
obs=np.array(transition["obs"]), # type: ignore[index]
next_obs=np.array(transition["next_obs"]), # type: ignore[index]
action=np.array(transition["acts"]), # type: ignore[index]
Expand All @@ -169,146 +251,40 @@ def set_demonstrations(self, demonstrations: algo_base.AnyTransitions) -> None:
infos=[{}],
)

def train(self, *, total_timesteps: int):
self.learn_dqn(total_timesteps=total_timesteps)
return expert_buffer

@property
def policy(self) -> policies.BasePolicy:
assert isinstance(self.dqn.policy, policies.BasePolicy)
return self.dqn.policy

def train_dqn(self, gradient_steps: int, batch_size: int = 100) -> None:

# Needed to make mypy happy, because SB3 typing is shoddy
assert isinstance(self.dqn.policy, policies.BasePolicy)

# Switch to train mode (this affects batch norm / dropout)
self.dqn.policy.set_training_mode(True)
# Update learning rate according to type_aliases.Schedule
self.dqn._update_learning_rate(self.dqn.policy.optimizer)

losses = []
for _ in range(gradient_steps):
# Sample replay buffer
new_data = self.dqn.replay_buffer.sample( # type: ignore[union-attr]
batch_size // 2,
env=self.dqn._vec_normalize_env,
)
new_data.rewards.zero_() # Zero out the rewards

expert_data = self.expert_buffer.sample(
batch_size // 2,
env=self.dqn._vec_normalize_env,
)

expert_data.rewards.fill_(1) # Fill the rewards with 1
def sample(
self,
batch_size: int,
env: Optional[vec_env.VecNormalize] = None,
) -> buffers.ReplayBufferSamples:
"""Sample a batch of data.
# Concatenate the two batches of data
replay_data = type_aliases.ReplayBufferSamples(
*(
th.cat((getattr(new_data, name), getattr(expert_data, name)))
for name in new_data._fields
),
)
Half of the batch will be from the expert buffer,
and the other half will be from the collected data.
with th.no_grad():
# Compute the next Q-values using the target network
next_q_values = self.dqn.q_net_target(replay_data.next_observations)
# Follow greedy policy: use the one with the highest value
next_q_values, _ = next_q_values.max(dim=1)
# Avoid potential broadcast issue
next_q_values = next_q_values.reshape(-1, 1)
# 1-step TD target
target_q_values = (
replay_data.rewards
+ (1 - replay_data.dones) * self.dqn.gamma * next_q_values
)

# Get current Q-values estimates
current_q_values = self.dqn.q_net(replay_data.observations)
Args:
batch_size: Number of element to sample in total
env: associated gym VecEnv to normalize the observations/rewards
when sampling
# Retrieve the q-values for the actions from the replay buffer
current_q_values = th.gather(
current_q_values,
dim=1,
index=replay_data.actions.long(),
)
Returns:
A batch of samples for DQN
# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q_values, target_q_values)
losses.append(loss.item())

# Optimize the policy
self.dqn.policy.optimizer.zero_grad()
loss.backward()
# Clip gradient norm
# For some reason pytype doesn't see nn.utils, so adding a type ignore
th.nn.utils.clip_grad_norm_( # type: ignore[module-attr]
self.dqn.policy.parameters(),
self.dqn.max_grad_norm,
)
self.dqn.policy.optimizer.step()
"""
new_sample_size, expert_sample_size = util.split_in_half(batch_size)

# Increase update counter
self.dqn._n_updates += gradient_steps
new_sample = super().sample(new_sample_size, env)
new_sample.rewards.fill_(0)

self.dqn.logger.record(
"train/n_updates",
self.dqn._n_updates,
exclude="tensorboard",
)
self.dqn.logger.record("train/loss", np.mean(losses))
expert_sample = self.expert_buffer.sample(expert_sample_size, env)
expert_sample.rewards.fill_(1)

def learn_dqn(
self,
total_timesteps: int,
callback: type_aliases.MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> None:

total_timesteps, callback = self.dqn._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
replay_data = type_aliases.ReplayBufferSamples(
*(
th.cat((getattr(new_sample, name), getattr(expert_sample, name)))
for name in new_sample._fields
),
)

callback.on_training_start(locals(), globals())

while self.dqn.num_timesteps < total_timesteps:
rollout = self.dqn.collect_rollouts(
self.dqn.env, # type: ignore[arg-type] # This is from SB3 code
train_freq=self.dqn.train_freq, # type: ignore[arg-type] # SB3
action_noise=self.dqn.action_noise,
callback=callback,
learning_starts=self.dqn.learning_starts,
replay_buffer=self.dqn.replay_buffer, # type: ignore[arg-type] # SB3
log_interval=log_interval,
)

if rollout.continue_training is False:
break

if (
self.dqn.num_timesteps > 0
and self.dqn.num_timesteps > self.dqn.learning_starts
):
# If no `gradient_steps` is specified,
# do as many gradients steps as steps performed during the rollout
gradient_steps = (
self.dqn.gradient_steps
if self.dqn.gradient_steps >= 0
else rollout.episode_timesteps
)
# Special case when the user passes `gradient_steps=0`
if gradient_steps > 0:
self.train_dqn(
batch_size=self.dqn.batch_size,
gradient_steps=gradient_steps,
)

callback.on_training_end()
return replay_data
15 changes: 15 additions & 0 deletions src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,18 @@ def parse_optional_path(
return None
else:
return parse_path(path, allow_relative, base_directory)


def split_in_half(x: int) -> Tuple[int, int]:
"""Split an integer in half, rounding up.
This is to ensure that the two halves sum to the original integer.
Args:
x: The integer to split.
Returns:
A tuple containing the two halves of `x`.
"""
half = x // 2
return half, x - half
Loading

0 comments on commit 5b23f84

Please sign in to comment.