Source code for neuralkg.model.RuleModel.IterE

import torch.nn as nn
import torch
import os
from .model import Model
from IPython import embed
from collections import defaultdict
import numpy as np
import pickle 
import copy

[docs]class IterE(Model): """`Iteratively Learning Embeddings and Rules for Knowledge Graph Reasoning. (WWW'19)`_ (IterE). Attributes: args: Model configuration parameters. epsilon: Caculate embedding_range. margin: Caculate embedding_range and loss. embedding_range: Uniform distribution range. ent_emb: Entity embedding, shape:[num_ent, emb_dim]. rel_emb: Relation_embedding, shape:[num_rel, emb_dim]. .. _Iteratively Learning Embeddings and Rules for Knowledge Graph Reasoning. (WWW'19): https://dl.acm.org/doi/10.1145/3308558.3313612 """ def __init__(self, args, train_sampler, test_sampler): super(IterE, self).__init__(args) self.args = args self.ent_emb = None self.rel_emb = None self.init_emb() #print(self.args) #print(train_sampler) #print('run get_axiom()') self.train_sampler = train_sampler self.train_triples_base = copy.deepcopy(train_sampler.train_triples) self.select_probability = self.args.select_probability self.max_entialments = self.args.max_entialments self.axiom_types = self.args.axiom_types self.axiom_weight = self.args.axiom_weight self.inject_triple_percent = self.args.inject_triple_percent self.sparsity = 0.995 self.num_entity = self.args.num_ent self.relation2id=train_sampler.rel2id self.train_ids=train_sampler.train_triples self.valid_ids=train_sampler.valid_triples self.test_ids=train_sampler.test_triples #print(len(self.train_ids)) #print(len(self.valid_ids)) #print(len(self.test_ids)) self.train_ids_labels_inject = np.reshape([], [-1, 4]) # generate r_ht, hr_t print('# generate r_ht, hr_t') self.r_ht, self.hr_t, self.tr_h, self.hr_t_all, self.tr_h_all = self._generate(self.train_ids, self.valid_ids, self.test_ids) # generate entity2frequency and entity2sparsity dict print('# generate entity2frequency and entity2sparsity dict') self.entity2frequency, self.entity2sparsity = self._entity2frequency() print('# get_axiom') self.get_axiom() #self.rule, self.conf = self.get_rule(self.relation2id) def _entity2frequency(self): ent2freq = {ent:0 for ent in range(self.num_entity)} ent2sparsity = {ent:-1 for ent in range(self.num_entity)} for h,r,t in self.train_ids: ent2freq[h] += 1 ent2freq[t] += 1 ent_freq_list = np.asarray([ent2freq[ent] for ent in range(self.num_entity)]) ent_freq_list_sort = np.argsort(ent_freq_list) max_freq = max(list(ent2freq)) min_freq = min(list(ent2freq)) for ent, freq in ent2freq.items(): sparsity = 1 - (freq-min_freq)/(max_freq - min_freq) ent2sparsity[ent] = sparsity return ent2freq, ent2sparsity def _generate(self, train, valid, test): r_ht = defaultdict(set) hr_t = defaultdict(set) tr_h = defaultdict(set) hr_t_all = defaultdict(list) tr_h_all = defaultdict(list) for (h,r,t) in train: r_ht[r].add((h,t)) hr_t[(h,r)].add(t) tr_h[(t,r)].add(h) hr_t_all[(h,r)].append(t) tr_h_all[(t,r)].append(h) for (h,r,t) in test+valid: hr_t_all[(h,r)].append(t) tr_h_all[(t, r)].append(h) return r_ht, hr_t, tr_h, hr_t_all, tr_h_all
[docs] def get_axiom(self, ): self.axiom_dir = os.path.join(self.args.data_path, 'axiom_pool') self.reflexive_dir, self.symmetric_dir, self.transitive_dir, self.inverse_dir, self.subproperty_dir, self.equivalent_dir, self.inferencechain1, self.inferencechain2, self.inferencechain3, self.inferencechain4 = map(lambda x: os.path.join(self.axiom_dir, x), ['axiom_reflexive.txt', 'axiom_symmetric.txt', 'axiom_transitive.txt', 'axiom_inverse.txt', 'axiom_subProperty.txt', 'axiom_equivalent.txt', 'axiom_inferenceChain1.txt', 'axiom_inferenceChain2.txt', 'axiom_inferenceChain3.txt', 'axiom_inferenceChain4.txt']) # read and materialize axioms print('# self._read_axioms()') self._read_axioms() print('# self._read_axioms()') self._materialize_axioms() print('# self._read_axioms()') self._init_valid_axioms()
def _read_axioms(self): # for each axiom, the first id is the basic relation self.axiompool_reflexive = self._read_axiompool_file(self.reflexive_dir) self.axiompool_symmetric = self._read_axiompool_file(self.symmetric_dir) self.axiompool_transitive = self._read_axiompool_file(self.transitive_dir) self.axiompool_inverse = self._read_axiompool_file(self.inverse_dir) self.axiompool_equivalent = self._read_axiompool_file(self.equivalent_dir) self.axiompool_subproperty = self._read_axiompool_file(self.subproperty_dir) self.axiompool_inferencechain1 = self._read_axiompool_file(self.inferencechain1) self.axiompool_inferencechain2 = self._read_axiompool_file(self.inferencechain2) self.axiompool_inferencechain3 = self._read_axiompool_file(self.inferencechain3) self.axiompool_inferencechain4 = self._read_axiompool_file(self.inferencechain4) self.axiompool = [self.axiompool_reflexive, self.axiompool_symmetric, self.axiompool_transitive, self.axiompool_inverse, self.axiompool_subproperty, self.axiompool_equivalent, self.axiompool_inferencechain1,self.axiompool_inferencechain2, self.axiompool_inferencechain3,self.axiompool_inferencechain4] def _read_axiompool_file(self, file): f = open(file, 'r') axioms = [] for line in f.readlines(): line_list = line.strip().split('\t') axiom_ids = list(map(lambda x: self.relation2id[x], line_list)) #axiom_ids = self.relation2id[line_list] axioms.append(axiom_ids) # for the case reflexive pool is empty if len(axioms) == 0: np.reshape(axioms, [-1, 3]) return axioms # for each axioms in axiom pool # generate a series of entailments for each axiom def _materialize_axioms(self, generate=True, dump=True, load=False): if generate: self.reflexive2entailment = defaultdict(list) self.symmetric2entailment = defaultdict(list) self.transitive2entailment = defaultdict(list) self.inverse2entailment = defaultdict(list) self.equivalent2entailment = defaultdict(list) self.subproperty2entailment = defaultdict(list) self.inferencechain12entailment = defaultdict(list) self.inferencechain22entailment = defaultdict(list) self.inferencechain32entailment = defaultdict(list) self.inferencechain42entailment = defaultdict(list) self.reflexive_entailments, self.reflexive_entailments_num = self._materialize_sparse(self.axiompool_reflexive, type='reflexive') self.symmetric_entailments, self.symmetric_entailments_num = self._materialize_sparse(self.axiompool_symmetric, type='symmetric') self.transitive_entailments, self.transitive_entailments_num = self._materialize_sparse(self.axiompool_transitive, type='transitive') self.inverse_entailments, self.inverse_entailments_num = self._materialize_sparse(self.axiompool_inverse, type='inverse') self.subproperty_entailments, self.subproperty_entailments_num = self._materialize_sparse(self.axiompool_subproperty, type='subproperty') self.equivalent_entailments, self.equivalent_entailments_num = self._materialize_sparse(self.axiompool_equivalent, type='equivalent') self.inferencechain1_entailments, self.inferencechain1_entailments_num = self._materialize_sparse(self.axiompool_inferencechain1, type='inferencechain1') self.inferencechain2_entailments, self.inferencechain2_entailments_num = self._materialize_sparse(self.axiompool_inferencechain2, type='inferencechain2') self.inferencechain3_entailments, self.inferencechain3_entailments_num = self._materialize_sparse(self.axiompool_inferencechain3, type='inferencechain3') self.inferencechain4_entailments, self.inferencechain4_entailments_num = self._materialize_sparse(self.axiompool_inferencechain4, type='inferencechain4') print('reflexive entailments for sparse: ', self.reflexive_entailments_num) print('symmetric entailments for sparse: ', self.symmetric_entailments_num) print('transitive entailments for sparse: ', self.transitive_entailments_num) print('inverse entailments for sparse: ', self.inverse_entailments_num) print('subproperty entailments for sparse: ', self.subproperty_entailments_num) print('equivalent entailments for sparse: ', self.equivalent_entailments_num) print('inferencechain1 entailments for sparse: ', self.inferencechain1_entailments_num) print('inferencechain2 entailments for sparse: ', self.inferencechain2_entailments_num) print('inferencechain3 entailments for sparse: ', self.inferencechain3_entailments_num) print('inferencechain4 entailments for sparse: ', self.inferencechain4_entailments_num) print("finish generate axioms entailments for sparse") if dump: pickle.dump(self.reflexive_entailments, open(os.path.join(self.axiom_dir, 'reflexive_entailments'), 'wb')) pickle.dump(self.symmetric_entailments, open(os.path.join(self.axiom_dir, 'symmetric_entailments'), 'wb')) pickle.dump(self.transitive_entailments, open(os.path.join(self.axiom_dir, 'transitive_entailments'), 'wb')) pickle.dump(self.inverse_entailments, open(os.path.join(self.axiom_dir, 'inverse_entailments'), 'wb')) pickle.dump(self.subproperty_entailments, open(os.path.join(self.axiom_dir, 'subproperty_entailments'), 'wb')) #pickle.dump(self.inferencechain_entailments, open(os.path.join(self.axiom_dir, 'inferencechain_entailments'), 'wb')) pickle.dump(self.equivalent_entailments, open(os.path.join(self.axiom_dir, 'equivalent_entailments'), 'wb')) pickle.dump(self.inferencechain1_entailments, open(os.path.join(self.axiom_dir, 'inferencechain1_entailments'), 'wb')) pickle.dump(self.inferencechain2_entailments, open(os.path.join(self.axiom_dir, 'inferencechain2_entailments'), 'wb')) pickle.dump(self.inferencechain3_entailments, open(os.path.join(self.axiom_dir, 'inferencechain3_entailments'), 'wb')) pickle.dump(self.inferencechain4_entailments, open(os.path.join(self.axiom_dir, 'inferencechain4_entailments'), 'wb')) print("finish dump axioms entialments") if load: print("load refexive entailments...") self.reflexive_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'reflexive_entailments'), 'rb')) print(self.reflexive_entailments) print('load symmetric entailments...') self.symmetric_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'symmetric_entailments'), 'rb')) print("load transitive entialments... ") self.transitive_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'transitive_entailments'), 'rb')) print("load inverse entailments...") self.inverse_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'inverse_entailments'), 'rb')) print("load subproperty entailments...") self.subproperty_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'subproperty_entailments'), 'rb')) #print("load inferencechain entailments...") #self.inferencechain_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'inferencechain_entailments'), 'rb')) print("load equivalent entialments...") self.equivalent_entailments = pickle.load(open(os.path.join(self.axiom_dir, 'equivalent_entailments'), 'rb')) print("load inferencechain1 entailments...") self.inferencechain1_entailments = pickle.load( open(os.path.join(self.axiom_dir, 'inferencechain1_entailments'), 'rb')) print("load inferencechain2 entailments...") self.inferencechain2_entailments = pickle.load( open(os.path.join(self.axiom_dir, 'inferencechain2_entailments'), 'rb')) print("load inferencechain3 entailments...") self.inferencechain3_entailments = pickle.load( open(os.path.join(self.axiom_dir, 'inferencechain3_entailments'), 'rb')) print("load inferencechain4 entailments...") self.inferencechain4_entailments = pickle.load( open(os.path.join(self.axiom_dir, 'inferencechain4_entailments'), 'rb')) print("finish load axioms entailments") def _materialize_sparse(self, axioms, type=None, sparse = False): inference = [] # axiom2entailment is a dict # with the all axioms in the axiom pool as keys # and all the entailments for each axiom as values axiom_list = axioms length = len(axioms) max_entailments = self.max_entialments num = 0 if length == 0: if type == 'reflexive': np.reshape(inference, [-1, 3]) elif type == 'symmetric' or type =='inverse' or type =='equivalent' or type =='subproperty': np.reshape(inference, [-1, 6]) elif type=='transitive' or type=='inferencechain': np.reshape(inference, [-1, 9]) else: raise NotImplementedError return inference, num if type == 'reflexive': for axiom in axiom_list: axiom_key =tuple(axiom) r = axiom[0] inference_tmp = [] for (h,t) in self.r_ht[r]: # filter the axiom with too much entailments if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 2]) break if h != t and self.entity2sparsity[h]>self.sparsity: num += 1 inference_tmp.append([h,r,h]) for entailment in inference_tmp: self.reflexive2entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'symmetric': #self.symmetric2entailment = defaultdict(list) for axiom in axiom_list: axiom_key = tuple(axiom) r = axiom[0] inference_tmp = [] for (h,t) in self.r_ht[r]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 2]) break if (t,h) not in self.r_ht[r] and (self.entity2sparsity[h]>self.sparsity or self.entity2sparsity[t]>self.sparsity): num += 1 inference_tmp.append([h,r,t,t,r,h]) for entailment in inference_tmp: self.symmetric2entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'transitive': #self.transitive2entailment = defaultdict(list) for axiom in axiom_list: axiom_key = tuple(axiom) r = axiom[0] inference_tmp = [] for (h,t) in self.r_ht[r]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 9]) break # (t,r,e) exist but (h,r,e) not exist and e!=h for e in self.hr_t[(t,r)]- self.hr_t[(h,r)]: if e != h and (self.entity2sparsity[h]>self.sparsity or self.entity2sparsity[e]>self.sparsity): num += 1 inference_tmp.append([h,r,t,t,r,e,h,r,e]) for entailment in inference_tmp: self.transitive2entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'inverse': for axiom in axiom_list: axiom_key = tuple(axiom) r1,r2 = axiom inference_tmp = [] for (h,t) in self.r_ht[r1]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 6]) break if (t,h) not in self.r_ht[r2] and (self.entity2sparsity[h]>self.sparsity or self.entity2sparsity[t]>self.sparsity): num += 1 inference_tmp.append([h,r1,t, t,r2,h]) #self.inverse2entailment[axiom_key].append([h,r1,t, t,r2,h]) for entailment in inference_tmp: self.inverse2entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'equivalent' or type =='subproperty': for axiom in axiom_list: axiom_key = tuple(axiom) r1,r2 = axiom inference_tmp = [] for (h,t) in self.r_ht[r1]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 6]) break if (h,t) not in self.r_ht[r2] and (self.entity2sparsity[h]>self.sparsity or self.entity2sparsity[t]>self.sparsity): num += 1 inference_tmp.append([h,r1,t, h,r2,t]) for entailment in inference_tmp: self.equivalent2entailment[axiom_key].append(entailment) self.subproperty2entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'inferencechain1': self.inferencechain12entailment = defaultdict(list) i = 0 for axiom in axiom_list: axiom_key = tuple(axiom) i += 1 # print('%d/%d' % (i, length)) r1, r2, r3 = axiom inference_tmp = [] for (e, h) in self.r_ht[r2]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 9]) break for t in self.hr_t[(e, r3)]: if (h, t) not in self.r_ht[r1] and ( self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[e] > self.sparsity): num += 1 inference_tmp.append([e, r2, h, e, r3, t, h, r1, t]) #self.inferencechain12entailment[axiom_key].append([[e, r2, h, e, r3, t, h, r1, t]]) for entailment in inference_tmp: self.inferencechain12entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'inferencechain2': self.inferencechain22entailment = defaultdict(list) i = 0 for axiom in axiom_list: axiom_key = tuple(axiom) i += 1 # print('%d/%d' % (i, length)) r1, r2, r3 = axiom inference_tmp = [] for (e, h) in self.r_ht[r2]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 9]) break for t in self.tr_h[(e, r3)]: if (h, t) not in self.r_ht[r1] and ( self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[e] > self.sparsity): num += 1 inference_tmp.append([e, r2, h, t, r3, e, h, r1, t]) #self.inferencechain22entailment[axiom_key].append([[e, r2, h, t, r3, e, h, r1, t]]) for entailment in inference_tmp: self.inferencechain22entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'inferencechain3': self.inferencechain32entailment = defaultdict(list) i = 0 for axiom in axiom_list: axiom_key = tuple(axiom) i += 1 # print('%d/%d' % (i, length)) r1, r2, r3 = axiom inference_tmp = [] for (h, e) in self.r_ht[r2]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 9]) break for t in self.hr_t[(e, r3)]: if (h, t) not in self.r_ht[r1] and ( self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[e] > self.sparsity): num += 1 inference_tmp.append([h, r2, e, e, r3, t, h, r1, t]) for entailment in inference_tmp: self.inferencechain32entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'inferencechain4': self.inferencechain42entailment = defaultdict(list) i = 0 for axiom in axiom_list: axiom_key = tuple(axiom) i += 1 # print('%d/%d' % (i, length)) r1, r2, r3 = axiom inference_tmp = [] for (h, e) in self.r_ht[r2]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 9]) break for t in self.tr_h[(e, r3)]: if (h, t) not in self.r_ht[r1] and ( self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[e] > self.sparsity): num += 1 inference_tmp.append([h, r2, e, t, r3, e, h, r1, t]) for entailment in inference_tmp: self.inferencechain42entailment[axiom_key].append(entailment) inference.append(inference_tmp) return inference, num def _materialize(self, axioms, type=None, sparse=False): inference = [] # axiom2entailment is a dict # with the all axioms in the axiom pool as keys # and all the entailments for each axiom as values axiom_list = axioms # print('axiom_list', axiom_list) length = len(axioms) max_entailments = 5000 num = 0 if length == 0: if type == 'reflexive': np.reshape(inference, [-1, 3]) elif type == 'symmetric' or type == 'inverse' or type == 'equivalent' or type == 'subproperty': np.reshape(inference, [-1, 6]) elif type == 'transitive' or type == 'inferencechain': np.reshape(inference, [-1, 9]) else: raise NotImplementedError return inference, num if type == 'reflexive': for axiom in axiom_list: axiom_key = tuple(axiom) r = axiom[0] inference_tmp = [] for (h, t) in self.r_ht[r]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 2]) break if h != t: #and self.entity2sparsity[h] > self.sparsity: num += 1 inference_tmp.append([h, r, h]) for entailment in inference_tmp: self.reflexive2entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'symmetric': for axiom in axiom_list: axiom_key = tuple(axiom) r = axiom[0] inference_tmp = [] for (h, t) in self.r_ht[r]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 2]) break if (t, h) not in self.r_ht[r]: #and (self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[t] > self.sparsity): num += 1 inference_tmp.append([h, r, t, t, r, h]) for entailment in inference_tmp: self.symmetric2entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'transitive': for axiom in axiom_list: axiom_key = tuple(axiom) r = axiom[0] inference_tmp = [] for (h, t) in self.r_ht[r]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 9]) break # (t,r,e) exist but (h,r,e) not exist and e!=h for e in self.hr_t[(t, r)] - self.hr_t[(h, r)]: if e != h: #and (self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[e] > self.sparsity): num += 1 inference_tmp.append([h, r, t, t, r, e, h, r, e]) for entailment in inference_tmp: self.transitive2entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'inverse': # self.inverse2entailment = defaultdict(list) for axiom in axiom_list: axiom_key = tuple(axiom) r1, r2 = axiom inference_tmp = [] for (h, t) in self.r_ht[r1]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 6]) break if (t, h) not in self.r_ht[r2]: #and (self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[t] > self.sparsity): num += 1 inference_tmp.append([h, r1, t, t, r2, h]) for entailment in inference_tmp: self.inverse2entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'equivalent' or type == 'subproperty': for axiom in axiom_list: axiom_key = tuple(axiom) r1, r2 = axiom inference_tmp = [] for (h, t) in self.r_ht[r1]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 6]) break if (h, t) not in self.r_ht[r2]: #and (self.entity2sparsity[h] > self.sparsity or self.entity2sparsity[t] > self.sparsity): num += 1 inference_tmp.append([h, r1, t, h, r2, t]) for entailment in inference_tmp: self.equivalent2entailment[axiom_key].append(entailment) self.subproperty2entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'inferencechain1': self.inferencechain12entailment = defaultdict(list) i = 0 for axiom in axiom_list: axiom_key = tuple(axiom) i += 1 # print('%d/%d' % (i, length)) r1, r2, r3 = axiom inference_tmp = [] for (e, h) in self.r_ht[r2]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 9]) break for t in self.hr_t[(e, r3)]: if (h, t) not in self.r_ht[r1]: num += 1 inference_tmp.append([e, r2, h, e, r3, t, h, r1, t]) for entailment in inference_tmp: self.inferencechain12entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'inferencechain2': self.inferencechain22entailment = defaultdict(list) i = 0 for axiom in axiom_list: axiom_key = tuple(axiom) i += 1 # print('%d/%d' % (i, length)) r1, r2, r3 = axiom inference_tmp = [] for (e, h) in self.r_ht[r2]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 9]) break for t in self.tr_h[(e, r3)]: if (h, t) not in self.r_ht[r1]: num += 1 inference_tmp.append([e, r2, h, t, r3, e, h, r1, t]) for entailment in inference_tmp: self.inferencechain22entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'inferencechain3': self.inferencechain32entailment = defaultdict(list) i = 0 for axiom in axiom_list: axiom_key = tuple(axiom) i += 1 # print('%d/%d' % (i, length)) r1, r2, r3 = axiom inference_tmp = [] for (h, e) in self.r_ht[r2]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 9]) break for t in self.hr_t[(e, r3)]: if (h, t) not in self.r_ht[r1]: num += 1 inference_tmp.append([h, r2, e, e, r3, t, h, r1, t]) for entailment in inference_tmp: self.inferencechain32entailment[axiom_key].append(entailment) inference.append(inference_tmp) if type == 'inferencechain4': self.inferencechain42entailment = defaultdict(list) i = 0 for axiom in axiom_list: axiom_key = tuple(axiom) i += 1 r1, r2, r3 = axiom inference_tmp = [] for (h, e) in self.r_ht[r2]: if len(inference_tmp) > max_entailments: inference_tmp = np.reshape([], [-1, 9]) break for t in self.tr_h[(e, r3)]: if (h, t) not in self.r_ht[r1]: num += 1 inference_tmp.append([h, r2, e, t, r3, e, h, r1, t]) for entailment in inference_tmp: self.inferencechain42entailment[axiom_key].append(entailment) inference.append(inference_tmp) return inference, num def _init_valid_axioms(self): # init valid axioms self.valid_reflexive, self.valid_symmetric, self.valid_transitive,\ self.valid_inverse, self.valid_subproperty, self.valid_equivalent,\ self.valid_inferencechain1, self.valid_inferencechain2, \ self.valid_inferencechain3, self.valid_inferencechain4 = [[] for x in range(self.axiom_types)] # init valid axiom entailments self.valid_reflexive2entailment, self.valid_symmetric2entailment, self.valid_transitive2entailment, \ self.valid_inverse2entailment, self.valid_subproperty2entailment, self.valid_equivalent2entailment, \ self.valid_inferencechain12entailment, self.valid_inferencechain22entailment, \ self.valid_inferencechain32entailment, self.valid_inferencechain42entailment = [[] for x in range(self.axiom_types)] # init valid axiom entailments probability self.valid_reflexive_p, self.valid_symmetric_p, self.valid_transitive_p, \ self.valid_inverse_p, self.valid_subproperty_p, self.valid_equivalent_p, \ self.valid_inferencechain1_p, self.valid_inferencechain2_p,\ self.valid_inferencechain3_p, self.valid_inferencechain4_p= [[] for x in range(self.axiom_types)] # init valid axiom batchsize self.reflexive_batchsize = 1 self.symmetric_batchsize = 1 self.transitive_batchsize = 1 self.inverse_batchsize = 1 self.subproperty_batchsize = 1 self.equivalent_batchsize = 1 #self.inferencechain_batchsize = 1 self.inferencechain1_batchsize = 1 self.inferencechain2_batchsize = 1 self.inferencechain3_batchsize = 1 self.inferencechain4_batchsize = 1 # add the new triples from axioms to training triple
[docs] def update_train_triples(self, epoch=0, update_per = 10): """add the new triples from axioms to training triple Args: epoch (int, optional): epoch in training process. Defaults to 0. update_per (int, optional): Defaults to 10. Returns: updated_train_data: training triple after adding the new triples from axioms """ reflexive_triples, symmetric_triples, transitive_triples, inverse_triples,\ equivalent_triples, subproperty_triples, inferencechain1_triples, \ inferencechain2_triples, inferencechain3_triples, inferencechain4_triples = [ np.reshape(np.asarray([]), [-1, 3]) for i in range(self.axiom_types)] reflexive_p, symmetric_p, transitive_p, inverse_p, \ equivalent_p, subproperty_p, inferencechain1_p, \ inferencechain2_p, inferencechain3_p, inferencechain4_p = [np.reshape(np.asarray([]), [-1, 1]) for i in range(self.axiom_types)] updated_train_data=None if epoch >= 20: #if True: print("len(self.valid_reflexive2entailment):", len(self.valid_reflexive2entailment)) print("len(self.valid_symmetric2entailment):", len(self.valid_symmetric2entailment)) print("len(self.valid_transitive2entailment)", len(self.valid_transitive2entailment)) print("len(self.valid_inverse2entailment)", len(self.valid_inverse2entailment)) print("len(self.valid_equivalent2entailment)", len(self.valid_equivalent2entailment)) print("len(self.valid_subproperty2entailment)", len(self.valid_subproperty2entailment)) valid_reflexive2entailment, valid_symmetric2entailment, valid_transitive2entailment,\ valid_inverse2entailment, valid_equivalent2entailment, valid_subproperty2entailment, \ valid_inferencechain12entailment, valid_inferencechain22entailment,\ valid_inferencechain32entailment, valid_inferencechain42entailment = [[] for i in range(10)] if len(self.valid_reflexive2entailment)>0: valid_reflexive2entailment = np.reshape(np.asarray(self.valid_reflexive2entailment), [-1, 3]) reflexive_triples = np.asarray(valid_reflexive2entailment)[:, -3:] reflexive_p = np.reshape(np.asarray(self.valid_reflexive_p),[-1,1]) if len(self.valid_symmetric2entailment) > 0: valid_symmetric2entailment = np.reshape(np.asarray(self.valid_symmetric2entailment), [-1, 6]) symmetric_triples = np.asarray(valid_symmetric2entailment)[:, -3:] symmetric_p = np.reshape(np.asarray(self.valid_symmetric_p),[-1,1]) if len(self.valid_transitive2entailment) > 0: valid_transitive2entailment = np.reshape(np.asarray(self.valid_transitive2entailment), [-1, 9]) transitive_triples = np.asarray(valid_transitive2entailment)[:, -3:] transitive_p = np.reshape(np.asarray(self.valid_transitive_p), [-1, 1]) if len(self.valid_inverse2entailment) > 0: valid_inverse2entailment = np.reshape(np.asarray(self.valid_inverse2entailment), [-1, 6]) inverse_triples = np.asarray(valid_inverse2entailment)[:, -3:] inverse_p = np.reshape(np.asarray(self.valid_inverse_p), [-1, 1]) if len(self.valid_equivalent2entailment) > 0: valid_equivalent2entailment = np.reshape(np.asarray(self.valid_equivalent2entailment), [-1, 6]) equivalent_triples = np.asarray(valid_equivalent2entailment)[:, -3:] equivalent_p = np.reshape(np.asarray(self.valid_equivalent_p), [-1, 1]) if len(self.valid_subproperty2entailment) > 0: valid_subproperty2entailment = np.reshape(np.asarray(self.valid_subproperty2entailment), [-1, 6]) subproperty_triples = np.asarray(valid_subproperty2entailment)[:, -3:] subproperty_p = np.reshape(np.asarray(self.valid_subproperty_p),[-1,1]) if len(self.valid_inferencechain12entailment) > 0: valid_inferencechain12entailment = np.reshape(np.asarray(self.valid_inferencechain12entailment), [-1, 9]) inferencechain1_triples = np.asarray(valid_inferencechain12entailment)[:, -3:] inferencechain1_p = np.reshape(np.asarray(self.valid_inferencechain1_p), [-1, 1]) if len(self.valid_inferencechain22entailment) > 0: valid_inferencechain22entailment = np.reshape(np.asarray(self.valid_inferencechain22entailment), [-1, 9]) inferencechain2_triples = np.asarray(valid_inferencechain22entailment)[:, -3:] inferencechain2_p = np.reshape(np.asarray(self.valid_inferencechain2_p), [-1, 1]) if len(self.valid_inferencechain32entailment) > 0: valid_inferencechain32entailment = np.reshape(np.asarray(self.valid_inferencechain32entailment), [-1, 9]) inferencechain3_triples = np.asarray(valid_inferencechain32entailment)[:, -3:] inferencechain3_p = np.reshape(np.asarray(self.valid_inferencechain3_p), [-1, 1]) if len(self.valid_inferencechain42entailment) > 0: valid_inferencechain42entailment = np.reshape(np.asarray(self.valid_inferencechain42entailment), [-1, 9]) inferencechain4_triples = np.asarray(valid_inferencechain42entailment)[:, -3:] inferencechain4_p = np.reshape(np.asarray(self.valid_inferencechain4_p), [-1, 1]) # pickle.dump(self.reflexive_entailments, open(os.path.join(self.axiom_dir, 'reflexive_entailments'), 'wb')) # store all the injected triples entailment_all = (valid_reflexive2entailment, valid_symmetric2entailment, valid_transitive2entailment, valid_inverse2entailment, valid_equivalent2entailment, valid_subproperty2entailment, valid_inferencechain12entailment,valid_inferencechain22entailment, valid_inferencechain32entailment,valid_inferencechain42entailment) pickle.dump(entailment_all, open(os.path.join(self.axiom_dir, 'valid_entailments.pickle'), 'wb')) train_inject_triples = np.concatenate([reflexive_triples, symmetric_triples, transitive_triples, inverse_triples, equivalent_triples, subproperty_triples, inferencechain1_triples, inferencechain2_triples,inferencechain3_triples,inferencechain4_triples], axis=0) train_inject_triples_p = np.concatenate([reflexive_p,symmetric_p, transitive_p, inverse_p, equivalent_p, subproperty_p, inferencechain1_p, inferencechain2_p,inferencechain3_p,inferencechain4_p], axis=0) self.train_inject_triples = train_inject_triples inject_labels = np.reshape(np.ones(len(train_inject_triples)), [-1, 1]) * self.axiom_weight * train_inject_triples_p train_inject_ids_labels = np.concatenate([train_inject_triples, inject_labels], axis=1) self.train_ids_labels_inject = train_inject_triples#train_inject_ids_labels print('num reflexive triples', len(reflexive_triples)) print('num symmetric triples', len(symmetric_triples)) print('num transitive triples', len(transitive_triples)) print('num inverse triples', len(inverse_triples)) print('num equivalent triples', len(equivalent_triples)) print('num subproperty triples', len(subproperty_triples)) print('num inferencechain1 triples', len(inferencechain1_triples)) print('num inferencechain2 triples', len(inferencechain2_triples)) print('num inferencechain3 triples', len(inferencechain3_triples)) print('num inferencechain4 triples', len(inferencechain4_triples)) #print(self.train_ids_labels_inject) updated_train_data=self.generate_new_train_triples() return updated_train_data
[docs] def split_embedding(self, embedding): """split embedding Args: embedding: embeddings need to be splited, shape:[None, dim]. Returns: probability: The similrity between two matrices. """ # embedding: [None, dim] assert self.args.emb_dim % 4 == 0 num_scalar = self.args.emb_dim // 2 num_block = self.args.emb_dim // 4 embedding_scalar = embedding[:, 0:num_scalar] embedding_x = embedding[:, num_scalar:-num_block] embedding_y = embedding[:, -num_block:] return embedding_scalar, embedding_x, embedding_y
# calculate the similrity between two matrices # head: [?, dim] # tail: [?, dim] or [1,dim]
[docs] def sim(self, head=None, tail=None, arity=None): """calculate the similrity between two matrices Args: head: embeddings of head, shape:[batch_size, dim]. tail: embeddings of tail, shape:[batch_size, dim] or [1, dim]. arity: 1,2 or 3 Returns: probability: The similrity between two matrices. """ if arity == 1: A_scalar, A_x, A_y = self.split_embedding(head) elif arity == 2: M1_scalar, M1_x, M1_y = self.split_embedding(head[0]) M2_scalar, M2_x, M2_y = self.split_embedding(head[1]) A_scalar= M1_scalar * M2_scalar A_x = M1_x*M2_x - M1_y*M2_y A_y = M1_x*M2_y + M1_y*M2_x elif arity==3: M1_scalar, M1_x, M1_y = self.split_embedding(head[0]) M2_scalar, M2_x, M2_y = self.split_embedding(head[1]) M3_scalar, M3_x, M3_y = self.split_embedding(head[2]) M1M2_scalar = M1_scalar * M2_scalar M1M2_x = M1_x * M2_x - M1_y * M2_y M1M2_y = M1_x * M2_y + M1_y * M2_x A_scalar = M1M2_scalar * M3_scalar A_x = M1M2_x * M3_x - M1M2_y * M3_y A_y = M1M2_x * M3_y + M1M2_y * M3_x else: raise NotImplemented B_scala, B_x, B_y = self.split_embedding(tail) similarity = torch.cat([(A_scalar - B_scala)**2, (A_x - B_x)**2, (A_x - B_x)**2, (A_y - B_y)**2, (A_y - B_y)**2 ], dim=1) similarity = torch.sqrt(torch.sum(similarity, dim=1)) #recale the probability probability = (torch.max(similarity)-similarity)/(torch.max(similarity)-torch.min(similarity)) return probability
# generate a probality for each axiom in axiom pool
[docs] def run_axiom_probability(self): """this function is used to generate a probality for each axiom in axiom pool """ self.identity = torch.cat((torch.ones(int(self.args.emb_dim-self.args.emb_dim/4)),torch.zeros(int(self.args.emb_dim/4))),0).unsqueeze(0).cuda() if len(self.axiompool_reflexive) != 0: index = torch.LongTensor(self.axiompool_reflexive).cuda() reflexive_embed = self.rel_emb(index) reflexive_prob = self.sim(head=reflexive_embed[:, 0, :], tail=self.identity, arity=1) else: reflexive_prob = [] if len(self.axiompool_symmetric) != 0: index = torch.LongTensor(self.axiompool_symmetric).cuda() symmetric_embed = self.rel_emb(index) symmetric_prob = self.sim(head=[symmetric_embed[:, 0, :], symmetric_embed[:, 0, :]], tail=self.identity, arity=2) #symmetric_prob = sess.run(self.symmetric_probability, {self.symmetric_pool: self.axiompool_symmetric}) else: symmetric_prob = [] if len(self.axiompool_transitive) != 0: index = torch.LongTensor(self.axiompool_transitive).cuda() transitive_embed = self.rel_emb(index) transitive_prob = self.sim(head=[transitive_embed[:, 0, :], transitive_embed[:, 0, :]], tail=transitive_embed[:, 0, :], arity=2) #transitive_prob = sess.run(self.transitive_probability, {self.transitive_pool: self.axiompool_transitive}) else: transitive_prob = [] if len(self.axiompool_inverse) != 0: index = torch.LongTensor(self.axiompool_inverse).cuda() #inverse_prob = sess.run(self.inverse_probability, {self.inverse_pool: self.axiompool_inverse}) inverse_embed = self.rel_emb(index) inverse_probability1 = self.sim(head=[inverse_embed[:, 0,:],inverse_embed[:, 1,:]], tail = self.identity, arity=2) inverse_probability2 = self.sim(head=[inverse_embed[:,1,:],inverse_embed[:, 0,:]], tail=self.identity, arity=2) inverse_prob = (inverse_probability1 + inverse_probability2)/2 else: inverse_prob = [] if len(self.axiompool_subproperty) != 0: index = torch.LongTensor(self.axiompool_subproperty).cuda() #subproperty_prob = sess.run(self.subproperty_probability, {self.subproperty_pool: self.axiompool_subproperty}) subproperty_embed = self.rel_emb(index) subproperty_prob = self.sim(head=subproperty_embed[:, 0,:], tail=subproperty_embed[:, 1, :], arity=1) else: subproperty_prob = [] if len(self.axiompool_equivalent) != 0: index = torch.LongTensor(self.axiompool_equivalent).cuda() #equivalent_prob = sess.run(self.equivalent_probability, {self.equivalent_pool: self.axiompool_equivalent}) equivalent_embed = self.rel_emb(index) equivalent_prob = self.sim(head=equivalent_embed[:, 0,:], tail=equivalent_embed[:, 1,:], arity=1) else: equivalent_prob = [] if len(self.axiompool_inferencechain1) != 0: index = torch.LongTensor(self.axiompool_inferencechain1).cuda() inferencechain_embed = self.rel_emb(index) inferencechain1_prob = self.sim(head=[inferencechain_embed[:, 1, :], inferencechain_embed[:, 0, :]], tail=inferencechain_embed[:, 2, :], arity=2) else: inferencechain1_prob = [] if len(self.axiompool_inferencechain2) != 0: index = torch.LongTensor(self.axiompool_inferencechain2).cuda() inferencechain_embed = self.rel_emb(index) inferencechain2_prob = self.sim(head=[inferencechain_embed[:, 2, :], inferencechain_embed[:, 1, :], inferencechain_embed[:, 0, :]], tail=self.identity, arity=3) else: inferencechain2_prob = [] if len(self.axiompool_inferencechain3) != 0: index = torch.LongTensor(self.axiompool_inferencechain3).cuda() inferencechain_embed = self.rel_emb(index) inferencechain3_prob = self.sim(head=[inferencechain_embed[:, 1, :], inferencechain_embed[:, 2, :]], tail=inferencechain_embed[:, 0, :], arity=2) else: inferencechain3_prob = [] if len(self.axiompool_inferencechain4) != 0: index = torch.LongTensor(self.axiompool_inferencechain4).cuda() inferencechain_embed = self.rel_emb(index) inferencechain4_prob = self.sim(head=[inferencechain_embed[:, 0, :], inferencechain_embed[:, 2, :]],tail=inferencechain_embed[:, 1, :], arity=2) else: inferencechain4_prob = [] output = [reflexive_prob, symmetric_prob, transitive_prob, inverse_prob, subproperty_prob,equivalent_prob,inferencechain1_prob, inferencechain2_prob, inferencechain3_prob, inferencechain4_prob] return output
[docs] def update_valid_axioms(self, input): """this function is used to select high probability axioms as valid axioms and record their scores """ # # valid_axioms = [self._select_high_probability(list(prob), axiom) for prob,axiom in zip(input, self.axiompool)] self.valid_reflexive, self.valid_symmetric, self.valid_transitive, \ self.valid_inverse, self.valid_subproperty, self.valid_equivalent, \ self.valid_inferencechain1, self.valid_inferencechain2, \ self.valid_inferencechain3, self.valid_inferencechain4 = valid_axioms # update the batchsize of axioms and entailments self._reset_valid_axiom_entailment()
def _select_high_probability(self, prob, axiom): # select the high probability axioms and recore their probabilities valid_axiom = [[axiom[prob.index(p)],[p]] for p in prob if p>self.select_probability] return valid_axiom def _reset_valid_axiom_entailment(self): self.infered_hr_t = defaultdict(set) self.infered_tr_h = defaultdict(set) self.valid_reflexive2entailment, self.valid_reflexive_p = \ self._valid_axiom2entailment(self.valid_reflexive, self.reflexive2entailment) self.valid_symmetric2entailment, self.valid_symmetric_p = \ self._valid_axiom2entailment(self.valid_symmetric, self.symmetric2entailment) self.valid_transitive2entailment, self.valid_transitive_p = \ self._valid_axiom2entailment(self.valid_transitive, self.transitive2entailment) self.valid_inverse2entailment, self.valid_inverse_p = \ self._valid_axiom2entailment(self.valid_inverse, self.inverse2entailment) self.valid_subproperty2entailment, self.valid_subproperty_p = \ self._valid_axiom2entailment(self.valid_subproperty, self.subproperty2entailment) self.valid_equivalent2entailment, self.valid_equivalent_p = \ self._valid_axiom2entailment(self.valid_equivalent, self.equivalent2entailment) self.valid_inferencechain12entailment, self.valid_inferencechain1_p = \ self._valid_axiom2entailment(self.valid_inferencechain1, self.inferencechain12entailment) self.valid_inferencechain22entailment, self.valid_inferencechain2_p = \ self._valid_axiom2entailment(self.valid_inferencechain2, self.inferencechain22entailment) self.valid_inferencechain32entailment, self.valid_inferencechain3_p = \ self._valid_axiom2entailment(self.valid_inferencechain3, self.inferencechain32entailment) self.valid_inferencechain42entailment, self.valid_inferencechain4_p = \ self._valid_axiom2entailment(self.valid_inferencechain4, self.inferencechain42entailment) def _valid_axiom2entailment(self, valid_axiom, axiom2entailment): valid_axiom2entailment = [] valid_axiom_p = [] for axiom_p in valid_axiom: axiom = tuple(axiom_p[0]) p = axiom_p[1] for entailment in axiom2entailment[axiom]: valid_axiom2entailment.append(entailment) valid_axiom_p.append(p) h,r,t = entailment[-3:] self.infered_hr_t[(h,r)].add(t) self.infered_tr_h[(t,r)].add(h) return valid_axiom2entailment, valid_axiom_p # updata new train triples:
[docs] def generate_new_train_triples(self): """The function is to updata new train triples and used after each training epoch end Returns: self.train_sampler.train_triples: The new training dataset (triples). """ self.train_sampler.train_triples = copy.deepcopy(self.train_triples_base) print('generate_new_train_triples...') #origin_triples = train_sampler.train_triples inject_triples = self.train_ids_labels_inject inject_num = int(self.inject_triple_percent*len(self.train_sampler.train_triples)) if len(inject_triples)> inject_num and inject_num >0: np.random.shuffle(inject_triples) inject_triples = inject_triples[:inject_num] #train_triples = np.concatenate([origin_triples, inject_triples], axis=0) print('当前train_sampler.train_triples数目',len(self.train_sampler.train_triples)) for h,r,t in inject_triples: self.train_sampler.train_triples.append((int(h),int(r),int(t))) print('添加后train_sampler.train_triples数目',len(self.train_sampler.train_triples)) return self.train_sampler.train_triples
[docs] def get_rule(self, rel2id): """Get rule for rule_base KGE models, such as ComplEx_NNE model. Get rule and confidence from _cons.txt file. Update: (rule_p, rule_q): Rule. confidence: The confidence of rule. """ rule_p, rule_q, confidence = [], [], [] with open(os.path.join(self.args.data_path, '_cons.txt')) as file: lines = file.readlines() for line in lines: rule_str, trust = line.strip().split() body, head = rule_str.split(',') if '-' in body: rule_p.append(rel2id[body[1:]]) rule_q.append(rel2id[head]) else: rule_p.append(rel2id[body]) rule_q.append(rel2id[head]) confidence.append(float(trust)) rule_p = torch.tensor(rule_p).cuda() rule_q = torch.tensor(rule_q).cuda() confidence = torch.tensor(confidence).cuda() return (rule_p, rule_q), confidence
"""def init_emb(self): self.epsilon = 2.0 self.margin = nn.Parameter( torch.Tensor([self.args.margin]), requires_grad=False ) self.embedding_range = nn.Parameter( torch.Tensor([(self.margin.item() + self.epsilon) / self.args.emb_dim]), requires_grad=False ) self.ent_emb = nn.Embedding(self.args.num_ent, self.args.emb_dim * 2) self.rel_emb = nn.Embedding(self.args.num_rel, self.args.emb_dim * 2) nn.init.uniform_(tensor=self.ent_emb.weight.data, a=-self.embedding_range.item(), b=self.embedding_range.item()) nn.init.uniform_(tensor=self.rel_emb.weight.data, a=-self.embedding_range.item(), b=self.embedding_range.item())"""
[docs] def init_emb(self): """Initialize the entity and relation embeddings in the form of a uniform distribution. """ self.epsilon = 2.0 self.margin = nn.Parameter( torch.Tensor([self.args.margin]), requires_grad=False ) self.embedding_range = nn.Parameter( torch.Tensor([(self.margin.item() + self.epsilon) / self.args.emb_dim]), requires_grad=False ) self.ent_emb = nn.Embedding(self.args.num_ent, self.args.emb_dim) self.rel_emb = nn.Embedding(self.args.num_rel, self.args.emb_dim) nn.init.uniform_(tensor=self.ent_emb.weight.data, a=-self.embedding_range.item(), b=self.embedding_range.item()) nn.init.uniform_(tensor=self.rel_emb.weight.data, a=-self.embedding_range.item(), b=self.embedding_range.item())
"""def score_func(self, head_emb, relation_emb, tail_emb, mode): re_head, im_head = torch.chunk(head_emb, 2, dim=-1) re_relation, im_relation = torch.chunk(relation_emb, 2, dim=-1) re_tail, im_tail = torch.chunk(tail_emb, 2, dim=-1) return torch.sum( re_head * re_tail * re_relation + im_head * im_tail * re_relation + re_head * im_tail * im_relation - im_head * re_tail * im_relation, -1 )"""
[docs] def score_func(self, head_emb, relation_emb, tail_emb, mode): """Calculating the score of triples. The formula for calculating the score is DistMult. Args: head_emb: The head entity embedding. relation_emb: The relation embedding. tail_emb: The tail entity embedding. mode: Choose head-predict or tail-predict. Returns: score: The score of triples. """ if mode == 'head-batch': score = head_emb * (relation_emb * tail_emb) else: score = (head_emb * relation_emb) * tail_emb score = score.sum(dim = -1) return score
[docs] def forward(self, triples, negs=None, mode='single'): """The functions used in the training phase Args: triples: The triples ids, as (h, r, t), shape:[batch_size, 3]. negs: Negative samples, defaults to None. mode: Choose head-predict or tail-predict, Defaults to 'single'. Returns: score: The score of triples. """ head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode) score = self.score_func(head_emb, relation_emb, tail_emb, mode) return score
[docs] def get_score(self, batch, mode): """The functions used in the testing phase Args: batch: A batch of data. mode: Choose head-predict or tail-predict. Returns: score: The score of triples. """ triples = batch["positive_sample"] head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode) score = self.score_func(head_emb, relation_emb, tail_emb, mode) return score