diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 19e44d7d..0435e5c3 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -234,17 +234,19 @@ def select_samples( ] if len(atoms_list_pt_filtered) <= args.num_samples: logging.info( - "Number of configurations after filtering is less than the number of samples, " - "selecting random configurations, for the rest." + f"Number of configurations after filtering {len(atoms_list_pt_filtered} " + f"is less than the number of samples {args.num_samples}, " + "selecting random configurations for the rest." ) atoms_list_pt_minus_filtered = [ x for x in atoms_list_pt if x not in atoms_list_pt_filtered ] - atoms_list_pt_random = np.random.choice( - atoms_list_pt_minus_filtered, + atoms_list_pt_random_inds = np.random.choice( + list(range(len(atoms_list_pt_minus_filtered))), args.num_samples - len(atoms_list_pt_filtered), - ).tolist() - atoms_list_pt = atoms_list_pt_filtered + atoms_list_pt_random + replace=False + ) + atoms_list_pt = atoms_list_pt_filtered + [atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds] else: atoms_list_pt = atoms_list_pt_filtered