Skip to content

Commit

Permalink
Implement common bands work chain for Quantum ESPRESSO
Browse files Browse the repository at this point in the history
  • Loading branch information
mbercx committed Jan 13, 2022
1 parent 69468a7 commit 8fb0496
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
# pylint: disable=undefined-variable
"""Module with the implementations of the common bands workchain for Quantum ESPRESSO."""
from .generator import *
from .workchain import *

__all__ = (generator.__all__ + workchain.__all__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""Implementation of the ``CommonBandsInputGenerator`` for Quantum ESPRESSO."""

from aiida import engine, orm
from aiida.common import LinkType

from aiida_common_workflows.generators import CodeType

from ..generator import CommonBandsInputGenerator

__all__ = ('QuantumEspressoCommonBandsInputGenerator',)


class QuantumEspressoCommonBandsInputGenerator(CommonBandsInputGenerator):
"""Input generator for the ``QuantumEspressoCommonBandsWorkChain``"""

@classmethod
def define(cls, spec):
"""Define the specification of the input generator.
The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method.
"""
super().define(spec)
spec.inputs['engines']['bands']['code'].valid_type = CodeType('quantumespresso.pw')

def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
"""Construct a process builder based on the provided keyword arguments.
The keyword arguments will have been validated against the input generator specification.
"""
# pylint: disable=too-many-branches,too-many-statements,too-many-locals
engines = kwargs.get('engines', None)
parent_folder = kwargs['parent_folder']
bands_kpoints = kwargs['bands_kpoints']

# Find the `PwCalculation` that created the `parent_folder` and obtain the restart builder.
parent_calc = parent_folder.get_incoming(link_type=LinkType.CREATE).one().node
if parent_calc.process_type != 'aiida.calculations:quantumespresso.pw':
raise ValueError('The `parent_folder` has not been created by a `PwCalculation`.')
builder_calc = parent_calc.get_builder_restart()

builder_common_bands_wc = self.process_class.get_builder()
builder_calc.pop('kpoints')
builder_common_bands_wc.pw = builder_calc
parameters = builder_common_bands_wc.pw.parameters.get_dict()
parameters['CONTROL']['calculation'] = 'bands'
builder_common_bands_wc.pw.parameters = orm.Dict(dict=parameters)
builder_common_bands_wc.kpoints = bands_kpoints
builder_common_bands_wc.pw.parent_folder = parent_folder

# Update the structure in case we have one in output, i.e. the `parent_calc` optimized the structure
if 'output_structure' in parent_calc.outputs:
builder_common_bands_wc.pw.structure = parent_calc.outputs.output_structure

# Update the code and computational options if `engines` is specified
try:
bands_engine = engines['bands']
except KeyError:
raise ValueError('The `engines` dictionary must contain `bands` as a top-level key')
if 'code' in bands_engine:
code = engines['bands']['code']
if isinstance(code, str):
code = orm.load_code(code)
builder_common_bands_wc.pw.code = code
if 'options' in bands_engine:
builder_common_bands_wc.pw.metadata.options = engines['bands']['options']

return builder_common_bands_wc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO."""
from aiida.engine import calcfunction
from aiida.orm import Float
from aiida.plugins import WorkflowFactory

from ..workchain import CommonBandsWorkChain
from .generator import QuantumEspressoCommonBandsInputGenerator

__all__ = ('QuantumEspressoCommonBandsWorkChain',)


@calcfunction
def get_fermi_energy(output_parameters):
"""Extract the Fermi energy from the ``output_parameters`` of a ``PwBaseWorkChain``."""
return Float(output_parameters['fermi_energy'])


class QuantumEspressoCommonBandsWorkChain(CommonBandsWorkChain):
"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO."""

_process_class = WorkflowFactory('quantumespresso.pw.base')
_generator_class = QuantumEspressoCommonBandsInputGenerator

def convert_outputs(self):
"""Convert the outputs of the sub work chain to the common output specification."""
outputs = self.ctx.workchain.outputs

if 'output_band' not in outputs:
self.report('The `bands` PwBaseWorkChain does not have the `output_band` output.')
return self.exit_codes.ERROR_SUB_PROCESS_FAILED

self.out('bands', outputs.output_band)
self.out('fermi_energy', get_fermi_energy(outputs.output_parameters))
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def define(cls, spec):
)
spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR))
spec.inputs['engines']['relax']['code'].valid_type = CodeType('quantumespresso.pw')
spec.input(
'clean_workdir',
valid_type=orm.Bool,
default=lambda: orm.Bool(False),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.'
)

def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
"""Construct a process builder based on the provided keyword arguments.
Expand All @@ -111,6 +117,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
threshold_forces = kwargs.get('threshold_forces', None)
threshold_stress = kwargs.get('threshold_stress', None)
reference_workchain = kwargs.get('reference_workchain', None)
clean_workdir = kwargs.get('clean_workdir')

if isinstance(electronic_type, str):
electronic_type = types.ElectronicType(electronic_type)
Expand Down Expand Up @@ -162,6 +169,7 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder:
spin_type=spin_type,
initial_magnetic_moments=initial_magnetic_moments,
)
builder.clean_workdir = clean_workdir

if threshold_forces is not None:
threshold = threshold_forces * CONSTANTS.bohr_to_ang / CONSTANTS.ry_to_ev
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,4 @@ def convert_outputs(self):
self.out('total_energy', total_energy)
self.out('forces', forces)
self.out('stress', stress)
self.out('remote_folder', outputs.remote_folder)
3 changes: 2 additions & 1 deletion setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
"common_workflows.relax.quantum_espresso = aiida_common_workflows.workflows.relax.quantum_espresso.workchain:QuantumEspressoCommonRelaxWorkChain",
"common_workflows.relax.siesta = aiida_common_workflows.workflows.relax.siesta.workchain:SiestaCommonRelaxWorkChain",
"common_workflows.relax.vasp = aiida_common_workflows.workflows.relax.vasp.workchain:VaspCommonRelaxWorkChain",
"common_workflows.bands.siesta = aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain"
"common_workflows.bands.siesta = aiida_common_workflows.workflows.bands.siesta.workchain:SiestaCommonBandsWorkChain",
"common_workflows.bands.quantum_espresso = aiida_common_workflows.workflows.bands.quantum_espresso.workchain:QuantumEspressoCommonBandsWorkChain"
]
},
"license": "MIT License",
Expand Down

0 comments on commit 8fb0496

Please sign in to comment.