import torch.nn as nn
import torch
import torch.nn.functional as F
import math
from .model import Model
from IPython import embed
[docs]class SimplE(Model):
"""`SimplE Embedding for Link Prediction in Knowledge Graphs`_ (SimpleE), which presents a simple enhancement of CP (which we call SimplE) to allow the two embeddings of each entity to be learned dependently.
Attributes:
args: Model configuration parameters.
epsilon: Calculate embedding_range.
margin: Calculate embedding_range and loss.
embedding_range: Uniform distribution range.
ent_h_emb: Entity embedding, shape:[num_ent, emb_dim].
ent_t_emb: Entity embedding, shape:[num_ent, emb_dim].
rel_emb: Relation_embedding, shape:[num_rel, emb_dim].
rel_inv_emb: Inverse Relation_embedding, shape:[num_rel, emb_dim].
.. _SimplE Embedding for Link Prediction in Knowledge Graphs: http://papers.neurips.cc/paper/7682-simple-embedding-for-link-prediction-in-knowledge-graphs.pdf
"""
def __init__(self, args):
super(SimplE, self).__init__(args)
self.args = args
self.ent_emb = None
self.rel_emb = None
self.init_emb()
[docs] def init_emb(self):
"""Initialize the entity and relation embeddings in the form of a uniform distribution."""
self.ent_h_emb = nn.Embedding(self.args.num_ent, self.args.emb_dim)
self.ent_t_emb = nn.Embedding(self.args.num_ent, self.args.emb_dim)
self.rel_emb = nn.Embedding(self.args.num_rel, self.args.emb_dim)
self.rel_inv_emb = nn.Embedding(self.args.num_rel, self.args.emb_dim)
sqrt_size = 6.0 / math.sqrt(self.args.emb_dim)
nn.init.uniform_(tensor=self.ent_h_emb.weight.data, a=-sqrt_size, b=sqrt_size)
nn.init.uniform_(tensor=self.ent_t_emb.weight.data, a=-sqrt_size, b=sqrt_size)
nn.init.uniform_(tensor=self.rel_emb.weight.data, a=-sqrt_size, b=sqrt_size)
nn.init.uniform_(tensor=self.rel_inv_emb.weight.data, a=-sqrt_size, b=sqrt_size)
[docs] def score_func(self, hh_emb, rel_emb, tt_emb, ht_emb, rel_inv_emb, th_emb):
"""Calculating the score of triples.
Args:
hh_emb: The head entity embedding on embedding h.
rel_emb: The relation embedding.
tt_emb: The tail entity embedding on embedding t.
ht_emb: The tail entity embedding on embedding h.
rel_inv_emb: The tail entity embedding.
th_emb: The head entity embedding on embedding t.
Returns:
score: The score of triples.
"""
# return -(torch.sum(head_emb * relation_emb * tail_emb, -1) + \
# torch.sum(head_emb * rel_inv_emb * tail_emb, -1))/2
scores1 = torch.sum(hh_emb * rel_emb * tt_emb, dim=-1)
scores2 = torch.sum(ht_emb * rel_inv_emb * th_emb, dim=-1)
return torch.clamp((scores1 + scores2) / 2, -20, 20)
[docs] def l2_loss(self):
return (self.ent_h_emb.weight.norm(p = 2) ** 2 + \
self.ent_t_emb.weight.norm(p = 2) ** 2 + \
self.rel_emb.weight.norm(p = 2) ** 2 + \
self.rel_inv_emb.weight.norm(p = 2) ** 2)
[docs] def forward(self, triples, negs=None, mode='single'):
"""The functions used in the training phase
Args:
triples: The triples ids, as (h, r, t), shape:[batch_size, 3].
negs: Negative samples, defaults to None.
mode: Choose head-predict or tail-predict, Defaults to 'single'.
Returns:
score: The score of triples.
"""
rel_emb, rel_inv_emb, hh_emb, th_emb, ht_emb, tt_emb = self.get_emb(triples, negs, mode)
return self.score_func(hh_emb, rel_emb, tt_emb, ht_emb, rel_inv_emb, th_emb)
[docs] def get_score(self, batch, mode):
"""The functions used in the testing phase
Args:
batch: A batch of data.
mode: Choose head-predict or tail-predict.
Returns:
score: The score of triples.
"""
triples = batch["positive_sample"]
rel_emb, rel_inv_emb, hh_emb, th_emb, ht_emb, tt_emb = self.get_emb(triples, mode=mode)
return self.score_func(hh_emb, rel_emb, tt_emb, ht_emb, rel_inv_emb, th_emb)
[docs] def get_emb(self, triples, negs=None, mode='single'):
if mode == 'single':
rel_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim]
rel_inv_emb = self.rel_inv_emb(triples[:, 1]).unsqueeze(1)
hh_emb = self.ent_h_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim]
th_emb = self.ent_t_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim]
ht_emb = self.ent_h_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim]
tt_emb = self.ent_t_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim]
elif mode == 'head-batch' or mode == "head_predict":
if negs is None: # 说明这个时候是在evluation,所以需要直接用所有的entity embedding
hh_emb = self.ent_h_emb.weight.data.unsqueeze(0) # [1, num_ent, dim]
th_emb = self.ent_t_emb.weight.data.unsqueeze(0) # [1, num_ent, dim]
else:
hh_emb = self.ent_h_emb(negs) # [bs, num_neg, dim]
th_emb = self.ent_t_emb(negs) # [bs, num_neg, dim]
rel_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim]
rel_inv_emb = self.rel_inv_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim]
ht_emb = self.ent_h_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim]
tt_emb = self.ent_t_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim]
elif mode == 'tail-batch' or mode == "tail_predict":
if negs is None: # 说明这个时候是在evluation,所以需要直接用所有的entity embedding
ht_emb = self.ent_h_emb.weight.data.unsqueeze(0) # [1, num_ent, dim]
tt_emb = self.ent_t_emb.weight.data.unsqueeze(0) # [1, num_ent, dim]
else:
ht_emb = self.ent_h_emb(negs) # [bs, num_neg, dim]
tt_emb = self.ent_t_emb(negs) # [bs, num_neg, dim]
rel_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim]
rel_inv_emb = self.rel_inv_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim]
hh_emb = self.ent_h_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim]
th_emb = self.ent_t_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim]
return rel_emb, rel_inv_emb, hh_emb, th_emb, ht_emb, tt_emb