import torch.nn as nn
import torch
from .model import Model
from numpy.random import RandomState
import numpy as np
[docs]class DualE(Model):
"""`Dual Quaternion Knowledge Graph Embeddings`_ (DualE), which introduces dual quaternions into knowledge graph embeddings.
Attributes:
args: Model configuration parameters.
ent_emb: Entity embedding, shape:[num_ent, emb_dim * 8].
rel_emb: Relation embedding, shape:[num_rel, emb_dim * 8].
.. Dual Quaternion Knowledge Graph Embeddings: https://ojs.aaai.org/index.php/AAAI/article/view/16850
"""
def __init__(self, args):
super(DualE, self).__init__(args)
self.args = args
self.ent_emb = nn.Embedding(self.args.num_ent, self.args.emb_dim*8)
self.rel_emb = nn.Embedding(self.args.num_rel, self.args.emb_dim*8)
self.criterion = nn.Softplus()
self.fc = nn.Linear(100, 50, bias=False)
self.ent_dropout = torch.nn.Dropout(0)
self.rel_dropout = torch.nn.Dropout(0)
self.bn = torch.nn.BatchNorm1d(self.args.emb_dim)
self.init_weights()
[docs] def init_weights(self):
r, i, j, k,r_1,i_1,j_1,k_1 = self.quaternion_init(self.args.num_ent, self.args.emb_dim)
r, i, j, k,r_1,i_1,j_1,k_1 = torch.from_numpy(r), torch.from_numpy(i), torch.from_numpy(j), torch.from_numpy(k),\
torch.from_numpy(r_1), torch.from_numpy(i_1), torch.from_numpy(j_1), torch.from_numpy(k_1)
tmp_ent_emb = torch.cat((r, i, j, k,r_1,i_1,j_1,k_1),1)
self.ent_emb.weight.data = tmp_ent_emb.type_as(self.ent_emb.weight.data)
s, x, y, z,s_1,x_1,y_1,z_1 = self.quaternion_init(self.args.num_ent, self.args.emb_dim)
s, x, y, z,s_1,x_1,y_1,z_1 = torch.from_numpy(s), torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(z), \
torch.from_numpy(s_1), torch.from_numpy(x_1), torch.from_numpy(y_1), torch.from_numpy(z_1)
tmp_rel_emb = torch.cat((s, x, y, z,s_1,x_1,y_1,z_1),1)
self.rel_emb.weight.data = tmp_rel_emb.type_as(self.ent_emb.weight.data)
#Calculate the Dual Hamiltonian product
def _omult(self, a_0, a_1, a_2, a_3, b_0, b_1, b_2, b_3, c_0, c_1, c_2, c_3, d_0, d_1, d_2, d_3):
h_0=a_0*c_0-a_1*c_1-a_2*c_2-a_3*c_3
h1_0=a_0*d_0+b_0*c_0-a_1*d_1-b_1*c_1-a_2*d_2-b_2*c_2-a_3*d_3-b_3*c_3
h_1=a_0*c_1+a_1*c_0+a_2*c_3-a_3*c_2
h1_1=a_0*d_1+b_0*c_1+a_1*d_0+b_1*c_0+a_2*d_3+b_2*c_3-a_3*d_2-b_3*c_2
h_2=a_0*c_2-a_1*c_3+a_2*c_0+a_3*c_1
h1_2=a_0*d_2+b_0*c_2-a_1*d_3-b_1*c_3+a_2*d_0+b_2*c_0+a_3*d_1+b_3*c_1
h_3=a_0*c_3+a_1*c_2-a_2*c_1+a_3*c_0
h1_3=a_0*d_3+b_0*c_3+a_1*d_2+b_1*c_2-a_2*d_1-b_2*c_1+a_3*d_0+b_3*c_0
return (h_0,h_1,h_2,h_3,h1_0,h1_1,h1_2,h1_3)
#Normalization of relationship embedding
def _onorm(self,r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8):
denominator_0 = r_1 ** 2 + r_2 ** 2 + r_3 ** 2 + r_4 ** 2
denominator_1 = torch.sqrt(denominator_0)
deno_cross = r_5 * r_1 + r_6 * r_2 + r_7 * r_3 + r_8 * r_4
r_5 = r_5 - deno_cross / denominator_0 * r_1
r_6 = r_6 - deno_cross / denominator_0 * r_2
r_7 = r_7 - deno_cross / denominator_0 * r_3
r_8 = r_8 - deno_cross / denominator_0 * r_4
r_1 = r_1 / denominator_1
r_2 = r_2 / denominator_1
r_3 = r_3 / denominator_1
r_4 = r_4 / denominator_1
return r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8
#Calculate the inner product of the head entity and the relationship Hamiltonian product and the tail entity
[docs] def score_func(self, head_emb, relation_emb, tail_emb, mode):
"""Calculating the score of triples.
The formula for calculating the score is :math:` <\boldsymbol{Q}_h \otimes \boldsymbol{W}_r^{\diamond}, \boldsymbol{Q}_t> `
Args:
head_emb: The head entity embedding with 8 dimensionalities.
relation_emb: The relation embedding with 8 dimensionalities.
tail_emb: The tail entity embedding with 8 dimensionalities.
mode: Choose head-predict or tail-predict.
Returns:
score: The score of triples with regul_1 and regul_2
"""
e_1_h,e_2_h,e_3_h,e_4_h,e_5_h,e_6_h,e_7_h,e_8_h = torch.chunk(head_emb, 8, dim=-1)
e_1_t,e_2_t,e_3_t,e_4_t,e_5_t,e_6_t,e_7_t,e_8_t = torch.chunk(tail_emb, 8, dim=-1)
r_1,r_2,r_3,r_4,r_5,r_6,r_7,r_8 = torch.chunk(relation_emb, 8, dim=-1)
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )
o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8)
score_r = (o_1 * e_1_t + o_2 * e_2_t + o_3 * e_3_t + o_4 * e_4_t
+ o_5 * e_5_t + o_6 * e_6_t + o_7 * e_7_t + o_8 * e_8_t)
regul_1 = (torch.mean(torch.abs(e_1_h) ** 2)
+ torch.mean(torch.abs(e_2_h) ** 2)
+ torch.mean(torch.abs(e_3_h) ** 2)
+ torch.mean(torch.abs(e_4_h) ** 2)
+ torch.mean(torch.abs(e_5_h) ** 2)
+ torch.mean(torch.abs(e_6_h) ** 2)
+ torch.mean(torch.abs(e_7_h) ** 2)
+ torch.mean(torch.abs(e_8_h) ** 2)
+ torch.mean(torch.abs(e_1_t) ** 2)
+ torch.mean(torch.abs(e_2_t) ** 2)
+ torch.mean(torch.abs(e_3_t) ** 2)
+ torch.mean(torch.abs(e_4_t) ** 2)
+ torch.mean(torch.abs(e_5_t) ** 2)
+ torch.mean(torch.abs(e_6_t) ** 2)
+ torch.mean(torch.abs(e_7_t) ** 2)
+ torch.mean(torch.abs(e_8_t) ** 2)
)
regul_2 = (torch.mean(torch.abs(r_1) ** 2)
+ torch.mean(torch.abs(r_2) ** 2)
+ torch.mean(torch.abs(r_3) ** 2)
+ torch.mean(torch.abs(r_4) ** 2)
+ torch.mean(torch.abs(r_5) ** 2)
+ torch.mean(torch.abs(r_6) ** 2)
+ torch.mean(torch.abs(r_7) ** 2)
+ torch.mean(torch.abs(r_8) ** 2))
return (torch.sum(score_r, -1), regul_1, regul_2)
[docs] def forward(self, triples, negs=None, mode='single'):
if negs != None:
head_emb, relation_emb, tail_emb = self.tri2emb(negs)
else:
head_emb, relation_emb, tail_emb = self.tri2emb(triples)
score, regul_1, regul_2 = self.score_func(head_emb, relation_emb, tail_emb, mode)
return (score, regul_1, regul_2)
[docs] def get_score(self, batch, mode):
"""The functions used in the testing phase
Args:
batch: A batch of data.
mode: Choose head-predict or tail-predict.
Returns:
score: The score of triples.
"""
triples = batch["positive_sample"]
head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
e_1_h,e_2_h,e_3_h,e_4_h,e_5_h,e_6_h,e_7_h,e_8_h = torch.chunk(head_emb, 8, dim=-1)
e_1_t,e_2_t,e_3_t,e_4_t,e_5_t,e_6_t,e_7_t,e_8_t = torch.chunk(tail_emb, 8, dim=-1)
r_1,r_2,r_3,r_4,r_5,r_6,r_7,r_8 = torch.chunk(relation_emb, 8, dim=-1)
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )
o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8)
score_r = (o_1 * e_1_t + o_2 * e_2_t + o_3 * e_3_t + o_4 * e_4_t
+ o_5 * e_5_t + o_6 * e_6_t + o_7 * e_7_t + o_8 * e_8_t)
return torch.sum(score_r, -1)
[docs] def quaternion_init(self, in_features, out_features, criterion='he'):
"""
Quaternion-valued weight initialization
the initialization scheme is optional on these four datasets,
random initialization can get the same performance. This initialization
scheme might be useful for the case which needs fewer epochs.
"""
fan_in = in_features
fan_out = out_features
if criterion == 'glorot':
s = 1. / np.sqrt(2 * (fan_in + fan_out))
elif criterion == 'he':
s = 1. / np.sqrt(2 * fan_in)
else:
raise ValueError('Invalid criterion: ', criterion)
rng = RandomState(2020)
# Generating randoms and purely imaginary quaternions :
kernel_shape = (in_features, out_features)
number_of_weights = np.prod(kernel_shape) # in_features*out_features
v_i = np.random.uniform(0.0, 1.0, number_of_weights) #(low,high,size)
v_j = np.random.uniform(0.0, 1.0, number_of_weights)
v_k = np.random.uniform(0.0, 1.0, number_of_weights)
# Purely imaginary quaternions unitary
for i in range(0, number_of_weights):
norm = np.sqrt(v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2) + 0.0001
v_i[i] /= norm
v_j[i] /= norm
v_k[i] /= norm
v_i = v_i.reshape(kernel_shape)
v_j = v_j.reshape(kernel_shape)
v_k = v_k.reshape(kernel_shape)
modulus = rng.uniform(low=-s, high=s, size=kernel_shape)
# Calculate the three parts about t
kernel_shape1 = (in_features, out_features)
number_of_weights1 = np.prod(kernel_shape1)
t_i = np.random.uniform(0.0, 1.0, number_of_weights1)
t_j = np.random.uniform(0.0, 1.0, number_of_weights1)
t_k = np.random.uniform(0.0, 1.0, number_of_weights1)
# Purely imaginary quaternions unitary
for i in range(0, number_of_weights1):
norm1 = np.sqrt(t_i[i] ** 2 + t_j[i] ** 2 + t_k[i] ** 2) + 0.0001
t_i[i] /= norm1
t_j[i] /= norm1
t_k[i] /= norm1
t_i = t_i.reshape(kernel_shape1)
t_j = t_j.reshape(kernel_shape1)
t_k = t_k.reshape(kernel_shape1)
tmp_t = rng.uniform(low=-s, high=s, size=kernel_shape1)
phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape)
phase1 = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape1)
weight_r = modulus * np.cos(phase)
weight_i = modulus * v_i * np.sin(phase)
weight_j = modulus * v_j * np.sin(phase)
weight_k = modulus * v_k * np.sin(phase)
wt_i = tmp_t * t_i * np.sin(phase1)
wt_j = tmp_t * t_j * np.sin(phase1)
wt_k = tmp_t * t_k * np.sin(phase1)
i_0=weight_r
i_1=weight_i
i_2=weight_j
i_3=weight_k
i_4=(-wt_i*weight_i-wt_j*weight_j-wt_k*weight_k)/2
i_5=(wt_i*weight_r+wt_j*weight_k-wt_k*weight_j)/2
i_6=(-wt_i*weight_k+wt_j*weight_r+wt_k*weight_i)/2
i_7=(wt_i*weight_j-wt_j*weight_i+wt_k*weight_r)/2
return (i_0,i_1,i_2,i_3,i_4,i_5,i_6,i_7)