-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement common bands work chain for Quantum ESPRESSO
- Loading branch information
Showing
6 changed files
with
120 additions
and
1 deletion.
There are no files selected for viewing
7 changes: 7 additions & 0 deletions
7
aiida_common_workflows/workflows/bands/quantum_espresso/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# -*- coding: utf-8 -*- | ||
# pylint: disable=undefined-variable | ||
"""Module with the implementations of the common bands workchain for Quantum ESPRESSO.""" | ||
from .generator import * | ||
from .workchain import * | ||
|
||
__all__ = (generator.__all__ + workchain.__all__) |
68 changes: 68 additions & 0 deletions
68
aiida_common_workflows/workflows/bands/quantum_espresso/generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Implementation of the ``CommonBandsInputGenerator`` for Quantum ESPRESSO.""" | ||
|
||
from aiida import engine, orm | ||
from aiida.common import LinkType | ||
|
||
from aiida_common_workflows.generators import CodeType | ||
|
||
from ..generator import CommonBandsInputGenerator | ||
|
||
__all__ = ('QuantumEspressoCommonBandsInputGenerator',) | ||
|
||
|
||
class QuantumEspressoCommonBandsInputGenerator(CommonBandsInputGenerator): | ||
"""Input generator for the ``QuantumEspressoCommonBandsWorkChain``""" | ||
|
||
@classmethod | ||
def define(cls, spec): | ||
"""Define the specification of the input generator. | ||
The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method. | ||
""" | ||
super().define(spec) | ||
spec.inputs['engines']['bands']['code'].valid_type = CodeType('quantumespresso.pw') | ||
|
||
def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: | ||
"""Construct a process builder based on the provided keyword arguments. | ||
The keyword arguments will have been validated against the input generator specification. | ||
""" | ||
# pylint: disable=too-many-branches,too-many-statements,too-many-locals | ||
engines = kwargs.get('engines', None) | ||
parent_folder = kwargs['parent_folder'] | ||
bands_kpoints = kwargs['bands_kpoints'] | ||
|
||
# Find the `PwCalculation` that created the `parent_folder` and obtain the restart builder. | ||
parent_calc = parent_folder.get_incoming(link_type=LinkType.CREATE).one().node | ||
if parent_calc.process_type != 'aiida.calculations:quantumespresso.pw': | ||
raise ValueError('The `parent_folder` has not been created by a `PwCalculation`.') | ||
builder_calc = parent_calc.get_builder_restart() | ||
|
||
builder_common_bands_wc = self.process_class.get_builder() | ||
builder_calc.pop('kpoints') | ||
builder_common_bands_wc.pw = builder_calc | ||
parameters = builder_common_bands_wc.pw.parameters.get_dict() | ||
parameters['CONTROL']['calculation'] = 'bands' | ||
builder_common_bands_wc.pw.parameters = orm.Dict(dict=parameters) | ||
builder_common_bands_wc.kpoints = bands_kpoints | ||
builder_common_bands_wc.pw.parent_folder = parent_folder | ||
|
||
# Update the structure in case we have one in output, i.e. the `parent_calc` optimized the structure | ||
if 'output_structure' in parent_calc.outputs: | ||
builder_common_bands_wc.pw.structure = parent_calc.outputs.output_structure | ||
|
||
# Update the code and computational options if `engines` is specified | ||
try: | ||
bands_engine = engines['bands'] | ||
except KeyError: | ||
raise ValueError('The `engines` dictionary must contain `bands` as a top-level key') | ||
if 'code' in bands_engine: | ||
code = engines['bands']['code'] | ||
if isinstance(code, str): | ||
code = orm.load_code(code) | ||
builder_common_bands_wc.pw.code = code | ||
if 'options' in bands_engine: | ||
builder_common_bands_wc.pw.metadata.options = engines['bands']['options'] | ||
|
||
return builder_common_bands_wc |
34 changes: 34 additions & 0 deletions
34
aiida_common_workflows/workflows/bands/quantum_espresso/workchain.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO.""" | ||
from aiida.engine import calcfunction | ||
from aiida.orm import Float | ||
from aiida.plugins import WorkflowFactory | ||
|
||
from ..workchain import CommonBandsWorkChain | ||
from .generator import QuantumEspressoCommonBandsInputGenerator | ||
|
||
__all__ = ('QuantumEspressoCommonBandsWorkChain',) | ||
|
||
|
||
@calcfunction | ||
def get_fermi_energy(output_parameters): | ||
"""Extract the Fermi energy from the ``output_parameters`` of a ``PwBaseWorkChain``.""" | ||
return Float(output_parameters['fermi_energy']) | ||
|
||
|
||
class QuantumEspressoCommonBandsWorkChain(CommonBandsWorkChain): | ||
"""Implementation of the ``CommonBandsWorkChain`` for Quantum ESPRESSO.""" | ||
|
||
_process_class = WorkflowFactory('quantumespresso.pw.base') | ||
_generator_class = QuantumEspressoCommonBandsInputGenerator | ||
|
||
def convert_outputs(self): | ||
"""Convert the outputs of the sub work chain to the common output specification.""" | ||
outputs = self.ctx.workchain.outputs | ||
|
||
if 'output_band' not in outputs: | ||
self.report('The `bands` PwBaseWorkChain does not have the `output_band` output.') | ||
return self.exit_codes.ERROR_SUB_PROCESS_FAILED | ||
|
||
self.out('bands', outputs.output_band) | ||
self.out('fermi_energy', get_fermi_energy(outputs.output_parameters)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters