diff --git a/aiida_cp2k/calculations/__init__.py b/aiida_cp2k/calculations/__init__.py index 3b0de0b9..afbe54c7 100644 --- a/aiida_cp2k/calculations/__init__.py +++ b/aiida_cp2k/calculations/__init__.py @@ -22,6 +22,7 @@ validate_pseudos_namespace, write_basissets, write_pseudos, + estimate_added_mos, ) from ..utils import Cp2kInput @@ -139,7 +140,7 @@ def prepare_for_submission(self, folder): :return: `aiida.common.datastructures.CalcInfo` instance """ - # pylint: disable=too-many-statements,too-many-branches + # pylint: disable=too-many-statements,too-many-branches,too-many-locals # Create cp2k input file. inp = Cp2kInput(self.inputs.parameters.get_dict()) @@ -167,6 +168,20 @@ def prepare_for_submission(self, folder): self.inputs.structure if 'structure' in self.inputs else None) write_basissets(inp, self.inputs.basissets, folder) + # if we have both basissets and structure we can start helping the user :) + if 'basissets' in self.inputs and 'structure' in self.inputs: + try: + scf_section = inp.get_section_dict('FORCE_EVAL/DFT/SCF') + except (KeyError, NotImplementedError): + pass # if not found or multiple FORCE_EVAL, do nothing (yet) + else: + if 'SMEAR' in scf_section and 'ADDED_MOS' not in scf_section: + # now is our time to shine! + added_mos = estimate_added_mos(self.inputs.basissets, self.inputs.structure) + inp.add_keyword('FORCE_EVAL/DFT/SCF/ADDED_MOS', added_mos) + self.logger.info(f'The FORCE_EVAL/DFT/SCF/ADDED_MOS was added' + f' with an automatically estimated value of {added_mos}') + if 'pseudos' in self.inputs: validate_pseudos(inp, self.inputs.pseudos, self.inputs.structure if 'structure' in self.inputs else None) write_pseudos(inp, self.inputs.pseudos, folder) diff --git a/aiida_cp2k/utils/datatype_helpers.py b/aiida_cp2k/utils/datatype_helpers.py index 8f6e813d..c32bf227 100644 --- a/aiida_cp2k/utils/datatype_helpers.py +++ b/aiida_cp2k/utils/datatype_helpers.py @@ -160,6 +160,31 @@ def validate_basissets(inp, basissets, structure): kind_sec["ELEMENT"] = bset.element +def estimate_added_mos(basissets, structure, fraction=0.3): + """Calculate an estimate for ADDED_MOS based on used basis sets""" + + symbols = [structure.get_kind(s.kind_name).get_symbols_string() for s in structure.sites] + n_mos = 0 + + # We are currently overcounting in the following cases: + # * if we get a mix of ORB basissets for the same chemical symbol but different sites + # * if we get multiple basissets for one element (merged within CP2K) + + for label, bset in _unpack(basissets): + try: + _, bstype = label.split("_", maxsplit=1) + except ValueError: + bstype = "ORB" + + if bstype != "ORB": # ignore non-ORB basissets + continue + + n_mos += symbols.count(bset.element) * bset.n_orbital_functions + + # at least one additional MO per site, otherwise a fraction of the total number of orbital functions + return max(len(symbols), int(fraction * n_mos)) + + def write_basissets(inp, basissets, folder): """Writes the unified BASIS_SETS file with the used basissets""" _write_gdt(inp, basissets, folder, "BASIS_SET_FILE_NAME", "BASIS_SETS") diff --git a/aiida_cp2k/utils/input_generator.py b/aiida_cp2k/utils/input_generator.py index 52c5bb43..5ed145bc 100644 --- a/aiida_cp2k/utils/input_generator.py +++ b/aiida_cp2k/utils/input_generator.py @@ -56,6 +56,74 @@ def add_keyword(self, kwpath, value, override=True, conflicting_keys=None): Cp2kInput._add_keyword(kwpath, value, self._params, ovrd=override, cfct=conflicting_keys) + @staticmethod + def _stringify_path(kwpath): + """Stringify a kwpath argument""" + if isinstance(kwpath, str): + return kwpath + + assert isinstance(kwpath, Sequence), "path is neither Sequence nor String" + return "/".join(kwpath) + + def get_section_dict(self, kwpath=""): + """Get a copy of a section from the current input structure + + Args: + + kwpath: Can be a single keyword, a path with `/` as divider for sections & key, + or a sequence with sections and key. + """ + + section = self._get_section_or_kw(kwpath) + + if not isinstance(section, Mapping): + raise TypeError(f"Section '{self._stringify_path(kwpath)}' requested, but keyword found") + + return deepcopy(section) + + def get_keyword_value(self, kwpath): + """Get the value of a keyword from the current input structure + + Args: + + kwpath: Can be a single keyword, a path with `/` as divider for sections & key, + or a sequence with sections and key. + """ + + keyword = self._get_section_or_kw(kwpath) + + if isinstance(keyword, Mapping): + raise TypeError(f"Keyword '{self._stringify_path(kwpath)}' requested, but section found") + + return keyword + + def _get_section_or_kw(self, kwpath): + """Retrieve either a section or a keyword given a path""" + + if isinstance(kwpath, str): + kwpath = kwpath.split("/") # convert to list of sections if string + + # get a copy of the path in a mutable sequence + # accept any case, but internally we use uppercase + # strip empty strings to accept leading "/", "//", etc. + path = [k.upper() for k in kwpath if k] + + # start with a reference to the root of the parameters + current = self._params + + try: + while path: + current = current[path.pop(0)] + except KeyError: + raise KeyError(f"Section '{self._stringify_path(kwpath)}' not found in parameters") + except TypeError: + if isinstance(current, Sequence) and not isinstance(current, str): + raise NotImplementedError(f"Repeated sections are not yet supported when retrieving data" + f" with path '{self._stringify_path(kwpath)}'") + raise + + return current + def render(self): output = [self.DISCLAIMER] self._render_section(output, deepcopy(self._params)) diff --git a/test/test_gaussian_datatypes.py b/test/test_gaussian_datatypes.py index 4285dc33..d22c2fa6 100644 --- a/test/test_gaussian_datatypes.py +++ b/test/test_gaussian_datatypes.py @@ -778,3 +778,86 @@ def test_without_kinds(cp2k_code, cp2k_basissets, cp2k_pseudos, clear_database): _, calc_node = run_get_node(CalculationFactory("cp2k"), **inputs) assert calc_node.exit_status == 0 + + +def test_added_mos(cp2k_code, cp2k_basissets, cp2k_pseudos, clear_database): # pylint: disable=unused-argument + """Testing CP2K with the Basis Set stored in gaussian.basisset and a smearing section but no predefined ADDED_MOS""" + + structure = StructureData(cell=[[4.00759, 0.0, 0.0], [-2.003795, 3.47067475, 0.0], + [3.06349683e-16, 5.30613216e-16, 5.00307]], + pbc=True) + structure.append_atom(position=(-0.00002004, 2.31379473, 0.87543719), symbols="H") + structure.append_atom(position=(2.00381504, 1.15688001, 4.12763281), symbols="H") + structure.append_atom(position=(2.00381504, 1.15688001, 3.37697219), symbols="H") + structure.append_atom(position=(-0.00002004, 2.31379473, 1.62609781), symbols="H") + + # parameters + parameters = Dict( + dict={ + 'GLOBAL': { + 'RUN_TYPE': 'ENERGY', + }, + 'FORCE_EVAL': { + 'METHOD': 'Quickstep', + 'DFT': { + "XC": { + "XC_FUNCTIONAL": { + "_": "PBE", + }, + }, + "MGRID": { + "CUTOFF": 100.0, + "REL_CUTOFF": 10.0, + }, + "QS": { + "METHOD": "GPW", + "EXTRAPOLATION": "USE_GUESS", + }, + "SCF": { + "EPS_SCF": 1e-05, + "MAX_SCF": 3, + "MIXING": { + "METHOD": "BROYDEN_MIXING", + "ALPHA": 0.4, + }, + "SMEAR": { + "METHOD": "FERMI_DIRAC", + "ELECTRONIC_TEMPERATURE": 300.0, + }, + }, + "KPOINTS": { + "SCHEME": "MONKHORST-PACK 2 2 1", + "FULL_GRID": False, + "SYMMETRY": False, + "PARALLEL_GROUP_SIZE": -1, + }, + }, + }, + }) + + options = { + "resources": { + "num_machines": 1, + "num_mpiprocs_per_machine": 1 + }, + "max_wallclock_seconds": 1 * 3 * 60, + } + + inputs = { + "structure": structure, + "parameters": parameters, + "code": cp2k_code, + "metadata": { + "options": options, + }, + "basissets": {label: b for label, b in cp2k_basissets.items() if label == "H"}, + "pseudos": {label: p for label, p in cp2k_pseudos.items() if label == "H"}, + } + + _, calc_node = run_get_node(CalculationFactory("cp2k"), **inputs) + + assert calc_node.exit_status == 0 + + # check that the ADDED_MOS keyword was added within the calculation + with calc_node.open("aiida.inp") as fhandle: + assert any("ADDED_MOS" in line for line in fhandle), "ADDED_MOS not found in the generated CP2K input file" diff --git a/test/test_input_generator.py b/test/test_input_generator.py index 9f154e14..a8694013 100644 --- a/test/test_input_generator.py +++ b/test/test_input_generator.py @@ -185,3 +185,42 @@ def test_invalid_preprocessor(): inp = Cp2kInput({"@SET": "bar"}) with pytest.raises(ValueError): inp.render() + + +def test_get_keyword_value(): + """Test get_keyword_value()""" + inp = Cp2kInput({"FOO": "bar", "A": {"KW1": "val1"}}) + assert inp.get_keyword_value("FOO") == "bar" + assert inp.get_keyword_value("/FOO") == "bar" + assert inp.get_keyword_value("A/KW1") == "val1" + assert inp.get_keyword_value("/A/KW1") == "val1" + assert inp.get_keyword_value(["A", "KW1"]) == "val1" + with pytest.raises(TypeError): + inp.get_keyword_value("A") + with pytest.raises(KeyError): + inp.get_keyword_value("B") + + +def test_get_section_dict(): + """Test get_section_dict()""" + orig_dict = {"FOO": "bar", "A": {"KW1": "val1"}} + inp = Cp2kInput(orig_dict) + assert inp.get_section_dict("/") == orig_dict + assert inp.get_section_dict("////") == orig_dict + assert inp.get_section_dict("") == orig_dict + assert inp.get_section_dict() == orig_dict + assert inp.get_section_dict("/") is not orig_dict # make sure we get a distinct object + assert inp.get_section_dict("A") == orig_dict["A"] + assert inp.get_section_dict("/A") == orig_dict["A"] + assert inp.get_section_dict(["A"]) == orig_dict["A"] + with pytest.raises(TypeError): + inp.get_section_dict("FOO") + with pytest.raises(KeyError): + inp.get_section_dict("BAR") + + +def test_get_section_dict_repeated(): + """Test NotImplementedError for repeated sections in get_section_dict()""" + inp = Cp2kInput({"FOO": [{"KW1": "val1_1"}, {"KW1": "val1_2"}]}) + with pytest.raises(NotImplementedError): + inp.get_keyword_value("/FOO/KW1")