import torch
import torch.nn as nn
from IPython import embed
from neuralkg.data import KGData
[docs]class ComplEx_NNE_AER_Loss(nn.Module):
def __init__(self, args, model):
super(ComplEx_NNE_AER_Loss, self).__init__()
self.args = args
self.model = model
self.rule_p, self.rule_q = model.rule
self.confidence = model.conf
[docs] def forward(self, pos_score, neg_score):
logistic_neg = torch.log(1 + torch.exp(neg_score)).sum(dim=1)
logistic_pos = torch.log(1 + torch.exp(-pos_score)).sum(dim=1)
logistic_loss = logistic_neg + logistic_pos
re_p, im_p = self.model.rel_emb(self.rule_p).chunk(2, dim=-1)
re_q, im_q = self.model.rel_emb(self.rule_q).chunk(2, dim=-1)
entail_loss_re = self.args.mu * torch.sum(
self.confidence * (re_p - re_q).clamp(min=0).sum(dim=-1)
)
entail_loss_im = self.args.mu * torch.sum(
self.confidence * (im_p - im_q).pow(2).sum(dim=-1)
)
entail_loss = entail_loss_re + entail_loss_im
loss = logistic_loss + entail_loss
# return loss
if self.args.regularization != 0.0:
# Use L2 regularization for ComplEx_NNE_AER
regularization = self.args.regularization * (
self.model.ent_emb.weight.norm(p=2) ** 2
+ self.model.rel_emb.weight.norm(p=2) ** 2
)
loss = loss + regularization
loss = loss.mean()
return loss