Source code for aiida_wannier90_workflows.workflows.open_grid

"""Wannierisation workflow using open_grid.x to bypass the nscf step."""
import pathlib
import typing as ty

from aiida import orm
from aiida.common import AttributeDict
from aiida.engine.processes import ProcessBuilder, ToContext, if_

from aiida_quantumespresso.utils.mapping import prepare_process_inputs

from .base.open_grid import OpenGridBaseWorkChain
from .wannier90 import Wannier90WorkChain

__all__ = ("validate_inputs", "Wannier90OpenGridWorkChain")


def validate_inputs(  # pylint: disable=unused-argument,inconsistent-return-statements
    inputs, ctx=None
):
    """Validate the inputs of the entire input namespace of `Wannier90OpenGridWorkChain`."""
    from .wannier90 import validate_inputs as parent_validate_inputs

    # Call parent validator
    result = parent_validate_inputs(inputs)

    if result is not None:
        return result


[docs]class Wannier90OpenGridWorkChain(Wannier90WorkChain): """WorkChain using open_grid.x to bypass the nscf step. The open_grid.x unfolds the symmetrized kmesh to a full kmesh, thus the full-kmesh nscf step can be avoided. 2 schemes: 1. scf w/ symmetry, more nbnd -> open_grid -> pw2wannier90 -> wannier90 2. scf w/ symmetry, default nbnd -> nscf w/ symm, more nbnd -> open_grid -> pw2wannier90 -> wannier90 """ @classmethod def define(cls, spec): """Define the process spec.""" super().define(spec) spec.expose_inputs( OpenGridBaseWorkChain, namespace="open_grid", exclude=("clean_workdir", "open_grid.parent_folder"), namespace_options={ "required": False, "populate_defaults": False, "help": "Inputs for the `OpenGridBaseWorkChain`, if not specified the open_grid step is skipped.", }, ) spec.inputs.validator = validate_inputs spec.outline( cls.setup, if_(cls.should_run_scf)( cls.run_scf, cls.inspect_scf, ), if_(cls.should_run_nscf)( cls.run_nscf, cls.inspect_nscf, ), if_(cls.should_run_open_grid)( cls.run_open_grid, cls.inspect_open_grid, ), if_(cls.should_run_projwfc)( cls.run_projwfc, cls.inspect_projwfc, ), cls.run_wannier90_pp, cls.inspect_wannier90_pp, cls.run_pw2wannier90, cls.inspect_pw2wannier90, cls.run_wannier90, cls.inspect_wannier90, cls.results, ) spec.expose_outputs( OpenGridBaseWorkChain, namespace="open_grid", namespace_options={"required": False}, ) spec.exit_code( 490, "ERROR_SUB_PROCESS_FAILED_OPEN_GRID", message="the OpenGridBaseWorkChain sub process failed", ) @classmethod def get_protocol_filepath(cls) -> pathlib.Path: """Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" from importlib_resources import files from . import protocols return files(protocols) / "open_grid.yaml" @classmethod def get_builder_from_protocol( # pylint: disable=arguments-differ cls, codes: ty.Mapping[str, ty.Union[str, int, orm.Code]], structure: orm.StructureData, *, open_grid_only_scf: bool = True, **kwargs, ) -> ProcessBuilder: """Return a builder populated with predefined inputs that can be directly submitted. Optional keyword arguments are passed to the same function of `Wannier90WorkChain`. Overrides `Wannier90WorkChain` workchain. :param codes: [description] :type codes: ty.Mapping[str, ty.Union[str, int, orm.Code]] :param open_grid_only_scf: If True first do a scf with symmetry and increased number of bands, then launch open_grid.x to unfold kmesh; If False first do a scf with symmetry and default number of bands, then a nscf with symmetry and increased number of bands, followed by open_grid.x. :type open_grid_only_scf: bool """ from aiida_wannier90_workflows.utils.workflows.builder.submit import ( recursive_merge_builder, ) summary = kwargs.pop("summary", {}) print_summary = kwargs.pop("print_summary", True) # Prepare workchain builder builder = Wannier90OpenGridWorkChain.get_builder() inputs = Wannier90OpenGridWorkChain.get_protocol_inputs( protocol=kwargs.get("protocol", None), overrides=kwargs.get("overrides", None), ) builder = recursive_merge_builder(builder, inputs) parent_builder = super().get_builder_from_protocol( codes, structure, summary=summary, print_summary=False, **kwargs ) inputs = parent_builder._inputs(prune=True) # pylint: disable=protected-access builder = recursive_merge_builder(builder, inputs) # Adapt pw.x parameters if open_grid_only_scf: nbnd = ( builder.nscf["pw"]["parameters"].get_dict()["SYSTEM"].get("nbnd", None) ) params = builder.scf["pw"]["parameters"].get_dict() if nbnd is not None: params["SYSTEM"]["nbnd"] = nbnd params["SYSTEM"].pop("nosym", None) params["SYSTEM"].pop("noinv", None) params["ELECTRONS"]["diago_full_acc"] = True builder.scf["pw"]["parameters"] = orm.Dict(params) builder.nscf.clear() else: params = builder.nscf["pw"]["parameters"].get_dict() params["SYSTEM"].pop("nosym", None) params["SYSTEM"].pop("noinv", None) builder.nscf["pw"]["parameters"] = orm.Dict(params) builder.nscf.pop("kpoints", None) builder.nscf["kpoints_distance"] = builder.scf["kpoints_distance"] builder.nscf["kpoints_force_parity"] = builder.scf["kpoints_force_parity"] # Prepare open_grid open_grid_overrides = kwargs.get("overrides", {}).get("open_grid", {}) open_grid_builder = OpenGridBaseWorkChain.get_builder_from_protocol( code=codes["open_grid"], protocol=kwargs.get("protocol", None), overrides=open_grid_overrides, ) # Remove workchain excluded inputs open_grid_builder.pop("clean_workdir", None) builder.open_grid = open_grid_builder._inputs( prune=True ) # pylint: disable=protected-access if print_summary: cls.print_summary(summary) return builder def should_run_open_grid(self): """If the 'open_grid' input namespace was specified, we run open_grid after scf or nscf calculation.""" return "open_grid" in self.inputs def run_open_grid(self): """Use QE open_grid.x to unfold irriducible kmesh to a full kmesh.""" inputs = AttributeDict( self.exposed_inputs(OpenGridBaseWorkChain, namespace="open_grid") ) inputs.open_grid.parent_folder = self.ctx.current_folder inputs.metadata.call_link_label = "open_grid" inputs = prepare_process_inputs(OpenGridBaseWorkChain, inputs) running = self.submit(OpenGridBaseWorkChain, **inputs) self.report(f"launching {running.process_label}<{running.pk}>") return ToContext(workchain_open_grid=running) def inspect_open_grid(self): # pylint: disable=inconsistent-return-statements """Verify that the `OpenGridBaseWorkChain` run successfully finished.""" workchain = self.ctx.workchain_open_grid if not workchain.is_finished_ok: self.report( f"{workchain.process_label} failed with exit status {workchain.exit_status}" ) return self.exit_codes.ERROR_SUB_PROCESS_FAILED_OPEN_GRID self.ctx.current_folder = workchain.outputs.remote_folder def prepare_wannier90_pp_inputs(self): """Override the parent method in `Wannier90WorkChain`. The wannier input kpoints are set as the parsed output from `OpenGridBaseWorkChain`. """ base_inputs = super().prepare_wannier90_pp_inputs() inputs = base_inputs["wannier90"] if self.should_run_open_grid(): open_grid_outputs = self.ctx.workchain_open_grid.outputs inputs.kpoints = open_grid_outputs.kpoints parameters = inputs.parameters.get_dict() parameters["mp_grid"] = open_grid_outputs.kpoints_mesh.get_kpoints_mesh()[0] inputs.parameters = orm.Dict(parameters) base_inputs["wannier90"] = inputs return base_inputs def results(self): """Override parent workchain.""" if self.should_run_open_grid(): self.out_many( self.exposed_outputs( self.ctx.workchain_open_grid, OpenGridBaseWorkChain, namespace="open_grid", ) ) super().results()