Skip to content

Commit

Permalink
some Black changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Aidandos committed Apr 3, 2024
1 parent fc8f68d commit fe995f8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
9 changes: 7 additions & 2 deletions gymnax/environments/bsuite/memory_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,13 @@ def reset_env(
key_context, key_query = jax.random.split(key)
context = jax.random.bernoulli(key_context, p=0.5, shape=(self.num_bits,))
query = jax.random.randint(key_query, minval=0, maxval=self.num_bits, shape=())
state = EnvState(context=jnp.int32(context), query=jnp.int32(query),
total_perfect=0, total_regret=jnp.float32(0), time=0,)
state = EnvState(
context=jnp.int32(context),
query=jnp.int32(query),
total_perfect=0,
total_regret=jnp.float32(0),
time=0,
)
return self.get_obs(state, params), state

def get_obs(self, state: EnvState, params: EnvParams, key=None) -> chex.Array:
Expand Down
7 changes: 6 additions & 1 deletion gymnax/environments/bsuite/umbrella_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def reset_env(
key_need, key_has, key_distractor = jax.random.split(key, 3)
need_umbrella = jnp.int32(jax.random.bernoulli(key_need, p=0.5, shape=()))
has_umbrella = jnp.int32(jax.random.bernoulli(key_has, p=0.5, shape=()))
state = EnvState(need_umbrella=need_umbrella, has_umbrella=has_umbrella, total_regret=0, time=0)
state = EnvState(
need_umbrella=need_umbrella,
has_umbrella=has_umbrella,
total_regret=0,
time=0,
)
return self.get_obs(state=state, key=key_distractor, params=params), state

def get_obs(
Expand Down
4 changes: 2 additions & 2 deletions gymnax/wrappers/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def action_size(self) -> int:
a_space = self.env.action_space(self.env.default_params)
example_a = a_space.sample(jax.random.PRNGKey(0))
return len(jax.tree_util.tree_flatten(example_a)[0])

def observation_size(self) -> int:
"""DEFAULT size of observation vector expected by step."""
o_space = self.env.observation_space(self.env.default_params)
Expand All @@ -84,4 +84,4 @@ def observation_size(self) -> int:

def backend(self) -> str:
"""Return backend of the environment."""
return "jax"
return "jax"

0 comments on commit fe995f8

Please sign in to comment.