Skip to content

Commit

Permalink
Don't apply np.random.choice to list(Atoms) since it thinks it's mult…
Browse files Browse the repository at this point in the history
…idimensional
  • Loading branch information
bernstei committed Jun 5, 2024
1 parent e4ac498 commit 21e6716
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,19 @@ def select_samples(
]
if len(atoms_list_pt_filtered) <= args.num_samples:
logging.info(
"Number of configurations after filtering is less than the number of samples, "
"selecting random configurations, for the rest."
f"Number of configurations after filtering {len(atoms_list_pt_filtered} "
f"is less than the number of samples {args.num_samples}, "
"selecting random configurations for the rest."
)
atoms_list_pt_minus_filtered = [
x for x in atoms_list_pt if x not in atoms_list_pt_filtered
]
atoms_list_pt_random = np.random.choice(
atoms_list_pt_minus_filtered,
atoms_list_pt_random_inds = np.random.choice(
list(range(len(atoms_list_pt_minus_filtered))),
args.num_samples - len(atoms_list_pt_filtered),
).tolist()
atoms_list_pt = atoms_list_pt_filtered + atoms_list_pt_random
replace=False
)
atoms_list_pt = atoms_list_pt_filtered + [atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds]
else:
atoms_list_pt = atoms_list_pt_filtered

Expand Down

0 comments on commit 21e6716

Please sign in to comment.