Skip to content

Commit

Permalink
Fix: PickSequentialTask spawning
Browse files Browse the repository at this point in the history
  • Loading branch information
arth-shukla committed Mar 3, 2024
1 parent 5ff3217 commit ccbb0fc
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions mani_skill2/envs/scenes/tasks/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def reconfigure(self):
qpos = torch.tensor(
self.agent.RESTING_QPOS[..., None].repeat(self.num_envs, axis=-1).transpose(1, 0)
).float()
accept_spawn_loc_rots = [[]] * self.num_envs
accept_dists = [[]] * self.num_envs
accept_spawn_loc_rots = [[] for _ in range(self.num_envs)]
accept_dists = [[] for _ in range(self.num_envs)]
bounding_box_corners = [
torch.tensor([dx, dy, 0]) for dx, dy in itertools.product([0.1, -0.1], [0.1, -0.1])
]
Expand Down Expand Up @@ -206,13 +206,12 @@ def reconfigure(self):

for i in torch.where(slrs_within_range & (robot_force < 1e-3))[0]:
accept_spawn_loc_rots[i].append(slrs[i].cpu().numpy().tolist())
accept_dists[i].append(dists[i].cpu().numpy().tolist())

accept_dists[i].append(dists[i].item())

self.num_spawn_loc_rots = torch.tensor([len(x) for x in accept_spawn_loc_rots])
self.spawn_loc_rots = pad_sequence([
torch.tensor(x) for x in accept_spawn_loc_rots
], batch_first=True, padding_value=0,)
], batch_first=True, padding_value=0)

self.closest_spawn_loc_rots = torch.stack([
self.spawn_loc_rots[i][torch.argmin(torch.tensor(x))] for i, x in enumerate(accept_dists)
Expand Down Expand Up @@ -253,10 +252,9 @@ def _initialize_agent(self, env_idx):
idxs = torch.tensor([
torch.randint(max_idx.item(), (1,)) for max_idx in self.num_spawn_loc_rots
])
loc_rot = self.spawn_loc_rots[torch.arange(self.num_envs), idxs].to(self.device)
loc_rot = self.spawn_loc_rots[torch.arange(self.num_envs), idxs]
else:
loc_rot = self.closest_spawn_loc_rots.to(self.device)

loc_rot = self.closest_spawn_loc_rots
robot_pos = self.agent.robot.pose.p
robot_pos[..., :2] = loc_rot[..., :2]
self.agent.robot.set_pose(Pose.create_from_pq(p=robot_pos))
Expand All @@ -265,19 +263,19 @@ def _initialize_agent(self, env_idx):
if self.randomize_base:
# base pos
robot_pos = self.agent.robot.pose.p
robot_pos[..., :2] += torch.clip(torch.normal(
robot_pos[..., :2] += torch.clamp(torch.normal(
0, 0.1, (b, len(robot_pos[0, :2]))
), -0.1, 0.1).to(self.device)
self.agent.robot.set_pose(Pose.create_from_pq(p=robot_pos))
# base rot
qpos[..., 2:3] += torch.clip(torch.normal(
qpos[..., 2:3] += torch.clamp(torch.normal(
0, 0.25, (b, len(qpos[0, 2:3]))
), -0.5, 0.5).to(self.device)
if self.randomize_arm:
qpos[..., 5:6] += torch.clip(torch.normal(
qpos[..., 5:6] += torch.clamp(torch.normal(
0, 0.05, (b, len(qpos[0, 5:6]))
), -0.1, 0.1).to(self.device)
qpos[..., 7:-2] += torch.clip(torch.normal(
qpos[..., 7:-2] += torch.clamp(torch.normal(
0, 0.05, (b, len(qpos[0, 7:-2]))
), -0.1, 0.1).to(self.device)
self.agent.reset(qpos)
Expand Down

0 comments on commit ccbb0fc

Please sign in to comment.