import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import RelGraphConv
from neuralkg.model import DistMult
[docs]class RGCN(nn.Module):
    """`Modeling Relational Data with Graph Convolutional Networks`_ (RGCN), which use GCN framework to model relation data.
    Attributes:
        args: Model configuration parameters.
    
    .. _Modeling Relational Data with Graph Convolutional Networks: https://arxiv.org/pdf/1703.06103.pdf
    """
    def __init__(self, args):
        super(RGCN, self).__init__()
        self.args = args
        self.ent_emb = None
        self.rel_emb = None 
        self.RGCN = None 
        self.Loss_emb = None
        self.build_model()
[docs]    def build_model(self):
        """Initialize the RGCN model and embeddings 
        Args:
            ent_emb: Entity embedding, shape:[num_ent, emb_dim].
            rel_emb: Relation_embedding, shape:[num_rel, emb_dim].
            RGCN: the relation graph convolution model.
        """
        self.ent_emb = nn.Embedding(self.args.num_ent,self.args.emb_dim)
        self.rel_emb = nn.Parameter(torch.Tensor(self.args.num_rel, self.args.emb_dim))
        nn.init.xavier_uniform_(self.rel_emb, gain=nn.init.calculate_gain('relu'))
        self.RGCN = nn.ModuleList()
        for idx in range(self.args.num_layers):
            RGCN_idx = self.build_hidden_layer(idx)
            self.RGCN.append(RGCN_idx) 
        
[docs]    def forward(self, graph, ent, rel, norm, triples, mode='single'):
        
        """The functions used in the training and testing phase
        Args:
            graph: The knowledge graph recorded in dgl.graph()
            ent: The entitiy ids sampled in triples
            rel: The relation ids sampled in triples
            norm: The edge norm in graph 
            triples: The triples ids, as (h, r, t), shape:[batch_size, 3].
            mode: Choose head-predict or tail-predict, Defaults to 'single'.
        Returns:
            score: The score of triples.
        """
        
        embedding = self.ent_emb(ent.squeeze())
        for layer in self.RGCN:
            embedding = layer(graph, embedding, rel, norm)
        self.Loss_emb = embedding
        head_emb, rela_emb, tail_emb = self.tri2emb(embedding, triples, mode)
        score = DistMult.score_func(self,head_emb, rela_emb, tail_emb, mode)
        return score 
[docs]    def get_score(self, batch, mode):
        """The functions used in the testing phase
        Args:
            batch: A batch of data.
            mode: Choose head-predict or tail-predict.
        Returns:
            score: The score of triples.
        """
        triples    = batch['positive_sample']
        graph      = batch['graph']
        ent        = batch['entity']
        rel        = batch['rela']
        norm       = batch['norm']
        embedding = self.ent_emb(ent.squeeze())
        for layer in self.RGCN:
            embedding = layer(graph, embedding, rel, norm)
        self.Loss_emb = embedding
        head_emb, rela_emb, tail_emb = self.tri2emb(embedding, triples, mode)
        score = DistMult.score_func(self,head_emb, rela_emb, tail_emb, mode)
        return score 
[docs]    def tri2emb(self, embedding, triples, mode="single"): #TODO:和XTransE合并
        
        """Get embedding of triples.
        
        This function get the embeddings of head, relation, and tail
        respectively. each embedding has three dimensions.
        Args:
            embedding(tensor): This embedding save the entity embeddings.            
            triples (tensor): This tensor save triples id, which dimension is 
                [triples number, 3].
            mode (str, optional): This arg indicates that the negative entity 
                will replace the head or tail entity. when it is 'single', it 
                means that entity will not be replaced. Defaults to 'single'.
        Returns:
            head_emb: Head entity embedding.
            rela_emb: Relation embedding.
            tail_emb: Tail entity embedding.
        """
        rela_emb = self.rel_emb[triples[:, 1]].unsqueeze(1)  # [bs, 1, dim]
        head_emb = embedding[triples[:, 0]].unsqueeze(1)  # [bs, 1, dim] 
        tail_emb = embedding[triples[:, 2]].unsqueeze(1)  # [bs, 1, dim]
        if mode == "head-batch" or mode == "head_predict":
            head_emb = embedding.unsqueeze(0)  # [1, num_ent, dim]
        elif mode == "tail-batch" or mode == "tail_predict":
            tail_emb = embedding.unsqueeze(0)  # [1, num_ent, dim]
        return head_emb, rela_emb, tail_emb 
[docs]    def build_hidden_layer(self, idx):
        """The functions used to initialize the RGCN model
        Args:
            idx: it`s used to identify rgcn layers. The last rgcn layer should use 
            relu as activation function.
        Returns:
            the relation graph convolution layer
        """
        act = F.relu if idx < self.args.num_layers - 1 else None
        return RelGraphConv(self.args.emb_dim, self.args.emb_dim, self.args.num_rel, "bdd",
                    num_bases=100, activation=act, self_loop=True,dropout=0.2 )