-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question: What is the best way to use multiple SpatioTemporalDatasets? #42
Comments
Thought I might answer my own question now that I've got something that works. It's possible to concatenate multiple from torch_geometric.data import Dataset
from tsl.data import ImputationDataset, SpatioTemporalDataset
class ForecastingDataset(Dataset):
def __init__(
self,
metadata_df: pd.DataFrame,
window: int = 12,
horizon: int = 1,
delay: int = 0,
stride: int = 1,
window_lag: int = 1,
horizon_lag: int = 1,
radius: float = 1.0,
transform=None,
pre_transform=None,
pre_filter=None,
):
super().__init__(transform, pre_transform, pre_filter)
self.df = metadata_df.reset_index(drop=True)
self.trfm = Compose(
[
RadiusGraph(r=radius),
calculate_inverse_distance,
]
)
self.data = None
for index, row in self.df.iterrows():
data = torch.load(INPUT_PATH / "time" / row["file_name"])
data = self.trfm(data)
tsl_data = SpatioTemporalDataset(
target=data.x.T,
connectivity=(data.edge_index, data.edge_attr),
window=window,
horizon=horizon,
delay=delay,
stride=stride,
window_lag=window_lag,
horizon_lag=horizon_lag,
)
if self.data is None:
self.data = tsl_data
else:
self.data += tsl_data
def len(self):
return len(self.data)
def get(self, idx):
return self.data[idx] There's probably a better way of doing the above, but my data is currently small enough to fit in memory. Then to create a Lightning import lightning as pl
from tsl.data.loader import DisjointGraphLoader
class STDataModule(pl.LightningDataModule):
def __init__(
self,
window: int = 12,
horizon: int = 1,
delay: int = 0,
stride: int = 1,
window_lag: int = 1,
horizon_lag: int = 1,
radius: float = 1.0,
batch_size: int = 32,
seed: int = 48,
n_folds: int = 5,
num_workers: int = 1,
**kwargs,
):
super().__init__()
self.dataset_kwargs = {
"window": window,
"horizon": horizon,
"delay": delay,
"stride": stride,
"window_lag": window_lag,
"horizon_lag": horizon_lag,
"radius": radius,
}
self.batch_size = batch_size
self.seed = seed
self.n_folds = n_folds
self.num_workers = num_workers
self.train_steps = 0
def setup(self, stage=None, fold: int = 0):
if stage == "fit" or stage == "predict":
folder = INPUT_PATH / "time"
metadata = pd.read_csv(folder / "metadata.csv")
if "fold" not in metadata.columns:
metadata = create_folds(
metadata,
n_splits=self.n_folds,
random_state=self.seed,
group="run_id",
)
trn_df = metadata.query(f"fold != {fold}")
val_df = metadata.query(f"fold == {fold}")
self.train_ds = ForecastingDataset(trn_df, **self.dataset_kwargs)
self.valid_ds = ForecastingDataset(val_df, **self.dataset_kwargs)
self.pred_ds = ForecastingDataset(val_df, **self.dataset_kwargs)
self.train_steps = len(self.train_ds) / self.batch_size
print(
f"Fold {fold}:",
len(self.train_ds),
"train and",
len(self.valid_ds),
"valid samples",
)
def train_dataloader(self):
return DisjointGraphLoader(
self.train_ds,
num_workers=self.num_workers,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
pin_memory=True,
persistent_workers=True,
force_batch=True,
)
def val_dataloader(self):
return DisjointGraphLoader(
self.valid_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
force_batch=True,
)
def predict_dataloader(self):
return DisjointGraphLoader(
self.pred_ds,
batch_size=self.batch_size,
num_workers=1,
pin_memory=True,
shuffle=False,
force_batch=True,
) |
If you have many SpatioTemporalDatasets, e.g. experiments with different numbers of sensors (i.e. nodes), what would be the best way to combining them to make a single dataset?
Or is it preferred to create a list of
Data
objects for each dataset above, concatenate them (i.e.data_list.extend(more_data)
) and iterate usingDisjointBatch
?Another option could be to create a
ConcatDataset
but I think this breaks the batching, but does create a list ofData
objects as I described aboveThe text was updated successfully, but these errors were encountered: