Skip to content

Commit

Permalink
Fix actor show/hide visual gpu sim (#221)
Browse files Browse the repository at this point in the history
* Fix: actor show/hide visual not working on gpu sim

* Fix: by default hidden poses hidden

* bug fix

* Fix: setter and getter for actor.pose returns before_hide_pose if hidden, Fix: render human also sets _hidden_objs to hidden when doen rendering

* Unit test + fix for hide/show visual

---------

Co-authored-by: StoneT2000 <[email protected]>
  • Loading branch information
arth-shukla and StoneT2000 authored Mar 2, 2024
1 parent 1defbae commit 23a5ef6
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 30 deletions.
2 changes: 0 additions & 2 deletions examples/benchmarking/benchmark_maniskill.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import tqdm

import mani_skill2.envs
from mani_skill2.envs.scenes.tasks.planner.planner import PickSubtask
from mani_skill2.envs.scenes.tasks.sequential_task import SequentialTaskEnv
from mani_skill2.utils.scene_builder.ai2thor.variants import ArchitecTHORSceneBuilder
from mani_skill2.utils.scene_builder.replicacad.scene_builder import ReplicaCADSceneBuilder
from mani_skill2.vector.wrappers.gymnasium import ManiSkillVectorEnv
Expand Down
15 changes: 11 additions & 4 deletions mani_skill2/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,10 @@ def reset(self, seed=None, options=None):
self._set_episode_rng(self._episode_seed)
self.agent.reset()
self.initialize_episode(env_idx)
# reset the reset mask back to all ones so any internal code in maniskill can continue to manipulate all scenes at once as usual
self._scene._reset_mask = torch.ones(
self.num_envs, dtype=bool, device=self.device
)
obs = self.get_obs()
if physx.is_gpu_enabled():
# ensure all updates to object poses and configurations are applied on GPU after task initialization
Expand Down Expand Up @@ -962,19 +966,20 @@ def _setup_viewer(self):
)

def render_human(self):
for obj in self._hidden_objects:
obj.show_visual()
if self._viewer is None:
self._viewer = Viewer()
self._setup_viewer()
if "render_camera" in self._human_render_cameras:
self._viewer.set_camera_pose(
self._human_render_cameras["render_camera"].camera.global_pose
)

for obj in self._hidden_objects:
obj.show_visual()
if physx.is_gpu_enabled() and self._scene._gpu_sim_initialized:
self.physx_system.sync_poses_gpu_to_cpu()
self._viewer.render()
for obj in self._hidden_objects:
obj.hide_visual()
return self._viewer

def render_rgb_array(self, camera_name: str = None):
Expand Down Expand Up @@ -1009,15 +1014,17 @@ def render_rgb_array(self, camera_name: str = None):
return None
if len(images) == 1:
return images[0]
for obj in self._hidden_objects:
obj.hide_visual()
return tile_images(images)

def render_sensors(self):
"""
Renders all sensors that the agent can use and see and displays them
"""
images = []
for obj in self._hidden_objects:
obj.hide_visual()
images = []
self._scene.update_render()
self.capture_sensor_data()
sensor_images = self.get_sensor_obs()
Expand Down
53 changes: 29 additions & 24 deletions mani_skill2/utils/structs/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,37 +162,36 @@ def hide_visual(self):
if self.hidden:
return
if physx.is_gpu_enabled():
# TODO (stao): fix hiding visuals
pass
# self.last_pose = self.px.cuda_rigid_body_data.torch()[
# self._body_data_index, :7
# ].clone()
# temp_pose = self.pose.raw_pose
# temp_pose[..., :3] += 99999
# self.pose = temp_pose
# self.px.gpu_apply_rigid_dynamic_data()
# self.px.gpu_fetch_rigid_dynamic_data()
# print("HIDE", self.pose.raw_pose[0, :3])
self.before_hide_pose = self.px.cuda_rigid_body_data.torch()[
self._body_data_index, :7
].clone()
temp_pose = self.pose.raw_pose
temp_pose[..., :3] += 99999
self.pose = temp_pose
self.px.gpu_apply_rigid_dynamic_data()
self.px.gpu_fetch_rigid_dynamic_data()
else:
self._objs[0].find_component_by_type(
sapien.render.RenderBodyComponent
).visibility = 0
# set hidden *after* setting/getting so not applied to self.before_hide_pose erroenously
self.hidden = True

def show_visual(self):
assert not self.has_collision_shapes()
if not self.hidden:
return
# set hidden *before* setting/getting so not applied to self.before_hide_pose erroenously
self.hidden = False
if physx.is_gpu_enabled():
if hasattr(self, "last_pose"):
self.pose = self.last_pose
if hasattr(self, "before_hide_pose"):
self.pose = self.before_hide_pose
self.px.gpu_apply_rigid_dynamic_data()
self.px.gpu_fetch_rigid_dynamic_data()
else:
self._objs[0].find_component_by_type(
sapien.render.RenderBodyComponent
).visibility = 1
self.hidden = False

def is_static(self, lin_thresh=1e-2, ang_thresh=1e-1):
"""
Expand Down Expand Up @@ -223,13 +222,16 @@ def pose(self) -> Pose:
# as part of observations if needed
return self._builder_initial_pose
else:
raw_pose = self.px.cuda_rigid_body_data.torch()[
self._body_data_index, :7
]
# if self.hidden:
# print(self.name, "hidden", raw_pose[0, :3])
# raw_pose[..., :3] = raw_pose[..., :3] - 99999
return Pose.create(raw_pose)
if self.hidden:
return Pose.create(self.before_hide_pose)
else:
raw_pose = self.px.cuda_rigid_body_data.torch()[
self._body_data_index, :7
]
# if self.hidden:
# print(self.name, "hidden", raw_pose[0, :3])
# raw_pose[..., :3] = raw_pose[..., :3] - 99999
return Pose.create(raw_pose)
else:
assert len(self._objs) == 1
return Pose.create(self._objs[0].pose)
Expand All @@ -239,9 +241,12 @@ def pose(self, arg1: Union[Pose, sapien.Pose, Array]) -> None:
if physx.is_gpu_enabled():
if not isinstance(arg1, torch.Tensor):
arg1 = vectorize_pose(arg1)
self.px.cuda_rigid_body_data.torch()[
self._body_data_index[self._scene._reset_mask], :7
] = arg1
if self.hidden:
self.before_hide_pose = arg1
else:
self.px.cuda_rigid_body_data.torch()[
self._body_data_index[self._scene._reset_mask], :7
] = arg1
else:
self._objs[0].pose = to_sapien_pose(arg1)

Expand Down
97 changes: 97 additions & 0 deletions tests/test_gpu_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,101 @@ def test_timelimits(env_id):
del env


@pytest.mark.gpu_sim
@pytest.mark.parametrize("env_id", ["PickCube-v1"])
def test_hidden_objs(env_id):
env: ManiSkillVectorEnv = gym.make_vec(env_id, num_envs=16, vectorization_mode="custom")
obs, _ = env.reset()

# for PickCube, this is env.goal_site
hide_obj = env.unwrapped._hidden_objects[0]

raw_pose = hide_obj.pose.raw_pose.clone()
p = hide_obj.pose.p.clone()
q = hide_obj.pose.q.clone()
linvel = hide_obj.linear_velocity.clone()
angvel = hide_obj.angular_velocity.clone()

# hide_visual tests
hide_obj.hide_visual()

# 1. check relevant hidden properties are active
assert hide_obj.hidden
assert hasattr(hide_obj, "before_hide_pose")

# 2. check state data for new pos is not too low or high
assert (
hide_obj.px.cuda_rigid_body_data.torch()[
hide_obj._body_data_index, :7
].clone()[..., :3] > 1e3
).all()
assert (
hide_obj.px.cuda_rigid_body_data.torch()[
hide_obj._body_data_index, :7
].clone()[..., :3] < 1e6
).all()

# 3. check that linvel and angvel same as before
assert (hide_obj.linear_velocity == linvel).all()
assert (hide_obj.angular_velocity == angvel).all()

# 4. Check data stored in buffer has same q but different p
assert (
hide_obj.px.cuda_rigid_body_data.torch()[
hide_obj._body_data_index, :7
].clone()[..., :3] != p
).all()
assert (
hide_obj.px.cuda_rigid_body_data.torch()[
hide_obj._body_data_index, :7
].clone()[..., 3:] == q
).all()

# 5. Check data stored in before_hide_pose has same q and p
assert (hide_obj.before_hide_pose[..., :3] == p).all()
assert (hide_obj.before_hide_pose[..., 3:] == q).all()

# 6. check that direct calls to raw_pose, pos, and rot same as before
# (should return `before_hide_pose`)
assert (hide_obj.pose.raw_pose == raw_pose).all()
assert (hide_obj.pose.p == p).all()
assert (hide_obj.pose.q == q).all()
assert (hide_obj.pose.raw_pose == hide_obj.before_hide_pose).all()

# show_visual tests
hide_obj.show_visual()

# 1. check relevant hidden properties are active
assert not hide_obj.hidden

# 2. check that qvel, linvel, angvel same as before
assert (hide_obj.linear_velocity == linvel).all()
assert (hide_obj.angular_velocity == angvel).all()

# 3. check gpu buffer goes back to normal
print(
hide_obj.px.cuda_rigid_body_data.torch()[
hide_obj._body_data_index, :7
].clone()[..., :3]
)
print(p)
assert (
hide_obj.px.cuda_rigid_body_data.torch()[
hide_obj._body_data_index, :7
].clone()[..., :3] == p
).all()
assert (
hide_obj.px.cuda_rigid_body_data.torch()[
hide_obj._body_data_index, :7
].clone()[..., 3:] == q
).all()

# 4. check that direct calls to raw_pose, pos, and rot same as before
assert (hide_obj.pose.raw_pose == raw_pose).all()
assert (hide_obj.pose.p == p).all()
assert (hide_obj.pose.q == q).all()

env.close()
del env

# TODO (stao): Add test for tasks where there is no success/success and failure/no success or failure

0 comments on commit 23a5ef6

Please sign in to comment.