Source code for neuralkg.loss.Margin_Loss

import torch
import torch.nn.functional as F
import torch.nn as nn
from IPython import embed

[docs]class Margin_Loss(nn.Module): def __init__(self, args, model): super(Margin_Loss, self).__init__() self.args = args self.model = model
[docs] def forward(self, pos_score, neg_score): neg_score = F.logsigmoid(-neg_score) #shape:[bs] pos_score = F.logsigmoid(pos_score) #shape:[bs, 1] positive_sample_loss = - pos_score.mean() negative_sample_loss = - neg_score.mean() loss = (positive_sample_loss + negative_sample_loss) / 2 if self.args.model_name == "XTransE": regularization = self.args.regularization * ( self.model.ent_emb.weight.norm(p = 1) + \ self.model.rel_emb.weight.norm(p = 1) ) if self.args.model_name == 'ComplEx' or self.args.model_name == 'DistMult': #Use L3 regularization for ComplEx and DistMult regularization = self.args.regularization * ( self.model.ent_emb.weight.norm(p = 3)**3 + \ self.model.rel_emb.weight.norm(p = 3)**3 ) loss = loss + regularization return loss