Skip to content

Commit

Permalink
Implement output trajectory merge in Cp2kBaseWorkChain
Browse files Browse the repository at this point in the history
  • Loading branch information
yakutovicha committed Mar 6, 2024
1 parent 33fd994 commit 21eab70
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 23 deletions.
2 changes: 2 additions & 0 deletions aiida_cp2k/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
###############################################################################
"""AiiDA-CP2K utils"""

from .datatype_helpers import merge_trajectory_data
from .input_generator import (
Cp2kInput,
add_ext_restart_section,
Expand Down Expand Up @@ -42,4 +43,5 @@
"merge_Dict",
"ot_has_small_bandgap",
"resize_unit_cell",
"merge_trajectory_data",
]
70 changes: 47 additions & 23 deletions aiida_cp2k/utils/datatype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import re
from collections.abc import Sequence

from aiida.common import InputValidationError
from aiida.plugins import DataFactory
import numpy as np
from aiida import common, orm, plugins


def _unpack(adict):
Expand Down Expand Up @@ -50,7 +50,9 @@ def _kind_element_from_kind_section(section):
try:
kind = section["_"]
except KeyError:
raise InputValidationError("No default parameter '_' found in KIND section.")
raise common.InputValidationError(
"No default parameter '_' found in KIND section."
)

try:
element = section["ELEMENT"]
Expand All @@ -60,7 +62,7 @@ def _kind_element_from_kind_section(section):
try:
element = match["sym"]
except TypeError:
raise InputValidationError(
raise common.InputValidationError(
f"Unable to figure out atomic symbol from KIND '{kind}'."
)

Expand Down Expand Up @@ -125,7 +127,7 @@ def _write_gdt(inp, entries, folder, key, fname):
def validate_basissets_namespace(basissets, _):
"""A input_namespace validator to ensure passed down basis sets have the correct type."""
return _validate_gdt_namespace(
basissets, DataFactory("gaussian.basisset"), "basis set"
basissets, plugins.DataFactory("gaussian.basisset"), "basis set"
)


Expand Down Expand Up @@ -176,7 +178,7 @@ def validate_basissets(inp, basissets, structure):
bsets = [(t, b) for t, s, b in basissets if s == element]

if not bsets:
raise InputValidationError(
raise common.InputValidationError(
f"No basis set found for kind {kind} or element {element}"
f" in basissets input namespace and not explicitly set."
)
Expand All @@ -203,7 +205,7 @@ def validate_basissets(inp, basissets, structure):
bsets = [(t, b) for t, s, b in basissets if s == element]

if not bsets:
raise InputValidationError(
raise common.InputValidationError(
f"'BASIS_SET {bstype} {bsname}' for element {element} (from kind {kind})"
" not found in basissets input namespace"
)
Expand All @@ -213,7 +215,7 @@ def validate_basissets(inp, basissets, structure):
basissets_used.add(bset)
break
else:
raise InputValidationError(
raise common.InputValidationError(
f"'BASIS_SET {bstype} {bsname}' for element {element} (from kind {kind})"
" not found in basissets input namespace"
)
Expand All @@ -222,14 +224,14 @@ def validate_basissets(inp, basissets, structure):
if not structure and any(
bset not in basissets_used for bset in basissets_specified
):
raise InputValidationError(
raise common.InputValidationError(
"No explicit structure given and basis sets not referenced in input"
)

if isinstance(inp["FORCE_EVAL"], Sequence) and any(
kind.name not in explicit_kinds for kind in structure.kinds
):
raise InputValidationError(
raise common.InputValidationError(
"Automated BASIS_SET keyword creation is not yet supported with multiple FORCE_EVALs."
" Please explicitly reference a BASIS_SET for each KIND."
)
Expand All @@ -250,13 +252,13 @@ def validate_basissets(inp, basissets, structure):
bsets = [(t, b) for t, s, b in basissets if s == kind.symbol]

if not bsets:
raise InputValidationError(
raise common.InputValidationError(
f"No basis set found in the given basissets for kind '{kind.name}' of your structure."
)

for _, bset in bsets:
if bset.element != kind.symbol:
raise InputValidationError(
raise common.InputValidationError(
f"Basis set '{bset.name}' for '{bset.element}' specified"
f" for kind '{kind.name}' (of '{kind.symbol}')."
)
Expand All @@ -274,7 +276,7 @@ def validate_basissets(inp, basissets, structure):

for bset in basissets_specified:
if bset not in basissets_used:
raise InputValidationError(
raise common.InputValidationError(
f"Basis set '{bset.name}' ('{bset.element}') specified in the basissets"
f" input namespace but not referenced by either input or structure."
)
Expand All @@ -287,7 +289,9 @@ def write_basissets(inp, basissets, folder):

def validate_pseudos_namespace(pseudos, _):
"""A input_namespace validator to ensure passed down pseudopentials have the correct type."""
return _validate_gdt_namespace(pseudos, DataFactory("gaussian.pseudo"), "pseudo")
return _validate_gdt_namespace(
pseudos, plugins.DataFactory("gaussian.pseudo"), "pseudo"
)


def validate_pseudos(inp, pseudos, structure):
Expand Down Expand Up @@ -318,7 +322,7 @@ def validate_pseudos(inp, pseudos, structure):
try:
pseudo = pseudos[element]
except KeyError:
raise InputValidationError(
raise common.InputValidationError(
f"No pseudopotential found for kind {kind} or element {element}"
f" in pseudos input namespace and not explicitly set."
)
Expand All @@ -335,19 +339,19 @@ def validate_pseudos(inp, pseudos, structure):
try:
pseudo = pseudos[element]
except KeyError:
raise InputValidationError(
raise common.InputValidationError(
f"'POTENTIAL {ptype} {pname}' for element {element} (from kind {kind})"
" not found in pseudos input namespace"
)

if pname not in pseudo.aliases:
raise InputValidationError(
raise common.InputValidationError(
f"'POTENTIAL {ptype} {pname}' for element {element} (from kind {kind})"
" not found in pseudos input namespace"
)

if pseudo.element != element:
raise InputValidationError(
raise common.InputValidationError(
f"Pseudopotential '{pseudo.name}' for '{pseudo.element}' specified"
f" for element '{element}'."
)
Expand All @@ -358,14 +362,14 @@ def validate_pseudos(inp, pseudos, structure):
if not structure and any(
pseudo not in pseudos_used for pseudo in pseudos_specified
):
raise InputValidationError(
raise common.InputValidationError(
"No explicit structure given and pseudo not referenced in input"
)

if isinstance(inp["FORCE_EVAL"], Sequence) and any(
kind.name not in explicit_kinds for kind in structure.kinds
):
raise InputValidationError(
raise common.InputValidationError(
"Automated POTENTIAL keyword creation is not yet supported with multiple FORCE_EVALs."
" Please explicitly reference a POTENTIAL for each KIND."
)
Expand All @@ -383,13 +387,13 @@ def validate_pseudos(inp, pseudos, structure):
try:
pseudo = pseudos[kind.symbol]
except KeyError:
raise InputValidationError(
raise common.InputValidationError(
f"No basis set found in the given basissets"
f" for kind '{kind.name}' (or '{kind.symbol}') of your structure."
)

if pseudo.element != kind.symbol:
raise InputValidationError(
raise common.InputValidationError(
f"Pseudopotential '{pseudo.name}' for '{pseudo.element}' specified"
f" for kind '{kind.name}' (of '{kind.symbol}')."
)
Expand All @@ -402,7 +406,7 @@ def validate_pseudos(inp, pseudos, structure):

for pseudo in pseudos_specified:
if pseudo not in pseudos_used:
raise InputValidationError(
raise common.InputValidationError(
f"Pseudopodential '{pseudo.name}' specified in the pseudos input namespace"
f" but not referenced by either input or structure."
)
Expand All @@ -411,3 +415,23 @@ def validate_pseudos(inp, pseudos, structure):
def write_pseudos(inp, pseudos, folder):
"""Writes the unified POTENTIAL file with the used pseudos"""
_write_gdt(inp, pseudos, folder, "POTENTIAL_FILE_NAME", "POTENTIAL")


def merge_trajectory_data(*trajectories):
if len(trajectories) < 0:
return None

final_trajectory = orm.TrajectoryData()

array_names = trajectories[0].get_arraynames()
for array_name in array_names:
if any(array_name not in traj.get_arraynames() for traj in trajectories):
raise ValueError(
f"Array name '{array_name}' not found in all trajectories."
)
merged_array = np.concatenate(
[traj.get_array(array_name) for traj in trajectories], axis=0
)
final_trajectory.set_array(array_name, merged_array)

return final_trajectory

0 comments on commit 21eab70

Please sign in to comment.