Skip to content

Commit

Permalink
Update TPGMM variants
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Mar 5, 2024
1 parent 6b16e7e commit da402fd
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 18 deletions.
11 changes: 11 additions & 0 deletions examples/learning_ml/example_tpgmmbi_RPCtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,14 @@

# Reproductions for the same situations
Repr.reproduce([model_l, model_r, model_c], show_demo_idx=2)

# Reproductions for new situations
start_xdx_l = [np.array([-0.5, 1, 0, 0])]
end_xdx_l = [np.array([5, 4, 0, 0])]
start_xdx_r = [np.array([6.5, 7, 0, 0])]
end_xdx_r = end_xdx_l

Repr.task_params = {"left": {"frame_origins": [start_xdx_l, end_xdx_l], "frame_names": ["start", "end"]},
"right": {"frame_origins": [start_xdx_r, end_xdx_r], "frame_names": ["start", "end"]}}
traj_l, traj_r, _, _ = Repr.generate([model_l, model_r, model_c], ref_demo_idx=0)

32 changes: 16 additions & 16 deletions rofunc/learning/ml/tpgmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,14 @@ def reproduce(self, models: List, show_demo_idx: int) -> Tuple[ndarray, ndarray,
plt.show()
return ctraj_l, ctraj_r, prod_l, prod_r

def iterative_generate(self, model_l: HMM, model_r: HMM, ref_demo_idx: int, task_params: dict, nb_iter=1) -> \
def iterative_generate(self, model_l: HMM, model_r: HMM, ref_demo_idx: int, nb_iter=1) -> \
Tuple[ndarray, ndarray, GMM, GMM]:
beauty_print('generate trajectories from learned representation with new task parameters iteratively',
type='info')

vanilla_repr = TPGMMBi(self.demos_left_x, self.demos_right_x, nb_states=self.nb_states, plot=self.plot,
save=self.save, save_params=self.save_params)
vanilla_repr = TPGMMBi(self.demos_left_x, self.demos_right_x, task_params=self.task_params,
nb_states=self.nb_states, plot=self.plot, save=self.save, save_params=self.save_params)
vanilla_repr.task_params = self.task_params
vanilla_model_l, vanilla_model_r = vanilla_repr.fit()

vanilla_traj_l, vanilla_traj_r, _, _ = vanilla_repr.generate([vanilla_model_l, vanilla_model_r],
Expand Down Expand Up @@ -601,12 +602,10 @@ def iterative_generate(self, model_l: HMM, model_r: HMM, ref_demo_idx: int, task
traj_l, traj_r = vanilla_traj_l, vanilla_traj_r
for i in range(nb_iter):

task_params['left']['traj'] = traj_l[:, :self.nb_dim]
_, ctraj_r, _, prod_r = self.conditional_generate(model_l, model_r, ref_demo_idx, task_params,
leader='left')
task_params['right']['traj'] = traj_r[:, :self.nb_dim]
_, ctraj_l, _, prod_l = self.conditional_generate(model_l, model_r, ref_demo_idx, task_params,
leader='right')
self.task_params['left']['traj'] = traj_l[:, :self.nb_dim]
_, ctraj_r, _, prod_r = self.conditional_generate(model_l, model_r, ref_demo_idx, leader='left')
self.task_params['right']['traj'] = traj_r[:, :self.nb_dim]
_, ctraj_l, _, prod_l = self.conditional_generate(model_l, model_r, ref_demo_idx, leader='right')

traj_l, traj_r = ctraj_l, ctraj_r

Expand All @@ -620,24 +619,24 @@ def iterative_generate(self, model_l: HMM, model_r: HMM, ref_demo_idx: int, task

return ctraj_l, ctraj_r, prod_l, prod_r

def conditional_generate(self, model_l: HMM, model_r: HMM, ref_demo_idx: int, task_params: dict, leader: str) -> \
def conditional_generate(self, model_l: HMM, model_r: HMM, ref_demo_idx: int, leader: str) -> \
Tuple[ndarray, ndarray, None, GMM]:
follower = 'left' if leader == 'right' else 'right'
leader_traj = task_params[leader]['traj']
leader_traj = self.task_params[leader]['traj']
models = {'left': model_l, 'right': model_r}
reprs = {'left': self.repr_l, 'right': self.repr_r}

beauty_print('generate the {} trajectory from learned representation conditioned on the {} trajectory'.format(
follower, leader), type='info')

A, b, index_list = self._get_dyna_A_b(models[follower], reprs[follower], ref_demo_idx,
task_params=task_params[follower])
task_params=self.task_params[follower])
b[:, 2, :self.nb_dim] = leader_traj[index_list, :self.nb_dim]

follower_prod = self._uni_poe(models[follower], reprs[follower], ref_demo_idx, task_params={'A': A, 'b': b})

follower_traj = reprs[follower]._reproduce(models[follower], follower_prod, ref_demo_idx,
task_params[follower]['start_xdx'])
self.task_params[follower]['start_xdx'])

data_lst = [leader_traj[:, :self.nb_dim], follower_traj[:, :self.nb_dim]]
fig = rf.visualab.traj_plot(data_lst, title='Generated bimanual trajectories in leader-follower manner')
Expand All @@ -647,11 +646,12 @@ def conditional_generate(self, model_l: HMM, model_r: HMM, ref_demo_idx: int, ta
plt.show()
return leader_traj, follower_traj, None, follower_prod

def generate(self, model_l: HMM, model_r: HMM, ref_demo_idx: int, task_params: dict, leader: str = None):
def generate(self, models: List, ref_demo_idx: int, leader: str = None):
model_l, model_r = models
if leader is None:
return self.iterative_generate(model_l, model_r, ref_demo_idx, task_params)
return self.iterative_generate(model_l, model_r, ref_demo_idx)
else:
return self.conditional_generate(model_l, model_r, ref_demo_idx, task_params, leader)
return self.conditional_generate(model_l, model_r, ref_demo_idx, leader)


class TPGMM_RPAll(TPGMM_RPRepr, TPGMM_RPCtrl):
Expand Down
9 changes: 7 additions & 2 deletions rofunc/utils/datalab/poselib/README_CHN.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,18 @@ Then you can install the `fbx sdk` in the Python 3.7 env:
4. Copy the these three files to the `site-packages` folder of the Python 3.7 env you created
![img.png](img/img3.png)

This process can refer to this [blog](https://zhuanlan.zhihu.com/p/585738703).
You can also refer to this [blog](https://zhuanlan.zhihu.com/p/585738703) for the installation of `fbx sdk`.

### Motion Retargeting

After installation, you can run the script with the following command:

```bash
python xsens_fbx_to_hotu_npy.py --input_fbx_path <path_to_input_fbx> --output_npy_path <path_to_output_npy>
python xsens_fbx_to_hotu_npy.py --fbx_file <path_to_fbx_file>
```

Or you can run the script with the following command to convert all the `fbx` files in a folder:

```bash
python xsens_fbx_to_hotu_npy.py --fbx_folder <path_to_fbx_folder>
```

0 comments on commit da402fd

Please sign in to comment.