Source code for aiida_wannier90_workflows.workflows.bands

"""WorkChain to automatically calculate Wannier band structure."""
import pathlib
import typing as ty

from aiida import orm
from aiida.engine import ProcessBuilder, if_
from aiida.orm.nodes.data.base import to_aiida_type

from .open_grid import Wannier90OpenGridWorkChain

__all__ = ["validate_inputs", "Wannier90BandsWorkChain"]


def validate_inputs(  # pylint: disable=unused-argument,inconsistent-return-statements
    inputs, ctx=None
):
    """Validate the inputs of the entire input namespace of `Wannier90BandsWorkChain`."""
    from .open_grid import validate_inputs as parent_validate_inputs

    # Call parent validator
    result = parent_validate_inputs(inputs)
    if result is not None:
        return result

    # Cannot specify both `kpoint_path` and `bands_kpoints_distance`
    if (
        sum(
            _ in inputs
            for _ in ["kpoint_path", "bands_kpoints", "bands_kpoints_distance"]
        )
        > 1
    ):
        return "Can only specify one of the `kpoint_path`, `bands_kpoints` and `bands_kpoints_distance`."

    # `kpoint_path` and `bands_kpoints` must contain `labels`
    if "kpoint_path" in inputs:
        if inputs["kpoint_path"].labels is None:
            return "`kpoint_path` must contain `labels`"
    if "bands_kpoints" in inputs:
        if inputs["bands_kpoints"].labels is None:
            return "`bands_kpoints` must contain `labels`"


[docs]class Wannier90BandsWorkChain(Wannier90OpenGridWorkChain): """WorkChain to automatically compute a Wannier band structure for a given structure.""" @classmethod def define(cls, spec): """Define the process specification.""" super().define(spec) spec.input( "kpoint_path", valid_type=orm.KpointsData, required=False, help=( "High symmetry kpoints to use for the wannier90 bands interpolation. " "If specified, the high symmetry kpoint labels will be used and wannier90 will use the " "`bands_num_points` mechanism to auto generate a list of kpoints along the kpath. " "If not specified, the workchain will run seekpath to generate " "a primitive cell and a bands_kpoints. Specify either this or `bands_kpoints` " "or `bands_kpoints_distance`." ), ) spec.input( "bands_kpoints", valid_type=orm.KpointsData, required=False, help=( "Explicit kpoints to use for the wannier90 bands interpolation. " "If specified, wannier90 will use this list of kpoints and will not use the " "`bands_num_points` mechanism to auto generate a list of kpoints along the kpath. " "If not specified, the workchain will run seekpath to generate " "a primitive cell and a bands_kpoints. Specify either this or `bands_kpoints` " "or `bands_kpoints_distance`. " "This ensures the wannier interpolated bands has the exact same number of kpoints " "as PW bands, to calculate bands distance." ), ) spec.input( "bands_kpoints_distance", valid_type=orm.Float, serializer=to_aiida_type, required=False, help="Minimum kpoints distance for seekpath to generate a list of kpoints along the path. " "Specify either this or `bands_kpoints` or `kpoint_path`.", ) # We expose the in/output of `Wannier90OpenGridWorkChain` since `Wannier90WorkChain` in/output # is a subset of `Wannier90OpenGridWorkChain`, this allow us to launch either `Wannier90WorkChain` # or `Wannier90OpenGridWorkChain`. spec.expose_inputs( Wannier90OpenGridWorkChain, exclude=( "wannier90.wannier90.kpoint_path", "wannier90.wannier90.bands_kpoints", ), namespace_options={"required": True}, ) spec.inputs.validator = validate_inputs spec.outline( cls.setup, if_(cls.should_run_seekpath)( cls.run_seekpath, ), if_(cls.should_run_scf)( cls.run_scf, cls.inspect_scf, ), if_(cls.should_run_nscf)( cls.run_nscf, cls.inspect_nscf, ), if_(cls.should_run_open_grid)( cls.run_open_grid, cls.inspect_open_grid, ), if_(cls.should_run_projwfc)( cls.run_projwfc, cls.inspect_projwfc, ), cls.run_wannier90_pp, cls.inspect_wannier90_pp, cls.run_pw2wannier90, cls.inspect_pw2wannier90, cls.run_wannier90, cls.inspect_wannier90, cls.results, ) spec.output( "primitive_structure", valid_type=orm.StructureData, required=False, help="The normalized and primitivized structure for which the calculations are computed.", ) spec.output( "seekpath_parameters", valid_type=orm.Dict, required=False, help="The parameters used in the SeeKpath call to normalize the input or relaxed structure.", ) spec.expose_outputs( Wannier90OpenGridWorkChain, namespace_options={"required": True} ) spec.output( "band_structure", valid_type=orm.BandsData, help="The Wannier interpolated band structure.", ) @classmethod def get_protocol_filepath(cls) -> pathlib.Path: """Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" from importlib_resources import files from . import protocols return files(protocols) / "bands.yaml" @classmethod def get_builder_from_protocol( # pylint: disable=arguments-differ cls, codes: ty.Mapping[str, ty.Union[str, int, orm.Code]], structure: orm.StructureData, *, kpoint_path: orm.Dict = None, bands_kpoints: orm.KpointsData = None, bands_kpoints_distance: float = None, run_open_grid: bool = False, open_grid_only_scf: bool = True, **kwargs, ) -> ProcessBuilder: """Return a builder prepopulated with inputs selected according to the specified arguments. :param codes: a dictionary of codes for pw.x, pw2wannier90.x, wannier90.x, and optionally projwfc.x, open_grid.x. :type codes: ty.Mapping[str, ty.Union[str, int, orm.Code]] :param structure: the ``StructureData`` instance to use. :type structure: orm.StructureData :param kpoint_path: Explicit kpoints to use for the Wannier bands interpolation. If `kpoint_path` or `bands_kpoints` is provided, the workchain will directly generate input parameters for the structure and use the provided `KpointsData`, e.g. when one wants to Wannierise a conventional cell structure. If not provided, will use seekpath to generate a primitive cell, and generate input parameters for the PRIMITIVE cell. After submission of the workchian, a `seekpath_structure_analysis` calcfunction will be launched to store the provenance from non-primitive cell to primitive cell. In any case, the `get_builder_from_protocol` will NOT launch any calcfunction so the aiida database is kept unmodified. Defaults to None. :type kpoint_path: orm.KpointsData, optional :param bands_kpoints: Explicit kpoints to use for the Wannier bands interpolation. See `kpoint_path`. :type bands_kpoints: orm.KpointsData, optional :param bands_kpoints_distance: Minimum kpoints distance for the Wannier bands interpolation. Specify either this or `kpoint_path`. If not provided, will use the default of seekpath. Defaults to None :type bands_kpoints_distance: float, optional :param run_open_grid: if True use open_grid.x to accelerate calculations. :type run_open_grid: bool, defaults to False :param open_grid_only_scf: if True only one scf calculation will be performed in the OpenGridWorkChain. :type open_grid_only_scf: bool, defaults to True :return: a process builder instance with all inputs defined and ready for launch. :rtype: ProcessBuilder """ from aiida.tools import get_explicit_kpoints_path from aiida_quantumespresso.common.types import SpinType from aiida_wannier90_workflows.utils.workflows.builder.submit import ( recursive_merge_builder, recursive_merge_container, ) kpt_inputs = (kpoint_path, bands_kpoints, bands_kpoints_distance) if sum(_ is not None for _ in kpt_inputs) > 1: raise ValueError( "Can only specify one of the `kpoint_path`, `bands_kpoints` and `bands_kpoints_distance`" ) del kpt_inputs if run_open_grid and kwargs.get("electronic_type", None) == SpinType.SPIN_ORBIT: raise ValueError("open_grid.x does not support spin orbit coupling") # I will call different parent_class.get_builder_from_protocl() if run_open_grid: # i.e. Wannier90OpenGridWorkChain parent_class = super() kwargs["open_grid_only_scf"] = open_grid_only_scf else: # i.e. Wannier90WorkChain parent_class = super(Wannier90OpenGridWorkChain, cls) summary = kwargs.pop("summary", {}) print_summary = kwargs.pop("print_summary", True) if kpoint_path is None and bands_kpoints is None: # If no `kpoint_path` and `bands_kpoints` provided, the workchain will always run seekpath # even if the structure is a primitive cell. # However, if seekpath reduce the structure to primitive cell, then I need to populate the # builder with parameters for primitive cell, otherwise parameters depending on number of atoms # e.g. num_wann, num_bands are wrong! # # In principle, the cleanest way to run workflows is first run a bunch of # `seekpath_structure_analysis`, store the primitive structure and the corresponding kpath, # when launching `WannierBandsWorkChain` always use both structure and kpath for inputs. # # Note don't use `seekpath_structure_analysis`, since it's a calcfunction and will # modify aiida database! args = {"structure": structure} if bands_kpoints_distance: args["reference_distance"] = bands_kpoints_distance result = get_explicit_kpoints_path(**args) primitive_structure = result["primitive_structure"] # ase Atoms class can test if two structures are the same # if structure.get_ase() == primitive_structure.get_ase(): if len(structure.sites) == len(primitive_structure.sites): parent_builder = parent_class.get_builder_from_protocol( codes=codes, structure=structure, **kwargs, summary=summary, print_summary=False, ) # If set `kpoint_path`, the workchain won't run seekpath. # However, to be consistent, if no `kpoint_path` and `bands_kpoints` provided, I will # always run seekpath inside workchain. # parent_builder.kpoint_path = orm.Dict( # dict={ # 'path': result['parameters']['path'], # 'point_coords': result['parameters']['point_coords'] # } # ) # parent_builder.kpoint_path = result['explicit_kpoints'] else: notes = summary.get("notes", []) notes.append( f"The input structure {structure.get_formula()}<{structure.pk}> is a supercell, " "the auto generated parameters are for the primitive cell " f"{primitive_structure.get_formula()} found by seekpath. " "Although this is inconsistent, after submitting the workchain a seekpath run will " "reduce the structure to primitive cell, so the Wannierisation is correct." ) summary["notes"] = notes # I need to use primitive cell to generate all the input parameters, e.g. num_wann, num_bands, etc. parent_builder = parent_class.get_builder_from_protocol( codes=codes, structure=primitive_structure, **kwargs, summary=summary, print_summary=False, ) # Don't set `kpoint_path` and `bands_kpoints_distance`, so the workchain will run seekpath. # However I still need to use the original cell, so the `seekpath_structure_analysis` will # store the provenance from original cell to primitive cell. parent_builder.structure = structure else: parent_builder = parent_class.get_builder_from_protocol( # pylint: disable=too-many-function-args codes, structure, **kwargs, summary=summary, print_summary=False ) # Prepare workchain builder # I need to explicitly write `Wannier90BandsWorkChain.get_builder()` instead of # `cls.get_builder()`, otherwise for a subclass ,e.g. `Wannier90OptimizeWorkChain`, # it will return the builder of the subclass. builder = Wannier90BandsWorkChain.get_builder() if kpoint_path: builder.kpoint_path = kpoint_path if bands_kpoints: builder.bands_kpoints = bands_kpoints protocol_inputs = Wannier90BandsWorkChain.get_protocol_inputs( protocol=kwargs.get("protocol", None), overrides=kwargs.get("overrides", None), ) inputs = parent_builder._inputs(prune=True) # pylint: disable=protected-access inputs = recursive_merge_container(inputs, protocol_inputs) builder = recursive_merge_builder(builder, inputs) if print_summary: cls.print_summary(summary) return builder def setup(self): """Define the current structure in the context to be the input structure.""" from aiida_wannier90_workflows.utils.kpoints import get_path_from_kpoints super().setup() self.ctx.current_kpoint_path = None self.ctx.current_bands_kpoints = None if not self.should_run_seekpath(): if "kpoint_path" in self.inputs: self.ctx.current_kpoint_path = get_path_from_kpoints( self.inputs.kpoint_path ) if "bands_kpoints" in self.inputs: self.ctx.current_bands_kpoints = self.inputs.bands_kpoints def should_run_seekpath(self): """Seekpath should only be run if the `kpoint_path` or `bands_kpoints` input is not specified.""" return not any(_ in self.inputs for _ in ("kpoint_path", "bands_kpoints")) def run_seekpath(self): """Run the structure through SeeKpath to get the primitive and normalized structure.""" from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import ( seekpath_structure_analysis, ) args = { "structure": self.inputs.structure, "metadata": {"call_link_label": "seekpath_structure_analysis"}, } if "bands_kpoints_distance" in self.inputs: args["reference_distance"] = self.inputs["bands_kpoints_distance"] result = seekpath_structure_analysis(**args) self.ctx.current_structure = result["primitive_structure"] # Add `kpoint_path` for Wannier bands self.ctx.current_kpoint_path = orm.Dict( dict={ "path": result["parameters"]["path"], "point_coords": result["parameters"]["point_coords"], } ) structure_formula = self.inputs.structure.get_formula() primitive_structure_formula = result["primitive_structure"].get_formula() self.report( f"launching seekpath: {structure_formula} -> {primitive_structure_formula}" ) self.out("primitive_structure", result["primitive_structure"]) self.out("seekpath_parameters", result["parameters"]) def prepare_wannier90_pp_inputs(self): """Override parent method.""" base_inputs = super().prepare_wannier90_pp_inputs() inputs = base_inputs["wannier90"] parameters = inputs.parameters.get_dict() parameters["bands_plot"] = True inputs.parameters = orm.Dict(parameters) if self.ctx.current_kpoint_path: inputs.kpoint_path = self.ctx.current_kpoint_path if self.ctx.current_bands_kpoints: inputs.bands_kpoints = self.ctx.current_bands_kpoints base_inputs["wannier90"] = inputs return base_inputs def results(self): """Attach the relevant output nodes from the band calculation to the workchain outputs for convenience.""" super().results() if "interpolated_bands" in self.outputs["wannier90"]: w90_bands = self.outputs["wannier90"]["interpolated_bands"] self.out("band_structure", w90_bands)