from logging import debug
import pytorch_lightning as pl
import torch
import sys
sys.path.append("../../src/neuralkg")   
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import json
from collections import defaultdict as ddict
from IPython import embed
from .BaseLitModel import BaseLitModel
from neuralkg.eval_task import *
from IPython import embed
from functools import partial
[docs]class RGCNLitModel(BaseLitModel):
    def __init__(self, model, args):
        super().__init__(model, args)
        
[docs]    def forward(self, x):
        return self.model(x) 
    
[docs]    def training_step(self, batch, batch_idx):
        graph    = batch["graph"]
        triples  = batch["triples"]
        label    = batch["label"]
        entity   = batch['entity']
        relation = batch['relation']
        norm     = batch['norm']
        score = self.model(graph, entity, relation, norm, triples)
        loss  = self.loss(score,  label)
        self.log("Train|loss", loss,  on_step=False, on_epoch=True)
        return loss 
    
[docs]    def validation_step(self, batch, batch_idx):
        # pos_triple, tail_label, head_label = batch
        results = dict()
        ranks = link_predict(batch, self.model, prediction='all')
        results["count"] = torch.numel(ranks)
        results["mrr"] = torch.sum(1.0 / ranks).item()
        for k in self.args.calc_hits:
            results['hits@{}'.format(k)] = torch.numel(ranks[ranks <= k])
        return results 
    
[docs]    def validation_epoch_end(self, results) -> None:
        outputs = self.get_results(results, "Eval")
        # self.log("Eval|mrr", outputs["Eval|mrr"], on_epoch=True)
        self.log_dict(outputs, prog_bar=True, on_epoch=True) 
[docs]    def test_step(self, batch, batch_idx):
        results = dict()
        ranks = link_predict(batch, self.model, prediction='all')
        results["count"] = torch.numel(ranks)
        results["mrr"] = torch.sum(1.0 / ranks).item()
        for k in self.args.calc_hits:
            results['hits@{}'.format(k)] = torch.numel(ranks[ranks <= k])
        return results 
    
[docs]    def test_epoch_end(self, results) -> None:
        outputs = self.get_results(results, "Test")
        self.log_dict(outputs, prog_bar=True, on_epoch=True) 
    '''这里设置优化器和lr_scheduler'''