Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Mar 4, 2024
1 parent 0740017 commit 301f53d
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 86 deletions.
31 changes: 30 additions & 1 deletion mani_skill2/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict, defaultdict
from typing import Dict, Sequence
from typing import Dict, Sequence, Union

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -27,6 +27,21 @@ def dict_merge(dct: dict, merge_dct: dict):
dct[k] = merge_dct[k]


def append_data(x1: Union[dict, Sequence, Array], x2: Union[dict, Sequence, Array]):
"""Append `x2` in front of `x1` and returns the result. Tries to do this in place if possible.
Assumes both `x1, x2` have the same dictionary structure if they are dictionaries.
They may also both be lists/sequences in which case this is just appending like normal"""
if isinstance(x1, np.ndarray):
return np.concatenate([x1, x2])
elif isinstance(x1, list):
return x1 + x2
elif isinstance(x1, dict):
for k in x1.keys():
assert k in x2, "dct and append_dct need to have the same dictionary layout"
x1[k] = append_data(x1[k], x2[k])
return x1


def merge_dicts(ds: Sequence[Dict], asarray=False):
"""Merge multiple dicts with the same keys to a single one."""
# NOTE(jigu): To be compatible with generator, we only iterate once.
Expand Down Expand Up @@ -308,3 +323,17 @@ def flatten_dict_space_keys(space: spaces.Dict, prefix="") -> spaces.Dict:
else:
out[prefix + k] = v
return spaces.Dict(out)


def find_max_episode_steps_value(env):
cur = env
while cur is not None:
if hasattr(cur, "max_episode_steps"):
return cur.max_episode_steps
if cur.spec is not None and cur.spec.max_episode_steps is not None:
return cur.spec.max_episode_steps
if hasattr(cur, "env"):
cur = env.env
else:
cur = None
return None
2 changes: 2 additions & 0 deletions mani_skill2/utils/sapien_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def _batch(array: Union[Array, Sequence]):


def batch(*args: Tuple[Union[Array, Sequence]]):
"""Adds one dimension in front of everything. If given a dictionary, every leaf in the dictionary
has a new dimension. If given a tuple, returns the same tuple with each element batched"""
x = [_batch(x) for x in args]
if len(args) == 1:
return x[0]
Expand Down
182 changes: 129 additions & 53 deletions mani_skill2/utils/wrappers/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@

from mani_skill2 import get_commit_info, logger
from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.utils.sapien_utils import to_numpy

from ..common import extract_scalars_from_info, flatten_dict_keys
from ..io_utils import dump_json
from ..visualization.misc import images_to_video, put_info_on_image, tile_images
from mani_skill2.utils.common import (
append_data,
extract_scalars_from_info,
find_max_episode_steps_value,
flatten_dict_keys,
)
from mani_skill2.utils.io_utils import dump_json
from mani_skill2.utils.sapien_utils import batch, to_numpy
from mani_skill2.utils.visualization.misc import (
images_to_video,
put_info_on_image,
tile_images,
)


def parse_env_info(env: gym.Env):
Expand All @@ -37,6 +45,14 @@ def parse_env_info(env: gym.Env):
)


def temp_deep_print_shapes(x, prefix=""):
if isinstance(x, dict):
for k in x:
temp_deep_print_shapes(x[k], prefix=prefix + "/" + k)
else:
print(prefix, x.shape)


def clean_trajectories(h5_file: h5py.File, json_dict: dict, prune_empty_action=True):
"""Clean trajectories by renaming and pruning trajectories in place.
Expand Down Expand Up @@ -87,8 +103,6 @@ class Step:
reward: np.ndarray
terminated: np.ndarray
truncated: np.ndarray
env_ptrs: np.ndarray
"""track the step of each parallel env?"""


def pack_step_data(state, obs, action, rew, terminated, truncated, info):
Expand All @@ -105,7 +119,8 @@ def pack_step_data(state, obs, action, rew, terminated, truncated, info):


class RecordEpisode(gym.Wrapper):
"""Record trajectories or videos for episodes.
"""Record trajectories or videos for episodes. You generally should always apply this wrapper last, particularly if you include
observation wrappers which modify the returned observations.
Trajectory data is saved with two files, the actual data in a .h5 file via H5py and metadata in a JSON file of the same basename.
Expand Down Expand Up @@ -140,7 +155,6 @@ class RecordEpisode(gym.Wrapper):
- actions: [T, A], `np.float32`. `T` is the number of transitions.
- success: [T], `np.bool_`. It indicates whether the task is successful at each time step.
- env_states: [T+1, D], `np.float32`. Environment states. It can be used to set the environment to a certain state, e.g., `env.set_state(env_states[i])`. However, it may not be enough to reproduce the trajectory.
- env_init_state: [D], `np.float32`. The initial environment state. It is used for soft-body environments, since their states (particle positions) can use too much space.
- obs (optional): observations. If the observation is a `dict`, the value will be stored in `obs/{key}`. The convention is applied recursively for nested dict.
Expand All @@ -150,10 +164,11 @@ class RecordEpisode(gym.Wrapper):
save_trajectory: whether to save trajectory
trajectory_name: name of trajectory file (.h5). Use timestamp if not provided.
save_video: whether to save video
render_mode: rendering mode passed to `env.render`
save_on_reset: whether to save the previous trajectory automatically when resetting. Note that during partial resets on GPU simulation a video won't be saved.
It will only be saved if a full reset is done or user calls flush_video(). For recording videos on the GPU (to leverage fast parallel rendering) we recommend
setting max_steps_per_video to a fixed number so that every max_steps_per_video steps a video is saved.
info_on_video: whether to write data about reward and data in the info object to the video
save_on_reset: whether to save the previous trajectory (and video of it if `save_video` is True) automatically when resetting.
Not that for environments simulated on the GPU (to leverage fast parallel rendering) you must
set `max_steps_per_video` to a fixed number so that every `max_steps_per_video` steps a video is saved. This is
required as there may be partial environment resets which makes it ambiguous about how to save/cut videos.
max_steps_per_video: how many steps can be recorded into a single video before flushing the video. If None this is not used. A internal step counter is maintained to do this.
If the video is flushed at any point, the step counter is reset to 0.
clean_on_close: whether to rename and prune trajectories when closed.
Expand All @@ -176,23 +191,27 @@ def __init__(
init_state_only=False,
video_fps=20,
):
# NOTE (stao): don't worry about replay by action, not needed really, only replay by state for visual, otherwise just train directly.
super().__init__(env)

self.output_dir = Path(output_dir)
self.init_state_only = init_state_only
if save_trajectory or save_video:
self.output_dir.mkdir(parents=True, exist_ok=True)
self.save_on_reset = save_on_reset
self.video_fps = video_fps
self._episode_id = -1
self._episode_data = []

self._trajectory_buffer: Step = None
self._episode_info = {}

self.save_on_reset = save_on_reset
self.save_trajectory = save_trajectory
if self._base_env.num_envs > 1:
# TODO (stao): fix trajectory saving on gpu simulation.
assert self.save_trajectory == False
if self._base_env.num_envs > 1 and save_video:
assert (
max_steps_per_video is not None
), "On GPU parallelized environments, \
there must be a given max steps per video value in order to flush videos in order \
to avoid issues caused by partial resets. If your environment does not do partial \
resets you may set max_steps_per_video equal to the max_episode_steps"
self.clean_on_close = clean_on_close
self.record_reward = record_reward
if self.save_trajectory:
Expand Down Expand Up @@ -221,6 +240,12 @@ def __init__(
self.max_steps_per_video = max_steps_per_video
self._video_steps = 0

self.max_episode_steps = find_max_episode_steps_value(env)

@property
def num_envs(self):
return self._base_env.num_envs

@property
def _base_env(self) -> BaseEnv:
return self.env.unwrapped
Expand All @@ -236,40 +261,65 @@ def reset(
self,
*args,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
options: Optional[dict] = dict(),
**kwargs,
):
skip_trajectory = False
if options is not None:
options.pop("save_trajectory", False)
options.pop("save_trajectory", False)

if self.save_on_reset and self._episode_id >= 0 and not skip_trajectory:
self.flush_trajectory(ignore_empty_transition=True)
# To make things easier, we only flush videos when there is no partial reset.
if "env_idx" not in options:
# when we just have one env, we look at save_on_reset and clear the trajectory buffer
# when there are mutliple envs we save based on timesteps and must do more finegrained management of the buffer
if (
self.num_envs == 1
and self.save_on_reset
and self._trajectory_buffer is not None
):
if not skip_trajectory:
self.flush_trajectory(ignore_empty_transition=True)
self.flush_video()
else:
self._trajectory_buffer = None

# Clear cache
self._episode_data = []
self._episode_info = {}
if not skip_trajectory:
self._episode_id += 1

reset_kwargs = copy.deepcopy(dict(seed=seed, options=options, **kwargs))
obs, info = super().reset(*args, seed=seed, options=options, **kwargs)

if self.save_trajectory:
state = self._base_env.get_state_dict()
data = pack_step_data(state, obs, None, None, None, None, None)
self._episode_data.append(data)
self._episode_info.update(
episode_id=self._episode_id,
episode_seed=getattr(self.unwrapped, "_episode_seed", None),
reset_kwargs=reset_kwargs,
control_mode=getattr(self.unwrapped, "control_mode", None),
elapsed_steps=0,
)

state_dict = self._base_env.get_state_dict()
# data = pack_step_data(state, obs, None, None, None, None, None)
# self._episode_data.append(data)
# self._episode_info.update(
# episode_id=self._episode_id,
# episode_seed=getattr(self.unwrapped, "_episode_seed", None),
# reset_kwargs=reset_kwargs,
# control_mode=getattr(self.unwrapped, "control_mode", None),
# elapsed_steps=0,
# )

if self._trajectory_buffer is None:
# Initialize trajectory buffer on the first episode based on given observation (which should be generated after all wrappers)
# TODO (stao): we do not really know the max size of the trajectory buffer since we keep it in memory until we flush?
# which for cpu env we do not know max size. gpu env we do.
self._trajectory_buffer = Step(
state=to_numpy(batch(state_dict)),
observation=to_numpy(batch(obs)),
# note first reward/action etc. are ignored when saving trajectories to disk
action=batch(self.action_space.sample()),
reward=np.zeros(
(
1,
self.num_envs,
),
dtype=float,
),
# terminated and truncated are fixed to be True at the start to indicate the start of an episode.
# an episode is done when one of these is True otherwise the trajectory is incomplete / a partial episode
terminated=np.ones((1, self.num_envs), dtype=bool),
truncated=np.ones((1, self.num_envs), dtype=bool),
)
else:
self._trajectory_buffer.observation
if self.save_video:
self._render_images.append(self.capture_image())

Expand All @@ -279,11 +329,32 @@ def step(self, action):
obs, rew, terminated, truncated, info = super().step(action)

if self.save_trajectory:
state = self.env.unwrapped.get_state()
data = pack_step_data(state, obs, action, rew, terminated, truncated, info)
self._episode_data.append(data)
self._episode_info["elapsed_steps"] += 1
self._episode_info["info"] = to_numpy(info)
if (
isinstance(truncated, bool)
and self.num_envs > 1
and self.max_episode_steps is not None
):
# this fixes the issue where gymnasium applies a non-batched timelimit wrapper
truncated = self._base_env.elapsed_steps >= self.max_episode_steps
state_dict = self._base_env.get_state_dict()
self._trajectory_buffer.state = append_data(
self._trajectory_buffer.state, to_numpy(batch(state_dict))
)
self._trajectory_buffer.observation = append_data(
self._trajectory_buffer.observation, to_numpy(batch(obs))
)
self._trajectory_buffer.reward = append_data(
self._trajectory_buffer.reward, to_numpy(batch(rew))
)
self._trajectory_buffer.terminated = append_data(
self._trajectory_buffer.terminated, to_numpy(batch(terminated))
)
self._trajectory_buffer.truncated = append_data(
self._trajectory_buffer.truncated, to_numpy(batch(truncated))
)
done = terminated | truncated
if done.any():
self.flush_trajectory()

if self.save_video:
self._video_steps += 1
Expand All @@ -306,15 +377,20 @@ def step(self, action):

return obs, rew, terminated, truncated, info

def flush_trajectory(self, verbose=False, ignore_empty_transition=False):
if (
not self.save_trajectory or len(self._episode_data) == 0
): # TODO (stao): remove this, this is not intuitive as it depends on data in self.
return
if ignore_empty_transition and len(self._episode_data) == 1:
return
def flush_trajectory(
self,
verbose=False,
ignore_empty_transition=False,
flush_incomplete_trajectories=False,
):
# if ignore_empty_transition and len(self.t) == 1:
# return

# find which trajectories completed
import ipdb

ipdb.set_trace()
self._episode_id += 1
traj_id = "traj_{}".format(self._episode_id)
group = self._h5_file.create_group(traj_id, track_order=True)

Expand Down
13 changes: 3 additions & 10 deletions mani_skill2/vector/wrappers/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from gymnasium.vector import VectorEnv

from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.utils.common import find_max_episode_steps_value
from mani_skill2.utils.structs.types import Array


Expand Down Expand Up @@ -54,16 +55,8 @@ def __init__(
):
self.max_episode_steps = self.base_env.spec.max_episode_steps
if self.max_episode_steps is None:
# search wrappers to see if there is a time limit wrapper
cur = env
while cur is not None:
if cur.spec.max_episode_steps is not None:
self.max_episode_steps = cur.spec.max_episode_steps
break
if hasattr(cur, "env"):
cur = env.env
else:
cur = None
# search wrappers to find where max episode steps may have been defined
self.max_episode_steps = find_max_episode_steps_value(env)

@property
def device(self):
Expand Down
Loading

0 comments on commit 301f53d

Please sign in to comment.