Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reworked validation data selection #1

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

arminshzd
Copy link

Since the input data for the SRVs are time series, randomly splitting the data for training and validation is not an option. The dataset creation workflow now supports three common cases for automatic train/validation data splitting as well as allowing the user to feed the validation and train data manually.
If val_frac is set to zero during model definition, all data fed to the fit function is treated as training data, and no validation will be performed. The user can manually provide the validation data to the fit function using the val_data keyword.
If val_frac is non-zero, one of the three cases can happen:

  1. If the input data is a single torch.tensor (a single continuous trajectory), the last val_frac of the data is separated and used for validation.

  2. If the input data is a list of torch.tensors:

    • If all the trajectories are the same length, ceil(val_frac*len(data)) of the trajectories are used for validation. E.g., if a list of 5 trajectories of the same length is given as input and val_frac=0.2, the last trajectory in the list will be used for validation. If val_frac*len(data) is not an integer, it will be rounded up. E.g., a list of 5 trajectories with val_frac=0.25 will use 2 of the five trajectories (without splitting any of the trajectories) for validation (this means val_frac will be automatically adjusted to 0.4 by the code).

    • If all the trajectories are not the same length, the shortest trajectory is used for validation.
      The user can override the above default behavior through val_frac=0.0 and the val_data keyword.

I suggest a full review of the new lines before merging.

arminshzd and others added 7 commits February 7, 2024 12:07
Since the input data for the SRVs are timeseries, randomly splitting the data for training and validation is not an option. The dataset creation workflow now supports three common cases for automatic train/validation data spilitting as well as the allowing the user to feed the validation and train data manually.
If  is set to zero during model definition, all data fed to the  function is treated as training data and no validation will be performed. The user can choose to manually feed the validation data to the  function using the  keyword.
If  is non-zero, one of the three cases can happen:
1) If the input data is a single  (a single continuous trajectory), the last  of the data is separated and used for validation.
2) If the input data is a  of s:
    a) If all the trajectories are the same length,  of the trajectories are used for validation. E.g, if a list of 5 trajectories of the same length are given as input and  is set to , the last trajectory in the list will be used for validation. If  is not an integer, it will be rounded up. E.g, list of 5 trajectories with  set to  will use 2 of the 5 trajectories (without splitting any of the trajectories) for validation (this means  will be automatically adjusted to  by the code).
    b)If all the trajectories are not the same length, the shortest trajectory is used for validation.
The above default behavior can be overridden by the used through  and the  keyword.
…of one of the trajectories in the input list
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant