Source code for neuralkg.loss.ComplEx_NNE_AER_Loss

import torch
import torch.nn as nn
from IPython import embed
from neuralkg.data import KGData


[docs]class ComplEx_NNE_AER_Loss(nn.Module): def __init__(self, args, model): super(ComplEx_NNE_AER_Loss, self).__init__() self.args = args self.model = model self.rule_p, self.rule_q = model.rule self.confidence = model.conf
[docs] def forward(self, pos_score, neg_score): logistic_neg = torch.log(1 + torch.exp(neg_score)).sum(dim=1) logistic_pos = torch.log(1 + torch.exp(-pos_score)).sum(dim=1) logistic_loss = logistic_neg + logistic_pos re_p, im_p = self.model.rel_emb(self.rule_p).chunk(2, dim=-1) re_q, im_q = self.model.rel_emb(self.rule_q).chunk(2, dim=-1) entail_loss_re = self.args.mu * torch.sum( self.confidence * (re_p - re_q).clamp(min=0).sum(dim=-1) ) entail_loss_im = self.args.mu * torch.sum( self.confidence * (im_p - im_q).pow(2).sum(dim=-1) ) entail_loss = entail_loss_re + entail_loss_im loss = logistic_loss + entail_loss # return loss if self.args.regularization != 0.0: # Use L2 regularization for ComplEx_NNE_AER regularization = self.args.regularization * ( self.model.ent_emb.weight.norm(p=2) ** 2 + self.model.rel_emb.weight.norm(p=2) ** 2 ) loss = loss + regularization loss = loss.mean() return loss