Skip to content

Commit

Permalink
Changes made to adress issue graphnet-team#753
Browse files Browse the repository at this point in the history
  • Loading branch information
mobra7 committed Oct 24, 2024
1 parent 6309445 commit 3eb2d99
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
44 changes: 27 additions & 17 deletions src/graphnet/models/components/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ class FourierEncoder(LightningModule):
This module incorporates sinusoidal positional embeddings and auxiliary
embeddings to process input sequences and produce meaningful
representations. The module assumes that the input data is in the format of
(x, y, z, time, charge, auxiliary), being the first four features
mandatory.
representations. The features x, y, z and time are mandatory, while charge
and auxiliary are optional. Please use the mapping to ensure correct fourier
encoding.
"""

def __init__(
Expand All @@ -66,7 +66,7 @@ def __init__(
mlp_dim: Optional[int] = None,
output_dim: int = 384,
scaled: bool = False,
n_features: int = 6,
mapping: list = [0, 1, 2, 3, 4, 5],
):
"""Construct `FourierEncoder`.
Expand All @@ -79,23 +79,29 @@ def __init__(
depending on `n_features`.
output_dim: Dimension of the output (I.e. number of columns).
scaled: Whether or not to scale the embeddings.
n_features: The number of features in the input data.
mapping: Mapping of the data to [x,y,z,time,charge,auxiliary]. Use None for missing features.
"""
super().__init__()

self.mapping_str = ["x", "y", "z", "time", "charge", "auxiliary"]
self.mapping = mapping
self.n_features = len([i for i in mapping if i is not None])
self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled)
self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled)

if n_features < 4:
assert len(mapping) == 6, "Fourier mapping must have 6 elements. Use None for missing features."
assert all([isinstance(i, int) or i is None for i in mapping]), "Use int or None in fourier mapping."

if any([i is None for i in mapping[:4]]):
missing = [self.mapping_str[i] for i in range(4) if mapping[i] is None]
raise ValueError(
f"At least x, y, z and time of the DOM are required. Got only "
f"{n_features} features."
f"x, y, z and time of the DOM are required."
f"{missing} missing in mapping."
)
elif n_features >= 6:
elif self.n_features == 6:
self.aux_emb = nn.Embedding(2, seq_length // 2)
hidden_dim = 6 * seq_length
else:
hidden_dim = int((n_features + 0.5) * seq_length)
hidden_dim = int((self.n_features + 0.5) * seq_length)

if mlp_dim is None:
mlp_dim = hidden_dim
Expand All @@ -107,24 +113,28 @@ def __init__(
nn.Linear(mlp_dim, output_dim),
)

self.n_features = n_features

def forward(
self,
x: Tensor,
seq_length: Tensor,
) -> Tensor:
"""Forward pass."""
if max(self.mapping)+1 > x.shape[2]:
raise IndexError(f"Fourier mapping does not fit given data."
f"Feature space of data is too small (size {x.shape[2]}),"
f"given fourier mapping requires at least {max(self.mapping) + 1}.")

length = torch.log10(seq_length.to(dtype=x.dtype))
embeddings = [self.sin_emb(4096 * x[:, :, :3]).flatten(-2)] # Position
embeddings = [self.sin_emb(4096 * x[:, :, self.mapping[:3]]).flatten(-2)] # Position

if self.n_features >= 5:
embeddings.append(self.sin_emb(1024 * x[:, :, 4])) # Charge
embeddings.append(self.sin_emb(1024 * x[:, :, self.mapping[4]])) # Charge

embeddings.append(self.sin_emb(4096 * x[:, :, 3])) # Time
embeddings.append(self.sin_emb(4096 * x[:, :, self.mapping[3]])) # Time

if self.n_features >= 6:
embeddings.append(self.aux_emb(x[:, :, 5].long())) # Auxiliary
if self.n_features == 6:
embeddings.append(self.aux_emb(x[:, :, self.mapping[5]].long())) # Auxiliary

embeddings.append(
self.sin_emb2(length).unsqueeze(1).expand(-1, max(seq_length), -1)
Expand Down
7 changes: 4 additions & 3 deletions src/graphnet/models/gnn/icemix.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
scaled_emb: bool = False,
include_dynedge: bool = False,
dynedge_args: Dict[str, Any] = None,
n_features: int = 6,
fourier_mapping: list = [0,1,2,3,4,5]
):
"""Construct `DeepIce`.
Expand All @@ -62,7 +62,8 @@ def __init__(
provided, DynEdge will be initialized with the original Kaggle
Competition settings. If `include_dynedge` is False, this
argument have no impact.
n_features: The number of features in the input data.
fourier_mapping: Mapping of the data to [x,y,z,time,charge,auxiliary]
for the FourierEncoder. Use None for missing features.
"""
super().__init__(seq_length, hidden_dim)
fourier_out_dim = hidden_dim // 2 if include_dynedge else hidden_dim
Expand All @@ -71,7 +72,7 @@ def __init__(
mlp_dim=None,
output_dim=fourier_out_dim,
scaled=scaled_emb,
n_features=n_features,
mapping = fourier_mapping,
)
self.rel_pos = SpacetimeEncoder(head_size)
self.sandwich = nn.ModuleList(
Expand Down

0 comments on commit 3eb2d99

Please sign in to comment.