Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when set_transformations is instantiated, instead of calling "ohe" or "gumbel" #41

Closed
BirkhoffG opened this issue Feb 22, 2024 · 1 comment · Fixed by #44
Closed
Labels
bug Something isn't working

Comments

@BirkhoffG
Copy link
Owner

BirkhoffG commented Feb 22, 2024

dm = relax.load_data('adult')
dm = dm.set_transformations({
        feat: SoftmaxTransformation() for feat in cat_feats
})
x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, correct_shape[-1]))
X_constraints = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=False)

This raises type error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[70], line 7
      3 test_set_transformations('gumbel', (32561, 29))
      4 # TODO: [bug] raise error when set_transformations is called with 
      5 # SoftmaxTransformation() or GumbelSoftmaxTransformation(),
      6 # instead of "ohe" or "gumbel".
----> 7 test_set_transformations(SoftmaxTransformation(), (32561, 29))
      8 # test_set_transformations(GumbelSoftmaxTransformation(), (32561, 29))

Cell In[69], line 20
     17     assert feat.is_immutable is False
     19 x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, correct_shape[-1]))
---> 20 _ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=False)
     21 _ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)

Cell In[48], line 145
    143 def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):
    144     return jnp.concatenate(
--> 145         [feat.apply_constraints(xs[:, start:end], cfs[:, start:end], 
    146                                 hard=hard, rng_key=rng_key, **kwargs) for feat, (start, end) in self.features_and_indices], axis=-1)

Cell In[48], line 145
    143 def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):
    144     return jnp.concatenate(
--> 145         [feat.apply_constraints(xs[:, start:end], cfs[:, start:end], 
    146                                 hard=hard, rng_key=rng_key, **kwargs) for feat, (start, end) in self.features_and_indices], axis=-1)

Cell In[45], line 133
    132 def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):
--> 133     return jax.lax.cond(
    134         self.is_immutable,
    135         true_fun=lambda xs: jnp.broadcast_to(xs, cfs.shape),
    136         false_fun=lambda _: self.transformation.apply_constraints(xs, cfs, hard=hard, rng_key=rng_key, **kwargs),
    137         operand=xs,
    138     )

    [... skipping hidden 13 frame]

Cell In[45], line 136
    132 def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):
    133     return jax.lax.cond(
    134         self.is_immutable,
    135         true_fun=lambda xs: jnp.broadcast_to(xs, cfs.shape),
--> 136         false_fun=lambda _: self.transformation.apply_constraints(xs, cfs, hard=hard, rng_key=rng_key, **kwargs),
    137         operand=xs,
    138     )

Cell In[39], line 18
     17 def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):
---> 18     return jax.lax.cond(
     19         hard,
     20         true_fun=self.hard_constraints,
     21         false_fun=self.soft_constraints,
     22         operand=(cfs, rng_key, kwargs),
     23     )

    [... skipping hidden 3 frame]

File ~/mambaforge-pypy3/envs/dev/lib/python3.9/site-packages/jax/_src/lax/control_flow/common.py:203, in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
    200 if not all(map(core.typematch, avals1, avals2)):
    201   diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
    202                   tree_unflatten(tree2, avals2))
--> 203   raise TypeError(f"{what} must have identical types, got\n{diff}.")

TypeError: true_fun and false_fun output must have identical types, got
DIFFERENT ShapedArray(float32[100,2]) vs. ShapedArray(float32[100,4]).
@BirkhoffG BirkhoffG added the bug Something isn't working label Feb 22, 2024
@BirkhoffG
Copy link
Owner Author

It works when

dm = dm.set_transformations({
        feat:  'ohe' for feat in cat_feats
})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant