import torch.nn as nn
import torch
[docs]class Model(nn.Module):
def __init__(self, args):
super(Model, self).__init__()
[docs] def init_emb(self):
raise NotImplementedError
[docs] def score_func(self, head_emb, relation_emb, tail_emb):
raise NotImplementedError
[docs] def forward(self, triples, negs, mode):
raise NotImplementedError
[docs] def tri2emb(self, triples, negs=None, mode="single"):
"""Get embedding of triples.
This function get the embeddings of head, relation, and tail
respectively. each embedding has three dimensions.
Args:
triples (tensor): This tensor save triples id, which dimension is
[triples number, 3].
negs (tensor, optional): This tenosr store the id of the entity to
be replaced, which has one dimension. when negs is None, it is
in the test/eval phase. Defaults to None.
mode (str, optional): This arg indicates that the negative entity
will replace the head or tail entity. when it is 'single', it
means that entity will not be replaced. Defaults to 'single'.
Returns:
head_emb: Head entity embedding.
relation_emb: Relation embedding.
tail_emb: Tail entity embedding.
"""
if mode == "single":
head_emb = self.ent_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim]
relation_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim]
tail_emb = self.ent_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim]
elif mode == "head-batch" or mode == "head_predict":
if negs is None: # 说明这个时候是在evluation,所以需要直接用所有的entity embedding
head_emb = self.ent_emb.weight.data.unsqueeze(0) # [1, num_ent, dim]
else:
head_emb = self.ent_emb(negs) # [bs, num_neg, dim]
relation_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim]
tail_emb = self.ent_emb(triples[:, 2]).unsqueeze(1) # [bs, 1, dim]
elif mode == "tail-batch" or mode == "tail_predict":
head_emb = self.ent_emb(triples[:, 0]).unsqueeze(1) # [bs, 1, dim]
relation_emb = self.rel_emb(triples[:, 1]).unsqueeze(1) # [bs, 1, dim]
if negs is None:
tail_emb = self.ent_emb.weight.data.unsqueeze(0) # [1, num_ent, dim]
else:
tail_emb = self.ent_emb(negs) # [bs, num_neg, dim]
return head_emb, relation_emb, tail_emb