Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update UnsupGraphsage #425

Merged
merged 3 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cogdl/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,15 @@ def __getitem__(self, idx):
"""
batch = self.node_idx[idx * self.batch_size : (idx + 1) * self.batch_size]
self.random_walker.build_up(self.edge_index, self.total_num_nodes)
walk_res=self.random_walker.walk_one(batch,length=1,p=0.0)

walk_res = self.random_walker.walk(
batch, walk_length=2, parallel=False
)[:,1]

neg_batch = torch.randint(0, self.total_num_nodes, (batch.numel(), ),
dtype=torch.int64)
pos_batch=torch.tensor(walk_res)
batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
if self.sizes != [-1]:
batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
node_id = batch
adj_list = []
for size in self.sizes:
Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def build_model(args):
"sortpool": "cogdl.models.nn.sortpool.SortPool",
"srgcn": "cogdl.models.nn.srgcn.SRGCN",
"gcc": "cogdl.models.nn.gcc_model.GCCModel",
"unsup_graphsage": "cogdl.models.nn.graphsage.Graphsage",
"unsup_graphsage": "cogdl.models.nn.graphsage.UnsupGraphsage",
"graphsaint": "cogdl.models.nn.graphsaint.GraphSAINT",
"m3s": "cogdl.models.nn.m3s.M3S",
"moe_gcn": "cogdl.models.nn.moe_gcn.MoEGCN",
Expand Down
16 changes: 16 additions & 0 deletions cogdl/models/nn/graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,19 @@ def forward(self, graph):
for layer in self.layers:
x = layer(graph, x)
return x

class UnsupGraphsage(Graphsage):
def __init__(self, num_features, num_classes, hidden_size, num_layers, sample_size, dropout, aggr):
super(Graphsage, self).__init__()
assert num_layers == len(sample_size)
self.adjlist = {}
self.num_features = num_features
self.num_classes = num_classes
self.hidden_size = hidden_size
self.num_layers = num_layers
self.sample_size = sample_size
self.dropout = dropout
shapes = [num_features] + hidden_size * num_layers
self.convs = nn.ModuleList(
[SAGELayer(shapes[layer], shapes[layer + 1], aggr=aggr) for layer in range(num_layers)]
)
22 changes: 0 additions & 22 deletions cogdl/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,3 @@ def walk(self, start, walk_length, restart_p=0.0, parallel=True):
result = random_walk_single(start, walk_length, self.indptr, self.indices, restart_p)
result = np.array(result, dtype=np.int64)
return result

def walk_one(self, start, length, p):
walk_res = [np.zeros(length, dtype=np.int32)] * len(start)
p = 0.0
for i in range(len(start)):
node = start[i]
result = [np.int32(0)] * length
index = np.int32(0)
_node = node
while index < length:
start1 = self.indptr[node]
end1 = self.indptr[node + 1]
sample1 = random.randint(start1, end1 - 1)
node = self.indices[sample1]
if np.random.uniform(0, 1) > p:
result[index] = node
else:
result[index] = _node
index += 1
k = int(np.floor(np.random.rand() * len(result)))
walk_res[i] = result[k]
return walk_res
Original file line number Diff line number Diff line change
@@ -1,52 +1,54 @@
import torch

import numpy as np
from cogdl.utils import RandomWalker
from cogdl.wrappers.tools.wrapper_utils import evaluate_node_embeddings_using_logreg
from cogdl.wrappers.tools.wrapper_utils import evaluate_node_embeddings_using_liblinear
from .. import UnsupervisedModelWrapper

from torch.nn import functional as F

class UnsupGraphSAGEModelWrapper(UnsupervisedModelWrapper):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument("--num-shuffle", type=int, default=1)
parser.add_argument("--training-percents", default=[0.2], type=float, nargs="+")
parser.add_argument("--walk-length", type=int, default=10)
parser.add_argument("--negative-samples", type=int, default=30)
# fmt: on

def __init__(self, model, optimizer_cfg, walk_length, negative_samples):
def __init__(self, model, optimizer_cfg, walk_length, negative_samples, num_shuffle=1, training_percents=[0.1]):
super(UnsupGraphSAGEModelWrapper, self).__init__()
self.model = model
self.optimizer_cfg = optimizer_cfg
self.walk_length = walk_length
self.num_negative_samples = negative_samples
self.num_shuffle = num_shuffle
self.training_percents = training_percents


def train_step(self, batch):
x_src, adjs = batch
out = self.model(x_src,adjs)
out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)

pos_loss = torch.log(torch.sigmoid((out * pos_out).sum(-1)).mean())
neg_loss = torch.log(torch.sigmoid(-(out * neg_out).sum(-1)).mean())
pos_loss = F.logsigmoid((out * pos_out).sum(-1)).mean()
neg_loss = F.logsigmoid(-(out * neg_out).sum(-1)).mean()
loss = -pos_loss - neg_loss
return loss


def test_step(self, batch):
dataset, test_loader = batch
def test_step(self, graph):
dataset, test_loader = graph
graph = dataset.data
if hasattr(self.model, "inference"):
pred = self.model.inference(graph.x, test_loader)
with torch.no_grad():
if hasattr(self.model, "inference"):
pred = self.model.inference(graph.x, test_loader)
else:
pred = self.model(graph)
if len(graph.y.shape) > 1:
self.label_matrix = graph.y.numpy()
else:
pred = self.model(graph)
pred= pred.split(pred.size(0) // 3, dim=0)[0]
pred = pred[graph.test_mask]
y = graph.y[graph.test_mask]

metric = self.evaluate(pred, y, metric="auto")
self.note("test_loss", self.default_loss_fn(pred, y))
self.note("test_metric", metric)
self.label_matrix = np.zeros((graph.num_nodes, graph.num_classes), dtype=int)
self.label_matrix[range(graph.num_nodes), graph.y.numpy()] = 1
return evaluate_node_embeddings_using_liblinear(pred, self.label_matrix, self.num_shuffle, self.training_percents)


def setup_optimizer(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/ssl/test_contrastive_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_unsupervised_graphsage():
args.epochs = 2
args.checkpoint_path = "graphsage.pt"
ret = train(args)
assert ret["test_acc"] > 0
assert ret["micro-f1 0.1"] > 0


def test_dgi():
Expand Down