Skip to content

Commit

Permalink
Merge branch 'refs/heads/main' into stratified_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Jun 27, 2024
2 parents 0170439 + 7e1ff56 commit 0b5511f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
4 changes: 0 additions & 4 deletions nequip/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
from nequip.utils._global_options import _set_global_options
from nequip.scripts._logger import set_up_script_logger

warnings.filterwarnings( # unnecessary e3nn-related JIT warning
"ignore",
message="The TorchScript type system doesn't support instance-level annotations",
)
default_config = dict(
root="./",
tensorboard=False,
Expand Down
10 changes: 5 additions & 5 deletions nequip/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,14 +1149,14 @@ def _parse_n_train_n_val(
) -> tuple[int]:
# parse n_train and n_val (can be ints or str with percentage):
n_train_n_val = []
for n_name in ["n_train", "n_val"]:
for n_name, dataset_size in (
("n_train", train_dataset_size),
("n_val", val_dataset_size),
):
n = getattr(self, n_name)
if isinstance(n, str) and "%" in n:
dataset_size = (
train_dataset_size if n_name == "n_train" else val_dataset_size
)
n_train_n_val.append(
(float(n.strip("%")) / 100) * dataset_size
(float(n.rstrip("%")) / 100) * dataset_size
) # convert to float first
elif isinstance(n, int):
n_train_n_val.append(n)
Expand Down

0 comments on commit 0b5511f

Please sign in to comment.