Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCatt91 committed Nov 8, 2023
1 parent 0a3af53 commit 7d5bfc1
Showing 1 changed file with 16 additions and 39 deletions.
55 changes: 16 additions & 39 deletions besskge/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,8 +965,8 @@ def __init__(
sharding: Sharding,
n_relation_type: int,
embedding_size: int,
embedding_width: int,
embedding_height: int,
embedding_width: int,
entity_initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]] = [
init_xavier_norm,
torch.nn.init.zeros_,
Expand All @@ -977,8 +977,8 @@ def __init__(
inverse_relations: bool = True,
input_channels: int = 1,
output_channels: int = 32,
kernel_width: int = 3,
kernel_height: int = 3,
kernel_width: int = 3,
input_dropout: float = 0.2,
feature_map_dropout: float = 0.2,
hidden_dropout: float = 0.3,
Expand All @@ -995,12 +995,12 @@ def __init__(
Number of relation types in the knowledge graph.
:param embedding_size:
Size of entity and relation embeddings.
:param embedding_width:
Width of the 2D-reshaping of the concatenation of
head and relation embeddings.
:param embedding_height:
Height of the 2D-reshaping of the concatenation of
head and relation embeddings.
:param embedding_width:
Width of the 2D-reshaping of the concatenation of
head and relation embeddings.
:param entity_initializer:
Initialization functions or table for entity embeddings.
If not passing a table, two functions are needed: the initializer
Expand All @@ -1013,10 +1013,10 @@ def __init__(
Number of input channels of the Conv2D operator. Default: 1.
:param output_channels:
Number of output channels of the Conv2D operator. Default: 32.
:param kernel_width:
Width of the Conv2D kernel. Default: 3.
:param kernel_height:
Height of the Conv2D kernel. Default: 3.
:param kernel_width:
Width of the Conv2D kernel. Default: 3.
:param input_dropout:
Rate of Dropout applied before the convolution. Default: 0.2.
:param feature_map_dropout:
Expand Down Expand Up @@ -1056,23 +1056,23 @@ def __init__(
self.embedding_size = embedding_size

self.inp_channels = input_channels
self.emb_w = embedding_width
self.emb_h = embedding_height
self.emb_w = embedding_width
conv_layers = [
torch.nn.Dropout(input_dropout),
torch.nn.Conv2d(
in_channels=self.inp_channels,
out_channels=output_channels,
kernel_size=(kernel_width, kernel_height),
kernel_size=(kernel_height, kernel_width),
),
torch.nn.ReLU(),
torch.nn.Dropout2d(feature_map_dropout),
]
fc_layers = [
torch.nn.Linear(
output_channels
* (2 * self.emb_w - kernel_width + 1)
* (self.emb_h - kernel_height + 1),
* (2 * self.emb_h - kernel_height + 1)
* (self.emb_w - kernel_width + 1),
embedding_size,
),
torch.nn.Dropout(hidden_dropout),
Expand Down Expand Up @@ -1100,8 +1100,8 @@ def score_triple(
tail_emb, tail_bias = torch.split(tail_emb, self.embedding_size, dim=-1)
hr_cat = torch.cat(
[
head_emb.view(-1, self.inp_channels, self.emb_w, self.emb_h),
relation_emb.view(-1, self.inp_channels, self.emb_w, self.emb_h),
head_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
relation_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
],
dim=-2,
)
Expand All @@ -1115,30 +1115,7 @@ def score_heads(
relation_id: torch.Tensor,
tail_emb: torch.Tensor,
) -> torch.Tensor:
relation_emb = torch.index_select(
self.relation_embedding, index=relation_id, dim=0
)
# Discard bias for heads
head_emb = head_emb[..., :-1]
tail_emb, tail_bias = torch.split(tail_emb, self.embedding_size, dim=-1)
if self.negative_sample_sharing:
head_emb = head_emb.view(1, -1, self.embedding_size)
broadcast_dim = (relation_emb.shape[0], head_emb.shape[1], self.embedding_size)
hr_cat = torch.cat(
[
head_emb.expand(*broadcast_dim).view(
-1, self.inp_channels, self.emb_w, self.emb_h
),
relation_emb.unsqueeze(1)
.expand(*broadcast_dim)
.view(-1, self.inp_channels, self.emb_w, self.emb_h),
],
dim=-2,
)
hr_conv = self.fc_layers(self.conv_layers(hr_cat).flatten(start_dim=1)).view(
*broadcast_dim
)
return self.reduce_embedding(hr_conv * tail_emb.unsqueeze(1)) + tail_bias
raise NotImplementedError("ConvE should not be used with head corruption")

# docstr-coverage: inherited
def score_tails(
Expand All @@ -1159,8 +1136,8 @@ def score_tails(
tail_bias = tail_bias.squeeze(-1)
hr_cat = torch.cat(
[
head_emb.view(-1, self.inp_channels, self.emb_w, self.emb_h),
relation_emb.view(-1, self.inp_channels, self.emb_w, self.emb_h),
head_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
relation_emb.view(-1, self.inp_channels, self.emb_h, self.emb_w),
],
dim=-2,
)
Expand Down

0 comments on commit 7d5bfc1

Please sign in to comment.