Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: What is the best way to use multiple SpatioTemporalDatasets? #42

Open
Anjum48 opened this issue Jul 31, 2024 · 1 comment
Open

Comments

@Anjum48
Copy link

Anjum48 commented Jul 31, 2024

If you have many SpatioTemporalDatasets, e.g. experiments with different numbers of sensors (i.e. nodes), what would be the best way to combining them to make a single dataset?

Or is it preferred to create a list of Data objects for each dataset above, concatenate them (i.e. data_list.extend(more_data)) and iterate using DisjointBatch?

Another option could be to create a ConcatDataset but I think this breaks the batching, but does create a list of Data objects as I described above

@Anjum48
Copy link
Author

Anjum48 commented Aug 19, 2024

Thought I might answer my own question now that I've got something that works.

It's possible to concatenate multiple SpatioTemporalDataset objects using something like this (I have each experiment as it's own .pt file, which has each graph stored as a PyTorch Geometric Data object):

from torch_geometric.data import Dataset
from tsl.data import ImputationDataset, SpatioTemporalDataset

class ForecastingDataset(Dataset):
    def __init__(
        self,
        metadata_df: pd.DataFrame,
        window: int = 12,
        horizon: int = 1,
        delay: int = 0,
        stride: int = 1,
        window_lag: int = 1,
        horizon_lag: int = 1,
        radius: float = 1.0,
        transform=None,
        pre_transform=None,
        pre_filter=None,
    ):
        super().__init__(transform, pre_transform, pre_filter)
        self.df = metadata_df.reset_index(drop=True)
        self.trfm = Compose(
            [
                RadiusGraph(r=radius),
                calculate_inverse_distance,
            ]
        )
        self.data = None

        for index, row in self.df.iterrows():
            data = torch.load(INPUT_PATH / "time" / row["file_name"])
            data = self.trfm(data)

            tsl_data = SpatioTemporalDataset(
                target=data.x.T,
                connectivity=(data.edge_index, data.edge_attr),
                window=window,
                horizon=horizon,
                delay=delay,
                stride=stride,
                window_lag=window_lag,
                horizon_lag=horizon_lag,
            )

            if self.data is None:
                self.data = tsl_data
            else:
                self.data += tsl_data

    def len(self):
        return len(self.data)

    def get(self, idx):
        return self.data[idx]

There's probably a better way of doing the above, but my data is currently small enough to fit in memory.

Then to create a Lightning DataModule, you can use DisjointGraphLoader which is not currently documented (I found it by exploring the code). This returns DisjointBatch objects:

import lightning as pl
from tsl.data.loader import DisjointGraphLoader

class STDataModule(pl.LightningDataModule):
    def __init__(
        self,
        window: int = 12,
        horizon: int = 1,
        delay: int = 0,
        stride: int = 1,
        window_lag: int = 1,
        horizon_lag: int = 1,
        radius: float = 1.0,
        batch_size: int = 32,
        seed: int = 48,
        n_folds: int = 5,
        num_workers: int = 1,
        **kwargs,
    ):
        super().__init__()
        self.dataset_kwargs = {
            "window": window,
            "horizon": horizon,
            "delay": delay,
            "stride": stride,
            "window_lag": window_lag,
            "horizon_lag": horizon_lag,
            "radius": radius,
        }
        self.batch_size = batch_size
        self.seed = seed
        self.n_folds = n_folds
        self.num_workers = num_workers
        self.train_steps = 0

    def setup(self, stage=None, fold: int = 0):
        if stage == "fit" or stage == "predict":
            folder = INPUT_PATH / "time"

            metadata = pd.read_csv(folder / "metadata.csv")

            if "fold" not in metadata.columns:
                metadata = create_folds(
                    metadata,
                    n_splits=self.n_folds,
                    random_state=self.seed,
                    group="run_id",
                )

            trn_df = metadata.query(f"fold != {fold}")
            val_df = metadata.query(f"fold == {fold}")

            self.train_ds = ForecastingDataset(trn_df, **self.dataset_kwargs)
            self.valid_ds = ForecastingDataset(val_df, **self.dataset_kwargs)
            self.pred_ds = ForecastingDataset(val_df, **self.dataset_kwargs)

            self.train_steps = len(self.train_ds) / self.batch_size
            print(
                f"Fold {fold}:",
                len(self.train_ds),
                "train and",
                len(self.valid_ds),
                "valid samples",
            )

    def train_dataloader(self):
        return DisjointGraphLoader(
            self.train_ds,
            num_workers=self.num_workers,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            persistent_workers=True,
            force_batch=True,
        )

    def val_dataloader(self):
        return DisjointGraphLoader(
            self.valid_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
            force_batch=True,
        )

    def predict_dataloader(self):
        return DisjointGraphLoader(
            self.pred_ds,
            batch_size=self.batch_size,
            num_workers=1,
            pin_memory=True,
            shuffle=False,
            force_batch=True,
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant