From 47da1654c3632b7d0f37a7d281790338e6fc197d Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo <84471416+AlCatt91@users.noreply.github.com> Date: Tue, 23 Jan 2024 09:42:05 +0000 Subject: [PATCH] change RotatE relation initialization (#38) --- besskge/embedding.py | 16 ++++++++++++++++ besskge/scoring.py | 3 ++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/besskge/embedding.py b/besskge/embedding.py index 86244da..91ed36e 100644 --- a/besskge/embedding.py +++ b/besskge/embedding.py @@ -44,6 +44,22 @@ def init_xavier_norm(embedding_table: torch.Tensor, gain: float = 1.0) -> torch. ) +def init_uniform_rotation(embedding_table: torch.Tensor) -> torch.Tensor: + r""" + Initialize tensor with each entry being a uniformly distributed + phase between 0 and :math:`2 \pi`. + To be used for initialization of relation embedding tables + in RotatE scoring function. + + :param embedding_table: + Tensor of embedding parameters to initialize. + + :return: + Initialized tensor. + """ + return torch.rand_like(embedding_table) * 2 * np.pi + + def init_KGE_uniform( embedding_table: torch.Tensor, b: float = 1.0, divide_by_embedding_size: bool = True ) -> torch.Tensor: diff --git a/besskge/scoring.py b/besskge/scoring.py index 7c0b985..063d4a7 100644 --- a/besskge/scoring.py +++ b/besskge/scoring.py @@ -15,6 +15,7 @@ init_KGE_normal, init_KGE_uniform, init_uniform_norm, + init_uniform_rotation, init_xavier_norm, initialize_entity_embedding, initialize_relation_embedding, @@ -369,7 +370,7 @@ def __init__( init_KGE_uniform ], relation_initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]] = [ - init_KGE_uniform + init_uniform_rotation ], inverse_relations: bool = False, ) -> None: