Source code for neuralkg.model.KGEModel.SimplE

import torch.nn as nn
import torch
import torch.nn.functional as F
import math
from .model import Model
from IPython import embed


[docs]class SimplE(Model): """`SimplE Embedding for Link Prediction in Knowledge Graphs`_ (SimpleE), which presents a simple enhancement of CP (which we call SimplE) to allow the two embeddings of each entity to be learned dependently. Attributes: args: Model configuration parameters. epsilon: Calculate embedding_range. margin: Calculate embedding_range and loss. embedding_range: Uniform distribution range. ent_h_emb: Entity embedding, shape:[num_ent, emb_dim]. ent_t_emb: Entity embedding, shape:[num_ent, emb_dim]. rel_emb: Relation_embedding, shape:[num_rel, emb_dim]. rel_inv_emb: Inverse Relation_embedding, shape:[num_rel, emb_dim]. .. _SimplE Embedding for Link Prediction in Knowledge Graphs: http://papers.neurips.cc/paper/7682-simple-embedding-for-link-prediction-in-knowledge-graphs.pdf """ def __init__(self, args): super(SimplE, self).__init__(args) self.args = args self.ent_emb = None self.rel_emb = None self.init_emb()
[docs] def init_emb(self): """Initialize the entity and relation embeddings in the form of a uniform distribution.""" self.ent_h_emb = nn.Embedding(self.args.num_ent, self.args.emb_dim) self.ent_t_emb = nn.Embedding(self.args.num_ent, self.args.emb_dim) self.rel_emb = nn.Embedding(self.args.num_rel, self.args.emb_dim) self.rel_inv_emb = nn.Embedding(self.args.num_rel, self.args.emb_dim) sqrt_size = 6.0 / math.sqrt(self.args.emb_dim) nn.init.uniform_(tensor=self.ent_h_emb.weight.data, a=-sqrt_size, b=sqrt_size) nn.init.uniform_(tensor=self.ent_t_emb.weight.data, a=-sqrt_size, b=sqrt_size) nn.init.uniform_(tensor=self.rel_emb.weight.data, a=-sqrt_size, b=sqrt_size) nn.init.uniform_(tensor=self.rel_inv_emb.weight.data, a=-sqrt_size, b=sqrt_size)
[docs] def score_func(self, hh_emb, rel_emb, tt_emb, ht_emb, rel_inv_emb, th_emb): """Calculating the score of triples. Args: hh_emb: The head entity embedding on embedding h. rel_emb: The relation embedding. tt_emb: The tail entity embedding on embedding t. ht_emb: The tail entity embedding on embedding h. rel_inv_emb: The tail entity embedding. th_emb: The head entity embedding on embedding t. Returns: score: The score of triples. """ # return -(torch.sum(head_emb * relation_emb * tail_emb, -1) + \ # torch.sum(head_emb * rel_inv_emb * tail_emb, -1))/2 scores1 = torch.sum(hh_emb * rel_emb * tt_emb, dim=-1) scores2 = torch.sum(ht_emb * rel_inv_emb * th_emb, dim=-1) return torch.clamp((scores1 + scores2) / 2, -20, 20)
[docs] def l2_loss(self): return (self.ent_h_emb.weight.norm(p = 2) ** 2 + \ self.ent_t_emb.weight.norm(p = 2) ** 2 + \ self.rel_emb.weight.norm(p = 2) ** 2 + \ self.rel_inv_emb.weight.norm(p = 2) ** 2)
[docs] def forward(self, triples, negs=None, mode='single'): """The functions used in the training phase Args: triples: The triples ids, as (h, r, t), shape:[batch_size, 3]. negs: Negative samples, defaults to None. mode: Choose head-predict or tail-predict, Defaults to 'single'. Returns: score: The score of triples. """ rel_emb, rel_inv_emb, hh_emb, th_emb, ht_emb, tt_emb = self.get_emb(triples, negs, mode) return self.score_func(hh_emb, rel_emb, tt_emb, ht_emb, rel_inv_emb, th_emb)
[docs] def get_score(self, batch, mode): """The functions used in the testing phase Args: batch: A batch of data. mode: Choose head-predict or tail-predict. Returns: score: The score of triples. """ triples = batch["positive_sample"] rel_emb, rel_inv_emb, hh_emb, th_emb, ht_emb, tt_emb = self.get_emb(triples, mode=mode) return self.score_func(hh_emb, rel_emb, tt_emb, ht_emb, rel_inv_emb, th_emb)
[docs] def get_emb(self, triples, negs=None, mode='single'): if mode == 'single': rel_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim] rel_inv_emb = self.rel_inv_emb(triples[:, 1]).unsqueeze(1) hh_emb = self.ent_h_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim] th_emb = self.ent_t_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim] ht_emb = self.ent_h_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim] tt_emb = self.ent_t_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim] elif mode == 'head-batch' or mode == "head_predict": if negs is None: # 说明这个时候是在evluation,所以需要直接用所有的entity embedding hh_emb = self.ent_h_emb.weight.data.unsqueeze(0) # [1, num_ent, dim] th_emb = self.ent_t_emb.weight.data.unsqueeze(0) # [1, num_ent, dim] else: hh_emb = self.ent_h_emb(negs) # [bs, num_neg, dim] th_emb = self.ent_t_emb(negs) # [bs, num_neg, dim] rel_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim] rel_inv_emb = self.rel_inv_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim] ht_emb = self.ent_h_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim] tt_emb = self.ent_t_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim] elif mode == 'tail-batch' or mode == "tail_predict": if negs is None: # 说明这个时候是在evluation,所以需要直接用所有的entity embedding ht_emb = self.ent_h_emb.weight.data.unsqueeze(0) # [1, num_ent, dim] tt_emb = self.ent_t_emb.weight.data.unsqueeze(0) # [1, num_ent, dim] else: ht_emb = self.ent_h_emb(negs) # [bs, num_neg, dim] tt_emb = self.ent_t_emb(negs) # [bs, num_neg, dim] rel_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim] rel_inv_emb = self.rel_inv_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim] hh_emb = self.ent_h_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim] th_emb = self.ent_t_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim] return rel_emb, rel_inv_emb, hh_emb, th_emb, ht_emb, tt_emb