Source code for aiida_wannier90_workflows.workflows.base.wannier90

"""Wrapper workchain for `Wannier90Calculation` to automatically handle several errors."""
import pathlib
import typing as ty

from aiida import orm
from aiida.common import AttributeDict
from aiida.common.lang import type_check
from aiida.engine import (
    BaseRestartWorkChain,
    ProcessHandlerReport,
    process_handler,
    while_,
)
from aiida.engine.processes.builder import ProcessBuilder
from aiida.orm.nodes.data.base import to_aiida_type

from aiida_quantumespresso.common.types import ElectronicType, SpinType
from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin

from aiida_wannier90.calculations import Wannier90Calculation

from aiida_wannier90_workflows.common.types import (
    WannierDisentanglementType,
    WannierFrozenType,
    WannierProjectionType,
)

__all__ = ["validate_inputs_base", "validate_inputs", "Wannier90BaseWorkChain"]

# pylint: disable=inconsistent-return-statements


def validate_inputs_base(inputs: AttributeDict, ctx=None) -> None:
    """Validate the inputs of the entire input namespace."""
    # pylint: disable=inconsistent-return-statements,unused-argument

    # Check `settings`
    if "settings" in inputs:
        settings = inputs["settings"].get_dict()
        valid_keys = ("remote_symlink_files",)
        for key in settings:
            if key not in valid_keys:
                return f"Invalid settings: `{key}`, valid keys are: {valid_keys}"


def validate_inputs(inputs: AttributeDict, ctx=None) -> None:
    """Validate the inputs of the entire input namespace."""
    # pylint: disable=too-many-return-statements,unused-argument

    result = validate_inputs_base(inputs, ctx)
    if result:
        return result

    # pylint: disable=protected-access
    calc_inputs = AttributeDict(inputs[Wannier90BaseWorkChain._inputs_namespace])
    calc_parameters = calc_inputs.parameters.get_dict()

    # Check existence of `fermi_energy`
    if inputs["shift_energy_windows"]:
        if "fermi_energy" not in calc_parameters:
            return "`shift_energy_windows` is requested but no `fermi_energy` in input parameters"
        if "bands" not in inputs:
            return "`shift_energy_windows` is requested but no `bands` in inputs"

    # Check `bands`
    if any(_ in inputs for _ in ("bands", "bands_projections")) and all(
        _ not in inputs for _ in ("shift_energy_windows", "auto_energy_windows")
    ):
        return (
            "`bands` and/or `bands_projections` are provided but both `shift_energy_windows` "
            "and `auto_energy_windows` are False?"
        )

    # Check `auto_energy_windows`
    if inputs["auto_energy_windows"]:
        if inputs["shift_energy_windows"]:
            return "No need to shift energy windows when auto set energy windows"

        if any(_ not in inputs for _ in ("bands_projections", "bands")):
            return "`auto_energy_windows` is requested but `bands_projections` or `bands` is empty"

        # Check bands and bands_projections are consistent
        bands_num_kpoints, bands_num_bands = inputs["bands"].attributes["array|bands"]
        projections_num_kpoints, projections_num_bands = inputs[
            "bands_projections"
        ].attributes["array|proj_array_0"]
        if bands_num_kpoints != projections_num_kpoints:
            return (
                "`bands` and `bands_projections` have different number of kpoints: "
                f"{bands_num_kpoints} != {projections_num_kpoints}"
            )
        if bands_num_bands != projections_num_bands:
            return (
                "`bands` and `bands_projections` have different number of bands: "
                f"{bands_num_bands} != {projections_num_bands}"
            )


[docs]class Wannier90BaseWorkChain(ProtocolMixin, BaseRestartWorkChain): """Workchain to run a `Wannier90Calculation` with automated error handling and restarts.""" _process_class = Wannier90Calculation _inputs_namespace = "wannier90" _WANNIER90_DEFAULT_KMESH_TOL = 1e-6 _WANNIER90_DEFAULT_DIS_PROJ_MIN = 0.1 _WANNIER90_DEFAULT_DIS_PROJ_MAX = 0.9 _WANNIER90_DEFAULT_WANNIER_PLOT_SUPERCELL = 2 @classmethod def define(cls, spec) -> None: """Define the process spec.""" from aiida_wannier90.calculations.wannier90 import ( validate_inputs_base as validate_inputs_base_wannier90, ) super().define(spec) spec.expose_inputs(Wannier90Calculation, namespace=cls._inputs_namespace) spec.inputs[cls._inputs_namespace]["metadata"]["options"][ "resources" ].default = { "num_machines": 1, # 'num_mpiprocs_per_machine': 1, } spec.inputs["wannier90"].validator = validate_inputs_base_wannier90 spec.input( "shift_energy_windows", valid_type=orm.Bool, default=lambda: orm.Bool(False), serializer=to_aiida_type, help=( "If True the `dis_froz_min`, `dis_froz_max`, `dis_win_min`, `dis_win_max` will be shifted by " "Fermi enerngy. False is the default behaviour of wannier90." ), ) spec.input( "auto_energy_windows", valid_type=orm.Bool, default=lambda: orm.Bool(False), serializer=to_aiida_type, help=( "If True use the energy corresponding to projectability = auto_energy_windows_threshold " "as `dis_froz_max` for wannier90." ), ) spec.input( "auto_energy_windows_threshold", valid_type=orm.Float, default=lambda: orm.Float(0.9), serializer=to_aiida_type, help="Threshold for auto_energy_windows.", ) spec.input( "bands", valid_type=orm.BandsData, required=False, help=( "For shift_energy_windows, if provided the energy windows will be shifted by Fermi energy " "for metals or minimum of lowest-unoccupied bands for insulators. " "The bands should be along a kpath to better estimate the band gap. " "For auto_energy_windows, the bands is used to find out the energy corresponds to " "projectability = auto_energy_windows_threshold, the energy is used as `dis_froz_max`. " "In this case the bands should be on a nscf kmesh." ), ) spec.input( "bands_projections", valid_type=orm.ProjectionData, required=False, help="Projectability of bands to auto set `dis_froz_max`.", ) spec.input( "settings", valid_type=orm.Dict, required=False, serializer=to_aiida_type, help="Additional settings.", ) spec.inputs.validator = validate_inputs spec.outline( cls.setup, while_(cls.should_run_process)( cls.run_process, cls.inspect_process, ), cls.results, ) spec.expose_outputs(Wannier90Calculation) spec.exit_code(401, "ERROR_BVECTORS", message="Unrecoverable bvectors error.") spec.exit_code( 402, "ERROR_DISENTANGLEMENT_NOT_ENOUGH_STATES", message="Unrecoverable disentanglement error.", ) spec.exit_code( 403, "ERROR_PLOT_WF_CUBE", message="Unrecoverable cube format error." ) spec.exit_code( 404, "ERROR_OUTPUT_STDOUT_INCOMPLETE", message="The stdout output file was incomplete probably because the calculation got interrupted.", ) @classmethod def get_protocol_filepath(cls) -> pathlib.Path: """Return the ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" from importlib_resources import files from .. import protocols return files(protocols) / "base" / "wannier90.yaml" @classmethod def get_builder_from_protocol( cls, code: ty.Union[orm.Code, str, int], *, structure: orm.StructureData, protocol: str = None, overrides: dict = None, pseudo_family: str = None, electronic_type: ElectronicType = ElectronicType.METAL, spin_type: SpinType = SpinType.NONE, projection_type: WannierProjectionType = WannierProjectionType.ATOMIC_PROJECTORS_QE, disentanglement_type: WannierDisentanglementType = WannierDisentanglementType.SMV, frozen_type: WannierFrozenType = WannierFrozenType.FIXED_PLUS_PROJECTABILITY, ) -> ProcessBuilder: """Return a builder prepopulated with inputs selected according to the chosen protocol. :param code: [description] :type code: ty.Union[orm.Code, str, int] :param structure: [description] :type structure: orm.StructureData :param protocol: [description], defaults to None :type protocol: str, optional :param overrides: [description], defaults to None :type overrides: dict, optional :return: [description] :rtype: ProcessBuilder """ # pylint: disable=too-many-statements,too-many-locals,too-many-branches from aiida_quantumespresso.workflows.protocols.utils import recursive_merge from aiida_wannier90_workflows.utils.kpoints import ( create_kpoints_from_distance, get_explicit_kpoints, ) from aiida_wannier90_workflows.utils.pseudo import ( get_number_of_projections, get_pseudo_and_cutoff, get_pseudo_orbitals, get_semicore_list, get_wannier_number_of_bands, ) if isinstance(code, (int, str)): code = orm.load_code(code) type_check(code, orm.Code) type_check(structure, orm.StructureData) type_check(electronic_type, ElectronicType) type_check(spin_type, SpinType) type_check(projection_type, WannierProjectionType) type_check(disentanglement_type, WannierDisentanglementType) type_check(frozen_type, WannierFrozenType) if electronic_type in [ElectronicType.AUTOMATIC]: raise NotImplementedError("`ElectronicType.AUTOMATIC` not implemented") if spin_type == SpinType.COLLINEAR: raise NotImplementedError("`SpinType.COLLINEAR` not implemented") if spin_type == SpinType.SPIN_ORBIT and pseudo_family is None: raise ValueError( "Need to explicitly specify `pseudo_family` for SOC calculation" ) inputs = cls.get_protocol_inputs(protocol, overrides) # Update the parameters based on the protocol inputs parameters = inputs[cls._inputs_namespace]["parameters"] metadata = inputs[cls._inputs_namespace]["metadata"] meta_parameters = inputs.pop("meta_parameters") num_atoms = len(structure.sites) parameters["conv_tol"] = num_atoms * meta_parameters["conv_tol_per_atom"] parameters["dis_conv_tol"] = ( num_atoms * meta_parameters["dis_conv_tol_per_atom"] ) # Set `num_bands`, `num_wann`, also take care of semicore states only_valence = electronic_type == ElectronicType.INSULATOR spin_polarized = spin_type == SpinType.COLLINEAR spin_orbit_coupling = spin_type == SpinType.SPIN_ORBIT if pseudo_family is None: pseudo_family = meta_parameters["pseudo_family"] pseudos, _, _ = get_pseudo_and_cutoff(pseudo_family, structure) num_bands = get_wannier_number_of_bands( structure=structure, pseudos=pseudos, factor=meta_parameters["num_bands_factor"], only_valence=only_valence, spin_polarized=spin_polarized, spin_orbit_coupling=spin_orbit_coupling, ) num_projs = get_number_of_projections( structure=structure, pseudos=pseudos, spin_orbit_coupling=spin_orbit_coupling, ) if electronic_type == ElectronicType.INSULATOR: num_wann = num_bands else: num_wann = num_projs if meta_parameters["exclude_semicore"]: pseudo_orbitals = get_pseudo_orbitals(pseudos) semicore_list = get_semicore_list(structure, pseudo_orbitals) num_excludes = len(semicore_list) # TODO I assume all the semicore bands are the lowest # pylint: disable=fixme exclude_pswfcs = range(1, num_excludes + 1) if num_excludes != 0: parameters["exclude_bands"] = exclude_pswfcs num_wann -= num_excludes num_bands -= num_excludes if num_wann <= 0: raise ValueError(f"Wrong num_wann {num_wann}") parameters["num_wann"] = num_wann parameters["num_bands"] = num_bands # Spinor if spin_type in [SpinType.NON_COLLINEAR, SpinType.SPIN_ORBIT]: parameters["spinors"] = True # Set initial projections if projection_type in [ WannierProjectionType.SCDM, WannierProjectionType.ATOMIC_PROJECTORS_QE, WannierProjectionType.ATOMIC_PROJECTORS_OPENMX, ]: parameters["auto_projections"] = True elif projection_type == WannierProjectionType.ANALYTIC: pseudo_orbitals = get_pseudo_orbitals(pseudos) projections = [] for kind in structure.kinds: for orb in pseudo_orbitals[kind.name]["pswfcs"]: if meta_parameters["exclude_semicore"]: if orb in pseudo_orbitals[kind.name]["semicores"]: continue projections.append(f"{kind.name}:{orb[-1].lower()}") inputs[cls._inputs_namespace]["projections"] = orm.List(list=projections) elif projection_type == WannierProjectionType.RANDOM: settings = inputs[cls._inputs_namespace].get("settings", {}) settings.update({"random_projections": True}) inputs[cls._inputs_namespace]["settings"] = settings else: raise ValueError(f"Unrecognized projection type {projection_type}") # Set disentanglement if disentanglement_type == WannierDisentanglementType.NONE: parameters["dis_num_iter"] = 0 parameters.pop("dis_froz_min", None) parameters.pop("dis_froz_max", None) parameters.pop("dis_win_min", None) parameters.pop("dis_win_max", None) elif disentanglement_type == WannierDisentanglementType.SMV: if frozen_type == WannierFrozenType.ENERGY_FIXED: inputs["shift_energy_windows"] = True parameters.update( { # Here +2 means fermi_energy + 2 eV, however Fermi energy is calculated at runtime # inside Wannier90WorkChain, so it will add Fermi energy with this # dis_froz_max dynamically. "dis_froz_max": +2.0, } ) elif frozen_type == WannierFrozenType.ENERGY_AUTO: # ENERGY_AUTO needs projectability, will be set dynamically when workchain is running inputs["auto_energy_windows"] = True elif frozen_type == WannierFrozenType.PROJECTABILITY: parameters.update( { "dis_proj_min": 0.01, "dis_proj_max": 0.95, } ) elif frozen_type == WannierFrozenType.FIXED_PLUS_PROJECTABILITY: inputs["shift_energy_windows"] = True parameters.update( { "dis_proj_min": 0.01, "dis_proj_max": 0.95, "dis_froz_max": +2.0, # relative to fermi_energy } ) elif frozen_type == WannierFrozenType.NONE: parameters.pop("dis_froz_min", None) parameters.pop("dis_froz_max", None) parameters.pop("dis_win_min", None) parameters.pop("dis_win_max", None) else: raise ValueError(f"Not supported frozen type: {frozen_type}") else: raise ValueError( f"Not supported disentanglement type: {disentanglement_type}" ) # Set kpoints # If inputs.kpoints is a kmesh, mp_grid will be auto-set by `Wannier90Calculation`, # otherwise we need to set it manually. If use open_grid, kpoints will be set dynamically # after open_grid calculation. kpoints = create_kpoints_from_distance( structure, distance=meta_parameters["kpoints_distance"], force_parity=meta_parameters["kpoints_force_parity"], ) inputs["kpoints"] = get_explicit_kpoints(kpoints) parameters["mp_grid"] = kpoints.get_kpoints_mesh()[0] # If overrides are provided, they take precedence over default values if overrides: parameter_overrides = overrides.get(cls._inputs_namespace, {}).get( "parameters", {} ) parameters = recursive_merge(parameters, parameter_overrides) metadata_overrides = overrides.get(cls._inputs_namespace, {}).get( "metadata", {} ) metadata = recursive_merge(metadata, metadata_overrides) # pylint: disable=no-member builder = cls.get_builder() builder[cls._inputs_namespace]["code"] = code builder[cls._inputs_namespace]["structure"] = structure builder[cls._inputs_namespace]["kpoints"] = inputs["kpoints"] builder[cls._inputs_namespace]["parameters"] = orm.Dict(parameters) builder[cls._inputs_namespace]["metadata"] = metadata if "projections" in inputs[cls._inputs_namespace]: builder[cls._inputs_namespace]["projections"] = inputs[ cls._inputs_namespace ]["projections"] if "settings" in inputs[cls._inputs_namespace]: builder[cls._inputs_namespace]["settings"] = orm.Dict( dict=inputs[cls._inputs_namespace]["settings"] ) if "settings" in inputs: builder["settings"] = orm.Dict(inputs["settings"]) builder["clean_workdir"] = orm.Bool(inputs["clean_workdir"]) builder["shift_energy_windows"] = orm.Bool(inputs["shift_energy_windows"]) builder["auto_energy_windows"] = orm.Bool(inputs["auto_energy_windows"]) builder["auto_energy_windows_threshold"] = orm.Float( inputs["auto_energy_windows_threshold"] ) # pylint: enable=no-member return builder def setup(self) -> None: """Call the `setup` of the `BaseRestartWorkChain` and then create the inputs dictionary in `self.ctx.inputs`. This `self.ctx.inputs` dictionary will be used by the `BaseRestartWorkChain` to submit the calculations in the internal loop. """ super().setup() self.ctx.inputs = self.prepare_inputs() self.ctx.kmeshtol_new = [self._WANNIER90_DEFAULT_KMESH_TOL, 1e-8, 1e-4] self.ctx.disprojmin_multipliers = [0.5, 0.25, 0.125, 0] self.ctx.wannier_plot_supercell_new = [4, 6, 8, 10] def prepare_inputs( self, ) -> AttributeDict: """Prepare `Wannier90Calculation` inputs according to workchain input sepc. Different from `get_builder_from_protocol', this function is executed at runtime. """ # pylint: disable=too-many-statements,too-many-locals,too-many-branches import numpy as np from aiida_wannier90_workflows.utils.bands import ( get_homo_lumo, remove_exclude_bands, ) from aiida_wannier90_workflows.utils.scdm import get_energy_of_projectability inputs = AttributeDict( self.exposed_inputs(Wannier90Calculation, self._inputs_namespace) ) parameters = inputs.parameters.get_dict() if self.inputs.shift_energy_windows: fermi_energy = parameters["fermi_energy"] # For metal, we shift the four parameters by Fermi energy. shift_energy = fermi_energy if "bands" in self.inputs: # Check the system is metal or insulator. # For insulator, we shift them by the minimum of LUMO. bands = self.inputs.bands.get_bands() homo, lumo = get_homo_lumo(bands, fermi_energy) bandgap = lumo - homo if bandgap > 1e-3: shift_energy = lumo keys = ("dis_froz_min", "dis_froz_max", "dis_win_min", "dis_win_max") for key in keys: if key in parameters: parameters[key] += shift_energy # Auto set `dis_froz_max` if self.inputs.auto_energy_windows: dis_froz_max = get_energy_of_projectability( bands=self.inputs.bands, projections=self.inputs.bands_projections, thresholds=self.inputs.auto_energy_windows_threshold.value, ) parameters["dis_froz_max"] = dis_froz_max # Prevent error: # dis_windows: More states in the frozen window than target WFs if "dis_froz_max" in parameters and "bands" in self.inputs: bands = self.inputs.bands.get_bands() if parameters.get("exclude_bands", None): # Index of parameters['exclude_bands'] starts from 1, # I need to change it to 0-based exclude_bands = [_ - 1 for _ in parameters["exclude_bands"]] bands = remove_exclude_bands(bands=bands, exclude_bands=exclude_bands) highest_band = bands[:, min(parameters["num_wann"], bands.shape[1]) - 1] # This is the energy to freeze all the num_wann bands max_froz_energy = np.max(highest_band) # Cannot freeze more bands than num_wann, # this sets the upper limit of `dis_froz_max`. if bands.shape[1] > parameters["num_wann"]: min_next_band = np.min(bands[:, parameters["num_wann"]]) # I subtract a small value for safety min_next_band -= 1e-4 max_froz_energy = min(max_froz_energy, min_next_band) # `dis_froz_max` should be smaller than this max_froz_energy # to allow doing disentanglement dis_froz_max = min(max_froz_energy, parameters["dis_froz_max"]) parameters["dis_froz_max"] = dis_froz_max inputs.parameters = orm.Dict(parameters) if "remote_input_folder" in inputs and "settings" in self.inputs: # Note there is an `additional_remote_symlink_list` in Wannier90Calculation.inputs.settings, # however it requires user providing a list of # (computer_uuid, remote_input_folder_abs_path, dest_path) # This is impossible if we launch a Wannier90Calculation inside a workflow since we don't # know the remote_input_folder when setting the inputs of the workflow. # Thus I add an `inputs.settings['remote_symlink_files']` to Wannier90BaseWorkChain, # which only accepts a list of filenames and generate the full # `additional_remote_symlink_list` here. remote_input_folder = inputs["remote_input_folder"] remote_input_folder_path = pathlib.Path( remote_input_folder.get_remote_path() ) workflow_settings = self.inputs.settings.get_dict() if "settings" in inputs: calc_settings = inputs.settings.get_dict() else: calc_settings = {} remote_symlink_list = calc_settings.get( "additional_remote_symlink_list", [] ) existed_symlinks = [_[-1] for _ in remote_symlink_list] for filename in workflow_settings.get("remote_symlink_files", []): if filename in existed_symlinks: continue remote_symlink_list.append( ( remote_input_folder.computer.uuid, str(remote_input_folder_path / filename), filename, ) ) calc_settings["additional_remote_symlink_list"] = remote_symlink_list inputs.settings = orm.Dict(calc_settings) return inputs def report_error_handled(self, calculation, action) -> None: """Report an action taken for a calculation that has failed. This should be called in a registered error handler if its condition is met and an action was taken. :param calculation: the failed calculation node :param action: a string message with the action taken """ message = f"{calculation.process_label}<{calculation.pk}> failed" message += ( f" with exit status {calculation.exit_status}: {calculation.exit_message}" ) self.report(message) self.report(f"Action taken: {action}") @process_handler(exit_codes=[Wannier90Calculation.exit_codes.ERROR_BVECTORS]) def handle_bvectors(self, calculation) -> ProcessHandlerReport: """Try to fix Wannier90 bvectors errors by tunning `kmesh_tol`. The handler will try to use kmesh_tol = 1e-6, 1e-8, 1e-4. """ parameters = self.ctx.inputs.parameters.get_dict() # If the user has specified `kmesh_tol` in the input parameters and it is different # from the default, we will first try to use the default `kmesh_tol`. current_kmeshtol = parameters.get( "kmesh_tol", self._WANNIER90_DEFAULT_KMESH_TOL ) if current_kmeshtol in self.ctx.kmeshtol_new: self.ctx.kmeshtol_new.remove(current_kmeshtol) if len(self.ctx.kmeshtol_new) == 0: action = "Unrecoverable bvectors error after several trials of kmesh_tol" self.report_error_handled(calculation, action) return ProcessHandlerReport(True, self.exit_codes.ERROR_BVECTORS) new_kmeshtol = self.ctx.kmeshtol_new.pop(0) parameters["kmesh_tol"] = new_kmeshtol action = f"Bvectors error, current kmesh_tol = {current_kmeshtol}, new kmesh_tol = {new_kmeshtol}" self.report_error_handled(calculation, action) self.ctx.inputs.parameters = orm.Dict(parameters) return ProcessHandlerReport(True) @process_handler( exit_codes=[ Wannier90Calculation.exit_codes.ERROR_DISENTANGLEMENT_NOT_ENOUGH_STATES ] ) def handle_disentanglement_not_enough_states( self, calculation ) -> ProcessHandlerReport: """Try to fix Wannier90 wout error message related to projectability disentanglement. The error message is: 'Energy window contains fewer states than number of target WFs, consider reducing dis_proj_min/increasing dis_win_max?'. The handler will try to use decrease 'dis_proj_min' to allow for more states for disentanglement. """ parameters = self.ctx.inputs.parameters.get_dict() if "dis_proj_min" not in parameters and "dis_proj_max" not in parameters: # If neither is present, I should never encounter this exit_code action = "Unrecoverable bvectors error: the error handler is only for projectability disentanglement" self.report_error_handled(calculation, action) return ProcessHandlerReport( True, self.exit_codes.ERROR_DISENTANGLEMENT_NOT_ENOUGH_STATES ) if len(self.ctx.disprojmin_multipliers) == 0: action = "Unrecoverable error after several trials of dis_proj_min" self.report_error_handled(calculation, action) return ProcessHandlerReport( True, self.exit_codes.ERROR_DISENTANGLEMENT_NOT_ENOUGH_STATES ) current_disprojmin = parameters.get( "dis_proj_min", self._WANNIER90_DEFAULT_KMESH_TOL ) multiplier = self.ctx.disprojmin_multipliers.pop(0) new_disprojmin = current_disprojmin * multiplier parameters["dis_proj_min"] = new_disprojmin action = "Not enough states for disentanglement, " action += f"current dis_proj_min = {current_disprojmin}, new dis_proj_min = {new_disprojmin}" self.report_error_handled(calculation, action) self.ctx.inputs.parameters = orm.Dict(parameters) return ProcessHandlerReport(True) @process_handler(exit_codes=[Wannier90Calculation.exit_codes.ERROR_PLOT_WF_CUBE]) def handle_plot_wf_cube(self, calculation) -> ProcessHandlerReport: """Try to fix Wannier90 wout error message related to cube format. The error message is: 'Error plotting WF cube. Try one of the following:'. The handler will try to increase 'wannier_plot_supercell'. """ parameters = self.ctx.inputs.parameters.get_dict() current_supercell = parameters.get( "wannier_plot_supercell", self._WANNIER90_DEFAULT_WANNIER_PLOT_SUPERCELL ) # Remove sizes which are smaller equal than current supercell size self.ctx.wannier_plot_supercell_new = [ _ for _ in self.ctx.wannier_plot_supercell_new if _ > current_supercell ] if len(self.ctx.wannier_plot_supercell_new) == 0: action = ( "Unrecoverable error after several trials of wannier_plot_supercell" ) self.report_error_handled(calculation, action) return ProcessHandlerReport(True, self.exit_codes.ERROR_PLOT_WF_CUBE) new_supercell = self.ctx.wannier_plot_supercell_new.pop(0) parameters["wannier_plot_supercell"] = new_supercell action = "Error plotting WFs in cube format, " action += f"current wannier_plot_supercell = {current_supercell}, new wannier_plot_supercell = {new_supercell}" self.report_error_handled(calculation, action) self.ctx.inputs.parameters = orm.Dict(parameters) return ProcessHandlerReport(True) @process_handler( exit_codes=[Wannier90Calculation.exit_codes.ERROR_OUTPUT_STDOUT_INCOMPLETE] ) def handle_output_stdout_incomplete(self, calculation) -> ProcessHandlerReport: """Try to fix incomplete stdout error by reducing the number of cores. Often the ERROR_OUTPUT_STDOUT_INCOMPLETE is due to out-of-memory. The handler will try to set `num_mpiprocs_per_machine` to 1. """ import re regex_oom = re.compile(r"Detected \d+ oom-kill event\(s\) in step") regex_bus = re.compile( r"Program received signal SIGBUS: Access to an undefined portion of a memory object" ) # srun: error: nid002144: tasks 0-63: Killed regex_killed = re.compile(r"srun: error: \w+: tasks? [\w\-,]+: Killed") scheduler_stderr = calculation.get_scheduler_stderr() for line in scheduler_stderr.split("\n"): if ( regex_oom.search(line) or regex_bus.search(line) or regex_killed.search(line) or "Out Of Memory" in line ): break else: action = "Unrecoverable incomplete stdout error" self.report_error_handled(calculation, action) return ProcessHandlerReport( True, self.exit_codes.ERROR_OUTPUT_STDOUT_INCOMPLETE ) metadata = self.ctx.inputs["metadata"] current_num_mpiprocs_per_machine = metadata["options"]["resources"].get( "num_mpiprocs_per_machine", 1 ) # num_mpiprocs_per_machine = calculation.attributes['resources'].get('num_mpiprocs_per_machine', 1) if current_num_mpiprocs_per_machine == 1: action = "Unrecoverable out-of-memory error after setting num_mpiprocs_per_machine to 1" self.report_error_handled(calculation, action) return ProcessHandlerReport( True, self.exit_codes.ERROR_OUTPUT_STDOUT_INCOMPLETE ) new_num_mpiprocs_per_machine = current_num_mpiprocs_per_machine // 2 metadata["options"]["resources"][ "num_mpiprocs_per_machine" ] = new_num_mpiprocs_per_machine action = f"Out-of-memory error, current num_mpiprocs_per_machine = {current_num_mpiprocs_per_machine}" action += f", new num_mpiprocs_per_machine = {new_num_mpiprocs_per_machine}" self.report_error_handled(calculation, action) self.ctx.inputs["metadata"] = metadata return ProcessHandlerReport(True)