Source code for neuralkg_ind.model.GNNModel.CoMPILE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable    

import numpy as np
from .layer import BatchGRU

[docs]class CoMPILE(nn.Module): """`Communicative Message Passing for Inductive Relation Reasoning`_ (CoMPILE), which reasons over local directed subgraph structures and strengthens the message interactions between edges and entitles through a communicative kernel. Attributes: args: Model configuration parameters. latent_dim: Latent dimension. output_dim: Output dimension. node_emb: Dimension of node embedding. relation_emb: Dimension of relation embedding. hidden_size: Size of hidden layer. .. _Communicative Message Passing for Inductive Relation Reasoning: https://arxiv.org/pdf/2012.08911 """ def __init__(self, args): super(CoMPILE, self).__init__() self.args = args self.latent_dim = self.args.emb_dim self.output_dim = 1 self.node_emb = self.args.inp_dim self.relation_emb = self.args.rel_emb_dim self.edge_emb = self.node_emb * 2 + self.relation_emb self.hidden_size = self.args.emb_dim self.final_relation_embeddings = nn.Parameter(torch.randn(self.args.aug_num_rels, self.args.rel_emb_dim)) self.relation_to_edge = nn.Linear(self.args.rel_emb_dim, self.hidden_size) self.linear1 = nn.Linear(self.args.emb_dim , 16) self.linear2 = nn.Linear(16, 1) self.node_fdim = self.node_emb self.edge_fdim = self.edge_emb self.bias = False self.depth = 3 self.dropout = 0.5 self.layers_per_message = 1 self.undirected = False self.node_messages = False self.dropout_layer = nn.Dropout(p=self.dropout) self.act_func = nn.ReLU() self.cached_zero_vector = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False) # Input input_dim = self.node_fdim self.W_i_node = nn.Linear(input_dim, self.hidden_size, bias=self.bias) input_dim = self.edge_fdim self.W_i_edge = nn.Linear(input_dim, self.hidden_size, bias=self.bias) w_h_input_size_node = self.hidden_size + self.edge_fdim self.W_h_node = nn.Linear(w_h_input_size_node, self.hidden_size, bias=self.bias) self.input_attention1 = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=self.bias) self.input_attention2 = nn.Linear(self.hidden_size, 1, bias=self.bias) w_h_input_size_edge = self.hidden_size for depth in range(self.depth-1): self._modules['W_h_edge_{}'.format(depth)] = nn.Linear(w_h_input_size_edge, self.hidden_size, bias=self.bias) self._modules['Attention1_{}'.format(depth)] = nn.Linear(self.hidden_size + self.relation_emb, self.hidden_size, bias=self.bias) self._modules['Attention2_{}'.format(depth)] = nn.Linear(self.hidden_size, 1, bias=self.bias) self.W_o = nn.Linear(self.hidden_size * 2, self.hidden_size) self.gru = BatchGRU(self.hidden_size) self.communicate_mlp = nn.Linear(self.hidden_size*3, self.hidden_size, bias=self.bias) for depth in range(self.depth-1): self._modules['W_h_node_{}'.format(depth)] = nn.Linear(self.hidden_size, self.hidden_size, bias=self.bias)
[docs] def forward(self, subgraph): """calculating subgraphs score. Args: subgraph: Subgraph of triple. Returns: out_conv: The output of convolution layer. """ subgraph = subgraph[0] target_relation = [] for i in range(len(subgraph)): graph = subgraph[i] target = graph.edata['label'][-1].squeeze() target_relation.append(self.final_relation_embeddings[target, :].unsqueeze(0)) target_relation = torch.cat(target_relation, dim = 0) graph_embed, source_embed, target_embed = self.batch_subgraph(subgraph) conv_input = torch.tanh(source_embed + target_relation -target_embed) out_conv = (self.linear1(conv_input)) out_conv = self.linear2(out_conv) return out_conv
[docs] def batch_subgraph(self, subgraph): """calculating subgraphs score. Args: subgraph: Subgraph of triple. Returns: graph_embed: Embedding of subgraph. source_embed: Embedding of source entities. target_embed: Embedding of target entities. """ graph_sizes = []; node_feat = [] list_num_nodes = np.zeros((len(subgraph), ), dtype=np.int32) list_num_edges = np.zeros((len(subgraph), ), dtype=np.int32) node_count = 0 ; edge_count = 0; edge_feat = [] total_edge = []; source_node = []; target_node = [] total_target_relation = []; total_edge2 = [] total_source = []; total_target = [] for i in range(len(subgraph)): graph = subgraph[i] node_embedding = graph.ndata['feat'] node_feat.append(node_embedding) graph_sizes.append(graph.number_of_nodes()) list_num_nodes[i] = graph.number_of_nodes() list_num_edges[i] = graph.number_of_edges() nodes = list((graph.nodes()).data.cpu().numpy()) source = list((graph.edges()[0]).data.cpu().numpy()) target = list((graph.edges()[1]).data.cpu().numpy()) relation = graph.edata['type'] relation_now = self.final_relation_embeddings[relation, :] target_relation = graph.edata['label'] target_relation_now = self.final_relation_embeddings[target_relation, :] total_target_relation.append(target_relation_now) mapping = dict(zip(nodes, [i for i in range(node_count, node_count+list_num_nodes[i])])) source_map_now = np.array([mapping[v] for v in source]) - node_count target_map_now = np.array([mapping[v] for v in target]) - node_count source_embed = node_embedding[source_map_now, :] target_embed = node_embedding[target_map_now, :] edge_embed = torch.cat([source_embed, relation_now, target_embed], dim = 1) edge_feat.append(edge_embed) source_now = (graph.ndata['id'] == 1).nonzero().squeeze() + node_count target_now = (graph.ndata['id'] == 2).nonzero().squeeze() + node_count source_node.append(source_now) target_node.append(target_now) target_now = target_now.unsqueeze(0).repeat(list_num_edges[i], 1).long() source_now = source_now.unsqueeze(0).repeat(list_num_edges[i], 1).long() total_source.append(source_now); total_target.append(target_now) node_count += list_num_nodes[i] source_map = torch.LongTensor(np.array([mapping[v] for v in source])).unsqueeze(0) target_map = torch.LongTensor(np.array([mapping[v] for v in target])).unsqueeze(0) edge_pair = torch.cat([target_map, torch.LongTensor(np.array(range(edge_count, edge_count+list_num_edges[i]))).unsqueeze(0)], dim=0) edge_pair2 = torch.cat([source_map, torch.LongTensor(np.array(range(edge_count, edge_count+list_num_edges[i]))).unsqueeze(0)], dim=0) edge_count += list_num_edges[i] total_edge.append(edge_pair) total_edge2.append(edge_pair2) source_node = np.array(torch.tensor(source_node, device='cpu')) target_node = np.array(torch.tensor(target_node, device='cpu')) total_edge = torch.cat(total_edge, dim = 1).type_as(self.final_relation_embeddings).long() total_edge2 = torch.cat(total_edge2, dim = 1).type_as(self.final_relation_embeddings).long() total_target_relation = torch.cat(total_target_relation, dim=0) total_source = torch.cat(total_source, dim=0) total_target = torch.cat(total_target, dim=0) total_num_nodes = np.sum(list_num_nodes) total_num_edges = np.sum(list_num_edges) e2n_value = torch.FloatTensor(torch.ones(total_edge.shape[1])).type_as(self.final_relation_embeddings) e2n_sp = torch.sparse.FloatTensor(total_edge, e2n_value, torch.Size([total_num_nodes, total_num_edges])) e2n_sp2 = torch.sparse.FloatTensor(total_edge2, e2n_value, torch.Size([total_num_nodes, total_num_edges])) node_feat = torch.cat(node_feat, dim=0) edge_feat = torch.cat(edge_feat, dim=0) graph_embed, source_embed, target_embed = self.CoMPILEConv(node_feat, edge_feat, e2n_sp, e2n_sp2, graph_sizes, total_target_relation, total_source, total_target, source_node, target_node, list(list_num_edges)) return graph_embed, source_embed, target_embed
[docs] def CoMPILEConv(self, node_feat, edge_feat, e2n_sp, e2n_sp2, graph_sizes, target_relation, total_source, total_target, source_node, target_node, edge_sizes = None, node_degs=None): """calculating graph embedding, source embedding and target embedding. Args: node_feat: Feature of nodes. edge_feat: Feature of edges. e2n_sp: Sparse matrix of edges to source nodes. e2n_sp2: Sparse matrix of edges to target nodes. graph_sizes: The number of each graph nodes. target_relation: Target relation label. total_source: Total source nodes. total_target: Total target nodes. source_node: Source node of triple. target_node: Target node of triple. edge_sizes: The sizes of edges. node_degs: The degrees of nodes. Returns: gmol_vecs: Graph embedding. source_embed: source node embedding. target_embed: target node embedding. """ input_node = self.W_i_node(node_feat) # num_nodes x hidden_size input_node = self.act_func(input_node) message_node = input_node.clone() relation_embed = (edge_feat[:, self.node_emb: self.node_emb + self.relation_emb]) input_edge = self.W_i_edge(edge_feat) # num_edges x hidden_size message_edge = self.act_func(input_edge) input_edge = self.act_func(input_edge) graph_source_embed = message_node[total_source, :].squeeze(1) graph_target_embed = message_node[total_target, :].squeeze(1) graph_edge_embed = graph_source_embed + target_relation - graph_target_embed edge_target_message = gnn_spmm(e2n_sp.t(), message_node) edge_source_message = gnn_spmm(e2n_sp2.t(), message_node) edge_message = edge_source_message + relation_embed - edge_target_message attention = torch.cat([graph_edge_embed, edge_message], dim=1) attention = torch.relu(self.input_attention1(attention)) attention = torch.sigmoid(self.input_attention2(attention)) # Message passing for depth in range(self.depth - 1): message_edge = (message_edge * attention) agg_message = gnn_spmm(e2n_sp, message_edge) message_node = message_node + agg_message message_node = self.act_func(self._modules['W_h_node_{}'.format(depth)](message_node)) edge_target_message = gnn_spmm(e2n_sp.t(), message_node) edge_source_message = gnn_spmm(e2n_sp2.t(), message_node) message_edge = torch.relu(message_edge + torch.tanh( edge_source_message + relation_embed - edge_target_message)) message_edge = self._modules['W_h_edge_{}'.format(depth)](message_edge) message_edge = self.act_func(input_edge + message_edge) message_edge = self.dropout_layer(message_edge) # num_edges x hidden graph_source_embed = message_node[total_source, :].squeeze(1) graph_target_embed = message_node[total_target, :].squeeze(1) graph_edge_embed = graph_source_embed + target_relation - graph_target_embed edge_message = edge_source_message + relation_embed - edge_target_message attention = torch.cat([graph_edge_embed, edge_message], dim=1) attention = torch.relu(self._modules['Attention1_{}'.format(depth)](attention)) attention = torch.sigmoid(self._modules['Attention2_{}'.format(depth)](attention)) message_edge = (message_edge * attention) agg_message = gnn_spmm(e2n_sp, message_edge) agg_message2 = self.communicate_mlp(torch.cat([agg_message, message_node, input_node], 1)) a_message = torch.relu(self.gru(agg_message2, graph_sizes)) node_hiddens = self.act_func(self.W_o(a_message)) # num_nodes x hidden node_hiddens = self.dropout_layer(node_hiddens) # num_nodes x hidden # Readout mol_vecs = [] a_start = 0 for a_size in graph_sizes: if a_size == 0: assert 0 cur_hiddens = node_hiddens.narrow(0, a_start, a_size) mol_vecs.append(cur_hiddens.mean(0)) a_start += a_size mol_vecs = torch.stack(mol_vecs, dim=0) source_embed = node_hiddens[source_node, :] target_embed = node_hiddens[target_node, :] return mol_vecs, source_embed, target_embed
[docs]class MySpMM(torch.autograd.Function):
[docs] @staticmethod def forward(ctx, sp_mat, dense_mat): ctx.save_for_backward(sp_mat, dense_mat) return torch.mm(sp_mat, dense_mat)
[docs] @staticmethod def backward(ctx, grad_output): sp_mat, dense_mat = ctx.saved_variables grad_matrix1 = grad_matrix2 = None assert not ctx.needs_input_grad[0] if ctx.needs_input_grad[1]: grad_matrix2 = Variable(torch.mm(sp_mat.data.t(), grad_output.data)) return grad_matrix1, grad_matrix2
[docs]def gnn_spmm(sp_mat, dense_mat): return MySpMM.apply(sp_mat, dense_mat)