Source code for neuralkg.model.GNNModel.RGCN

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import RelGraphConv
from neuralkg.model import DistMult


[docs]class RGCN(nn.Module): """`Modeling Relational Data with Graph Convolutional Networks`_ (RGCN), which use GCN framework to model relation data. Attributes: args: Model configuration parameters. .. _Modeling Relational Data with Graph Convolutional Networks: https://arxiv.org/pdf/1703.06103.pdf """ def __init__(self, args): super(RGCN, self).__init__() self.args = args self.ent_emb = None self.rel_emb = None self.RGCN = None self.Loss_emb = None self.build_model()
[docs] def build_model(self): """Initialize the RGCN model and embeddings Args: ent_emb: Entity embedding, shape:[num_ent, emb_dim]. rel_emb: Relation_embedding, shape:[num_rel, emb_dim]. RGCN: the relation graph convolution model. """ self.ent_emb = nn.Embedding(self.args.num_ent,self.args.emb_dim) self.rel_emb = nn.Parameter(torch.Tensor(self.args.num_rel, self.args.emb_dim)) nn.init.xavier_uniform_(self.rel_emb, gain=nn.init.calculate_gain('relu')) self.RGCN = nn.ModuleList() for idx in range(self.args.num_layers): RGCN_idx = self.build_hidden_layer(idx) self.RGCN.append(RGCN_idx)
[docs] def forward(self, graph, ent, rel, norm, triples, mode='single'): """The functions used in the training and testing phase Args: graph: The knowledge graph recorded in dgl.graph() ent: The entitiy ids sampled in triples rel: The relation ids sampled in triples norm: The edge norm in graph triples: The triples ids, as (h, r, t), shape:[batch_size, 3]. mode: Choose head-predict or tail-predict, Defaults to 'single'. Returns: score: The score of triples. """ embedding = self.ent_emb(ent.squeeze()) for layer in self.RGCN: embedding = layer(graph, embedding, rel, norm) self.Loss_emb = embedding head_emb, rela_emb, tail_emb = self.tri2emb(embedding, triples, mode) score = DistMult.score_func(self,head_emb, rela_emb, tail_emb, mode) return score
[docs] def get_score(self, batch, mode): """The functions used in the testing phase Args: batch: A batch of data. mode: Choose head-predict or tail-predict. Returns: score: The score of triples. """ triples = batch['positive_sample'] graph = batch['graph'] ent = batch['entity'] rel = batch['rela'] norm = batch['norm'] embedding = self.ent_emb(ent.squeeze()) for layer in self.RGCN: embedding = layer(graph, embedding, rel, norm) self.Loss_emb = embedding head_emb, rela_emb, tail_emb = self.tri2emb(embedding, triples, mode) score = DistMult.score_func(self,head_emb, rela_emb, tail_emb, mode) return score
[docs] def tri2emb(self, embedding, triples, mode="single"): #TODO:和XTransE合并 """Get embedding of triples. This function get the embeddings of head, relation, and tail respectively. each embedding has three dimensions. Args: embedding(tensor): This embedding save the entity embeddings. triples (tensor): This tensor save triples id, which dimension is [triples number, 3]. mode (str, optional): This arg indicates that the negative entity will replace the head or tail entity. when it is 'single', it means that entity will not be replaced. Defaults to 'single'. Returns: head_emb: Head entity embedding. rela_emb: Relation embedding. tail_emb: Tail entity embedding. """ rela_emb = self.rel_emb[triples[:, 1]].unsqueeze(1) # [bs, 1, dim] head_emb = embedding[triples[:, 0]].unsqueeze(1) # [bs, 1, dim] tail_emb = embedding[triples[:, 2]].unsqueeze(1) # [bs, 1, dim] if mode == "head-batch" or mode == "head_predict": head_emb = embedding.unsqueeze(0) # [1, num_ent, dim] elif mode == "tail-batch" or mode == "tail_predict": tail_emb = embedding.unsqueeze(0) # [1, num_ent, dim] return head_emb, rela_emb, tail_emb
[docs] def build_hidden_layer(self, idx): """The functions used to initialize the RGCN model Args: idx: it`s used to identify rgcn layers. The last rgcn layer should use relu as activation function. Returns: the relation graph convolution layer """ act = F.relu if idx < self.args.num_layers - 1 else None return RelGraphConv(self.args.emb_dim, self.args.emb_dim, self.args.num_rel, "bdd", num_bases=100, activation=act, self_loop=True,dropout=0.2 )