Source code for aiida_wannier90_workflows.workflows.base.pw2wannier90

"""Wrapper workchain for Pw2wannier90Calculation to automatically handle several errors."""
import pathlib
import typing as ty

from aiida import orm
from aiida.common import AttributeDict
from aiida.common.lang import type_check
from aiida.engine import process_handler
from aiida.engine.processes.builder import ProcessBuilder
from aiida.orm.nodes.data.base import to_aiida_type

from aiida_quantumespresso.calculations.pw2wannier90 import Pw2wannier90Calculation
from aiida_quantumespresso.common.types import ElectronicType
from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin

from aiida_wannier90_workflows.common.types import WannierProjectionType

from .qebaserestart import QeBaseRestartWorkChain

__all__ = ["validate_inputs_base", "validate_inputs", "Pw2wannier90BaseWorkChain"]


def validate_inputs_base(  # pylint: disable=unused-argument
    inputs: AttributeDict, ctx=None
) -> None:
    """Validate the inputs of the entire input namespace."""
    return


def validate_inputs(  # pylint: disable=unused-argument,inconsistent-return-statements
    inputs: AttributeDict, ctx=None
) -> None:
    """Validate the inputs of the entire input namespace of `Pw2wannier90BaseWorkChain`."""

    result = validate_inputs_base(inputs, ctx)  # pylint: disable=assignment-from-none
    if result:
        return result

    calc_inputs = AttributeDict(inputs[Pw2wannier90BaseWorkChain._inputs_namespace])
    calc_parameters = calc_inputs["parameters"].get_dict().get("inputpp", {})

    scdm_proj = calc_parameters.get("scdm_proj", False)
    scdm_entanglement = calc_parameters.get("scdm_entanglement", "isolated")

    fit_scdm = False
    if scdm_proj and scdm_entanglement != "isolated":
        scdm_mu = calc_parameters.get("scdm_mu", None)
        scdm_sigma = calc_parameters.get("scdm_sigma", None)
        if scdm_mu is None or scdm_sigma is None:
            fit_scdm = True

    if fit_scdm:
        # Check bands and bands_projections are provided
        if any(_ not in inputs for _ in ("bands_projections", "bands")):
            return "`scdm_proj` is True but `bands_projections` or `bands` is empty"

        # Check `bands` and `bands_projections` are consistent
        bands_num_kpoints, bands_num_bands = inputs["bands"].attributes["array|bands"]
        projections_num_kpoints, projections_num_bands = inputs[
            "bands_projections"
        ].attributes["array|proj_array_0"]
        if bands_num_kpoints != projections_num_kpoints:
            return (
                "`bands` and `bands_projections` have different number of kpoints: "
                f"{bands_num_kpoints} != {projections_num_kpoints}"
            )
        if bands_num_bands != projections_num_bands:
            return (
                "`bands` and `bands_projections` have different number of bands: "
                f"{bands_num_bands} != {projections_num_bands}"
            )

    atom_proj = calc_parameters.get("atom_proj", False)
    atom_proj_ext = calc_parameters.get("atom_proj_ext", False)
    atom_proj_dir = calc_parameters.get("atom_proj_dir", None)
    if atom_proj and atom_proj_ext and not atom_proj_dir:
        return "`atom_proj_dir` must be specified when using external projectors."


[docs]class Pw2wannier90BaseWorkChain(ProtocolMixin, QeBaseRestartWorkChain): """Workchain to run a pw2wannier90 calculation with automated error handling and restarts.""" _process_class = Pw2wannier90Calculation _inputs_namespace = "pw2wannier90" @classmethod def define(cls, spec) -> None: """Define the process spec.""" super().define(spec) spec.input( "scdm_sigma_factor", valid_type=orm.Float, default=lambda: orm.Float(3.0), serializer=to_aiida_type, help="The `sigma` factor of occupation function for SCDM projection.", ) spec.input( "bands", valid_type=orm.BandsData, required=False, help="Bands to calculate SCDM `mu`, `sigma`.", ) spec.input( "bands_projections", valid_type=orm.ProjectionData, required=False, help="Bands projectability to calculate SCDM `mu`, `sigma`.", ) spec.inputs.validator = validate_inputs spec.exit_code( 400, "ERROR_SCDM_FITTING", message="Error when fitting `scdm_mu` and `scdm_sigma`.", ) @classmethod def get_protocol_filepath(cls) -> pathlib.Path: """Return the ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" from importlib_resources import files from .. import protocols return files(protocols) / "base" / "pw2wannier90.yaml" @classmethod def get_builder_from_protocol( cls, code: ty.Union[orm.Code, str, int], *, protocol: str = None, overrides: dict = None, electronic_type: ElectronicType = ElectronicType.METAL, projection_type: WannierProjectionType = WannierProjectionType.ATOMIC_PROJECTORS_QE, exclude_projectors: list = None, external_projectors_path: str = None, ) -> ProcessBuilder: """Return a builder prepopulated with inputs selected according to the chosen protocol. :param code: [description] :type code: ty.Union[orm.Code, str, int] :param structure: [description] :type structure: orm.StructureData :param protocol: [description], defaults to None :type protocol: str, optional :param overrides: [description], defaults to None :type overrides: dict, optional :return: [description] :rtype: ProcessBuilder """ from aiida_quantumespresso.workflows.protocols.utils import recursive_merge if isinstance(code, (int, str)): code = orm.load_code(code) type_check(code, orm.Code) type_check(electronic_type, ElectronicType) type_check(projection_type, WannierProjectionType) # Update the parameters based on the protocol inputs inputs = cls.get_protocol_inputs(protocol, overrides) parameters = inputs[cls._inputs_namespace]["parameters"]["inputpp"] metadata = inputs[cls._inputs_namespace]["metadata"] # Set projection if projection_type == WannierProjectionType.SCDM: parameters["scdm_proj"] = True if electronic_type == ElectronicType.INSULATOR: parameters["scdm_entanglement"] = "isolated" else: parameters["scdm_entanglement"] = "erfc" # scdm_mu, scdm_sigma will be set at runtime elif projection_type in [ WannierProjectionType.ATOMIC_PROJECTORS_QE, WannierProjectionType.ATOMIC_PROJECTORS_OPENMX, ]: parameters["atom_proj"] = True if exclude_projectors is not None and len(exclude_projectors) > 0: parameters["atom_proj_exclude"] = list(exclude_projectors) if projection_type == WannierProjectionType.ATOMIC_PROJECTORS_OPENMX: parameters["atom_proj_ext"] = True if external_projectors_path is None: raise ValueError( f"Must specify `external_projectors_path` when using {projection_type}" ) parameters["atom_proj_dir"] = external_projectors_path parameters = {"inputpp": parameters} # If overrides are provided, they take precedence over default protocol if overrides: parameter_overrides = overrides.get(cls._inputs_namespace, {}).get( "parameters", {} ) parameters = recursive_merge(parameters, parameter_overrides) metadata_overrides = overrides.get(cls._inputs_namespace, {}).get( "metadata", {} ) metadata = recursive_merge(metadata, metadata_overrides) # pylint: disable=no-member builder = cls.get_builder() builder[cls._inputs_namespace]["code"] = code builder[cls._inputs_namespace]["parameters"] = orm.Dict(parameters) builder[cls._inputs_namespace]["metadata"] = metadata if "settings" in inputs[cls._inputs_namespace]: builder[cls._inputs_namespace]["settings"] = orm.Dict( dict=inputs[cls._inputs_namespace]["settings"] ) if "settings" in inputs: builder["settings"] = orm.Dict(inputs["settings"]) builder.clean_workdir = orm.Bool(inputs["clean_workdir"]) # pylint: enable=no-member return builder def setup(self) -> None: """Call the `setup` of the `BaseRestartWorkChain` and then create the inputs dictionary in `self.ctx.inputs`. This `self.ctx.inputs` dictionary will be used by the `BaseRestartWorkChain` to submit the calculations in the internal loop. """ super().setup() self.ctx.inputs = self.prepare_inputs() def prepare_inputs(self) -> AttributeDict: """Prepare `Pw2wannier90Calculation` inputs according to workchain input sepc. Different from `get_builder_from_protocol', this function is executed at runtime. """ from aiida_wannier90_workflows.utils.scdm import fit_scdm_mu_sigma inputs = AttributeDict( self.exposed_inputs(Pw2wannier90Calculation, self._inputs_namespace) ) parameters = inputs["parameters"].get_dict().get("inputpp", {}) scdm_proj = parameters.get("scdm_proj", False) scdm_entanglement = parameters.get("scdm_entanglement", None) scdm_mu = parameters.get("scdm_mu", None) scdm_sigma = parameters.get("scdm_sigma", None) fit_scdm = ( scdm_proj and scdm_entanglement == "erfc" and (scdm_mu is None or scdm_sigma is None) ) if scdm_entanglement == "gaussian": if scdm_mu is None or scdm_sigma is None: raise NotImplementedError( "scdm_entanglement = gaussian but scdm_mu or scdm_sigma is empty." ) if fit_scdm: # pylint: disable=unbalanced-tuple-unpacking try: mu_new, sigma_new = fit_scdm_mu_sigma( self.inputs.bands, self.inputs.bands_projections, self.inputs.scdm_sigma_factor, ) except ValueError: # raise ValueError(f'SCDM mu/sigma fitting failed! {exc.args}') from exc return self.exit_codes.ERROR_SCDM_FITTING # If `scdm_mu` and/or `scdm_sigma` is present in the input parameters, # the workchain will directly use them, only the missing one will be populated. if "scdm_mu" not in parameters: parameters["scdm_mu"] = mu_new if "scdm_sigma" not in parameters: parameters["scdm_sigma"] = sigma_new inputs["parameters"] = orm.Dict({"inputpp": parameters}) return inputs @process_handler( exit_codes=[_process_class.exit_codes.ERROR_OUTPUT_STDOUT_INCOMPLETE] ) def handle_output_stdout_incomplete(self, calculation): """Overide parent function.""" return super().handle_output_stdout_incomplete(calculation)