import torch.nn as nn
import torch
import dgl
import pickle
import numpy as np
import torch.nn.functional as F
from .RGCN import RelGraphConv
from .model import Model
from neuralkg_ind.utils.tools import *
from neuralkg_ind.model import TransE, DistMult, ComplEx, RotatE
from neuralkg_ind.utils import get_indtest_test_dataset_and_train_g, get_g_bidir
[docs]class MorsE(nn.Module):
"""`Meta-Knowledge Transfer for Inductive Knowledge Graph Embedding`_ (MorsE), which learns transferable meta-knowledge that
can be used to produce entity embeddings.
Attributes:
args: Model configuration parameters.
ent_init: Relation embedding init class.
rgcn: RGCN model.
KGEModel: KGE model.
.. _Meta-Knowledge Transfer for Inductive Knowledge Graph Embedding: https://arxiv.org/abs/2110.14170
"""
def __init__(self, args):
super().__init__()
self.args = args
args.ent_dim = args.emb_dim
args.rel_dim = args.emb_dim
if args.kge_model in ['ComplEx', 'RotatE']:
args.ent_dim = args.emb_dim * 2
if args.kge_model in ['ComplEx']:
args.rel_dim = args.emb_dim * 2
args.num_rel = self.get_num_rel(args)
self.ent_init = EntInit(args)
self.rgcn = RGCN(args, basiclayer = RelMorsGraphConv)
self.kge_model = KGEModel(args)
[docs] def forward(self, sample, ent_emb, mode='single'):
"""Calculating triple score.
Args:
sample: Sampled triplets.
ent_emb: Embedding of entities.
mode: This arg indicates that negative entity will replace the head or tail entity.
Returns:
score: Score of triple.
"""
return self.kge_model(sample, ent_emb, mode)
[docs] def get_intest_train_g(self):
"""Getting inductive test-train graph.
Returns:
indtest_train_g: test-train graph.
"""
data, _, _ , _ = get_indtest_test_dataset_and_train_g(self.args)
self.indtest_train_g = get_g_bidir(torch.LongTensor(data['train']), self.args)
self.indtest_train_g = self.indtest_train_g.to(self.args.gpu)
return self.indtest_train_g
[docs] def get_ent_emb(self, sup_g_bidir):
"""Getting entities embedding.
Args:
sup_g_bidir: Undirected supporting graph.
Returns:
ent_emb: Embedding of entities.
"""
self.ent_init(sup_g_bidir)
ent_emb = self.rgcn(sup_g_bidir)
return ent_emb
[docs] def get_score(self, batch, mode):
"""Getting score of triplets.
Args:
batch: Including positive sample, entities embedding, etc.
Returns:
score: Score of positive or negative sample.
"""
pos_triple = batch["positive_sample"]
ent_emb = batch["ent_emb"]
if batch['cand'] == 'all':
return self.kge_model((pos_triple, None), ent_emb, mode)
else:
if mode == 'tail_predict':
tail_cand = batch['tail_cand']
return self.kge_model((pos_triple, tail_cand), ent_emb, mode)
else:
head_cand = batch['head_cand']
return self.kge_model((pos_triple, head_cand), ent_emb, mode)
[docs] def get_num_rel(self, args):
"""Getting number of relation.
Args:
args: Model configuration parameters.
Returns:
num_rel: The number of relation.
"""
data = pickle.load(open(args.pk_path, 'rb'))
num_rel = len(np.unique(np.array(data['train_graph']['train'])[:, 1]))
return num_rel
[docs]class EntInit(nn.Module):
"""Class of initializing entities.
Attributes:
args: Model configuration parameters.
rel_head_emb: Embedding of relation to head.
rel_tail_emb: Embedding of relation to tail.
"""
def __init__(self, args):
super(EntInit, self).__init__()
self.args = args
self.rel_head_emb = nn.Parameter(torch.Tensor(args.num_rel, args.ent_dim))
self.rel_tail_emb = nn.Parameter(torch.Tensor(args.num_rel, args.ent_dim))
nn.init.xavier_normal_(self.rel_head_emb, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_normal_(self.rel_tail_emb, gain=nn.init.calculate_gain('relu'))
[docs] def forward(self, g_bidir):
"""Initialize entities in graph.
Args:
g_bidir: Undirected graph.
"""
num_edge = g_bidir.num_edges()
etypes = g_bidir.edata['type']
g_bidir.edata['ent_e'] = torch.zeros(num_edge, self.args.ent_dim).type_as(etypes).float()
rh_idx = etypes < self.args.num_rel
rt_idx = etypes >= self.args.num_rel
g_bidir.edata['ent_e'][rh_idx] = self.rel_head_emb[etypes[rh_idx]]
g_bidir.edata['ent_e'][rt_idx] = self.rel_tail_emb[etypes[rt_idx] - self.args.num_rel]
message_func = dgl.function.copy_e('ent_e', 'msg')
reduce_func = dgl.function.mean('msg', 'feat')
g_bidir.update_all(message_func, reduce_func)
g_bidir.edata.pop('ent_e')
[docs]class RelMorsGraphConv(RelGraphConv):
"""Basic layer of RGCN.
Attributes:
args: Model configuration parameters.
bias: Weight bias.
inp_dim: Dimension of input.
out_dim: Dimension of output.
num_rels: The number of relations.
num_bases: The number of bases.
has_attn: Whether there is attention mechanism.
is_input_layer: Whether it is input layer.
aggregator: Type of aggregator.
weight: Weight matrix.
w_comp: Bases matrix.
self_loop_weight: Self-loop weight.
edge_dropout: Dropout of edge.
"""
def __init__(self, args, inp_dim, out_dim, aggregator, num_rels, num_bases=-1, bias=False,
activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False, has_attn=False):
super().__init__(args, inp_dim, out_dim, 0, None, 0, bias=bias, activation=activation,
self_loop=True, dropout=0.0, layer_norm=False)
self.in_dim = inp_dim
self.out_dim = out_dim
self.num_rels = num_rels
self.aggregator = aggregator
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases <= 0:
self.num_bases = self.num_rels
self.rel_weight = None
self.input_ = None
self.activation = activation
self.is_input_layer = is_input_layer
self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_dim, self.out_dim))
self.w_comp = nn.Parameter(torch.Tensor(self.num_rels*2, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu'))
self.aggregator = self.aggregator
[docs] def message(self, edges):
"""Message function for propagating.
Args:
edges: Edges in graph.
Returns:
curr_emb: Embedding of current layer.
msg: Message for propagating.
a: Coefficient.
"""
w = self.rel_weight.index_select(0, edges.data['type'])
msg = torch.bmm(edges.src[self.input_].unsqueeze(1), w).squeeze(1)
curr_emb = torch.mm(edges.dst[self.input_], self.loop_weight) # (B, F)
a = 1 / edges.dst['in_d'].type_as(w).to(torch.float32).reshape(-1, 1)
return {'curr_emb': curr_emb, 'msg': msg, 'alpha': a}
[docs] def apply_node_func(self, nodes):
"""Function used for nodes.
Args:
nodes: nodes in graph.
Returns:
node_repr: Representation of nodes.
"""
node_repr = nodes.data['h']
if self.bias:
node_repr = node_repr + self.h_bias
if self.activation:
node_repr = self.activation(node_repr)
return {'h': node_repr}
[docs] def forward(self, g):
"""Update node representation.
Args:
g: Subgraph of corresponding triple.
"""
# generate all relations' weight from bases
weight = self.weight.view(self.num_bases, self.in_dim * self.out_dim)
self.rel_weight = torch.matmul(self.w_comp, weight).view(
self.num_rels*2, self.in_dim, self.out_dim)
# normalization constant
g.dstdata['in_d'] = g.in_degrees()
self.input_ = 'feat' if self.is_input_layer else 'h'
g.update_all(self.message, self.aggregator, self.apply_node_func)
if self.is_input_layer:
g.ndata['repr'] = torch.cat([g.ndata['feat'], g.ndata['h']], dim=1)
else:
g.ndata['repr'] = torch.cat([g.ndata['repr'], g.ndata['h']], dim=1)
[docs]class RGCN(Model):
"""RGCN model
Attributes:
args: Model configuration parameters.
basiclayer: Layer of RGCN model.
inp_dim: Dimension of input.
emb_dim: Dimension of embedding.
aggregator: Type of aggregator.
"""
def __init__(self, args, basiclayer):
super(RGCN, self).__init__(args)
self.args = args
self.basiclayer = basiclayer
self.inp_dim = args.ent_dim
self.emb_dim = args.ent_dim
self.num_rel = args.num_rel
self.num_bases = args.num_bases
aggregator_type = self.args.gnn_agg_type.upper()+"Aggregator"
aggregator_class = import_class(f"neuralkg_ind.model.{aggregator_type}")
self.aggregator = aggregator_class(self.emb_dim)
self.build_model()
self.jk_linear = nn.Linear(self.emb_dim*(self.args.num_layers+1), self.emb_dim)
[docs] def build_hidden_layer(self, idx):
"""build hidden layer of RGCN.
Args:
idx: The idx of layer.
Returns:
output: Build a basic layer according to whether it is the first layer.
"""
input_flag = True if idx == 0 else False
return self.basiclayer(self.args, self.inp_dim, self.emb_dim, self.aggregator, self.num_rel, self.num_bases, bias=True,
activation=F.relu, is_input_layer=input_flag)
[docs] def forward(self, g):
"""Getting nodes embedding.
Args:
g: Subgraph of corresponding task.
Returns:
g.ndata['h']: Nodes embedding.
"""
for layer in self.layers:
layer(g)
g.ndata['h'] = self.jk_linear(g.ndata['repr'])
return g.ndata['h']
[docs]class KGEModel(nn.Module):
"""KGE model
Attributes:
args: Model configuration parameters.
model_name: The name of model.
nrelation: The number of relation.
emb_dim: Dimension of embedding.
epsilon: Calculate embedding_range.
margin: Calculate embedding_range and loss.
embedding_range: Uniform distribution range.
relation_embedding: Embedding of relation.
"""
def __init__(self, args):
super(KGEModel, self).__init__()
self.args = args
self.model_name = args.kge_model
self.nrelation = args.num_rel
self.emb_dim = args.emb_dim
self.epsilon = 2.0
self.margin = torch.Tensor([args.margin])
self.embedding_range = torch.Tensor([(self.margin.item() + self.epsilon) / args.emb_dim])
self.relation_embedding = nn.Parameter(torch.zeros(self.nrelation, self.args.rel_dim))
nn.init.uniform_(
tensor=self.relation_embedding,
a=-self.embedding_range.item(),
b=self.embedding_range.item()
)
if self.model_name not in ['TransE', 'DistMult', 'ComplEx', 'RotatE']:
raise ValueError('model %s not supported' % self.model_name)
[docs] def forward(self, sample, ent_emb, mode='single'):
'''Forward function that calculate the score of a batch of triples.
In the 'single' mode, sample is a batch of triple.
In the 'head-batch' or 'tail-batch' mode, sample consists two part.
The first part is usually the positive sample.
And the second part is the entities in the negative samples.
Because negative samples and positive samples usually share two elements
in their triple ((head, relation) or (relation, tail)).
Args:
sample: Positive and negative sample.
ent_emb: Embedding of entities.
mode: 'single', 'head-batch' or 'tail-batch'.
Returns:
score: The score of sample.
'''
self.entity_embedding = ent_emb
if mode == 'single':
batch_size, negative_sample_size = sample.size(0), 1
head = torch.index_select(
self.entity_embedding,
dim=0,
index=sample[:, 0]
).unsqueeze(1)
relation = torch.index_select(
self.relation_embedding,
dim=0,
index=sample[:, 1]
).unsqueeze(1)
tail = torch.index_select(
self.entity_embedding,
dim=0,
index=sample[:, 2]
).unsqueeze(1)
elif mode == 'head_predict':
tail_part, head_part = sample
if head_part != None:
batch_size, negative_sample_size = head_part.size(0), head_part.size(1)
if head_part == None:
head = self.entity_embedding.unsqueeze(0)
else:
head = torch.index_select(
self.entity_embedding,
dim=0,
index=head_part.view(-1)
).view(batch_size, negative_sample_size, -1)
relation = torch.index_select(
self.relation_embedding,
dim=0,
index=tail_part[:, 1]
).unsqueeze(1)
tail = torch.index_select(
self.entity_embedding,
dim=0,
index=tail_part[:, 2]
).unsqueeze(1)
elif mode == 'tail_predict':
head_part, tail_part = sample
if tail_part != None:
batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
head = torch.index_select(
self.entity_embedding,
dim=0,
index=head_part[:, 0]
).unsqueeze(1)
relation = torch.index_select(
self.relation_embedding,
dim=0,
index=head_part[:, 1]
).unsqueeze(1)
if tail_part == None:
tail = self.entity_embedding.unsqueeze(0)
else:
tail = torch.index_select(
self.entity_embedding,
dim=0,
index=tail_part.view(-1)
).view(batch_size, negative_sample_size, -1)
elif mode == 'rel-batch':
head_part, tail_part = sample
if tail_part != None:
batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)
head = torch.index_select(
self.entity_embedding,
dim=0,
index=head_part[:, 0]
).unsqueeze(1)
tail = torch.index_select(
self.entity_embedding,
dim=0,
index=head_part[:, 2]
).unsqueeze(1)
if tail_part == None:
relation = self.relation_embedding.unsqueeze(0)
else:
relation = torch.index_select(
self.relation_embedding,
dim=0,
index=tail_part.view(-1)
).view(batch_size, negative_sample_size, -1)
else:
raise ValueError('mode %s not supported' % mode)
model_func = {
'TransE': TransE.score_func,
'DistMult': DistMult.score_func,
'ComplEx': ComplEx.score_func,
'RotatE': RotatE.score_func,
}
if self.model_name in model_func:
score = model_func[self.model_name](self, head, relation, tail, mode)
else:
raise ValueError('model %s not supported' % self.model_name)
return score