From 4cbd7dd005391f6101d575dad0f415f4c8521bfa Mon Sep 17 00:00:00 2001 From: Anya497 Date: Thu, 20 Jul 2023 14:38:02 +0300 Subject: [PATCH 01/47] added common model and new predictor for training this model --- VSharp.ML.AIAgent/ml/common_model.py | 183 +++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 VSharp.ML.AIAgent/ml/common_model.py diff --git a/VSharp.ML.AIAgent/ml/common_model.py b/VSharp.ML.AIAgent/ml/common_model.py new file mode 100644 index 000000000..c0bf2c7a5 --- /dev/null +++ b/VSharp.ML.AIAgent/ml/common_model.py @@ -0,0 +1,183 @@ +import os.path +from collections import namedtuple + +import torch +import torch.nn.functional as F +from common.constants import DEVICE +from ml import data_loader_compact +from ml.models import GNN_Het +from torch_geometric.data import HeteroData +from torch_geometric.loader import DataLoader +from torch_geometric.nn import to_hetero +from torch import nn +from torch.nn import Linear +from torch_geometric.nn import ( + GATConv, + GatedGraphConv, + GCNConv, + HeteroConv, + Linear, + ResGatedGraphConv, + SAGEConv, + TAGConv, + TransformerConv, + global_mean_pool, + to_hetero, +) +from torchvision.ops import MLP + +from timer.wrapper import timeit +from conn.socket_manager import game_server_socket_manager +from ml.model_wrappers.protocols import Predictor + + + +class CommonModel(torch.nn.Module): + def __init__( + self, + hidden_channels, + num_gv_layers=2, + num_sv_layers=2, + num_history_layers=2, + num_in_layers=2, + ): + super().__init__() + self.gv_layers = nn.ModuleList() + self.gv_layers.append(SAGEConv(-1, hidden_channels)) + for i in range(num_gv_layers - 1): + sage_gv = SAGEConv(-1, hidden_channels) + self.gv_layers.append(sage_gv) + + self.sv_layers = nn.ModuleList() + self.sv_layers.append(SAGEConv(-1, hidden_channels)) + for i in range(num_sv_layers - 1): + sage_sv = SAGEConv(-1, hidden_channels) + self.sv_layers.append(sage_sv) + + self.history1 = GATConv((-1, -1), hidden_channels, add_self_loops=False) + + self.in1 = SAGEConv((-1, -1), hidden_channels) + + self.sv_layers2 = nn.ModuleList() + self.sv_layers2.append(SAGEConv(-1, hidden_channels)) + for i in range(num_sv_layers - 1): + sage_sv = SAGEConv(-1, hidden_channels) + self.sv_layers2.append(sage_sv) + + self.mlp = MLP(hidden_channels, [1]) + + @timeit + def forward(self, x_dict, edge_index_dict, edge_attr_dict): + # print(x_dict) + # print(edge_attr_dict) + game_x = self.gv_layers[0]( + x_dict["game_vertex"], + edge_index_dict[("game_vertex", "to", "game_vertex")], + ).relu() + for layer in self.gv_layers[1:]: + game_x = layer( + game_x, + edge_index_dict[("game_vertex", "to", "game_vertex")], + ).relu() + # print(game_x.size()) + + state_x = self.sv_layers[0]( + x_dict["state_vertex"], + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + for layer in self.sv_layers[1:]: + state_x = layer( + state_x, + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + + history_x = self.history1( + (game_x, state_x), + edge_index_dict[("game_vertex", "history", "state_vertex")], + edge_attr_dict, + size=(game_x.size(0), state_x.size(0)) + ).relu() + + # history_x = self.history1( + # (state_x, game_x), + # edge_index_dict[("state_vertex", "history", "game_vertex")], + # edge_attr_dict, + # size=(state_x.size(0), game_x.size(0)) + # ).relu() + + # history_x = self.history1((game_x, state_x), + # edge_index_dict[("game_vertex", "history", "state_vertex")]).relu() + #history_x = self.history2((history_x, game_x), + # edge_index_dict[("state_vertex", "history", "game_vertex")]).relu() + + in_x = self.in1( + (game_x, history_x), + edge_index_dict[("game_vertex", "in", "state_vertex")] + ).relu() + + state_x = self.sv_layers2[0]( + in_x, + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + for layer in self.sv_layers2[1:]: + state_x = layer( + state_x, + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + x = self.mlp(in_x) + return x + + +def euclidean_dist(y_pred, y_true): + y_pred_min, ind1 = torch.min(y_pred, dim=0) + y_pred_norm = y_pred - y_pred_min + + y_true_min, ind1 = torch.min(y_true, dim=0) + y_true_norm = y_true - y_true_min + + return torch.sqrt(torch.sum((y_pred_norm - y_true_norm) ** 2)) + + +lr = 0.0001 +model = CommonModel(32) +model.to(DEVICE) +optimizer = torch.optim.Adam(model.parameters(), lr=lr) +criterion = euclidean_dist + + +class CommonModelPredictor(Predictor): + def __init__(self, model: torch.nn.Module, best_models: dict) -> None: + self.model = model + self._name = str(sum(weights_flat)) + self._hash = tuple(weights_flat).__hash__() + + def name(self) -> str: + return self._name + + def predict(self, input: GameState): + hetero_input, state_map = ServerDataloaderHeteroVector.convert_input_to_tensor( + input + ) + assert self.model is not None + + next_step_id = PredictStateVectorHetGNN.predict_state_single_out( + self.model, hetero_input, state_map + ) + + back_prop(best_models[input.map_name], self.model, hetero_input) + + del hetero_input + return next_step_id + + +def back_prop(best_model, model, data): + model.train() + data.to(DEVICE) + optimizer.zero_grad() + + out = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) + y_true = best_model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) + + loss = criterion(out, y_true) + loss.backward() + optimizer.step() From 346a1c7b5552390bbd353cc04663d66e3e7c6663 Mon Sep 17 00:00:00 2001 From: Anya497 Date: Fri, 21 Jul 2023 11:35:27 +0300 Subject: [PATCH 02/47] fixed play_map --- VSharp.ML.AIAgent/learning/r_learn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/learning/r_learn.py b/VSharp.ML.AIAgent/learning/r_learn.py index 1e8b16566..04bcf8401 100644 --- a/VSharp.ML.AIAgent/learning/r_learn.py +++ b/VSharp.ML.AIAgent/learning/r_learn.py @@ -80,7 +80,7 @@ def play_map( tests_count=0, errors_count=0, actual_coverage_percent=0, - ) + ), perf_counter() - start_time if gameover.actual_coverage is not None: actual_coverage = gameover.actual_coverage From a11b1f6c5578d475cebdce2ea84b217cb2666e94 Mon Sep 17 00:00:00 2001 From: Anya497 Date: Fri, 21 Jul 2023 11:53:31 +0300 Subject: [PATCH 03/47] Ignore new python env --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0c89a657c..b6f0db2fa 100644 --- a/.gitignore +++ b/.gitignore @@ -244,12 +244,12 @@ VSharp.Test/GeneratedTests/** #Pytjon environments **/.env/** +/torch_venv/** #Python caches **/__pycache__/** #logs **/logs_full*/** - /output/** /References/*.dll From 8da717c674523d13bc0e6e7ad2afb5f553115389 Mon Sep 17 00:00:00 2001 From: Anya497 Date: Fri, 21 Jul 2023 15:24:28 +0300 Subject: [PATCH 04/47] added some functions for CommonModel training --- VSharp.ML.AIAgent/learning/play_game.py | 4 +- VSharp.ML.AIAgent/ml/common_model.py | 129 ++++++++++-------- .../ml/model_wrappers/protocols.py | 2 +- 3 files changed, 77 insertions(+), 58 deletions(-) diff --git a/VSharp.ML.AIAgent/learning/play_game.py b/VSharp.ML.AIAgent/learning/play_game.py index 06791f8dc..79e95f5d6 100644 --- a/VSharp.ML.AIAgent/learning/play_game.py +++ b/VSharp.ML.AIAgent/learning/play_game.py @@ -34,7 +34,7 @@ def play_map( try: for _ in range(steps): game_state = with_connector.recv_state_or_throw_gameover() - predicted_state_id = with_predictor.predict(game_state) + predicted_state_id = with_predictor.predict(game_state, with_connector.map.MapName) logging.debug( f"<{with_predictor.name()}> step: {steps_count}, available states: {get_states(game_state)}, predicted: {predicted_state_id}" ) @@ -59,7 +59,7 @@ def play_map( tests_count=0, errors_count=0, actual_coverage_percent=0, - ) + ), perf_counter() - start_time if gameover.actual_coverage is not None: actual_coverage = gameover.actual_coverage diff --git a/VSharp.ML.AIAgent/ml/common_model.py b/VSharp.ML.AIAgent/ml/common_model.py index c0bf2c7a5..49d8022ef 100644 --- a/VSharp.ML.AIAgent/ml/common_model.py +++ b/VSharp.ML.AIAgent/ml/common_model.py @@ -1,35 +1,28 @@ -import os.path -from collections import namedtuple - import torch -import torch.nn.functional as F from common.constants import DEVICE -from ml import data_loader_compact -from ml.models import GNN_Het -from torch_geometric.data import HeteroData -from torch_geometric.loader import DataLoader -from torch_geometric.nn import to_hetero +import os +import re + from torch import nn -from torch.nn import Linear from torch_geometric.nn import ( GATConv, - GatedGraphConv, - GCNConv, - HeteroConv, - Linear, - ResGatedGraphConv, SAGEConv, - TAGConv, - TransformerConv, - global_mean_pool, - to_hetero, ) from torchvision.ops import MLP -from timer.wrapper import timeit -from conn.socket_manager import game_server_socket_manager +from common.game import GameState +from ml.data_loader_compact import ServerDataloaderHeteroVector from ml.model_wrappers.protocols import Predictor +from ml.predict_state_vector_hetero import PredictStateVectorHetGNN +import csv +from learning.play_game import play_game +from config import GeneralConfig +from connection.game_server_conn.utils import MapsType +from ml.models import SAGEConvModel + +csv_path = '../report/epochs_tables/' +models_path = '../report/epochs_best/' class CommonModel(torch.nn.Module): @@ -38,8 +31,6 @@ def __init__( hidden_channels, num_gv_layers=2, num_sv_layers=2, - num_history_layers=2, - num_in_layers=2, ): super().__init__() self.gv_layers = nn.ModuleList() @@ -66,10 +57,7 @@ def __init__( self.mlp = MLP(hidden_channels, [1]) - @timeit def forward(self, x_dict, edge_index_dict, edge_attr_dict): - # print(x_dict) - # print(edge_attr_dict) game_x = self.gv_layers[0]( x_dict["game_vertex"], edge_index_dict[("game_vertex", "to", "game_vertex")], @@ -79,7 +67,6 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): game_x, edge_index_dict[("game_vertex", "to", "game_vertex")], ).relu() - # print(game_x.size()) state_x = self.sv_layers[0]( x_dict["state_vertex"], @@ -97,19 +84,7 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): edge_attr_dict, size=(game_x.size(0), state_x.size(0)) ).relu() - - # history_x = self.history1( - # (state_x, game_x), - # edge_index_dict[("state_vertex", "history", "game_vertex")], - # edge_attr_dict, - # size=(state_x.size(0), game_x.size(0)) - # ).relu() - - # history_x = self.history1((game_x, state_x), - # edge_index_dict[("game_vertex", "history", "state_vertex")]).relu() - #history_x = self.history2((history_x, game_x), - # edge_index_dict[("state_vertex", "history", "game_vertex")]).relu() - + in_x = self.in1( (game_x, history_x), edge_index_dict[("game_vertex", "in", "state_vertex")] @@ -138,23 +113,19 @@ def euclidean_dist(y_pred, y_true): return torch.sqrt(torch.sum((y_pred_norm - y_true_norm) ** 2)) -lr = 0.0001 -model = CommonModel(32) -model.to(DEVICE) -optimizer = torch.optim.Adam(model.parameters(), lr=lr) -criterion = euclidean_dist - - -class CommonModelPredictor(Predictor): +class CommonModelWrapper(Predictor): def __init__(self, model: torch.nn.Module, best_models: dict) -> None: self.model = model - self._name = str(sum(weights_flat)) - self._hash = tuple(weights_flat).__hash__() + self.best_models = best_models + self._model = model def name(self) -> str: - return self._name + return "MY AWESOME MODEL" + + def model(self): + return self._model - def predict(self, input: GameState): + def predict(self, input: GameState, map_name): hetero_input, state_map = ServerDataloaderHeteroVector.convert_input_to_tensor( input ) @@ -164,20 +135,68 @@ def predict(self, input: GameState): self.model, hetero_input, state_map ) - back_prop(best_models[input.map_name], self.model, hetero_input) + back_prop(self.best_models[map_name], self.model, hetero_input) del hetero_input return next_step_id +def get_last_epoch_num(path): + epochs = list(map(lambda x: re.findall('[0-9]+', x), os.listdir(path))) + return str(sorted(epochs)[-1][0]) + + +def csv2best_models(): + best_models = {} + values = [] + models_names = [] + + with open(csv_path + get_last_epoch_num(csv_path) + '.csv', 'r') as csv_file: + csv_reader = csv.reader(csv_file) + map_names = next(csv_reader)[1:] + for row in csv_reader: + models_names.append(row[0]) + int_row = list(map(lambda x: int(x), row[1:])) + values.append(int_row) + val, ind = torch.max(torch.tensor(values), dim=0) + for i in range(len(map_names)): + best_models[map_names[i]] = models_names[ind[i]] + return best_models + + def back_prop(best_model, model, data): model.train() data.to(DEVICE) optimizer.zero_grad() - + ref_model = SAGEConvModel(16) + ref_model.load_state_dict(torch.load(models_path + "epoch_" + get_last_epoch_num(models_path) + "/" + best_model + ".pth")) + ref_model.to(DEVICE) out = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) - y_true = best_model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) + y_true = ref_model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) loss = criterion(out, y_true) loss.backward() optimizer.step() + + +model = CommonModel(32) +model.to(DEVICE) + +cmwrapper = CommonModelWrapper(model, csv2best_models()) +lr = 0.0001 +epochs = 10 +optimizer = torch.optim.Adam(model.parameters(), lr=lr) +criterion = euclidean_dist + + +def train(): + for epoch in range(epochs): + # some function with some parameters + play_game( + with_predictor=cmwrapper, + max_steps=GeneralConfig.MAX_STEPS, + maps_type=MapsType.TRAIN, + ) + + +train() diff --git a/VSharp.ML.AIAgent/ml/model_wrappers/protocols.py b/VSharp.ML.AIAgent/ml/model_wrappers/protocols.py index 5039497b3..2928a1e9b 100644 --- a/VSharp.ML.AIAgent/ml/model_wrappers/protocols.py +++ b/VSharp.ML.AIAgent/ml/model_wrappers/protocols.py @@ -13,7 +13,7 @@ def name(self) -> str: class Predictor(Named, ABC): @abstractmethod - def predict(self, input: GameState): + def predict(self, input: GameState, map_name): raise NotImplementedError @abstractmethod From 301c0e04e53d847d4a41f78a90067765d0f8acfd Mon Sep 17 00:00:00 2001 From: Anya497 Date: Fri, 28 Jul 2023 15:27:58 +0300 Subject: [PATCH 05/47] new maps --- VSharp.ML.GameServer/Maps.fs | 77 ++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/VSharp.ML.GameServer/Maps.fs b/VSharp.ML.GameServer/Maps.fs index 303e621a6..a4d01b526 100644 --- a/VSharp.ML.GameServer/Maps.fs +++ b/VSharp.ML.GameServer/Maps.fs @@ -248,6 +248,83 @@ let trainMaps, validationMaps = add 0u "Virtu.dll" CoverageZone.Method "DiskIIController.ReadIoRegionC0C0" add 0u "Virtu.dll" CoverageZone.Method "DiskIIController.WriteIoRegionC0C0" + + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "MC68000.Disassemble" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CP1610.Disassemble" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CP1610.Execute" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "F3850.FetchInstruction" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "HuC6280.DisassembleExt" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "HuC6280.Execute" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "HuC6280.DisassembleCDL" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "I8048.Disassemble" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "I8048.ExecuteOne" + + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.BuildInstructionTable" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.ExecuteOne" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.Disassemble" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.ADDS_Func" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.DA_Func" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "MC6800.ExecuteOne" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "MC6800.DA_Func" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "MOS6502X.State" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "x86.Execute" + + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Emu83.LoadStateBinary" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "NECUPD765.ReadPort" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "NECUPD765.SetUnitSelect" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AY38912.PortWrite" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPCBase.PollInput" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPCBase.LoadAllMedia" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPCBase.DecodeINPort" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "GateArrayBase.SetupScreenMapping" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "GateArrayBase.ClockCycle" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPC6128.ReadBus" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPC6128.WriteBus" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPC6128.InitROM" + + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPC6128.ReadPort" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPC6128.WritePort" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradCPC.AmstradCPC" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradCPC.GetFirmware" + + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradCPC.FrameAdvance" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPCMachineMetaData.GetMetaString" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradCPC.OSD_ShowDiskStatus" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradCPC.OSD_ShowTapeStatus" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradCPC.CheckMessageSettings" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "SoundProviderMixer.GetSamplesSync" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CartridgeDevice.Load" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Mapper0000.Mapper0000" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Mapper000F.Mapper000F" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Mapper0005.Mapper0005" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Mapper0013.Mapper0013" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Mapper0020.Mapper0020" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "D64.Read" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "DiskBuilder.Build" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "G64.Read" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "G64.Write" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Prg.Load" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Tape.ExecuteCycle" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Tape.Load" + + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Chip6510.Read" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Chip6510.Write" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Chip6526.CreateCia1" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Chip90611401.Write" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Cia.ExecutePhase" + // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Cia.Write" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Sid.Flush" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Sid.filter_operator" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Envelope.ExecutePhase2" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Sid.Write" + + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Sid.GetSamplesSync" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Via.ExecutePhase" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Via.SyncState" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Via.Write" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Vic.ExecutePhase1" + add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Vic.Read" + //add 0u "Virtu.dll" CoverageZone.Method "Keyboard.SetKeys" //add 0u "Algorithms.dll" CoverageZone.Method "TopologicalSorter.Sort" From 4702f0185f8241adad73e071ff6f6935b0dc9c32 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Mon, 7 Aug 2023 00:28:35 +0300 Subject: [PATCH 06/47] Remove new timeout, shorten dict upd --- .../connection/game_server_conn/connector.py | 22 ---------------- VSharp.ML.AIAgent/learning/play_game.py | 25 ++++++++++++------- VSharp.ML.AIAgent/ml/utils.py | 5 ++-- 3 files changed, 18 insertions(+), 34 deletions(-) diff --git a/VSharp.ML.AIAgent/connection/game_server_conn/connector.py b/VSharp.ML.AIAgent/connection/game_server_conn/connector.py index 60c285c19..e9d3b620d 100644 --- a/VSharp.ML.AIAgent/connection/game_server_conn/connector.py +++ b/VSharp.ML.AIAgent/connection/game_server_conn/connector.py @@ -1,12 +1,10 @@ import logging import logging.config -import time from typing import Optional import websocket from common.game import GameMap, GameState -from config import FeatureConfig from .messages import ( ClientMessage, @@ -34,9 +32,6 @@ def __init__( class IncorrectSentStateError(Exception): pass - class TimeoutError(Exception): - pass - class GameOver(Exception): def __init__( self, @@ -55,7 +50,6 @@ def __init__( ws: websocket.WebSocket, map: GameMap, steps: int, - timeout_sec: float, ) -> None: self.ws = ws @@ -66,8 +60,6 @@ def __init__( self.game_is_over = False self.map = map self.steps = steps - self.start_time_ms = time.perf_counter() - self.timeout_sec = timeout_sec def _raise_if_gameover(self, msg) -> GameOverServerMessage | str: if self.game_is_over: @@ -92,9 +84,6 @@ def _raise_if_gameover(self, msg) -> GameOverServerMessage | str: return msg def recv_state_or_throw_gameover(self) -> GameState: - self._handle_timeout( - start_time_ms=self.start_time_ms, current_time_ms=time.perf_counter() - ) received = self.ws.recv() data = GameStateServerMessage.from_json_handle( self._raise_if_gameover(received), @@ -114,9 +103,6 @@ def send_step(self, next_state_id: int, predicted_usefullness: int): self._sent_state_id = next_state_id def recv_reward_or_throw_gameover(self) -> Reward: - self._handle_timeout( - start_time_ms=self.start_time_ms, current_time_ms=time.perf_counter() - ) received = self.ws.recv() decoded = RewardServerMessage.from_json_handle( self._raise_if_gameover(received), @@ -126,14 +112,6 @@ def recv_reward_or_throw_gameover(self) -> Reward: return self._process_reward_server_message(decoded) - def _handle_timeout(self, start_time_ms, current_time_ms): - if not FeatureConfig.DUMP_BY_TIMEOUT.enabled: - return - if self.timeout_sec * 1000 <= (current_time_ms - start_time_ms): - raise Connector.TimeoutError( - f"{self.timeout_sec * 1000=}, {(current_time_ms - start_time_ms)=}" - ) - def _process_reward_server_message(self, msg): match msg.MessageType: case ServerMessageType.INCORRECT_PREDICTED_STATEID: diff --git a/VSharp.ML.AIAgent/learning/play_game.py b/VSharp.ML.AIAgent/learning/play_game.py index 47e1ca1e4..34f5c519a 100644 --- a/VSharp.ML.AIAgent/learning/play_game.py +++ b/VSharp.ML.AIAgent/learning/play_game.py @@ -106,6 +106,13 @@ def play_map_with_stats( return model_result, time_duration +@func_set_timeout(FeatureConfig.DUMP_BY_TIMEOUT.timeout_sec) +def play_map_with_timeout( + with_connector: Connector, with_predictor: Predictor +) -> tuple[GameResult, TimeDuration]: + return play_map_with_stats(with_connector, with_predictor) + + def play_game(with_predictor: Predictor, max_steps: int, maps_type: MapsType): with game_server_socket_manager() as ws: maps = get_maps(websocket=ws, type=maps_type) @@ -119,14 +126,14 @@ def play_game(with_predictor: Predictor, max_steps: int, maps_type: MapsType): logging.info(f"<{with_predictor.name()}> is playing {game_map.MapName}") try: + play_func = ( + play_map_with_timeout + if FeatureConfig.DUMP_BY_TIMEOUT.enabled + else play_map_with_stats + ) with game_server_socket_manager() as ws: - game_result, time = play_map_with_stats( - with_connector=Connector( - ws, - game_map, - max_steps, - FeatureConfig.DUMP_BY_TIMEOUT.timeout_sec, - ), + game_result, time = play_func( + with_connector=Connector(ws, game_map, max_steps), with_predictor=with_predictor, ) logging.info( @@ -134,13 +141,13 @@ def play_game(with_predictor: Predictor, max_steps: int, maps_type: MapsType): f"in {game_result.steps_count} steps, {time} seconds, " f"actual coverage: {game_result.actual_coverage_percent:.2f}" ) - except (Connector.TimeoutError, FunctionTimedOut): + except FunctionTimedOut as fto: game_result, time = ( GameResult(0, 0, 0, 0), FeatureConfig.DUMP_BY_TIMEOUT.timeout_sec, ) logging.warning( - f"<{with_predictor.name()}> timeouted on map {game_map.MapName}" + f"<{with_predictor.name()}> timeouted on map {game_map.MapName} with {fto.timedOutAfter}s" ) save_model( with_predictor.model(), diff --git a/VSharp.ML.AIAgent/ml/utils.py b/VSharp.ML.AIAgent/ml/utils.py index 0e0b2623b..263df9695 100644 --- a/VSharp.ML.AIAgent/ml/utils.py +++ b/VSharp.ML.AIAgent/ml/utils.py @@ -22,9 +22,8 @@ def load_model(path: Path, model: torch.nn.Module): def convert_to_export( old_sd: OrderedDict, new_sd: OrderedDict, last_layer_weights: list[float] ): - for key, value in [(k, v) for k, v in new_sd.items()]: - if key in old_sd: - new_sd.update({key: value}) + for key, value in old_sd.items(): + new_sd.update({key: value}) new_model = GeneralConfig.EXPORT_MODEL_INIT() new_model.load_state_dict(new_sd, strict=False) From 1c4f1c6158bcf0d29dc9cf7dad4b03716bc49453 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Mon, 7 Aug 2023 19:10:09 +0300 Subject: [PATCH 07/47] Upd socket manager --- VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 9de891008..d0857e5c8 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -12,7 +12,9 @@ def wait_for_connection(url: WSUrl): ws = websocket.WebSocket() - while True: + retries_left = 60 + + while retries_left: with suppress(ConnectionRefusedError, ConnectionResetError): ws.settimeout(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT) ws.connect( @@ -21,6 +23,8 @@ def wait_for_connection(url: WSUrl): if ws.connected: return ws time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT) + retries_left -= 1 + raise RuntimeError("Retries exsausted") @contextmanager From 60a82e705c4afbca0015169180f5611c2e1efc2e Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Mon, 7 Aug 2023 19:34:58 +0300 Subject: [PATCH 08/47] Add retries count logging --- VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index d0857e5c8..c9f24a558 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -1,3 +1,4 @@ +import logging import time from contextlib import contextmanager, suppress @@ -12,7 +13,8 @@ def wait_for_connection(url: WSUrl): ws = websocket.WebSocket() - retries_left = 60 + max_retries = 60 + retries_left = max_retries while retries_left: with suppress(ConnectionRefusedError, ConnectionResetError): @@ -23,6 +25,7 @@ def wait_for_connection(url: WSUrl): if ws.connected: return ws time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT) + logging.info(f"Try connecting, did {max_retries - retries_left} attempts") retries_left -= 1 raise RuntimeError("Retries exsausted") From 622b3740f890c877eb42fcf7f78139683b61a5dd Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Mon, 7 Aug 2023 20:47:43 +0300 Subject: [PATCH 09/47] Add retries to acquire_instance call, fix minor issues --- .../connection/broker_conn/requests.py | 3 +++ .../connection/broker_conn/socket_manager.py | 2 +- VSharp.ML.AIAgent/launch_servers.py | 16 +++++++++------- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/requests.py b/VSharp.ML.AIAgent/connection/broker_conn/requests.py index 924bcbac9..1ccba3aeb 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/requests.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/requests.py @@ -10,6 +10,9 @@ def acquire_instance() -> ServerInstanceInfo: response, content = httplib2.Http().request(WebsocketSourceLinks.GET_WS) + if response.status != 200: + logging.error(f"{response.status} with {content=} on acuire_instance call") + raise RuntimeError(f"Not ok response: {response.status}") aquired_instance = ServerInstanceInfo.from_json(json.loads(content.decode("utf-8"))) logging.info(f"acquired ws: {aquired_instance}") return aquired_instance diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index c9f24a558..1e7b788fb 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -14,7 +14,7 @@ def wait_for_connection(url: WSUrl): ws = websocket.WebSocket() max_retries = 60 - retries_left = max_retries + retries_left = 60 while retries_left: with suppress(ConnectionRefusedError, ConnectionResetError): diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index b4946fe39..30415fb09 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -17,13 +17,15 @@ @routes.get("/get_ws") async def dequeue_instance(request): - try: - server_info = SERVER_INSTANCES.get(timeout=0.1) - print(f"issued {server_info}") - return web.json_response(server_info.to_json()) - except Empty as e: - print(f"{os.getpid()} tried to dequeue an empty queue. Waiting...") - return web.Response(text=str(e)) + retry_count = 60 + while retry_count: + try: + server_info = SERVER_INSTANCES.get(timeout=1) + print(f"issued {server_info}") + return web.json_response(server_info.to_json()) + except Empty: + print(f"{os.getpid()} tried to dequeue an empty queue. Waiting...") + retry_count -= 1 @routes.post("/post_ws") From c7f127db55271675d1223f1de8a5bddb0b3501b4 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Mon, 7 Aug 2023 22:20:30 +0300 Subject: [PATCH 10/47] Add try count --- VSharp.ML.AIAgent/launch_servers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 30415fb09..47c256813 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -24,7 +24,9 @@ async def dequeue_instance(request): print(f"issued {server_info}") return web.json_response(server_info.to_json()) except Empty: - print(f"{os.getpid()} tried to dequeue an empty queue. Waiting...") + print( + f"{os.getpid()} tried to dequeue an empty queue. Retrying {retry_count}'s time..." + ) retry_count -= 1 From eb45f9b46a404e63ccda02cd5ecff31b53830353 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 11:21:35 +0300 Subject: [PATCH 11/47] Add error raise on empty dequeue --- VSharp.ML.AIAgent/launch_servers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 47c256813..5dab4d8f7 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -28,6 +28,7 @@ async def dequeue_instance(request): f"{os.getpid()} tried to dequeue an empty queue. Retrying {retry_count}'s time..." ) retry_count -= 1 + raise RuntimeError("Couldn't dequeue instance, the queue is not replenishing") @routes.post("/post_ws") From 4c2d92475cd62f44d486ec08545414635f06ceb3 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 11:32:08 +0300 Subject: [PATCH 12/47] Bring socket instantiation up from try-catch --- VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 1e7b788fb..17efd010d 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -33,10 +33,9 @@ def wait_for_connection(url: WSUrl): @contextmanager def game_server_socket_manager(): server_instance = acquire_instance() + socket = wait_for_connection(server_instance.ws_url) - socket = None try: - socket = wait_for_connection(server_instance.ws_url) socket.settimeout(GameServerConnectorConfig.RESPONCE_TIMEOUT_SEC) yield socket finally: From 7393d97edd652bf71d27437a44f88da95dd7b893 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 11:57:04 +0300 Subject: [PATCH 13/47] Add logging to server --- VSharp.ML.AIAgent/launch_servers.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 5dab4d8f7..e6127e0aa 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -1,5 +1,6 @@ import argparse import json +import logging import os import signal import subprocess @@ -13,6 +14,12 @@ from connection.broker_conn.classes import ServerInstanceInfo, Undefined, WSUrl routes = web.RouteTableDef() +logging.basicConfig( + level=GeneralConfig.LOGGER_LEVEL, + filename="instance_manager.log", + filemode="w", + format="%(asctime)s - p%(process)d: %(name)s - [%(levelname)s]: %(message)s", +) @routes.get("/get_ws") @@ -21,13 +28,14 @@ async def dequeue_instance(request): while retry_count: try: server_info = SERVER_INSTANCES.get(timeout=1) - print(f"issued {server_info}") + logging.info(f"issued {server_info}") return web.json_response(server_info.to_json()) except Empty: - print( - f"{os.getpid()} tried to dequeue an empty queue. Retrying {retry_count}'s time..." + logging.warning( + f"{os.getpid()} tried to dequeue an empty queue. {retry_count} retries left" ) retry_count -= 1 + logging.error("Couldn't dequeue instance, the queue is not replenishing") raise RuntimeError("Couldn't dequeue instance, the queue is not replenishing") @@ -37,15 +45,18 @@ async def enqueue_instance(request): returned_instance_info = ServerInstanceInfo.from_json( returned_instance_info_raw.decode("utf-8") ) + logging.info(f"got {returned_instance_info} from client") - print(f"put back {returned_instance_info}") if FeatureConfig.ON_GAME_SERVER_RESTART: kill_server(returned_instance_info.pid, forget=True) + logging.info(f"killing {returned_instance_info.pid}") returned_instance_info = run_server_instance( port=returned_instance_info.port, start_server=START_SERVERS ) + logging.info(f"running new instance: {returned_instance_info}") SERVER_INSTANCES.put(returned_instance_info) + logging.info(f"enqueue {returned_instance_info}") return web.HTTPOk() @@ -88,7 +99,7 @@ def run_server_instance(port: int, start_server: bool) -> ServerInstanceInfo: ) server_pid = proc.pid PROCS.append(server_pid) - print(f"{server_pid}: " + " ".join(launch_server + [str(port)])) + logging.info(f"{server_pid}: " + " ".join(launch_server + [str(port)])) ws_url = get_socket_url(port) return ServerInstanceInfo(port, ws_url, server_pid) @@ -109,7 +120,6 @@ def kill_server(pid: int, forget): os.kill(pid, signal.SIGKILL) if forget: PROCS.remove(pid) - print(f"killed {pid}") @contextmanager From 80d83376aa0d798cc8316813b2e87907e59301a1 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 12:13:59 +0300 Subject: [PATCH 14/47] Add killed server status display --- VSharp.ML.AIAgent/launch_servers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index e6127e0aa..cc5740ed0 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -6,7 +6,7 @@ import subprocess from contextlib import contextmanager from queue import Empty, Queue - +import psutil from aiohttp import web from common.constants import SERVER_WORKING_DIR @@ -49,7 +49,9 @@ async def enqueue_instance(request): if FeatureConfig.ON_GAME_SERVER_RESTART: kill_server(returned_instance_info.pid, forget=True) - logging.info(f"killing {returned_instance_info.pid}") + logging.info( + f"killing {returned_instance_info.pid}, its status: {psutil.Process(returned_instance_info.pid).status()}" + ) returned_instance_info = run_server_instance( port=returned_instance_info.port, start_server=START_SERVERS ) From a1532cd4893e48bf2774f9d0cdaa6611527eee6b Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 12:19:58 +0300 Subject: [PATCH 15/47] Add process status check with retries --- VSharp.ML.AIAgent/launch_servers.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index cc5740ed0..6bade167a 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -4,8 +4,10 @@ import os import signal import subprocess +import time from contextlib import contextmanager from queue import Empty, Queue + import psutil from aiohttp import web @@ -49,6 +51,19 @@ async def enqueue_instance(request): if FeatureConfig.ON_GAME_SERVER_RESTART: kill_server(returned_instance_info.pid, forget=True) + + wait_for_reset_retries = 60 + while wait_for_reset_retries: + wait_for_reset_retries -= 1 + logging.log( + f"Waiting for server to die, {wait_for_reset_retries} retries left" + ) + if ( + psutil.Process(returned_instance_info.pid).status() + != psutil.STATUS_RUNNING + ): + break + time.sleep(1) logging.info( f"killing {returned_instance_info.pid}, its status: {psutil.Process(returned_instance_info.pid).status()}" ) From 2d9d9ade243e21900a2f4948360798e6b1597d1a Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 12:23:48 +0300 Subject: [PATCH 16/47] Fix typo --- VSharp.ML.AIAgent/launch_servers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 6bade167a..5c1d106a4 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -55,7 +55,7 @@ async def enqueue_instance(request): wait_for_reset_retries = 60 while wait_for_reset_retries: wait_for_reset_retries -= 1 - logging.log( + logging.info( f"Waiting for server to die, {wait_for_reset_retries} retries left" ) if ( From 581b3ff2b7091105a31174cbe91ecde70d786035 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 12:32:55 +0300 Subject: [PATCH 17/47] Decrease sleep time --- VSharp.ML.AIAgent/launch_servers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 5c1d106a4..686b05fb0 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -52,7 +52,7 @@ async def enqueue_instance(request): if FeatureConfig.ON_GAME_SERVER_RESTART: kill_server(returned_instance_info.pid, forget=True) - wait_for_reset_retries = 60 + wait_for_reset_retries = 100 while wait_for_reset_retries: wait_for_reset_retries -= 1 logging.info( @@ -63,7 +63,7 @@ async def enqueue_instance(request): != psutil.STATUS_RUNNING ): break - time.sleep(1) + time.sleep(0.1) logging.info( f"killing {returned_instance_info.pid}, its status: {psutil.Process(returned_instance_info.pid).status()}" ) From 967144569a4010ad6ad7074166f096f4603abc1b Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 13:06:06 +0300 Subject: [PATCH 18/47] Up wait reset time --- VSharp.ML.AIAgent/launch_servers.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 686b05fb0..918b58c25 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -52,9 +52,8 @@ async def enqueue_instance(request): if FeatureConfig.ON_GAME_SERVER_RESTART: kill_server(returned_instance_info.pid, forget=True) - wait_for_reset_retries = 100 + wait_for_reset_retries = 60 while wait_for_reset_retries: - wait_for_reset_retries -= 1 logging.info( f"Waiting for server to die, {wait_for_reset_retries} retries left" ) @@ -63,9 +62,14 @@ async def enqueue_instance(request): != psutil.STATUS_RUNNING ): break - time.sleep(0.1) + time.sleep(1) + wait_for_reset_retries -= 1 + + if wait_for_reset_retries == 0: + raise RuntimeError(f"{returned_instance_info} could not be killed") + logging.info( - f"killing {returned_instance_info.pid}, its status: {psutil.Process(returned_instance_info.pid).status()}" + f"killed {returned_instance_info.pid}, its status: {psutil.Process(returned_instance_info.pid).status()}" ) returned_instance_info = run_server_instance( port=returned_instance_info.port, start_server=START_SERVERS From eefc4381a9632f75debd6ca2f2ceaacd3503e310 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 13:08:28 +0300 Subject: [PATCH 19/47] 1 -> 0.1 wait time --- VSharp.ML.AIAgent/launch_servers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 918b58c25..4ac020ca9 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -52,7 +52,7 @@ async def enqueue_instance(request): if FeatureConfig.ON_GAME_SERVER_RESTART: kill_server(returned_instance_info.pid, forget=True) - wait_for_reset_retries = 60 + wait_for_reset_retries = 600 while wait_for_reset_retries: logging.info( f"Waiting for server to die, {wait_for_reset_retries} retries left" @@ -62,7 +62,7 @@ async def enqueue_instance(request): != psutil.STATUS_RUNNING ): break - time.sleep(1) + time.sleep(0.1) wait_for_reset_retries -= 1 if wait_for_reset_retries == 0: From 8a6b8428278eaf2b61a751459f964999bd0411d2 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 13:35:09 +0300 Subject: [PATCH 20/47] Up retry time --- .../connection/broker_conn/socket_manager.py | 9 ++++----- VSharp.ML.AIAgent/launch_servers.py | 8 +++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 17efd010d..5230887f5 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -13,19 +13,18 @@ def wait_for_connection(url: WSUrl): ws = websocket.WebSocket() - max_retries = 60 - retries_left = 60 + retries_left = GameServerConnectorConfig.WAIT_FOR_SOCKET_RECONNECTION_MAX_RETRIES while retries_left: with suppress(ConnectionRefusedError, ConnectionResetError): - ws.settimeout(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT) + ws.settimeout(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC) ws.connect( url, skip_utf8_validation=GameServerConnectorConfig.SKIP_UTF_VALIDATION ) if ws.connected: return ws - time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT) - logging.info(f"Try connecting, did {max_retries - retries_left} attempts") + time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC) + logging.info(f"Try connecting, {retries_left} attempts left") retries_left -= 1 raise RuntimeError("Retries exsausted") diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 4ac020ca9..7957b269c 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -49,10 +49,12 @@ async def enqueue_instance(request): ) logging.info(f"got {returned_instance_info} from client") - if FeatureConfig.ON_GAME_SERVER_RESTART: + if FeatureConfig.ON_GAME_SERVER_RESTART.enabled: kill_server(returned_instance_info.pid, forget=True) - wait_for_reset_retries = 600 + wait_for_reset_retries = ( + FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_retries + ) while wait_for_reset_retries: logging.info( f"Waiting for server to die, {wait_for_reset_retries} retries left" @@ -62,7 +64,7 @@ async def enqueue_instance(request): != psutil.STATUS_RUNNING ): break - time.sleep(0.1) + time.sleep(FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_time) wait_for_reset_retries -= 1 if wait_for_reset_retries == 0: From 7659fb86464d6e310235eda81b1cdacec3004fef Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 13:39:13 +0300 Subject: [PATCH 21/47] Add config --- VSharp.ML.AIAgent/config.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/VSharp.ML.AIAgent/config.py b/VSharp.ML.AIAgent/config.py index 1424c7704..0bd34472b 100644 --- a/VSharp.ML.AIAgent/config.py +++ b/VSharp.ML.AIAgent/config.py @@ -61,6 +61,13 @@ def create_save_path_if_not_exists(self): self.save_path.mkdir() +@dataclass(slots=True, frozen=True) +class OnGameServerRestartFeature: + enabled: bool + wait_for_reset_retries: int + wait_for_reset_time: float + + class FeatureConfig: VERBOSE_TABLES = True SHOW_SUCCESSORS = True @@ -72,11 +79,14 @@ class FeatureConfig: SAVE_EPOCHS_COVERAGES = SaveEpochsCoveragesFeature( enabled=True, save_path=Path("./report/epochs_tables/") ) - ON_GAME_SERVER_RESTART = True + ON_GAME_SERVER_RESTART = OnGameServerRestartFeature( + enabled=True, wait_for_reset_retries=10 * 60 * 10, wait_for_reset_time=0.1 + ) class GameServerConnectorConfig: - CREATE_CONNECTION_TIMEOUT = 1 + CREATE_CONNECTION_TIMEOUT_SEC = 1 + WAIT_FOR_SOCKET_RECONNECTION_MAX_RETRIES = 10 * 60 RESPONCE_TIMEOUT_SEC = ( FeatureConfig.DUMP_BY_TIMEOUT.timeout_sec + 1 if FeatureConfig.DUMP_BY_TIMEOUT.enabled From c1bba65b465f964a7fc1d7dc622cfa936228af62 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 14:42:50 +0300 Subject: [PATCH 22/47] Wait for server alive --- VSharp.ML.AIAgent/launch_servers.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 7957b269c..807750ab9 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -78,6 +78,26 @@ async def enqueue_instance(request): ) logging.info(f"running new instance: {returned_instance_info}") + wait_for_alive_retries = ( + FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_retries + ) + while wait_for_alive_retries: + logging.info( + f"Waiting for server to die, {wait_for_alive_retries} retries left" + ) + if ( + psutil.Process(returned_instance_info.pid).status() + == psutil.STATUS_RUNNING + ): + break + time.sleep(FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_time) + wait_for_alive_retries -= 1 + + if wait_for_alive_retries == 0: + raise RuntimeError(f"{returned_instance_info} could not be killed") + + logging.info(f"{returned_instance_info} is run") + SERVER_INSTANCES.put(returned_instance_info) logging.info(f"enqueue {returned_instance_info}") return web.HTTPOk() From 5ca475a5368d406a0ee20e462d04b974c2c9b3bf Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 14:51:48 +0300 Subject: [PATCH 23/47] Fix desc --- VSharp.ML.AIAgent/launch_servers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 807750ab9..51bb7ae9b 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -83,7 +83,7 @@ async def enqueue_instance(request): ) while wait_for_alive_retries: logging.info( - f"Waiting for server to die, {wait_for_alive_retries} retries left" + f"Waiting for server to start, {wait_for_alive_retries} retries left" ) if ( psutil.Process(returned_instance_info.pid).status() @@ -94,7 +94,7 @@ async def enqueue_instance(request): wait_for_alive_retries -= 1 if wait_for_alive_retries == 0: - raise RuntimeError(f"{returned_instance_info} could not be killed") + raise RuntimeError(f"{returned_instance_info} has not resurrected") logging.info(f"{returned_instance_info} is run") From 8c027d07a3377c3e2eeb77fe8f38937bc31ffa73 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 15:52:39 +0300 Subject: [PATCH 24/47] Add url --- VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 5230887f5..8d2a5a6a4 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -24,7 +24,7 @@ def wait_for_connection(url: WSUrl): if ws.connected: return ws time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC) - logging.info(f"Try connecting, {retries_left} attempts left") + logging.info(f"Try connecting to {url}, {retries_left} attempts left") retries_left -= 1 raise RuntimeError("Retries exsausted") From 66c8fe7f77bc6c22baee32e0a7b06c0463965514 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 16:28:48 +0300 Subject: [PATCH 25/47] Verbose message --- VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 8d2a5a6a4..a1931a35d 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -26,7 +26,9 @@ def wait_for_connection(url: WSUrl): time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC) logging.info(f"Try connecting to {url}, {retries_left} attempts left") retries_left -= 1 - raise RuntimeError("Retries exsausted") + raise RuntimeError( + f"Retries exsausted wnen trying to connect to {url}: {retries_left} left" + ) @contextmanager From 16dc283547832ca47b08d48066ee47869e0ea467 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 17:25:55 +0300 Subject: [PATCH 26/47] Fix rare server state case --- VSharp.ML.AIAgent/launch_servers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 51bb7ae9b..de1a4884c 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -59,9 +59,9 @@ async def enqueue_instance(request): logging.info( f"Waiting for server to die, {wait_for_reset_retries} retries left" ) - if ( - psutil.Process(returned_instance_info.pid).status() - != psutil.STATUS_RUNNING + if psutil.Process(returned_instance_info.pid).status() in ( + psutil.STATUS_DEAD, + psutil.STATUS_ZOMBIE, ): break time.sleep(FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_time) From 2a8f21adb0cfb26d0747211b68691cccbd1b5298 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 17:53:57 +0300 Subject: [PATCH 27/47] View process status when try connecting --- .../connection/broker_conn/socket_manager.py | 20 +++++++++++-------- VSharp.ML.AIAgent/launch_servers.py | 4 +--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index a1931a35d..895d66843 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -2,15 +2,16 @@ import time from contextlib import contextmanager, suppress +import psutil import websocket from config import GameServerConnectorConfig -from connection.broker_conn.classes import WSUrl +from connection.broker_conn.classes import ServerInstanceInfo from .requests import acquire_instance, return_instance -def wait_for_connection(url: WSUrl): +def wait_for_connection(server_instance: ServerInstanceInfo): ws = websocket.WebSocket() retries_left = GameServerConnectorConfig.WAIT_FOR_SOCKET_RECONNECTION_MAX_RETRIES @@ -19,22 +20,25 @@ def wait_for_connection(url: WSUrl): with suppress(ConnectionRefusedError, ConnectionResetError): ws.settimeout(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC) ws.connect( - url, skip_utf8_validation=GameServerConnectorConfig.SKIP_UTF_VALIDATION + server_instance.ws_url, + skip_utf8_validation=GameServerConnectorConfig.SKIP_UTF_VALIDATION, ) - if ws.connected: - return ws + # if ws.connected: + # return ws time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC) - logging.info(f"Try connecting to {url}, {retries_left} attempts left") + logging.info( + f"Try connecting to {server_instance.ws_url}, {retries_left} attempts left; {psutil.Process(server_instance.pid)}" + ) retries_left -= 1 raise RuntimeError( - f"Retries exsausted wnen trying to connect to {url}: {retries_left} left" + f"Retries exsausted wnen trying to connect to {server_instance.ws_url}: {retries_left} left" ) @contextmanager def game_server_socket_manager(): server_instance = acquire_instance() - socket = wait_for_connection(server_instance.ws_url) + socket = wait_for_connection(server_instance) try: socket.settimeout(GameServerConnectorConfig.RESPONCE_TIMEOUT_SEC) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index de1a4884c..72719729b 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -70,9 +70,7 @@ async def enqueue_instance(request): if wait_for_reset_retries == 0: raise RuntimeError(f"{returned_instance_info} could not be killed") - logging.info( - f"killed {returned_instance_info.pid}, its status: {psutil.Process(returned_instance_info.pid).status()}" - ) + logging.info(f"killed {psutil.Process(returned_instance_info.pid)}") returned_instance_info = run_server_instance( port=returned_instance_info.port, start_server=START_SERVERS ) From af8edfac719328d60f58440cd9f0e675d2677589 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 18:06:36 +0300 Subject: [PATCH 28/47] Fix typo --- VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 895d66843..759a0b55a 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -23,8 +23,8 @@ def wait_for_connection(server_instance: ServerInstanceInfo): server_instance.ws_url, skip_utf8_validation=GameServerConnectorConfig.SKIP_UTF_VALIDATION, ) - # if ws.connected: - # return ws + if ws.connected: + return ws time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC) logging.info( f"Try connecting to {server_instance.ws_url}, {retries_left} attempts left; {psutil.Process(server_instance.pid)}" From 51ee70a6ebcaa7ece7157a3b05ca0cb2453874d7 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Tue, 8 Aug 2023 18:18:42 +0300 Subject: [PATCH 29/47] Solve issue on server --- VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 759a0b55a..76c6f656c 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -2,7 +2,6 @@ import time from contextlib import contextmanager, suppress -import psutil import websocket from config import GameServerConnectorConfig @@ -27,7 +26,7 @@ def wait_for_connection(server_instance: ServerInstanceInfo): return ws time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC) logging.info( - f"Try connecting to {server_instance.ws_url}, {retries_left} attempts left; {psutil.Process(server_instance.pid)}" + f"Try connecting to {server_instance.ws_url}, {retries_left} attempts left; {server_instance}" ) retries_left -= 1 raise RuntimeError( From 359c88e78d511ff500f9058544990e55a5e580f4 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Wed, 9 Aug 2023 12:45:48 +0300 Subject: [PATCH 30/47] Upd logs msg --- VSharp.ML.AIAgent/launch_servers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index 72719729b..d00b5a413 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -94,7 +94,7 @@ async def enqueue_instance(request): if wait_for_alive_retries == 0: raise RuntimeError(f"{returned_instance_info} has not resurrected") - logging.info(f"{returned_instance_info} is run") + logging.info(f"we should have run {psutil.Process(returned_instance_info.pid)}") SERVER_INSTANCES.put(returned_instance_info) logging.info(f"enqueue {returned_instance_info}") From 1ef5e4bece932d1e5f278cedcb14fcd29af5fef0 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Wed, 9 Aug 2023 13:09:56 +0300 Subject: [PATCH 31/47] More verbosity --- VSharp.ML.AIAgent/launch_servers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index d00b5a413..c831b1fac 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -57,7 +57,7 @@ async def enqueue_instance(request): ) while wait_for_reset_retries: logging.info( - f"Waiting for server to die, {wait_for_reset_retries} retries left" + f"Waiting for {returned_instance_info} to die, {wait_for_reset_retries} retries left" ) if psutil.Process(returned_instance_info.pid).status() in ( psutil.STATUS_DEAD, @@ -81,7 +81,7 @@ async def enqueue_instance(request): ) while wait_for_alive_retries: logging.info( - f"Waiting for server to start, {wait_for_alive_retries} retries left" + f"Waiting for {returned_instance_info} to start, {wait_for_alive_retries} retries left" ) if ( psutil.Process(returned_instance_info.pid).status() @@ -94,7 +94,9 @@ async def enqueue_instance(request): if wait_for_alive_retries == 0: raise RuntimeError(f"{returned_instance_info} has not resurrected") - logging.info(f"we should have run {psutil.Process(returned_instance_info.pid)}") + logging.info( + f"we should have run {returned_instance_info} on process: {psutil.Process(returned_instance_info.pid)}" + ) SERVER_INSTANCES.put(returned_instance_info) logging.info(f"enqueue {returned_instance_info}") From cac88802fb7f97d1066cf0059d1e63c47b216da0 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Wed, 9 Aug 2023 13:54:51 +0300 Subject: [PATCH 32/47] View server info when deque --- VSharp.ML.AIAgent/launch_servers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index c831b1fac..b8ce24e09 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -30,7 +30,7 @@ async def dequeue_instance(request): while retry_count: try: server_info = SERVER_INSTANCES.get(timeout=1) - logging.info(f"issued {server_info}") + logging.info(f"issued {server_info}: {psutil.Process(server_info.pid)}") return web.json_response(server_info.to_json()) except Empty: logging.warning( From d80937fc8964593c11fd95ff96f6eacbc3e66f7c Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Wed, 9 Aug 2023 14:11:16 +0300 Subject: [PATCH 33/47] Add not ok responce logging --- VSharp.ML.AIAgent/connection/broker_conn/requests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/requests.py b/VSharp.ML.AIAgent/connection/broker_conn/requests.py index 1ccba3aeb..949267c63 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/requests.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/requests.py @@ -12,7 +12,7 @@ def acquire_instance() -> ServerInstanceInfo: response, content = httplib2.Http().request(WebsocketSourceLinks.GET_WS) if response.status != 200: logging.error(f"{response.status} with {content=} on acuire_instance call") - raise RuntimeError(f"Not ok response: {response.status}") + raise RuntimeError(f"Not ok response: {response}, {content}") aquired_instance = ServerInstanceInfo.from_json(json.loads(content.decode("utf-8"))) logging.info(f"acquired ws: {aquired_instance}") return aquired_instance From ab8c18704cb6e8b3849c1dbca91d4a1e63e8f65a Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Wed, 9 Aug 2023 15:00:42 +0300 Subject: [PATCH 34/47] Launch servers on dequeue --- VSharp.ML.AIAgent/launch_servers.py | 132 +++++++++++++++------------- 1 file changed, 69 insertions(+), 63 deletions(-) diff --git a/VSharp.ML.AIAgent/launch_servers.py b/VSharp.ML.AIAgent/launch_servers.py index b8ce24e09..8716e6161 100644 --- a/VSharp.ML.AIAgent/launch_servers.py +++ b/VSharp.ML.AIAgent/launch_servers.py @@ -30,6 +30,10 @@ async def dequeue_instance(request): while retry_count: try: server_info = SERVER_INSTANCES.get(timeout=1) + assert server_info.pid is Undefined + server_info = run_server_instance( + port=server_info.port, start_server=START_SERVERS + ) logging.info(f"issued {server_info}: {psutil.Process(server_info.pid)}") return web.json_response(server_info.to_json()) except Empty: @@ -50,52 +54,9 @@ async def enqueue_instance(request): logging.info(f"got {returned_instance_info} from client") if FeatureConfig.ON_GAME_SERVER_RESTART.enabled: - kill_server(returned_instance_info.pid, forget=True) - - wait_for_reset_retries = ( - FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_retries - ) - while wait_for_reset_retries: - logging.info( - f"Waiting for {returned_instance_info} to die, {wait_for_reset_retries} retries left" - ) - if psutil.Process(returned_instance_info.pid).status() in ( - psutil.STATUS_DEAD, - psutil.STATUS_ZOMBIE, - ): - break - time.sleep(FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_time) - wait_for_reset_retries -= 1 - - if wait_for_reset_retries == 0: - raise RuntimeError(f"{returned_instance_info} could not be killed") - - logging.info(f"killed {psutil.Process(returned_instance_info.pid)}") - returned_instance_info = run_server_instance( - port=returned_instance_info.port, start_server=START_SERVERS - ) - logging.info(f"running new instance: {returned_instance_info}") - - wait_for_alive_retries = ( - FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_retries - ) - while wait_for_alive_retries: - logging.info( - f"Waiting for {returned_instance_info} to start, {wait_for_alive_retries} retries left" - ) - if ( - psutil.Process(returned_instance_info.pid).status() - == psutil.STATUS_RUNNING - ): - break - time.sleep(FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_time) - wait_for_alive_retries -= 1 - - if wait_for_alive_retries == 0: - raise RuntimeError(f"{returned_instance_info} has not resurrected") - - logging.info( - f"we should have run {returned_instance_info} on process: {psutil.Process(returned_instance_info.pid)}" + kill_server(returned_instance_info) + returned_instance_info = ServerInstanceInfo( + returned_instance_info.port, returned_instance_info.ws_url, pid=Undefined ) SERVER_INSTANCES.put(returned_instance_info) @@ -133,16 +94,39 @@ def run_server_instance(port: int, start_server: bool) -> ServerInstanceInfo: "--checkactualcoverage", "--port", ] - server_pid = Undefined - if start_server: - proc = subprocess.Popen( - launch_server + [str(port)], - start_new_session=True, - cwd=SERVER_WORKING_DIR, + ws_url = get_socket_url(port) + if not start_server: + return ServerInstanceInfo(port, ws_url, pid=Undefined) + + proc = subprocess.Popen( + launch_server + [str(port)], + start_new_session=True, + cwd=SERVER_WORKING_DIR, + ) + server_pid = proc.pid + PROCS.append(server_pid) + logging.info( + f"running new instance on {port=} with {server_pid=}:" + + f"{server_pid}: " + + " ".join(launch_server + [str(port)]) + ) + + wait_for_alive_retries = FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_retries + while wait_for_alive_retries: + logging.info( + f"Waiting for {server_pid}:{ws_url} to start, {wait_for_alive_retries} retries left" ) - server_pid = proc.pid - PROCS.append(server_pid) - logging.info(f"{server_pid}: " + " ".join(launch_server + [str(port)])) + if psutil.Process(server_pid).status() == psutil.STATUS_RUNNING: + break + time.sleep(FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_time) + wait_for_alive_retries -= 1 + + if wait_for_alive_retries == 0: + raise RuntimeError(f"{server_pid}:{ws_url} has not resurrected") + + logging.info( + f"we should have run {ws_url} on process: {psutil.Process(server_pid)}" + ) ws_url = get_socket_url(port) return ServerInstanceInfo(port, ws_url, server_pid) @@ -159,19 +143,41 @@ def run_servers( return servers_info -def kill_server(pid: int, forget): +def kill_server(server_instance: ServerInstanceInfo): + os.kill(server_instance.pid, signal.SIGKILL) + PROCS.remove(server_instance.pid) + + wait_for_reset_retries = FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_retries + while wait_for_reset_retries: + logging.info( + f"Waiting for {server_instance} to die, {wait_for_reset_retries} retries left" + ) + if psutil.Process(server_instance.pid).status() in ( + psutil.STATUS_DEAD, + psutil.STATUS_ZOMBIE, + ): + break + time.sleep(FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_time) + wait_for_reset_retries -= 1 + + if wait_for_reset_retries == 0: + raise RuntimeError(f"{server_instance} could not be killed") + + logging.info(f"killed {psutil.Process(server_instance.pid)}") + + +def kill_process(pid: int): os.kill(pid, signal.SIGKILL) - if forget: - PROCS.remove(pid) + PROCS.remove(pid) @contextmanager -def server_manager(server_queue: Queue[ServerInstanceInfo], start_servers: bool): +def server_manager(server_queue: Queue[ServerInstanceInfo]): global PROCS servers_info = run_servers( num_inst=GeneralConfig.SERVER_COUNT, start_port=ServerConfig.VSHARP_INSTANCES_START_PORT, - start_servers=start_servers, + start_servers=False, ) for server_info in servers_info: @@ -179,8 +185,8 @@ def server_manager(server_queue: Queue[ServerInstanceInfo], start_servers: bool) try: yield finally: - for proc in PROCS: - kill_server(proc, forget=False) + for proc in list(PROCS): + kill_process(proc) PROCS = [] @@ -200,7 +206,7 @@ def main(): PROCS = [] RESULTS = [] - with server_manager(SERVER_INSTANCES, start_servers=START_SERVERS): + with server_manager(SERVER_INSTANCES): app = web.Application() app.add_routes(routes) web.run_app(app, port=BrokerConfig.BROKER_PORT) From d5eb30732ba4fadde1a5ef11e5d38d8d4561f6fe Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Thu, 10 Aug 2023 11:46:26 +0300 Subject: [PATCH 35/47] Add status check --- .../connection/broker_conn/socket_manager.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 76c6f656c..2c362b984 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -1,4 +1,5 @@ import logging +import os import time from contextlib import contextmanager, suppress @@ -34,9 +35,20 @@ def wait_for_connection(server_instance: ServerInstanceInfo): ) +def get_pid_status(pid: int): + f = os.popen(f"top -pid {pid} -n 1 -l 1", "r") + text = f.read() + for potential_status in ("sleeping", "running", "zombie", "dead"): + if text.find(potential_status) != -1: + return potential_status + raise RuntimeError(f"Unknow status for {pid=}") + + @contextmanager def game_server_socket_manager(): server_instance = acquire_instance() + + logging.info(f"{server_instance} status is {get_pid_status(server_instance.pid)}") socket = wait_for_connection(server_instance) try: From 2162201c78247d70409b1d39d77d3e4a423982ee Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Thu, 10 Aug 2023 11:53:23 +0300 Subject: [PATCH 36/47] Upd status check --- VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 2c362b984..64ed31ca0 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -41,7 +41,7 @@ def get_pid_status(pid: int): for potential_status in ("sleeping", "running", "zombie", "dead"): if text.find(potential_status) != -1: return potential_status - raise RuntimeError(f"Unknow status for {pid=}") + raise RuntimeError(f"Unknow status for {pid=}: {text}") @contextmanager From 3d30cca4cceb42fda09947af0d339103085fb4f6 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Thu, 10 Aug 2023 12:35:47 +0300 Subject: [PATCH 37/47] Add error ignoring --- VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 64ed31ca0..219c70602 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -17,7 +17,11 @@ def wait_for_connection(server_instance: ServerInstanceInfo): retries_left = GameServerConnectorConfig.WAIT_FOR_SOCKET_RECONNECTION_MAX_RETRIES while retries_left: - with suppress(ConnectionRefusedError, ConnectionResetError): + with suppress( + ConnectionRefusedError, + ConnectionResetError, + websocket.WebSocketTimeoutException, + ): ws.settimeout(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC) ws.connect( server_instance.ws_url, From 0550cdda4320a2bdd8a4af2ef63b0968a046436a Mon Sep 17 00:00:00 2001 From: Semyon Grigorev Date: Thu, 10 Aug 2023 15:00:15 +0300 Subject: [PATCH 38/47] Port code --- .../connection/broker_conn/socket_manager.py | 7 +- VSharp.ML.AIAgent/ml/common_model.py | 209 ------------------ VSharp.ML.AIAgent/ml/common_model/models.py | 89 ++++++++ VSharp.ML.AIAgent/ml/common_model/paths.py | 5 + VSharp.ML.AIAgent/ml/common_model/utils.py | 86 +++++++ VSharp.ML.AIAgent/ml/common_model/wrapper.py | 56 +++++ .../ml/model_wrappers/nnwrapper.py | 2 +- VSharp.ML.AIAgent/ml/models.py | 79 +++++++ .../ml/predict_state_vector_hetero.py | 2 +- .../run_common_model_training.py | 83 +++++++ 10 files changed, 401 insertions(+), 217 deletions(-) delete mode 100644 VSharp.ML.AIAgent/ml/common_model.py create mode 100644 VSharp.ML.AIAgent/ml/common_model/models.py create mode 100644 VSharp.ML.AIAgent/ml/common_model/paths.py create mode 100644 VSharp.ML.AIAgent/ml/common_model/utils.py create mode 100644 VSharp.ML.AIAgent/ml/common_model/wrapper.py create mode 100644 VSharp.ML.AIAgent/run_common_model_training.py diff --git a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py index 219c70602..958525b43 100644 --- a/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py +++ b/VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py @@ -40,12 +40,7 @@ def wait_for_connection(server_instance: ServerInstanceInfo): def get_pid_status(pid: int): - f = os.popen(f"top -pid {pid} -n 1 -l 1", "r") - text = f.read() - for potential_status in ("sleeping", "running", "zombie", "dead"): - if text.find(potential_status) != -1: - return potential_status - raise RuntimeError(f"Unknow status for {pid=}: {text}") + return "Unknown" @contextmanager diff --git a/VSharp.ML.AIAgent/ml/common_model.py b/VSharp.ML.AIAgent/ml/common_model.py deleted file mode 100644 index 340a475ba..000000000 --- a/VSharp.ML.AIAgent/ml/common_model.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch -import os -import re - -from torch import nn -from torch_geometric.nn import ( - GATConv, - SAGEConv, -) -from torchvision.ops import MLP - -from common.game import GameState -from ml.data_loader_compact import ServerDataloaderHeteroVector -from ml.model_wrappers.protocols import Predictor -from ml.predict_state_vector_hetero import PredictStateVectorHetGNN -import csv - -from learning.play_game import play_game -from config import GeneralConfig -from connection.game_server_conn.utils import MapsType -from ml.models import SAGEConvModel - -csv_path = "../report/epochs_tables/" -models_path = "../report/epochs_best/" - - -class CommonModel(torch.nn.Module): - def __init__( - self, - hidden_channels, - num_gv_layers=2, - num_sv_layers=2, - ): - super().__init__() - self.gv_layers = nn.ModuleList() - self.gv_layers.append(SAGEConv(-1, hidden_channels)) - for i in range(num_gv_layers - 1): - sage_gv = SAGEConv(-1, hidden_channels) - self.gv_layers.append(sage_gv) - - self.sv_layers = nn.ModuleList() - self.sv_layers.append(SAGEConv(-1, hidden_channels)) - for i in range(num_sv_layers - 1): - sage_sv = SAGEConv(-1, hidden_channels) - self.sv_layers.append(sage_sv) - - self.history1 = GATConv((-1, -1), hidden_channels, add_self_loops=False) - - self.in1 = SAGEConv((-1, -1), hidden_channels) - - self.sv_layers2 = nn.ModuleList() - self.sv_layers2.append(SAGEConv(-1, hidden_channels)) - for i in range(num_sv_layers - 1): - sage_sv = SAGEConv(-1, hidden_channels) - self.sv_layers2.append(sage_sv) - - self.mlp = MLP(hidden_channels, [1]) - - def forward(self, x_dict, edge_index_dict, edge_attr_dict): - game_x = self.gv_layers[0]( - x_dict["game_vertex"], - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - for layer in self.gv_layers[1:]: - game_x = layer( - game_x, - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - state_x = self.sv_layers[0]( - x_dict["state_vertex"], - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - for layer in self.sv_layers[1:]: - state_x = layer( - state_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - history_x = self.history1( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - edge_attr_dict, - size=(game_x.size(0), state_x.size(0)), - ).relu() - - in_x = self.in1( - (game_x, history_x), edge_index_dict[("game_vertex", "in", "state_vertex")] - ).relu() - - state_x = self.sv_layers2[0]( - in_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - for layer in self.sv_layers2[1:]: - state_x = layer( - state_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - x = self.mlp(in_x) - return x - - -def euclidean_dist(y_pred, y_true): - y_pred_min, ind1 = torch.min(y_pred, dim=0) - y_pred_norm = y_pred - y_pred_min - - y_true_min, ind1 = torch.min(y_true, dim=0) - y_true_norm = y_true - y_true_min - - return torch.sqrt(torch.sum((y_pred_norm - y_true_norm) ** 2)) - - -class CommonModelWrapper(Predictor): - def __init__(self, model: torch.nn.Module, best_models: dict) -> None: - self.model = model - self.best_models = best_models - self._model = model - - def name(self) -> str: - return "MY AWESOME MODEL" - - def model(self): - return self._model - - def predict(self, input: GameState, map_name): - hetero_input, state_map = ServerDataloaderHeteroVector.convert_input_to_tensor( - input - ) - assert self.model is not None - - next_step_id = PredictStateVectorHetGNN.predict_state_single_out( - self.model, hetero_input, state_map - ) - - back_prop(self.best_models[map_name], self.model, hetero_input) - - del hetero_input - return next_step_id - - -def get_last_epoch_num(path): - epochs = list(map(lambda x: re.findall("[0-9]+", x), os.listdir(path))) - return str(sorted(epochs)[-1][0]) - - -def csv2best_models(): - best_models = {} - values = [] - models_names = [] - - with open(csv_path + get_last_epoch_num(csv_path) + ".csv", "r") as csv_file: - csv_reader = csv.reader(csv_file) - map_names = next(csv_reader)[1:] - for row in csv_reader: - models_names.append(row[0]) - int_row = list(map(lambda x: int(x), row[1:])) - values.append(int_row) - val, ind = torch.max(torch.tensor(values), dim=0) - for i in range(len(map_names)): - best_models[map_names[i]] = models_names[ind[i]] - return best_models - - -def back_prop(best_model, model, data): - model.train() - data.to(GeneralConfig.DEVICE) - optimizer.zero_grad() - ref_model = SAGEConvModel(16) - ref_model.load_state_dict( - torch.load( - models_path - + "epoch_" - + get_last_epoch_num(models_path) - + "/" - + best_model - + ".pth" - ) - ) - ref_model.to(GeneralConfig.DEVICE) - out = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) - y_true = ref_model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) - - loss = criterion(out, y_true) - loss.backward() - optimizer.step() - - -model = CommonModel(32) -model.to(GeneralConfig.DEVICE) - -cmwrapper = CommonModelWrapper(model, csv2best_models()) -lr = 0.0001 -epochs = 10 -optimizer = torch.optim.Adam(model.parameters(), lr=lr) -criterion = euclidean_dist - - -def train(): - for epoch in range(epochs): - # some function with some parameters - play_game( - with_predictor=cmwrapper, - max_steps=GeneralConfig.MAX_STEPS, - maps_type=MapsType.TRAIN, - ) - - -train() diff --git a/VSharp.ML.AIAgent/ml/common_model/models.py b/VSharp.ML.AIAgent/ml/common_model/models.py new file mode 100644 index 000000000..ec00d4e6d --- /dev/null +++ b/VSharp.ML.AIAgent/ml/common_model/models.py @@ -0,0 +1,89 @@ +import torch +from torch import nn +from torch_geometric.nn import ( + GATConv, + SAGEConv, + TAGConv, +) +from torchvision.ops import MLP + + +class CommonModel(torch.nn.Module): + def __init__( + self, + hidden_channels, + num_gv_layers=2, + num_sv_layers=2, + ): + super().__init__() + self.tag_conv1 = TAGConv(5, hidden_channels, 2) + self.tag_conv2 = TAGConv(6, hidden_channels, 3) + self.gv_layers = nn.ModuleList() + self.gv_layers.append(self.tag_conv1) + self.gv_layers.append(SAGEConv(-1, hidden_channels)) + for i in range(num_gv_layers - 1): + sage_gv = SAGEConv(-1, hidden_channels) + self.gv_layers.append(sage_gv) + + self.sv_layers = nn.ModuleList() + self.sv_layers.append(self.tag_conv2) + self.sv_layers.append(SAGEConv(-1, hidden_channels)) + for i in range(num_sv_layers - 1): + sage_sv = SAGEConv(-1, hidden_channels) + self.sv_layers.append(sage_sv) + + self.history1 = GATConv((-1, -1), hidden_channels, add_self_loops=False) + + self.in1 = SAGEConv((-1, -1), hidden_channels) + + self.sv_layers2 = nn.ModuleList() + self.sv_layers2.append(SAGEConv(-1, hidden_channels)) + for i in range(num_sv_layers - 1): + sage_sv = SAGEConv(-1, hidden_channels) + self.sv_layers2.append(sage_sv) + + self.mlp = MLP(hidden_channels, [1]) + + def forward(self, x_dict, edge_index_dict, edge_attr_dict): + game_x = self.gv_layers[0]( + x_dict["game_vertex"], + edge_index_dict[("game_vertex", "to", "game_vertex")], + ).relu() + for layer in self.gv_layers[1:]: + game_x = layer( + game_x, + edge_index_dict[("game_vertex", "to", "game_vertex")], + ).relu() + + state_x = self.sv_layers[0]( + x_dict["state_vertex"], + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + for layer in self.sv_layers[1:]: + state_x = layer( + state_x, + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + + history_x = self.history1( + (game_x, state_x), + edge_index_dict[("game_vertex", "history", "state_vertex")], + edge_attr_dict, + size=(game_x.size(0), state_x.size(0)), + ).relu() + + in_x = self.in1( + (game_x, history_x), edge_index_dict[("game_vertex", "in", "state_vertex")] + ).relu() + + state_x = self.sv_layers2[0]( + in_x, + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + for layer in self.sv_layers2[1:]: + state_x = layer( + state_x, + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + x = self.mlp(in_x) + return x diff --git a/VSharp.ML.AIAgent/ml/common_model/paths.py b/VSharp.ML.AIAgent/ml/common_model/paths.py new file mode 100644 index 000000000..4ee8c313e --- /dev/null +++ b/VSharp.ML.AIAgent/ml/common_model/paths.py @@ -0,0 +1,5 @@ +import os + +csv_path = os.path.join("report", "epochs_tables") +models_path = os.path.join("report", "epochs_best") +common_models_path = os.path.join("report", "common_models") diff --git a/VSharp.ML.AIAgent/ml/common_model/utils.py b/VSharp.ML.AIAgent/ml/common_model/utils.py new file mode 100644 index 000000000..989035290 --- /dev/null +++ b/VSharp.ML.AIAgent/ml/common_model/utils.py @@ -0,0 +1,86 @@ +import ast +import csv +import os +import re + +import torch + +from config import GeneralConfig +from ml.common_model.paths import csv_path, models_path +from ml.models import SAGEConvModel + + +def euclidean_dist(y_pred, y_true): + if len(y_pred) > 1: + y_pred_min, ind1 = torch.min(y_pred, dim=0) + y_pred_norm = y_pred - y_pred_min + + y_true_min, ind1 = torch.min(y_true, dim=0) + y_true_norm = y_true - y_true_min + return torch.sqrt(torch.sum((y_pred_norm - y_true_norm) ** 2)) + else: + return 0 + + +def get_last_epoch_num(path): + epochs = list(map(lambda x: re.findall("[0-9]+", x), os.listdir(path))) + return str(sorted(epochs)[-1][0]) + + +def get_tuple_for_max(t): + t[1] *= -1 + t[3] *= -1 + return t + + +def csv2best_models(): + best_models = {} + for epoch_num in range(1, len(os.listdir(csv_path)) + 1): + path_to_csv = os.path.join(csv_path, str(epoch_num) + ".csv") + with open(path_to_csv, "r") as csv_file: + csv_reader = csv.reader(csv_file) + map_names = next(csv_reader)[1:] + models = [] + for row in csv_reader: + models_stat = dict() + # int_row = list(map(lambda x: tuple(map(lambda y: int(y), x[1:-1].split(", "))), row[1:])) + int_row = list(map(lambda x: ast.literal_eval(x), row[1:])) + for i in range(len(int_row)): + models_stat[map_names[i]] = int_row[i] + models.append((row[0], models_stat)) + + for map_name in map_names: + best_model = max(models, key=(lambda m: m[1][map_name])) + best_model_name = best_model[0] + best_model_score = best_model[1] + ref_model = SAGEConvModel(16) + path_to_model = os.path.join( + models_path, + "epoch_" + str(epoch_num), + best_model_name + ".pth", + ) + ref_model.load_state_dict(torch.load(path_to_model)) + ref_model.to(GeneralConfig.DEVICE) + best_models[map_name] = (ref_model, best_model_score[map_name]) + return best_models + + +def back_prop(best_model, model, data, optimizer, criterion): + model.train() + data.to(GeneralConfig.DEVICE) + optimizer.zero_grad() + out = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict)["state_vertex"] + y_true = best_model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) + # print(y_true, "\n", out) + if type(y_true) is dict: + y_true = y_true["state_vertex"] + if abs(torch.min(y_true)) > 1: + y_true = 1 / y_true + # print(out, "\n", y_true) + loss = criterion(out, y_true) + if loss == 0: + return 0 + # print('loss:', loss) + loss.backward() + optimizer.step() + return loss diff --git a/VSharp.ML.AIAgent/ml/common_model/wrapper.py b/VSharp.ML.AIAgent/ml/common_model/wrapper.py new file mode 100644 index 000000000..7f9b87151 --- /dev/null +++ b/VSharp.ML.AIAgent/ml/common_model/wrapper.py @@ -0,0 +1,56 @@ +import torch +import copy +import logging + +from common.game import GameState +from ml.data_loader_compact import ServerDataloaderHeteroVector +from ml.model_wrappers.protocols import Predictor +from ml.predict_state_vector_hetero import PredictStateVectorHetGNN +from ml.common_model.utils import back_prop + + +class CommonModelWrapper(Predictor): + def __init__( + self, model: torch.nn.Module, best_models: dict, optimizer, criterion + ) -> None: + self.best_models = best_models + self._model = model + # self.name = sum(torch.cat([p.view(-1) for p in self.model.parameters()], dim=0)) + self.optimizer = optimizer + self.criterion = criterion + + def name(self): + return "Common model" + + def update(self, map_name, map_result): + map_result = map_result.actual_coverage_percent + if self.best_models[map_name][1] <= map_result: + logging.info( + f"The model with result = {self.best_models[map_name][1]} was replaced with the model with " + f"result = {map_result} on the map {map_name}" + ) + self.best_models[map_name] = (copy.deepcopy(self._model), map_result) + + def model(self): + return self._model + + def predict(self, input: GameState, map_name): + hetero_input, state_map = ServerDataloaderHeteroVector.convert_input_to_tensor( + input + ) + assert self._model is not None + + next_step_id = PredictStateVectorHetGNN.predict_state_with_dict( + self._model, hetero_input, state_map + ) + + back_prop( + self.best_models[map_name][0], + self._model, + hetero_input, + self.optimizer, + self.criterion, + ) + + del hetero_input + return next_step_id diff --git a/VSharp.ML.AIAgent/ml/model_wrappers/nnwrapper.py b/VSharp.ML.AIAgent/ml/model_wrappers/nnwrapper.py index 6680e0acc..850382c8e 100644 --- a/VSharp.ML.AIAgent/ml/model_wrappers/nnwrapper.py +++ b/VSharp.ML.AIAgent/ml/model_wrappers/nnwrapper.py @@ -26,7 +26,7 @@ def predict(self, input: GameState, map_name): ) assert self._model is not None - next_step_id = PredictStateVectorHetGNN.predict_state_ekaterina( + next_step_id = PredictStateVectorHetGNN.predict_state_with_dict( self._model, hetero_input, state_map ) del hetero_input diff --git a/VSharp.ML.AIAgent/ml/models.py b/VSharp.ML.AIAgent/ml/models.py index 1e8fde0f3..45841c3b7 100644 --- a/VSharp.ML.AIAgent/ml/models.py +++ b/VSharp.ML.AIAgent/ml/models.py @@ -18,6 +18,9 @@ to_hetero, ) +from torchvision.ops import MLP + + from learning.timer.wrapper import timeit NUM_PREDICTED_VALUES = 4 @@ -684,3 +687,79 @@ def forward(self, x_dict, edge_index_dict): z_dict["state_vertex"] = self.state_encoder(x_dict, edge_index_dict) z_dict["game_vertex"] = x_dict["game_vertex"] return z_dict + + +class SAGEConvModel(torch.nn.Module): + def __init__( + self, + hidden_channels, + num_gv_layers=2, + num_sv_layers=2, + ): + super().__init__() + self.gv_layers = nn.ModuleList() + self.gv_layers.append(SAGEConv(-1, hidden_channels)) + for i in range(num_gv_layers - 1): + sage_gv = SAGEConv(-1, hidden_channels) + self.gv_layers.append(sage_gv) + + self.sv_layers = nn.ModuleList() + self.sv_layers.append(SAGEConv(-1, hidden_channels)) + for i in range(num_sv_layers - 1): + sage_sv = SAGEConv(-1, hidden_channels) + self.sv_layers.append(sage_sv) + + self.history1 = GATConv((-1, -1), hidden_channels, add_self_loops=False) + self.in1 = SAGEConv((-1, -1), hidden_channels) + + self.sv_layers2 = nn.ModuleList() + self.sv_layers2.append(SAGEConv(-1, hidden_channels)) + for i in range(num_sv_layers - 1): + sage_sv = SAGEConv(-1, hidden_channels) + self.sv_layers2.append(sage_sv) + self.mlp = MLP(hidden_channels, [1]) + + @timeit + def forward(self, x_dict, edge_index_dict, edge_attr_dict): + game_x = self.gv_layers[0]( + x_dict["game_vertex"], + edge_index_dict[("game_vertex", "to", "game_vertex")], + ).relu() + for layer in self.gv_layers[1:]: + game_x = layer( + game_x, + edge_index_dict[("game_vertex", "to", "game_vertex")], + ).relu() + + state_x = self.sv_layers[0]( + x_dict["state_vertex"], + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + for layer in self.sv_layers[1:]: + state_x = layer( + state_x, + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + + history_x = self.history1( + (game_x, state_x), + edge_index_dict[("game_vertex", "history", "state_vertex")], + edge_attr_dict, + size=(game_x.size(0), state_x.size(0)), + ).relu() + + in_x = self.in1( + (game_x, history_x), edge_index_dict[("game_vertex", "in", "state_vertex")] + ).relu() + + state_x = self.sv_layers2[0]( + in_x, + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + for layer in self.sv_layers2[1:]: + state_x = layer( + state_x, + edge_index_dict[("state_vertex", "parent_of", "state_vertex")], + ).relu() + x = self.mlp(in_x) + return x diff --git a/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py b/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py index e35d4aece..e95072ac4 100644 --- a/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py +++ b/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py @@ -87,7 +87,7 @@ def predict_state(model, data: HeteroData, state_map: dict[int, int]) -> int: return state_map[int(out["state_vertex"].argmax(dim=0)[0])] @staticmethod - def predict_state_ekaterina( + def predict_state_with_dict( model: torch.nn.Module, data: HeteroData, state_map: dict[int, int] ) -> int: """Gets state id from model and heterogeneous graph diff --git a/VSharp.ML.AIAgent/run_common_model_training.py b/VSharp.ML.AIAgent/run_common_model_training.py new file mode 100644 index 000000000..76fa5b9cb --- /dev/null +++ b/VSharp.ML.AIAgent/run_common_model_training.py @@ -0,0 +1,83 @@ +import logging +import os +from pathlib import Path + +import torch + +from config import GeneralConfig +from connection.game_server_conn.utils import MapsType +from epochs_statistics.tables import create_pivot_table, table_to_string +from learning.play_game import play_game +from ml.common_model.models import CommonModel +from ml.common_model.utils import csv2best_models, euclidean_dist +from ml.common_model.wrapper import CommonModelWrapper +from ml.fileop import save_model +from ml.common_model.paths import common_models_path +from ml.utils import load_model + + +LOG_PATH = Path("./ml_app.log") +TABLES_PATH = Path("./ml_tables.log") +COMMON_MODELS_PATH = Path(common_models_path) + +logging.basicConfig( + level=GeneralConfig.LOGGER_LEVEL, + filename=LOG_PATH, + filemode="a", + format="%(asctime)s - p%(process)d: %(name)s - [%(levelname)s]: %(message)s", +) + +if not COMMON_MODELS_PATH.exists(): + os.makedirs(common_models_path) + + +def create_file(file: Path): + open(file, "w").close() + + +def append_to_file(file: Path, s: str): + with open(file, "a") as file: + file.write(s) + + +def main(): + lr = 0.0000001 + epochs = 3 + hidden_channels = 32 + num_gv_layers = 2 + num_sv_layers = 2 + print(GeneralConfig.DEVICE) + # model = CommonModel(hidden_channels, num_gv_layers, num_sv_layers) + # model.forward(*ml.onnx.onnx_import.create_torch_dummy_input()) + path_to_model = os.path.join( + "ml", + "pretrained_models", + "-262.75775990410693.pth", + ) + + model = load_model(path_to_model, model=GeneralConfig.EXPORT_MODEL_INIT()) + model.to(GeneralConfig.DEVICE) + + create_file(TABLES_PATH) + create_file(LOG_PATH) + print(model) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + criterion = euclidean_dist + cmwrapper = CommonModelWrapper(model, csv2best_models(), optimizer, criterion) + + for epoch in range(epochs): + result = play_game( + with_predictor=cmwrapper, + max_steps=GeneralConfig.MAX_STEPS, + maps_type=MapsType.TRAIN, + ) + table, _, _ = create_pivot_table({cmwrapper: result}) + table = table_to_string(table) + append_to_file(TABLES_PATH, f"Epoch#{epoch}" + "\n") + append_to_file(TABLES_PATH, table + "\n") + path_to_model = os.path.join(common_models_path, str(epoch + 1)) + save_model(model=cmwrapper.model(), to=Path(path_to_model)) + + +if __name__ == "__main__": + main() From add882d98eb54f2a8eb9bbe7358b79b656684c96 Mon Sep 17 00:00:00 2001 From: Semyon Grigorev Date: Thu, 10 Aug 2023 15:22:09 +0300 Subject: [PATCH 39/47] added some code --- VSharp.ML.AIAgent/learning/play_game.py | 1 + VSharp.ML.AIAgent/run_common_model_training.py | 1 + 2 files changed, 2 insertions(+) diff --git a/VSharp.ML.AIAgent/learning/play_game.py b/VSharp.ML.AIAgent/learning/play_game.py index 34f5c519a..b6c52a56f 100644 --- a/VSharp.ML.AIAgent/learning/play_game.py +++ b/VSharp.ML.AIAgent/learning/play_game.py @@ -83,6 +83,7 @@ def play_map( actual_coverage_percent=actual_coverage, ) + with_predictor.update(with_connector.map.MapName, model_result) return model_result, end_time - start_time diff --git a/VSharp.ML.AIAgent/run_common_model_training.py b/VSharp.ML.AIAgent/run_common_model_training.py index 76fa5b9cb..2936ab67b 100644 --- a/VSharp.ML.AIAgent/run_common_model_training.py +++ b/VSharp.ML.AIAgent/run_common_model_training.py @@ -13,6 +13,7 @@ from ml.common_model.wrapper import CommonModelWrapper from ml.fileop import save_model from ml.common_model.paths import common_models_path +from ml.model_wrappers.protocols import Predictor from ml.utils import load_model From 7ca8d007d7c4193c58faffaed910081a8cf82b8c Mon Sep 17 00:00:00 2001 From: Anya497 Date: Thu, 10 Aug 2023 15:44:51 +0300 Subject: [PATCH 40/47] added model validation and best_models updating after each epoch --- VSharp.ML.AIAgent/ml/common_model/wrapper.py | 42 ++++++++++++++----- .../run_common_model_training.py | 27 ++++++++---- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/VSharp.ML.AIAgent/ml/common_model/wrapper.py b/VSharp.ML.AIAgent/ml/common_model/wrapper.py index 7f9b87151..bee34e3f9 100644 --- a/VSharp.ML.AIAgent/ml/common_model/wrapper.py +++ b/VSharp.ML.AIAgent/ml/common_model/wrapper.py @@ -10,14 +10,10 @@ class CommonModelWrapper(Predictor): - def __init__( - self, model: torch.nn.Module, best_models: dict, optimizer, criterion - ) -> None: + def __init__(self, model: torch.nn.Module, best_models: dict) -> None: self.best_models = best_models self._model = model # self.name = sum(torch.cat([p.view(-1) for p in self.model.parameters()], dim=0)) - self.optimizer = optimizer - self.criterion = criterion def name(self): return "Common model" @@ -44,13 +40,37 @@ def predict(self, input: GameState, map_name): self._model, hetero_input, state_map ) - back_prop( - self.best_models[map_name][0], - self._model, - hetero_input, - self.optimizer, - self.criterion, + del hetero_input + return next_step_id + + +class BestModelsWrapper(Predictor): + def __init__(self, model: torch.nn.Module, best_models: dict, optimizer, criterion) -> None: + self.best_models = best_models + self._model = model + self.optimizer = optimizer + self.criterion = criterion + + def name(self): + return "Common model" + + def update(self, map_name, map_result): + pass + + def model(self): + return self._model + + def predict(self, input: GameState, map_name): + hetero_input, state_map = ServerDataloaderHeteroVector.convert_input_to_tensor( + input + ) + assert self._model is not None + + next_step_id = PredictStateVectorHetGNN.predict_state_single_out( + self.best_models[map_name][0], hetero_input, state_map ) + back_prop(self.best_models[map_name][0], self._model, hetero_input, self.optimizer, self.criterion) + del hetero_input return next_step_id diff --git a/VSharp.ML.AIAgent/run_common_model_training.py b/VSharp.ML.AIAgent/run_common_model_training.py index 2936ab67b..25679745e 100644 --- a/VSharp.ML.AIAgent/run_common_model_training.py +++ b/VSharp.ML.AIAgent/run_common_model_training.py @@ -3,6 +3,7 @@ from pathlib import Path import torch +import torch.nn as nn from config import GeneralConfig from connection.game_server_conn.utils import MapsType @@ -10,7 +11,7 @@ from learning.play_game import play_game from ml.common_model.models import CommonModel from ml.common_model.utils import csv2best_models, euclidean_dist -from ml.common_model.wrapper import CommonModelWrapper +from ml.common_model.wrapper import CommonModelWrapper, BestModelsWrapper from ml.fileop import save_model from ml.common_model.paths import common_models_path from ml.model_wrappers.protocols import Predictor @@ -43,30 +44,41 @@ def append_to_file(file: Path, s: str): def main(): lr = 0.0000001 - epochs = 3 + epochs = 1 hidden_channels = 32 num_gv_layers = 2 num_sv_layers = 2 print(GeneralConfig.DEVICE) # model = CommonModel(hidden_channels, num_gv_layers, num_sv_layers) # model.forward(*ml.onnx.onnx_import.create_torch_dummy_input()) + # path = os.path.join( + # common_models_path, + # "1", + # ) path_to_model = os.path.join( - "ml", - "pretrained_models", + common_models_path, "-262.75775990410693.pth", ) - model = load_model(path_to_model, model=GeneralConfig.EXPORT_MODEL_INIT()) + model = load_model(Path(path_to_model), model=GeneralConfig.EXPORT_MODEL_INIT()) model.to(GeneralConfig.DEVICE) create_file(TABLES_PATH) create_file(LOG_PATH) - print(model) optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = euclidean_dist - cmwrapper = CommonModelWrapper(model, csv2best_models(), optimizer, criterion) + best_models_dict = csv2best_models() + cmwrapper = CommonModelWrapper(model, best_models_dict) + bmwrapper = BestModelsWrapper(model, best_models_dict, optimizer, criterion) for epoch in range(epochs): + # training + play_game( + with_predictor=bmwrapper, + max_steps=GeneralConfig.MAX_STEPS, + maps_type=MapsType.TRAIN, + ) + # validation result = play_game( with_predictor=cmwrapper, max_steps=GeneralConfig.MAX_STEPS, @@ -76,6 +88,7 @@ def main(): table = table_to_string(table) append_to_file(TABLES_PATH, f"Epoch#{epoch}" + "\n") append_to_file(TABLES_PATH, table + "\n") + path_to_model = os.path.join(common_models_path, str(epoch + 1)) save_model(model=cmwrapper.model(), to=Path(path_to_model)) From 26f408a907eb7f0b1edffd3cac2bdd75119ff710 Mon Sep 17 00:00:00 2001 From: Anya497 Date: Tue, 8 Aug 2023 12:28:22 +0300 Subject: [PATCH 41/47] added maps names to logs, deleted difficult maps from train dataset --- VSharp.ML.GameServer.Runner/Main.fs | 2 +- VSharp.ML.GameServer/Maps.fs | 56 ++++++++++++++--------------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/VSharp.ML.GameServer.Runner/Main.fs b/VSharp.ML.GameServer.Runner/Main.fs index eeef0a61f..1a9eebbda 100644 --- a/VSharp.ML.GameServer.Runner/Main.fs +++ b/VSharp.ML.GameServer.Runner/Main.fs @@ -111,7 +111,7 @@ let ws checkActualCoverage outputDirectory (webSocket : WebSocket) (context: Htt let _expectedCoverage = 100 let exploredMethodInfo = AssemblyManager.NormalizeMethod method let status,actualCoverage,message = VSharp.Test.TestResultChecker.Check(testsDir, exploredMethodInfo :?> MethodInfo, _expectedCoverage) - printfn $"Actual coverage: {actualCoverage}" + printfn $"Actual coverage for {settings.MapName}: {actualCoverage}" System.Nullable (if actualCoverage < 0 then 0u else uint actualCoverage) with e -> diff --git a/VSharp.ML.GameServer/Maps.fs b/VSharp.ML.GameServer/Maps.fs index a4d01b526..3434c6667 100644 --- a/VSharp.ML.GameServer/Maps.fs +++ b/VSharp.ML.GameServer/Maps.fs @@ -17,7 +17,6 @@ let trainMaps, validationMaps = add 0u "VSharp.ML.GameMaps.dll" CoverageZone.Method "BinarySearch" add 50u "VSharp.ML.GameMaps.dll" CoverageZone.Method "BinarySearch" - add 0u "VSharp.ML.GameMaps.dll" CoverageZone.Method "Switches1" add 0u "VSharp.ML.GameMaps.dll" CoverageZone.Method "Switches2" @@ -26,6 +25,7 @@ let trainMaps, validationMaps = add 0u "VSharp.ML.GameMaps.dll" CoverageZone.Method "Switches3" add 0u "VSharp.ML.GameMaps.dll" CoverageZone.Method "Switches4" add 0u "VSharp.ML.GameMaps.dll" CoverageZone.Method "Switches5" + add 0u "VSharp.ML.GameMaps.dll" CoverageZone.Method "NestedFors" add 0u "VSharp.ML.GameMaps.dll" CoverageZone.Method "SearchKMP" @@ -236,7 +236,7 @@ let trainMaps, validationMaps = add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "TI83LinkPort.Update" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradGateArray.ClockCycle" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradGateArray.OnHSYNC" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradGateArray.OnHSYNC" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradGateArray.GetVideoBuffer" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CRCT_6845.ClockCycle" @@ -250,31 +250,31 @@ let trainMaps, validationMaps = add 0u "Virtu.dll" CoverageZone.Method "DiskIIController.WriteIoRegionC0C0" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "MC68000.Disassemble" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CP1610.Disassemble" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CP1610.Execute" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "F3850.FetchInstruction" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "HuC6280.DisassembleExt" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "HuC6280.Execute" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "HuC6280.DisassembleCDL" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CP1610.Disassemble" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CP1610.Execute" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "F3850.FetchInstruction" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "HuC6280.DisassembleExt" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "HuC6280.Execute" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "HuC6280.DisassembleCDL" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "I8048.Disassemble" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "I8048.ExecuteOne" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.BuildInstructionTable" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.ExecuteOne" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.Disassemble" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.ADDS_Func" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.DA_Func" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.BuildInstructionTable" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.ExecuteOne" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.Disassemble" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.ADDS_Func" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "LR35902.DA_Func" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "MC6800.ExecuteOne" //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "MC6800.DA_Func" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "MOS6502X.State" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "x86.Execute" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "x86.Execute" //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Emu83.LoadStateBinary" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "NECUPD765.ReadPort" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "NECUPD765.SetUnitSelect" //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AY38912.PortWrite" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPCBase.PollInput" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPCBase.LoadAllMedia" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPCBase.LoadAllMedia" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPCBase.DecodeINPort" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "GateArrayBase.SetupScreenMapping" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "GateArrayBase.ClockCycle" @@ -282,7 +282,7 @@ let trainMaps, validationMaps = //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPC6128.WriteBus" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPC6128.InitROM" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPC6128.ReadPort" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPC6128.ReadPort" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CPC6128.WritePort" //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradCPC.AmstradCPC" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradCPC.GetFirmware" @@ -293,34 +293,34 @@ let trainMaps, validationMaps = // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradCPC.OSD_ShowTapeStatus" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "AmstradCPC.CheckMessageSettings" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "SoundProviderMixer.GetSamplesSync" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CartridgeDevice.Load" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "CartridgeDevice.Load" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Mapper0000.Mapper0000" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Mapper000F.Mapper000F" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Mapper0005.Mapper0005" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Mapper0013.Mapper0013" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Mapper0020.Mapper0020" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "D64.Read" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "DiskBuilder.Build" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "G64.Read" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "D64.Read" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "DiskBuilder.Build" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "G64.Read" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "G64.Write" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Prg.Load" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Tape.ExecuteCycle" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Tape.Load" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Tape.Load" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Chip6510.Read" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Chip6510.Write" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Chip6526.CreateCia1" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Chip90611401.Write" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Cia.ExecutePhase" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Cia.ExecutePhase" // add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Cia.Write" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Sid.Flush" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Sid.filter_operator" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Sid.filter_operator" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Envelope.ExecutePhase2" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Sid.Write" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Sid.GetSamplesSync" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Via.ExecutePhase" - add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Via.SyncState" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Sid.GetSamplesSync" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Via.ExecutePhase" + //add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Via.SyncState" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Via.Write" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Vic.ExecutePhase1" add 0u "BizHawk.Emulation.Cores.dll" CoverageZone.Method "Vic.Read" @@ -337,12 +337,11 @@ let trainMaps, validationMaps = //add 0u "Algorithms.dll" CoverageZone.Method "BellmanFordShortestPaths.ShortestPathTo" - //add 0u "VSharp.ML.GameMaps.dll" CoverageZone.Method "KruskalMST" //add 20u "VSharp.ML.GameMaps.dll" CoverageZone.Method "KruskalMST" //add 40u "VSharp.ML.GameMaps.dll" CoverageZone.Method "KruskalMST" //add 60u "VSharp.ML.GameMaps.dll" CoverageZone.Method "KruskalMST" - + let add = add' validationMaps //!!!add 1000u 0u "VSharp.ML.GameMaps.dll" CoverageZone.Method "mergeSort" @@ -361,5 +360,4 @@ let trainMaps, validationMaps = add 5000u 0u "JetBrains.Lifetimes.dll" CoverageZone.Method "ReactiveEx.AdviseUntil" add 5000u 0u "JetBrains.Lifetimes.dll" CoverageZone.Method "ReactiveEx.AdviseOnce" add 5000u 0u "JetBrains.Lifetimes.dll" CoverageZone.Method "Types.ToString" - trainMaps, validationMaps From 173f15ab660c697a3d94b138da772a9c7fc22b19 Mon Sep 17 00:00:00 2001 From: Anya497 Date: Fri, 11 Aug 2023 16:17:44 +0300 Subject: [PATCH 42/47] Remove obsolete file --- VSharp.ML.AIAgent/ml/common_model.py | 202 --------------------------- 1 file changed, 202 deletions(-) delete mode 100644 VSharp.ML.AIAgent/ml/common_model.py diff --git a/VSharp.ML.AIAgent/ml/common_model.py b/VSharp.ML.AIAgent/ml/common_model.py deleted file mode 100644 index 49d8022ef..000000000 --- a/VSharp.ML.AIAgent/ml/common_model.py +++ /dev/null @@ -1,202 +0,0 @@ -import torch -from common.constants import DEVICE -import os -import re - -from torch import nn -from torch_geometric.nn import ( - GATConv, - SAGEConv, -) -from torchvision.ops import MLP - -from common.game import GameState -from ml.data_loader_compact import ServerDataloaderHeteroVector -from ml.model_wrappers.protocols import Predictor -from ml.predict_state_vector_hetero import PredictStateVectorHetGNN -import csv - -from learning.play_game import play_game -from config import GeneralConfig -from connection.game_server_conn.utils import MapsType -from ml.models import SAGEConvModel - -csv_path = '../report/epochs_tables/' -models_path = '../report/epochs_best/' - - -class CommonModel(torch.nn.Module): - def __init__( - self, - hidden_channels, - num_gv_layers=2, - num_sv_layers=2, - ): - super().__init__() - self.gv_layers = nn.ModuleList() - self.gv_layers.append(SAGEConv(-1, hidden_channels)) - for i in range(num_gv_layers - 1): - sage_gv = SAGEConv(-1, hidden_channels) - self.gv_layers.append(sage_gv) - - self.sv_layers = nn.ModuleList() - self.sv_layers.append(SAGEConv(-1, hidden_channels)) - for i in range(num_sv_layers - 1): - sage_sv = SAGEConv(-1, hidden_channels) - self.sv_layers.append(sage_sv) - - self.history1 = GATConv((-1, -1), hidden_channels, add_self_loops=False) - - self.in1 = SAGEConv((-1, -1), hidden_channels) - - self.sv_layers2 = nn.ModuleList() - self.sv_layers2.append(SAGEConv(-1, hidden_channels)) - for i in range(num_sv_layers - 1): - sage_sv = SAGEConv(-1, hidden_channels) - self.sv_layers2.append(sage_sv) - - self.mlp = MLP(hidden_channels, [1]) - - def forward(self, x_dict, edge_index_dict, edge_attr_dict): - game_x = self.gv_layers[0]( - x_dict["game_vertex"], - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - for layer in self.gv_layers[1:]: - game_x = layer( - game_x, - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - state_x = self.sv_layers[0]( - x_dict["state_vertex"], - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - for layer in self.sv_layers[1:]: - state_x = layer( - state_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - history_x = self.history1( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - edge_attr_dict, - size=(game_x.size(0), state_x.size(0)) - ).relu() - - in_x = self.in1( - (game_x, history_x), - edge_index_dict[("game_vertex", "in", "state_vertex")] - ).relu() - - state_x = self.sv_layers2[0]( - in_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - for layer in self.sv_layers2[1:]: - state_x = layer( - state_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - x = self.mlp(in_x) - return x - - -def euclidean_dist(y_pred, y_true): - y_pred_min, ind1 = torch.min(y_pred, dim=0) - y_pred_norm = y_pred - y_pred_min - - y_true_min, ind1 = torch.min(y_true, dim=0) - y_true_norm = y_true - y_true_min - - return torch.sqrt(torch.sum((y_pred_norm - y_true_norm) ** 2)) - - -class CommonModelWrapper(Predictor): - def __init__(self, model: torch.nn.Module, best_models: dict) -> None: - self.model = model - self.best_models = best_models - self._model = model - - def name(self) -> str: - return "MY AWESOME MODEL" - - def model(self): - return self._model - - def predict(self, input: GameState, map_name): - hetero_input, state_map = ServerDataloaderHeteroVector.convert_input_to_tensor( - input - ) - assert self.model is not None - - next_step_id = PredictStateVectorHetGNN.predict_state_single_out( - self.model, hetero_input, state_map - ) - - back_prop(self.best_models[map_name], self.model, hetero_input) - - del hetero_input - return next_step_id - - -def get_last_epoch_num(path): - epochs = list(map(lambda x: re.findall('[0-9]+', x), os.listdir(path))) - return str(sorted(epochs)[-1][0]) - - -def csv2best_models(): - best_models = {} - values = [] - models_names = [] - - with open(csv_path + get_last_epoch_num(csv_path) + '.csv', 'r') as csv_file: - csv_reader = csv.reader(csv_file) - map_names = next(csv_reader)[1:] - for row in csv_reader: - models_names.append(row[0]) - int_row = list(map(lambda x: int(x), row[1:])) - values.append(int_row) - val, ind = torch.max(torch.tensor(values), dim=0) - for i in range(len(map_names)): - best_models[map_names[i]] = models_names[ind[i]] - return best_models - - -def back_prop(best_model, model, data): - model.train() - data.to(DEVICE) - optimizer.zero_grad() - ref_model = SAGEConvModel(16) - ref_model.load_state_dict(torch.load(models_path + "epoch_" + get_last_epoch_num(models_path) + "/" + best_model + ".pth")) - ref_model.to(DEVICE) - out = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) - y_true = ref_model(data.x_dict, data.edge_index_dict, data.edge_attr_dict) - - loss = criterion(out, y_true) - loss.backward() - optimizer.step() - - -model = CommonModel(32) -model.to(DEVICE) - -cmwrapper = CommonModelWrapper(model, csv2best_models()) -lr = 0.0001 -epochs = 10 -optimizer = torch.optim.Adam(model.parameters(), lr=lr) -criterion = euclidean_dist - - -def train(): - for epoch in range(epochs): - # some function with some parameters - play_game( - with_predictor=cmwrapper, - max_steps=GeneralConfig.MAX_STEPS, - maps_type=MapsType.TRAIN, - ) - - -train() From 6087f569c48a6c1e40a7b5efb30da9070061e3b9 Mon Sep 17 00:00:00 2001 From: Semyon Grigorev Date: Mon, 14 Aug 2023 14:48:20 +0300 Subject: [PATCH 43/47] added best_models dict saving --- VSharp.ML.AIAgent/ml/common_model/paths.py | 1 + VSharp.ML.AIAgent/ml/common_model/utils.py | 42 ++++++++++++++++++- VSharp.ML.AIAgent/ml/common_model/wrapper.py | 23 ++++++++-- .../run_common_model_training.py | 36 ++++++++++++---- 4 files changed, 89 insertions(+), 13 deletions(-) diff --git a/VSharp.ML.AIAgent/ml/common_model/paths.py b/VSharp.ML.AIAgent/ml/common_model/paths.py index 4ee8c313e..f677b5e96 100644 --- a/VSharp.ML.AIAgent/ml/common_model/paths.py +++ b/VSharp.ML.AIAgent/ml/common_model/paths.py @@ -3,3 +3,4 @@ csv_path = os.path.join("report", "epochs_tables") models_path = os.path.join("report", "epochs_best") common_models_path = os.path.join("report", "common_models") +best_models_dict_path = os.path.join("report", "updated_best_models_dicts") diff --git a/VSharp.ML.AIAgent/ml/common_model/utils.py b/VSharp.ML.AIAgent/ml/common_model/utils.py index 989035290..7a1ee83c6 100644 --- a/VSharp.ML.AIAgent/ml/common_model/utils.py +++ b/VSharp.ML.AIAgent/ml/common_model/utils.py @@ -2,12 +2,14 @@ import csv import os import re +from pathlib import Path import torch from config import GeneralConfig -from ml.common_model.paths import csv_path, models_path +from ml.common_model.paths import csv_path, models_path, common_models_path from ml.models import SAGEConvModel +from ml.utils import load_model def euclidean_dist(y_pred, y_true): @@ -61,7 +63,11 @@ def csv2best_models(): ) ref_model.load_state_dict(torch.load(path_to_model)) ref_model.to(GeneralConfig.DEVICE) - best_models[map_name] = (ref_model, best_model_score[map_name]) + best_models[map_name] = ( + ref_model, + best_model_score[map_name], + best_model_name, + ) return best_models @@ -84,3 +90,35 @@ def back_prop(best_model, model, data, optimizer, criterion): loss.backward() optimizer.step() return loss + + +def save_best_models2csv(best_models: dict, path): + values_for_csv = [] + for map_name in best_models.keys: + values_for_csv.append( + { + "map_name": map_name, + "best_model_name": best_models[map_name][2], + "result": best_models[map_name][1], + } + ) + with open(path, "w") as csv_file: + writer = csv.DictWriter( + csv_file, fieldnames=["map_name", "best_model_name", "result"] + ) + + +def load_best_models_dict(path): + best_models = csv2best_models() + with open(path, "r") as csv_file: + csv_reader = csv.reader(csv_file) + for row in csv_reader: + if row[1] != best_models[row[0]][2]: + path_to_model = os.path.join(common_models_path, row[1]) + ref_model = load_model( + Path(path_to_model), model=GeneralConfig.EXPORT_MODEL_INIT() + ) + + ref_model.load_state_dict(torch.load(path_to_model)) + ref_model.to(GeneralConfig.DEVICE) + best_models[row[0]][0] = ref_model diff --git a/VSharp.ML.AIAgent/ml/common_model/wrapper.py b/VSharp.ML.AIAgent/ml/common_model/wrapper.py index bee34e3f9..26588d61a 100644 --- a/VSharp.ML.AIAgent/ml/common_model/wrapper.py +++ b/VSharp.ML.AIAgent/ml/common_model/wrapper.py @@ -13,11 +13,17 @@ class CommonModelWrapper(Predictor): def __init__(self, model: torch.nn.Module, best_models: dict) -> None: self.best_models = best_models self._model = model + self.model_copy = model + self._name = "1" # self.name = sum(torch.cat([p.view(-1) for p in self.model.parameters()], dim=0)) def name(self): return "Common model" + def make_copy(self, model_name: str): + self.model_copy = copy.deepcopy(self._model) + self._name = model_name + def update(self, map_name, map_result): map_result = map_result.actual_coverage_percent if self.best_models[map_name][1] <= map_result: @@ -25,7 +31,10 @@ def update(self, map_name, map_result): f"The model with result = {self.best_models[map_name][1]} was replaced with the model with " f"result = {map_result} on the map {map_name}" ) - self.best_models[map_name] = (copy.deepcopy(self._model), map_result) + self.best_models[map_name] = ( + self.model_copy, + map_result, + ) def model(self): return self._model @@ -45,7 +54,9 @@ def predict(self, input: GameState, map_name): class BestModelsWrapper(Predictor): - def __init__(self, model: torch.nn.Module, best_models: dict, optimizer, criterion) -> None: + def __init__( + self, model: torch.nn.Module, best_models: dict, optimizer, criterion + ) -> None: self.best_models = best_models self._model = model self.optimizer = optimizer @@ -70,7 +81,13 @@ def predict(self, input: GameState, map_name): self.best_models[map_name][0], hetero_input, state_map ) - back_prop(self.best_models[map_name][0], self._model, hetero_input, self.optimizer, self.criterion) + back_prop( + self.best_models[map_name][0], + self._model, + hetero_input, + self.optimizer, + self.criterion, + ) del hetero_input return next_step_id diff --git a/VSharp.ML.AIAgent/run_common_model_training.py b/VSharp.ML.AIAgent/run_common_model_training.py index 25679745e..8ab29dd4e 100644 --- a/VSharp.ML.AIAgent/run_common_model_training.py +++ b/VSharp.ML.AIAgent/run_common_model_training.py @@ -10,17 +10,19 @@ from epochs_statistics.tables import create_pivot_table, table_to_string from learning.play_game import play_game from ml.common_model.models import CommonModel -from ml.common_model.utils import csv2best_models, euclidean_dist +from ml.common_model.utils import csv2best_models, euclidean_dist, save_best_models2csv from ml.common_model.wrapper import CommonModelWrapper, BestModelsWrapper from ml.fileop import save_model -from ml.common_model.paths import common_models_path +from ml.common_model.paths import common_models_path, best_models_dict_path from ml.model_wrappers.protocols import Predictor from ml.utils import load_model +import numpy as np LOG_PATH = Path("./ml_app.log") TABLES_PATH = Path("./ml_tables.log") COMMON_MODELS_PATH = Path(common_models_path) +BEST_MODELS_DICT = Path(best_models_dict_path) logging.basicConfig( level=GeneralConfig.LOGGER_LEVEL, @@ -32,6 +34,9 @@ if not COMMON_MODELS_PATH.exists(): os.makedirs(common_models_path) +if not BEST_MODELS_DICT.exists(): + os.makedirs(best_models_dict_path) + def create_file(file: Path): open(file, "w").close() @@ -43,8 +48,8 @@ def append_to_file(file: Path, s: str): def main(): - lr = 0.0000001 - epochs = 1 + lr = 0.000001 + epochs = 20 hidden_channels = 32 num_gv_layers = 2 num_sv_layers = 2 @@ -55,9 +60,15 @@ def main(): # common_models_path, # "1", # ) + # path_to_model = os.path.join( + # "ml", + # "pretrained_models", + # "-262.75775990410693.pth", + # ) path_to_model = os.path.join( - common_models_path, - "-262.75775990410693.pth", + "report", + "common_models", + "10", ) model = load_model(Path(path_to_model), model=GeneralConfig.EXPORT_MODEL_INIT()) @@ -68,8 +79,8 @@ def main(): optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = euclidean_dist best_models_dict = csv2best_models() - cmwrapper = CommonModelWrapper(model, best_models_dict) bmwrapper = BestModelsWrapper(model, best_models_dict, optimizer, criterion) + cmwrapper = CommonModelWrapper(model, best_models_dict) for epoch in range(epochs): # training @@ -79,18 +90,27 @@ def main(): maps_type=MapsType.TRAIN, ) # validation + cmwrapper.make_copy(str(epoch + 1)) result = play_game( with_predictor=cmwrapper, max_steps=GeneralConfig.MAX_STEPS, maps_type=MapsType.TRAIN, ) + average_result = np.average( + list(map(lambda x: x.game_result.actual_coverage_percent, result)) + ) table, _, _ = create_pivot_table({cmwrapper: result}) table = table_to_string(table) - append_to_file(TABLES_PATH, f"Epoch#{epoch}" + "\n") + append_to_file( + TABLES_PATH, + f"Epoch#{epoch}" + " Average coverage: " + str(average_result) + "\n", + ) append_to_file(TABLES_PATH, table + "\n") path_to_model = os.path.join(common_models_path, str(epoch + 1)) save_model(model=cmwrapper.model(), to=Path(path_to_model)) + path_to_best_models_dict = os.path.join(best_models_dict_path, str(epoch + 1)) + save_best_models2csv(best_models_dict, path_to_best_models_dict) if __name__ == "__main__": From ede21a7916051a547d95a962fc079f2e7b3d6f37 Mon Sep 17 00:00:00 2001 From: Semyon Grigorev Date: Mon, 14 Aug 2023 15:00:42 +0300 Subject: [PATCH 44/47] new model output format is included --- VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py b/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py index e95072ac4..fff3d4b92 100644 --- a/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py +++ b/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py @@ -123,7 +123,8 @@ def predict_state_single_out( out = model.forward(data.x_dict, data.edge_index_dict, data.edge_attr_dict) remapped = [] - + if type(out) is dict: + out = out["state_vertex"] for index, vector in enumerate(out): state_vector_mapping = StateVectorMapping( state=reversed_state_map[index], From 5f4b6497d2256bd4e2ebe1c5f72b8f89ee14cdc6 Mon Sep 17 00:00:00 2001 From: Semyon Grigorev Date: Mon, 14 Aug 2023 17:50:32 +0300 Subject: [PATCH 45/47] maps shuffle is added --- VSharp.ML.AIAgent/learning/play_game.py | 2 ++ VSharp.ML.AIAgent/ml/common_model/utils.py | 1 + VSharp.ML.AIAgent/run_common_model_training.py | 1 + 3 files changed, 4 insertions(+) diff --git a/VSharp.ML.AIAgent/learning/play_game.py b/VSharp.ML.AIAgent/learning/play_game.py index b6c52a56f..9fed98025 100644 --- a/VSharp.ML.AIAgent/learning/play_game.py +++ b/VSharp.ML.AIAgent/learning/play_game.py @@ -2,6 +2,7 @@ from statistics import StatisticsError from time import perf_counter from typing import TypeAlias +import random import tqdm from func_timeout import FunctionTimedOut, func_set_timeout @@ -117,6 +118,7 @@ def play_map_with_timeout( def play_game(with_predictor: Predictor, max_steps: int, maps_type: MapsType): with game_server_socket_manager() as ws: maps = get_maps(websocket=ws, type=maps_type) + random.shuffle(maps) with tqdm.tqdm( total=len(maps), desc=f"{with_predictor.name():20}: {maps_type.value}", diff --git a/VSharp.ML.AIAgent/ml/common_model/utils.py b/VSharp.ML.AIAgent/ml/common_model/utils.py index 7a1ee83c6..3c69826a2 100644 --- a/VSharp.ML.AIAgent/ml/common_model/utils.py +++ b/VSharp.ML.AIAgent/ml/common_model/utils.py @@ -122,3 +122,4 @@ def load_best_models_dict(path): ref_model.load_state_dict(torch.load(path_to_model)) ref_model.to(GeneralConfig.DEVICE) best_models[row[0]][0] = ref_model + return best_models diff --git a/VSharp.ML.AIAgent/run_common_model_training.py b/VSharp.ML.AIAgent/run_common_model_training.py index 8ab29dd4e..5819b7b0d 100644 --- a/VSharp.ML.AIAgent/run_common_model_training.py +++ b/VSharp.ML.AIAgent/run_common_model_training.py @@ -99,6 +99,7 @@ def main(): average_result = np.average( list(map(lambda x: x.game_result.actual_coverage_percent, result)) ) + result = sorted(result, key=lambda x: x.map.MapName) table, _, _ = create_pivot_table({cmwrapper: result}) table = table_to_string(table) append_to_file( From da67a47ead3202bd7940be6c7280a1e77c2e98d2 Mon Sep 17 00:00:00 2001 From: Semyon Grigorev Date: Tue, 15 Aug 2023 12:14:50 +0300 Subject: [PATCH 46/47] some bugs in the best models dict saving are fixed, tuple results processing is added --- VSharp.ML.AIAgent/ml/common_model/utils.py | 22 +++++++++++++------ VSharp.ML.AIAgent/ml/common_model/wrapper.py | 8 ++++++- .../run_common_model_training.py | 15 ++++++------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/VSharp.ML.AIAgent/ml/common_model/utils.py b/VSharp.ML.AIAgent/ml/common_model/utils.py index 3c69826a2..4cfd86a7f 100644 --- a/VSharp.ML.AIAgent/ml/common_model/utils.py +++ b/VSharp.ML.AIAgent/ml/common_model/utils.py @@ -30,9 +30,10 @@ def get_last_epoch_num(path): def get_tuple_for_max(t): - t[1] *= -1 - t[3] *= -1 - return t + values_list = list(t) + values_list[1] *= -1 + values_list[3] *= -1 + return tuple(values_list) def csv2best_models(): @@ -46,7 +47,9 @@ def csv2best_models(): for row in csv_reader: models_stat = dict() # int_row = list(map(lambda x: tuple(map(lambda y: int(y), x[1:-1].split(", "))), row[1:])) - int_row = list(map(lambda x: ast.literal_eval(x), row[1:])) + int_row = list( + map(lambda x: get_tuple_for_max(ast.literal_eval(x)), row[1:]) + ) for i in range(len(int_row)): models_stat[map_names[i]] = int_row[i] models.append((row[0], models_stat)) @@ -55,13 +58,17 @@ def csv2best_models(): best_model = max(models, key=(lambda m: m[1][map_name])) best_model_name = best_model[0] best_model_score = best_model[1] - ref_model = SAGEConvModel(16) + # ref_model = SAGEConvModel(16) path_to_model = os.path.join( models_path, "epoch_" + str(epoch_num), best_model_name + ".pth", ) - ref_model.load_state_dict(torch.load(path_to_model)) + # ref_model.load_state_dict(torch.load(path_to_model)) + ref_model = load_model( + Path(path_to_model), model=GeneralConfig.EXPORT_MODEL_INIT() + ) + ref_model.to(GeneralConfig.DEVICE) best_models[map_name] = ( ref_model, @@ -94,7 +101,7 @@ def back_prop(best_model, model, data, optimizer, criterion): def save_best_models2csv(best_models: dict, path): values_for_csv = [] - for map_name in best_models.keys: + for map_name in best_models.keys(): values_for_csv.append( { "map_name": map_name, @@ -106,6 +113,7 @@ def save_best_models2csv(best_models: dict, path): writer = csv.DictWriter( csv_file, fieldnames=["map_name", "best_model_name", "result"] ) + writer.writerows(values_for_csv) def load_best_models_dict(path): diff --git a/VSharp.ML.AIAgent/ml/common_model/wrapper.py b/VSharp.ML.AIAgent/ml/common_model/wrapper.py index 26588d61a..855a2fe9e 100644 --- a/VSharp.ML.AIAgent/ml/common_model/wrapper.py +++ b/VSharp.ML.AIAgent/ml/common_model/wrapper.py @@ -25,7 +25,12 @@ def make_copy(self, model_name: str): self._name = model_name def update(self, map_name, map_result): - map_result = map_result.actual_coverage_percent + map_result = ( + map_result.actual_coverage_percent, + -map_result.tests_count, + map_result.errors_count, + -map_result.steps_count, + ) if self.best_models[map_name][1] <= map_result: logging.info( f"The model with result = {self.best_models[map_name][1]} was replaced with the model with " @@ -34,6 +39,7 @@ def update(self, map_name, map_result): self.best_models[map_name] = ( self.model_copy, map_result, + self._name, ) def model(self): diff --git a/VSharp.ML.AIAgent/run_common_model_training.py b/VSharp.ML.AIAgent/run_common_model_training.py index 5819b7b0d..fccd0cb18 100644 --- a/VSharp.ML.AIAgent/run_common_model_training.py +++ b/VSharp.ML.AIAgent/run_common_model_training.py @@ -60,20 +60,19 @@ def main(): # common_models_path, # "1", # ) - # path_to_model = os.path.join( - # "ml", - # "pretrained_models", - # "-262.75775990410693.pth", - # ) path_to_model = os.path.join( - "report", - "common_models", - "10", + "ml", + "pretrained_models", + "-262.75775990410693.pth", ) model = load_model(Path(path_to_model), model=GeneralConfig.EXPORT_MODEL_INIT()) model.to(GeneralConfig.DEVICE) + for name, param in model.named_parameters(): + if "lin_last" not in name: + param.requires_grad = False + create_file(TABLES_PATH) create_file(LOG_PATH) optimizer = torch.optim.Adam(model.parameters(), lr=lr) From 17585acc7a2910e8dc0e0c50f23ff4b12a914812 Mon Sep 17 00:00:00 2001 From: Semyon Grigorev Date: Tue, 15 Aug 2023 13:16:20 +0300 Subject: [PATCH 47/47] Ignore pretrained models and reports. --- VSharp.ML.AIAgent/.gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/VSharp.ML.AIAgent/.gitignore b/VSharp.ML.AIAgent/.gitignore index 46a3eb297..54e13e45c 100644 --- a/VSharp.ML.AIAgent/.gitignore +++ b/VSharp.ML.AIAgent/.gitignore @@ -1,6 +1,7 @@ # python cache and venv .env __pycache__/ -report/ +report**/ +ml/pretrained_models/ *.pkl *.onnx