diff --git a/dpdata/utils.py b/dpdata/utils.py index e008120e..07e10caa 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -77,8 +77,17 @@ def sort_atom_names(data, type_map=None): """ if type_map is not None: # assign atom_names index to the specify order + if not set(data["atom_names"]).issubset(set(type_map)): + # delete the types where atom_numbs == 0 + real_index = np.array(data["atom_numbs"]) > 0 + real_map = list(np.array(data["atom_names"])[real_index]) + assert set(real_map).issubset(set(type_map)) + data["atom_numbs"] = list(np.array(data["atom_numbs"])[real_index]) + idx = np.zeros(len(data["atom_names"]), dtype=int) + idx[real_index] = np.arange(len(real_map)) + data["atom_names"] = real_map + data["atom_types"] = idx[data["atom_types"]] # atom_names must be a subset of type_map - assert set(data["atom_names"]).issubset(set(type_map)) # for the condition that type_map is a proper superset of atom_names # new_atoms = set(type_map) - set(data["atom_names"]) new_atoms = [e for e in type_map if e not in data["atom_names"]]