Skip to content

Commit

Permalink
fix deterministic valid
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Jun 6, 2024
1 parent 346999c commit 5463656
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def select_samples(
]
if len(atoms_list_pt_filtered) <= args.num_samples:
logging.info(
f"Number of configurations after filtering {len(atoms_list_pt_filtered} "
f"Number of configurations after filtering {len(atoms_list_pt_filtered)} "
f"is less than the number of samples {args.num_samples}, "
"selecting random configurations for the rest."
)
Expand All @@ -244,9 +244,11 @@ def select_samples(
atoms_list_pt_random_inds = np.random.choice(
list(range(len(atoms_list_pt_minus_filtered))),
args.num_samples - len(atoms_list_pt_filtered),
replace=False
replace=False,
)
atoms_list_pt = atoms_list_pt_filtered + [atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds]
atoms_list_pt = atoms_list_pt_filtered + [
atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds
]
else:
atoms_list_pt = atoms_list_pt_filtered

Expand Down
2 changes: 1 addition & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def main() -> None:
dataset=valid_set,
batch_size=args.valid_batch_size,
sampler=valid_samplers[head] if args.distributed else None,
shuffle=(valid_sampler is None),
shuffle=False
drop_last=False,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
Expand Down

0 comments on commit 5463656

Please sign in to comment.