Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change deprecated jax.tree_util.tree_map to jax.tree.map. Fix argument passed to jax.numpy.finfo call. #95

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mctx/_src/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def _mask_invalid_actions(logits, invalid_actions):


def _get_logits_from_probs(probs):
tiny = jnp.finfo(probs).tiny
tiny = jnp.finfo(probs.dtype).tiny
return jnp.log(jnp.maximum(probs, tiny))


Expand Down
6 changes: 3 additions & 3 deletions mctx/_src/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def expand(
chex.assert_shape([parent_index, action, next_node_index], (batch_size,))

# Retrieve states for nodes to be evaluated.
embedding = jax.tree_util.tree_map(
embedding = jax.tree.map(
lambda x: x[batch_range, parent_index], tree.embeddings)

# Evaluate and create a new node.
Expand Down Expand Up @@ -335,7 +335,7 @@ def update_tree_node(
tree.node_values, value, node_index),
node_visits=batch_update(
tree.node_visits, new_visit, node_index),
embeddings=jax.tree_util.tree_map(
embeddings=jax.tree.map(
lambda t, s: batch_update(t, s, node_index),
tree.embeddings, embedding))

Expand Down Expand Up @@ -375,7 +375,7 @@ def _zeros(x):
children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32),
children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype),
children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype),
embeddings=jax.tree_util.tree_map(_zeros, root.embedding),
embeddings=jax.tree.map(_zeros, root.embedding),
root_invalid_actions=root_invalid_actions,
extra_data=extra_data)

Expand Down
2 changes: 1 addition & 1 deletion mctx/_src/tests/policies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_gumbel_muzero_policy(self):

# Testing max_depth.
leaf, max_found_depth = _get_deepest_leaf(
jax.tree_util.tree_map(lambda x: x[0], policy_output.search_tree),
jax.tree.map(lambda x: x[0], policy_output.search_tree),
policy_output.search_tree.ROOT_INDEX)
self.assertEqual(max_depth, max_found_depth)
self.assertEqual(6, policy_output.search_tree.node_visits[0, leaf])
Expand Down
Loading