Source code for aiida_wannier90_workflows.workflows.optimize

"""Workchain to automatically optimize dis_proj_min/max for projectability disentanglement."""
import pathlib
import typing as ty
import warnings

import numpy as np

from aiida import orm
from aiida.engine import ProcessBuilder, ToContext, append_, if_, while_
from aiida.orm.nodes.data.base import to_aiida_type

from aiida_quantumespresso.utils.mapping import prepare_process_inputs

from aiida_wannier90_workflows.utils.workflows import get_last_calcjob

from .bands import Wannier90BandsWorkChain
from .base.wannier90 import Wannier90BaseWorkChain

__all__ = ["validate_inputs", "Wannier90OptimizeWorkChain"]


def validate_inputs(inputs, ctx=None):  # pylint: disable=unused-argument
    """Validate the inputs of the entire input namespace of `Wannier90OptimizeWorkChain`."""
    from .bands import validate_inputs as parent_validate_inputs

    # Call parent validator
    result = parent_validate_inputs(inputs)
    if result is not None:
        return result

    parameters = inputs["wannier90"]["wannier90"]["parameters"].get_dict()

    if inputs["optimize_disproj"]:
        if all(_ not in parameters for _ in ("dis_proj_min", "dis_proj_max")):
            return "Trying to optimize dis_proj_min/max but no dis_proj_min/max in wannier90 parameters?"

    if "optimize_reference_bands" in inputs and not inputs["optimize_disproj"]:
        warnings.warn(
            "`optimize_reference_bands` is provided but `optimize_disproj = False`?"
        )

    if (
        "optimize_bands_distance_threshold" in inputs
        and "optimize_reference_bands" not in inputs
    ):
        return "No `optimize_reference_bands` but `optimize_bands_distance_threshold` is set?"

    if inputs["separate_plotting"]:
        plot_inputs = [
            parameters.get(_, False)
            for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS  # pylint: disable=protected-access
        ]
        if not any(plot_inputs):
            return (
                "Trying to separate plotting routines but no "
                f"{'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)} in wannier90 parameters?"  # pylint: disable=protected-access
            )

    if inputs["optimize_disproj"] and not inputs["separate_plotting"]:
        warnings.warn(
            "`optimize_disproj = True` but `separate_plotting = False`. For optimizing projectability "
            "disentanglement, it is highly recommended to run the plotting mode in a separate step."
        )

    return None


[docs]class Wannier90OptimizeWorkChain(Wannier90BandsWorkChain): """Workchain to optimize dis_proj_min/max for projectability disentanglement.""" # The following keys are for wannier90.x plotting, i.e. they can be restarted from # chk file by setting `restart = plot` in wannier90.win. _WANNIER90_PLOT_INPUTS = ( "wannier_plot", "bands_plot", "write_tb", "write_hr", "write_hhmn", "write_hkmn", "write_hvmn", "write_hdmn", ) @classmethod def define(cls, spec): """Define the process spec.""" super().define(spec) spec.input( "separate_plotting", valid_type=orm.Bool, default=lambda: orm.Bool(False), serializer=to_aiida_type, help=( "If True separate the maximal localisation and the plotting of bands/Wannier function in two steps. " "This allows reusing the chk file to restart plotting if it were crashed due to memory issue." ), ) spec.input( "optimize_disproj", valid_type=orm.Bool, default=lambda: orm.Bool(True), serializer=to_aiida_type, help=( "If True iterate dis_proj_min/max to find the best MLWFs for projectability disentanglement." ), ) spec.input( "optimize_disprojmax_range", valid_type=orm.List, default=lambda: orm.List(list=list(np.linspace(0.99, 0.85, 15))), serializer=to_aiida_type, help=( "The range to iterate dis_proj_min. `None` means disabling projectability disentanglement." ), ) spec.input( "optimize_disprojmin_range", valid_type=orm.List, default=lambda: orm.List(list=list(np.linspace(0.01, 0.02, 2))), serializer=to_aiida_type, help=( "The range to iterate dis_proj_max. `None` means disabling projectability disentanglement." ), ) spec.input( "optimize_reference_bands", valid_type=orm.BandsData, required=False, help=( "If provided, during the iteration of dis_proj_min/max, the BandsData will be the reference " "for calculating bands distance, the final optimal MLWFs will be selected based on both spreads " "and bands distance. If not provided, spreads will be the criterion for selecting optimal MLWFs. " "The bands distance is calculated for bands below Fermi energy + 2eV." ), ) spec.input( "optimize_bands_distance_threshold", valid_type=orm.Float, required=False, serializer=to_aiida_type, help=( "If provided, during the iteration of dis_proj_min/max, if the bands distance is smaller " "than this threshold, the optimization will stop. Unit is eV." ), ) spec.input( "optimize_spreads_imbalence_threshold", valid_type=orm.Float, required=False, serializer=to_aiida_type, help=( "If provided, during the iteration of dis_proj_min/max, if the spreads imbalence is smaller " "than this threshold, the optimization will stop." ), ) spec.inputs.validator = validate_inputs spec.outline( cls.setup, if_(cls.should_run_seekpath)( cls.run_seekpath, ), if_(cls.should_run_scf)( cls.run_scf, cls.inspect_scf, ), if_(cls.should_run_nscf)( cls.run_nscf, cls.inspect_nscf, ), if_(cls.should_run_open_grid)( cls.run_open_grid, cls.inspect_open_grid, ), if_(cls.should_run_projwfc)( cls.run_projwfc, cls.inspect_projwfc, ), cls.run_wannier90_pp, cls.inspect_wannier90_pp, cls.run_pw2wannier90, cls.inspect_pw2wannier90, cls.run_wannier90, cls.inspect_wannier90, while_(cls.should_run_wannier90_optimize)( cls.run_wannier90_optimize, cls.inspect_wannier90_optimize, ), cls.inspect_wannier90_optimize_final, if_(cls.should_run_wannier90_plot)( cls.run_wannier90_plot, cls.inspect_wannier90_plot, ), cls.results, ) spec.expose_outputs( Wannier90BaseWorkChain, namespace="wannier90_optimal", namespace_options={"required": False}, ) spec.expose_outputs( Wannier90BaseWorkChain, namespace="wannier90_plot", namespace_options={"required": False}, ) spec.output( "bands_distance", valid_type=orm.Float, required=False, help="Bands distances between reference bands and Wannier interpolated bands for Ef to Ef+5eV.", ) spec.exit_code( 500, "ERROR_SUB_PROCESS_FAILED_WANNIER90_OPTIMIZE", message="All the trials on dis_proj_min/max have failed, cannot compare bands distance", ) spec.exit_code( 501, "ERROR_SUB_PROCESS_FAILED_WANNIER90_OPTIMIZE", message="All the trials on dis_proj_min/max have failed, cannot compare spreads", ) spec.exit_code( 500, "ERROR_SUB_PROCESS_FAILED_WANNIER90_PLOT", message="the Wannier90Calculation plotting sub process failed", ) @classmethod def get_protocol_filepath(cls) -> pathlib.Path: """Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" from importlib_resources import files from . import protocols return files(protocols) / "optimize.yaml" @classmethod def get_builder_from_protocol( # pylint: disable=arguments-differ cls, codes: ty.Mapping[str, ty.Union[str, int, orm.Code]], structure: orm.StructureData, *, reference_bands: orm.BandsData = None, bands_distance_threshold: float = 1e-2, # unit is eV **kwargs, ) -> ProcessBuilder: """Return a builder prepopulated with inputs selected according to the specified arguments. :return: [description] :rtype: ProcessBuilder """ from aiida_wannier90_workflows.utils.workflows.bands import ( has_overlapping_semicore, ) from aiida_wannier90_workflows.utils.workflows.builder.submit import ( recursive_merge_builder, ) parent_builder = super().get_builder_from_protocol(codes, structure, **kwargs) if reference_bands is not None: exclude_semicore = kwargs.get("exclude_semicore", True) if exclude_semicore: params = parent_builder.wannier90.wannier90.parameters.get_dict() exclude_bands = params.get("exclude_bands", None) overlapping_semicore = has_overlapping_semicore( reference_bands, exclude_bands ) if overlapping_semicore: warnings.warn( "The reference bands has overlapping semicore bands, " "the exclude_semicore option is set to False." ) kwargs["exclude_semicore"] = False parent_builder = super().get_builder_from_protocol( codes, structure, **kwargs ) # Prepare workchain builder builder = Wannier90OptimizeWorkChain.get_builder() inputs = Wannier90OptimizeWorkChain.get_protocol_inputs( protocol=kwargs.get("protocol", None), overrides=kwargs.get("overrides", None), ) builder = recursive_merge_builder(builder, inputs) inputs = parent_builder._inputs(prune=True) # pylint: disable=protected-access builder = recursive_merge_builder(builder, inputs) # Inputs for optimizing dis_proj_min/max if reference_bands: builder.separate_plotting = True builder.optimize_disproj = True builder.optimize_reference_bands = reference_bands builder.optimize_bands_distance_threshold = bands_distance_threshold return builder def setup(self): """Define the current structure in the context to be the input structure.""" super().setup() dis_proj_min = self.inputs["optimize_disprojmin_range"].get_list() dis_proj_max = self.inputs["optimize_disprojmax_range"].get_list() # dis_proj_max changes the fastest self.ctx.optimize_minmax_new = [ (i, j) for i in dis_proj_min for j in dis_proj_max ] # Arrays to save calculated results self.ctx.optimize_minmax = [] self.ctx.optimize_bandsdist = [] self.ctx.optimize_spreads_imbalence = [] # The optimal wannier90 workchain self.ctx.optimize_best = None # For separate_plotting, restore these inputs when running plotting calc. self.ctx.saved_parameters = {} if self.inputs["separate_plotting"]: parameters = self.inputs.wannier90.wannier90["parameters"].get_dict() # I convert the tuple to list so it can be changed excluded_inputs = list(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS) # I need to calculate bands for comparing bands distance if "optimize_reference_bands" in self.inputs: excluded_inputs.remove("bands_plot") for key in excluded_inputs: plot_input = parameters.get(key, False) if plot_input: self.ctx.saved_parameters[key] = plot_input def should_run_wannier90_optimize(self): """Whether should optimize dis_proj_min/max.""" if not self.inputs["optimize_disproj"]: return False if "optimize_bands_distance_threshold" in self.inputs: threshold = self.inputs["optimize_bands_distance_threshold"] if self.ctx.workchain_wannier90_bandsdist <= threshold: # Stop if the initial bands distance is already good enough self.ctx.optimize_minmax_new = [] else: # Replace `None` by a huge number to avoid np.min error: # TypeError: '<=' not supported between instances of 'float' and 'NoneType' opt_dist = [_ if _ else 1e5 for _ in self.ctx.optimize_bandsdist] if len(opt_dist) > 0 and np.min(opt_dist) <= threshold: self.ctx.optimize_minmax_new = [] elif "optimize_spreads_imbalence_threshold" in self.inputs: threshold = self.inputs["optimize_spreads_imbalence_threshold"] if self.ctx.workchain_wannier90_spreads_imbalence <= threshold: # Stop if the initial spreads are already good enough self.ctx.optimize_minmax_new = [] elif ( len(self.ctx.optimize_spreads_imbalence) > 0 and np.min(self.ctx.optimize_spreads_imbalence) <= threshold ): self.ctx.optimize_minmax_new = [] if len(self.ctx.optimize_minmax_new) == 0: return False return True def has_run_wannier90_optimize(self): """Whether the optimization loop has been invoked.""" return "workchain_wannier90_optimize" in self.ctx def should_run_wannier90_plot(self): """Whether to run wannier90 maximal localisation and plotting in two steps or in one step.""" return self.inputs["separate_plotting"] def prepare_wannier90_pp_inputs(self): """Override parent method. :return: the inputs port :rtype: InputPort """ base_inputs = super().prepare_wannier90_pp_inputs() inputs = base_inputs["wannier90"] parameters = inputs.parameters.get_dict() # Do not run plotting subroutines in the Wannier90BandsWorkChain, they will be run in a separate step if self.should_run_wannier90_plot(): for key in self.ctx.saved_parameters: parameters.pop(key, None) inputs.parameters = orm.Dict(parameters) if "optimize_reference_bands" not in self.inputs: inputs.pop("kpoint_path", None) inputs.pop("bands_kpoints", None) base_inputs["wannier90"] = inputs # This will use the input bands which usually is a bands along kpath, # also used to calculate bands distance. # However in Wannier90BaseWorkChain, this bands will also be used to # set the max of dis_froz_max (in `prepare_inputs()`), however, if # the number of bands of this input is too small, this will lead to # too small max limit of dis_froz_max, essentially causing a very low # dis_froz_max in the actual wannier90 calculation. # In such case, it is safer to use the output_band of scf or nscf step # inside the workflow, however # these are bands on a grid, usually their LUMO is a bit larger than # LUMO from bands along kpath, so when shifting dis_froz_max w.r.t. LUMO, # it might be a bit inaccurate. # # The parent class Wannier90BandsWorkChain.inputs.wannier90.bands is empty, # so it will use output_band of scf or nscf. Here I explicitly overwrite the # wannier90.bands by the reference_bands. if self.inputs.optimize_disproj and "optimize_reference_bands" in self.inputs: base_inputs.bands = self.inputs.optimize_reference_bands return base_inputs def run_wannier90(self): """Overide parent, pop stash settings.""" inputs = self.prepare_wannier90_inputs() # I should not stash files if there is an additional plotting step, # otherwise there is a RemoteStashFolderData in outputs if self.should_run_wannier90_plot(): inputs["wannier90"]["metadata"]["options"].pop("stash", None) inputs["metadata"] = {"call_link_label": "wannier90"} inputs = prepare_process_inputs(Wannier90BaseWorkChain, inputs) running = self.submit(Wannier90BaseWorkChain, **inputs) self.report(f"launching {running.process_label}<{running.pk}>") return ToContext(workchain_wannier90=running) def inspect_wannier90(self): """Overide parent.""" super().inspect_wannier90() workchain = self.ctx.workchain_wannier90 if "optimize_reference_bands" in self.inputs: bandsdist = self._get_bands_distance(workchain) self.ctx.workchain_wannier90_bandsdist = bandsdist self.report( f"current workchain<{workchain.pk}> bands distance={bandsdist:.2e}eV" ) self.ctx.workchain_wannier90_spreads_imbalence = get_spreads_imbalence( workchain.outputs.output_parameters["wannier_functions_output"] ) def prepare_wannier90_optimize_inputs(self): """Prepare inputs for optimize run.""" base_inputs = super().prepare_wannier90_inputs() inputs = base_inputs["wannier90"] # I need to save `inputs.wannier90.metadata.options.resources`, somehow it is missing if I # copy all the inputs from `self.ctx.workchain_wannier90`. # resources = base_inputs['wannier90']['wannier90']['metadata']['options']['resources'] # Use the Wannier90BaseWorkChain-corrected parameters, especially `num_mpiprocs_per_machine` last_calc = get_last_calcjob(self.ctx.workchain_wannier90) for key in last_calc.inputs: inputs[key] = last_calc.inputs[key] parameters = inputs.parameters.get_dict() dis_proj_min, dis_proj_max = self.ctx.optimize_minmax_new[0] parameters["dis_proj_min"] = dis_proj_min parameters["dis_proj_max"] = dis_proj_max if "optimize_reference_bands" in self.inputs: parameters["bands_plot"] = True if self.ctx.current_kpoint_path: inputs.kpoint_path = self.ctx.current_kpoint_path if self.ctx.current_bands_kpoints: inputs.bands_kpoints = self.ctx.current_bands_kpoints inputs.parameters = orm.Dict(parameters) # I should not stash files if there is an additional plotting step, # otherwise there is a RemoteStashFolderData in outputs if self.should_run_wannier90_plot(): inputs["metadata"]["options"].pop("stash", None) base_inputs["wannier90"] = inputs base_inputs["clean_workdir"] = orm.Bool(False) return base_inputs def run_wannier90_optimize(self): """Optimize dis_proj_min/max.""" inputs = self.prepare_wannier90_optimize_inputs() iteration = len(self.ctx.optimize_minmax) + 1 # Start from 1 inputs["metadata"] = { "call_link_label": f"wannier90_optimize_iteration{iteration}" } # Disable the error handler which might modify dis_proj_min handler_overrides = {"handle_disentanglement_not_enough_states": False} inputs["handler_overrides"] = orm.Dict(handler_overrides) inputs = prepare_process_inputs(Wannier90BaseWorkChain, inputs) running = self.submit(Wannier90BaseWorkChain, **inputs) dis_proj_min, dis_proj_max = self.ctx.optimize_minmax_new[0] self.report( f"launching {running.process_label}<{running.pk}> with dis_proj_min={dis_proj_min} " f"dis_proj_max={dis_proj_max}" ) return ToContext(workchain_wannier90_optimize=append_(running)) def inspect_wannier90_optimize(self): """Verify that the `Wannier90BaseWorkChain` for the wannier90 optimization run successfully finished.""" workchain = self.ctx.workchain_wannier90_optimize[-1] if workchain.is_finished_ok: spreads = get_spreads_imbalence( workchain.outputs.output_parameters["wannier_functions_output"] ) if ( "optimize_reference_bands" in self.inputs and "interpolated_bands" in workchain.outputs ): bandsdist = self._get_bands_distance(workchain) self.report( f"current workchain<{workchain.pk}> bands distance={bandsdist:.2e}eV" ) else: bandsdist = None else: self.report( f"{workchain.process_label} failed with exit status {workchain.exit_status}, " "but I will keep launching next iteration" ) spreads = None bandsdist = None minmax = self.ctx.optimize_minmax_new.pop(0) self.ctx.optimize_minmax.append(minmax) self.ctx.optimize_bandsdist.append(bandsdist) self.ctx.optimize_spreads_imbalence.append(spreads) def inspect_wannier90_optimize_final(self): """Select the optimal choice for dis_proj_min/max.""" if not self.has_run_wannier90_optimize(): return workchains = self.ctx.workchain_wannier90_optimize # The optimal wannier90 workchain self.ctx.optimize_best = None if "optimize_reference_bands" in self.inputs: # Usually good bands distance means MLWFs have good spreads fake_max = 1e5 bandsdist = np.array( [_ if _ else fake_max for _ in self.ctx.optimize_bandsdist] ) idx = np.argmin(bandsdist) if bandsdist[idx] < min(fake_max, self.ctx.workchain_wannier90_bandsdist): self.ctx.optimize_best = workchains[idx] opt_bandsdist = bandsdist[idx] minmax = self.ctx.optimize_minmax[idx] else: # All optimizations failed, just output the initial w90 self.ctx.optimize_best = self.ctx.workchain_wannier90 opt_bandsdist = self.ctx.workchain_wannier90_bandsdist # dis_proj_min/max might be corrected by error handlers, # output the last min/max. last_calc = get_last_calcjob(self.ctx.optimize_best) params = last_calc.inputs.parameters.get_dict() minmax = ( params.get("dis_proj_min", None), params.get("dis_proj_max", None), ) self.report( f"Optimal bands distance={opt_bandsdist:.2e}, " f"dis_proj_min={minmax[0]} dis_proj_max={minmax[1]}" ) else: # I only check the spreads are balenced spreads = np.array( [_ if _ else 1e5 for _ in self.ctx.optimize_spreads_imbalence] ) idx = np.argmin(spreads) self.ctx.optimize_best = workchains[idx] minmax = self.ctx.optimize_minmax[idx] self.report( f"Optimal spreads={spreads[idx]}, " f"dis_proj_min={minmax[0]} dis_proj_max={minmax[1]}" ) self.ctx.current_folder = self.ctx.optimize_best.outputs.remote_folder def prepare_wannier90_plot_inputs(self): """Wannier90 plot step, also stash files.""" # Note with `Wannier90WorkChain.prepare_wannier90_inputs()`, the stash setting # has been restored. base_inputs = super().prepare_wannier90_inputs() inputs = base_inputs["wannier90"] # Use the corrected parameters if self.has_run_wannier90_optimize(): # Use the optimal parameters optimal_workchain = self.ctx.optimize_best else: # Just use the base workchain optimal_workchain = self.ctx.workchain_wannier90 # Copy inputs, especially the `dis_proj_min/max` might have been corrected last_calc = get_last_calcjob(optimal_workchain) for key in last_calc.inputs: inputs[key] = last_calc.inputs[key] # Use `current_folder` which points to the optimal wannier90 folder, since we need the chk file. # However we need to explicitly # symlink UNK files in that folder, otherwise the plot calculation would fail. inputs["remote_input_folder"] = self.ctx.current_folder # Maybe in aiida-w90, should let Calculation accepts an optional SinglefileData? for chk, # so we don't need to explicitly symlink. settings = inputs.settings.get_dict() remote_input_folder_uuid = ( self.ctx.workchain_pw2wannier90.outputs.remote_folder.computer.uuid ) remote_input_folder_path = pathlib.Path( self.ctx.workchain_pw2wannier90.outputs.remote_folder.get_remote_path() ) additional_remote_symlink_list = settings.get( "additional_remote_symlink_list", [] ) additional_remote_symlink_list += [ (remote_input_folder_uuid, str(remote_input_folder_path / "UNK*"), ".") ] settings["additional_remote_symlink_list"] = additional_remote_symlink_list inputs.settings = orm.Dict(settings) # Restore plotting related tags parameters = inputs.parameters.get_dict() for key in self.ctx.saved_parameters: parameters[key] = True parameters["restart"] = "plot" inputs.parameters = orm.Dict(parameters) if parameters.get("bands_plot", False): if self.ctx.current_kpoint_path: inputs.kpoint_path = self.ctx.current_kpoint_path if self.ctx.current_bands_kpoints: inputs.bands_kpoints = self.ctx.current_bands_kpoints base_inputs["wannier90"] = inputs base_inputs["clean_workdir"] = orm.Bool(False) return base_inputs def run_wannier90_plot(self): """Wannier90 plot step, also stash files.""" inputs = self.prepare_wannier90_plot_inputs() inputs["metadata"] = {"call_link_label": "wannier90_plot"} inputs = prepare_process_inputs(Wannier90BaseWorkChain, inputs) running = self.submit(Wannier90BaseWorkChain, **inputs) self.report(f"launching {running.process_label}<{running.pk}> in plotting mode") return ToContext(workchain_wannier90_plot=running) def inspect_wannier90_plot(self): """Verify that the `Wannier90BaseWorkChain` for the wannier90 plotting run successfully finished.""" workchain = self.ctx.workchain_wannier90_plot if not workchain.is_finished_ok: self.report( f"{workchain.process_label} failed with exit status {workchain.exit_status}" ) return self.exit_codes.ERROR_SUB_PROCESS_FAILED_WANNIER90_PLOT self.ctx.current_folder = workchain.outputs.remote_folder return None def results(self): """Attach the relevant output nodes from the band calculation to the workchain outputs for convenience.""" super().results() if self.inputs["optimize_disproj"]: if self.has_run_wannier90_optimize(): optimal_workchain = self.ctx.optimize_best else: optimal_workchain = self.ctx.workchain_wannier90 self.out_many( self.exposed_outputs( optimal_workchain, Wannier90BaseWorkChain, namespace="wannier90_optimal", ) ) if self.should_run_wannier90_plot(): self.out_many( self.exposed_outputs( self.ctx.workchain_wannier90_plot, Wannier90BaseWorkChain, namespace="wannier90_plot", ) ) if "interpolated_bands" in self.outputs["wannier90_plot"]: w90_bands = self.outputs["wannier90_plot"]["interpolated_bands"] self.out("band_structure", w90_bands) if "optimize_reference_bands" in self.inputs: if self.has_run_wannier90_optimize(): optimal_workchain = self.ctx.optimize_best else: # Even if I haven't run optimization, I still output bands distance if reference bands is present optimal_workchain = self.ctx.workchain_wannier90 bandsdist = self._get_bands_distance(optimal_workchain) bandsdist = orm.Float(bandsdist) bandsdist.store() self.out("bands_distance", bandsdist) def _get_bands_distance(self, wannier_workchain: Wannier90BaseWorkChain) -> float: """Get bands distance for Fermi energy + 2eV.""" ref_bands = self.inputs["optimize_reference_bands"] bandsdist = get_bands_distance_ef2(ref_bands, wannier_workchain) return bandsdist
def get_bands_distance_ef2( ref_bands: orm.BandsData, wannier_workchain: Wannier90BaseWorkChain ) -> float: """Get bands distance for E <= Fermi energy + 2eV.""" from aiida_wannier90_workflows.utils.bands.distance import bands_distance wan_bands = wannier_workchain.outputs["interpolated_bands"] wan_parameters = wannier_workchain.inputs["wannier90"]["parameters"].get_dict() fermi_energy = wan_parameters.get("fermi_energy") exclude_list_dft = wan_parameters.get("exclude_bands", None) # Bands distance from Ef to Ef+5 bandsdist = bands_distance(ref_bands, wan_bands, fermi_energy, exclude_list_dft) # Only return average distance, not max distance bandsdist = bandsdist[:, 1] # Return Ef+2 bandsdist = bandsdist[2] return bandsdist def get_spreads_imbalence(wannier_functions_output: dict) -> float: """Calculate the variance of spreads. There could be other ways to calculate the spreads imbalence, for now I just use variance. :param wannier_functions_output: [description] :type wannier_functions_output: dict :return: [description] :rtype: float """ spreads = [_["wf_spreads"] for _ in wannier_functions_output] var = np.var(spreads) # TODO try K-Means clustering? pylint: disable=fixme return var