Skip to content

Commit

Permalink
fix: 🐛 bugs in STAEformer and MTGNN (issues #219 and #220)
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Jan 1, 2025
1 parent 87be63b commit a99ae89
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 32 deletions.
32 changes: 3 additions & 29 deletions baselines/MTGNN/runner/mtgnn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,9 @@ def __init__(self, cfg: dict):
self.num_split = cfg.TRAIN.CUSTOM.NUM_SPLIT
self.perm = None

def select_input_features(self, data: torch.Tensor) -> torch.Tensor:
"""Select input features.
Args:
data (torch.Tensor): input history data, shape [B, L, N, C]
Returns:
torch.Tensor: reshaped data
"""

# select feature using self.forward_features
if self.forward_features is not None:
data = data[:, :, :, self.forward_features]
return data

def select_target_features(self, data: torch.Tensor) -> torch.Tensor:
"""Select target feature
Args:
data (torch.Tensor): prediction of the model with arbitrary shape.
Returns:
torch.Tensor: reshaped data with shape [B, L, N, C]
"""

# select feature using self.target_features
data = data[:, :, :, self.target_features]
return data

def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: bool = True, **kwargs) -> tuple:
data = self.preprocessing(data)

if train:
future_data, history_data, idx = data['target'], data['inputs'], data['idx']
else:
Expand All @@ -68,6 +41,7 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b
model_return["target"] = self.select_target_features(future_data)
assert list(model_return["prediction"].shape)[:3] == [batch_size, seq_len, num_nodes], \
"error shape of the output, edit the forward function to reshape it to [B, L, N, C]"
model_return = self.postprocessing(model_return)
return model_return

def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor:
Expand Down
6 changes: 3 additions & 3 deletions baselines/STAEformer/arch/staeformer_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,16 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_s
batch_size = x.shape[0]

if self.tod_embedding_dim > 0:
tod = x[..., 1]
tod = x[..., 1] * self.steps_per_day
if self.dow_embedding_dim > 0:
dow = x[..., 2]
dow = x[..., 2] * 7
x = x[..., : self.input_dim]

x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim)
features = [x]
if self.tod_embedding_dim > 0:
tod_emb = self.tod_embedding(
(tod * self.steps_per_day).long()
tod.long()
) # (batch_size, in_steps, num_nodes, tod_embedding_dim)
features.append(tod_emb)
if self.dow_embedding_dim > 0:
Expand Down

0 comments on commit a99ae89

Please sign in to comment.