Source code for aiida_wannier90_workflows.workflows.wannier90

"""Base class for Wannierisation workflow."""
import pathlib
import typing as ty

from aiida import orm
from aiida.common import AttributeDict
from aiida.common.lang import type_check
from aiida.engine.processes import ProcessBuilder, ToContext, WorkChain, if_
from aiida.orm.nodes.data.base import to_aiida_type

from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_quantumespresso.common.types import ElectronicType, SpinType
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain

from aiida_wannier90_workflows.common.types import (
    WannierDisentanglementType,
    WannierFrozenType,
    WannierProjectionType,
)

from .base.projwfc import ProjwfcBaseWorkChain
from .base.pw2wannier90 import Pw2wannier90BaseWorkChain
from .base.wannier90 import Wannier90BaseWorkChain

__all__ = ["validate_inputs", "Wannier90WorkChain"]


def validate_inputs(  # pylint: disable=unused-argument,inconsistent-return-statements
    inputs, ctx=None
):
    """Validate the inputs of the entire input namespace of `Wannier90WorkChain`."""
    # If no scf inputs, the nscf must have a `parent_folder`
    if "scf" not in inputs:
        if "parent_folder" not in inputs["nscf"]["pw"]:
            return "If skipping scf step, nscf inputs must have a `parent_folder`"

    # Cannot specify both `auto_energy_windows` and `scdm_proj`
    pw2wannier_parameters = inputs["pw2wannier90"]["pw2wannier90"][
        "parameters"
    ].get_dict()
    auto_energy_windows = inputs["wannier90"].get("auto_energy_windows", False)
    scdm_proj = pw2wannier_parameters["inputpp"].get("scdm_proj", False)
    if auto_energy_windows and scdm_proj:
        return "`auto_energy_windows` is incompatible with SCDM"

    # Cannot specify both `auto_energy_windows` and `shift_energy_windows`
    shift_energy_windows = inputs["wannier90"].get("shift_energy_windows", False)
    if auto_energy_windows and shift_energy_windows:
        return "`auto_energy_windows` and `shift_energy_windows` are incompatible"


# pylint: disable=fixme,too-many-lines
[docs]class Wannier90WorkChain( ProtocolMixin, WorkChain ): # pylint: disable=too-many-public-methods """Workchain to obtain maximally localised Wannier functions (MLWF). Run the following steps: scf -> nscf -> projwfc -> wannier90 postproc -> pw2wannier90 -> wannier90 """ @classmethod def define(cls, spec): """Define the process spec.""" from .base.pw2wannier90 import ( validate_inputs_base as validate_inputs_base_pw2wannier90, ) from .base.wannier90 import ( validate_inputs_base as validate_inputs_base_wannier90, ) super().define(spec) spec.input( "structure", valid_type=orm.StructureData, help="The input structure." ) spec.input( "clean_workdir", valid_type=orm.Bool, serializer=to_aiida_type, default=lambda: orm.Bool(False), help=( "If True, work directories of all called calculation will be cleaned " "at the end of execution." ), ) spec.expose_inputs( PwBaseWorkChain, namespace="scf", exclude=("clean_workdir", "pw.structure"), namespace_options={ "required": False, "populate_defaults": False, "help": "Inputs for the `PwBaseWorkChain` for the SCF calculation.", }, ) spec.expose_inputs( PwBaseWorkChain, namespace="nscf", exclude=("clean_workdir", "pw.structure"), namespace_options={ "required": False, "populate_defaults": False, "help": "Inputs for the `PwBaseWorkChain` for the NSCF calculation.", }, ) spec.inputs["nscf"]["pw"].validator = PwCalculation.validate_inputs_base spec.expose_inputs( ProjwfcBaseWorkChain, namespace="projwfc", exclude=("clean_workdir", "projwfc.parent_folder"), namespace_options={ "required": False, "populate_defaults": False, "help": "Inputs for the `ProjwfcBaseWorkChain`.", }, ) spec.expose_inputs( Pw2wannier90BaseWorkChain, namespace="pw2wannier90", exclude=( "clean_workdir", "pw2wannier90.parent_folder", "pw2wannier90.nnkp_file", ), namespace_options={"help": "Inputs for the `Pw2wannier90BaseWorkChain`."}, ) spec.inputs["pw2wannier90"].validator = validate_inputs_base_pw2wannier90 spec.expose_inputs( Wannier90BaseWorkChain, namespace="wannier90", exclude=("clean_workdir", "wannier90.structure"), namespace_options={"help": "Inputs for the `Wannier90BaseWorkChain`."}, ) spec.inputs["wannier90"].validator = validate_inputs_base_wannier90 spec.inputs.validator = validate_inputs spec.outline( cls.setup, if_(cls.should_run_scf)( cls.run_scf, cls.inspect_scf, ), if_(cls.should_run_nscf)( cls.run_nscf, cls.inspect_nscf, ), if_(cls.should_run_projwfc)( cls.run_projwfc, cls.inspect_projwfc, ), cls.run_wannier90_pp, cls.inspect_wannier90_pp, cls.run_pw2wannier90, cls.inspect_pw2wannier90, cls.run_wannier90, cls.inspect_wannier90, cls.results, ) spec.expose_outputs( PwBaseWorkChain, namespace="scf", namespace_options={"required": False} ) spec.expose_outputs( PwBaseWorkChain, namespace="nscf", namespace_options={"required": False} ) spec.expose_outputs( ProjwfcBaseWorkChain, namespace="projwfc", namespace_options={"required": False}, ) spec.expose_outputs(Pw2wannier90BaseWorkChain, namespace="pw2wannier90") spec.expose_outputs(Wannier90BaseWorkChain, namespace="wannier90_pp") spec.expose_outputs(Wannier90BaseWorkChain, namespace="wannier90") spec.exit_code( 420, "ERROR_SUB_PROCESS_FAILED_SCF", message="the scf PwBaseWorkChain sub process failed", ) spec.exit_code( 430, "ERROR_SUB_PROCESS_FAILED_NSCF", message="the nscf PwBaseWorkChain sub process failed", ) spec.exit_code( 440, "ERROR_SUB_PROCESS_FAILED_PROJWFC", message="the ProjwfcBaseWorkChain sub process failed", ) spec.exit_code( 450, "ERROR_SUB_PROCESS_FAILED_WANNIER90PP", message="the postproc Wannier90BaseWorkChain sub process failed", ) spec.exit_code( 460, "ERROR_SUB_PROCESS_FAILED_PW2WANNIER90", message="the Pw2wannier90BaseWorkChain sub process failed", ) spec.exit_code( 470, "ERROR_SUB_PROCESS_FAILED_WANNIER90", message="the Wannier90BaseWorkChain sub process failed", ) spec.exit_code( 480, "ERROR_SANITY_CHECK_FAILED", message="outputs sanity check failed" ) @classmethod def get_protocol_filepath(cls) -> pathlib.Path: """Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" from importlib_resources import files from . import protocols return files(protocols) / "wannier90.yaml" @classmethod def get_protocol_overrides(cls) -> dict: """Return the ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" from importlib_resources import files import yaml from . import protocols path = files(protocols) / "overrides" / "wannier90.yaml" with path.open() as file: return yaml.safe_load(file) @classmethod def get_builder_from_protocol( # pylint: disable=unused-argument cls, codes: ty.Mapping[str, ty.Union[str, int, orm.Code]], structure: orm.StructureData, *, protocol: str = None, overrides: dict = None, pseudo_family: str = None, electronic_type: ElectronicType = ElectronicType.METAL, spin_type: SpinType = SpinType.NONE, initial_magnetic_moments: dict = None, projection_type: WannierProjectionType = WannierProjectionType.SCDM, disentanglement_type: WannierDisentanglementType = None, frozen_type: WannierFrozenType = None, exclude_semicore: bool = True, external_projectors_path: str = None, plot_wannier_functions: bool = False, retrieve_hamiltonian: bool = False, retrieve_matrices: bool = False, print_summary: bool = True, summary: dict = None, ) -> ProcessBuilder: """Return a builder prepopulated with inputs selected according to the chosen protocol. The builder can be submitted directly by `aiida.engine.submit(builder)`. :param codes: a dictionary of ``Code`` instance for pw.x, pw2wannier90.x, wannier90.x, (optionally) projwfc.x. :type codes: dict :param structure: the ``StructureData`` instance to use. :type structure: orm.StructureData :param protocol: protocol to use, if not specified, the default will be used. :type protocol: str :param overrides: optional dictionary of inputs to override the defaults of the protocol. :param electronic_type: indicate the electronic character of the system through ``ElectronicType`` instance. :param spin_type: indicate the spin polarization type to use through a ``SpinType`` instance. :param initial_magnetic_moments: optional dictionary that maps the initial magnetic moment of each kind to a desired value for a spin polarized calculation. Note that for ``spin_type == SpinType.COLLINEAR`` an initial guess for the magnetic moment is automatically set in case this argument is not provided. :param projection_type: indicate the Wannier initial projection type of the system through ``WannierProjectionType`` instance. Default to SCDM. :param disentanglement_type: indicate the Wannier disentanglement type of the system through ``WannierDisentanglementType`` instance. Default to None, which will choose the best type based on `projection_type`: For WannierProjectionType.SCDM, use WannierDisentanglementType.NONE For other WannierProjectionType, use WannierDisentanglementType.SMV :param frozen_type: indicate the Wannier disentanglement type of the system through ``WannierFrozenType`` instance. Default to None, which will choose the best frozen type based on `electronic_type` and `projection_type`. for ElectronicType.INSULATOR, use WannierFrozenType.NONE for metals or insulators with conduction bands: for WannierProjectionType.ANALYTIC/RANDOM, use WannierFrozenType.ENERGY_FIXED for WannierProjectionType.ATOMIC_PROJECTORS_QE/OPENMX, use WannierFrozenType.FIXED_PLUS_PROJECTABILITY for WannierProjectionType.SCDM, use WannierFrozenType.NONE :param maximal_localisation: if true do maximal localisation of Wannier functions. :param exclude_semicores: if True do not Wannierise semicore states. :param plot_wannier_functions: if True plot Wannier functions as xsf files. :param retrieve_hamiltonian: if True retrieve Wannier Hamiltonian. :param retrieve_matrices: if True retrieve amn/mmn/eig/chk/spin files. :param print_summary: if True print a summary of key input parameters :param summary: A dict containing key input parameters and can be printed out when the `get_builder_from_protocol` returns, to let user have a quick check of the generated inputs. Since in python dict is pass-by-reference, the input dict can be modified in the method and used by the invoking function. This allows printing the summary only by the last overriding method. :return: a process builder instance with all inputs defined and ready for launch. :rtype: ProcessBuilder """ from aiida_wannier90_workflows.utils.pseudo import ( get_pseudo_orbitals, get_semicore_list, ) from aiida_wannier90_workflows.utils.workflows.builder.generator import ( get_nscf_builder, get_scf_builder, ) from aiida_wannier90_workflows.utils.workflows.builder.projections import ( guess_wannier_projection_types, ) from aiida_wannier90_workflows.utils.workflows.builder.submit import ( check_codes, recursive_merge_builder, ) # Check function arguments codes = check_codes(codes) type_check(electronic_type, ElectronicType) type_check(spin_type, SpinType) type_check(projection_type, WannierProjectionType) if disentanglement_type: type_check(disentanglement_type, WannierDisentanglementType) if frozen_type: type_check(frozen_type, WannierFrozenType) if electronic_type not in [ElectronicType.METAL, ElectronicType.INSULATOR]: raise NotImplementedError( f"electronic type `{electronic_type}` is not supported." ) if spin_type not in [SpinType.NONE, SpinType.SPIN_ORBIT]: raise NotImplementedError(f"spin type `{spin_type}` is not supported.") if initial_magnetic_moments and spin_type != SpinType.COLLINEAR: raise ValueError( f"`initial_magnetic_moments` is specified but spin type `{spin_type}` is incompatible." ) ( projection_type, disentanglement_type, frozen_type, ) = guess_wannier_projection_types( electronic_type=electronic_type, projection_type=projection_type, disentanglement_type=disentanglement_type, frozen_type=frozen_type, ) if projection_type == WannierProjectionType.ATOMIC_PROJECTORS_OPENMX: if external_projectors_path is None: raise ValueError( f"Must specify `external_projectors_path` when using {projection_type}" ) type_check(external_projectors_path, str) if pseudo_family is None: if spin_type == SpinType.SPIN_ORBIT: # I use pseudo-dojo for SOC pseudo_family = "PseudoDojo/0.4/PBE/FR/standard/upf" else: pseudo_family = Wannier90BaseWorkChain.get_protocol_inputs( protocol=protocol )["meta_parameters"]["pseudo_family"] # Prepare workchain builder # I need to use explicitly `Wannier90WorkChain.get_protocol_inputs()` instead of # `cls.get_protocol_inputs()`, because for a subclass of Wannier90WorkChain, # `cls.get_protocol_inputs()` will call the `get_protocol_inputs` of that subclass, # which might be different from this base class. builder = Wannier90WorkChain.get_builder() inputs = Wannier90WorkChain.get_protocol_inputs(protocol, overrides) builder = recursive_merge_builder(builder, inputs) builder["structure"] = structure if not overrides: overrides = {} # Prepare wannier90 wannier_overrides = overrides.get("wannier90", {}) wannier_overrides.setdefault("meta_parameters", {}) wannier_overrides["meta_parameters"].setdefault("pseudo_family", pseudo_family) wannier_overrides["meta_parameters"].setdefault( "exclude_semicore", exclude_semicore ) wannier_builder = Wannier90BaseWorkChain.get_builder_from_protocol( code=codes["wannier90"], structure=structure, protocol=protocol, overrides=wannier_overrides, electronic_type=electronic_type, spin_type=spin_type, projection_type=projection_type, disentanglement_type=disentanglement_type, frozen_type=frozen_type, ) # Remove workchain excluded inputs wannier_builder["wannier90"].pop("structure", None) wannier_builder.pop("clean_workdir", None) builder[ "wannier90" ] = wannier_builder._inputs( # pylint: disable=protected-access prune=True ) kpoints_distance = Wannier90BaseWorkChain.get_protocol_inputs( protocol=protocol, overrides=wannier_overrides )["meta_parameters"]["kpoints_distance"] # Prepare scf scf_overrides = overrides.get("scf", {}) scf_builder = get_scf_builder( code=codes["pw"], structure=structure, kpoints_distance=kpoints_distance, pseudo_family=pseudo_family, electronic_type=electronic_type, spin_type=spin_type, overrides=scf_overrides, ) # Remove workchain excluded inputs scf_builder["pw"].pop("structure", None) scf_builder.pop("clean_workdir", None) builder["scf"] = scf_builder._inputs( # pylint: disable=protected-access prune=True ) # Prepare nscf num_bands = wannier_builder["wannier90"]["parameters"]["num_bands"] exclude_bands = ( wannier_builder["wannier90"]["parameters"] .get_dict() .get("exclude_bands", []) ) nbnd = num_bands + len(exclude_bands) # Use explicit list of kpoints generated by wannier builder. # Since the QE auto generated kpoints might be different from wannier90, here we explicitly # generate a list of kpoint coordinates to avoid discrepencies. kpoints = wannier_builder["wannier90"]["kpoints"] nscf_overrides = overrides.get("nscf", {}) nscf_builder = get_nscf_builder( code=codes["pw"], structure=structure, nbnd=nbnd, kpoints=kpoints, pseudo_family=pseudo_family, electronic_type=electronic_type, spin_type=spin_type, overrides=nscf_overrides, ) # Remove workchain excluded inputs nscf_builder["pw"].pop("structure", None) nscf_builder.pop("clean_workdir", None) builder["nscf"] = nscf_builder._inputs( # pylint: disable=protected-access prune=True ) # Prepare projwfc if projection_type == WannierProjectionType.SCDM: run_projwfc = True else: if ( # pylint: disable=simplifiable-if-statement frozen_type == WannierFrozenType.ENERGY_AUTO ): run_projwfc = True else: run_projwfc = False if run_projwfc: projwfc_overrides = overrides.get("projwfc", {}) projwfc_builder = ProjwfcBaseWorkChain.get_builder_from_protocol( code=codes["projwfc"], protocol=protocol, overrides=projwfc_overrides ) # Remove workchain excluded inputs projwfc_builder.pop("clean_workdir", None) builder[ "projwfc" ] = projwfc_builder._inputs( # pylint: disable=protected-access prune=True ) # Prepare pw2wannier90 exclude_projectors = None if exclude_semicore: pseudo_orbitals = get_pseudo_orbitals(builder["scf"]["pw"]["pseudos"]) exclude_projectors = get_semicore_list(structure, pseudo_orbitals) pw2wannier90_overrides = overrides.get("projwfc", {}) pw2wannier90_builder = Pw2wannier90BaseWorkChain.get_builder_from_protocol( code=codes["pw2wannier90"], protocol=protocol, overrides=pw2wannier90_overrides, electronic_type=electronic_type, projection_type=projection_type, exclude_projectors=exclude_projectors, external_projectors_path=external_projectors_path, ) # Remove workchain excluded inputs pw2wannier90_builder.pop("clean_workdir", None) builder[ "pw2wannier90" ] = pw2wannier90_builder._inputs( # pylint: disable=protected-access prune=True ) # Apply several overrides protocol_overrides = Wannier90WorkChain.get_protocol_overrides() if plot_wannier_functions: builder = recursive_merge_builder( builder, protocol_overrides["plot_wannier_functions"] ) if retrieve_hamiltonian: builder = recursive_merge_builder( builder, protocol_overrides["retrieve_hamiltonian"] ) if retrieve_matrices: builder = recursive_merge_builder( builder, protocol_overrides["retrieve_matrices"] ) # A dictionary containing key info of Wannierisation and will be printed when the function returns. if summary is None: summary = {} summary["Formula"] = structure.get_formula() summary["PseudoFamily"] = pseudo_family summary["ElectronicType"] = electronic_type.name summary["SpinType"] = spin_type.name summary["WannierProjectionType"] = projection_type.name summary["WannierDisentanglementType"] = disentanglement_type.name summary["WannierFrozenType"] = frozen_type.name params = builder["wannier90"]["wannier90"]["parameters"].get_dict() summary["num_bands"] = params["num_bands"] summary["num_wann"] = params["num_wann"] if "exclude_bands" in params: summary["exclude_bands"] = params["exclude_bands"] summary["mp_grid"] = params["mp_grid"] notes = summary.get("notes", []) summary["notes"] = notes if print_summary: cls.print_summary(summary) return builder @classmethod def print_summary(cls, summary: ty.Dict) -> None: """Try to pretty print the summary when the `get_builder_from_protocol` returns.""" notes = summary.pop("notes", []) print("Summary of key input parameters:") for key, val in summary.items(): print(f" {key}: {val}") print("") if len(notes) == 0: return print("Notes:") for note in notes: print(f" * {note}") def setup(self) -> None: """Define the current structure in the context to be the input structure.""" self.ctx.current_structure = self.inputs.structure if not self.should_run_scf(): self.ctx.current_folder = self.inputs["nscf"]["pw"]["parent_folder"] def should_run_scf(self) -> bool: """If the 'scf' input namespace was specified, run the scf workchain.""" return "scf" in self.inputs def run_scf(self): """Run the `PwBaseWorkChain` in scf mode on the current structure.""" inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace="scf")) inputs.pw.structure = self.ctx.current_structure inputs.metadata.call_link_label = "scf" inputs = prepare_process_inputs(PwBaseWorkChain, inputs) running = self.submit(PwBaseWorkChain, **inputs) self.report(f"launching {running.process_label}<{running.pk}> in scf mode") return ToContext(workchain_scf=running) def inspect_scf(self): # pylint: disable=inconsistent-return-statements """Verify that the `PwBaseWorkChain` for the scf run successfully finished.""" workchain = self.ctx.workchain_scf if not workchain.is_finished_ok: self.report( f"scf {workchain.process_label} failed with exit status {workchain.exit_status}" ) return self.exit_codes.ERROR_SUB_PROCESS_FAILED_SCF self.ctx.current_folder = workchain.outputs.remote_folder def should_run_nscf(self) -> bool: """If the `nscf` input namespace was specified, run the nscf workchain.""" return "nscf" in self.inputs def run_nscf(self): """Run the PwBaseWorkChain in nscf mode.""" inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace="nscf")) inputs.pw.structure = self.ctx.current_structure inputs.pw.parent_folder = self.ctx.current_folder inputs.metadata.call_link_label = "nscf" inputs = prepare_process_inputs(PwBaseWorkChain, inputs) running = self.submit(PwBaseWorkChain, **inputs) self.report(f"launching {running.process_label}<{running.pk}> in nscf mode") return ToContext(workchain_nscf=running) def inspect_nscf(self): # pylint: disable=inconsistent-return-statements """Verify that the `PwBaseWorkChain` for the nscf run successfully finished.""" workchain = self.ctx.workchain_nscf if not workchain.is_finished_ok: self.report( f"nscf {workchain.process_label} failed with exit status {workchain.exit_status}" ) return self.exit_codes.ERROR_SUB_PROCESS_FAILED_NSCF self.ctx.current_folder = workchain.outputs.remote_folder def should_run_projwfc(self) -> bool: """If the 'projwfc' input namespace was specified, run the projwfc calculation.""" return "projwfc" in self.inputs def run_projwfc(self): """Projwfc step.""" inputs = AttributeDict( self.exposed_inputs(ProjwfcBaseWorkChain, namespace="projwfc") ) inputs.projwfc.parent_folder = self.ctx.current_folder inputs.metadata.call_link_label = "projwfc" inputs = prepare_process_inputs(ProjwfcBaseWorkChain, inputs) running = self.submit(ProjwfcBaseWorkChain, **inputs) self.report(f"launching {running.process_label}<{running.pk}>") return ToContext(workchain_projwfc=running) def inspect_projwfc(self): # pylint: disable=inconsistent-return-statements """Verify that the `ProjwfcCalculation` for the projwfc run successfully finished.""" workchain = self.ctx.workchain_projwfc if not workchain.is_finished_ok: self.report( f"{workchain.process_label} failed with exit status {workchain.exit_status}" ) return self.exit_codes.ERROR_SUB_PROCESS_FAILED_PROJWFC def prepare_wannier90_pp_inputs(self): # pylint: disable=too-many-statements """Prepare the inputs of wannier90 calculation before submission. This method will be called by the workchain at runtime, to fill some parameters such as Fermi energy which can only be retrieved after scf step. Moreover, this allows overriding the method in derived classes to further modify the inputs. """ from aiida_wannier90_workflows.utils.workflows.pw import ( get_fermi_energy, get_fermi_energy_from_nscf, ) base_inputs = AttributeDict( self.exposed_inputs(Wannier90BaseWorkChain, namespace="wannier90") ) inputs = base_inputs["wannier90"] inputs.structure = self.ctx.current_structure parameters = inputs.parameters.get_dict() # Add Fermi energy if "workchain_scf" in self.ctx: scf_output_parameters = self.ctx.workchain_scf.outputs.output_parameters fermi_energy = get_fermi_energy(scf_output_parameters) elif "workchain_nscf" in self.ctx: fermi_energy = get_fermi_energy_from_nscf(self.ctx.workchain_nscf) else: raise ValueError("Cannot retrieve Fermi energy from scf or nscf output") parameters["fermi_energy"] = fermi_energy inputs.parameters = orm.Dict(parameters) # Add `postproc_setup` if "settings" in inputs: settings = inputs["settings"].get_dict() else: settings = {} settings["postproc_setup"] = True inputs["settings"] = settings # I should not stash files in postproc, otherwise there is a RemoteStashFolderData in outputs inputs["metadata"]["options"].pop("stash", None) base_inputs["wannier90"] = inputs if base_inputs["shift_energy_windows"] and "bands" not in base_inputs: if "workchain_scf" in self.ctx: output_band = self.ctx.workchain_scf.outputs.output_band elif "workchain_nscf" in self.ctx: output_band = self.ctx.workchain_nscf.outputs.output_band else: raise ValueError("No output scf or nscf bands") base_inputs.bands = output_band if base_inputs["auto_energy_windows"]: if "bands" not in base_inputs: base_inputs.bands = self.ctx.workchain_projwfc.outputs.bands if "bands_projections" not in base_inputs: base_inputs.bands_projections = ( self.ctx.workchain_projwfc.outputs.projections ) base_inputs["clean_workdir"] = orm.Bool(False) return base_inputs def run_wannier90_pp(self): """Wannier90 post processing step.""" inputs = self.prepare_wannier90_pp_inputs() inputs["metadata"] = {"call_link_label": "wannier90_pp"} inputs = prepare_process_inputs(Wannier90BaseWorkChain, inputs) running = self.submit(Wannier90BaseWorkChain, **inputs) self.report(f"launching {running.process_label}<{running.pk}> in postproc mode") return ToContext(workchain_wannier90_pp=running) def inspect_wannier90_pp(self): # pylint: disable=inconsistent-return-statements """Verify that the `Wannier90Calculation` for the wannier90 run successfully finished.""" workchain = self.ctx.workchain_wannier90_pp if not workchain.is_finished_ok: self.report( f"wannier90 postproc {workchain.process_label} failed with exit status {workchain.exit_status}" ) return self.exit_codes.ERROR_SUB_PROCESS_FAILED_WANNIER90PP def prepare_pw2wannier90_inputs(self): """Prepare the inputs of `Pw2wannier90BaseWorkChain` before submission. This method will be called by the workchain at runtime, so it can dynamically add/modify inputs based on outputs of previous calculations, e.g. add bands and projections for calculating scdm_mu/sigma from projectability, etc. Moreover, it can be overridden in derived classes. """ base_inputs = AttributeDict( self.exposed_inputs(Pw2wannier90BaseWorkChain, namespace="pw2wannier90") ) inputs = base_inputs["pw2wannier90"] parameters = inputs.parameters.get_dict().get("inputpp", {}) scdm_proj = parameters.get("scdm_proj", False) scdm_entanglement = parameters.get("scdm_entanglement", None) scdm_mu = parameters.get("scdm_mu", None) scdm_sigma = parameters.get("scdm_sigma", None) fit_scdm = ( scdm_proj and scdm_entanglement == "erfc" and (scdm_mu is None or scdm_sigma is None) ) if fit_scdm: if "workchain_projwfc" not in self.ctx: raise ValueError("Needs to run projwfc for SCDM projection") base_inputs["bands"] = self.ctx.workchain_projwfc.outputs.bands base_inputs[ "bands_projections" ] = self.ctx.workchain_projwfc.outputs.projections inputs["parent_folder"] = self.ctx.current_folder inputs["nnkp_file"] = self.ctx.workchain_wannier90_pp.outputs.nnkp_file base_inputs["pw2wannier90"] = inputs return base_inputs def run_pw2wannier90(self): """Run the pw2wannier90 step.""" inputs = self.prepare_pw2wannier90_inputs() inputs.metadata.call_link_label = "pw2wannier90" inputs = prepare_process_inputs(Pw2wannier90BaseWorkChain, inputs) running = self.submit(Pw2wannier90BaseWorkChain, **inputs) self.report(f"launching {running.process_label}<{running.pk}>") return ToContext(workchain_pw2wannier90=running) def inspect_pw2wannier90(self): # pylint: disable=inconsistent-return-statements """Verify that the Pw2wannier90BaseWorkChain for the pw2wannier90 run successfully finished.""" workchain = self.ctx.workchain_pw2wannier90 if not workchain.is_finished_ok: self.report( f"{workchain.process_label} failed with exit status {workchain.exit_status}" ) return self.exit_codes.ERROR_SUB_PROCESS_FAILED_PW2WANNIER90 self.ctx.current_folder = workchain.outputs.remote_folder def prepare_wannier90_inputs(self): # pylint: disable=too-many-statements """Prepare the inputs of wannier90 calculation before submission. This method will be called by the workchain at runtime, to fill some parameters such as Fermi energy which can only be retrieved after scf step. Moreover, this allows overriding the method in derived classes to further modify the inputs. """ from copy import deepcopy from aiida_wannier90_workflows.utils.workflows import get_last_calcjob base_inputs = AttributeDict( self.exposed_inputs(Wannier90BaseWorkChain, namespace="wannier90") ) # I need to disable Fermi energy shifting since this is done in postproc step, # otherwise it will be shifted twice! base_inputs.pop("shift_energy_windows", None) base_inputs.pop("auto_energy_windows", None) base_inputs.pop("auto_energy_windows_threshold", None) base_inputs.pop("bands", None) base_inputs.pop("bands_projections", None) inputs = base_inputs["wannier90"] # I should stash files, which was removed from metadata in the postproc step stash = None if "stash" in inputs["metadata"]["options"]: stash = deepcopy(inputs["metadata"]["options"]["stash"]) # Use the Wannier90BaseWorkChain-corrected parameters last_calc = get_last_calcjob(self.ctx.workchain_wannier90_pp) # copy postproc inputs, especially the `kmesh_tol` might have been corrected for key in last_calc.inputs: inputs[key] = last_calc.inputs[key] inputs["remote_input_folder"] = self.ctx.current_folder if "settings" in inputs: settings = inputs.settings.get_dict() else: settings = {} settings["postproc_setup"] = False inputs.settings = settings # Restore stash files if stash: options = deepcopy(inputs["metadata"]["options"]) options["stash"] = stash inputs["metadata"]["options"] = options base_inputs["wannier90"] = inputs base_inputs["clean_workdir"] = orm.Bool(False) return base_inputs def run_wannier90(self): """Wannier90 step for MLWF.""" inputs = self.prepare_wannier90_inputs() inputs["metadata"] = {"call_link_label": "wannier90"} inputs = prepare_process_inputs(Wannier90BaseWorkChain, inputs) running = self.submit(Wannier90BaseWorkChain, **inputs) self.report(f"launching {running.process_label}<{running.pk}>") return ToContext(workchain_wannier90=running) def inspect_wannier90(self): # pylint: disable=inconsistent-return-statements """Verify that the `Wannier90BaseWorkChain` for the wannier90 run successfully finished.""" workchain = self.ctx.workchain_wannier90 if not workchain.is_finished_ok: self.report( f"{workchain.process_label} failed with exit status {workchain.exit_status}" ) return self.exit_codes.ERROR_SUB_PROCESS_FAILED_WANNIER90 self.ctx.current_folder = workchain.outputs.remote_folder def results(self): # pylint: disable=inconsistent-return-statements """Attach the desired output nodes directly as outputs of the workchain.""" if "workchain_scf" in self.ctx: self.out_many( self.exposed_outputs( self.ctx.workchain_scf, PwBaseWorkChain, namespace="scf" ) ) if "workchain_nscf" in self.ctx: self.out_many( self.exposed_outputs( self.ctx.workchain_nscf, PwBaseWorkChain, namespace="nscf" ) ) if "workchain_projwfc" in self.ctx: self.out_many( self.exposed_outputs( self.ctx.workchain_projwfc, ProjwfcBaseWorkChain, namespace="projwfc", ) ) self.out_many( self.exposed_outputs( self.ctx.workchain_pw2wannier90, Pw2wannier90BaseWorkChain, namespace="pw2wannier90", ) ) self.out_many( self.exposed_outputs( self.ctx.workchain_wannier90_pp, Wannier90BaseWorkChain, namespace="wannier90_pp", ) ) self.out_many( self.exposed_outputs( self.ctx.workchain_wannier90, Wannier90BaseWorkChain, namespace="wannier90", ) ) result = self.sanity_check() if result: return result self.report(f"{self.get_name()} successfully completed") def sanity_check(self): # pylint: disable=inconsistent-return-statements """Sanity checks for final outputs. Not necessary but it is good to check it. """ from aiida_wannier90_workflows.utils.pseudo import ( get_number_of_electrons, get_number_of_projections, ) # If using external atomic projectors, disable sanity check p2w_params = self.ctx.workchain_pw2wannier90.inputs["pw2wannier90"][ "parameters" ].get_dict()["inputpp"] atom_proj = p2w_params.get("atom_proj", False) atom_proj_ext = p2w_params.get("atom_proj_ext", False) if atom_proj and atom_proj_ext: return # 1. the calculated number of projections is consistent with QE projwfc.x if "scf" in self.inputs: pseudos = self.inputs["scf"]["pw"]["pseudos"] else: pseudos = self.inputs["nscf"]["pw"]["pseudos"] args = { "structure": self.ctx.current_structure, # The type of `self.inputs['scf']['pw']['pseudos']` is AttributesFrozendict, # we need to convert it to dict, otherwise get_number_of_projections will fail. "pseudos": dict(pseudos), } if "workchain_projwfc" in self.ctx: num_proj = len( self.ctx.workchain_projwfc.outputs["projections"].get_orbitals() ) params = self.ctx.workchain_wannier90.inputs["wannier90"][ "parameters" ].get_dict() spin_orbit_coupling = params.get("spinors", False) number_of_projections = get_number_of_projections( **args, spin_orbit_coupling=spin_orbit_coupling ) if number_of_projections != num_proj: self.report( f"number of projections {number_of_projections} != projwfc.x output {num_proj}" ) return self.exit_codes.ERROR_SANITY_CHECK_FAILED # 2. the number of electrons is consistent with QE output if "workchain_scf" in self.ctx: num_elec = self.ctx.workchain_scf.outputs["output_parameters"][ "number_of_electrons" ] else: num_elec = self.ctx.workchain_nscf.outputs["output_parameters"][ "number_of_electrons" ] number_of_electrons = get_number_of_electrons(**args) if number_of_electrons != num_elec: self.report( f"number of electrons {number_of_electrons} != QE output {num_elec}" ) return self.exit_codes.ERROR_SANITY_CHECK_FAILED def on_terminated(self): """Clean the working directories of all child calculations if `clean_workdir=True` in the inputs.""" super().on_terminated() if not self.inputs.clean_workdir: self.report("remote folders will not be cleaned") return cleaned_calcs = [] for called_descendant in self.node.called_descendants: if isinstance(called_descendant, orm.CalcJobNode): try: called_descendant.outputs.remote_folder._clean() # pylint: disable=protected-access cleaned_calcs.append(called_descendant.pk) except (OSError, KeyError): pass if cleaned_calcs: self.report( f"cleaned remote folders of calculations: {' '.join(map(str, cleaned_calcs))}" )