Source code for neuralkg_ind.model.GNNModel.layer

import abc
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

[docs]class GNNLayer(nn.Module): def __init__(self): super().__init__()
[docs] def message(self, edges): raise NotImplementedError
[docs] def forward(self, g, feat): raise NotImplementedError
[docs]class Aggregator(nn.Module): def __init__(self, emb_dim): super(Aggregator, self).__init__()
[docs] def forward(self, node): curr_emb = node.mailbox['curr_emb'][:, 0, :] nei_msg = torch.bmm(node.mailbox['alpha'].transpose(1, 2), node.mailbox['msg']).squeeze(1) new_emb = self.update_embedding(curr_emb, nei_msg) return {'h': new_emb}
[docs] @abc.abstractmethod def update_embedding(curr_emb, nei_msg): raise NotImplementedError
[docs]class SUMAggregator(Aggregator): def __init__(self, emb_dim): super(SUMAggregator, self).__init__(emb_dim)
[docs] def update_embedding(self, curr_emb, nei_msg): new_emb = nei_msg + curr_emb return new_emb
[docs]class MLPAggregator(Aggregator): def __init__(self, emb_dim): super(MLPAggregator, self).__init__(emb_dim) self.linear = nn.Linear(2 * emb_dim, emb_dim)
[docs] def update_embedding(self, curr_emb, nei_msg): inp = torch.cat((nei_msg, curr_emb), 1) new_emb = F.relu(self.linear(inp)) return new_emb
[docs]class GRUAggregator(Aggregator): def __init__(self, emb_dim): super(GRUAggregator, self).__init__(emb_dim) self.gru = nn.GRUCell(emb_dim, emb_dim)
[docs] def update_embedding(self, curr_emb, nei_msg): new_emb = self.gru(nei_msg, curr_emb) return new_emb
[docs]class BatchGRU(nn.Module): def __init__(self, hidden_size=300): super(BatchGRU, self).__init__() self.hidden_size = hidden_size self.gru = nn.GRU(self.hidden_size, self.hidden_size, batch_first=True, bidirectional=True) self.bias = nn.Parameter(torch.Tensor(self.hidden_size)) self.bias.data.uniform_(-1.0 / math.sqrt(self.hidden_size), 1.0 / math.sqrt(self.hidden_size))
[docs] def forward(self, node, a_scope): hidden = node message = F.relu(node + self.bias) MAX_node_len = max(a_scope) # padding message_lst = [] hidden_lst = [] a_start = 0 for i in a_scope: i = int(i) if i == 0: assert 0 cur_message = message.narrow(0, a_start, i) cur_hidden = hidden.narrow(0, a_start, i) hidden_lst.append(cur_hidden.max(0)[0].unsqueeze(0).unsqueeze(0)) a_start += i cur_message = torch.nn.ZeroPad2d((0,0,0,MAX_node_len-cur_message.shape[0]))(cur_message) message_lst.append(cur_message.unsqueeze(0)) message_lst = torch.cat(message_lst, 0) hidden_lst = torch.cat(hidden_lst, 1) hidden_lst = hidden_lst.repeat(2,1,1) cur_message, cur_hidden = self.gru(message_lst, hidden_lst) cur_message_unpadding = [] kk = 0 for a_size in a_scope: a_size = int(a_size) cur_message_unpadding.append(cur_message[kk, :a_size].view(-1, 2*self.hidden_size)) kk += 1 cur_message_unpadding = torch.cat(cur_message_unpadding, 0) return cur_message_unpadding