import torch
import torch.nn as nn
import dgl
import dgl.function as fn
from neuralkg_ind import utils
from neuralkg_ind.utils.tools import get_param
from neuralkg_ind.model import ConvE
[docs]class SEGNN(nn.Module):
def __init__(self, args):
super(SEGNN, self).__init__()
self.device = torch.device("cuda:0") #TODO: remove cuda
self.args = args
self.dataset = self.args.dataset_name
self.n_ent = self.args.num_ent
self.n_rel = self.args.num_rel
self.emb_dim = self.args.emb_dim
# entity embedding
self.ent_emb = get_param(self.n_ent, self.emb_dim)
# gnn layer
self.kg_n_layer = self.args.kg_layer #1
# relation SE layer
self.edge_layers = nn.ModuleList([EdgeLayer(self.args) for _ in range(self.kg_n_layer)])
# entity SE layer
self.node_layers = nn.ModuleList([NodeLayer(self.args) for _ in range(self.kg_n_layer)])
# triple SE layer
self.comp_layers = nn.ModuleList([CompLayer(self.args) for _ in range(self.kg_n_layer)])
# relation embedding for aggregation
self.rel_embs = nn.ParameterList([get_param(self.n_rel * 2, self.emb_dim) for _ in range(self.kg_n_layer)])
# relation embedding for prediction
if self.args.pred_rel_w: #true
self.rel_w = get_param(self.emb_dim * self.kg_n_layer, self.emb_dim).to(self.device)
else:
self.pred_rel_emb = get_param(self.n_rel * 2, self.emb_dim)
self.predictor = ConvE(self.args) #(200, 250, 7)
self.ent_drop = nn.Dropout(self.args.ent_drop) #0.2
self.rel_drop = nn.Dropout(self.args.rel_drop) #0
self.ent_pred_drop = nn.Dropout(self.args.ent_drop_pred)
self.act = nn.Tanh()
[docs] def concat(self, head_emb, rela_emb):
head_emb = head_emb.view(-1, 1, head_emb.shape[-1])
rela_emb = rela_emb.view(-1, 1, rela_emb.shape[-1])
stacked_input = torch.cat([head_emb, rela_emb], 1)
stacked_input = torch.transpose(stacked_input, 2, 1).reshape((-1, 1, 2 * self.args.k_h, self.args.k_w))
return stacked_input
[docs] def forward(self, h_id, r_id, kg):
"""
matching computation between query (h, r) and answer t.
:param h_id: head entity id, (bs, )
:param r_id: relation id, (bs, )
:param kg: aggregation graph
:return: matching score, (bs, n_ent)
"""
# aggregate embedding
kg = kg.to(self.device)
ent_emb, rel_emb = self.aggragate_emb(kg)
head = ent_emb[h_id]
rel = rel_emb[r_id]
# (bs, n_ent)
ent_emb = self.ent_pred_drop(ent_emb)
score = self.predictor.score_func(head, rel, self.concat, ent_emb)
return score
[docs] def aggragate_emb(self, kg):
"""
aggregate embedding.
:param kg:
:return:
"""
ent_emb = self.ent_emb
rel_emb_list = []
for edge_layer, node_layer, comp_layer, rel_emb in zip(self.edge_layers, self.node_layers, self.comp_layers, self.rel_embs):
ent_emb, rel_emb = self.ent_drop(ent_emb), self.rel_drop(rel_emb)
ent_emb = ent_emb.to(self.device)
rel_emb = rel_emb.to(self.device)
edge_ent_emb = edge_layer(kg, ent_emb, rel_emb)
node_ent_emb = node_layer(kg, ent_emb)
comp_ent_emb = comp_layer(kg, ent_emb, rel_emb)
ent_emb = ent_emb + edge_ent_emb + node_ent_emb + comp_ent_emb
rel_emb_list.append(rel_emb)
if self.args.pred_rel_w:
pred_rel_emb = torch.cat(rel_emb_list, dim=1).to(self.device)
pred_rel_emb = pred_rel_emb.mm(self.rel_w)
else:
pred_rel_emb = self.pred_rel_emb
return ent_emb, pred_rel_emb
[docs]class CompLayer(nn.Module):
def __init__(self, args):
super(CompLayer, self).__init__()
self.device = torch.device("cuda:0")
self.args = args
self.dataset = self.args.dataset_name
self.n_ent = self.args.num_ent
self.n_rel = self.args.num_rel
self.emb_dim = self.args.emb_dim
self.comp_op = self.args.comp_op #'mul'
assert self.comp_op in ['add', 'mul']
self.neigh_w = get_param(self.emb_dim, self.emb_dim).to(self.device)
self.act = nn.Tanh()
if self.args.bn:
self.bn = torch.nn.BatchNorm1d(self.emb_dim).to(self.device)
else:
self.bn = None
[docs] def forward(self, kg, ent_emb, rel_emb):
assert kg.number_of_nodes() == ent_emb.shape[0]
assert rel_emb.shape[0] == 2 * self.n_rel
ent_emb = ent_emb.to(self.device)
rel_emb = rel_emb.to(self.device)
kg = kg.to(self.device)
with kg.local_scope():
kg.ndata['emb'] = ent_emb
rel_id = kg.edata['rel_id']
kg.edata['emb'] = rel_emb[rel_id]
# neihgbor entity and relation composition
if self.args.comp_op == 'add':
kg.apply_edges(fn.u_add_e('emb', 'emb', 'comp_emb'))
elif self.args.comp_op == 'mul':
kg.apply_edges(fn.u_mul_e('emb', 'emb', 'comp_emb'))
else:
raise NotImplementedError
# attention
kg.apply_edges(fn.e_dot_v('comp_emb', 'emb', 'norm')) # (n_edge, 1)
kg.edata['norm'] = dgl.ops.edge_softmax(kg, kg.edata['norm'])
# agg
kg.edata['comp_emb'] = kg.edata['comp_emb'] * kg.edata['norm']
kg.update_all(fn.copy_e('comp_emb', 'm'), fn.sum('m', 'neigh'))
neigh_ent_emb = kg.ndata['neigh']
neigh_ent_emb = neigh_ent_emb.mm(self.neigh_w)
if callable(self.bn):
neigh_ent_emb = self.bn(neigh_ent_emb)
neigh_ent_emb = self.act(neigh_ent_emb)
return neigh_ent_emb
[docs]class NodeLayer(nn.Module):
def __init__(self, args):
super(NodeLayer, self).__init__()
self.device = torch.device("cuda:0")
self.args = args
self.dataset = self.args.dataset_name
self.n_ent = self.args.num_ent
self.n_rel = self.args.num_rel
self.emb_dim = self.args.emb_dim
self.neigh_w = get_param(self.emb_dim, self.emb_dim).to(self.device)
self.act = nn.Tanh()
if self.args.bn:
self.bn = torch.nn.BatchNorm1d(self.emb_dim).to(self.device)
else:
self.bn = None
[docs] def forward(self, kg, ent_emb):
assert kg.number_of_nodes() == ent_emb.shape[0]
kg = kg.to(self.device)
ent_emb = ent_emb.to(self.device)
with kg.local_scope():
kg.ndata['emb'] = ent_emb
# attention
kg.apply_edges(fn.u_dot_v('emb', 'emb', 'norm')) # (n_edge, 1)
kg.edata['norm'] = dgl.ops.edge_softmax(kg, kg.edata['norm'])
# agg
kg.update_all(fn.u_mul_e('emb', 'norm', 'm'), fn.sum('m', 'neigh'))
neigh_ent_emb = kg.ndata['neigh']
neigh_ent_emb = neigh_ent_emb.mm(self.neigh_w)
if callable(self.bn):
neigh_ent_emb = self.bn(neigh_ent_emb)
neigh_ent_emb = self.act(neigh_ent_emb)
return neigh_ent_emb
[docs]class EdgeLayer(nn.Module):
def __init__(self, args):
super(EdgeLayer, self).__init__()
self.device = torch.device("cuda:0")
self.args = args
self.dataset = self.args.dataset_name
self.n_ent = self.args.num_ent
self.n_rel = self.args.num_rel
self.emb_dim = self.args.emb_dim
self.neigh_w = utils.get_param(self.emb_dim, self.emb_dim).to(self.device)
self.act = nn.Tanh()
if self.args.bn: # True
self.bn = torch.nn.BatchNorm1d(self.emb_dim).to(self.device)
else:
self.bn = None
[docs] def forward(self, kg, ent_emb, rel_emb):
assert kg.number_of_nodes() == ent_emb.shape[0]
assert rel_emb.shape[0] == 2 * self.n_rel
kg = kg.to(self.device)
ent_emb = ent_emb.to(self.device)
rel_emb = rel_emb.to(self.device)
with kg.local_scope():
kg.ndata['emb'] = ent_emb
rel_id = kg.edata['rel_id']
kg.edata['emb'] = rel_emb[rel_id]
# attention
kg.apply_edges(fn.e_dot_v('emb', 'emb', 'norm')) # (n_edge, 1)
kg.edata['norm'] = dgl.ops.edge_softmax(kg, kg.edata['norm'])
# agg
kg.edata['emb'] = kg.edata['emb'] * kg.edata['norm']
kg.update_all(fn.copy_e('emb', 'm'), fn.sum('m', 'neigh'))
neigh_ent_emb = kg.ndata['neigh']
neigh_ent_emb = neigh_ent_emb.mm(self.neigh_w)
if callable(self.bn):
neigh_ent_emb = self.bn(neigh_ent_emb)
neigh_ent_emb = self.act(neigh_ent_emb)
return neigh_ent_emb