Skip to content

Commit

Permalink
Merge pull request #365 from Limmen/fix_stopping_tests
Browse files Browse the repository at this point in the history
update stopping_game tests
  • Loading branch information
Limmen authored May 30, 2024
2 parents f2db9d5 + 2b996bf commit 92d12d8
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import patch, MagicMock
from gymnasium.spaces import Box, Discrete
import numpy as np
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState
Expand All @@ -23,19 +24,19 @@ def setup_env(self) -> None:
:return: None
"""
env_name = "test_env"
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
O = np.array([0, 1])
Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
T = StoppingGameUtil.transition_tensor(L=3, p=0)
O = StoppingGameUtil.observation_space(n=100)
Z = StoppingGameUtil.observation_tensor(n=100)
R = np.zeros((2, 3, 3, 3))
S = np.array([0, 1, 2])
A1 = np.array([0, 1, 2])
A2 = np.array([0, 1, 2])
S = StoppingGameUtil.state_space()
A1 = StoppingGameUtil.defender_actions()
A2 = StoppingGameUtil.attacker_actions()
L = 2
R_INT = 1
R_COST = 2
R_SLA = 3
R_ST = 4
b1 = np.array([0.6, 0.4])
b1 = StoppingGameUtil.b1()
save_dir = "save_directory"
checkpoint_traces_freq = 100
gamma = 0.9
Expand Down Expand Up @@ -69,12 +70,12 @@ def test_stopping_game_init_(self) -> None:
:return: None
"""
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
O = np.array([0, 1])
A1 = np.array([0, 1, 2])
A2 = np.array([0, 1, 2])
T = StoppingGameUtil.transition_tensor(L=3, p=0)
O = StoppingGameUtil.observation_space(n=100)
A1 = StoppingGameUtil.defender_actions()
A2 = StoppingGameUtil.attacker_actions()
L = 2
b1 = np.array([0.6, 0.4])
b1 = StoppingGameUtil.b1()
attacker_observation_space = Box(
low=np.array([0.0, 0.0, 0.0]),
high=np.array([float(L), 1.0, 2.0]),
Expand Down Expand Up @@ -304,7 +305,7 @@ def test_is_state_terminal(self) -> None:
assert not env.is_state_terminal(state_tuple)

with pytest.raises(ValueError):
env.is_state_terminal([1, 2, 3]) # type: ignore
env.is_state_terminal([1, 2, 3]) # type: ignore

def test_get_observation_from_history(self) -> None:
"""
Expand Down Expand Up @@ -346,26 +347,6 @@ def test_step(self) -> None:
:return: None
"""
env = StoppingGameEnv(self.config)
env.state = MagicMock()
env.state.s = 1
env.state.l = 2
env.state.t = 0
env.state.attacker_observation.return_value = np.array([1, 2, 3])
env.state.defender_observation.return_value = np.array([4, 5, 6])
env.state.b = np.array([0.5, 0.5, 0.0])

env.trace = MagicMock()
env.trace.defender_rewards = []
env.trace.attacker_rewards = []
env.trace.attacker_actions = []
env.trace.defender_actions = []
env.trace.infos = []
env.trace.states = []
env.trace.beliefs = []
env.trace.infrastructure_metrics = []
env.trace.attacker_observations = []
env.trace.defender_observations = []

with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state",
return_value=2):
with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation",
Expand All @@ -376,32 +357,20 @@ def test_step(self) -> None:
1,
(
np.array(
[[0.2, 0.8, 0.0], [0.6, 0.4, 0.0], [0.5, 0.5, 0.0]]
[[0.2, 0.8], [0.6, 0.4], [0.5, 0.5]]
),
2,
),
)
observations, rewards, terminated, truncated, info = env.step(
action_profile
)

assert (observations[0] == np.array([4, 5, 6])).all(), "Incorrect defender observations"
assert (observations[1] == np.array([1, 2, 3])).all(), "Incorrect attacker observations"
assert observations[0].all() == np.array([1, 0.7]).all(), "Incorrect defender observations"
assert observations[1].all() == np.array([1, 2, 3]).all(), "Incorrect attacker observations"
assert rewards == (0, 0)
assert not terminated
assert not truncated
assert env.trace.defender_rewards[-1] == 0
assert env.trace.attacker_rewards[-1] == 0
assert env.trace.attacker_actions[-1] == 2
assert env.trace.defender_actions[-1] == 1
assert env.trace.infos[-1] == info
assert env.trace.states[-1] == 2
print(env.trace.beliefs)
assert env.trace.beliefs[-1] == 0.7
assert env.trace.infrastructure_metrics[-1] == 1
assert (env.trace.attacker_observations[-1] == np.array([1, 2, 3])).all()
assert (env.trace.defender_observations[-1] == np.array([4, 5, 6])).all()


def test_info(self) -> None:
"""
Tests the function of adding the cumulative reward and episode length to the info dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from gym_csle_stopping_game.dao.stopping_game_attacker_mdp_config import (
StoppingGameAttackerMdpConfig,
)
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
from csle_common.dao.training.policy import Policy
from csle_common.dao.training.random_policy import RandomPolicy
from csle_common.dao.training.player_type import PlayerType
from csle_common.dao.simulation_config.action import Action
import pytest
from unittest.mock import MagicMock
import numpy as np
Expand All @@ -25,19 +29,19 @@ def setup_env(self) -> None:
:return: None
"""
env_name = "test_env"
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
O = np.array([0, 1])
Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
T = StoppingGameUtil.transition_tensor(L=3, p=0)
O = StoppingGameUtil.observation_space(n=100)
Z = StoppingGameUtil.observation_tensor(n=100)
R = np.zeros((2, 3, 3, 3))
S = np.array([0, 1, 2])
A1 = np.array([0, 1, 2])
A2 = np.array([0, 1, 2])
S = StoppingGameUtil.state_space()
A1 = StoppingGameUtil.defender_actions()
A2 = StoppingGameUtil.attacker_actions()
L = 2
R_INT = 1
R_COST = 2
R_SLA = 3
R_ST = 4
b1 = np.array([0.6, 0.4])
b1 = StoppingGameUtil.b1()
save_dir = "save_directory"
checkpoint_traces_freq = 100
gamma = 0.9
Expand Down Expand Up @@ -107,9 +111,8 @@ def test_reset(self) -> None:
)

env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
attacker_obs, info = env.reset()
assert env.latest_defender_obs.all() == np.array([2, 0.4]).all() # type: ignore
assert info == {}
info = env.reset()
assert info[-1] == {}

def test_set_model(self) -> None:
"""
Expand Down Expand Up @@ -144,7 +147,7 @@ def test_set_state(self) -> None:
)

env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
assert not env.set_state(1) # type: ignore
assert not env.set_state(1) # type: ignore

def test_calculate_stage_policy(self) -> None:
"""
Expand Down Expand Up @@ -190,7 +193,7 @@ def test_get_attacker_dist(self) -> None:
def test_render(self) -> None:
"""
Tests the function for rendering the environment
:return: None
"""
defender_strategy = MagicMock(spec=Policy)
Expand Down Expand Up @@ -317,7 +320,7 @@ def test_get_actions_from_particles(self) -> None:
particles = [1, 2, 3]
t = 0
observation = 0
expected_actions = [0, 1, 2]
expected_actions = [0, 1]
assert (
env.get_actions_from_particles(particles, t, observation)
== expected_actions
Expand All @@ -326,18 +329,32 @@ def test_get_actions_from_particles(self) -> None:
def test_step(self) -> None:
"""
Tests the function for taking a step in the environment by executing the given action
:return: None
"""
defender_strategy = MagicMock(spec=Policy)
defender_stage_strategy = np.zeros((3, 2))
defender_stage_strategy[0][0] = 0.9
defender_stage_strategy[0][1] = 0.1
defender_stage_strategy[1][0] = 0.9
defender_stage_strategy[1][1] = 0.1
defender_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A1))
defender_strategy = RandomPolicy(
actions=defender_actions,
player_type=PlayerType.DEFENDER,
stage_policy_tensor=list(defender_stage_strategy),
)
attacker_mdp_config = StoppingGameAttackerMdpConfig(
env_name="test_env",
stopping_game_config=self.config,
defender_strategy=defender_strategy,
stopping_game_name="csle-stopping-game-v1",
)

env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
pi2 = np.array([[0.5, 0.5]])
with pytest.raises(AssertionError):
env.step(pi2)
env.reset()
pi2 = env.calculate_stage_policy(o=list(env.latest_attacker_obs), a2=0) # type: ignore
attacker_obs, reward, terminated, truncated, info = env.step(pi2)
assert isinstance(attacker_obs[0], float) # type: ignore
assert isinstance(terminated, bool) # type: ignore
assert isinstance(truncated, bool) # type: ignore
assert isinstance(reward, float) # type: ignore
assert isinstance(info, dict) # type: ignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import StoppingGamePomdpDefenderEnv
from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import (
StoppingGamePomdpDefenderEnv,
)
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig
from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import (
StoppingGameDefenderPomdpConfig,
)
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
from csle_common.dao.training.policy import Policy
from csle_common.dao.simulation_config.action import Action
from csle_common.dao.training.random_policy import RandomPolicy
from csle_common.dao.training.player_type import PlayerType
import pytest
Expand Down Expand Up @@ -219,7 +224,7 @@ def test_set_state(self) -> None:
stopping_game_name="csle-stopping-game-v1",
)
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
assert env.set_state(1) is None # type: ignore
assert env.set_state(1) is None # type: ignore

def test_get_observation_from_history(self) -> None:
"""
Expand Down Expand Up @@ -301,7 +306,10 @@ def test_get_actions_from_particles(self) -> None:
t = 0
observation = 0
expected_actions = [0, 1]
assert env.get_actions_from_particles(particles, t, observation) == expected_actions
assert (
env.get_actions_from_particles(particles, t, observation)
== expected_actions
)

def test_step(self) -> None:
"""
Expand All @@ -315,8 +323,12 @@ def test_step(self) -> None:
attacker_stage_strategy[1][0] = 0.9
attacker_stage_strategy[1][1] = 0.1
attacker_stage_strategy[2] = attacker_stage_strategy[1]
attacker_strategy = RandomPolicy(actions=list(self.config.A2), player_type=PlayerType.ATTACKER,
stage_policy_tensor=list(attacker_stage_strategy))
attacker_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A2))
attacker_strategy = RandomPolicy(
actions=attacker_actions,
player_type=PlayerType.ATTACKER,
stage_policy_tensor=list(attacker_stage_strategy),
)
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
env_name="test_env",
stopping_game_config=self.config,
Expand All @@ -328,9 +340,9 @@ def test_step(self) -> None:
env.reset()
defender_obs, reward, terminated, truncated, info = env.step(a1)
assert len(defender_obs) == 2
assert isinstance(defender_obs[0], float) # type: ignore
assert isinstance(defender_obs[1], float) # type: ignore
assert isinstance(reward, float) # type: ignore
assert isinstance(terminated, bool) # type: ignore
assert isinstance(truncated, bool) # type: ignore
assert isinstance(info, dict) # type: ignore
assert isinstance(defender_obs[0], float) # type: ignore
assert isinstance(defender_obs[1], float) # type: ignore
assert isinstance(reward, float) # type: ignore
assert isinstance(terminated, bool) # type: ignore
assert isinstance(truncated, bool) # type: ignore
assert isinstance(info, dict) # type: ignore

0 comments on commit 92d12d8

Please sign in to comment.