diff --git a/README.md b/README.md index 81306968c..ab6472ee4 100644 --- a/README.md +++ b/README.md @@ -172,7 +172,7 @@ state, timestep = jax.jit(env.reset)(key) env.render(state) # Interact with the (jit-able) environment -action = env.action_spec().generate_value() # Action selection (dummy value here) +action = env.action_spec.generate_value() # Action selection (dummy value here) state, timestep = jax.jit(env.step)(state, action) # Take a step and observe the next state and time step ``` diff --git a/docs/guides/advanced_usage.md b/docs/guides/advanced_usage.md index 19c29c358..0b1eed97a 100644 --- a/docs/guides/advanced_usage.md +++ b/docs/guides/advanced_usage.md @@ -16,7 +16,7 @@ env = AutoResetWrapper(env) # Automatically reset the environment when an ep batch_size = 7 rollout_length = 5 -num_actions = env.action_spec().num_values +num_actions = env.action_spec.num_values random_key = jax.random.PRNGKey(0) key1, key2 = jax.random.split(random_key) diff --git a/docs/guides/wrappers.md b/docs/guides/wrappers.md index 838131480..4623a1588 100644 --- a/docs/guides/wrappers.md +++ b/docs/guides/wrappers.md @@ -13,7 +13,7 @@ env = jumanji.make("Snake-6x6-v0") dm_env = jumanji.wrappers.JumanjiToDMEnvWrapper(env) timestep = dm_env.reset() -action = dm_env.action_spec().generate_value() +action = dm_env.action_spec.generate_value() next_timestep = dm_env.step(action) ... ``` @@ -52,7 +52,7 @@ key = jax.random.PRNGKey(0) state, timestep = env.reset(key) print("New episode") for i in range(100): - action = env.action_spec().generate_value() # Returns jnp.array(0) when using Snake. + action = env.action_spec.generate_value() # Returns jnp.array(0) when using Snake. state, timestep = env.step(state, action) if timestep.first(): print("New episode") diff --git a/jumanji/env.py b/jumanji/env.py index d3ddac6bd..48035a992 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -17,13 +17,14 @@ from __future__ import annotations import abc +from functools import cached_property from typing import Any, Generic, Tuple, TypeVar import chex from typing_extensions import Protocol from jumanji import specs -from jumanji.types import TimeStep +from jumanji.types import Observation, TimeStep class StateProtocol(Protocol): @@ -33,9 +34,10 @@ class StateProtocol(Protocol): State = TypeVar("State", bound="StateProtocol") +ActionSpec = TypeVar("ActionSpec", bound=specs.Array) -class Environment(abc.ABC, Generic[State]): +class Environment(abc.ABC, Generic[State, ActionSpec, Observation]): """Environment written in Jax that differs from the gym API to make the step and reset functions jittable. The state contains all the dynamics and data needed to step the environment, no computation stored in attributes of self. @@ -45,8 +47,15 @@ class Environment(abc.ABC, Generic[State]): def __repr__(self) -> str: return "Environment." + def __init__(self) -> None: + """Initialize environment.""" + self.observation_spec + self.action_spec + self.reward_spec + self.discount_spec + @abc.abstractmethod - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment to an initial state. Args: @@ -58,7 +67,9 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """ @abc.abstractmethod - def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: + def step( + self, state: State, action: chex.Array + ) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -71,33 +82,35 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """ @abc.abstractmethod - def observation_spec(self) -> specs.Spec: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. Returns: - observation_spec: a NestedSpec tree of spec. + observation_spec: a potentially nested `Spec` structure representing the observation. """ @abc.abstractmethod - def action_spec(self) -> specs.Spec: + @cached_property + def action_spec(self) -> ActionSpec: """Returns the action spec. Returns: - action_spec: a NestedSpec tree of spec. + action_spec: a potentially nested `Spec` structure representing the action. """ + @cached_property def reward_spec(self) -> specs.Array: - """Describes the reward returned by the environment. By default, this is assumed to be a - single float. + """Returns the reward spec. By default, this is assumed to be a single float. Returns: reward_spec: a `specs.Array` spec. """ return specs.Array(shape=(), dtype=float, name="reward") + @cached_property def discount_spec(self) -> specs.BoundedArray: - """Describes the discount returned by the environment. By default, this is assumed to be a - single float between 0 and 1. + """Returns the discount spec. By default, this is assumed to be a single float between 0 and 1. Returns: discount_spec: a `specs.BoundedArray` spec. @@ -107,7 +120,7 @@ def discount_spec(self) -> specs.BoundedArray: ) @property - def unwrapped(self) -> Environment: + def unwrapped(self) -> Environment[State, ActionSpec, Observation]: return self def render(self, state: State) -> Any: diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index 1f0a91a6f..45d189d2e 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -29,7 +30,7 @@ from jumanji.viewer import Viewer -class Game2048(Environment[State]): +class Game2048(Environment[State, specs.DiscreteArray, Observation]): """Environment for the game 2048. The game consists of a board of size board_size x board_size (4x4 by default) in which the player can take actions to move the tiles on the board up, down, left, or right. The goal of the game is to combine tiles with the same number to create a tile @@ -69,7 +70,7 @@ class Game2048(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -85,6 +86,7 @@ def __init__( viewer: `Viewer` used for rendering. Defaults to `Game2048Viewer`. """ self.board_size = board_size + super().__init__() # Create viewer used for rendering self._viewer = viewer or Game2048Viewer("2048", board_size) @@ -97,6 +99,7 @@ def __repr__(self) -> str: """ return f"2048 Game(board_size={self.board_size})" + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Game2048` environment. @@ -122,6 +125,7 @@ def observation_spec(self) -> specs.Spec[Observation]: ), ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """Returns the action spec. diff --git a/jumanji/environments/logic/game_2048/env_test.py b/jumanji/environments/logic/game_2048/env_test.py index 7985e0278..125fba401 100644 --- a/jumanji/environments/logic/game_2048/env_test.py +++ b/jumanji/environments/logic/game_2048/env_test.py @@ -19,7 +19,10 @@ from jumanji.environments.logic.game_2048.env import Game2048 from jumanji.environments.logic.game_2048.types import Board, State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -154,3 +157,8 @@ def test_game_2048__get_action_mask(game_2048: Game2048, board: Board) -> None: def test_game_2048__does_not_smoke(game_2048: Game2048) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(game_2048) + + +def test_game_2048__specs_does_not_smoke(game_2048: Game2048) -> None: + """Test that we access specs without any errors.""" + check_env_specs_does_not_smoke(game_2048) diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index f20f05a88..b5e65a3e5 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -33,7 +34,7 @@ from jumanji.viewer import Viewer -class GraphColoring(Environment[State]): +class GraphColoring(Environment[State, specs.DiscreteArray, Observation]): """Environment for the GraphColoring problem. The problem is a combinatorial optimization task where the goal is to assign a color to each vertex of a graph @@ -76,7 +77,7 @@ class GraphColoring(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -100,6 +101,7 @@ def __init__( num_nodes=20, edge_probability=0.8 ) self.num_nodes = self.generator.num_nodes + super().__init__() # Create viewer used for rendering self._env_viewer = viewer or GraphColoringViewer(name="GraphColoring") @@ -206,6 +208,7 @@ def step( ) return next_state, timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. @@ -253,6 +256,7 @@ def observation_spec(self) -> specs.Spec[Observation]: ), ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """Specification of the action for the `GraphColoring` environment. diff --git a/jumanji/environments/logic/graph_coloring/env_test.py b/jumanji/environments/logic/graph_coloring/env_test.py index d0418da77..f7b618b1d 100644 --- a/jumanji/environments/logic/graph_coloring/env_test.py +++ b/jumanji/environments/logic/graph_coloring/env_test.py @@ -18,7 +18,10 @@ from jumanji.environments.logic.graph_coloring import GraphColoring from jumanji.environments.logic.graph_coloring.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -90,3 +93,8 @@ def test_graph_coloring_get_action_mask(graph_coloring: GraphColoring) -> None: def test_graph_coloring_does_not_smoke(graph_coloring: GraphColoring) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(graph_coloring) + + +def test_graph_coloring_specs_does_not_smoke(graph_coloring: GraphColoring) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(graph_coloring) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 641ee48f9..1e9d8d4f1 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -36,7 +37,7 @@ from jumanji.viewer import Viewer -class Minesweeper(Environment[State]): +class Minesweeper(Environment[State, specs.MultiDiscreteArray, Observation]): """A JAX implementation of the minesweeper game. - observation: `Observation` @@ -81,7 +82,7 @@ class Minesweeper(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -127,6 +128,7 @@ def __init__( self.num_rows = self.generator.num_rows self.num_cols = self.generator.num_cols self.num_mines = self.generator.num_mines + super().__init__() self._viewer = viewer or MinesweeperViewer( num_rows=self.num_rows, num_cols=self.num_cols ) @@ -182,6 +184,7 @@ def step( ) return next_state, next_timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Minesweeper` environment. @@ -229,6 +232,7 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Returns the action spec. An action consists of the height and width of the square to be explored. diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 197675f0e..7ae532deb 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -24,7 +24,10 @@ from jumanji.environments.logic.minesweeper.env import Minesweeper from jumanji.environments.logic.minesweeper.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -123,7 +126,7 @@ def test_minesweeper__step(minesweeper_env: Minesweeper) -> None: key = jax.random.PRNGKey(0) state, timestep = jax.jit(minesweeper_env.reset)(key) # For this board, this action will be a non-mined square - action = minesweeper_env.action_spec().generate_value() + action = minesweeper_env.action_spec.generate_value() next_state, next_timestep = step_fn(state, action) # Check that the state has changed @@ -154,6 +157,11 @@ def test_minesweeper__does_not_smoke(minesweeper_env: Minesweeper) -> None: check_env_does_not_smoke(env=minesweeper_env) +def test_minesweeper__specs_does_not_smoke(minesweeper_env: Minesweeper) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(minesweeper_env) + + def test_minesweeper__render( monkeypatch: pytest.MonkeyPatch, minesweeper_env: Minesweeper ) -> None: @@ -162,7 +170,7 @@ def test_minesweeper__render( state, timestep = jax.jit(minesweeper_env.reset)(jax.random.PRNGKey(0)) minesweeper_env.render(state) minesweeper_env.close() - action = minesweeper_env.action_spec().generate_value() + action = minesweeper_env.action_spec.generate_value() state, timestep = jax.jit(minesweeper_env.step)(state, action) minesweeper_env.render(state) minesweeper_env.close() @@ -171,7 +179,7 @@ def test_minesweeper__render( def test_minesweeper__done_invalid_action(minesweeper_env: Minesweeper) -> None: """Test that the strict done signal is sent correctly""" # Note that this action corresponds to not stepping on a mine - action = minesweeper_env.action_spec().generate_value() + action = minesweeper_env.action_spec.generate_value() *_, episode_length = play_and_get_episode_stats( env=minesweeper_env, actions=[action for _ in range(10)], time_limit=10 ) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 84a2dff44..a4472e0ed 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -42,7 +43,7 @@ from jumanji.viewer import Viewer -class RubiksCube(Environment[State]): +class RubiksCube(Environment[State, specs.MultiDiscreteArray, Observation]): """A JAX implementation of the Rubik's Cube with a configurable cube size (by default, 3) and number of scrambles at reset. @@ -75,7 +76,7 @@ class RubiksCube(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -113,6 +114,7 @@ def __init__( cube_size=3, num_scrambles_on_reset=100, ) + super().__init__() self._viewer = viewer or RubiksCubeViewer( sticker_colors=DEFAULT_STICKER_COLORS, cube_size=self.generator.cube_size ) @@ -173,6 +175,7 @@ def step( ) return next_state, next_timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `RubiksCube` environment. @@ -202,6 +205,7 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Returns the action spec. An action is composed of 3 elements that range in: 6 faces, each with cube_size//2 possible depths, and 3 possible directions. diff --git a/jumanji/environments/logic/rubiks_cube/env_test.py b/jumanji/environments/logic/rubiks_cube/env_test.py index 3cbf7ac55..59d56cfa3 100644 --- a/jumanji/environments/logic/rubiks_cube/env_test.py +++ b/jumanji/environments/logic/rubiks_cube/env_test.py @@ -23,7 +23,10 @@ from jumanji.environments.logic.rubiks_cube.env import RubiksCube from jumanji.environments.logic.rubiks_cube.generator import ScramblingGenerator from jumanji.environments.logic.rubiks_cube.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -51,7 +54,7 @@ def test_rubiks_cube__step(rubiks_cube: RubiksCube) -> None: step_fn = jax.jit(chex.assert_max_traces(rubiks_cube.step, n=1)) key = jax.random.PRNGKey(0) state, timestep = rubiks_cube.reset(key) - action = rubiks_cube.action_spec().generate_value() + action = rubiks_cube.action_spec.generate_value() next_state, next_timestep = step_fn(state, action) # Check that the state has changed @@ -84,6 +87,16 @@ def test_rubiks_cube__does_not_smoke(cube_size: int) -> None: check_env_does_not_smoke(env) +@pytest.mark.parametrize("cube_size", [3, 4, 5]) +def test_rubiks_cube__specs_does_not_smoke(cube_size: int) -> None: + """Test that we can access specs without any errors.""" + env = RubiksCube( + time_limit=10, + generator=ScramblingGenerator(cube_size=cube_size, num_scrambles_on_reset=5), + ) + check_env_specs_does_not_smoke(env) + + def test_rubiks_cube__render( monkeypatch: pytest.MonkeyPatch, rubiks_cube: RubiksCube ) -> None: @@ -92,7 +105,7 @@ def test_rubiks_cube__render( state, timestep = rubiks_cube.reset(jax.random.PRNGKey(0)) rubiks_cube.render(state) rubiks_cube.close() - action = rubiks_cube.action_spec().generate_value() + action = rubiks_cube.action_spec.generate_value() state, timestep = rubiks_cube.step(state, action) rubiks_cube.render(state) rubiks_cube.close() @@ -103,7 +116,7 @@ def test_rubiks_cube__done(time_limit: int) -> None: """Test that the done signal is sent correctly.""" env = RubiksCube(time_limit=time_limit) state, timestep = env.reset(jax.random.PRNGKey(0)) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() episode_length = 0 step_fn = jax.jit(env.step) while not timestep.last(): diff --git a/jumanji/environments/logic/sliding_tile_puzzle/env.py b/jumanji/environments/logic/sliding_tile_puzzle/env.py index fe6a29f3d..9ab17e6c4 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/env.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Dict, Optional, Sequence, Tuple import chex @@ -40,7 +41,7 @@ from jumanji.viewer import Viewer -class SlidingTilePuzzle(Environment[State]): +class SlidingTilePuzzle(Environment[State, specs.DiscreteArray, Observation]): """Environment for the Sliding Tile Puzzle problem. The problem is a combinatorial optimization task where the goal is @@ -95,8 +96,8 @@ def __init__( grid_size=5, num_random_moves=200 ) self.reward_fn = reward_fn or DenseRewardFn() - self.time_limit = time_limit + super().__init__() # Create viewer used for rendering self._env_viewer = viewer or SlidingTilePuzzleViewer(name="SlidingTilePuzzle") @@ -205,6 +206,7 @@ def _get_extras(self, state: State) -> Dict[str, chex.Array]: num_correct_tiles = jnp.sum(self.solved_puzzle == state.puzzle) return {"prop_correctly_placed": num_correct_tiles / state.puzzle.size} + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec.""" grid_size = self.generator.grid_size @@ -241,6 +243,7 @@ def observation_spec(self) -> specs.Spec[Observation]: ), ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """Returns the action spec.""" # Up, Right, Down, Left diff --git a/jumanji/environments/logic/sliding_tile_puzzle/env_test.py b/jumanji/environments/logic/sliding_tile_puzzle/env_test.py index 31bab5f0e..5ed9ddc75 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/env_test.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/env_test.py @@ -18,7 +18,10 @@ from jumanji.environments.logic.sliding_tile_puzzle import SlidingTilePuzzle from jumanji.environments.logic.sliding_tile_puzzle.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -88,6 +91,13 @@ def test_sliding_tile_puzzle_does_not_smoke( check_env_does_not_smoke(sliding_tile_puzzle) +def test_sliding_tile_puzzle_specs_does_not_smoke( + sliding_tile_puzzle: SlidingTilePuzzle, +) -> None: + """Test that we access specs without any errors.""" + check_env_specs_does_not_smoke(sliding_tile_puzzle) + + def test_env_one_move_to_solve(sliding_tile_puzzle: SlidingTilePuzzle) -> None: """Test that the environment correctly handles a situation where the puzzle is one move away from being solved. diff --git a/jumanji/environments/logic/sudoku/env.py b/jumanji/environments/logic/sudoku/env.py index 8e8f5ca12..64a91899d 100644 --- a/jumanji/environments/logic/sudoku/env.py +++ b/jumanji/environments/logic/sudoku/env.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from functools import cached_property from typing import Any, Optional, Sequence, Tuple import chex @@ -32,7 +33,7 @@ from jumanji.viewer import Viewer -class Sudoku(Environment[State]): +class Sudoku(Environment[State, specs.MultiDiscreteArray, Observation]): """A JAX implementation of the sudoku game. - observation: `Observation` @@ -66,7 +67,7 @@ class Sudoku(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -78,6 +79,7 @@ def __init__( reward_fn: Optional[RewardFn] = None, viewer: Optional[Viewer[State]] = None, ): + super().__init__() if generator is None: file_path = os.path.dirname(os.path.abspath(__file__)) database_file = DATABASES["mixed"] @@ -129,6 +131,7 @@ def step( return next_state, timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec containing the board and action_mask arrays. @@ -158,6 +161,7 @@ def observation_spec(self) -> specs.Spec[Observation]: Observation, "ObservationSpec", board=board, action_mask=action_mask ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Returns the action spec. An action is composed of 3 integers: the row index, the column index and the value to be placed in the cell. diff --git a/jumanji/environments/logic/sudoku/env_test.py b/jumanji/environments/logic/sudoku/env_test.py index 9e55cdc12..85bd9d672 100644 --- a/jumanji/environments/logic/sudoku/env_test.py +++ b/jumanji/environments/logic/sudoku/env_test.py @@ -23,7 +23,10 @@ from jumanji.environments.logic.sudoku.env import Sudoku from jumanji.environments.logic.sudoku.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -53,7 +56,7 @@ def test_sudoku__step(sudoku_env: Sudoku) -> None: key = jax.random.PRNGKey(0) state, timestep = jax.jit(sudoku_env.reset)(key) - action = sudoku_env.action_spec().generate_value() + action = sudoku_env.action_spec.generate_value() next_state, next_timestep = step_fn(state, action) # Check that the state has changed @@ -75,13 +78,18 @@ def test_sudoku__does_not_smoke(sudoku_env: Sudoku) -> None: check_env_does_not_smoke(env=sudoku_env) +def test_sudoku__specs_does_not_smoke(sudoku_env: Sudoku) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(env=sudoku_env) + + def test_sudoku__render(monkeypatch: pytest.MonkeyPatch, sudoku_env: Sudoku) -> None: """Check that the render method builds the figure but does not display it.""" monkeypatch.setattr(plt, "show", lambda fig: None) state, timestep = jax.jit(sudoku_env.reset)(jax.random.PRNGKey(0)) sudoku_env.render(state) sudoku_env.close() - action = sudoku_env.action_spec().generate_value() + action = sudoku_env.action_spec.generate_value() state, timestep = jax.jit(sudoku_env.step)(state, action) sudoku_env.render(state) sudoku_env.close() diff --git a/jumanji/environments/packing/bin_pack/conftest.py b/jumanji/environments/packing/bin_pack/conftest.py index 9325960ba..33ca6c437 100644 --- a/jumanji/environments/packing/bin_pack/conftest.py +++ b/jumanji/environments/packing/bin_pack/conftest.py @@ -111,7 +111,7 @@ def bin_pack(dummy_generator: DummyGenerator) -> BinPack: @pytest.fixture def obs_spec(bin_pack: BinPack) -> specs.Spec: - return bin_pack.observation_spec() + return bin_pack.observation_spec @pytest.fixture diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index c3127c07e..4f410af62 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -13,6 +13,7 @@ # limitations under the License. import itertools +from functools import cached_property from typing import Dict, Optional, Sequence, Tuple import chex @@ -43,7 +44,7 @@ from jumanji.viewer import Viewer -class BinPack(Environment[State]): +class BinPack(Environment[State, specs.MultiDiscreteArray, Observation]): """Problem of 3D bin packing, where a set of items have to be placed in a 3D container with the goal of maximizing its volume utilization. This environment only supports 1 bin, meaning it is equivalent to the 3D-knapsack problem. We use the Empty Maximal Space (EMS) formulation of this @@ -106,7 +107,7 @@ class BinPack(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -154,6 +155,7 @@ def __init__( self.obs_num_ems = obs_num_ems self.reward_fn = reward_fn or DenseReward() self.normalize_dimensions = normalize_dimensions + super().__init__() self._viewer = viewer or BinPackViewer("BinPack", render_mode="human") self.debug = debug @@ -171,6 +173,7 @@ def __repr__(self) -> str: ] ) + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `BinPack` environment. @@ -248,6 +251,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Specifications of the action expected by the `BinPack` environment. @@ -610,13 +614,11 @@ def _get_intersections_dict( _, direction_intersections_mask, ) in zip(intersections_ems_dict.items(), intersections_mask_dict.items()): - # Inner loop iterates through alternative directions. for (alt_direction, alt_direction_intersections_ems), ( _, alt_direction_intersections_mask, ) in zip(intersections_ems_dict.items(), intersections_mask_dict.items()): - # The current direction EMS is included in the alternative EMS. directions_included_in_alt_directions = jax.vmap( jax.vmap(Space.is_included, in_axes=(None, 0)), in_axes=(0, None) diff --git a/jumanji/environments/packing/bin_pack/env_test.py b/jumanji/environments/packing/bin_pack/env_test.py index 921ce7025..967a56538 100644 --- a/jumanji/environments/packing/bin_pack/env_test.py +++ b/jumanji/environments/packing/bin_pack/env_test.py @@ -33,14 +33,18 @@ item_from_space, location_from_space, ) -from jumanji.testing.env_not_smoke import SelectActionFn, check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + SelectActionFn, + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @pytest.fixture def bin_pack_random_select_action(bin_pack: BinPack) -> SelectActionFn: - num_ems, num_items = np.asarray(bin_pack.action_spec().num_values) + num_ems, num_items = np.asarray(bin_pack.action_spec.num_values) def select_action(key: chex.PRNGKey, observation: Observation) -> chex.Array: """Randomly sample valid actions, as determined by `observation.action_mask`.""" @@ -148,7 +152,7 @@ def test_bin_pack_step__jit(bin_pack: BinPack) -> None: key = jax.random.PRNGKey(0) state, timestep = bin_pack.reset(key) - action = bin_pack.action_spec().generate_value() + action = bin_pack.action_spec.generate_value() _ = step_fn(state, action) # Call again to check it does not compile twice. state, timestep = step_fn(state, action) @@ -168,6 +172,11 @@ def test_bin_pack__does_not_smoke( check_env_does_not_smoke(bin_pack, bin_pack_random_select_action) +def test_bin_pack__specs_does_not_smoke(bin_pack: BinPack) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(bin_pack) + + def test_bin_pack__pack_all_items_dummy_instance( bin_pack: BinPack, bin_pack_random_select_action: SelectActionFn ) -> None: diff --git a/jumanji/environments/packing/flat_pack/env.py b/jumanji/environments/packing/flat_pack/env.py index 573486a73..e1125e98b 100644 --- a/jumanji/environments/packing/flat_pack/env.py +++ b/jumanji/environments/packing/flat_pack/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -34,7 +35,7 @@ from jumanji.viewer import Viewer -class FlatPack(Environment[State]): +class FlatPack(Environment[State, specs.MultiDiscreteArray, Observation]): """The FlatPack environment with a configurable number of row and column blocks. Here the goal of an agent is to completely fill an empty grid by placing all @@ -91,7 +92,7 @@ class FlatPack(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -129,6 +130,7 @@ def __init__( self.viewer = viewer or FlatPackViewer( "FlatPack", self.num_blocks, render_mode="human" ) + super().__init__() def __repr__(self) -> str: return ( @@ -141,7 +143,6 @@ def reset( self, key: chex.PRNGKey, ) -> Tuple[State, TimeStep[Observation]]: - """Resets the environment. Args: @@ -259,6 +260,7 @@ def close(self) -> None: self.viewer.close() + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec of the environment. @@ -307,6 +309,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Specifications of the action expected by the `FlatPack` environment. diff --git a/jumanji/environments/packing/flat_pack/env_test.py b/jumanji/environments/packing/flat_pack/env_test.py index 923306349..36b82f77d 100644 --- a/jumanji/environments/packing/flat_pack/env_test.py +++ b/jumanji/environments/packing/flat_pack/env_test.py @@ -28,7 +28,10 @@ CellDenseReward, ) from jumanji.environments.packing.flat_pack.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -182,6 +185,11 @@ def test_flat_pack__does_not_smoke(flat_pack: FlatPack) -> None: check_env_does_not_smoke(flat_pack) +def test_flat_pack__specs_does_not_smoke(flat_pack: FlatPack) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(flat_pack) + + def test_flat_pack__is_done(flat_pack: FlatPack, key: chex.PRNGKey) -> None: """Test that the is_done method works as expected.""" diff --git a/jumanji/environments/packing/job_shop/env.py b/jumanji/environments/packing/job_shop/env.py index 0e2421524..ec1e22e79 100644 --- a/jumanji/environments/packing/job_shop/env.py +++ b/jumanji/environments/packing/job_shop/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Any, Optional, Sequence, Tuple import chex @@ -29,7 +30,7 @@ from jumanji.viewer import Viewer -class JobShop(Environment[State]): +class JobShop(Environment[State, specs.MultiDiscreteArray, Observation]): """The Job Shop Scheduling Problem, as described in [1], is one of the best known combinatorial optimization problems. We are given `num_jobs` jobs, each consisting of at most `max_num_ops` ops, which need to be processed on `num_machines` machines. @@ -83,7 +84,7 @@ class JobShop(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -113,6 +114,7 @@ def __init__( self.num_machines = self.generator.num_machines self.max_num_ops = self.generator.max_num_ops self.max_op_duration = self.generator.max_op_duration + super().__init__() # Define the "job id" of a no-op action as the number of jobs self.no_op_idx = self.num_jobs @@ -356,6 +358,7 @@ def _update_machines( return updated_machines_job_ids, updated_machines_remaining_times + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `JobShop` environment. @@ -421,6 +424,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Specifications of the action in the `JobShop` environment. The action gives each machine a job id ranging from 0, 1, ..., num_jobs where the last value corresponds diff --git a/jumanji/environments/packing/job_shop/env_test.py b/jumanji/environments/packing/job_shop/env_test.py index 737f42bea..964042dac 100644 --- a/jumanji/environments/packing/job_shop/env_test.py +++ b/jumanji/environments/packing/job_shop/env_test.py @@ -19,7 +19,10 @@ from jumanji.environments.packing.job_shop.env import JobShop from jumanji.environments.packing.job_shop.generator import ToyGenerator from jumanji.environments.packing.job_shop.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.types import TimeStep @@ -816,3 +819,7 @@ def test_job_shop__toy_generator_reward(self) -> None: def test_job_shop_env__does_not_smoke(self, job_shop_env: JobShop) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(job_shop_env) + + def test_job_shop_env__specs_does_not_smoke(self, job_shop_env: JobShop) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(job_shop_env) diff --git a/jumanji/environments/packing/knapsack/env.py b/jumanji/environments/packing/knapsack/env.py index 573b83b32..3d544132a 100644 --- a/jumanji/environments/packing/knapsack/env.py +++ b/jumanji/environments/packing/knapsack/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -30,7 +31,7 @@ from jumanji.viewer import Viewer -class Knapsack(Environment[State]): +class Knapsack(Environment[State, specs.DiscreteArray, Observation]): """Knapsack environment as described in [1]. - observation: Observation @@ -76,7 +77,7 @@ class Knapsack(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -107,6 +108,7 @@ def __init__( total_budget=12.5, ) self.num_items = self.generator.num_items + super().__init__() self.total_budget = self.generator.total_budget self.reward_fn = reward_fn or DenseReward() self._viewer = viewer or KnapsackViewer( @@ -176,6 +178,7 @@ def step( return next_state, timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. @@ -223,6 +226,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """Returns the action spec. diff --git a/jumanji/environments/packing/knapsack/env_test.py b/jumanji/environments/packing/knapsack/env_test.py index 139c9b3fb..32ea10cf7 100644 --- a/jumanji/environments/packing/knapsack/env_test.py +++ b/jumanji/environments/packing/knapsack/env_test.py @@ -17,7 +17,10 @@ from jax import numpy as jnp from jumanji.environments.packing.knapsack import Knapsack, State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -75,6 +78,12 @@ def test_knapsack_sparse__does_not_smoke( """Test that we can run an episode without any errors.""" check_env_does_not_smoke(knapsack_sparse_reward) + def test_knapsack_sparse__specs_does_not_smoke( + self, knapsack_sparse_reward: Knapsack + ) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(knapsack_sparse_reward) + def test_knapsack_sparse__trajectory_action( self, knapsack_sparse_reward: Knapsack ) -> None: diff --git a/jumanji/environments/packing/tetris/env.py b/jumanji/environments/packing/tetris/env.py index 8223225f1..995cb1fd6 100644 --- a/jumanji/environments/packing/tetris/env.py +++ b/jumanji/environments/packing/tetris/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -34,7 +35,7 @@ from jumanji.viewer import Viewer -class Tetris(Environment[State]): +class Tetris(Environment[State, specs.MultiDiscreteArray, Observation]): """RL Environment for the game of Tetris. The environment has a grid where the player can place tetrominoes. The environment has the following characteristics: @@ -69,7 +70,7 @@ class Tetris(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -106,6 +107,7 @@ def __init__( self.TETROMINOES_LIST = jnp.array(TETROMINOES_LIST, jnp.int32) self.reward_list = jnp.array(REWARD_LIST, float) self.time_limit = time_limit + super().__init__() self._viewer = viewer or TetrisViewer( num_rows=self.num_rows, num_cols=self.num_cols, @@ -246,6 +248,7 @@ def render(self, state: State) -> Optional[NDArray]: """ return self._viewer.render(state) + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Tetris` environment. @@ -285,6 +288,7 @@ def observation_spec(self) -> specs.Spec[Observation]: ), ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Returns the action spec. An action consists of two pieces of information: the amount of rotation (number of 90-degree rotations) and the x-position of diff --git a/jumanji/environments/packing/tetris/env_test.py b/jumanji/environments/packing/tetris/env_test.py index a46017d22..d68a7026f 100644 --- a/jumanji/environments/packing/tetris/env_test.py +++ b/jumanji/environments/packing/tetris/env_test.py @@ -19,6 +19,10 @@ from jumanji.environments.packing.tetris.env import Tetris from jumanji.environments.packing.tetris.types import State +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -115,3 +119,13 @@ def test_calculate_action_mask(tetris_env: Tetris, grid: chex.Array) -> None: ] ) assert (action_mask == expected_action_mask).all() + + +def test_tetris__does_not_smoke(tetris_env: Tetris) -> None: + """Test that we can run an episode without any errors.""" + check_env_does_not_smoke(tetris_env) + + +def test_tetris__specs_does_not_smoke(tetris_env: Tetris) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(tetris_env) diff --git a/jumanji/environments/routing/cleaner/env.py b/jumanji/environments/routing/cleaner/env.py index 6377bcb8f..7dd8c6423 100644 --- a/jumanji/environments/routing/cleaner/env.py +++ b/jumanji/environments/routing/cleaner/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Any, Dict, Optional, Sequence, Tuple import chex @@ -30,7 +31,7 @@ from jumanji.viewer import Viewer -class Cleaner(Environment[State]): +class Cleaner(Environment[State, specs.MultiDiscreteArray, Observation]): """A JAX implementation of the 'Cleaner' game where multiple agents have to clean all tiles of a maze. @@ -74,7 +75,7 @@ class Cleaner(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -107,6 +108,7 @@ def __init__( self.num_cols = self.generator.num_cols self.grid_shape = (self.num_rows, self.num_cols) self.time_limit = time_limit or (self.num_rows * self.num_cols) + super().__init__() self.penalty_per_timestep = penalty_per_timestep # Create viewer used for rendering @@ -122,6 +124,7 @@ def __repr__(self) -> str: ")" ) + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specification of the observation of the `Cleaner` environment. @@ -152,6 +155,7 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Specification of the action for the `Cleaner` environment. diff --git a/jumanji/environments/routing/cleaner/env_test.py b/jumanji/environments/routing/cleaner/env_test.py index 7a8eb3192..f386a5dad 100644 --- a/jumanji/environments/routing/cleaner/env_test.py +++ b/jumanji/environments/routing/cleaner/env_test.py @@ -21,7 +21,10 @@ from jumanji.environments.routing.cleaner.env import Cleaner from jumanji.environments.routing.cleaner.generator import Generator from jumanji.environments.routing.cleaner.types import Observation, State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -191,6 +194,10 @@ def select_action( check_env_does_not_smoke(cleaner, select_actions) + def test_cleaner__specs_does_not_smoke(self, cleaner: Cleaner) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(cleaner) + def test_cleaner__compute_extras(self, cleaner: Cleaner, key: chex.PRNGKey) -> None: state, _ = cleaner.reset(key) diff --git a/jumanji/environments/routing/connector/env.py b/jumanji/environments/routing/connector/env.py index e76ba9da7..fc78dd891 100644 --- a/jumanji/environments/routing/connector/env.py +++ b/jumanji/environments/routing/connector/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Dict, Optional, Sequence, Tuple import chex @@ -46,7 +47,7 @@ from jumanji.viewer import Viewer -class Connector(Environment[State]): +class Connector(Environment[State, specs.MultiDiscreteArray, Observation]): """The `Connector` environment is a gridworld problem where multiple pairs of points (sets) must be connected without overlapping the paths taken by any other set. This is achieved by allowing certain points to move to an adjacent cell at each step. However, each time a @@ -88,7 +89,7 @@ class Connector(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_specc.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -118,6 +119,7 @@ def __init__( self.time_limit = time_limit self.num_agents = self._generator.num_agents self.grid_size = self._generator.grid_size + super().__init__() self._agent_ids = jnp.arange(self.num_agents) self._viewer = viewer or ConnectorViewer( "Connector", self.num_agents, render_mode="human" @@ -318,6 +320,7 @@ def close(self) -> None: """ self._viewer.close() + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Connector` environment. @@ -356,6 +359,7 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Returns the action spec for the Connector environment. diff --git a/jumanji/environments/routing/connector/env_test.py b/jumanji/environments/routing/connector/env_test.py index c0a649fff..12a6c1d94 100644 --- a/jumanji/environments/routing/connector/env_test.py +++ b/jumanji/environments/routing/connector/env_test.py @@ -24,7 +24,10 @@ from jumanji.environments.routing.connector.env import Connector from jumanji.environments.routing.connector.types import Agent, State from jumanji.environments.routing.connector.utils import get_position, get_target -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.tree_utils import tree_slice from jumanji.types import StepType, TimeStep @@ -230,6 +233,11 @@ def test_connector__does_not_smoke(connector: Connector) -> None: check_env_does_not_smoke(connector) +def test_connector__specs_does_not_smoke(connector: Connector) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(connector) + + def test_connector__get_action_mask(state: State, connector: Connector) -> None: """Validates the action masking.""" action_masks = jax.vmap(connector._get_action_mask, (0, None))( diff --git a/jumanji/environments/routing/cvrp/env.py b/jumanji/environments/routing/cvrp/env.py index ca6f76920..921dc646e 100644 --- a/jumanji/environments/routing/cvrp/env.py +++ b/jumanji/environments/routing/cvrp/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -30,7 +31,7 @@ from jumanji.viewer import Viewer -class CVRP(Environment[State]): +class CVRP(Environment[State, specs.DiscreteArray, Observation]): """Capacitated Vehicle Routing Problem (CVRP) environment as described in [1]. - observation: `Observation` @@ -89,7 +90,7 @@ class CVRP(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -121,6 +122,7 @@ def __init__( max_demand=10, ) self.num_nodes = self.generator.num_nodes + super().__init__() self.max_capacity = self.generator.max_capacity self.max_demand = self.generator.max_demand if self.max_capacity < self.max_demand: @@ -195,6 +197,7 @@ def step( ) return next_state, timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. @@ -261,6 +264,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """Returns the action spec. diff --git a/jumanji/environments/routing/cvrp/env_test.py b/jumanji/environments/routing/cvrp/env_test.py index 6ad2d2184..c0f828db2 100644 --- a/jumanji/environments/routing/cvrp/env_test.py +++ b/jumanji/environments/routing/cvrp/env_test.py @@ -19,7 +19,10 @@ from jumanji.environments.routing.cvrp.constants import DEPOT_IDX from jumanji.environments.routing.cvrp.env import CVRP from jumanji.environments.routing.cvrp.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -95,6 +98,10 @@ def test_cvrp_sparse__does_not_smoke(self, cvrp_sparse_reward: CVRP) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(cvrp_sparse_reward) + def test_cvrp_sparse__specs_does_not_smoke(self, cvrp_sparse_reward: CVRP) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(cvrp_sparse_reward) + def test_cvrp_sparse__trajectory_action(self, cvrp_sparse_reward: CVRP) -> None: """Tests a trajectory by visiting nodes in increasing and cyclic order, visiting the depot when the next node in the list surpasses the current capacity of the agent. diff --git a/jumanji/environments/routing/maze/env.py b/jumanji/environments/routing/maze/env.py index d0045144c..c2f0100dd 100644 --- a/jumanji/environments/routing/maze/env.py +++ b/jumanji/environments/routing/maze/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -30,7 +31,7 @@ from jumanji.viewer import Viewer -class Maze(Environment[State]): +class Maze(Environment[State, specs.DiscreteArray, Observation]): """A JAX implementation of a 2D Maze. The goal is to navigate the maze to find the target position. @@ -71,7 +72,7 @@ class Maze(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -100,6 +101,7 @@ def __init__( self.generator = generator or RandomGenerator(num_rows=10, num_cols=10) self.num_rows = self.generator.num_rows self.num_cols = self.generator.num_cols + super().__init__() self.shape = (self.num_rows, self.num_cols) self.time_limit = time_limit or self.num_rows * self.num_cols @@ -117,6 +119,7 @@ def __repr__(self) -> str: ] ) + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Maze` environment. @@ -159,6 +162,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. diff --git a/jumanji/environments/routing/maze/env_test.py b/jumanji/environments/routing/maze/env_test.py index e84c9b942..5cf2a69a4 100644 --- a/jumanji/environments/routing/maze/env_test.py +++ b/jumanji/environments/routing/maze/env_test.py @@ -20,7 +20,10 @@ from jumanji.environments.routing.maze.env import Maze from jumanji.environments.routing.maze.generator import RandomGenerator, ToyGenerator from jumanji.environments.routing.maze.types import Position, State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -227,3 +230,7 @@ def test_maze__toy_generator(self) -> None: def test_maze__does_not_smoke(self, maze: Maze) -> None: check_env_does_not_smoke(maze) + + def test_maze__specs_does_not_smoke(self, maze: Maze) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(maze) diff --git a/jumanji/environments/routing/mmst/env.py b/jumanji/environments/routing/mmst/env.py index 485fa87ad..386f2dd3c 100644 --- a/jumanji/environments/routing/mmst/env.py +++ b/jumanji/environments/routing/mmst/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Any, Dict, Optional, Sequence, Tuple import chex @@ -42,7 +43,7 @@ from jumanji.viewer import Viewer -class MMST(Environment[State]): +class MMST(Environment[State, specs.MultiDiscreteArray, Observation]): """The `MMST` (Multi Minimum Spanning Tree) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. @@ -124,7 +125,7 @@ class MMST(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -168,6 +169,7 @@ def __init__( self._env_viewer = viewer or MMSTViewer(num_agents=self.num_agents) self.time_limit = time_limit + super().__init__() def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment. @@ -209,7 +211,6 @@ def step_agent_fn( indices: chex.Array, agent_id: int, ) -> Tuple[chex.Array, ...]: - is_invalid_choice = jnp.any(action == INVALID_CHOICE) | jnp.any( action == INVALID_TIE_BREAK ) @@ -284,6 +285,7 @@ def step_agent_fn( state, timestep = self._state_to_timestep(state, action) return state, timestep + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Returns the action spec. @@ -295,6 +297,7 @@ def action_spec(self) -> specs.MultiDiscreteArray: name="action", ) + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. diff --git a/jumanji/environments/routing/mmst/env_test.py b/jumanji/environments/routing/mmst/env_test.py index ccd494f86..06f8e57d7 100644 --- a/jumanji/environments/routing/mmst/env_test.py +++ b/jumanji/environments/routing/mmst/env_test.py @@ -24,13 +24,16 @@ ) from jumanji.environments.routing.mmst.env import MMST from jumanji.environments.routing.mmst.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep def test__mmst_agent_observation( - deterministic_mmst_env: Tuple[MMST, State, TimeStep] + deterministic_mmst_env: Tuple[MMST, State, TimeStep], ) -> None: """Test that agent observation view of the node types is correct""" @@ -49,7 +52,7 @@ def test__mmst_agent_observation( def test__mmst_action_tie_break( - deterministic_mmst_env: Tuple[MMST, State, TimeStep] + deterministic_mmst_env: Tuple[MMST, State, TimeStep], ) -> None: """Test if the actions are mask correctly if multiple agents select the same node as next nodes. @@ -131,10 +134,14 @@ def test__mmst_does_not_smoke( check_env_does_not_smoke(mmst_split_gn_env) +def test__mmst_specs_does_not_smoke(mmst_split_gn_env: MMST) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(mmst_split_gn_env) + + def test__mmst_termination( - deterministic_mmst_env: Tuple[MMST, State, TimeStep] + deterministic_mmst_env: Tuple[MMST, State, TimeStep], ) -> None: - env, state, timestep = deterministic_mmst_env step_fn = jax.jit(env.step) @@ -170,7 +177,6 @@ def test__mmst_termination( def test__mmst_truncation(deterministic_mmst_env: Tuple[MMST, State, TimeStep]) -> None: - env, state, timestep = deterministic_mmst_env step_fn = jax.jit(env.step) @@ -182,9 +188,8 @@ def test__mmst_truncation(deterministic_mmst_env: Tuple[MMST, State, TimeStep]) def test__mmst_action_masking( - deterministic_mmst_env: Tuple[MMST, State, TimeStep] + deterministic_mmst_env: Tuple[MMST, State, TimeStep], ) -> None: - env, state, _ = deterministic_mmst_env step_fn = jax.jit(env.step) diff --git a/jumanji/environments/routing/multi_cvrp/env.py b/jumanji/environments/routing/multi_cvrp/env.py index 01aa933ce..2cd53c46c 100644 --- a/jumanji/environments/routing/multi_cvrp/env.py +++ b/jumanji/environments/routing/multi_cvrp/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -48,7 +49,7 @@ from jumanji.viewer import Viewer -class MultiCVRP(Environment[State]): +class MultiCVRP(Environment[State, specs.BoundedArray, Observation]): """ Multi-Vehicle Routing Problems with Soft Time Windows (MVRPSTW) environment as described in [1]. We simplfy the naming to multi-agent capacitated vehicle routing problem (MultiCVRP). @@ -71,7 +72,7 @@ class MultiCVRP(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -138,6 +139,7 @@ def __init__( max_single_vehicle_distance(self._map_max, self._num_customers) / self._speed ) + super().__init__() def __repr__(self) -> str: return f"MultiCVRP(num_customers={self._num_customers}, num_vehicles={self._num_vehicles})" @@ -188,6 +190,7 @@ def step( return new_state, timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """ Returns the observation spec. @@ -317,6 +320,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + @cached_property def action_spec(self) -> specs.BoundedArray: """ Returns the action spec. diff --git a/jumanji/environments/routing/multi_cvrp/env_test.py b/jumanji/environments/routing/multi_cvrp/env_test.py index a93aaf17e..f3292cf40 100644 --- a/jumanji/environments/routing/multi_cvrp/env_test.py +++ b/jumanji/environments/routing/multi_cvrp/env_test.py @@ -23,7 +23,10 @@ test_node_demand, ) from jumanji.environments.routing.multi_cvrp.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -278,3 +281,9 @@ def select_action( return select_action(subkeys, observation.action_mask) check_env_does_not_smoke(multicvrp_env, select_actions) + + def test_env_multicvrp__specs_does_not_smoke( + self, multicvrp_env: MultiCVRP + ) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(multicvrp_env) diff --git a/jumanji/environments/routing/pac_man/env.py b/jumanji/environments/routing/pac_man/env.py index 6db84f16c..3007042b2 100644 --- a/jumanji/environments/routing/pac_man/env.py +++ b/jumanji/environments/routing/pac_man/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Any, Optional, Sequence, Tuple import chex @@ -35,7 +36,7 @@ from jumanji.viewer import Viewer -class PacMan(Environment[State]): +class PacMan(Environment[State, specs.DiscreteArray, Observation]): """A JAX implementation of the 'PacMan' game where a single agent must navigate a maze to collect pellets and avoid 4 heuristic agents. The game takes place on a 31x28 grid where the player can move in 4 directions (left, right, up, down) and collect @@ -103,7 +104,7 @@ class PacMan(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -129,9 +130,11 @@ def __init__( self.x_size = self.generator.x_size self.y_size = self.generator.y_size self.pellet_spaces = self.generator.pellet_spaces + super().__init__() self._viewer = viewer or PacManViewer("Pacman", render_mode="human") self.time_limit = 1000 or time_limit + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `PacMan` environment. @@ -199,6 +202,7 @@ def observation_spec(self) -> specs.Spec[Observation]: score=score, ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """Returns the action spec. @@ -210,7 +214,6 @@ def action_spec(self) -> specs.DiscreteArray: return specs.DiscreteArray(5, name="action") def __repr__(self) -> str: - return ( f"PacMan(\n" f"\tnum_rows={self.x_size!r},\n" @@ -460,7 +463,6 @@ def check_power_up( return power_up_locations, eat, reward def check_wall_collisions(self, state: State, new_player_pos: Position) -> Any: - """ Check if the new player position collides with a wall. diff --git a/jumanji/environments/routing/pac_man/env_test.py b/jumanji/environments/routing/pac_man/env_test.py index f2ab5a7ec..54bc13d5c 100644 --- a/jumanji/environments/routing/pac_man/env_test.py +++ b/jumanji/environments/routing/pac_man/env_test.py @@ -19,7 +19,10 @@ from jumanji.environments.routing.pac_man.env import PacMan from jumanji.environments.routing.pac_man.types import Position, State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -96,12 +99,15 @@ def test_pac_man_step_invalid(pac_man: PacMan) -> None: def test_pac_man_does_not_smoke(pac_man: PacMan) -> None: - check_env_does_not_smoke(pac_man) -def test_power_pellet(pac_man: PacMan) -> None: +def test_env_pac_man_specs_does_not_smoke(pac_man: PacMan) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(pac_man) + +def test_power_pellet(pac_man: PacMan) -> None: key = jax.random.PRNGKey(0) state, timestep = pac_man.reset(key) diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py index 908fe8053..eb9c2c578 100644 --- a/jumanji/environments/routing/robot_warehouse/env.py +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +from functools import cached_property from typing import List, Optional, Sequence, Tuple import chex @@ -48,7 +49,7 @@ from jumanji.viewer import Viewer -class RobotWarehouse(Environment[State]): +class RobotWarehouse(Environment[State, specs.MultiDiscreteArray, Observation]): """A JAX implementation of the 'Robotic warehouse' environment: https://github.com/semitable/robotic-warehouse which is described in the paper [1]. @@ -127,7 +128,7 @@ class RobotWarehouse(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -182,6 +183,7 @@ def __init__( ) self.goals = self._generator.goals self.time_limit = time_limit + super().__init__() # create viewer for rendering environment self._viewer = viewer or RobotWarehouseViewer( @@ -334,6 +336,7 @@ def update_reward_and_request_queue_scan( ) return next_state, timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Specification of the observation of the `RobotWarehouse` environment. Returns: @@ -357,6 +360,7 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. Since this is a multi-agent environment, the environment expects an array of actions. diff --git a/jumanji/environments/routing/robot_warehouse/env_test.py b/jumanji/environments/routing/robot_warehouse/env_test.py index e5d60b94d..cf37e3b2e 100644 --- a/jumanji/environments/routing/robot_warehouse/env_test.py +++ b/jumanji/environments/routing/robot_warehouse/env_test.py @@ -21,7 +21,10 @@ from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse from jumanji.environments.routing.robot_warehouse.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.tree_utils import tree_slice from jumanji.types import TimeStep @@ -29,8 +32,8 @@ def test_robot_warehouse__specs(robot_warehouse_env: RobotWarehouse) -> None: """Validate environment specs conform to the expected shapes and values""" - action_spec = robot_warehouse_env.action_spec() - observation_spec = robot_warehouse_env.observation_spec() + action_spec = robot_warehouse_env.action_spec + observation_spec = robot_warehouse_env.observation_spec assert observation_spec.agents_view.shape == (2, 66) # type: ignore assert action_spec.num_values.shape[0] == robot_warehouse_env.num_agents @@ -60,7 +63,7 @@ def test_robot_warehouse__reset(robot_warehouse_env: RobotWarehouse) -> None: def test_robot_warehouse__agent_observation( - deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep], ) -> None: """Validate the agent observation function.""" env, state, timestep = deterministic_robot_warehouse_env @@ -163,6 +166,13 @@ def test_robot_warehouse__does_not_smoke(robot_warehouse_env: RobotWarehouse) -> check_env_does_not_smoke(robot_warehouse_env) +def test_robot_warehouse__specs_does_not_smoke( + robot_warehouse_env: RobotWarehouse, +) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(robot_warehouse_env) + + def test_robot_warehouse__time_limit(robot_warehouse_env: RobotWarehouse) -> None: """Validate the terminal reward.""" step_fn = jax.jit(robot_warehouse_env.step) @@ -179,7 +189,7 @@ def test_robot_warehouse__time_limit(robot_warehouse_env: RobotWarehouse) -> Non def test_robot_warehouse__truncation( - deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep], ) -> None: """Validate episode truncation based on set time limit.""" robot_warehouse_env, state, timestep = deterministic_robot_warehouse_env @@ -197,7 +207,7 @@ def test_robot_warehouse__truncation( def test_robot_warehouse__truncate_upon_collision( - deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep], ) -> None: """Validate episode terminates upon collision of agents.""" robot_warehouse_env, state, timestep = deterministic_robot_warehouse_env @@ -217,7 +227,7 @@ def test_robot_warehouse__truncate_upon_collision( def test_robot_warehouse__reward_in_goal( - deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep], ) -> None: """Validate goal reward behavior.""" robot_warehouse_env, state, timestep = deterministic_robot_warehouse_env diff --git a/jumanji/environments/routing/snake/env.py b/jumanji/environments/routing/snake/env.py index fafd0cd20..0a1d0451c 100644 --- a/jumanji/environments/routing/snake/env.py +++ b/jumanji/environments/routing/snake/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -29,7 +30,7 @@ from jumanji.viewer import Viewer -class Snake(Environment[State]): +class Snake(Environment[State, specs.DiscreteArray, Observation]): """A JAX implementation of the 'Snake' game. - observation: `Observation` @@ -84,7 +85,7 @@ class Snake(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -108,11 +109,11 @@ def __init__( the episode ends. Defaults to 4000. viewer: `Viewer` used for rendering. Defaults to `SnakeViewer`. """ - super().__init__() self.num_rows = num_rows self.num_cols = num_cols self.board_shape = (num_rows, num_cols) self.time_limit = time_limit + super().__init__() self._viewer = viewer or SnakeViewer() def __repr__(self) -> str: @@ -235,6 +236,7 @@ def step( ) return next_state, timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. @@ -269,6 +271,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. diff --git a/jumanji/environments/routing/snake/env_test.py b/jumanji/environments/routing/snake/env_test.py index df37aff3a..c7b0e9711 100644 --- a/jumanji/environments/routing/snake/env_test.py +++ b/jumanji/environments/routing/snake/env_test.py @@ -22,7 +22,10 @@ from jumanji.environments.routing.snake.env import Snake, State from jumanji.environments.routing.snake.types import Position -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -60,7 +63,7 @@ def test_snake__step(snake: Snake) -> None: # Sample two different actions action1, action2 = jax.random.choice( action_key, - jnp.arange(snake.action_spec()._num_values), + jnp.arange(snake.action_spec._num_values), shape=(2,), replace=False, ) @@ -94,6 +97,11 @@ def test_snake__does_not_smoke(snake: Snake) -> None: check_env_does_not_smoke(snake) +def test_snake__specs_does_not_smoke(snake: Snake) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(snake) + + def test_update_head_position(snake: Snake) -> None: """Validates _update_head_position method. Checks that starting from a certain position, taking some actions @@ -140,7 +148,7 @@ def test_snake__render(monkeypatch: pytest.MonkeyPatch, snake: Snake) -> None: monkeypatch.setattr(plt, "show", lambda fig: None) step_fn = jax.jit(snake.step) state, timestep = snake.reset(jax.random.PRNGKey(0)) - action = snake.action_spec().generate_value() + action = snake.action_spec.generate_value() state, timestep = step_fn(state, action) snake.render(state) snake.close() @@ -151,7 +159,7 @@ def test_snake__animation(snake: Snake, tmpdir: py.path.local) -> None: step_fn = jax.jit(snake.step) state, _ = snake.reset(jax.random.PRNGKey(0)) states = [state] - action = snake.action_spec().generate_value() + action = snake.action_spec.generate_value() state, _ = step_fn(state, action) states.append(state) animation = snake.animate(states) diff --git a/jumanji/environments/routing/sokoban/env.py b/jumanji/environments/routing/sokoban/env.py index c56fcbf89..2433df322 100644 --- a/jumanji/environments/routing/sokoban/env.py +++ b/jumanji/environments/routing/sokoban/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Dict, Optional, Sequence, Tuple import chex @@ -45,7 +46,7 @@ from jumanji.viewer import Viewer -class Sokoban(Environment[State]): +class Sokoban(Environment[State, specs.DiscreteArray, Observation]): """A JAX implementation of the 'Sokoban' game from deepmind. - observation: `Observation` @@ -103,7 +104,7 @@ class Sokoban(Environment[State]): key_train = jax.random.PRNGKey(0) state, timestep = jax.jit(env_train.reset)(key_train) env_train.render(state) - action = env_train.action_spec().generate_value() + action = env_train.action_spec.generate_value() state, timestep = jax.jit(env_train.step)(state, action) env_train.render(state) ``` @@ -136,6 +137,8 @@ def __init__( self.shape = (self.num_rows, self.num_cols) self.time_limit = time_limit + super().__init__() + self.generator = generator or HuggingFaceDeepMindGenerator( "unfiltered-train", proportion_of_files=1, @@ -256,6 +259,7 @@ def step( return next_state, timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """ Returns the specifications of the observation of the `Sokoban` @@ -279,6 +283,7 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """ Returns the action specification for the Sokoban environment. diff --git a/jumanji/environments/routing/sokoban/env_test.py b/jumanji/environments/routing/sokoban/env_test.py index c6e935e4f..e7b61959c 100644 --- a/jumanji/environments/routing/sokoban/env_test.py +++ b/jumanji/environments/routing/sokoban/env_test.py @@ -26,7 +26,10 @@ SimpleSolveGenerator, ) from jumanji.environments.routing.sokoban.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.types import TimeStep @@ -214,3 +217,8 @@ def test_sokoban__reward_function_solved(sokoban_simple: Sokoban) -> None: def test_sokoban__does_not_smoke(sokoban: Sokoban) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(sokoban) + + +def test_sokoban__specs_does_not_smoke(sokoban: Sokoban) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(sokoban) diff --git a/jumanji/environments/routing/tsp/env.py b/jumanji/environments/routing/tsp/env.py index 0428e646c..f6d57bf93 100644 --- a/jumanji/environments/routing/tsp/env.py +++ b/jumanji/environments/routing/tsp/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -31,7 +32,7 @@ from jumanji.viewer import Viewer -class TSP(Environment[State]): +class TSP(Environment[State, specs.DiscreteArray, Observation]): """Traveling Salesman Problem (TSP) environment as described in [1]. - observation: Observation @@ -83,7 +84,7 @@ class TSP(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -112,6 +113,7 @@ def __init__( num_cities=20, ) self.num_cities = self.generator.num_cities + super().__init__() self.reward_fn = reward_fn or DenseReward() self._viewer = viewer or TSPViewer(name="TSP", render_mode="human") @@ -169,6 +171,7 @@ def step( ) return next_state, timestep + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. @@ -212,6 +215,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """Returns the action spec. diff --git a/jumanji/environments/routing/tsp/env_test.py b/jumanji/environments/routing/tsp/env_test.py index d43e51f13..7d320b13d 100644 --- a/jumanji/environments/routing/tsp/env_test.py +++ b/jumanji/environments/routing/tsp/env_test.py @@ -19,7 +19,10 @@ from jumanji.environments.routing.tsp.env import TSP from jumanji.environments.routing.tsp.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -198,6 +201,12 @@ def test_tsp_sparse__does_not_smoke( """Test that we can run an episode without any errors.""" check_env_does_not_smoke(tsp_sparse_reward) + def test_tsp_sparse__specs_does_not_smoke( + self, tsp_sparse_reward: TSP, capsys: pytest.CaptureFixture + ) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(tsp_sparse_reward) + def test_tsp_sparse__trajectory_action(self, tsp_sparse_reward: TSP) -> None: """Checks that the agent stops when there are no more cities to be selected and that the appropriate reward is received. The testing loop ensures that no city is selected twice. diff --git a/jumanji/registration_test.py b/jumanji/registration_test.py index 5ec06bd0a..a37e05935 100644 --- a/jumanji/registration_test.py +++ b/jumanji/registration_test.py @@ -18,7 +18,7 @@ import pytest_mock import jumanji -from jumanji import registration +from jumanji import registration, specs from jumanji.testing.fakes import FakeEnvironment @@ -91,10 +91,10 @@ def test_register__override_kwargs(mocker: pytest_mock.MockerFixture) -> None: id=env_id, entry_point="jumanji.testing.fakes:FakeEnvironment", ) - env: FakeEnvironment = registration.make( # type: ignore + obs_spec: specs.Array = registration.make( # type: ignore env_id, observation_shape=obs_shape - ) - assert env.observation_spec().shape == obs_shape + ).observation_spec + assert obs_spec.shape == obs_shape def test_registration__make() -> None: diff --git a/jumanji/testing/env_not_smoke.py b/jumanji/testing/env_not_smoke.py index 8a3cb34b4..4a9603c9d 100644 --- a/jumanji/testing/env_not_smoke.py +++ b/jumanji/testing/env_not_smoke.py @@ -20,8 +20,8 @@ from jumanji import specs from jumanji.env import Environment +from jumanji.types import Observation -Observation = TypeVar("Observation") Action = TypeVar("Action") SelectActionFn = Callable[[chex.PRNGKey, Observation], Action] @@ -29,7 +29,7 @@ def make_random_select_action_fn( action_spec: Union[ specs.BoundedArray, specs.DiscreteArray, specs.MultiDiscreteArray - ] + ], ) -> SelectActionFn: """Create select action function that chooses random actions.""" @@ -72,17 +72,16 @@ def check_env_does_not_smoke( assert_finite_check: bool = True, ) -> None: """Run an episode of the environment, with a jitted step function to check no errors occur.""" - action_spec = env.action_spec() if select_action is None: - if isinstance(action_spec, specs.BoundedArray) or isinstance( - action_spec, specs.DiscreteArray + if isinstance(env.action_spec, specs.BoundedArray) or isinstance( + env.action_spec, specs.DiscreteArray ): - select_action = make_random_select_action_fn(action_spec) + select_action = make_random_select_action_fn(env.action_spec) else: raise NotImplementedError( f"Currently the `make_random_select_action_fn` only works for environments with " f"either discrete actions or bounded continuous actions. The input environment to " - f"this test has an action spec of type {action_spec}, and therefore requires " + f"this test has an action spec of type {env.action_spec}, and therefore requires " f"a custom `SelectActionFn` to be provided to this test." ) key = jax.random.PRNGKey(0) @@ -92,8 +91,21 @@ def check_env_does_not_smoke( while not timestep.last(): key, action_key = jax.random.split(key) action = select_action(action_key, timestep.observation) - env.action_spec().validate(action) + env.action_spec.validate(action) state, timestep = step_fn(state, action) - env.observation_spec().validate(timestep.observation) + env.observation_spec.validate(timestep.observation) if assert_finite_check: chex.assert_tree_all_finite((state, timestep)) + + +def access_specs(env: Environment) -> None: + """Access specs of the environment.""" + env.observation_spec + env.action_spec + env.reward_spec + env.discount_spec + + +def check_env_specs_does_not_smoke(env: Environment) -> None: + """Access specs of the environment in a jitted function to check no errors occur.""" + jax.jit(access_specs, static_argnums=0)(env) diff --git a/jumanji/testing/env_not_smoke_test.py b/jumanji/testing/env_not_smoke_test.py index d11f357d1..8900aaaeb 100644 --- a/jumanji/testing/env_not_smoke_test.py +++ b/jumanji/testing/env_not_smoke_test.py @@ -20,6 +20,7 @@ from jumanji.testing.env_not_smoke import ( SelectActionFn, check_env_does_not_smoke, + check_env_specs_does_not_smoke, make_random_select_action_fn, ) from jumanji.testing.fakes import FakeEnvironment @@ -54,10 +55,15 @@ def test_random_select_action(fake_env: FakeEnvironment) -> None: """Validate that the `select_action` method returns random actions meeting the environment spec.""" key = jax.random.PRNGKey(0) - select_action = make_random_select_action_fn(fake_env.action_spec()) + select_action = make_random_select_action_fn(fake_env.action_spec) key1, key2, key3 = jax.random.split(key, 3) env_state, timestep = fake_env.reset(key1) action_1 = select_action(key2, timestep.observation) action_2 = select_action(key3, timestep.observation) - fake_env.action_spec().validate(action_1) + fake_env.action_spec.validate(action_1) assert not jnp.all(action_1 == action_2) + + +def test_env_specs_not_smoke(fake_env: FakeEnvironment) -> None: + """Test that the""" + check_env_specs_does_not_smoke(fake_env) diff --git a/jumanji/testing/fakes.py b/jumanji/testing/fakes.py index 0835a47ac..a41246e1d 100644 --- a/jumanji/testing/fakes.py +++ b/jumanji/testing/fakes.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import TYPE_CHECKING, Tuple if TYPE_CHECKING: @@ -34,7 +35,7 @@ class FakeState: step: jnp.int32 -class FakeEnvironment(Environment[FakeState]): +class FakeEnvironment(Environment[FakeState, specs.BoundedArray, chex.Array]): """ A fake environment that inherits from Environment, for testing purposes. The observation is an array full of `state.step` of shape `(self.observation_shape,)` @@ -56,8 +57,10 @@ def __init__( self.time_limit = time_limit self.observation_shape = observation_shape self.action_shape = action_shape - self._example_action = self.action_spec().generate_value() + super().__init__() + self._example_action = self.action_spec.generate_value() + @cached_property def observation_spec(self) -> specs.Array: """Returns the observation spec. @@ -69,6 +72,7 @@ def observation_spec(self) -> specs.Array: shape=self.observation_shape, dtype=float, name="observation" ) + @cached_property def action_spec(self) -> specs.BoundedArray: """Returns the action spec. @@ -142,7 +146,7 @@ def _state_to_obs(self, state: FakeState) -> chex.Array: return state.step * jnp.ones(self.observation_shape, float) -class FakeMultiEnvironment(Environment[FakeState]): +class FakeMultiEnvironment(Environment[FakeState, specs.BoundedArray, chex.Array]): """ A fake multi agent environment that inherits from Environment, for testing purposes. """ @@ -169,12 +173,14 @@ def __init__( self.observation_shape = observation_shape self.num_action_values = num_action_values self.num_agents = num_agents + super().__init__() self.reward_per_step = reward_per_step assert ( observation_shape[0] == num_agents ), f"""a leading dimension of size 'num_agents': {num_agents} is expected for the observation, got shape: {observation_shape}.""" + @cached_property def observation_spec(self) -> specs.Array: """Returns the observation spec. @@ -186,6 +192,7 @@ def observation_spec(self) -> specs.Array: shape=self.observation_shape, dtype=float, name="observation" ) + @cached_property def action_spec(self) -> specs.BoundedArray: """Returns the action spec. @@ -197,6 +204,7 @@ def action_spec(self) -> specs.BoundedArray: (self.num_agents,), int, 0, self.num_action_values - 1 ) + @cached_property def reward_spec(self) -> specs.Array: """Returns the reward spec. @@ -205,6 +213,7 @@ def reward_spec(self) -> specs.Array: """ return specs.Array(shape=(self.num_agents,), dtype=float, name="reward") + @cached_property def discount_spec(self) -> specs.BoundedArray: """Describes the discount returned by the environment. @@ -231,7 +240,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[FakeState, TimeStep]: """ state = FakeState(key=key, step=0) - observation = self.observation_spec().generate_value() + observation = self.observation_spec.generate_value() timestep = restart(observation=observation, shape=(self.num_agents,)) return state, timestep diff --git a/jumanji/testing/fakes_test.py b/jumanji/testing/fakes_test.py index 42cb98415..81d75dd16 100644 --- a/jumanji/testing/fakes_test.py +++ b/jumanji/testing/fakes_test.py @@ -31,7 +31,7 @@ def test_fake_environment__reset(fake_environment: fakes.FakeEnvironment) -> Non def test_fake_environment__step(fake_environment: fakes.FakeEnvironment) -> None: """Validates the step function of the fake environment.""" state, timestep = fake_environment.reset(random.PRNGKey(0)) - action = fake_environment.action_spec().generate_value() + action = fake_environment.action_spec.generate_value() next_state, timestep = fake_environment.step(state, action) # Check that the step value is now different assert state.step != next_state.step @@ -43,7 +43,7 @@ def test_fake_environment__does_not_smoke( ) -> None: """Validates the run of an episode in the fake environment. Check that it does not smoke.""" state, timestep = fake_environment.reset(random.PRNGKey(0)) - action = fake_environment.action_spec().generate_value() + action = fake_environment.action_spec.generate_value() while not timestep.last(): state, timestep = fake_environment.step(state, action) @@ -67,7 +67,7 @@ def test_fake_multi_environment__step( ) -> None: """Validates the step function of the fake multi agent environment.""" state, timestep = fake_multi_environment.reset(random.PRNGKey(0)) - action = fake_multi_environment.action_spec().generate_value() + action = fake_multi_environment.action_spec.generate_value() assert action.shape[0] == fake_multi_environment.num_agents next_state, timestep = fake_multi_environment.step(state, action) @@ -85,7 +85,7 @@ def test_fake_multi_environment__does_not_smoke( """Validates the run of an episode in the fake multi agent environment. Check that it does not smoke.""" state, timestep = fake_multi_environment.reset(random.PRNGKey(0)) - action = fake_multi_environment.action_spec().generate_value() + action = fake_multi_environment.action_spec.generate_value() assert action.shape[0] == fake_multi_environment.num_agents while not timestep.last(): state, timestep = fake_multi_environment.step(state, action) diff --git a/jumanji/training/agents/a2c/a2c_agent.py b/jumanji/training/agents/a2c/a2c_agent.py index 09eca211b..2392ab05c 100644 --- a/jumanji/training/agents/a2c/a2c_agent.py +++ b/jumanji/training/agents/a2c/a2c_agent.py @@ -51,7 +51,7 @@ def __init__( ) -> None: super().__init__(total_batch_size=total_batch_size) self.env = env - self.observation_spec = env.observation_spec() + self.observation_spec = env.observation_spec self.n_steps = n_steps self.actor_critic_networks = actor_critic_networks self.optimizer = optimizer diff --git a/jumanji/training/agents/random/random_agent.py b/jumanji/training/agents/random/random_agent.py index 4c17edd48..904fbe98f 100644 --- a/jumanji/training/agents/random/random_agent.py +++ b/jumanji/training/agents/random/random_agent.py @@ -33,7 +33,7 @@ def __init__( ) -> None: super().__init__(total_batch_size=total_batch_size) self.env = env - self.observation_spec = env.observation_spec() + self.observation_spec = env.observation_spec self.n_steps = n_steps self.random_policy = random_policy diff --git a/jumanji/training/networks/bin_pack/actor_critic.py b/jumanji/training/networks/bin_pack/actor_critic.py index 1643a5dc8..c3de93f8f 100644 --- a/jumanji/training/networks/bin_pack/actor_critic.py +++ b/jumanji/training/networks/bin_pack/actor_critic.py @@ -40,7 +40,7 @@ def make_actor_critic_networks_bin_pack( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `BinPack` environment.""" - num_values = np.asarray(bin_pack.action_spec().num_values) + num_values = np.asarray(bin_pack.action_spec.num_values) parametric_action_distribution = FactorisedActionSpaceParametricDistribution( action_spec_num_values=num_values ) diff --git a/jumanji/training/networks/bin_pack/random.py b/jumanji/training/networks/bin_pack/random.py index 537772239..add32d75a 100644 --- a/jumanji/training/networks/bin_pack/random.py +++ b/jumanji/training/networks/bin_pack/random.py @@ -21,7 +21,7 @@ def make_random_policy_bin_pack(bin_pack: BinPack) -> RandomPolicy: """Make random policy for BinPack.""" - action_spec_num_values = bin_pack.action_spec().num_values + action_spec_num_values = bin_pack.action_spec.num_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values ) diff --git a/jumanji/training/networks/cleaner/actor_critic.py b/jumanji/training/networks/cleaner/actor_critic.py index b8002df3b..2fedfe289 100644 --- a/jumanji/training/networks/cleaner/actor_critic.py +++ b/jumanji/training/networks/cleaner/actor_critic.py @@ -39,7 +39,7 @@ def make_actor_critic_networks_cleaner( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Cleaner` environment.""" - num_values = np.asarray(cleaner.action_spec().num_values) + num_values = np.asarray(cleaner.action_spec.num_values) parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=num_values ) diff --git a/jumanji/training/networks/connector/actor_critic.py b/jumanji/training/networks/connector/actor_critic.py index 1473de23b..edcc4946f 100644 --- a/jumanji/training/networks/connector/actor_critic.py +++ b/jumanji/training/networks/connector/actor_critic.py @@ -46,7 +46,7 @@ def make_actor_critic_networks_connector( conv_n_channels: int, ) -> ActorCriticNetworks: """Make actor-critic networks for the `Connector` environment.""" - num_values = np.asarray(connector.action_spec().num_values) + num_values = np.asarray(connector.action_spec.num_values) parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=num_values ) diff --git a/jumanji/training/networks/cvrp/actor_critic.py b/jumanji/training/networks/cvrp/actor_critic.py index 5bae24f7f..e5b498a2a 100644 --- a/jumanji/training/networks/cvrp/actor_critic.py +++ b/jumanji/training/networks/cvrp/actor_critic.py @@ -38,7 +38,7 @@ def make_actor_critic_networks_cvrp( mean_nodes_in_query: bool, ) -> ActorCriticNetworks: """Make actor-critic networks for the `CVRP` environment.""" - num_actions = cvrp.action_spec().num_values + num_actions = cvrp.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/flat_pack/actor_critic.py b/jumanji/training/networks/flat_pack/actor_critic.py index aeb1819b0..40aeaa037 100644 --- a/jumanji/training/networks/flat_pack/actor_critic.py +++ b/jumanji/training/networks/flat_pack/actor_critic.py @@ -40,7 +40,7 @@ def make_actor_critic_networks_flat_pack( hidden_size: int, ) -> ActorCriticNetworks: """Make actor-critic networks for the `FlatPack` environment.""" - num_values = np.asarray(flat_pack.action_spec().num_values) + num_values = np.asarray(flat_pack.action_spec.num_values) parametric_action_distribution = FactorisedActionSpaceParametricDistribution( action_spec_num_values=num_values ) @@ -172,7 +172,6 @@ def __call__(self, observation: Observation) -> Tuple[chex.Array, chex.Array]: ) # (B, model_size), (B, num_rows-2, num_cols-2, hidden_size) for block_id in range(self.num_transformer_layers): - ( self_attention_mask, # (B, 1, num_blocks, num_blocks) cross_attention_mask, # (B, 1, num_blocks, 1) diff --git a/jumanji/training/networks/flat_pack/random.py b/jumanji/training/networks/flat_pack/random.py index a81ba43f0..7c8c09463 100644 --- a/jumanji/training/networks/flat_pack/random.py +++ b/jumanji/training/networks/flat_pack/random.py @@ -21,7 +21,7 @@ def make_random_policy_flat_pack(flat_pack: FlatPack) -> RandomPolicy: """Make random policy for FlatPack.""" - action_spec_num_values = flat_pack.action_spec().num_values + action_spec_num_values = flat_pack.action_spec.num_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values diff --git a/jumanji/training/networks/game_2048/actor_critic.py b/jumanji/training/networks/game_2048/actor_critic.py index 6a7d66fa5..caa4351dd 100644 --- a/jumanji/training/networks/game_2048/actor_critic.py +++ b/jumanji/training/networks/game_2048/actor_critic.py @@ -37,7 +37,7 @@ def make_actor_critic_networks_game_2048( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Game2048` environment.""" - num_actions = game_2048.action_spec().num_values + num_actions = game_2048.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/graph_coloring/actor_critic.py b/jumanji/training/networks/graph_coloring/actor_critic.py index 2833061c0..6e2e336f6 100644 --- a/jumanji/training/networks/graph_coloring/actor_critic.py +++ b/jumanji/training/networks/graph_coloring/actor_critic.py @@ -38,7 +38,7 @@ def make_actor_critic_networks_graph_coloring( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `GraphColoring` environment.""" - num_actions = graph_coloring.action_spec().num_values + num_actions = graph_coloring.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/job_shop/actor_critic.py b/jumanji/training/networks/job_shop/actor_critic.py index 0c09c8bf6..77d070a17 100644 --- a/jumanji/training/networks/job_shop/actor_critic.py +++ b/jumanji/training/networks/job_shop/actor_critic.py @@ -42,7 +42,7 @@ def make_actor_critic_networks_job_shop( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Create an actor-critic network for the `JobShop` environment.""" - num_values = np.asarray(job_shop.action_spec().num_values) + num_values = np.asarray(job_shop.action_spec.num_values) parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=num_values ) diff --git a/jumanji/training/networks/knapsack/actor_critic.py b/jumanji/training/networks/knapsack/actor_critic.py index 799a08e4d..b8a676e23 100644 --- a/jumanji/training/networks/knapsack/actor_critic.py +++ b/jumanji/training/networks/knapsack/actor_critic.py @@ -36,7 +36,7 @@ def make_actor_critic_networks_knapsack( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Knapsack` environment.""" - num_actions = knapsack.action_spec().num_values + num_actions = knapsack.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/maze/actor_critic.py b/jumanji/training/networks/maze/actor_critic.py index ae93b0286..8d236ef5a 100644 --- a/jumanji/training/networks/maze/actor_critic.py +++ b/jumanji/training/networks/maze/actor_critic.py @@ -37,7 +37,7 @@ def make_actor_critic_networks_maze( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Maze` environment.""" - num_actions = np.asarray(maze.action_spec().num_values) + num_actions = np.asarray(maze.action_spec.num_values) parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/minesweeper/actor_critic.py b/jumanji/training/networks/minesweeper/actor_critic.py index 3789be5f3..593673f9f 100644 --- a/jumanji/training/networks/minesweeper/actor_critic.py +++ b/jumanji/training/networks/minesweeper/actor_critic.py @@ -45,7 +45,7 @@ def make_actor_critic_networks_minesweeper( vocab_size = 1 + PATCH_SIZE**2 # unexplored, or 0, 1, ..., 8 parametric_action_distribution = FactorisedActionSpaceParametricDistribution( - action_spec_num_values=np.asarray(minesweeper.action_spec().num_values) + action_spec_num_values=np.asarray(minesweeper.action_spec.num_values) ) policy_network = make_network_cnn( vocab_size=vocab_size, diff --git a/jumanji/training/networks/minesweeper/random.py b/jumanji/training/networks/minesweeper/random.py index b7e80a3de..c7194091f 100644 --- a/jumanji/training/networks/minesweeper/random.py +++ b/jumanji/training/networks/minesweeper/random.py @@ -22,7 +22,7 @@ def make_random_policy_minesweeper(minesweeper: Minesweeper) -> RandomPolicy: """Make random policy for Minesweeper.""" - action_spec_num_values = minesweeper.action_spec().num_values + action_spec_num_values = minesweeper.action_spec.num_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values diff --git a/jumanji/training/networks/mmst/actor_critic.py b/jumanji/training/networks/mmst/actor_critic.py index ccc5d5674..45e776b4c 100644 --- a/jumanji/training/networks/mmst/actor_critic.py +++ b/jumanji/training/networks/mmst/actor_critic.py @@ -38,7 +38,7 @@ def make_actor_critic_networks_mmst( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `MMST` environment.""" - num_values = mmst.action_spec().num_values + num_values = mmst.action_spec.num_values parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=num_values ) @@ -96,12 +96,10 @@ def get_node_feats(node: chex.Array) -> chex.Array: return embeddings def embed_agents(self, agents: chex.Array) -> chex.Array: - embeddings = hk.Linear(self.model_size, name="agent_projection")(agents) return embeddings def __call__(self, observation: Observation) -> chex.Array: - batch_size, num_nodes = observation.node_types.shape num_agents = observation.positions.shape[1] agents_used = jnp.arange(num_agents).reshape(-1, 1) diff --git a/jumanji/training/networks/multi_cvrp/actor_critic.py b/jumanji/training/networks/multi_cvrp/actor_critic.py index 268426149..3300b7835 100644 --- a/jumanji/training/networks/multi_cvrp/actor_critic.py +++ b/jumanji/training/networks/multi_cvrp/actor_critic.py @@ -46,7 +46,7 @@ def make_actor_critic_networks_multicvrp( # Add depot to the number of customers num_customers += 1 - num_actions = MultiCVRP.action_spec().maximum + num_actions = MultiCVRP.action_spec.maximum parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=np.asarray(num_actions).reshape(1) ) @@ -161,7 +161,6 @@ def customer_encoder( o_customers: chex.Array, v_embedding: chex.Array, ) -> chex.Array: - # Embed the depot differently # (B, C, D) depot_projection = hk.Linear(self.model_size, name="depot_projection") @@ -211,7 +210,6 @@ def vehicle_encoder( v_embedding: chex.Array, c_embedding: chex.Array, ) -> chex.Array: - # Projection of the operations embeddings = hk.Linear(self.model_size, name="o_vehicle_projections")( v_embedding diff --git a/jumanji/training/networks/pac_man/actor_critic.py b/jumanji/training/networks/pac_man/actor_critic.py index 566adf7b9..59ca353ad 100644 --- a/jumanji/training/networks/pac_man/actor_critic.py +++ b/jumanji/training/networks/pac_man/actor_critic.py @@ -37,7 +37,7 @@ def make_actor_critic_networks_pacman( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `PacMan` environment.""" - num_actions = np.asarray(pac_man.action_spec().num_values) + num_actions = np.asarray(pac_man.action_spec.num_values) parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/robot_warehouse/actor_critic.py b/jumanji/training/networks/robot_warehouse/actor_critic.py index a1aca10cd..965caf397 100644 --- a/jumanji/training/networks/robot_warehouse/actor_critic.py +++ b/jumanji/training/networks/robot_warehouse/actor_critic.py @@ -39,7 +39,7 @@ def make_actor_critic_networks_robot_warehouse( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `RobotWarehouse` environment.""" - num_values = np.asarray(robot_warehouse.action_spec().num_values) + num_values = np.asarray(robot_warehouse.action_spec.num_values) parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=num_values ) diff --git a/jumanji/training/networks/rubiks_cube/actor_critic.py b/jumanji/training/networks/rubiks_cube/actor_critic.py index 5e79d9e38..53a2643ac 100644 --- a/jumanji/training/networks/rubiks_cube/actor_critic.py +++ b/jumanji/training/networks/rubiks_cube/actor_critic.py @@ -37,7 +37,7 @@ def make_actor_critic_networks_rubiks_cube( dense_layer_dims: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `RubiksCube` environment.""" - action_spec_num_values = np.asarray(rubiks_cube.action_spec().num_values) + action_spec_num_values = np.asarray(rubiks_cube.action_spec.num_values) num_actions = int(np.prod(action_spec_num_values)) parametric_action_distribution = FactorisedActionSpaceParametricDistribution( action_spec_num_values=action_spec_num_values diff --git a/jumanji/training/networks/rubiks_cube/random.py b/jumanji/training/networks/rubiks_cube/random.py index b8d18dd0d..3040db114 100644 --- a/jumanji/training/networks/rubiks_cube/random.py +++ b/jumanji/training/networks/rubiks_cube/random.py @@ -21,8 +21,8 @@ def make_random_policy_rubiks_cube(rubiks_cube: RubiksCube) -> RandomPolicy: """Make random policy for RubiksCube.""" - action_minimum = rubiks_cube.action_spec().minimum - action_maximum = rubiks_cube.action_spec().maximum + action_minimum = rubiks_cube.action_spec.minimum + action_maximum = rubiks_cube.action_spec.maximum def random_policy(observation: Observation, key: chex.PRNGKey) -> chex.Array: batch_size = observation.cube.shape[0] diff --git a/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py b/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py index 71625f7cc..5c4a6752c 100644 --- a/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py +++ b/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py @@ -40,7 +40,7 @@ def make_actor_critic_networks_sliding_tile_puzzle( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `SlidingTilePuzzle` environment.""" - num_actions = sliding_tile_puzzle.action_spec().num_values + num_actions = sliding_tile_puzzle.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/snake/actor_critic.py b/jumanji/training/networks/snake/actor_critic.py index f23906154..0be42e223 100644 --- a/jumanji/training/networks/snake/actor_critic.py +++ b/jumanji/training/networks/snake/actor_critic.py @@ -36,7 +36,7 @@ def make_actor_critic_networks_snake( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Snake` environment.""" - num_actions = snake.action_spec().num_values + num_actions = snake.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/sokoban/actor_critic.py b/jumanji/training/networks/sokoban/actor_critic.py index 968180c60..37942b7dd 100644 --- a/jumanji/training/networks/sokoban/actor_critic.py +++ b/jumanji/training/networks/sokoban/actor_critic.py @@ -36,7 +36,7 @@ def make_actor_critic_networks_sokoban( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Sokoban` environment.""" - num_actions = sokoban.action_spec().num_values + num_actions = sokoban.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/sudoku/actor_critic.py b/jumanji/training/networks/sudoku/actor_critic.py index f3bd1092b..8ba664e65 100644 --- a/jumanji/training/networks/sudoku/actor_critic.py +++ b/jumanji/training/networks/sudoku/actor_critic.py @@ -40,7 +40,7 @@ def make_cnn_actor_critic_networks_sudoku( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Sudoku` environment. Uses the CNN network architecture.""" - num_actions = sudoku.action_spec().num_values + num_actions = sudoku.action_spec.num_values parametric_action_distribution = FactorisedActionSpaceParametricDistribution( action_spec_num_values=np.asarray(num_actions) ) @@ -71,7 +71,7 @@ def make_equivariant_actor_critic_networks_sudoku( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Sudoku` environment. Uses the digits-permutation equivariant network architecture.""" - num_actions = sudoku.action_spec().num_values + num_actions = sudoku.action_spec.num_values parametric_action_distribution = FactorisedActionSpaceParametricDistribution( action_spec_num_values=np.asarray(num_actions) ) diff --git a/jumanji/training/networks/sudoku/random.py b/jumanji/training/networks/sudoku/random.py index 3b394d7b2..e2cce1fde 100644 --- a/jumanji/training/networks/sudoku/random.py +++ b/jumanji/training/networks/sudoku/random.py @@ -23,7 +23,7 @@ def make_random_policy_sudoku(sudoku: Sudoku) -> RandomPolicy: """Make random policy for the `Sudoku` environment.""" - action_spec_num_values = sudoku.action_spec().num_values + action_spec_num_values = sudoku.action_spec.num_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values diff --git a/jumanji/training/networks/tetris/actor_critic.py b/jumanji/training/networks/tetris/actor_critic.py index 5ac0488de..4e37052fd 100644 --- a/jumanji/training/networks/tetris/actor_critic.py +++ b/jumanji/training/networks/tetris/actor_critic.py @@ -39,7 +39,7 @@ def make_actor_critic_networks_tetris( """Make actor-critic networks for the `Tetris` environment.""" parametric_action_distribution = FactorisedActionSpaceParametricDistribution( - action_spec_num_values=np.asarray(tetris.action_spec().num_values) + action_spec_num_values=np.asarray(tetris.action_spec.num_values) ) policy_network = make_network_cnn( conv_num_channels=conv_num_channels, diff --git a/jumanji/training/networks/tetris/random.py b/jumanji/training/networks/tetris/random.py index e7410f35c..eef995792 100644 --- a/jumanji/training/networks/tetris/random.py +++ b/jumanji/training/networks/tetris/random.py @@ -21,7 +21,7 @@ def make_random_policy_tetris(tetris: Tetris) -> RandomPolicy: """Make random policy for `Tetris`.""" - action_spec_num_values = tetris.action_spec().num_values + action_spec_num_values = tetris.action_spec.num_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values ) diff --git a/jumanji/training/networks/tsp/actor_critic.py b/jumanji/training/networks/tsp/actor_critic.py index 6b0761411..cff891c5e 100644 --- a/jumanji/training/networks/tsp/actor_critic.py +++ b/jumanji/training/networks/tsp/actor_critic.py @@ -38,7 +38,7 @@ def make_actor_critic_networks_tsp( mean_cities_in_query: bool, ) -> ActorCriticNetworks: """Make actor-critic networks for the `TSP` environment.""" - num_actions = tsp.action_spec().num_values + num_actions = tsp.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index 1945910c3..04a87db45 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -13,17 +13,8 @@ # limitations under the License. from __future__ import annotations -from typing import ( - Any, - Callable, - ClassVar, - Dict, - Generic, - Optional, - Tuple, - TypeVar, - Union, -) +from functools import cached_property +from typing import Any, Callable, ClassVar, Dict, Generic, Optional, Tuple, Union import chex import dm_env.specs @@ -33,23 +24,23 @@ import numpy as np from jumanji import specs, tree_utils -from jumanji.env import Environment, State +from jumanji.env import ActionSpec, Environment, Observation, State from jumanji.types import TimeStep -Observation = TypeVar("Observation") - # Type alias that corresponds to ObsType in the Gym API GymObservation = Any -class Wrapper(Environment[State], Generic[State]): +class Wrapper( + Environment[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] +): """Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72 """ - def __init__(self, env: Environment): - super().__init__() + def __init__(self, env: Environment[State, ActionSpec, Observation]): self._env = env + super().__init__() def __repr__(self) -> str: return f"{self.__class__.__name__}({repr(self._env)})" @@ -60,11 +51,11 @@ def __getattr__(self, name: str) -> Any: return getattr(self._env, name) @property - def unwrapped(self) -> Environment: + def unwrapped(self) -> Environment[State, ActionSpec, Observation]: """Returns the wrapped env.""" return self._env.unwrapped - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment to an initial state. Args: @@ -76,7 +67,9 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """ return self._env.reset(key) - def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: + def step( + self, state: State, action: chex.Array + ) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -89,21 +82,25 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """ return self._env.step(state, action) - def observation_spec(self) -> specs.Spec: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec.""" - return self._env.observation_spec() + return self._env.observation_spec - def action_spec(self) -> specs.Spec: + @cached_property + def action_spec(self) -> ActionSpec: """Returns the action spec.""" - return self._env.action_spec() + return self._env.action_spec + @cached_property def reward_spec(self) -> specs.Array: """Returns the reward spec.""" - return self._env.reward_spec() + return self._env.reward_spec + @cached_property def discount_spec(self) -> specs.BoundedArray: """Returns the discount spec.""" - return self._env.discount_spec() + return self._env.discount_spec def render(self, state: State) -> Any: """Compute render frames during initialisation of the environment. @@ -128,10 +125,16 @@ def __exit__(self, *args: Any) -> None: self.close() -class JumanjiToDMEnvWrapper(dm_env.Environment): +class JumanjiToDMEnvWrapper( + dm_env.Environment, Generic[State, ActionSpec, Observation] +): """A wrapper that converts Environment to dm_env.Environment.""" - def __init__(self, env: Environment, key: Optional[chex.PRNGKey] = None): + def __init__( + self, + env: Environment[State, ActionSpec, Observation], + key: Optional[chex.PRNGKey] = None, + ): """Create the wrapped environment. Args: @@ -165,7 +168,7 @@ def reset(self) -> dm_env.TimeStep: - observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the - specification returned by `observation_spec()`. + specification returned by `observation_spec`. """ reset_key, self._key = jax.random.split(self._key) self._state, timestep = self._jitted_reset(reset_key) @@ -184,21 +187,21 @@ def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep: Args: action: A NumPy array, or a nested dict, list or tuple of arrays - corresponding to `action_spec()`. + corresponding to `action_spec`. Returns: A `TimeStep` namedtuple containing: - step_type: A `StepType` value. - reward: Reward at this timestep, or None if step_type is `StepType.FIRST`. Must conform to the specification returned by - `reward_spec()`. + `reward_spec`. - discount: A discount in the range [0, 1], or None if step_type is `StepType.FIRST`. Must conform to the specification returned by - `discount_spec()`. + `discount_spec`. - observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the - specification returned by `observation_spec()`. + specification returned by `observation_spec`. """ self._state, timestep = self._jitted_step(self._state, action) return dm_env.TimeStep( @@ -210,23 +213,25 @@ def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep: def observation_spec(self) -> dm_env.specs.Array: """Returns the dm_env observation spec.""" - return specs.jumanji_specs_to_dm_env_specs(self._env.observation_spec()) + return specs.jumanji_specs_to_dm_env_specs(self._env.observation_spec) def action_spec(self) -> dm_env.specs.Array: """Returns the dm_env action spec.""" - return specs.jumanji_specs_to_dm_env_specs(self._env.action_spec()) + return specs.jumanji_specs_to_dm_env_specs(self._env.action_spec) @property - def unwrapped(self) -> Environment: + def unwrapped(self) -> Environment[State, ActionSpec, Observation]: return self._env -class MultiToSingleWrapper(Wrapper): +class MultiToSingleWrapper( + Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] +): """A wrapper that converts a multi-agent Environment to a single-agent Environment.""" def __init__( self, - env: Environment, + env: Environment[State, ActionSpec, Observation], reward_aggregator: Callable = jnp.sum, discount_aggregator: Callable = jnp.max, ): @@ -243,7 +248,9 @@ def __init__( self._reward_aggregator = reward_aggregator self._discount_aggregator = discount_aggregator - def _aggregate_timestep(self, timestep: TimeStep) -> TimeStep: + def _aggregate_timestep( + self, timestep: TimeStep[Observation] + ) -> TimeStep[Observation]: """Apply the reward and discount aggregator to a multi-agent timestep object to create a new timestep object that consists of a scalar reward and discount value. @@ -298,7 +305,9 @@ def step( return state, timestep -class VmapWrapper(Wrapper): +class VmapWrapper( + Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] +): """Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: @@ -379,7 +388,9 @@ def add_obs_to_extras(timestep: TimeStep[Observation]) -> TimeStep[Observation]: return timestep.replace(extras=extras) # type: ignore -class AutoResetWrapper(Wrapper): +class AutoResetWrapper( + Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] +): """Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. @@ -390,7 +401,11 @@ class AutoResetWrapper(Wrapper): being processed each time `step` is called. Please use the `VmapAutoResetWrapper` instead. """ - def __init__(self, env: Environment, next_obs_in_extras: bool = False): + def __init__( + self, + env: Environment[State, ActionSpec, Observation], + next_obs_in_extras: bool = False, + ): """Wrap an environment to automatically reset it when the episode terminates. Args: @@ -453,7 +468,9 @@ def step( return state, timestep -class VmapAutoResetWrapper(Wrapper): +class VmapAutoResetWrapper( + Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] +): """Efficient combination of VmapWrapper and AutoResetWrapper, to be used as a replacement of the combination of both wrappers. `env = VmapAutoResetWrapper(env)` is equivalent to `env = VmapWrapper(AutoResetWrapper(env))` @@ -465,7 +482,11 @@ class VmapAutoResetWrapper(Wrapper): NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"]. """ - def __init__(self, env: Environment, next_obs_in_extras: bool = False): + def __init__( + self, + env: Environment[State, ActionSpec, Observation], + next_obs_in_extras: bool = False, + ): """Wrap an environment to vmap it and automatically reset it when the episode terminates. Args: @@ -529,7 +550,7 @@ def step( return state, timestep def _auto_reset( - self, state: State, timestep: TimeStep + self, state: State, timestep: TimeStep[Observation] ) -> Tuple[State, TimeStep[Observation]]: """Reset the state and overwrite `timestep.observation` with the reset observation if the episode has terminated. @@ -556,7 +577,7 @@ def _auto_reset( return state, timestep def _maybe_reset( - self, state: State, timestep: TimeStep + self, state: State, timestep: TimeStep[Observation] ) -> Tuple[State, TimeStep[Observation]]: """Overwrite the state and timestep appropriately if the episode terminates.""" state, timestep = jax.lax.cond( @@ -580,14 +601,19 @@ def render(self, state: State) -> Any: return super().render(state_0) -class JumanjiToGymWrapper(gym.Env): +class JumanjiToGymWrapper(gym.Env, Generic[State, ActionSpec, Observation]): """A wrapper that converts a Jumanji `Environment` to one that follows the `gym.Env` API.""" # Flag that prevents `gym.register` from misinterpreting the `_step` and # `_reset` as signs of a deprecated gym Env API. _gym_disable_underscore_compat: ClassVar[bool] = True - def __init__(self, env: Environment, seed: int = 0, backend: Optional[str] = None): + def __init__( + self, + env: Environment[State, ActionSpec, Observation], + seed: int = 0, + backend: Optional[str] = None, + ): """Create the Gym environment. Args: @@ -601,9 +627,9 @@ def __init__(self, env: Environment, seed: int = 0, backend: Optional[str] = Non self.backend = backend self._state = None self.observation_space = specs.jumanji_specs_to_gym_spaces( - self._env.observation_spec() + self._env.observation_spec ) - self.action_space = specs.jumanji_specs_to_gym_spaces(self._env.action_spec()) + self.action_space = specs.jumanji_specs_to_gym_spaces(self._env.action_spec) def reset(key: chex.PRNGKey) -> Tuple[State, Observation, Optional[Dict]]: """Reset function of a Jumanji environment to be jitted.""" @@ -691,6 +717,8 @@ def render(self, mode: str = "human") -> Any: mode: currently not used since Jumanji does not currently support modes. """ del mode + if self._state is None: + raise ValueError("Cannot render when _state is None.") return self._env.render(self._state) def close(self) -> None: @@ -698,7 +726,7 @@ def close(self) -> None: self._env.close() @property - def unwrapped(self) -> Environment: + def unwrapped(self) -> Environment[State, ActionSpec, Observation]: return self._env diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index 556ff54ca..53d914f3a 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import namedtuple -from typing import Tuple, Type, TypeVar +from typing import Tuple, Type import chex import dm_env.specs @@ -43,13 +43,12 @@ jumanji_to_gym_obs, ) -State = TypeVar("State") -Observation = TypeVar("Observation") +FakeWrapper = Wrapper[FakeState, specs.BoundedArray, chex.Array] @pytest.fixture -def mock_wrapper_class() -> Type[Wrapper]: - class MockWrapper(Wrapper[FakeState]): +def mock_wrapper_class() -> Type[FakeWrapper]: + class MockWrapper(Wrapper[FakeState, specs.BoundedArray, chex.Array]): pass return MockWrapper @@ -70,13 +69,13 @@ class TestBaseWrapper: @pytest.fixture def wrapped_fake_environment( - self, mock_wrapper_class: Type[Wrapper], fake_environment: FakeEnvironment - ) -> Wrapper: + self, mock_wrapper_class: Type[FakeWrapper], fake_environment: FakeEnvironment + ) -> FakeWrapper: wrapped_env = mock_wrapper_class(fake_environment) return wrapped_env def test_wrapper__unwrapped( - self, wrapped_fake_environment: Wrapper, fake_environment: FakeEnvironment + self, wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment ) -> None: """Checks `Wrapper.unwrapped` returns the unwrapped env.""" assert wrapped_fake_environment.unwrapped is fake_environment @@ -84,7 +83,7 @@ def test_wrapper__unwrapped( def test_wrapper__step( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment, ) -> None: """Checks `Wrapper.step` calls the step method of the underlying env.""" @@ -99,7 +98,7 @@ def test_wrapper__step( def test_wrapper__reset( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment, ) -> None: """Checks `Wrapper.reset` calls the reset method of the underlying env.""" @@ -113,36 +112,38 @@ def test_wrapper__reset( def test_wrapper__observation_spec( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + mock_wrapper_class: Type[FakeWrapper], fake_environment: FakeEnvironment, ) -> None: - """Checks `Wrapper.observation_spec` calls the observation_spec function of - the underlying env. - """ + """Checks `Wrapper.__init__` calls the observation_spec function of the underlying env.""" mock_obs_spec = mocker.patch.object( - fake_environment, "observation_spec", autospec=True + FakeEnvironment, "observation_spec", new_callable=mocker.PropertyMock ) - wrapped_fake_environment.observation_spec() + wrapped_fake_environment = mock_wrapper_class(fake_environment) + mock_obs_spec.assert_called_once() + wrapped_fake_environment.observation_spec mock_obs_spec.assert_called_once() def test_wrapper__action_spec( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + mock_wrapper_class: Type[FakeWrapper], fake_environment: FakeEnvironment, ) -> None: - """Checks `Wrapper.action_spec` calls the action_spec function of the underlying env.""" + """Checks `Wrapper.__init__` calls the action_spec function of the underlying env.""" mock_action_spec = mocker.patch.object( - fake_environment, "action_spec", autospec=True + FakeEnvironment, "action_spec", new_callable=mocker.PropertyMock ) - wrapped_fake_environment.action_spec() + wrapped_fake_environment = mock_wrapper_class(fake_environment) + mock_action_spec.assert_called_once() + wrapped_fake_environment.action_spec mock_action_spec.assert_called_once() - def test_wrapper__repr(self, wrapped_fake_environment: Wrapper) -> None: + def test_wrapper__repr(self, wrapped_fake_environment: FakeWrapper) -> None: """Checks `Wrapper.__repr__` returns the expected representation string.""" repr_str = repr(wrapped_fake_environment) assert "MockWrapper" in repr_str @@ -150,7 +151,7 @@ def test_wrapper__repr(self, wrapped_fake_environment: Wrapper) -> None: def test_wrapper__render( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment, ) -> None: """Checks `Wrapper.render` calls the render method of the underlying env.""" @@ -167,7 +168,7 @@ def test_wrapper__render( def test_wrapper__close( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment, ) -> None: """Checks `Wrapper.close` calls the close method of the underlying env.""" @@ -179,13 +180,18 @@ def test_wrapper__close( mock_action_spec.assert_called_once() def test_wrapper__getattr( - self, wrapped_fake_environment: Wrapper, fake_environment: FakeEnvironment + self, wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment ) -> None: """Checks `Wrapper.__getattr__` calls the underlying env for unknown attr.""" # time_limit is defined in the mock env assert wrapped_fake_environment.time_limit == fake_environment.time_limit +FakeJumanjiToDMEnvWrapper = JumanjiToDMEnvWrapper[ + FakeState, specs.BoundedArray, chex.Array +] + + class TestJumanjiEnvironmentToDeepMindEnv: """Test the JumanjiEnvironmentToDeepMindEnv that transforms an Environment into a dm_env.Environment format. @@ -205,14 +211,14 @@ def test_jumanji_environment_to_deep_mind_env__init( ) assert isinstance(dm_environment_with_key, dm_env.Environment) - def test_dm_env__reset(self, fake_dm_env: JumanjiToDMEnvWrapper) -> None: + def test_dm_env__reset(self, fake_dm_env: FakeJumanjiToDMEnvWrapper) -> None: """Validates reset function and timestep type of the wrapped environment.""" timestep = fake_dm_env.reset() assert isinstance(timestep, dm_env.TimeStep) assert timestep.step_type == dm_env.StepType.FIRST def test_jumanji_environment_to_deep_mind_env__step( - self, fake_dm_env: JumanjiToDMEnvWrapper + self, fake_dm_env: FakeJumanjiToDMEnvWrapper ) -> None: """Validates step function of the wrapped environment.""" timestep = fake_dm_env.reset() @@ -221,31 +227,34 @@ def test_jumanji_environment_to_deep_mind_env__step( assert next_timestep != timestep def test_jumanji_environment_to_deep_mind_env__observation_spec( - self, fake_dm_env: JumanjiToDMEnvWrapper + self, fake_dm_env: FakeJumanjiToDMEnvWrapper ) -> None: """Validates observation_spec property of the wrapped environment.""" assert isinstance(fake_dm_env.observation_spec(), dm_env.specs.Array) def test_jumanji_environment_to_deep_mind_env__action_spec( - self, fake_dm_env: JumanjiToDMEnvWrapper + self, fake_dm_env: FakeJumanjiToDMEnvWrapper ) -> None: """Validates action_spec property of the wrapped environment.""" assert isinstance(fake_dm_env.action_spec(), dm_env.specs.Array) def test_jumanji_environment_to_deep_mind_env__unwrapped( - self, fake_dm_env: JumanjiToDMEnvWrapper + self, fake_dm_env: FakeJumanjiToDMEnvWrapper ) -> None: """Validates unwrapped property of the wrapped environment.""" assert isinstance(fake_dm_env.unwrapped, Environment) +FakeJumanjiToGymWrapper = JumanjiToGymWrapper[FakeState, specs.BoundedArray, chex.Array] + + class TestJumanjiEnvironmentToGymEnv: """ Test the JumanjiEnvironmentToGymEnv that transforms an Environment into a gym.Env format. """ @pytest.fixture - def fake_gym_env(self, time_limit: int = 10) -> gym.Env: + def fake_gym_env(self, time_limit: int = 10) -> FakeJumanjiToGymWrapper: """Creates a fake environment wrapped as a gym.Env.""" return JumanjiToGymWrapper(FakeEnvironment(time_limit=time_limit)) @@ -259,12 +268,12 @@ def test_jumanji_environment_to_gym_env__init( assert isinstance(gym_environment_with_seed, gym.Env) def test_jumanji_environment_to_gym_env__reset( - self, fake_gym_env: JumanjiToGymWrapper + self, fake_gym_env: FakeJumanjiToGymWrapper ) -> None: """Validates reset function of the wrapped environment.""" - observation1 = fake_gym_env.reset() # type: ignore + observation1 = fake_gym_env.reset() state1 = fake_gym_env._state - observation2 = fake_gym_env.reset() # type: ignore + observation2 = fake_gym_env.reset() state2 = fake_gym_env._state # Observation is typically numpy array @@ -276,24 +285,24 @@ def test_jumanji_environment_to_gym_env__reset( assert_trees_are_different(state1, state2) def test_jumanji_environment_to_gym_env__step( - self, fake_gym_env: JumanjiToGymWrapper + self, fake_gym_env: FakeJumanjiToGymWrapper ) -> None: """Validates step function of the wrapped environment.""" - observation = fake_gym_env.reset() # type: ignore + observation = fake_gym_env.reset() action = fake_gym_env.action_space.sample() - next_observation, reward, terminated, info = fake_gym_env.step(action) # type: ignore + next_observation, reward, terminated, info = fake_gym_env.step(action) assert_trees_are_different(observation, next_observation) assert isinstance(reward, float) assert isinstance(terminated, bool) def test_jumanji_environment_to_gym_env__observation_space( - self, fake_gym_env: JumanjiToGymWrapper + self, fake_gym_env: FakeJumanjiToGymWrapper ) -> None: """Validates observation_space attribute of the wrapped environment.""" assert isinstance(fake_gym_env.observation_space, gym.spaces.Space) def test_jumanji_environment_to_gym_env__action_space( - self, fake_gym_env: JumanjiToGymWrapper + self, fake_gym_env: FakeJumanjiToGymWrapper ) -> None: """Validates action_space attribute of the wrapped environment.""" assert isinstance(fake_gym_env.action_space, gym.spaces.Space) @@ -301,14 +310,16 @@ def test_jumanji_environment_to_gym_env__action_space( def test_jumanji_environment_to_gym_env__render( self, mocker: pytest_mock.MockerFixture, - fake_gym_env: JumanjiToGymWrapper, + fake_gym_env: FakeJumanjiToGymWrapper, ) -> None: - mock_render = mocker.patch.object( fake_gym_env.unwrapped, "render", autospec=True ) mock_state = mocker.MagicMock() + with pytest.raises(ValueError): + fake_gym_env.render(mock_state) + fake_gym_env.reset() fake_gym_env.render(mock_state) mock_render.assert_called_once() @@ -316,9 +327,8 @@ def test_jumanji_environment_to_gym_env__render( def test_jumanji_environment_to_gym_env__close( self, mocker: pytest_mock.MockerFixture, - fake_gym_env: JumanjiToGymWrapper, + fake_gym_env: FakeJumanjiToGymWrapper, ) -> None: - mock_close = mocker.patch.object(fake_gym_env.unwrapped, "close", autospec=True) fake_gym_env.close() @@ -326,17 +336,22 @@ def test_jumanji_environment_to_gym_env__close( mock_close.assert_called_once() def test_jumanji_environment_to_gym_env__unwrapped( - self, fake_gym_env: JumanjiToGymWrapper + self, fake_gym_env: FakeJumanjiToGymWrapper ) -> None: """Validates unwrapped property of the wrapped environment.""" assert isinstance(fake_gym_env.unwrapped, Environment) +FakeMultiToSingleWrapper = MultiToSingleWrapper[ + FakeState, specs.BoundedArray, chex.Array +] + + class TestMultiToSingleEnvironment: @pytest.fixture def fake_multi_to_single_env( self, fake_multi_environment: FakeMultiEnvironment - ) -> MultiToSingleWrapper: + ) -> FakeMultiToSingleWrapper: """Creates a fake wrapper that converts a multi-agent Environment to a single-agent Environment.""" return MultiToSingleWrapper(fake_multi_environment) @@ -344,7 +359,7 @@ def fake_multi_to_single_env( def test_multi_env_wrapper__init( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, ) -> None: """Validates initialization of the multi agent to single agent wrapper.""" single_agent_env = MultiToSingleWrapper(fake_multi_environment) @@ -353,7 +368,7 @@ def test_multi_env_wrapper__init( def test_multi_env__reset( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, key: chex.PRNGKey, ) -> None: """Validates (jitted) reset function and timestep type of the multi agent @@ -369,14 +384,14 @@ def test_multi_env__reset( def test_multi_env__step( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, key: chex.PRNGKey, ) -> None: """Validates (jitted) step function of the multi agent to single agent wrapped environment. """ - state, timestep = fake_multi_to_single_env.reset(key) # type: ignore - action = fake_multi_to_single_env.action_spec().generate_value() + state, timestep = fake_multi_to_single_env.reset(key) + action = fake_multi_to_single_env.action_spec.generate_value() state, next_timestep = jax.jit(fake_multi_to_single_env.step)(state, action) assert next_timestep != timestep assert next_timestep.reward.shape == () @@ -391,18 +406,16 @@ def test_multi_env__step( def test_multi_env__different_reward_aggregator( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, key: chex.PRNGKey, ) -> None: """Checks that using a different reward aggregator is correct.""" mean_fake_multi_to_single_env = MultiToSingleWrapper( fake_multi_environment, reward_aggregator=jnp.mean ) - state, timestep = mean_fake_multi_to_single_env.reset(key) # type: ignore - action = mean_fake_multi_to_single_env.action_spec().generate_value() - state, next_timestep = mean_fake_multi_to_single_env.step( - state, action - ) # type: Tuple[FakeState, TimeStep] + state, timestep = mean_fake_multi_to_single_env.reset(key) + action = mean_fake_multi_to_single_env.action_spec.generate_value() + state, next_timestep = mean_fake_multi_to_single_env.step(state, action) assert next_timestep != timestep assert next_timestep.reward.shape == () assert next_timestep.reward == fake_multi_environment.reward_per_step @@ -412,31 +425,32 @@ def test_multi_env__different_reward_aggregator( def test_multi_env__observation_spec( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, ) -> None: """Validates observation_spec property of the multi agent to single agent wrapped environment. """ - obs_spec: specs.Array = fake_multi_to_single_env.observation_spec() # type: ignore + obs_spec: specs.Array = fake_multi_to_single_env.observation_spec # type: ignore assert isinstance(obs_spec, specs.Array) - assert obs_spec.shape == fake_multi_environment.observation_spec().shape + multi_obs_spec: specs.Array = fake_multi_environment.observation_spec # type: ignore + assert obs_spec.shape == multi_obs_spec.shape def test_multi_env__action_spec( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, ) -> None: """Validates action_spec property of the multi agent to single agent wrapped environment. """ - action_spec: specs.Array = fake_multi_to_single_env.action_spec() # type: ignore - assert isinstance(fake_multi_to_single_env.action_spec(), specs.Array) - assert action_spec.shape == fake_multi_environment.action_spec().shape + action_spec = fake_multi_to_single_env.action_spec + assert isinstance(fake_multi_to_single_env.action_spec, specs.Array) + assert action_spec.shape == fake_multi_environment.action_spec.shape def test_multi_env__unwrapped( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, ) -> None: """Validates unwrapped property of the multi agent to single agent wrapped environment. @@ -445,9 +459,14 @@ def test_multi_env__unwrapped( assert fake_multi_to_single_env._env is fake_multi_environment +FakeVmapWrapper = Wrapper[FakeState, specs.BoundedArray, chex.Array] + + class TestVmapWrapper: @pytest.fixture - def fake_vmap_environment(self, fake_environment: FakeEnvironment) -> VmapWrapper: + def fake_vmap_environment( + self, fake_environment: FakeEnvironment + ) -> FakeVmapWrapper: return VmapWrapper(fake_environment) def test_vmap_wrapper__init(self, fake_environment: FakeEnvironment) -> None: @@ -456,7 +475,7 @@ def test_vmap_wrapper__init(self, fake_environment: FakeEnvironment) -> None: assert isinstance(vmap_env, Environment) def test_vmap_env__reset( - self, fake_vmap_environment: VmapWrapper, keys: chex.PRNGKey + self, fake_vmap_environment: FakeVmapWrapper, keys: chex.PRNGKey ) -> None: """Validates reset function and timestep type of the vmap wrapped environment.""" _, timestep = jax.jit(fake_vmap_environment.reset)(keys) @@ -468,19 +487,15 @@ def test_vmap_env__reset( assert timestep.discount.shape == (keys.shape[0],) def test_vmap_env__step( - self, fake_vmap_environment: VmapWrapper, keys: chex.PRNGKey + self, fake_vmap_environment: FakeVmapWrapper, keys: chex.PRNGKey ) -> None: """Validates step function of the vmap environment.""" - state, timestep = fake_vmap_environment.reset( + state, timestep = fake_vmap_environment.reset(keys) + action = jax.vmap(lambda _: fake_vmap_environment.action_spec.generate_value())( keys - ) # type: Tuple[FakeState, TimeStep] - action = jax.vmap( - lambda _: fake_vmap_environment.action_spec().generate_value() - )(keys) + ) - state, next_timestep = jax.jit(fake_vmap_environment.step)( - state, action - ) # type: Tuple[FakeState, TimeStep] + state, next_timestep = jax.jit(fake_vmap_environment.step)(state, action) assert_trees_are_different(next_timestep, timestep) chex.assert_trees_all_equal(next_timestep.reward, 0) @@ -490,45 +505,46 @@ def test_vmap_env__step( assert next_timestep.observation.shape[0] == keys.shape[0] def test_vmap_env__render( - self, fake_vmap_environment: VmapWrapper, keys: chex.PRNGKey + self, fake_vmap_environment: FakeVmapWrapper, keys: chex.PRNGKey ) -> None: - states, _ = fake_vmap_environment.reset( - keys - ) # type: Tuple[FakeState, TimeStep] + states, _ = fake_vmap_environment.reset(keys) result = fake_vmap_environment.render(states) assert result == (keys.shape[1:], ()) def test_vmap_env__unwrapped( - self, fake_environment: Environment, fake_vmap_environment: VmapWrapper + self, fake_environment: FakeEnvironment, fake_vmap_environment: FakeVmapWrapper ) -> None: """Validates unwrapped property of the vmap environment.""" assert isinstance(fake_vmap_environment.unwrapped, Environment) assert fake_vmap_environment._env is fake_environment +FakeAutoResetWrapper = AutoResetWrapper[FakeState, specs.BoundedArray, chex.Array] + + class TestAutoResetWrapper: @pytest.fixture def fake_auto_reset_environment( - self, fake_environment: Environment - ) -> AutoResetWrapper: + self, fake_environment: FakeEnvironment + ) -> FakeAutoResetWrapper: return AutoResetWrapper(fake_environment, next_obs_in_extras=True) @pytest.fixture def fake_state_and_timestep( - self, fake_auto_reset_environment: AutoResetWrapper, key: chex.PRNGKey - ) -> Tuple[State, TimeStep[Observation]]: + self, fake_auto_reset_environment: FakeAutoResetWrapper, key: chex.PRNGKey + ) -> Tuple[FakeState, TimeStep[chex.Array]]: state, timestep = jax.jit(fake_auto_reset_environment.reset)(key) return state, timestep - def test_auto_reset_wrapper__init(self, fake_environment: Environment) -> None: + def test_auto_reset_wrapper__init(self, fake_environment: FakeEnvironment) -> None: """Validates initialization of the AutoResetWrapper.""" auto_reset_env = AutoResetWrapper(fake_environment) assert isinstance(auto_reset_env, Environment) def test_auto_reset_wrapper__auto_reset( self, - fake_auto_reset_environment: AutoResetWrapper, - fake_state_and_timestep: Tuple[State, TimeStep[Observation]], + fake_auto_reset_environment: FakeAutoResetWrapper, + fake_state_and_timestep: Tuple[FakeState, TimeStep[chex.Array]], ) -> None: """Validates the auto_reset function of the AutoResetWrapper.""" state, timestep = fake_state_and_timestep @@ -540,21 +556,19 @@ def test_auto_reset_wrapper__auto_reset( assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_auto_reset_wrapper__step_no_reset( - self, fake_auto_reset_environment: AutoResetWrapper, key: chex.PRNGKey + self, fake_auto_reset_environment: FakeAutoResetWrapper, key: chex.PRNGKey ) -> None: """Validates that step function of the AutoResetWrapper does not do an auto-reset when the terminal state is not reached. """ - state, first_timestep = fake_auto_reset_environment.reset( - key - ) # type: Tuple[FakeState, TimeStep] + state, first_timestep = fake_auto_reset_environment.reset(key) # Generate an action - action = fake_auto_reset_environment.action_spec().generate_value() + action = fake_auto_reset_environment.action_spec.generate_value() state, timestep = jax.jit(fake_auto_reset_environment.step)( state, action - ) # type: Tuple[FakeState, TimeStep] + ) # type: Tuple[FakeState, TimeStep[chex.Array]] assert timestep.step_type == StepType.MID assert_trees_are_different(timestep, first_timestep) @@ -565,7 +579,7 @@ def test_auto_reset_wrapper__step_no_reset( def test_auto_reset_wrapper__step_reset( self, fake_environment: FakeEnvironment, - fake_auto_reset_environment: AutoResetWrapper, + fake_auto_reset_environment: FakeAutoResetWrapper, key: chex.PRNGKey, ) -> None: """Validates that the auto-reset is done correctly by the step function @@ -577,7 +591,7 @@ def test_auto_reset_wrapper__step_reset( # Loop across time_limit so auto-reset occurs for _ in range(fake_environment.time_limit - 1): - action = fake_auto_reset_environment.action_spec().generate_value() + action = fake_auto_reset_environment.action_spec.generate_value() state, timestep = jax.jit(fake_auto_reset_environment.step)(state, action) assert jnp.all( timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] @@ -597,19 +611,26 @@ def test_auto_reset_wrapper__step_reset( ) +FakeVmapAutoResetWrapper = VmapAutoResetWrapper[ + FakeState, specs.BoundedArray, chex.Array +] + + class TestVmapAutoResetWrapper: @pytest.fixture def fake_vmap_auto_reset_environment( self, fake_environment: FakeEnvironment - ) -> VmapAutoResetWrapper: + ) -> FakeVmapAutoResetWrapper: return VmapAutoResetWrapper(fake_environment, next_obs_in_extras=True) @pytest.fixture def action( - self, fake_vmap_auto_reset_environment: VmapAutoResetWrapper, keys: chex.PRNGKey + self, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, + keys: chex.PRNGKey, ) -> chex.Array: generate_action_fn = ( - lambda _: fake_vmap_auto_reset_environment.action_spec().generate_value() + lambda _: fake_vmap_auto_reset_environment.action_spec.generate_value() ) return jax.vmap(generate_action_fn)(keys) @@ -621,7 +642,9 @@ def test_vmap_auto_reset_wrapper__init( assert isinstance(vmap_auto_reset_env, Environment) def test_vmap_auto_reset_wrapper__reset( - self, fake_vmap_auto_reset_environment: VmapAutoResetWrapper, keys: chex.PRNGKey + self, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, + keys: chex.PRNGKey, ) -> None: """Validates reset function and timestep type of the wrapper.""" _, timestep = jax.jit(fake_vmap_auto_reset_environment.reset)(keys) @@ -636,11 +659,11 @@ def test_vmap_auto_reset_wrapper__reset( def test_vmap_auto_reset_wrapper__auto_reset( self, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, ) -> None: """Validates the auto_reset function of the wrapper.""" - state, timestep = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + state, timestep = fake_vmap_auto_reset_environment.reset(keys) _, reset_timestep = jax.lax.map( lambda args: fake_vmap_auto_reset_environment._auto_reset(*args), (state, timestep), @@ -653,11 +676,11 @@ def test_vmap_auto_reset_wrapper__auto_reset( def test_vmap_auto_reset_wrapper__maybe_reset( self, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, ) -> None: """Validates the auto_reset function of the wrapper.""" - state, timestep = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + state, timestep = fake_vmap_auto_reset_environment.reset(keys) _, reset_timestep = jax.lax.map( lambda args: fake_vmap_auto_reset_environment._maybe_reset(*args), (state, timestep), @@ -670,14 +693,14 @@ def test_vmap_auto_reset_wrapper__maybe_reset( def test_vmap_auto_reset_wrapper__step_no_reset( self, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, action: chex.Array, ) -> None: """Validates that step function of the wrapper does not do an auto-reset when the terminal state is not reached. """ - state, first_timestep = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + state, first_timestep = fake_vmap_auto_reset_environment.reset(keys) state, timestep = jax.jit(fake_vmap_auto_reset_environment.step)(state, action) assert jnp.all(timestep.step_type == StepType.MID) @@ -693,15 +716,15 @@ def test_vmap_auto_reset_wrapper__step_no_reset( def test_vmap_auto_reset_wrapper__step_reset( self, - fake_environment: Environment, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_environment: FakeEnvironment, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, action: chex.Array, ) -> None: """Validates that the auto-reset is done correctly by the step function of the wrapper when the terminal timestep is reached. """ - state, first_timestep = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + state, first_timestep = fake_vmap_auto_reset_environment.reset(keys) fake_vmap_auto_reset_environment.unwrapped.time_limit = 5 # type: ignore # Loop across time_limit so auto-reset occurs @@ -729,12 +752,12 @@ def test_vmap_auto_reset_wrapper__step_reset( def test_vmap_auto_reset_wrapper__step( self, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, action: chex.Array, ) -> None: """Validates step function of the vmap environment.""" - state, timestep = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + state, timestep = fake_vmap_auto_reset_environment.reset(keys) state, next_timestep = jax.jit(fake_vmap_auto_reset_environment.step)( state, action ) @@ -750,16 +773,18 @@ def test_vmap_auto_reset_wrapper__step( ) def test_vmap_auto_reset_wrapper__render( - self, fake_vmap_auto_reset_environment: VmapAutoResetWrapper, keys: chex.PRNGKey + self, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, + keys: chex.PRNGKey, ) -> None: - states, _ = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + states, _ = fake_vmap_auto_reset_environment.reset(keys) result = fake_vmap_auto_reset_environment.render(states) assert result == (keys.shape[1:], ()) def test_vmap_auto_reset_wrapper__unwrapped( self, fake_environment: FakeEnvironment, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, ) -> None: """Validates unwrapped property of the vmap environment.""" assert isinstance(fake_vmap_auto_reset_environment.unwrapped, FakeEnvironment) @@ -817,9 +842,8 @@ def test_jumanji_to_gym_obs__wrong_observation(self) -> None: def test_jumanji_to_gym_obs__bin_pack(self) -> None: """Check that an example bin_pack observation is correctly converted.""" - env = BinPack(obs_num_ems=1) - env.generator = bin_pack_conftest.DummyGenerator() - obs = env.observation_spec().generate_value() + env = BinPack(generator=bin_pack_conftest.DummyGenerator(), obs_num_ems=1) + obs = env.observation_spec.generate_value() converted_obs = jumanji_to_gym_obs(obs) correct_obs = { diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index ea0b437c2..58afa9227 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -25,5 +25,5 @@ pytype scipy>=1.7.3 testfixtures types-Pillow -types-requests +types-requests<1.27 types-setuptools