Source code for neuralkg.model.KGEModel.model

import torch.nn as nn
import torch

[docs]class Model(nn.Module): def __init__(self, args): super(Model, self).__init__()
[docs] def init_emb(self): raise NotImplementedError
[docs] def score_func(self, head_emb, relation_emb, tail_emb): raise NotImplementedError
[docs] def forward(self, triples, negs, mode): raise NotImplementedError
[docs] def tri2emb(self, triples, negs=None, mode="single"): """Get embedding of triples. This function get the embeddings of head, relation, and tail respectively. each embedding has three dimensions. Args: triples (tensor): This tensor save triples id, which dimension is [triples number, 3]. negs (tensor, optional): This tenosr store the id of the entity to be replaced, which has one dimension. when negs is None, it is in the test/eval phase. Defaults to None. mode (str, optional): This arg indicates that the negative entity will replace the head or tail entity. when it is 'single', it means that entity will not be replaced. Defaults to 'single'. Returns: head_emb: Head entity embedding. relation_emb: Relation embedding. tail_emb: Tail entity embedding. """ if mode == "single": head_emb = self.ent_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim] relation_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim] tail_emb = self.ent_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim] elif mode == "head-batch" or mode == "head_predict": if negs is None: # 说明这个时候是在evluation,所以需要直接用所有的entity embedding head_emb = # [1, num_ent, dim] else: head_emb = self.ent_emb(negs) # [bs, num_neg, dim] relation_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim] tail_emb = self.ent_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim] elif mode == "tail-batch" or mode == "tail_predict": head_emb = self.ent_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim] relation_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim] if negs is None: tail_emb = # [1, num_ent, dim] else: tail_emb = self.ent_emb(negs) # [bs, num_neg, dim] return head_emb, relation_emb, tail_emb