import torch.nn as nn
import torch
import os
from .model import Model
from IPython import embed
from collections import defaultdict
import numpy as np
import pickle 
import copy
[docs]class IterE(Model):
    """`Iteratively Learning Embeddings and Rules for Knowledge Graph Reasoning. (WWW'19)`_ (IterE).
    Attributes:
        args: Model configuration parameters.
        epsilon: Caculate embedding_range.
        margin: Caculate embedding_range and loss.
        embedding_range: Uniform distribution range.
        ent_emb: Entity embedding, shape:[num_ent, emb_dim].
        rel_emb: Relation_embedding, shape:[num_rel, emb_dim].
    
    .. _Iteratively Learning Embeddings and Rules for Knowledge Graph Reasoning. (WWW'19): https://dl.acm.org/doi/10.1145/3308558.3313612
    """
    def __init__(self, args, train_sampler, test_sampler):
        super(IterE, self).__init__(args)
        self.args = args
        self.ent_emb = None
        self.rel_emb = None
        self.init_emb()
        #print(self.args)
        #print(train_sampler)
        #print('run get_axiom()')
        self.train_sampler = train_sampler
        self.train_triples_base = copy.deepcopy(train_sampler.train_triples)
        self.select_probability = self.args.select_probability
        self.max_entialments = self.args.max_entialments
        self.axiom_types  = self.args.axiom_types
        self.axiom_weight = self.args.axiom_weight
        self.inject_triple_percent = self.args.inject_triple_percent
        self.sparsity = 0.995
        
        self.num_entity = self.args.num_ent
        
        self.relation2id=train_sampler.rel2id
        
        self.train_ids=train_sampler.train_triples
        self.valid_ids=train_sampler.valid_triples
        self.test_ids=train_sampler.test_triples
        
        #print(len(self.train_ids))
        #print(len(self.valid_ids))
        #print(len(self.test_ids))
        
        self.train_ids_labels_inject = np.reshape([], [-1, 4])
        
        # generate r_ht, hr_t
        print('# generate r_ht, hr_t')
        self.r_ht, self.hr_t, self.tr_h, self.hr_t_all, self.tr_h_all = self._generate(self.train_ids, self.valid_ids, self.test_ids)
        
        # generate entity2frequency and entity2sparsity dict
        print('# generate entity2frequency and entity2sparsity dict')
        self.entity2frequency, self.entity2sparsity = self._entity2frequency()
    
        print('# get_axiom')
        self.get_axiom()
        
        
        #self.rule, self.conf = self.get_rule(self.relation2id)
        
    def _entity2frequency(self):
        ent2freq = {ent:0 for ent in range(self.num_entity)}
        ent2sparsity = {ent:-1 for ent in range(self.num_entity)}
        for h,r,t in self.train_ids:
            ent2freq[h] += 1
            ent2freq[t] += 1
        ent_freq_list = np.asarray([ent2freq[ent] for ent in range(self.num_entity)])
        ent_freq_list_sort = np.argsort(ent_freq_list)
        max_freq = max(list(ent2freq))
        min_freq = min(list(ent2freq))
        for ent, freq in ent2freq.items():
            sparsity = 1 - (freq-min_freq)/(max_freq - min_freq)
            ent2sparsity[ent] = sparsity
        return ent2freq, ent2sparsity
    
    def _generate(self, train, valid, test):
        r_ht = defaultdict(set)
        hr_t = defaultdict(set)
        tr_h = defaultdict(set)
        hr_t_all = defaultdict(list)
        tr_h_all = defaultdict(list)
        for (h,r,t) in train:
            r_ht[r].add((h,t))
            hr_t[(h,r)].add(t)
            tr_h[(t,r)].add(h)
            hr_t_all[(h,r)].append(t)
            tr_h_all[(t,r)].append(h)
        for (h,r,t) in test+valid:
            hr_t_all[(h,r)].append(t)
            tr_h_all[(t, r)].append(h)
        return r_ht, hr_t, tr_h, hr_t_all, tr_h_all
[docs]    def get_axiom(self, ):
        self.axiom_dir = os.path.join(self.args.data_path, 'axiom_pool')
        self.reflexive_dir, self.symmetric_dir, self.transitive_dir, self.inverse_dir, self.subproperty_dir, self.equivalent_dir, self.inferencechain1, self.inferencechain2, self.inferencechain3, self.inferencechain4 = map(lambda x: os.path.join(self.axiom_dir, x),
                                                             ['axiom_reflexive.txt',
                                                              'axiom_symmetric.txt',
                                                              'axiom_transitive.txt',
                                                              'axiom_inverse.txt',
                                                              'axiom_subProperty.txt',
                                                              'axiom_equivalent.txt',
                                                              'axiom_inferenceChain1.txt',
                                                              'axiom_inferenceChain2.txt',
                                                              'axiom_inferenceChain3.txt',
                                                              'axiom_inferenceChain4.txt'])
        # read and materialize axioms
        print('# self._read_axioms()')
        self._read_axioms()
        print('# self._read_axioms()')
        self._materialize_axioms()
        print('# self._read_axioms()')
        self._init_valid_axioms() 
        
        
        
        
    def _read_axioms(self):
        # for each axiom, the first id is the basic relation
        self.axiompool_reflexive = self._read_axiompool_file(self.reflexive_dir)
        self.axiompool_symmetric = self._read_axiompool_file(self.symmetric_dir)
        self.axiompool_transitive = self._read_axiompool_file(self.transitive_dir)
        self.axiompool_inverse = self._read_axiompool_file(self.inverse_dir)
        self.axiompool_equivalent = self._read_axiompool_file(self.equivalent_dir)
        self.axiompool_subproperty = self._read_axiompool_file(self.subproperty_dir)
        self.axiompool_inferencechain1 = self._read_axiompool_file(self.inferencechain1)
        self.axiompool_inferencechain2 = self._read_axiompool_file(self.inferencechain2)
        self.axiompool_inferencechain3 = self._read_axiompool_file(self.inferencechain3)
        self.axiompool_inferencechain4 = self._read_axiompool_file(self.inferencechain4)
        self.axiompool = [self.axiompool_reflexive, self.axiompool_symmetric, self.axiompool_transitive,
                          self.axiompool_inverse, self.axiompool_subproperty, self.axiompool_equivalent,
                          self.axiompool_inferencechain1,self.axiompool_inferencechain2,
                          self.axiompool_inferencechain3,self.axiompool_inferencechain4]
    
    def _read_axiompool_file(self, file):
        f = open(file, 'r')
        axioms = []
        for line in f.readlines():
            line_list = line.strip().split('\t')
            axiom_ids = list(map(lambda x: self.relation2id[x], line_list))
            #axiom_ids = self.relation2id[line_list]
            axioms.append(axiom_ids)
        # for the case reflexive pool is empty
        if len(axioms) == 0:
            np.reshape(axioms, [-1, 3])
        return axioms
    
    # for each axioms in axiom pool
    # generate a series of entailments for each axiom
    def _materialize_axioms(self, generate=True, dump=True, load=False):
        if generate:
            self.reflexive2entailment = defaultdict(list)
            self.symmetric2entailment = defaultdict(list)
            self.transitive2entailment = defaultdict(list)
            self.inverse2entailment = defaultdict(list)
            self.equivalent2entailment = defaultdict(list)
            self.subproperty2entailment = defaultdict(list)
            self.inferencechain12entailment = defaultdict(list)
            self.inferencechain22entailment = defaultdict(list)
            self.inferencechain32entailment = defaultdict(list)
            self.inferencechain42entailment = defaultdict(list)
            self.reflexive_entailments, self.reflexive_entailments_num = self._materialize_sparse(self.axiompool_reflexive, type='reflexive')
            self.symmetric_entailments, self.symmetric_entailments_num = self._materialize_sparse(self.axiompool_symmetric, type='symmetric')
            self.transitive_entailments, self.transitive_entailments_num = self._materialize_sparse(self.axiompool_transitive, type='transitive')
            self.inverse_entailments, self.inverse_entailments_num = self._materialize_sparse(self.axiompool_inverse, type='inverse')
            self.subproperty_entailments, self.subproperty_entailments_num = self._materialize_sparse(self.axiompool_subproperty, type='subproperty')
            self.equivalent_entailments, self.equivalent_entailments_num  = self._materialize_sparse(self.axiompool_equivalent, type='equivalent')
            self.inferencechain1_entailments, self.inferencechain1_entailments_num = self._materialize_sparse(self.axiompool_inferencechain1, type='inferencechain1')
            self.inferencechain2_entailments, self.inferencechain2_entailments_num = self._materialize_sparse(self.axiompool_inferencechain2, type='inferencechain2')
            self.inferencechain3_entailments, self.inferencechain3_entailments_num = self._materialize_sparse(self.axiompool_inferencechain3, type='inferencechain3')
            self.inferencechain4_entailments, self.inferencechain4_entailments_num = self._materialize_sparse(self.axiompool_inferencechain4, type='inferencechain4')
            print('reflexive entailments for sparse: ', self.reflexive_entailments_num)
            print('symmetric entailments for sparse: ', self.symmetric_entailments_num)
            print('transitive entailments for sparse: ', self.transitive_entailments_num)
            print('inverse entailments for sparse: ', self.inverse_entailments_num)
            print('subproperty entailments for sparse: ', self.subproperty_entailments_num)
            print('equivalent entailments for sparse: ', self.equivalent_entailments_num)
            print('inferencechain1 entailments for sparse: ', self.inferencechain1_entailments_num)
            print('inferencechain2 entailments for sparse: ', self.inferencechain2_entailments_num)
            print('inferencechain3 entailments for sparse: ', self.inferencechain3_entailments_num)
            print('inferencechain4 entailments for sparse: ', self.inferencechain4_entailments_num)
            print("finish generate axioms entailments for sparse")
        if dump:
            pickle.dump(self.reflexive_entailments, open(os.path.join(self.axiom_dir, 'reflexive_entailments'), 'wb'))
            pickle.dump(self.symmetric_entailments, open(os.path.join(self.axiom_dir, 'symmetric_entailments'), 'wb'))
            pickle.dump(self.transitive_entailments, open(os.path.join(self.axiom_dir, 'transitive_entailments'), 'wb'))
            pickle.dump(self.inverse_entailments, open(os.path.join(self.axiom_dir, 'inverse_entailments'), 'wb'))
            pickle.dump(self.subproperty_entailments, open(os.path.join(self.axiom_dir, 'subproperty_entailments'), 'wb'))
            #pickle.dump(self.inferencechain_entailments, open(os.path.join(self.axiom_dir, 'inferencechain_entailments'), 'wb'))
            pickle.dump(self.equivalent_entailments, open(os.path.join(self.axiom_dir, 'equivalent_entailments'), 'wb'))
            pickle.dump(self.inferencechain1_entailments,
                        open(os.path.join(self.axiom_dir, 'inferencechain1_entailments'), 'wb'))
            pickle.dump(self.inferencechain2_entailments,
                        open(os.path.join(self.axiom_dir, 'inferencechain2_entailments'), 'wb'))
            pickle.dump(self.inferencechain3_entailments,
                        open(os.path.join(self.axiom_dir, 'inferencechain3_entailments'), 'wb'))
            pickle.dump(self.inferencechain4_entailments,
                        open(os.path.join(self.axiom_dir, 'inferencechain4_entailments'), 'wb'))
            print("finish dump axioms entialments")
        if load:
            print("load refexive entailments...")
            self.reflexive_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'reflexive_entailments'), 'rb'))
            print(self.reflexive_entailments)
            print('load symmetric entailments...')
            self.symmetric_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'symmetric_entailments'), 'rb'))
            print("load transitive entialments... ")
            self.transitive_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'transitive_entailments'), 'rb'))
            print("load inverse entailments...")
            self.inverse_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'inverse_entailments'), 'rb'))
            print("load subproperty entailments...")
            self.subproperty_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'subproperty_entailments'), 'rb'))
            #print("load inferencechain entailments...")
            #self.inferencechain_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'inferencechain_entailments'), 'rb'))
            print("load equivalent entialments...")
            self.equivalent_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'equivalent_entailments'), 'rb'))
            print("load inferencechain1 entailments...")
            self.inferencechain1_entailments = pickle.load(
                open(os.path.join(self.axiom_dir, 'inferencechain1_entailments'), 'rb'))
            print("load inferencechain2 entailments...")
            self.inferencechain2_entailments = pickle.load(
                open(os.path.join(self.axiom_dir, 'inferencechain2_entailments'), 'rb'))
            print("load inferencechain3 entailments...")
            self.inferencechain3_entailments = pickle.load(
                open(os.path.join(self.axiom_dir, 'inferencechain3_entailments'), 'rb'))
            print("load inferencechain4 entailments...")
            self.inferencechain4_entailments = pickle.load(
                open(os.path.join(self.axiom_dir, 'inferencechain4_entailments'), 'rb'))
            print("finish load axioms entailments")
    def _materialize_sparse(self, axioms, type=None, sparse = False):
        inference = []
        # axiom2entailment is a dict
        # with the all axioms in the axiom pool as keys
        # and all the entailments for each axiom as values
        axiom_list = axioms
        length = len(axioms)
        max_entailments = self.max_entialments
        num = 0
        if length == 0:
            if type == 'reflexive':
                np.reshape(inference, [-1, 3])
            elif type == 'symmetric' or type =='inverse' or  type =='equivalent' or type =='subproperty':
                np.reshape(inference, [-1, 6])
            elif type=='transitive' or type=='inferencechain':
                np.reshape(inference, [-1, 9])
            else:
                raise NotImplementedError
            return inference, num
        if type == 'reflexive':
            for axiom in axiom_list:
                axiom_key =tuple(axiom)
                r = axiom[0]
                inference_tmp = []
                for (h,t) in self.r_ht[r]:
                    # filter the axiom with too much entailments
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 2])
                        break
                    if h != t and self.entity2sparsity[h]>self.sparsity:
                        num += 1
                        inference_tmp.append([h,r,h])
                for entailment in inference_tmp:
                    self.reflexive2entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'symmetric':
            #self.symmetric2entailment = defaultdict(list)
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                r = axiom[0]
                inference_tmp = []
                for (h,t) in self.r_ht[r]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 2])
                        break
                    if (t,h) not in self.r_ht[r] and (self.entity2sparsity[h]>self.sparsity or self.entity2sparsity[t]>self.sparsity):
                        num += 1
                        inference_tmp.append([h,r,t,t,r,h])
                for entailment in inference_tmp:
                    self.symmetric2entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'transitive':
            #self.transitive2entailment = defaultdict(list)
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                r = axiom[0]
                inference_tmp = []
                for (h,t) in self.r_ht[r]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 9])
                        break
                    # (t,r,e) exist but (h,r,e) not exist and e!=h
                    for e in self.hr_t[(t,r)]- self.hr_t[(h,r)]:
                        if e != h and (self.entity2sparsity[h]>self.sparsity or self.entity2sparsity[e]>self.sparsity):
                            num += 1
                            inference_tmp.append([h,r,t,t,r,e,h,r,e])
                for entailment in inference_tmp:
                    self.transitive2entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'inverse':
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                r1,r2 = axiom
                inference_tmp = []
                for (h,t) in self.r_ht[r1]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 6])
                        break
                    if (t,h) not in self.r_ht[r2] and (self.entity2sparsity[h]>self.sparsity or self.entity2sparsity[t]>self.sparsity):
                        num += 1
                        inference_tmp.append([h,r1,t, t,r2,h])
                        #self.inverse2entailment[axiom_key].append([h,r1,t, t,r2,h])
                for entailment in inference_tmp:
                    self.inverse2entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'equivalent' or type =='subproperty':
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                r1,r2 = axiom
                inference_tmp = []
                for (h,t) in self.r_ht[r1]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 6])
                        break
                    if (h,t) not in self.r_ht[r2] and (self.entity2sparsity[h]>self.sparsity or self.entity2sparsity[t]>self.sparsity):
                        num += 1
                        inference_tmp.append([h,r1,t, h,r2,t])
                for entailment in inference_tmp:
                    self.equivalent2entailment[axiom_key].append(entailment)
                    self.subproperty2entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'inferencechain1':
            self.inferencechain12entailment = defaultdict(list)
            i = 0
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                i += 1
                # print('%d/%d' % (i, length))
                r1, r2, r3 = axiom
                inference_tmp = []
                for (e, h) in self.r_ht[r2]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 9])
                        break
                    for t in self.hr_t[(e, r3)]:
                        if (h, t) not in self.r_ht[r1] and (
                                        self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[e] > self.sparsity):
                            num += 1
                            inference_tmp.append([e, r2, h, e, r3, t, h, r1, t])
                            #self.inferencechain12entailment[axiom_key].append([[e, r2, h, e, r3, t, h, r1, t]])
                for entailment in inference_tmp:
                    self.inferencechain12entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'inferencechain2':
            self.inferencechain22entailment = defaultdict(list)
            i = 0
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                i += 1
                # print('%d/%d' % (i, length))
                r1, r2, r3 = axiom
                inference_tmp = []
                for (e, h) in self.r_ht[r2]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 9])
                        break
                    for t in self.tr_h[(e, r3)]:
                        if (h, t) not in self.r_ht[r1] and (
                                        self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[e] > self.sparsity):
                            num += 1
                            inference_tmp.append([e, r2, h, t, r3, e, h, r1, t])
                            #self.inferencechain22entailment[axiom_key].append([[e, r2, h, t, r3, e, h, r1, t]])
                for entailment in inference_tmp:
                    self.inferencechain22entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'inferencechain3':
            self.inferencechain32entailment = defaultdict(list)
            i = 0
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                i += 1
                # print('%d/%d' % (i, length))
                r1, r2, r3 = axiom
                inference_tmp = []
                for (h, e) in self.r_ht[r2]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 9])
                        break
                    for t in self.hr_t[(e, r3)]:
                        if (h, t) not in self.r_ht[r1] and (
                                        self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[e] > self.sparsity):
                            num += 1
                            inference_tmp.append([h, r2, e, e, r3, t, h, r1, t])
                for entailment in inference_tmp:
                    self.inferencechain32entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'inferencechain4':
            self.inferencechain42entailment = defaultdict(list)
            i = 0
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                i += 1
                # print('%d/%d' % (i, length))
                r1, r2, r3 = axiom
                inference_tmp = []
                for (h, e) in self.r_ht[r2]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 9])
                        break
                    for t in self.tr_h[(e, r3)]:
                        if (h, t) not in self.r_ht[r1] and (
                                        self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[e] > self.sparsity):
                            num += 1
                            inference_tmp.append([h, r2, e, t, r3, e, h, r1, t])
                for entailment in inference_tmp:
                    self.inferencechain42entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        return inference, num
    def _materialize(self, axioms, type=None, sparse=False):
        inference = []
        # axiom2entailment is a dict
        # with the all axioms in the axiom pool as keys
        # and all the entailments for each axiom as values
        axiom_list = axioms
        # print('axiom_list', axiom_list)
        length = len(axioms)
        max_entailments = 5000
        num = 0
        if length == 0:
            if type == 'reflexive':
                np.reshape(inference, [-1, 3])
            elif type == 'symmetric' or type == 'inverse' or type == 'equivalent' or type == 'subproperty':
                np.reshape(inference, [-1, 6])
            elif type == 'transitive' or type == 'inferencechain':
                np.reshape(inference, [-1, 9])
            else:
                raise NotImplementedError
            return inference, num
        if type == 'reflexive':
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                r = axiom[0]
                inference_tmp = []
                for (h, t) in self.r_ht[r]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 2])
                        break
                    if h != t: #and self.entity2sparsity[h] > self.sparsity:
                        num += 1
                        inference_tmp.append([h, r, h])
                for entailment in inference_tmp:
                    self.reflexive2entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'symmetric':
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                r = axiom[0]
                inference_tmp = []
                for (h, t) in self.r_ht[r]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 2])
                        break
                    if (t, h) not in self.r_ht[r]: #and (self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[t] > self.sparsity):
                        num += 1
                        inference_tmp.append([h, r, t, t, r, h])
                for entailment in inference_tmp:
                    self.symmetric2entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'transitive':
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                r = axiom[0]
                inference_tmp = []
                for (h, t) in self.r_ht[r]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 9])
                        break
                    # (t,r,e) exist but (h,r,e) not exist and e!=h
                    for e in self.hr_t[(t, r)] - self.hr_t[(h, r)]:
                        if e != h: #and (self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[e] > self.sparsity):
                            num += 1
                            inference_tmp.append([h, r, t, t, r, e, h, r, e])
                for entailment in inference_tmp:
                    self.transitive2entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'inverse':
            # self.inverse2entailment = defaultdict(list)
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                r1, r2 = axiom
                inference_tmp = []
                for (h, t) in self.r_ht[r1]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 6])
                        break
                    if (t, h) not in self.r_ht[r2]: #and (self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[t] > self.sparsity):
                        num += 1
                        inference_tmp.append([h, r1, t, t, r2, h])
                for entailment in inference_tmp:
                    self.inverse2entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'equivalent' or type == 'subproperty':
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                r1, r2 = axiom
                inference_tmp = []
                for (h, t) in self.r_ht[r1]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 6])
                        break
                    if (h, t) not in self.r_ht[r2]: #and (self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[t] > self.sparsity):
                        num += 1
                        inference_tmp.append([h, r1, t, h, r2, t])
                for entailment in inference_tmp:
                    self.equivalent2entailment[axiom_key].append(entailment)
                    self.subproperty2entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'inferencechain1':
            self.inferencechain12entailment = defaultdict(list)
            i = 0
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                i += 1
                # print('%d/%d' % (i, length))
                r1, r2, r3 = axiom
                inference_tmp = []
                for (e, h) in self.r_ht[r2]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 9])
                        break
                    for t in self.hr_t[(e, r3)]:
                        if (h, t) not in self.r_ht[r1]:
                            num += 1
                            inference_tmp.append([e, r2, h, e, r3, t, h, r1, t])
                for entailment in inference_tmp:
                    self.inferencechain12entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'inferencechain2':
            self.inferencechain22entailment = defaultdict(list)
            i = 0
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                i += 1
                # print('%d/%d' % (i, length))
                r1, r2, r3 = axiom
                inference_tmp = []
                for (e, h) in self.r_ht[r2]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 9])
                        break
                    for t in self.tr_h[(e, r3)]:
                        if (h, t) not in self.r_ht[r1]:
                            num += 1
                            inference_tmp.append([e, r2, h, t, r3, e, h, r1, t])
                for entailment in inference_tmp:
                    self.inferencechain22entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'inferencechain3':
            self.inferencechain32entailment = defaultdict(list)
            i = 0
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                i += 1
                # print('%d/%d' % (i, length))
                r1, r2, r3 = axiom
                inference_tmp = []
                for (h, e) in self.r_ht[r2]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 9])
                        break
                    for t in self.hr_t[(e, r3)]:
                        if (h, t) not in self.r_ht[r1]:
                            num += 1
                            inference_tmp.append([h, r2, e, e, r3, t, h, r1, t])
                for entailment in inference_tmp:
                    self.inferencechain32entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        if type == 'inferencechain4':
            self.inferencechain42entailment = defaultdict(list)
            i = 0
            for axiom in axiom_list:
                axiom_key = tuple(axiom)
                i += 1
                r1, r2, r3 = axiom
                inference_tmp = []
                for (h, e) in self.r_ht[r2]:
                    if len(inference_tmp) > max_entailments:
                        inference_tmp = np.reshape([], [-1, 9])
                        break
                    for t in self.tr_h[(e, r3)]:
                        if (h, t) not in self.r_ht[r1]:
                            num += 1
                            inference_tmp.append([h, r2, e, t, r3, e, h, r1, t])
                for entailment in inference_tmp:
                    self.inferencechain42entailment[axiom_key].append(entailment)
                inference.append(inference_tmp)
        return inference, num
    
    def _init_valid_axioms(self):
        # init valid axioms
        self.valid_reflexive, self.valid_symmetric, self.valid_transitive,\
        
self.valid_inverse, self.valid_subproperty, self.valid_equivalent,\
        
self.valid_inferencechain1, self.valid_inferencechain2, \
        
self.valid_inferencechain3, self.valid_inferencechain4 = [[] for x in range(self.axiom_types)]
        # init valid axiom entailments
        self.valid_reflexive2entailment, self.valid_symmetric2entailment, self.valid_transitive2entailment, \
        
self.valid_inverse2entailment, self.valid_subproperty2entailment, self.valid_equivalent2entailment, \
        
self.valid_inferencechain12entailment, self.valid_inferencechain22entailment, \
        
self.valid_inferencechain32entailment, self.valid_inferencechain42entailment = [[] for x in range(self.axiom_types)]
        # init valid axiom entailments probability
        self.valid_reflexive_p, self.valid_symmetric_p, self.valid_transitive_p, \
        
self.valid_inverse_p, self.valid_subproperty_p, self.valid_equivalent_p, \
        
self.valid_inferencechain1_p, self.valid_inferencechain2_p,\
        
self.valid_inferencechain3_p, self.valid_inferencechain4_p= [[] for x in range(self.axiom_types)]
        # init valid axiom batchsize
        self.reflexive_batchsize = 1
        self.symmetric_batchsize = 1
        self.transitive_batchsize = 1
        self.inverse_batchsize = 1
        self.subproperty_batchsize = 1
        self.equivalent_batchsize = 1
        #self.inferencechain_batchsize = 1
        self.inferencechain1_batchsize = 1
        self.inferencechain2_batchsize = 1
        self.inferencechain3_batchsize = 1
        self.inferencechain4_batchsize = 1
    # add the new triples from axioms to training triple
[docs]    def update_train_triples(self, epoch=0, update_per = 10):
        """add the new triples from axioms to training triple
        Args:
            epoch (int, optional): epoch in training process. Defaults to 0.
            update_per (int, optional): Defaults to 10.
        Returns:
            updated_train_data: training triple after adding the new triples from axioms
        """
        reflexive_triples, symmetric_triples, transitive_triples, inverse_triples,\
            
equivalent_triples, subproperty_triples, inferencechain1_triples, \
            
inferencechain2_triples, inferencechain3_triples, inferencechain4_triples = [ np.reshape(np.asarray([]), [-1, 3]) for i in range(self.axiom_types)]
        reflexive_p, symmetric_p, transitive_p, inverse_p, \
            
equivalent_p, subproperty_p, inferencechain1_p, \
            
inferencechain2_p, inferencechain3_p, inferencechain4_p = [np.reshape(np.asarray([]), [-1, 1]) for i in
                                                                           range(self.axiom_types)]
        updated_train_data=None
        if epoch >= 20:
        #if True:
            print("len(self.valid_reflexive2entailment):", len(self.valid_reflexive2entailment))
            print("len(self.valid_symmetric2entailment):", len(self.valid_symmetric2entailment))
            print("len(self.valid_transitive2entailment)", len(self.valid_transitive2entailment))
            print("len(self.valid_inverse2entailment)", len(self.valid_inverse2entailment))
            print("len(self.valid_equivalent2entailment)", len(self.valid_equivalent2entailment))
            print("len(self.valid_subproperty2entailment)", len(self.valid_subproperty2entailment))
            valid_reflexive2entailment, valid_symmetric2entailment, valid_transitive2entailment,\
            
valid_inverse2entailment, valid_equivalent2entailment, valid_subproperty2entailment, \
            
valid_inferencechain12entailment, valid_inferencechain22entailment,\
            
valid_inferencechain32entailment, valid_inferencechain42entailment = [[] for i in range(10)]
            if len(self.valid_reflexive2entailment)>0:
                valid_reflexive2entailment = np.reshape(np.asarray(self.valid_reflexive2entailment), [-1, 3])
                reflexive_triples = np.asarray(valid_reflexive2entailment)[:, -3:]
                reflexive_p = np.reshape(np.asarray(self.valid_reflexive_p),[-1,1])
            if len(self.valid_symmetric2entailment) > 0:
                valid_symmetric2entailment = np.reshape(np.asarray(self.valid_symmetric2entailment), [-1, 6])
                symmetric_triples = np.asarray(valid_symmetric2entailment)[:, -3:]
                symmetric_p = np.reshape(np.asarray(self.valid_symmetric_p),[-1,1])
            if len(self.valid_transitive2entailment) > 0:
                valid_transitive2entailment = np.reshape(np.asarray(self.valid_transitive2entailment), [-1, 9])
                transitive_triples = np.asarray(valid_transitive2entailment)[:, -3:]
                transitive_p = np.reshape(np.asarray(self.valid_transitive_p), [-1, 1])
            if len(self.valid_inverse2entailment) > 0:
                valid_inverse2entailment = np.reshape(np.asarray(self.valid_inverse2entailment), [-1, 6])
                inverse_triples = np.asarray(valid_inverse2entailment)[:, -3:]
                inverse_p = np.reshape(np.asarray(self.valid_inverse_p), [-1, 1])
            if len(self.valid_equivalent2entailment) > 0:
                valid_equivalent2entailment = np.reshape(np.asarray(self.valid_equivalent2entailment), [-1, 6])
                equivalent_triples = np.asarray(valid_equivalent2entailment)[:, -3:]
                equivalent_p = np.reshape(np.asarray(self.valid_equivalent_p), [-1, 1])
            if len(self.valid_subproperty2entailment) > 0:
                valid_subproperty2entailment = np.reshape(np.asarray(self.valid_subproperty2entailment), [-1, 6])
                subproperty_triples = np.asarray(valid_subproperty2entailment)[:, -3:]
                subproperty_p = np.reshape(np.asarray(self.valid_subproperty_p),[-1,1])
            if len(self.valid_inferencechain12entailment) > 0:
                valid_inferencechain12entailment = np.reshape(np.asarray(self.valid_inferencechain12entailment), [-1, 9])
                inferencechain1_triples = np.asarray(valid_inferencechain12entailment)[:, -3:]
                inferencechain1_p = np.reshape(np.asarray(self.valid_inferencechain1_p), [-1, 1])
            if len(self.valid_inferencechain22entailment) > 0:
                valid_inferencechain22entailment = np.reshape(np.asarray(self.valid_inferencechain22entailment), [-1, 9])
                inferencechain2_triples = np.asarray(valid_inferencechain22entailment)[:, -3:]
                inferencechain2_p = np.reshape(np.asarray(self.valid_inferencechain2_p), [-1, 1])
            if len(self.valid_inferencechain32entailment) > 0:
                valid_inferencechain32entailment = np.reshape(np.asarray(self.valid_inferencechain32entailment), [-1, 9])
                inferencechain3_triples = np.asarray(valid_inferencechain32entailment)[:, -3:]
                inferencechain3_p = np.reshape(np.asarray(self.valid_inferencechain3_p), [-1, 1])
            if len(self.valid_inferencechain42entailment) > 0:
                valid_inferencechain42entailment = np.reshape(np.asarray(self.valid_inferencechain42entailment), [-1, 9])
                inferencechain4_triples = np.asarray(valid_inferencechain42entailment)[:, -3:]
                inferencechain4_p = np.reshape(np.asarray(self.valid_inferencechain4_p), [-1, 1])
            # pickle.dump(self.reflexive_entailments, open(os.path.join(self.axiom_dir, 'reflexive_entailments'), 'wb'))
            # store all the injected triples
            entailment_all = (valid_reflexive2entailment, valid_symmetric2entailment, valid_transitive2entailment,
                     valid_inverse2entailment, valid_equivalent2entailment, valid_subproperty2entailment,
                     valid_inferencechain12entailment,valid_inferencechain22entailment,
                              valid_inferencechain32entailment,valid_inferencechain42entailment)
            pickle.dump(entailment_all, open(os.path.join(self.axiom_dir, 'valid_entailments.pickle'), 'wb'))
        train_inject_triples = np.concatenate([reflexive_triples, symmetric_triples, transitive_triples, inverse_triples,
                                                equivalent_triples, subproperty_triples, inferencechain1_triples,
                                               inferencechain2_triples,inferencechain3_triples,inferencechain4_triples],
                                                axis=0)
        train_inject_triples_p = np.concatenate([reflexive_p,symmetric_p, transitive_p, inverse_p,
                                               equivalent_p, subproperty_p, inferencechain1_p,
                                                 inferencechain2_p,inferencechain3_p,inferencechain4_p],
                                              axis=0)
        self.train_inject_triples = train_inject_triples
        inject_labels = np.reshape(np.ones(len(train_inject_triples)), [-1, 1]) * self.axiom_weight * train_inject_triples_p
        train_inject_ids_labels = np.concatenate([train_inject_triples, inject_labels],
                                                axis=1)
        self.train_ids_labels_inject = train_inject_triples#train_inject_ids_labels
        print('num reflexive triples', len(reflexive_triples))
        print('num symmetric triples', len(symmetric_triples))
        print('num transitive triples', len(transitive_triples))
        print('num inverse triples', len(inverse_triples))
        print('num equivalent triples', len(equivalent_triples))
        print('num subproperty triples', len(subproperty_triples))
        print('num inferencechain1 triples', len(inferencechain1_triples))
        print('num inferencechain2 triples', len(inferencechain2_triples))
        print('num inferencechain3 triples', len(inferencechain3_triples))
        print('num inferencechain4 triples', len(inferencechain4_triples))
        #print(self.train_ids_labels_inject)
        updated_train_data=self.generate_new_train_triples()
        return updated_train_data 
[docs]    def split_embedding(self, embedding):
        """split embedding
        Args:
            embedding: embeddings need to be splited, shape:[None, dim].
        Returns:
            probability: The similrity between two matrices.
        """
        # embedding: [None, dim]
        assert self.args.emb_dim % 4 == 0
        num_scalar = self.args.emb_dim // 2
        num_block = self.args.emb_dim // 4
        embedding_scalar = embedding[:, 0:num_scalar]
        embedding_x = embedding[:, num_scalar:-num_block]
        embedding_y = embedding[:, -num_block:]
        return embedding_scalar, embedding_x, embedding_y 
    
    # calculate the similrity between two matrices
    # head: [?, dim]
    # tail: [?, dim] or [1,dim]
[docs]    def sim(self, head=None, tail=None, arity=None):
        """calculate the similrity between two matrices
        Args:
            head: embeddings of head, shape:[batch_size, dim].
            tail: embeddings of tail, shape:[batch_size, dim] or [1, dim].
            arity: 1,2 or 3
        Returns:
            probability: The similrity between two matrices.
        """
        if arity == 1:
            A_scalar, A_x, A_y = self.split_embedding(head)
        elif arity == 2:
            M1_scalar, M1_x, M1_y = self.split_embedding(head[0])
            M2_scalar, M2_x, M2_y = self.split_embedding(head[1])
            A_scalar= M1_scalar * M2_scalar
            A_x = M1_x*M2_x - M1_y*M2_y
            A_y = M1_x*M2_y + M1_y*M2_x
        elif arity==3:
            M1_scalar, M1_x, M1_y = self.split_embedding(head[0])
            M2_scalar, M2_x, M2_y = self.split_embedding(head[1])
            M3_scalar, M3_x, M3_y = self.split_embedding(head[2])
            M1M2_scalar = M1_scalar * M2_scalar
            M1M2_x = M1_x * M2_x - M1_y * M2_y
            M1M2_y = M1_x * M2_y + M1_y * M2_x
            A_scalar = M1M2_scalar * M3_scalar
            A_x = M1M2_x * M3_x - M1M2_y * M3_y
            A_y = M1M2_x * M3_y + M1M2_y * M3_x
        else:
            raise NotImplemented
        B_scala, B_x, B_y = self.split_embedding(tail)
        similarity = torch.cat([(A_scalar - B_scala)**2, (A_x - B_x)**2, (A_x - B_x)**2, (A_y - B_y)**2, (A_y - B_y)**2 ], dim=1)
        similarity = torch.sqrt(torch.sum(similarity, dim=1))
        #recale the probability
        probability = (torch.max(similarity)-similarity)/(torch.max(similarity)-torch.min(similarity))
        return probability 
    # generate a probality for each axiom in axiom pool
[docs]    def run_axiom_probability(self):
        """this function is used to generate a probality for each axiom in axiom pool
        """
        self.identity = torch.cat((torch.ones(int(self.args.emb_dim-self.args.emb_dim/4)),torch.zeros(int(self.args.emb_dim/4))),0).unsqueeze(0).cuda()
        
        if len(self.axiompool_reflexive) != 0: 
            index = torch.LongTensor(self.axiompool_reflexive).cuda()
            reflexive_embed = self.rel_emb(index)
            reflexive_prob = self.sim(head=reflexive_embed[:, 0, :], tail=self.identity, arity=1)
        else: 
            reflexive_prob = []
        if len(self.axiompool_symmetric) != 0: 
            index = torch.LongTensor(self.axiompool_symmetric).cuda()
            symmetric_embed = self.rel_emb(index)
            symmetric_prob = self.sim(head=[symmetric_embed[:, 0, :], symmetric_embed[:, 0, :]], tail=self.identity, arity=2)
            #symmetric_prob = sess.run(self.symmetric_probability, {self.symmetric_pool: self.axiompool_symmetric})
        else: 
            symmetric_prob = []
        if len(self.axiompool_transitive) != 0: 
            index = torch.LongTensor(self.axiompool_transitive).cuda()
            transitive_embed = self.rel_emb(index)
            transitive_prob = self.sim(head=[transitive_embed[:, 0, :], transitive_embed[:, 0, :]], tail=transitive_embed[:, 0, :], arity=2)
            #transitive_prob = sess.run(self.transitive_probability, {self.transitive_pool: self.axiompool_transitive})
        else: 
            transitive_prob = []
        if len(self.axiompool_inverse) != 0: 
            index = torch.LongTensor(self.axiompool_inverse).cuda()
            #inverse_prob = sess.run(self.inverse_probability, {self.inverse_pool: self.axiompool_inverse})
            inverse_embed = self.rel_emb(index)
            inverse_probability1 = self.sim(head=[inverse_embed[:, 0,:],inverse_embed[:, 1,:]], tail = self.identity, arity=2)
            inverse_probability2 = self.sim(head=[inverse_embed[:,1,:],inverse_embed[:, 0,:]], tail=self.identity, arity=2)
            inverse_prob = (inverse_probability1 + inverse_probability2)/2
        else: 
            inverse_prob = []
        if len(self.axiompool_subproperty) != 0: 
            index = torch.LongTensor(self.axiompool_subproperty).cuda()
            #subproperty_prob = sess.run(self.subproperty_probability, {self.subproperty_pool: self.axiompool_subproperty})
            subproperty_embed = self.rel_emb(index)
            subproperty_prob = self.sim(head=subproperty_embed[:, 0,:], tail=subproperty_embed[:, 1, :], arity=1)
        else: 
            subproperty_prob = []
        if len(self.axiompool_equivalent) != 0: 
            index = torch.LongTensor(self.axiompool_equivalent).cuda()
            #equivalent_prob = sess.run(self.equivalent_probability, {self.equivalent_pool: self.axiompool_equivalent})
            equivalent_embed = self.rel_emb(index)
            equivalent_prob = self.sim(head=equivalent_embed[:, 0,:], tail=equivalent_embed[:, 1,:], arity=1)
        else: 
            equivalent_prob = []
        if len(self.axiompool_inferencechain1) != 0:
            index = torch.LongTensor(self.axiompool_inferencechain1).cuda()
            inferencechain_embed = self.rel_emb(index) 
            inferencechain1_prob = self.sim(head=[inferencechain_embed[:, 1, :], inferencechain_embed[:, 0, :]], tail=inferencechain_embed[:, 2, :], arity=2)
        else:
            inferencechain1_prob = []
        if len(self.axiompool_inferencechain2) != 0:
            index = torch.LongTensor(self.axiompool_inferencechain2).cuda()
            inferencechain_embed = self.rel_emb(index)
            inferencechain2_prob = self.sim(head=[inferencechain_embed[:, 2, :], inferencechain_embed[:, 1, :], inferencechain_embed[:, 0, :]], tail=self.identity, arity=3)
        else:
            inferencechain2_prob = []
        if len(self.axiompool_inferencechain3) != 0:
            index = torch.LongTensor(self.axiompool_inferencechain3).cuda()
            inferencechain_embed = self.rel_emb(index)
            inferencechain3_prob = self.sim(head=[inferencechain_embed[:, 1, :], inferencechain_embed[:, 2, :]], tail=inferencechain_embed[:, 0, :], arity=2)
        else:
            inferencechain3_prob = []
        if len(self.axiompool_inferencechain4) != 0:
            index = torch.LongTensor(self.axiompool_inferencechain4).cuda()
            inferencechain_embed = self.rel_emb(index)
            inferencechain4_prob = self.sim(head=[inferencechain_embed[:, 0, :], inferencechain_embed[:, 2, :]],tail=inferencechain_embed[:, 1, :], arity=2)
        else:
            inferencechain4_prob = []
        output = [reflexive_prob, symmetric_prob, transitive_prob, inverse_prob,
                  subproperty_prob,equivalent_prob,inferencechain1_prob, inferencechain2_prob,
                  inferencechain3_prob, inferencechain4_prob]
        return output 
[docs]    def update_valid_axioms(self, input):
        """this function is used to select high probability axioms as valid axioms and record their scores
        """
        # 
        # 
        valid_axioms = [self._select_high_probability(list(prob), axiom) for prob,axiom in zip(input, self.axiompool)]
        self.valid_reflexive, self.valid_symmetric, self.valid_transitive, \
        
self.valid_inverse, self.valid_subproperty, self.valid_equivalent, \
        
self.valid_inferencechain1, self.valid_inferencechain2, \
        
self.valid_inferencechain3, self.valid_inferencechain4 = valid_axioms
        # update the batchsize of axioms and entailments
        self._reset_valid_axiom_entailment() 
    def _select_high_probability(self, prob, axiom):
        # select the high probability axioms and recore their probabilities
        valid_axiom = [[axiom[prob.index(p)],[p]] for p in prob if p>self.select_probability]
        return valid_axiom
    def _reset_valid_axiom_entailment(self):
        self.infered_hr_t = defaultdict(set)
        self.infered_tr_h = defaultdict(set)
        self.valid_reflexive2entailment, self.valid_reflexive_p = \
            
self._valid_axiom2entailment(self.valid_reflexive, self.reflexive2entailment)
        self.valid_symmetric2entailment, self.valid_symmetric_p = \
            
self._valid_axiom2entailment(self.valid_symmetric, self.symmetric2entailment)
        self.valid_transitive2entailment, self.valid_transitive_p = \
            
self._valid_axiom2entailment(self.valid_transitive, self.transitive2entailment)
        self.valid_inverse2entailment, self.valid_inverse_p = \
            
self._valid_axiom2entailment(self.valid_inverse, self.inverse2entailment)
        self.valid_subproperty2entailment, self.valid_subproperty_p = \
            
self._valid_axiom2entailment(self.valid_subproperty, self.subproperty2entailment)
        self.valid_equivalent2entailment, self.valid_equivalent_p = \
            
self._valid_axiom2entailment(self.valid_equivalent, self.equivalent2entailment)
        self.valid_inferencechain12entailment, self.valid_inferencechain1_p = \
            
self._valid_axiom2entailment(self.valid_inferencechain1, self.inferencechain12entailment)
        self.valid_inferencechain22entailment, self.valid_inferencechain2_p = \
            
self._valid_axiom2entailment(self.valid_inferencechain2, self.inferencechain22entailment)
        self.valid_inferencechain32entailment, self.valid_inferencechain3_p = \
            
self._valid_axiom2entailment(self.valid_inferencechain3, self.inferencechain32entailment)
        self.valid_inferencechain42entailment, self.valid_inferencechain4_p = \
            
self._valid_axiom2entailment(self.valid_inferencechain4, self.inferencechain42entailment)
    def _valid_axiom2entailment(self, valid_axiom, axiom2entailment):
        valid_axiom2entailment = []
        valid_axiom_p = []
        for axiom_p in valid_axiom:
            axiom = tuple(axiom_p[0])
            p = axiom_p[1]
            for entailment in axiom2entailment[axiom]:
                valid_axiom2entailment.append(entailment)
                valid_axiom_p.append(p)
                h,r,t = entailment[-3:]
                self.infered_hr_t[(h,r)].add(t)
                self.infered_tr_h[(t,r)].add(h)
        return valid_axiom2entailment, valid_axiom_p
    # updata new train triples:
[docs]    def generate_new_train_triples(self):
        """The function is to updata new train triples and used after each training epoch end
        Returns:
            self.train_sampler.train_triples: The new training dataset (triples).
        """
        self.train_sampler.train_triples = copy.deepcopy(self.train_triples_base)
        print('generate_new_train_triples...')
        #origin_triples = train_sampler.train_triples
        inject_triples = self.train_ids_labels_inject
        inject_num = int(self.inject_triple_percent*len(self.train_sampler.train_triples))
        if len(inject_triples)> inject_num and inject_num >0:
            np.random.shuffle(inject_triples)
            inject_triples = inject_triples[:inject_num]
        #train_triples = np.concatenate([origin_triples, inject_triples], axis=0)
        print('当前train_sampler.train_triples数目',len(self.train_sampler.train_triples))
        
        for h,r,t in inject_triples:       
            self.train_sampler.train_triples.append((int(h),int(r),int(t)))
        print('添加后train_sampler.train_triples数目',len(self.train_sampler.train_triples))
        return self.train_sampler.train_triples 
    
[docs]    def get_rule(self, rel2id):
        """Get rule for rule_base KGE models, such as ComplEx_NNE model.
        Get rule and confidence from _cons.txt file.
        Update:
            (rule_p, rule_q): Rule.
            confidence: The confidence of rule.
        """
        rule_p, rule_q, confidence = [], [], []
        with open(os.path.join(self.args.data_path, '_cons.txt')) as file:
            lines = file.readlines()
            for line in lines:
                rule_str, trust = line.strip().split()
                body, head = rule_str.split(',')
                if '-' in body:
                    rule_p.append(rel2id[body[1:]])
                    rule_q.append(rel2id[head])
                else:
                    rule_p.append(rel2id[body])
                    rule_q.append(rel2id[head])
                confidence.append(float(trust))
        rule_p = torch.tensor(rule_p).cuda()
        rule_q = torch.tensor(rule_q).cuda()
        confidence = torch.tensor(confidence).cuda()
        return (rule_p, rule_q), confidence 
    """def init_emb(self):
        self.epsilon = 2.0
        self.margin = nn.Parameter(
            torch.Tensor([self.args.margin]), 
            requires_grad=False
        )
        self.embedding_range = nn.Parameter(
            torch.Tensor([(self.margin.item() + self.epsilon) / self.args.emb_dim]), 
            requires_grad=False
        )
        
        self.ent_emb = nn.Embedding(self.args.num_ent, self.args.emb_dim * 2)
        self.rel_emb = nn.Embedding(self.args.num_rel, self.args.emb_dim * 2)
        nn.init.uniform_(tensor=self.ent_emb.weight.data, a=-self.embedding_range.item(), b=self.embedding_range.item())
        nn.init.uniform_(tensor=self.rel_emb.weight.data, a=-self.embedding_range.item(), b=self.embedding_range.item())"""
    
[docs]    def init_emb(self):
        """Initialize the entity and relation embeddings in the form of a uniform distribution.
        """
        self.epsilon = 2.0
        self.margin = nn.Parameter(
            torch.Tensor([self.args.margin]), 
            requires_grad=False
        )
        self.embedding_range = nn.Parameter(
            torch.Tensor([(self.margin.item() + self.epsilon) / self.args.emb_dim]), 
            requires_grad=False
        )
        self.ent_emb = nn.Embedding(self.args.num_ent, self.args.emb_dim)
        self.rel_emb = nn.Embedding(self.args.num_rel, self.args.emb_dim)
        nn.init.uniform_(tensor=self.ent_emb.weight.data, a=-self.embedding_range.item(), b=self.embedding_range.item())
        nn.init.uniform_(tensor=self.rel_emb.weight.data, a=-self.embedding_range.item(), b=self.embedding_range.item()) 
        
    """def score_func(self, head_emb, relation_emb, tail_emb, mode):
        re_head, im_head = torch.chunk(head_emb, 2, dim=-1)
        re_relation, im_relation = torch.chunk(relation_emb, 2, dim=-1)
        re_tail, im_tail = torch.chunk(tail_emb, 2, dim=-1)
        return torch.sum(
            re_head * re_tail * re_relation
            + im_head * im_tail * re_relation
            + re_head * im_tail * im_relation
            - im_head * re_tail * im_relation,
            -1
        )"""
        
[docs]    def score_func(self, head_emb, relation_emb, tail_emb, mode):
        """Calculating the score of triples.
        
        The formula for calculating the score is DistMult.
        Args:
            head_emb: The head entity embedding.
            relation_emb: The relation embedding.
            tail_emb: The tail entity embedding.
            mode: Choose head-predict or tail-predict.
        Returns:
            score: The score of triples.
        """
        if mode == 'head-batch':
            score = head_emb * (relation_emb * tail_emb)
        else:
            score = (head_emb * relation_emb) * tail_emb
        score = score.sum(dim = -1)
        return score 
[docs]    def forward(self, triples, negs=None, mode='single'):
        """The functions used in the training phase
        Args:
            triples: The triples ids, as (h, r, t), shape:[batch_size, 3].
            negs: Negative samples, defaults to None.
            mode: Choose head-predict or tail-predict, Defaults to 'single'.
        Returns:
            score: The score of triples.
        """
        head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
        score = self.score_func(head_emb, relation_emb, tail_emb, mode)
        return score 
    
[docs]    def get_score(self, batch, mode):
        """The functions used in the testing phase
        Args:
            batch: A batch of data.
            mode: Choose head-predict or tail-predict.
        Returns:
            score: The score of triples.
        """
        triples = batch["positive_sample"]
        head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
        score = self.score_func(head_emb, relation_emb, tail_emb, mode)
        return score