Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Since the input data for the SRVs are time series, randomly splitting the data for training and validation is not an option. The dataset creation workflow now supports three common cases for automatic train/validation data splitting as well as allowing the user to feed the validation and train data manually.
If
val_frac
is set to zero during model definition, all data fed to thefit
function is treated as training data, and no validation will be performed. The user can manually provide the validation data to thefit
function using theval_data
keyword.If
val_frac
is non-zero, one of the three cases can happen:If the input data is a single
torch.tensor
(a single continuous trajectory), the lastval_frac
of the data is separated and used for validation.If the input data is a
list
oftorch.tensor
s:If all the trajectories are the same length,
ceil(val_frac*len(data))
of the trajectories are used for validation. E.g., if a list of 5 trajectories of the same length is given as input andval_frac=0.2
, the last trajectory in the list will be used for validation. Ifval_frac*len(data)
is not an integer, it will be rounded up. E.g., a list of 5 trajectories withval_frac=0.25
will use 2 of the five trajectories (without splitting any of the trajectories) for validation (this meansval_frac
will be automatically adjusted to0.4
by the code).If all the trajectories are not the same length, the shortest trajectory is used for validation.
The user can override the above default behavior through
val_frac=0.0
and theval_data
keyword.I suggest a full review of the new lines before merging.