import torch
import torch.nn as nn
import math
from torch.autograd import Variable
from IPython import embed
[docs]class RugE_Loss(nn.Module):
def __init__(self,args, model):
super(RugE_Loss, self).__init__()
self.args = args
self.model = model
[docs] def forward(self, pos_score, neg_score, rule, confidence, triple_num, pos_len):
entroy = nn.BCELoss()
# 这段代码写的太简陋了 先跑通再说
pos_label = torch.ones([pos_len, 1])
neg_label = torch.zeros([pos_len, self.args.num_neg])
one = torch.ones([1])
zero = torch.zeros([1])
pos_label = Variable(pos_label).to(self.args.gpu, dtype=torch.float)
neg_label = Variable(neg_label).to(self.args.gpu, dtype=torch.float)
one = Variable(one).to(self.args.gpu, dtype=torch.float)
zero = Variable(zero).to(self.args.gpu, dtype=torch.float)
sigmoid_neg = torch.sigmoid(neg_score)
sigmoid_pos = torch.sigmoid(pos_score)
postive_loss = entroy(sigmoid_pos, pos_label)
negative_loss = entroy(sigmoid_neg, neg_label)
pi_gradient = dict()
# 感觉应该放在这个大函数的外面,不然每次被清空也没什么用
sigmoid_value = dict()
# 在计算每个grounding rule中的unlable的三元组对应的类似gradient
for i in range(len(rule[0])):
if triple_num[i] == 2:
p1_rule = rule[0][i]
unlabel_rule = rule[1][i]
if p1_rule not in sigmoid_value:
p1_rule_score = self.model(p1_rule.unsqueeze(0))
sigmoid_rule = torch.sigmoid(p1_rule_score)
sigmoid_value[p1_rule] = sigmoid_rule
else:
sigmoid_rule = sigmoid_value[p1_rule]
if unlabel_rule not in pi_gradient:
pi_gradient[unlabel_rule] = self.args.slackness_penalty * confidence[i] * sigmoid_rule
else:
pi_gradient[unlabel_rule] += self.args.slackness_penalty * confidence[i] * sigmoid_rule
elif triple_num[i] == 3:
p1_rule = rule[0][i]
p2_rule = rule[1][i]
unlabel_rule = rule[2][i]
if p1_rule not in sigmoid_value:
p1_rule_score = self.model(p1_rule.unsqueeze(0))
sigmoid_rule = torch.sigmoid(p1_rule_score)
sigmoid_value[p1_rule] = sigmoid_rule
else:
sigmoid_rule = sigmoid_value[p1_rule]
if p2_rule not in sigmoid_value:
p2_rule_score = self.model(p2_rule.unsqueeze(0))
sigmoid_rule2 = torch.sigmoid(p2_rule_score)
sigmoid_value[p2_rule] = sigmoid_rule
else:
sigmoid_rule2 = sigmoid_value[p2_rule]
if unlabel_rule not in pi_gradient:
pi_gradient[unlabel_rule] = self.args.slackness_penalty * confidence[i] * sigmoid_rule * sigmoid_rule2
else:
pi_gradient[unlabel_rule] += self.args.slackness_penalty * confidence[i] * sigmoid_rule * sigmoid_rule2
unlabel_loss = 0.
unlabel_triples = []
gradient = []
# 对于pi_gradient中的每个三元组(不重复)的 根据公式计算s函数
for unlabel_triple in pi_gradient.keys():
unlabel_triples.append(unlabel_triple.cpu().numpy())
gradient.append(pi_gradient[unlabel_triple].cpu().detach().numpy())
unlabel_triples = torch.tensor(unlabel_triples).to(self.args.gpu)
gradient = torch.tensor(gradient).to(self.args.gpu).view(-1, 1)
unlabel_triple_score = self.model(unlabel_triples)
unlabel_triple_score = torch.sigmoid(unlabel_triple_score)
unlabel_scores = []
for i in range(0, len(gradient)):
unlabel_score = (torch.min(torch.max(unlabel_triple_score[i] + gradient[i], zero), one)).cpu().detach().numpy()
unlabel_scores.append(unlabel_score[0])
unlabel_scores = torch.tensor(unlabel_scores).to(self.args.gpu)
unlabel_scores = unlabel_scores.unsqueeze(1)
unlabel_loss = entroy(unlabel_triple_score, unlabel_scores)
# for unlabel_triple in pi_gradient.keys():
# unlabelrule_score = model(unlabel_triple.unsqueeze(0))
# sigmoid_unlabelrule = torch.sigmoid(unlabelrule_score)
# unlabel_score = torch.min(torch.max(sigmoid_unlabelrule + args.slackness_penalty * pi_gradient[unlabel_triple], zero), one)
# loss_part = entroy(sigmoid_unlabelrule, unlabel_score.to(args.gpu).detach())
# unlabel_loss = unlabel_loss + loss_part
# 所有的grounding的unlbeled的两个值sigmoid和s函数都存在list中,需要转成tensor,然后一起计算loss
loss = postive_loss + negative_loss + unlabel_loss
if self.args.weight_decay != 0.0:
#Use L2 regularization for ComplEx_NNE_AER
ent_emb_all = self.model.ent_emb(torch.arange(self.args.num_ent).to(self.args.gpu))
rel_emb_all = self.model.rel_emb(torch.arange(self.args.num_rel).to(self.args.gpu))
regularization = self.args.weight_decay * (
ent_emb_all.norm(p = 2)**2 + rel_emb_all.norm(p=2)**2
)
# print(postive_loss)
# print(negative_loss)
# print(unlabel_loss)
loss += regularization
return loss