Skip to content

Commit

Permalink
Merge pull request #69 from emnigma/server_restart
Browse files Browse the repository at this point in the history
Update mlSearcher
  • Loading branch information
gsvgit authored Aug 15, 2023
2 parents a569f0c + d969c2a commit 7961a55
Show file tree
Hide file tree
Showing 29 changed files with 960 additions and 391 deletions.
13 changes: 6 additions & 7 deletions VSharp.ML.AIAgent/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from pathlib import Path

import torch

from config import BrokerConfig


Expand All @@ -14,10 +12,10 @@ def _build_bar_format() -> str:


IMPORTED_FULL_MODEL_PATH = Path(
"ml/imported/GNN_state_pred_het_full_TAGConv_20e_2xAll_10h"
"ml/imported/GNN_state_pred_het_full_StateGNNEncoderConvEdgeAttr_32ch.zip"
)
IMPORTED_DICT_MODEL_PATH = Path(
"ml/imported/GNN_state_pred_het_dict_TAGConv_20e_2xAll_10h"
"ml/imported/GNN_state_pred_het_dict_StateGNNEncoderConvEdgeAttr_32ch.zip"
)

BASE_REPORT_DIR = Path("./report")
Expand All @@ -32,9 +30,6 @@ def _build_bar_format() -> str:
"dynamic_ncols": True,
}

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BASE_NN_OUT_FEATURES_NUM = 8


class WebsocketSourceLinks:
GET_WS = f"http://0.0.0.0:{BrokerConfig.BROKER_PORT}/get_ws"
Expand All @@ -49,3 +44,7 @@ class ResultsHandlerLinks:
DUMMY_INPUT_PATH = Path("ml/onnx/dummy_input.json")
BEST_MODEL_ONNX_SAVE_PATH = Path("ml/onnx/StateModelEncoder.onnx")
TEMP_EPOCH_INFERENCE_TIMES_DIR = Path(".epoch_inference_times/")
BASE_NN_OUT_FEATURES_NUM = 8

# assuming we start from /VSharp/VSharp.ML.AIAgent
SERVER_WORKING_DIR = "../VSharp.ML.GameServer.Runner/bin/Release/net6.0/"
47 changes: 41 additions & 6 deletions VSharp.ML.AIAgent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,29 @@
from pathlib import Path
from shutil import rmtree

import torch

import ml.model_modified
import ml.models


class GeneralConfig:
SERVER_COUNT = 8
MAX_STEPS = 3000
SERVER_COUNT = 16
NUM_GENERATIONS = 20
NUM_PARENTS_MATING = 22
KEEP_ELITISM = 2
NUM_RANDOM_SOLUTIONS = 60
NUM_RANDOM_LAST_LAYER = 18
MAX_STEPS = 5000
MUTATION_PERCENT_GENES = 5
LOGGER_LEVEL = logging.INFO
MODEL_INIT = lambda: ml.models.SAGEConvModel(16)
IMPORT_MODEL_INIT = lambda: ml.models.StateModelEncoder(
hidden_channels=32, out_channels=8
)
EXPORT_MODEL_INIT = lambda: ml.model_modified.StateModelEncoderExport(
hidden_channels=32, out_channels=8
)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class BrokerConfig:
Expand All @@ -24,7 +39,7 @@ class ServerConfig:
@dataclass(slots=True, frozen=True)
class DumpByTimeoutFeature:
enabled: bool
timeout_seconds: int
timeout_sec: int
save_path: Path

def create_save_path_if_not_exists(self):
Expand All @@ -46,15 +61,35 @@ def create_save_path_if_not_exists(self):
self.save_path.mkdir()


@dataclass(slots=True, frozen=True)
class OnGameServerRestartFeature:
enabled: bool
wait_for_reset_retries: int
wait_for_reset_time: float


class FeatureConfig:
VERBOSE_TABLES = True
SHOW_SUCCESSORS = True
NAME_LEN = 7
N_BEST_SAVED_EACH_GEN = 2
DISABLE_MESSAGE_CHECKS = True
DUMP_BY_TIMEOUT = DumpByTimeoutFeature(
enabled=True, timeout_seconds=1200, save_path=Path("./report/timeouted_agents/")
enabled=True, timeout_sec=1800, save_path=Path("./report/timeouted_agents/")
)
SAVE_EPOCHS_COVERAGES = SaveEpochsCoveragesFeature(
enabled=True, save_path=Path("./report/epochs_tables/")
)
ON_GAME_SERVER_RESTART = OnGameServerRestartFeature(
enabled=True, wait_for_reset_retries=10 * 60, wait_for_reset_time=0.1
)


class GameServerConnectorConfig:
CREATE_CONNECTION_TIMEOUT_SEC = 1
WAIT_FOR_SOCKET_RECONNECTION_MAX_RETRIES = 10 * 60
RESPONCE_TIMEOUT_SEC = (
FeatureConfig.DUMP_BY_TIMEOUT.timeout_sec + 1
if FeatureConfig.DUMP_BY_TIMEOUT.enabled
else 1000
)
SKIP_UTF_VALIDATION = True
13 changes: 12 additions & 1 deletion VSharp.ML.AIAgent/connection/broker_conn/classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Callable
from typing import Callable, TypeAlias

from dataclasses_json import config, dataclass_json

Expand All @@ -8,6 +8,17 @@
from connection.game_server_conn.unsafe_json import asdict
from ml.model_wrappers.nnwrapper import NNWrapper, decode, encode

WSUrl: TypeAlias = str
Undefined: TypeAlias = None


@dataclass_json
@dataclass(slots=True, frozen=True)
class ServerInstanceInfo:
port: int
ws_url: WSUrl
pid: int | Undefined


def custom_encoder_if_disable_message_checks() -> Callable | None:
return asdict if FeatureConfig.DISABLE_MESSAGE_CHECKS else None
Expand Down
32 changes: 16 additions & 16 deletions VSharp.ML.AIAgent/connection/broker_conn/requests.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
import json
import logging

import httplib2

from common.constants import ResultsHandlerLinks, WebsocketSourceLinks

from .classes import Agent2ResultsOnMaps
from .classes import Agent2ResultsOnMaps, ServerInstanceInfo


def aquire_ws() -> str:
while True:
response, content = httplib2.Http().request(WebsocketSourceLinks.GET_WS)
aquired_ws_url = content.decode("utf-8")
if aquired_ws_url == "":
logging.warning(f"all sockets are in use")
continue
logging.info(f"aquired ws: {aquired_ws_url}")
return aquired_ws_url
def acquire_instance() -> ServerInstanceInfo:
response, content = httplib2.Http().request(WebsocketSourceLinks.GET_WS)
if response.status != 200:
logging.error(f"{response.status} with {content=} on acuire_instance call")
raise RuntimeError(f"Not ok response: {response}, {content}")
aquired_instance = ServerInstanceInfo.from_json(json.loads(content.decode("utf-8")))
logging.info(f"acquired ws: {aquired_instance}")
return aquired_instance


def return_ws(ws_url: str):
logging.info(f"returning: {ws_url}")
def return_instance(instance: ServerInstanceInfo):
logging.info(f"returning: {instance}")

response, content = httplib2.Http().request(
WebsocketSourceLinks.POST_WS,
method="POST",
body=ws_url,
body=instance.to_json(),
)

if response.status == 200:
logging.info(f"{ws_url} is returned")
logging.info(f"{instance} is returned")
else:
logging.error(f"{response.status} on returning {ws_url}")
logging.error(f"{response.status} on returning {instance}")
raise RuntimeError(f"Not ok response: {response.status}")


Expand All @@ -51,5 +51,5 @@ def send_game_results(data: Agent2ResultsOnMaps):
def recv_game_result_list() -> str:
response, content = httplib2.Http().request(ResultsHandlerLinks.GET_RES)
games_data = content.decode("utf-8")
logging.info(f"Aquired games data")
logging.info(f"Acquired games data")
return games_data
48 changes: 43 additions & 5 deletions VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,53 @@
from contextlib import contextmanager
import logging
import os
import time
from contextlib import contextmanager, suppress

import websocket
from .requests import aquire_ws, return_ws

from config import GameServerConnectorConfig
from connection.broker_conn.classes import ServerInstanceInfo

from .requests import acquire_instance, return_instance


def wait_for_connection(server_instance: ServerInstanceInfo):
ws = websocket.WebSocket()

retries_left = GameServerConnectorConfig.WAIT_FOR_SOCKET_RECONNECTION_MAX_RETRIES

while retries_left:
with suppress(
ConnectionRefusedError,
ConnectionResetError,
websocket.WebSocketTimeoutException,
):
ws.settimeout(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC)
ws.connect(
server_instance.ws_url,
skip_utf8_validation=GameServerConnectorConfig.SKIP_UTF_VALIDATION,
)
if ws.connected:
return ws
time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC)
logging.info(
f"Try connecting to {server_instance.ws_url}, {retries_left} attempts left; {server_instance}"
)
retries_left -= 1
raise RuntimeError(
f"Retries exsausted wnen trying to connect to {server_instance.ws_url}: {retries_left} left"
)


@contextmanager
def game_server_socket_manager():
socket_url = aquire_ws()
socket = websocket.create_connection(socket_url, skip_utf8_validation=True)
server_instance = acquire_instance()

socket = wait_for_connection(server_instance)

try:
socket.settimeout(GameServerConnectorConfig.RESPONCE_TIMEOUT_SEC)
yield socket
finally:
socket.close()
return_ws(socket_url)
return_instance(server_instance)
13 changes: 7 additions & 6 deletions VSharp.ML.AIAgent/connection/game_server_conn/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,21 @@ def send_step(self, next_state_id: int, predicted_usefullness: int):
self._sent_state_id = next_state_id

def recv_reward_or_throw_gameover(self) -> Reward:
data = RewardServerMessage.from_json_handle(
self._raise_if_gameover(self.ws.recv()),
received = self.ws.recv()
decoded = RewardServerMessage.from_json_handle(
self._raise_if_gameover(received),
expected=RewardServerMessage,
)
logging.debug(f"<-- MoveReward : {data.MessageBody}")
logging.debug(f"<-- MoveReward : {decoded.MessageBody}")

return self._process_reward_server_message(data)
return self._process_reward_server_message(decoded)

def _process_reward_server_message(self, msg):
match msg.MessageType:
case ServerMessageType.INCORRECT_PREDICTED_STATEID:
raise Connector.IncorrectSentStateError(
f"Sending state_id={self._sent_state_id} \
at step #{self._current_step} resulted in {msg.MessageType}"
f"Sending state_id={self._sent_state_id} "
f"at step #{self._current_step} resulted in {msg.MessageType}"
)

case ServerMessageType.MOVE_REVARD:
Expand Down
9 changes: 8 additions & 1 deletion VSharp.ML.AIAgent/epochs_statistics/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@ def create_pivot_table(
for mutable2result in mutable2result_list:
name_results_dict[map_obj.Id].append(convert_to_view_model(mutable2result))
epoch_percents_dict[map_obj.Id].append(
mutable2result.game_result.actual_coverage_percent
str(
(
mutable2result.game_result.actual_coverage_percent,
mutable2result.game_result.tests_count,
mutable2result.game_result.errors_count,
mutable2result.game_result.steps_count,
)
)
)

mutable_names = get_model_names_in_order(name_results_dict)
Expand Down
1 change: 1 addition & 0 deletions VSharp.ML.AIAgent/install_script.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
conda install numpy pandas tabulate
conda install -c pytorch pytorch=1.13.1 torchvision=0.14.1 torchaudio=0.13.1
python3 -m pip install --force-reinstall -v "torch-scatter==2.1.0" "torch-geometric==2.2.0" "torch-sparse==0.6.16"
python3 -m pip install func_timeout
conda install -c conda-forge dataclasses-json websocket-client pre_commit aiohttp cchardet pygad httplib2 onnx onnxruntime
pre-commit install
Loading

0 comments on commit 7961a55

Please sign in to comment.