Skip to content

Commit

Permalink
Remove best_models_dict saving. Add directory for teachers. Remove su…
Browse files Browse the repository at this point in the history
…pport dict as model's output. Add function for last layer adding
  • Loading branch information
Anya497 committed Nov 23, 2023
1 parent 0d33cdc commit 1be35f1
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 256 deletions.
28 changes: 9 additions & 19 deletions VSharp.ML.AIAgent/ml/common_model/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,12 @@ def remove_similar_steps(self, map_steps):
filtered_map_steps = []
for step in map_steps:
if len(filtered_map_steps) != 0:
if (
step["y_true"]["state_vertex"].size()
== filtered_map_steps[-1]["y_true"]["state_vertex"].size()
):
if step["y_true"].size() == filtered_map_steps[-1]["y_true"].size():
cos_d = 1 - torch.sum(
(
step["y_true"]["state_vertex"]
/ torch.linalg.vector_norm(step["y_true"]["state_vertex"])
)
(step["y_true"] / torch.linalg.vector_norm(step["y_true"]))
* (
filtered_map_steps[-1]["y_true"]["state_vertex"]
/ torch.linalg.vector_norm(
filtered_map_steps[-1]["y_true"]["state_vertex"]
)
filtered_map_steps[-1]["y_true"]
/ torch.linalg.vector_norm(filtered_map_steps[-1]["y_true"])
)
)
if (
Expand Down Expand Up @@ -91,13 +83,11 @@ def remove_similar_steps(self, map_steps):
def filter_map_steps(self, map_steps):
filtered_map_steps = []
for step in map_steps:
if step["y_true"]["state_vertex"].size()[0] != 1:
if not step["y_true"]["state_vertex"].isnan().any():
max_ind = torch.argmax(step["y_true"]["state_vertex"])
step["y_true"]["state_vertex"] = torch.zeros_like(
step["y_true"]["state_vertex"]
)
step["y_true"]["state_vertex"][max_ind] = 1.0
if step["y_true"].size()[0] != 1:
if not step["y_true"].isnan().any():
max_ind = torch.argmax(step["y_true"])
step["y_true"] = torch.zeros_like(step["y_true"])
step["y_true"][max_ind] = 1.0
filtered_map_steps.append(step)

return filtered_map_steps
Expand Down
1 change: 1 addition & 0 deletions VSharp.ML.AIAgent/ml/common_model/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
dataset_root_path = os.path.join("report", "dataset")
dataset_map_results_file_name = os.path.join("report", "dataset_state.csv")
training_data_path = os.path.join("report", "run_tables")
pretrained_models_path = os.path.join("ml", "models")

path_to_models_for_parallel_architecture = os.path.join(
"ml", "pretrained_models", "models_for_parallel_architecture"
Expand Down
15 changes: 12 additions & 3 deletions VSharp.ML.AIAgent/ml/common_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
from pathlib import Path

import torch
import numpy as np

from config import GeneralConfig
from ml.common_model.paths import (
csv_path,
models_path,
common_models_path,
)
from ml.models import SAGEConvModel
from ml.utils import load_model
from ml.models.TAGSageTeacher.model_modified import StateModelEncoderLastLayer


def euclidean_dist(y_pred, y_true):
Expand Down Expand Up @@ -67,7 +68,7 @@ def csv2best_models():
best_model_name + ".pth",
)
ref_model = load_model(
Path(path_to_model), model=GeneralConfig.EXPORT_MODEL_INIT()
Path(path_to_model), model=StateModelEncoderLastLayer(32, 8)
)

ref_model.to(GeneralConfig.DEVICE)
Expand Down Expand Up @@ -122,7 +123,7 @@ def load_best_models_dict(path):
if row[1] != best_models[row[0]][2]:
path_to_model = os.path.join(common_models_path, row[1])
ref_model = load_model(
Path(path_to_model), model=GeneralConfig.EXPORT_MODEL_INIT()
Path(path_to_model), model=StateModelEncoderLastLayer(32, 8)
)
ref_model.load_state_dict(torch.load(path_to_model))
ref_model.to(GeneralConfig.DEVICE)
Expand All @@ -139,3 +140,11 @@ def load_dataset_state_dict(path):
for row in csv_reader:
dataset_state_dict[row[0]] = ast.literal_eval(row[1])
return dataset_state_dict


def get_model(path_to_weights: Path, model: torch.nn.Module):
weights = torch.load(path_to_weights)
weights["lin_last.weight"] = torch.tensor(np.random.random([1, 8]))
weights["lin_last.bias"] = torch.tensor(np.random.random([1]))
model.load_state_dict(weights)
return model
10 changes: 2 additions & 8 deletions VSharp.ML.AIAgent/ml/common_model/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import logging

import torch
from predict import predict_state_single_out, predict_state_with_dict

from common.game import GameState
from ml.common_model.utils import back_prop
from ml.data_loader_compact import ServerDataloaderHeteroVector
from ml.model_wrappers.protocols import Predictor

from predict import predict_state_with_dict
from ml.predict import predict_state_with_dict


class CommonModelWrapper(Predictor):
Expand Down Expand Up @@ -48,24 +46,20 @@ def predict(self, input: GameState, map_name):
class BestModelsWrapper(Predictor):
def __init__(
self,
model: torch.nn.Module,
best_models: dict,
) -> None:
self.best_models = best_models
self._model = model

def name(self):
return "Best model"

def model(self):
return self._model
return None

def predict(self, input: GameState, map_name):
hetero_input, state_map = ServerDataloaderHeteroVector.convert_input_to_tensor(
input
)
assert self._model is not None

next_step_id, nn_output = predict_state_with_dict(
self.best_models[map_name][0], hetero_input, state_map
)
Expand Down
182 changes: 0 additions & 182 deletions VSharp.ML.AIAgent/ml/model_modified.py

This file was deleted.

2 changes: 1 addition & 1 deletion VSharp.ML.AIAgent/ml/model_wrappers/nnwrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

import torch.nn
from predict import predict_state_with_dict
from ml.predict import predict_state_with_dict

from common.game import GameState
from ml.data_loader_compact import ServerDataloaderHeteroVector
Expand Down
40 changes: 40 additions & 0 deletions VSharp.ML.AIAgent/ml/models/TAGSageSimple/model_modified.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch.nn import Linear
from torch_geometric.nn import Linear

from learning.timer.wrapper import timeit
from torch.nn.functional import softmax
from .model import StateModelEncoder


class StateModelEncoderLastLayer(StateModelEncoder):
def __init__(self, hidden_channels, out_channels):
super().__init__(hidden_channels, out_channels)
self.lin_last = Linear(out_channels, 1)

def forward(
self,
game_x,
state_x,
edge_index_v_v,
edge_type_v_v,
edge_index_history_v_s,
edge_attr_history_v_s,
edge_index_in_v_s,
edge_index_s_s,
):
return softmax(
self.lin_last(
super().forward(
game_x=game_x,
state_x=state_x,
edge_index_v_v=edge_index_v_v,
edge_type_v_v=edge_type_v_v,
edge_index_history_v_s=edge_index_history_v_s,
edge_attr_history_v_s=edge_attr_history_v_s,
edge_index_in_v_s=edge_index_in_v_s,
edge_index_s_s=edge_index_s_s,
)
),
dim=0,
)
Loading

0 comments on commit 1be35f1

Please sign in to comment.