Skip to content

Commit

Permalink
Merge branch 'rebased_common_model' into mlSearcher
Browse files Browse the repository at this point in the history
  • Loading branch information
gsvgit committed Aug 15, 2023
2 parents 5fea8de + 17585ac commit ee0e59d
Show file tree
Hide file tree
Showing 14 changed files with 561 additions and 244 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ VSharp.Test/GeneratedTests/**

#Pytjon environments
**/.env/**
/torch_venv/**

#Python caches
**/__pycache__/**

#logs
**/logs_full*/**

/output/**
/References/*.dll
3 changes: 2 additions & 1 deletion VSharp.ML.AIAgent/.gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# python cache and venv
.env
__pycache__/
report/
report**/
ml/pretrained_models/
*.pkl
*.onnx
3 changes: 3 additions & 0 deletions VSharp.ML.AIAgent/learning/play_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from statistics import StatisticsError
from time import perf_counter
from typing import TypeAlias
import random

import tqdm
from func_timeout import FunctionTimedOut, func_set_timeout
Expand Down Expand Up @@ -83,6 +84,7 @@ def play_map(
actual_coverage_percent=actual_coverage,
)

with_predictor.update(with_connector.map.MapName, model_result)
return model_result, end_time - start_time


Expand Down Expand Up @@ -116,6 +118,7 @@ def play_map_with_timeout(
def play_game(with_predictor: Predictor, max_steps: int, maps_type: MapsType):
with game_server_socket_manager() as ws:
maps = get_maps(websocket=ws, type=maps_type)
random.shuffle(maps)
with tqdm.tqdm(
total=len(maps),
desc=f"{with_predictor.name():20}: {maps_type.value}",
Expand Down
209 changes: 0 additions & 209 deletions VSharp.ML.AIAgent/ml/common_model.py

This file was deleted.

89 changes: 89 additions & 0 deletions VSharp.ML.AIAgent/ml/common_model/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
from torch import nn
from torch_geometric.nn import (
GATConv,
SAGEConv,
TAGConv,
)
from torchvision.ops import MLP


class CommonModel(torch.nn.Module):
def __init__(
self,
hidden_channels,
num_gv_layers=2,
num_sv_layers=2,
):
super().__init__()
self.tag_conv1 = TAGConv(5, hidden_channels, 2)
self.tag_conv2 = TAGConv(6, hidden_channels, 3)
self.gv_layers = nn.ModuleList()
self.gv_layers.append(self.tag_conv1)
self.gv_layers.append(SAGEConv(-1, hidden_channels))
for i in range(num_gv_layers - 1):
sage_gv = SAGEConv(-1, hidden_channels)
self.gv_layers.append(sage_gv)

self.sv_layers = nn.ModuleList()
self.sv_layers.append(self.tag_conv2)
self.sv_layers.append(SAGEConv(-1, hidden_channels))
for i in range(num_sv_layers - 1):
sage_sv = SAGEConv(-1, hidden_channels)
self.sv_layers.append(sage_sv)

self.history1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)

self.in1 = SAGEConv((-1, -1), hidden_channels)

self.sv_layers2 = nn.ModuleList()
self.sv_layers2.append(SAGEConv(-1, hidden_channels))
for i in range(num_sv_layers - 1):
sage_sv = SAGEConv(-1, hidden_channels)
self.sv_layers2.append(sage_sv)

self.mlp = MLP(hidden_channels, [1])

def forward(self, x_dict, edge_index_dict, edge_attr_dict):
game_x = self.gv_layers[0](
x_dict["game_vertex"],
edge_index_dict[("game_vertex", "to", "game_vertex")],
).relu()
for layer in self.gv_layers[1:]:
game_x = layer(
game_x,
edge_index_dict[("game_vertex", "to", "game_vertex")],
).relu()

state_x = self.sv_layers[0](
x_dict["state_vertex"],
edge_index_dict[("state_vertex", "parent_of", "state_vertex")],
).relu()
for layer in self.sv_layers[1:]:
state_x = layer(
state_x,
edge_index_dict[("state_vertex", "parent_of", "state_vertex")],
).relu()

history_x = self.history1(
(game_x, state_x),
edge_index_dict[("game_vertex", "history", "state_vertex")],
edge_attr_dict,
size=(game_x.size(0), state_x.size(0)),
).relu()

in_x = self.in1(
(game_x, history_x), edge_index_dict[("game_vertex", "in", "state_vertex")]
).relu()

state_x = self.sv_layers2[0](
in_x,
edge_index_dict[("state_vertex", "parent_of", "state_vertex")],
).relu()
for layer in self.sv_layers2[1:]:
state_x = layer(
state_x,
edge_index_dict[("state_vertex", "parent_of", "state_vertex")],
).relu()
x = self.mlp(in_x)
return x
6 changes: 6 additions & 0 deletions VSharp.ML.AIAgent/ml/common_model/paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os

csv_path = os.path.join("report", "epochs_tables")
models_path = os.path.join("report", "epochs_best")
common_models_path = os.path.join("report", "common_models")
best_models_dict_path = os.path.join("report", "updated_best_models_dicts")
Loading

0 comments on commit ee0e59d

Please sign in to comment.