import dgl
import torch
from .BaseLitModel import BaseLitModel
from neuralkg_ind.eval_task import *
[docs]class SEGNNLitModel(BaseLitModel):
def __init__(self, model, args, src_list, dst_list, rel_list):
super().__init__(model, args)
self.src_list = src_list
self.dst_list = dst_list
self.rel_list = rel_list
self.kg = self.get_kg(src_list, dst_list, rel_list)
[docs] def forward(self, x):
return self.model(x)
[docs] def training_step(self, batch):
optimizer = self.optimizers()
#optimizer = optimizer.optimizer
optimizer.zero_grad()
(head, rel, _), label, rm_edges= batch
kg = self.get_kg(self.src_list, self.dst_list, self.rel_list)
kg = kg.to(torch.device("cuda:0"))
if self.args.rm_rate > 0:
kg.remove_edges(rm_edges)
score = self.model(head, rel, kg)
loss = self.loss(score, label)
self.manual_backward(loss)
optimizer.step()
sch = self.lr_schedulers()
sch.step()
return loss
[docs] def validation_step(self, batch, batch_idx):
# pos_triple, tail_label, head_label = batch
results = dict()
ranks = link_predict_SEGNN(batch, self.kg, self.model, prediction='tail')
results["count"] = torch.numel(ranks)
#results['mr'] = results.get('mr', 0.) + ranks.sum().item()
results['mrr'] = torch.sum(1.0 / ranks).item()
for k in self.args.calc_hits:
results['hits@{}'.format(k)] = torch.numel(ranks[ranks<=k])
return results
[docs] def validation_epoch_end(self, results) -> None:
outputs = self.get_results(results, "Eval")
# self.log("Eval|mrr", outputs["Eval|mrr"], on_epoch=True)
self.log_dict(outputs, prog_bar=True, on_epoch=True)
[docs] def test_step(self, batch, batch_idx):
results = dict()
ranks = link_predict_SEGNN(batch, self.kg, self.model, prediction='tail')
results["count"] = torch.numel(ranks)
#results['mr'] = results.get('MR', 0.) + ranks.sum().item()
results['mrr'] = torch.sum(1.0 / ranks).item()
for k in self.args.calc_hits:
results['hits@{}'.format(k)] = torch.numel(ranks[ranks <= k])
return results
[docs] def test_epoch_end(self, results) -> None:
outputs = self.get_results(results, "Test")
self.log_dict(outputs, prog_bar=True, on_epoch=True)
[docs] def get_kg(self, src_list, dst_list, rel_list):
n_ent = self.args.num_ent
kg = dgl.graph((src_list, dst_list), num_nodes=n_ent)
kg.edata['rel_id'] = rel_list
return kg
'''这里设置优化器和lr_scheduler'''