Skip to content

Commit

Permalink
Merge pull request #59 from robinzyb/devel
Browse files Browse the repository at this point in the history
raise error for not finding the pos file in dpdata md paser
  • Loading branch information
robinzyb authored Jun 25, 2024
2 parents f377796 + 2e1b654 commit 8cd58b8
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 71 deletions.
143 changes: 73 additions & 70 deletions cp2kdata/dpdata_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def from_labeled_system(self, file_name, restart: bool=None, **kwargs):
# atom_numbs not total num of atoms!
data['energies'] = cp2kmd.energies_list * AU_TO_EV
data['cells'] = cells
data['coords'] = cp2kmd.atomic_frames_list
if cp2kmd.atomic_frames_list is None:
raise ValueError("No atomic coordinates found in cp2k output, do you have *-pos-*.xyz file?")
else:
data['coords'] = cp2kmd.atomic_frames_list
data['forces'] = cp2kmd.atomic_forces_list * AU_TO_EV/AU_TO_ANG
if cp2kmd.has_stress():
# note that virial = stress * volume
Expand Down Expand Up @@ -172,72 +175,72 @@ def get_uniq_atom_names_and_types(chemical_symbols):
# NOTE: incomplete function, do not release!


@Format.register("cp2kdata/md_wannier")
class CP2KMDWannierFormat(Format):
def from_labeled_system(self, file_name, **kwargs):

# -- Set Basic Parameters --
path_prefix = file_name # in cp2k md, file_name is directory name.
true_symbols = kwargs.get('true_symbols', False)
cells = kwargs.get('cells', None)
cp2k_output_name = kwargs.get('cp2k_output_name', None)

# -- start parsing --
print(WRAPPER)

cp2kmd = Cp2kOutput(output_file=cp2k_output_name,
run_type="MD", path_prefix=path_prefix)

num_frames = cp2kmd.get_num_frames()

chemical_symbols = get_chemical_symbols_from_cp2kdata(
cp2koutput=cp2kmd,
true_symbols=true_symbols
)

if cells is None:
if cp2kmd.filename:
# cells = cp2kmd.get_init_cell()
# cells = cells[np.newaxis, :, :]
# cells = np.repeat(cells, repeats=num_frames, axis=0)
cells = cp2kmd.get_all_cells()
else:
print("No cell information, please check if your inputs are correct.")
elif isinstance(cells, np.ndarray):
if cells.shape == (3, 3):
cells = cells[np.newaxis, :, :]
cells = np.repeat(cells, repeats=num_frames, axis=0)
elif cells.shape == (num_frames, 3, 3):
pass
else:
print(
"Illegal Cell Information, cells shape should be (num_frames, 3, 3) or (3, 3)")
else:
print(
"Illegal Cell Information, cp2kdata accepts np.ndarray as cells information")

# -- data dict collects information, and return to dpdata --
data = {}
data['atom_names'], data['atom_numbs'], data["atom_types"] = get_uniq_atom_names_and_types(
chemical_symbols=chemical_symbols)
# atom_numbs not total num of atoms!
data['energies'] = cp2kmd.energies_list * AU_TO_EV
data['cells'] = cells

# get wannier centers from wannier xyz file

cp2k_wannier_file = kwargs.get('cp2k_wannier_file', None)
if cp2k_wannier_file:
print("This is wannier center parser")
print("Position parsed from pos files are not used.")
cp2k_wannier_file = os.path.join(path_prefix, cp2k_wannier_file)
data['coords'] = parse_pos_xyz_from_wannier(cp2k_wannier_file)
else:
raise ValueError("Please specify the cp2k wannier file name!")

data['forces'] = cp2kmd.atomic_forces_list * AU_TO_EV/AU_TO_ANG
if cp2kmd.has_stress():
data['virials'] = cp2kmd.stress_tensor_list/EV_ANG_m3_TO_GPa
# print(len(data['cells']), len(data['coords']), len(data['energies']))
print(WRAPPER)
return data
# @Format.register("cp2kdata/md_wannier")
# class CP2KMDWannierFormat(Format):
# def from_labeled_system(self, file_name, **kwargs):

# # -- Set Basic Parameters --
# path_prefix = file_name # in cp2k md, file_name is directory name.
# true_symbols = kwargs.get('true_symbols', False)
# cells = kwargs.get('cells', None)
# cp2k_output_name = kwargs.get('cp2k_output_name', None)

# # -- start parsing --
# print(WRAPPER)

# cp2kmd = Cp2kOutput(output_file=cp2k_output_name,
# run_type="MD", path_prefix=path_prefix)

# num_frames = cp2kmd.get_num_frames()

# chemical_symbols = get_chemical_symbols_from_cp2kdata(
# cp2koutput=cp2kmd,
# true_symbols=true_symbols
# )

# if cells is None:
# if cp2kmd.filename:
# # cells = cp2kmd.get_init_cell()
# # cells = cells[np.newaxis, :, :]
# # cells = np.repeat(cells, repeats=num_frames, axis=0)
# cells = cp2kmd.get_all_cells()
# else:
# print("No cell information, please check if your inputs are correct.")
# elif isinstance(cells, np.ndarray):
# if cells.shape == (3, 3):
# cells = cells[np.newaxis, :, :]
# cells = np.repeat(cells, repeats=num_frames, axis=0)
# elif cells.shape == (num_frames, 3, 3):
# pass
# else:
# print(
# "Illegal Cell Information, cells shape should be (num_frames, 3, 3) or (3, 3)")
# else:
# print(
# "Illegal Cell Information, cp2kdata accepts np.ndarray as cells information")

# # -- data dict collects information, and return to dpdata --
# data = {}
# data['atom_names'], data['atom_numbs'], data["atom_types"] = get_uniq_atom_names_and_types(
# chemical_symbols=chemical_symbols)
# # atom_numbs not total num of atoms!
# data['energies'] = cp2kmd.energies_list * AU_TO_EV
# data['cells'] = cells

# # get wannier centers from wannier xyz file

# cp2k_wannier_file = kwargs.get('cp2k_wannier_file', None)
# if cp2k_wannier_file:
# print("This is wannier center parser")
# print("Position parsed from pos files are not used.")
# cp2k_wannier_file = os.path.join(path_prefix, cp2k_wannier_file)
# data['coords'] = parse_pos_xyz_from_wannier(cp2k_wannier_file)
# else:
# raise ValueError("Please specify the cp2k wannier file name!")

# data['forces'] = cp2kmd.atomic_forces_list * AU_TO_EV/AU_TO_ANG
# if cp2kmd.has_stress():
# data['virials'] = cp2kmd.stress_tensor_list/EV_ANG_m3_TO_GPa
# # print(len(data['cells']), len(data['coords']), len(data['energies']))
# print(WRAPPER)
# return data
2 changes: 1 addition & 1 deletion cp2kdata/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ def parse_md(self):
# if no pos file and ener file, parse energies from the output file
format_logger(info="Energies", filename=self.filename)
self.energies_list = parse_energies_list(self.output_file)

self.energies_list = self.drop_last_info(self.cp2k_info, self.energies_list)
self.atomic_frames_list = None

frc_xyz_file_list = glob.glob(
os.path.join(self.path_prefix, "*frc*.xyz"))
Expand Down

0 comments on commit 8cd58b8

Please sign in to comment.