diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 02460aad..7ba1538b 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -24,10 +24,6 @@ from nequip.utils._global_options import _set_global_options from nequip.scripts._logger import set_up_script_logger -warnings.filterwarnings( # unnecessary e3nn-related JIT warning - "ignore", - message="The TorchScript type system doesn't support instance-level annotations", -) default_config = dict( root="./", tensorboard=False, diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 21c7199d..8788c84a 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -1149,14 +1149,14 @@ def _parse_n_train_n_val( ) -> tuple[int]: # parse n_train and n_val (can be ints or str with percentage): n_train_n_val = [] - for n_name in ["n_train", "n_val"]: + for n_name, dataset_size in ( + ("n_train", train_dataset_size), + ("n_val", val_dataset_size), + ): n = getattr(self, n_name) if isinstance(n, str) and "%" in n: - dataset_size = ( - train_dataset_size if n_name == "n_train" else val_dataset_size - ) n_train_n_val.append( - (float(n.strip("%")) / 100) * dataset_size + (float(n.rstrip("%")) / 100) * dataset_size ) # convert to float first elif isinstance(n, int): n_train_n_val.append(n)