Source code for neuralkg_ind.model.GNNModel.Grail

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from .model import Model
from .RGCN import RelGraphConv
from neuralkg_ind.utils.tools import *

[docs]class Grail(nn.Module): """`Inductive Relation Prediction by Subgraph Reasoning`_ (Grail), which reasons over local subgraph structures. Attributes: args: Model configuration parameters. rel_emb: Entity embedding, shape: [num_rel, rel_emb_dim]. gnn: RGCN model. .. _Inductive Relation Prediction by Subgraph Reasoning: https://arxiv.org/abs/1911.06962 """ def __init__(self, args): super().__init__() self.args = args self.ent_emb = None self.rel_emb = None self.gnn = RGCN(args = args, basiclayer = RelAttGraphConv) self.rel_emb = nn.Embedding(self.args.num_rel, self.args.rel_emb_dim, sparse=False) if self.args.add_ht_emb: self.fc_layer = nn.Linear(3 * self.args.num_layers * self.args.emb_dim + self.args.rel_emb_dim, 1) else: self.fc_layer = nn.Linear(self.args.num_layers * self.args.emb_dim + self.args.rel_emb_dim, 1)
[docs] def forward(self, data): """calculating subgraphs score. Args: data: Tuple of subgraphs and relation labels. Returns: output: The score of subgraphs. """ g, rel_labels = data g = dgl.batch(g) g.ndata['h'], _ = self.gnn(g) g_out = dgl.mean_nodes(g, 'repr') head_ids = (g.ndata['id'] == 1).nonzero().squeeze(1) head_embs = g.ndata['repr'][head_ids] tail_ids = (g.ndata['id'] == 2).nonzero().squeeze(1) tail_embs = g.ndata['repr'][tail_ids] if self.args.add_ht_emb: g_rep = torch.cat([g_out.view(-1, self.args.num_layers * self.args.emb_dim), head_embs.view(-1, self.args.num_layers * self.args.emb_dim), tail_embs.view(-1, self.args.num_layers * self.args.emb_dim), self.rel_emb(rel_labels)], dim=1) else: g_rep = torch.cat([g_out.view(-1, self.args.num_layers * self.args.emb_dim), self.rel_emb(rel_labels)], dim=1) output = self.fc_layer(g_rep) return output
[docs]class RGCN(Model): """RGCN model Attributes: args: Model configuration parameters. basiclayer: Layer of RGCN model. inp_dim: Dimension of input. emb_dim: Dimension of embedding. has_attn: Whether there is attention mechanism. attn_rel_emb: Embedding of relation attention. attn_rel_emb_dim: Dimension of relation attention Embedding. """ def __init__(self, args, basiclayer): super(RGCN, self).__init__(args) self.args = args self.basiclayer = basiclayer self.inp_dim = args.inp_dim self.emb_dim = args.emb_dim self.has_attn = args.has_attn self.attn_rel_emb = None self.attn_rel_emb_dim = args.attn_rel_emb_dim self.init_emb() self.build_model()
[docs] def init_emb(self): """Initialize the relation attention embedding, aggregator and features. """ if self.has_attn: self.attn_rel_emb = nn.Embedding(self.args.num_rel, self.attn_rel_emb_dim, sparse=False) aggregator_type = self.args.gnn_agg_type.upper()+"Aggregator" aggregator_class = import_class(f"neuralkg_ind.model.{aggregator_type}") self.aggregator = aggregator_class(self.emb_dim) # create initial features self.features = torch.arange(self.inp_dim)
[docs] def build_hidden_layer(self, idx): """build hidden layer of RGCN. Args: idx: The idx of layer. Returns: output: Build a basic layer according to whether it is the first layer. """ input_flag = True if idx == 0 else False input_emb = self.inp_dim if idx == 0 else self.emb_dim return self.basiclayer(self.args, input_emb, self.emb_dim, self.aggregator, self.attn_rel_emb_dim, self.args.aug_num_rels, self.args.num_bases, None, F.relu, self.args.dropout, self.args.edge_dropout, is_input_layer=input_flag, has_attn=self.has_attn)
[docs] def forward(self, graph, rela=None): """Getting node and relation embedding. Args: graph: Subgraph of corresponding triple. rela: Embedding of relation. Returns: graph.ndata.pop('h'): Node embedding. rela: Relation embedding. """ for layer in self.layers: rela = layer(graph, rela, self.attn_rel_emb) return graph.ndata.pop('h'), rela
[docs]class RelAttGraphConv(RelGraphConv): """Basic layer of RGCN. Attributes: args: Model configuration parameters. bias: Weight bias. inp_dim: Dimension of input. out_dim: Dimension of output. num_rels: The number of relations. num_bases: The number of bases. has_attn: Whether there is attention mechanism. is_input_layer: Whether it is input layer. aggregator: Type of aggregator. weight: Weight matrix. w_comp: Bases matrix. self_loop_weight: Self-loop weight. edge_dropout: Dropout of edge. """ def __init__(self, args, inp_dim, out_dim, aggregator, attn_rel_emb_dim, num_rels, num_bases=-1, bias=None, activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False, has_attn=False): super().__init__(args, inp_dim, out_dim, 0, None, 0, bias=False, activation=activation, self_loop=False, dropout=dropout, layer_norm=False,) self.bias = bias self.inp_dim = inp_dim self.out_dim = out_dim self.num_rels = num_rels self.num_bases = num_bases self.is_input_layer = is_input_layer self.has_attn = has_attn self.aggregator = aggregator if self.bias: self.bias = nn.Parameter(torch.Tensor(out_dim)) nn.init.xavier_uniform_(self.bias, gain=nn.init.calculate_gain('relu')) if self.num_bases <= 0 or self.num_bases > self.num_rels: self.num_bases = self.num_rels self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.inp_dim, self.out_dim)) self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases)) if self.has_attn: self.A = nn.Linear(2 * self.inp_dim + 2 * attn_rel_emb_dim, inp_dim) self.B = nn.Linear(inp_dim, 1) self.self_loop_weight = nn.Parameter(torch.Tensor(self.inp_dim, self.out_dim)) self.edge_dropout = nn.Dropout(edge_dropout) nn.init.xavier_uniform_(self.self_loop_weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu'))
[docs] def propagate(self, g, attn_rel_emb=None): """Message propagate function. Propagate messages and perform calculations according to the graph traversal order. Args: g: Subgraph of triple. attn_rel_emb: Relation attention embedding. """ weight = self.weight.view(self.num_bases, self.inp_dim * self.out_dim) weight = torch.matmul(self.w_comp, weight).view(self.num_rels, self.inp_dim, self.out_dim) g.edata['w'] = self.edge_dropout(torch.ones(g.number_of_edges(), 1)).type_as(weight) input_ = 'feat' if self.is_input_layer else 'h' def message(edges): w = weight.index_select(0, edges.data['type']) msg = edges.data['w'] * torch.bmm(edges.src[input_].unsqueeze(1), w).squeeze(1) curr_emb = torch.mm(edges.dst[input_], self.self_loop_weight) # (B, F) if self.has_attn: e = torch.cat([edges.src[input_], edges.dst[input_], attn_rel_emb(edges.data['type']), attn_rel_emb(edges.data['label'])], dim=1) a = torch.sigmoid(self.B(F.relu(self.A(e)))) else: a = torch.ones((len(edges), 1)) return {'curr_emb': curr_emb, 'msg': msg, 'alpha': a} g.update_all(message, self.aggregator, None)
[docs] def forward(self, g, rel_emb=None, attn_rel_emb=None): """Update node representation. Args: graph: Subgraph of corresponding triple. rel_emb: Embedding of relation. attn_rel_emb: Embedding of relation attention. Returns: rel_emb: Embedding of relation. """ self.propagate(g, attn_rel_emb) # apply bias and activation node_repr = g.ndata['h'] if self.bias: node_repr = node_repr + self.bias if self.activation: node_repr = self.activation(node_repr) node_repr = self.dropout(node_repr) g.ndata['h'] = node_repr if self.is_input_layer: g.ndata['repr'] = g.ndata['h'].unsqueeze(1) else: g.ndata['repr'] = torch.cat([g.ndata['repr'], g.ndata['h'].unsqueeze(1)], dim=1) return rel_emb