Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConvE scoring function #35

Merged
merged 2 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions besskge/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ def init_uniform_norm(embedding_table: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.normalize(torch.nn.init.uniform(embedding_table), dim=-1)


def init_xavier_norm(embedding_table: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
"""
Initialize embeddings according to Xavier normal scheme, with
`fan_in = 0`, `fan_out=row_size`.

:param embedding_table:
Tensor of embedding parameters to initialize.
:param gain:
Scaling factor for standard deviation. Default: 1.0.

:return:
Initialized tensor.
"""
return torch.nn.init.normal_(
embedding_table, std=gain * np.sqrt(2.0 / embedding_table.shape[-1])
)


def init_KGE_uniform(
embedding_table: torch.Tensor, b: float = 1.0, divide_by_embedding_size: bool = True
) -> torch.Tensor:
Expand Down
219 changes: 210 additions & 9 deletions besskge/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
init_KGE_normal,
init_KGE_uniform,
init_uniform_norm,
init_xavier_norm,
initialize_entity_embedding,
initialize_relation_embedding,
refactor_embedding_sharding,
Expand Down Expand Up @@ -291,7 +292,7 @@ def __init__(
:param relation_initializer:
Initialization function or table for relation embeddings.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(TransE, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -391,7 +392,7 @@ def __init__(
:param relation_initializer:
Initialization function or table for relation embeddings.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(RotatE, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -502,7 +503,7 @@ def __init__(
If True, L2-normalize head and tail entity embeddings before projecting,
as in :cite:p:`PairRE`. Default: True.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(PairRE, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -637,7 +638,7 @@ def __init__(
Offset factor for head/tail relation projections, as in TripleREv2.
Default: 0.0 (no offset).
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(TripleRE, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -776,7 +777,7 @@ def __init__(
:param relation_initializer:
Initialization function or table for relation embeddings.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(DistMult, self).__init__(negative_sample_sharing=negative_sample_sharing)

Expand Down Expand Up @@ -870,7 +871,7 @@ def __init__(
:param relation_initializer:
Initialization function or table for relation embeddings.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(ComplEx, self).__init__(negative_sample_sharing=negative_sample_sharing)

Expand Down Expand Up @@ -944,6 +945,206 @@ def score_tails(
)


class ConvE(MatrixDecompositionScoreFunction):
"""
ConvE scoring function :cite:p:`ConvE`.

Note that, differently from :cite:p:`ConvE`, the scores returned by this class
have not been passed through a final sigmoid layer, as we assume that this is
included in the loss function.

By design, this scoring function should be used in combination with a
negative/candidate sampler that only corrupts tails (possibly after
including all inverse triples in the dataset, see the `add_inverse_triples`
argument in :func:`besskge.sharding.PartitionedTripleSet.create_from_dataset`).
"""

def __init__(
self,
negative_sample_sharing: bool,
sharding: Sharding,
n_relation_type: int,
embedding_size: int,
embedding_height: int,
embedding_width: int,
entity_initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]] = [
init_xavier_norm,
torch.nn.init.zeros_,
],
relation_initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]] = [
init_xavier_norm,
],
inverse_relations: bool = True,
input_channels: int = 1,
output_channels: int = 32,
kernel_height: int = 3,
kernel_width: int = 3,
input_dropout: float = 0.2,
feature_map_dropout: float = 0.2,
hidden_dropout: float = 0.3,
batch_normalization: bool = True,
) -> None:
"""
Initialize ConvE model.

:param negative_sample_sharing:
see :meth:`DistanceBasedScoreFunction.__init__`
:param sharding:
Entity sharding.
:param n_relation_type:
Number of relation types in the knowledge graph.
:param embedding_size:
Size of entity and relation embeddings.
:param embedding_height:
Height of the 2D-reshaping of the concatenation of
head and relation embeddings.
:param embedding_width:
Width of the 2D-reshaping of the concatenation of
head and relation embeddings.
:param entity_initializer:
Initialization functions or table for entity embeddings.
If not passing a table, two functions are needed: the initializer
for entity embeddings and initializer for (scalar) tail biases.
:param relation_initializer:
Initialization function or table for relation embeddings.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: True.
:param input_channels:
Number of input channels of the Conv2D operator. Default: 1.
:param output_channels:
Number of output channels of the Conv2D operator. Default: 32.
:param kernel_height:
Height of the Conv2D kernel. Default: 3.
:param kernel_width:
Width of the Conv2D kernel. Default: 3.
:param input_dropout:
Rate of Dropout applied before the convolution. Default: 0.2.
:param feature_map_dropout:
Rate of Dropout applied after the convolution. Default: 0.2.
:param hidden_dropout:
Rate of Dropout applied after the Linear layer. Default: 0.3.
:param batch_normalization:
If True, apply batch normalization before and after the
convolution and after the Linear layer. Default: True.
"""
super(ConvE, self).__init__(negative_sample_sharing=negative_sample_sharing)

self.sharding = sharding

if input_channels * embedding_width * embedding_height != embedding_size:
raise ValueError(
"`embedding_size` needs to be equal to"
" `input_channels * embedding_width * embedding_height`"
)

# self.entity_embedding[..., :embedding_size] entity_embeddings
# self.entity_embedding[..., -1] tail biases
self.entity_embedding = initialize_entity_embedding(
self.sharding, entity_initializer, [embedding_size, 1]
)
self.relation_embedding = initialize_relation_embedding(
n_relation_type, inverse_relations, relation_initializer, [embedding_size]
)
assert (
self.entity_embedding.shape[-1] - 1
== self.relation_embedding.shape[-1]
== embedding_size
), (
"ConvE requires `embedding_size + 1` embedding parameters for each entity"
" and `embedding_size` embedding parameters for each relation"
)
self.embedding_size = embedding_size

self.inp_channels = input_channels
self.emb_h = embedding_height
self.emb_w = embedding_width
conv_layers = [
torch.nn.Dropout(input_dropout),
torch.nn.Conv2d(
in_channels=self.inp_channels,
out_channels=output_channels,
kernel_size=(kernel_height, kernel_width),
),
torch.nn.ReLU(),
torch.nn.Dropout2d(feature_map_dropout),
]
fc_layers = [
torch.nn.Linear(
output_channels
* (2 * self.emb_h - kernel_height + 1)
* (self.emb_w - kernel_width + 1),
embedding_size,
),
torch.nn.Dropout(hidden_dropout),
torch.nn.ReLU(),
]
if batch_normalization:
conv_layers.insert(0, torch.nn.BatchNorm2d(input_channels))
conv_layers.insert(3, torch.nn.BatchNorm2d(output_channels))
fc_layers.insert(2, torch.nn.BatchNorm1d(embedding_size))
self.conv_layers = torch.nn.Sequential(*conv_layers)
self.fc_layers = torch.nn.Sequential(*fc_layers)

# docstr-coverage: inherited
def score_triple(
self,
head_emb: torch.Tensor,
relation_id: torch.Tensor,
tail_emb: torch.Tensor,
) -> torch.Tensor:
relation_emb = torch.index_select(
self.relation_embedding, index=relation_id, dim=0
)
# Discard bias for heads
head_emb = head_emb[..., :-1]
tail_emb, tail_bias = torch.split(tail_emb, self.embedding_size, dim=-1)
hr_cat = torch.cat(
[
head_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
relation_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
],
dim=-2,
)
hr_conv = self.fc_layers(self.conv_layers(hr_cat).flatten(start_dim=1))
return self.reduce_embedding(hr_conv * tail_emb) + tail_bias.squeeze(-1)

# docstr-coverage: inherited
def score_heads(
self,
head_emb: torch.Tensor,
relation_id: torch.Tensor,
tail_emb: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError("ConvE should not be used with head corruption")

# docstr-coverage: inherited
def score_tails(
self,
head_emb: torch.Tensor,
relation_id: torch.Tensor,
tail_emb: torch.Tensor,
) -> torch.Tensor:
relation_emb = torch.index_select(
self.relation_embedding, index=relation_id, dim=0
)
# Discard bias for heads
head_emb = head_emb[..., :-1]
tail_emb, tail_bias = torch.split(tail_emb, self.embedding_size, dim=-1)
if self.negative_sample_sharing:
tail_bias = tail_bias.view(1, -1)
else:
tail_bias = tail_bias.squeeze(-1)
hr_cat = torch.cat(
[
head_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
relation_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
],
dim=-2,
)
hr_conv = self.fc_layers(self.conv_layers(hr_cat).flatten(start_dim=1))
return self.broadcasted_dot_product(hr_conv, tail_emb) + tail_bias


class BoxE(DistanceBasedScoreFunction):
"""
BoxE scoring function :cite:p:`BoxE`.
Expand Down Expand Up @@ -1000,7 +1201,7 @@ def __init__(
Softening parameter for geometric normalization of box widths.
Default: 1e-6.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(BoxE, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -1258,7 +1459,7 @@ def __init__(
:param offset:
Offset applied to auxiliary entity embeddings. Default: 1.0.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(InterHT, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down Expand Up @@ -1415,7 +1616,7 @@ def __init__(
:param offset:
Offset applied to tilde entity embeddings. Default: 1.0.
:param inverse_relations:
If True, learn embeddings for inverse relations. Default: False
If True, learn embeddings for inverse relations. Default: False.
"""
super(TranS, self).__init__(
negative_sample_sharing=negative_sample_sharing, scoring_norm=scoring_norm
Expand Down