From 8fb0496b033274fe96d7a5077928add8e96057d5 Mon Sep 17 00:00:00 2001 From: Marnik Bercx Date: Thu, 13 Jan 2022 16:11:19 +0100 Subject: [PATCH] Implement common bands work chain for Quantum ESPRESSO --- .../bands/quantum_espresso/__init__.py | 7 ++ .../bands/quantum_espresso/generator.py | 68 +++++++++++++++++++ .../bands/quantum_espresso/workchain.py | 34 ++++++++++ .../relax/quantum_espresso/generator.py | 8 +++ .../relax/quantum_espresso/workchain.py | 1 + setup.json | 3 +- 6 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 aiida_common_workflows/workflows/bands/quantum_espresso/__init__.py create mode 100644 aiida_common_workflows/workflows/bands/quantum_espresso/generator.py create mode 100644 aiida_common_workflows/workflows/bands/quantum_espresso/workchain.py diff --git a/aiida_common_workflows/workflows/bands/quantum_espresso/__init__.py b/aiida_common_workflows/workflows/bands/quantum_espresso/__init__.py new file mode 100644 index 00000000..1a03f73f --- /dev/null +++ b/aiida_common_workflows/workflows/bands/quantum_espresso/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +# pylint: disable=undefined-variable +"""Module with the implementations of the common bands workchain for Quantum ESPRESSO.""" +from .generator import * +from .workchain import * + +__all__ = (generator.__all__ + workchain.__all__) diff --git a/aiida_common_workflows/workflows/bands/quantum_espresso/generator.py b/aiida_common_workflows/workflows/bands/quantum_espresso/generator.py new file mode 100644 index 00000000..297c664a --- /dev/null +++ b/aiida_common_workflows/workflows/bands/quantum_espresso/generator.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +"""Implementation of the ``CommonBandsInputGenerator`` for Quantum ESPRESSO.""" + +from aiida import engine, orm +from aiida.common import LinkType + +from aiida_common_workflows.generators import CodeType + +from ..generator import CommonBandsInputGenerator + +__all__ = ('QuantumEspressoCommonBandsInputGenerator',) + + +class QuantumEspressoCommonBandsInputGenerator(CommonBandsInputGenerator): + """Input generator for the ``QuantumEspressoCommonBandsWorkChain``""" + + @classmethod + def define(cls, spec): + """Define the specification of the input generator. + + The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method. + """ + super().define(spec) + spec.inputs['engines']['bands']['code'].valid_type = CodeType('quantumespresso.pw') + + def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: + """Construct a process builder based on the provided keyword arguments. + + The keyword arguments will have been validated against the input generator specification. + """ + # pylint: disable=too-many-branches,too-many-statements,too-many-locals + engines = kwargs.get('engines', None) + parent_folder = kwargs['parent_folder'] + bands_kpoints = kwargs['bands_kpoints'] + + # Find the `PwCalculation` that created the `parent_folder` and obtain the restart builder. + parent_calc = parent_folder.get_incoming(link_type=LinkType.CREATE).one().node + if parent_calc.process_type != 'aiida.calculations:quantumespresso.pw': + raise ValueError('The `parent_folder` has not been created by a `PwCalculation`.') + builder_calc = parent_calc.get_builder_restart() + + builder_common_bands_wc = self.process_class.get_builder() + builder_calc.pop('kpoints') + builder_common_bands_wc.pw = builder_calc + parameters = builder_common_bands_wc.pw.parameters.get_dict() + parameters['CONTROL']['calculation'] = 'bands' + builder_common_bands_wc.pw.parameters = orm.Dict(dict=parameters) + builder_common_bands_wc.kpoints = bands_kpoints + builder_common_bands_wc.pw.parent_folder = parent_folder + + # Update the structure in case we have one in output, i.e. the `parent_calc` optimized the structure + if 'output_structure' in parent_calc.outputs: + builder_common_bands_wc.pw.structure = parent_calc.outputs.output_structure + + # Update the code and computational options if `engines` is specified + try: + bands_engine = engines['bands'] + except KeyError: + raise ValueError('The `engines` dictionary must contain `bands` as a top-level key') + if 'code' in bands_engine: + code = engines['bands']['code'] + if isinstance(code, str): + code = orm.load_code(code) + builder_common_bands_wc.pw.code = code + if 'options' in bands_engine: + builder_common_bands_wc.pw.metadata.options = engines['bands']['options'] + + return builder_common_bands_wc diff --git a/aiida_common_workflows/workflows/bands/quantum_espresso/workchain.py b/aiida_common_workflows/workflows/bands/quantum_espresso/workchain.py new file mode 100644 index 00000000..7826e019 --- /dev/null +++ b/aiida_common_workflows/workflows/bands/quantum_espresso/workchain.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO.""" +from aiida.engine import calcfunction +from aiida.orm import Float +from aiida.plugins import WorkflowFactory + +from ..workchain import CommonBandsWorkChain +from .generator import QuantumEspressoCommonBandsInputGenerator + +__all__ = ('QuantumEspressoCommonBandsWorkChain',) + + +@calcfunction +def get_fermi_energy(output_parameters): + """Extract the Fermi energy from the ``output_parameters`` of a ``PwBaseWorkChain``.""" + return Float(output_parameters['fermi_energy']) + + +class QuantumEspressoCommonBandsWorkChain(CommonBandsWorkChain): + """Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO.""" + + _process_class = WorkflowFactory('quantumespresso.pw.base') + _generator_class = QuantumEspressoCommonBandsInputGenerator + + def convert_outputs(self): + """Convert the outputs of the sub work chain to the common output specification.""" + outputs = self.ctx.workchain.outputs + + if 'output_band' not in outputs: + self.report('The `bands` PwBaseWorkChain does not have the `output_band` output.') + return self.exit_codes.ERROR_SUB_PROCESS_FAILED + + self.out('bands', outputs.output_band) + self.out('fermi_energy', get_fermi_energy(outputs.output_parameters)) diff --git a/aiida_common_workflows/workflows/relax/quantum_espresso/generator.py b/aiida_common_workflows/workflows/relax/quantum_espresso/generator.py index 8a96df49..72578088 100644 --- a/aiida_common_workflows/workflows/relax/quantum_espresso/generator.py +++ b/aiida_common_workflows/workflows/relax/quantum_espresso/generator.py @@ -91,6 +91,12 @@ def define(cls, spec): ) spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR)) spec.inputs['engines']['relax']['code'].valid_type = CodeType('quantumespresso.pw') + spec.input( + 'clean_workdir', + valid_type=orm.Bool, + default=lambda: orm.Bool(False), + help='If `True`, work directories of all called calculation will be cleaned at the end of execution.' + ) def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: """Construct a process builder based on the provided keyword arguments. @@ -111,6 +117,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: threshold_forces = kwargs.get('threshold_forces', None) threshold_stress = kwargs.get('threshold_stress', None) reference_workchain = kwargs.get('reference_workchain', None) + clean_workdir = kwargs.get('clean_workdir') if isinstance(electronic_type, str): electronic_type = types.ElectronicType(electronic_type) @@ -162,6 +169,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: spin_type=spin_type, initial_magnetic_moments=initial_magnetic_moments, ) + builder.clean_workdir = clean_workdir if threshold_forces is not None: threshold = threshold_forces * CONSTANTS.bohr_to_ang / CONSTANTS.ry_to_ev diff --git a/aiida_common_workflows/workflows/relax/quantum_espresso/workchain.py b/aiida_common_workflows/workflows/relax/quantum_espresso/workchain.py index 09bd4da4..ed275c34 100644 --- a/aiida_common_workflows/workflows/relax/quantum_espresso/workchain.py +++ b/aiida_common_workflows/workflows/relax/quantum_espresso/workchain.py @@ -63,3 +63,4 @@ def convert_outputs(self): self.out('total_energy', total_energy) self.out('forces', forces) self.out('stress', stress) + self.out('remote_folder', outputs.remote_folder) diff --git a/setup.json b/setup.json index 45fbfbea..8a100749 100644 --- a/setup.json +++ b/setup.json @@ -72,7 +72,8 @@ "common_workflows.relax.quantum_espresso = aiida_common_workflows.workflows.relax.quantum_espresso.workchain:QuantumEspressoCommonRelaxWorkChain", "common_workflows.relax.siesta = aiida_common_workflows.workflows.relax.siesta.workchain:SiestaCommonRelaxWorkChain", "common_workflows.relax.vasp = aiida_common_workflows.workflows.relax.vasp.workchain:VaspCommonRelaxWorkChain", - "common_workflows.bands.siesta = aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain" + "common_workflows.bands.siesta = aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain", + "common_workflows.bands.quantum_espresso = aiida_common_workflows.workflows.bands.quantum_espresso.workchain:QuantumEspressoCommonBandsWorkChain" ] }, "license": "MIT License",