From 3eb2d991995c8aceda2097d9475d48b0bfe7c47c Mon Sep 17 00:00:00 2001 From: Moritz Date: Thu, 24 Oct 2024 15:01:14 +0200 Subject: [PATCH] Changes made to adress issue #753 --- src/graphnet/models/components/embedding.py | 44 +++++++++++++-------- src/graphnet/models/gnn/icemix.py | 7 ++-- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 1b49cd901..005ab3a8b 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -55,9 +55,9 @@ class FourierEncoder(LightningModule): This module incorporates sinusoidal positional embeddings and auxiliary embeddings to process input sequences and produce meaningful - representations. The module assumes that the input data is in the format of - (x, y, z, time, charge, auxiliary), being the first four features - mandatory. + representations. The features x, y, z and time are mandatory, while charge + and auxiliary are optional. Please use the mapping to ensure correct fourier + encoding. """ def __init__( @@ -66,7 +66,7 @@ def __init__( mlp_dim: Optional[int] = None, output_dim: int = 384, scaled: bool = False, - n_features: int = 6, + mapping: list = [0, 1, 2, 3, 4, 5], ): """Construct `FourierEncoder`. @@ -79,23 +79,29 @@ def __init__( depending on `n_features`. output_dim: Dimension of the output (I.e. number of columns). scaled: Whether or not to scale the embeddings. - n_features: The number of features in the input data. + mapping: Mapping of the data to [x,y,z,time,charge,auxiliary]. Use None for missing features. """ super().__init__() - + self.mapping_str = ["x", "y", "z", "time", "charge", "auxiliary"] + self.mapping = mapping + self.n_features = len([i for i in mapping if i is not None]) self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) - if n_features < 4: + assert len(mapping) == 6, "Fourier mapping must have 6 elements. Use None for missing features." + assert all([isinstance(i, int) or i is None for i in mapping]), "Use int or None in fourier mapping." + + if any([i is None for i in mapping[:4]]): + missing = [self.mapping_str[i] for i in range(4) if mapping[i] is None] raise ValueError( - f"At least x, y, z and time of the DOM are required. Got only " - f"{n_features} features." + f"x, y, z and time of the DOM are required." + f"{missing} missing in mapping." ) - elif n_features >= 6: + elif self.n_features == 6: self.aux_emb = nn.Embedding(2, seq_length // 2) hidden_dim = 6 * seq_length else: - hidden_dim = int((n_features + 0.5) * seq_length) + hidden_dim = int((self.n_features + 0.5) * seq_length) if mlp_dim is None: mlp_dim = hidden_dim @@ -107,7 +113,6 @@ def __init__( nn.Linear(mlp_dim, output_dim), ) - self.n_features = n_features def forward( self, @@ -115,16 +120,21 @@ def forward( seq_length: Tensor, ) -> Tensor: """Forward pass.""" + if max(self.mapping)+1 > x.shape[2]: + raise IndexError(f"Fourier mapping does not fit given data." + f"Feature space of data is too small (size {x.shape[2]})," + f"given fourier mapping requires at least {max(self.mapping) + 1}.") + length = torch.log10(seq_length.to(dtype=x.dtype)) - embeddings = [self.sin_emb(4096 * x[:, :, :3]).flatten(-2)] # Position + embeddings = [self.sin_emb(4096 * x[:, :, self.mapping[:3]]).flatten(-2)] # Position if self.n_features >= 5: - embeddings.append(self.sin_emb(1024 * x[:, :, 4])) # Charge + embeddings.append(self.sin_emb(1024 * x[:, :, self.mapping[4]])) # Charge - embeddings.append(self.sin_emb(4096 * x[:, :, 3])) # Time + embeddings.append(self.sin_emb(4096 * x[:, :, self.mapping[3]])) # Time - if self.n_features >= 6: - embeddings.append(self.aux_emb(x[:, :, 5].long())) # Auxiliary + if self.n_features == 6: + embeddings.append(self.aux_emb(x[:, :, self.mapping[5]].long())) # Auxiliary embeddings.append( self.sin_emb2(length).unsqueeze(1).expand(-1, max(seq_length), -1) diff --git a/src/graphnet/models/gnn/icemix.py b/src/graphnet/models/gnn/icemix.py index a073e3ca8..569ac5a38 100644 --- a/src/graphnet/models/gnn/icemix.py +++ b/src/graphnet/models/gnn/icemix.py @@ -43,7 +43,7 @@ def __init__( scaled_emb: bool = False, include_dynedge: bool = False, dynedge_args: Dict[str, Any] = None, - n_features: int = 6, + fourier_mapping: list = [0,1,2,3,4,5] ): """Construct `DeepIce`. @@ -62,7 +62,8 @@ def __init__( provided, DynEdge will be initialized with the original Kaggle Competition settings. If `include_dynedge` is False, this argument have no impact. - n_features: The number of features in the input data. + fourier_mapping: Mapping of the data to [x,y,z,time,charge,auxiliary] + for the FourierEncoder. Use None for missing features. """ super().__init__(seq_length, hidden_dim) fourier_out_dim = hidden_dim // 2 if include_dynedge else hidden_dim @@ -71,7 +72,7 @@ def __init__( mlp_dim=None, output_dim=fourier_out_dim, scaled=scaled_emb, - n_features=n_features, + mapping = fourier_mapping, ) self.rel_pos = SpacetimeEncoder(head_size) self.sandwich = nn.ModuleList(