Source code for pythtb.wannier

import numpy as np
import logging
from .wfarray import WFArray
from .visualization import plot_centers, plot_decay, plot_density
from .mesh import Mesh
from .utils import mat_exp, copydoc
from itertools import product
from typing import TYPE_CHECKING

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    pass

__all__ = ["Wannier"]


[docs] class Wannier: r"""Construct Wannier functions through the projection method. This class implements projection, disentanglement [2]_, and maximal-localization [1]_ workflows using Bloch states represented in a finite tight-binding orbital basis. The high-level workflow is: 1. Single-shot SVD projection to trial functions (:meth:`project`). 2. Optional disentanglement in an outer/frozen window (:meth:`disentangle`). 3. Unitary gauge optimization for maximal localization (:meth:`maxloc`). Parameters ---------- bloch_states : WFArray Bloch-like states to Wannierize. The mesh must be a torus in k-space (no endpoint duplication), and states must already be populated. Notes ----- - :class:`Wannier` performs a *re-Wannierization* relative to typical Wannier90 workflows: - Wannier90 commonly starts from first-principles Bloch states and projects onto trial orbitals, typically taking the form of localized atomic orbitals. - :class:`Wannier` works directly with :class:`WFArray` states and trial functions expressed in the same tight-binding orbital/spin basis. - Wannier functions are obtained from Bloch-like states :math:`\tilde{\psi}_{n\mathbf{k}}` via inverse FFT over k-axes: .. math:: w_{n\mathbf{R}} = \frac{1}{\sqrt{N_k}} \sum_{\mathbf{k}} e^{i\mathbf{k}\cdot\mathbf{R}} \tilde{\psi}_{n\mathbf{k}}. References ---------- .. [1] Marzari, N., & Vanderbilt, D. Phys. Rev. B 56, 12847 (1997). .. [2] Souza, I., Marzari, N., & Vanderbilt, D. Phys. Rev. B 65, 035109 (2001). """ def __init__(self, bloch_states: WFArray): self._wfa: WFArray = bloch_states if not self.mesh.is_k_torus: raise ValueError( "Mesh is not a torus. The Wannier class requires a toroidal k-mesh." "To construct a toroidal k-space mesh use `Mesh.build_grid`" ) for ax in self.mesh.k_axes: if ax.has_endpoint: raise ValueError( f"Detected a closed k-axis: {ax}. The endpoints of the Brillouin zone " f"must not be included." ) ranges = [np.arange(-nk // 2, nk // 2) for nk in self.nks] mesh = np.meshgrid(*ranges, indexing="ij") # used for real space looping of WFs self.supercell = np.stack(mesh, axis=-1).reshape( # (..., len(nks)) -1, len(self.nks) ) # (product, dims) @property def mesh(self) -> Mesh: """Mesh associated with this Wannier workflow. Returns ------- Mesh Mesh used by :attr:`bloch_states`. """ return self.bloch_states.mesh @property def lattice(self): """Lattice associated with this Wannier workflow. Returns ------- Lattice Lattice used by :attr:`bloch_states`. """ return self._wfa.lattice @property def bloch_states(self) -> WFArray: """Input Bloch states to be Wannierized. Returns ------- WFArray Source state container. """ return self._wfa @property def tilde_states(self) -> WFArray: r"""Bloch-like states :math:`\tilde{\psi}_{n\mathbf{k}}`. These states are Fourier transformed to build Wannier functions and are related to reference states by a (semi-)unitary gauge rotation: .. math:: |\tilde{\psi}_{n\mathbf{k}} \rangle = \sum_{m=1}^{N} U_{mn}^{(\mathbf{k})} |\psi_{m\mathbf{k}} \rangle Returns ------- WFArray Current Bloch-like states. Raises ------ ValueError If tilde states have not been initialized. """ if not hasattr(self, "_tilde_states"): raise ValueError( "Bloch-like states have not been set. " "Use `set_tilde_states` or `project`." ) return getattr(self, "_tilde_states", None) @property def nks(self) -> tuple[int, ...]: """Number of k points along each k-axis. Returns ------- tuple of int Mesh shape in reciprocal directions. """ return self.mesh.shape_k @property def wannier(self) -> np.ndarray: r"""Wannier functions in the supercell implied by the k-grid. The Wannier functions are discrete inverse Fourier transforms of :math:`\tilde{\psi}`: .. math:: w_{n\mathbf{R}} = \frac{1}{\sqrt{N_k}} \sum_{\mathbf{k}} e^{i\mathbf{k} \cdot \mathbf{R}} \tilde{\psi}_{n\mathbf{k}} Returns ------- np.ndarray Wannier functions with mesh/supercell and orbital/spin axes. Raises ------ ValueError If tilde states are not initialized. """ if not self.tilde_states.filled: raise ValueError("Tilde states are not initialized.") return getattr(self, "_wannier", None) @property def spread(self) -> list[float]: r"""Quadratic spread :math:`\Omega_n` for each Wannier function. .. math:: \Omega_n = \langle \mathbf{0} n | r^2 | \mathbf{0} n \rangle - \langle \mathbf{0} n | \mathbf{r} | \mathbf{0} n \rangle^2 Returns ------- list of float Per-band quadratic spreads. Raises ------ ValueError If tilde states are not initialized. """ if not self.tilde_states.filled: raise ValueError("Tilde states are not initialized.") return getattr(self, "_spread", None) @property def Omega_OD(self) -> float: r"""Off-diagonal gauge-dependent spread :math:`\Omega_{\mathrm{OD}}`. Part of the decomposition of the quadratic spread into gauge-invariant (:math:`\widetilde{\Omega}`) and gauge-dependent (:math:`\widetilde{\Omega}`) parts, .. math:: \Omega = \widetilde{\Omega} + \Omega_I = \Omega_{\rm OD} + \Omega_{\rm D} + \Omega_I The off-diagonal part :math:`\Omega_{\rm OD}` is computed via .. math:: \Omega_{\rm OD} = \frac{1}{N_k} \sum_{\mathbf{k}, \mathbf{b}} w_b \sum_{m\neq n} |M_{mn}^{(\mathbf{b})}(\mathbf{k})|^2 Returns ------- float Off-diagonal spread contribution. Raises ------ ValueError If tilde states are not initialized. """ if not self.tilde_states.filled: raise ValueError("Tilde states are not initialized.") return getattr(self, "_omega_od", None) @property def Omega_D(self) -> float: r"""Diagonal gauge-dependent spread :math:`\Omega_{\mathrm{D}}`. Part of the decomposition of the quadratic spread into gauge-invariant (:math:`\widetilde{\Omega}`) and gauge-dependent (:math:`\widetilde{\Omega}`) parts, .. math:: \Omega = \widetilde{\Omega} + \Omega_I = \Omega_{\rm OD} + \Omega_{\rm D} + \Omega_I The diagonal part :math:`\Omega_{\rm D}` is computed via .. math:: \Omega_{\rm D} = \frac{1}{N_k} \sum_{\mathbf{k}, \mathbf{b}} w_b \sum_n \left( -\operatorname{Im}\!\left[\ln M_{nn}^{(\mathbf{b})}(\mathbf{k})\right] - \mathbf{b}\cdot\mathbf{r}_n \right)^2 Returns ------- float Diagonal spread contribution. Raises ------ ValueError If tilde states are not initialized. """ if not self.tilde_states.filled: raise ValueError("Tilde states are not initialized.") return getattr(self, "_omega_d", None) @property def Omega_I(self) -> float: r"""Gauge-invariant spread :math:`\Omega_I`. Part of the decomposition of the quadratic spread into gauge-invariant (:math:`\widetilde{\Omega}`) and gauge-dependent (:math:`\widetilde{\Omega}`) parts, .. math:: \Omega = \widetilde{\Omega} + \Omega_I = \Omega_{\rm OD} + \Omega_{\rm D} + \Omega_I The gauge-invariant part :math:`\Omega_I` is independent of the choice of Wannier gauge. It is computed via .. math:: \Omega_I = \frac{1}{N_k} \sum_{\mathbf{k}, \mathbf{b}} w_b \left( N_{\rm bands} - \sum_{m,n} |M_{mn}^{(\mathbf{b})}(\mathbf{k})|^2 \right) Returns ------- float Gauge-invariant spread contribution. Raises ------ ValueError If tilde states are not initialized. """ if not self.tilde_states.filled: raise ValueError("Tilde states are not initialized.") return getattr(self, "_omega_i", None) @property def centers(self) -> np.ndarray: r"""Wannier centers in Cartesian coordinates. The Wannier center for band :math:`n` is obtained from the phases of the diagonal overlaps, .. math:: \mathbf{r}_n \;=\; -\frac{1}{N_k} \sum_{\mathbf{k}, \mathbf{b}} w_b \mathbf{b} \operatorname{Im}\!\left[\ln M_{nn}^{(\mathbf{b})}(\mathbf{k})\right] \, , Returns ------- np.ndarray Array of shape ``(n_wannier, dim_r)``. Raises ------ ValueError If tilde states are not initialized. """ if not self.tilde_states.filled: raise ValueError("Tilde states are not set.") return getattr(self, "_centers", None) @property def trial_wfs(self) -> np.ndarray: """Trial wavefunctions used for projection. Returns ------- np.ndarray or None Trial wavefunctions in orbital (and optional spin) basis. """ return getattr(self, "_trial_wfs", None) @property def num_twfs(self) -> int: """Number of trial wavefunctions. Returns ------- int Number of trial wavefunctions. Raises ------ ValueError If trial wavefunctions are not initialized. """ if self.trial_wfs is None: raise ValueError("Trial wavefunctions are not set.") return self.trial_wfs.shape[0] @property def Amn(self) -> np.ndarray: r"""Overlap matrix between reference states and trial wavefunctions. The overlap matrix is defined as .. math:: A(\mathbf{k})_{mn} = \langle \psi_{m \mathbf{k}} | t_{n} \rangle where :math:`|\psi_{n\mathbf{k}}\rangle` are the Bloch energy eigenstates and :math:`|t_j\rangle` are the trial wavefunctions. Returns ------- np.ndarray or None Last computed overlap matrix, if available. """ return getattr(self, "_A", None)
[docs] def info(self, precision=8): """Print a formatted report of Wannier centers and spreads. Parameters ---------- precision : int, optional Number of decimal places in printed values. Raises ------ ValueError If tilde states are not initialized. """ if not getattr(self.tilde_states, "filled", False): raise ValueError("Tilde states are not set.") spreads = np.asarray(self.spread, float) centers = np.atleast_2d(np.asarray(self.centers, float)) n, d = centers.shape lines = ["Wannier Function Report"] # individual WF rows for i, (c, s) in enumerate(zip(centers, spreads), 1): c_str = ", ".join(f"{x:.{precision}f}" for x in c) lines.append(f"WF {i}: center = [{c_str}] Omega = {s:.{precision}f}") # totals sum_c = centers.sum(axis=0) sum_s = spreads.sum() sum_c_str = ", ".join(f"{x:.{precision}f}" for x in sum_c) lines.append(f"Sum : center = [{sum_c_str}] Omega tot = {sum_s:.{precision}f}") # Omegas Omega_I = float(getattr(self, "Omega_I", np.nan)) Omega_D = float(getattr(self, "Omega_D", np.nan)) Omega_OD = float(getattr(self, "Omega_OD", np.nan)) Omega_tot = Omega_I + Omega_D + Omega_OD lines += [ f"Omega I = {Omega_I:.{precision}f}", f"Omega D = {Omega_D:.{precision}f}", f"Omega OD = {Omega_OD:.{precision}f}", f"Omega tot = {Omega_tot:.{precision}f}", ] # determine longest line maxlen = max(len(line) for line in lines) divider = "=" * maxlen sub_div = "-" * maxlen # insert dividers at appropriate places lines.insert(1, divider) lines.insert(len(lines) - 4, sub_div) out = "\n".join(lines) print(out)
[docs] def get_centers(self, cartesian=False): r"""Return Wannier centers in Cartesian or fractional coordinates. The center of Wannier function :math:`n` is computed from the phases of diagonal overlap matrices as .. math:: \mathbf{r}_n = -\frac{1}{N_k} \sum_{\mathbf{k},\mathbf{b}} w_b\,\mathbf{b}\, \operatorname{Im}\!\left[\ln M_{nn}^{(\mathbf{b})}(\mathbf{k})\right], where :math:`M_{mn}^{(\mathbf{b})}(\mathbf{k})` are nearest-neighbor overlap matrices of cell-periodic states, :math:`w_b` are shell weights, and :math:`N_k` is the number of k points. Parameters ---------- cartesian : bool, optional If ``True``, return Cartesian coordinates. If ``False``, return fractional coordinates in the lattice basis. Returns ------- np.ndarray Wannier centers with shape ``(n_wannier, dim_r)``. Notes ----- If ``cartesian=False``, the returned coordinates are fractional components in the lattice basis. """ if cartesian: return self.centers else: return self.centers @ np.linalg.inv(self.lattice.lat_vecs)
[docs] def get_trial_wfs(self, twf_list=None): """Build normalized trial-wavefunction array from tuple specifications. .. versionadded:: 2.0.2 Parameters ---------- twf_list : list of list of tuple or None, optional Trial-wavefunction specification. For spinless systems each entry is ``(orb, amp)``; for spinful systems each entry is ``(orb, spin, amp)``. If ``None``, return previously stored trial wavefunctions. Returns ------- np.ndarray Normalized trial wavefunctions with shape ``(n_trial, n_orb[, n_spin])``. """ if twf_list is None: return self._trial_wfs # number of trial functions to define num_tf = len(twf_list) if self.bloch_states.spinful: tfs = np.zeros( [num_tf, self.lattice.norb, self.bloch_states.nspin], dtype=complex ) for j, tf in enumerate(twf_list): assert isinstance(tf, (list, np.ndarray)), ( "Trial function must be a list of tuples" ) for orb, spin, amp in tf: tfs[j, orb, spin] = amp tfs[j] /= np.linalg.norm(tfs[j]) else: # initialize array containing tfs = "trial functions" tfs = np.zeros([num_tf, self.lattice.norb], dtype=complex) for j, tf in enumerate(twf_list): assert isinstance(tf, (list, np.ndarray)), ( "Trial function must be a list of tuples" ) for site, amp in tf: tfs[j, site] = amp tfs[j] /= np.linalg.norm(tfs[j]) return tfs
[docs] def set_trial_wfs(self, tf_list): r"""Set trial wavefunctions for Wannierization. Parameters ---------- tf_list : list of list of tuple List of trial wavefunctions. Each trial wavefunction is a list of the form ``[(orb, amp), ...]``, for spinless models, or ``[(orb, spin, amp), ...]`` for spinful models, where ``orb`` is the orbital index, ``spin`` is the spin index, and ``amp`` is the complex amplitude. Trial wavefunctions are normalized internally, so only the relative amplitudes matter. Examples -------- For a system with 4 orbitals and no spin, the following defines two trial wavefunctions: >>> twf_list = [[(0, 1.0), (2, 1.0)], [(1, 1.0), (3, -1.0)]] >>> wan.set_trial_wfs(twf_list) This defines two trial wavefunctions: the first is an equal superposition of orbitals 0 and 2, and the second is an equal superposition of orbitals 1 and 3 with a relative minus sign. """ self._trial_wfs = self.get_trial_wfs(tf_list) self._tilde_states: WFArray = WFArray( self.lattice, self.mesh, nstates=self.num_twfs, spinful=self.bloch_states.spinful, )
[docs] def set_tilde_states(self, states, is_cell_periodic=True, is_spin_axis_flat=False): r"""Set the Bloch-like states for the Wannier functions. These states are Fourier transformed to form the Wannier functions. They are related to the original energy eigenstates via the (semi-) unitary transformation .. math:: |\tilde{\psi}_{n\mathbf{k}} \rangle = \sum_{m=1}^{N} U_{mn}^{(\mathbf{k})} |\psi_{m\mathbf{k}} \rangle Parameters ---------- states : np.ndarray The states to set as Bloch-like states. Must have the shape ``(nk1, ..., nstates, n_orbs[, n_spins])``. is_cell_periodic : bool, optional Whether to treat ``states`` as cell-periodic parts :math:`u_{n\mathbf{k}}`. is_spin_axis_flat : bool, optional Whether the spin dimension is flattened into the orbital dimension. If True, ``states`` must have shape ``(nk1, ..., nstates, n_orbs*n_spins)``. If False, ``states`` must have shape ``(nk1, ..., nstates, n_orbs, n_spins)``. Defaults to ``False``. Raises ------ ValueError If the input states are not a numpy array or have an invalid shape. Notes ----- - If ``is_cell_periodic`` is True, the states are treated as cell-periodic parts of Bloch functions :math:`u_{n\mathbf{k}}`, otherwise as full Bloch functions :math:`\psi_{n\mathbf{k}}`. - If ``is_spin_axis_flat`` is False and wavefunctions have spin, states are reshaped to flatten the spin dimension into the orbital dimension. - The Wannier functions, spreads, and centers are computed upon setting the Bloch-like states. """ if not isinstance(states, np.ndarray): raise ValueError("Bloch-like states must be a numpy array.") if not is_spin_axis_flat and ( states.ndim != self.mesh.nk_axes + 2 + (self.bloch_states.nspin - 1) ): raise ValueError( f"Bloch-like states must have shape (nk1, ..., nstates, n_orbs[, n_spins]), " f"but got {states.shape}." ) elif is_spin_axis_flat and (states.ndim != self.mesh.nk_axes + 2): raise ValueError( f"Bloch-like states must have shape (nk1, ..., nstates, n_orbs*n_spins), " f"but got {states.shape}." ) if self.bloch_states.spinful and not is_spin_axis_flat: states = states.reshape((*states.shape[:-2], -1)) logger.info("Setting Bloch-like states...") self.tilde_states.set_states( states, is_cell_periodic=is_cell_periodic, is_spin_axis_flat=is_spin_axis_flat, ) # Fourier transform Bloch-like states to set WFs psi_nk = self.tilde_states.psi_nk nk_axes = self.mesh.nk_axes # FFT NOTE: A non-repeating grid is required for consistent inverse FFTs. self._wannier = self.WFs = np.fft.ifftn( psi_nk, axes=[i for i in range(nk_axes)], norm=None ) # set spreads spread = self._spread_recip(decomp=True) self._spread = spread[0][0] self._omega_i = spread[0][1] self._omega_d = spread[0][2] self._omega_od = spread[0][3] self._centers = spread[1]
def _compute_Amn(self, psi_nk, twfs, band_idxs): r"""Overlap matrix between Bloch states and trial wavefunctions. The overlap matrix is defined as .. math:: A_{k, n, j} = \langle \psi_{n,k} \mid t_j \rangle where :math:`|\psi_{n\mathbf{k}}\rangle` are reference states and :math:`|t_j\rangle` are trial wavefunctions. Parameters ---------- psi_nk : np.ndarray or None States used for overlaps. If ``None``, use Bloch eigenstates from :attr:`bloch_states`. Expected shape: ``(*shape_mesh, n_states, n_orb*n_spin)``. twfs : np.ndarray Trial wavefunctions with shape ``(n_trial, n_orb[, n_spin])``. band_idxs : sequence of int State indices to include from ``psi_nk``. Returns ------- np.ndarray Overlap matrix with shape ``(*shape_mesh, n_selected, n_trial)``. """ if psi_nk is None: # get Bloch psi_nk energy eigenstates _, psi_nk = self.bloch_states.states( flatten_spin_axis=True, return_psi=True ) # only keep band_idxs psi_nk = np.take(psi_nk, band_idxs, axis=-2) trial_wfs = twfs # flatten along spin dimension in case spin is considered trial_wfs = trial_wfs.reshape((*trial_wfs.shape[:1], -1)) A_k = np.einsum("...ij, kj -> ...ik", psi_nk.conj(), trial_wfs) return A_k def _single_shot_project(self, psi_nk, twfs, state_idx): """Perform single-shot SVD projection/alignment onto trial wavefunctions. Parameters ---------- psi_nk : np.ndarray States to project with shape ``(*mesh_shape, n_states, n_orb*n_spin)``. twfs : np.ndarray Trial wavefunctions with shape ``(n_trial, n_orb[, n_spin])``. state_idx : sequence of int Indices of states to project. Returns ------- np.ndarray Projected states with shape ``(*mesh_shape, n_selected, n_orb*n_spin)``. """ A_k = self._compute_Amn(psi_nk, twfs, state_idx) V_k, _, Wh_k = np.linalg.svd(A_k, full_matrices=False) # take only state_idxs psi_nk = np.take(psi_nk, state_idx, axis=-2) # optimal alignment psi_tilde = np.einsum( "...mn, ...mj -> ...nj", V_k @ Wh_k, psi_nk ) # shape: (*mesh_shape, states, orbs*n_spin]) return psi_tilde
[docs] def project(self, tf_list: list = None, band_idxs: list = None, use_tilde=False): r"""Initialize or update Bloch-like states by projection onto trial functions. This method performs the single-shot projection step used in Wannierization as desribed in [1]_ (Sec. II) and used in the disentanglement initialization of [2]_. For each k-point, it builds the overlap matrix between selected states and the trial wavefunctions, computes its SVD, and applies the optimal unitary (or semi-unitary) alignment to produce projected states :math:`\tilde{\psi}_{n\mathbf{k}}`. The projected states are passed to :meth:`set_tilde_states`, which also updates derived quantities (Wannier functions, centers, and spreads). Parameters ---------- tf_list : list of list of tuple, optional Trial-wavefunction specification passed to :meth:`set_trial_wfs`. If ``None``, already stored trial wavefunctions are reused. Each trial wavefunction is a list of the form ``[(orb, amp), ...]``, for spinless models, or ``[(orb, spin, amp), ...]`` for spinful models, where ``orb`` is the orbital index, ``spin`` is the spin index, and ``amp`` is the complex amplitude. Trial wavefunctions are normalized internally, so only the relative amplitudes matter. band_idxs : list of int, optional Indices of states to project. - If ``use_tilde=False`` and ``band_idxs is None``, defaults to the first half of Bloch eigenstates (half-filling assumption). - If ``use_tilde=True`` and ``band_idxs is None``, defaults to all current tilde-state indices. use_tilde : bool, optional If ``False`` (default), project from Bloch energy eigenstates. If ``True``, re-project within the current tilde-state manifold. Raises ------ ValueError If trial wavefunctions are unavailable. See Also -------- :meth:`set_trial_wfs` : for setting trial wavefunctions. :meth:`set_tilde_states` : for setting the Bloch-like states directly. Notes ----- - Specifically, for the overlap matrix .. math:: A_{n j}(\mathbf{k}) \;=\; \langle \psi_{n\mathbf{k}} \mid t_j \rangle , we compute :math:`A(\mathbf{k}) = V(\mathbf{k}) \Sigma(\mathbf{k}) W^\dagger(\mathbf{k})` and rotate the selected energy eigenstates by :math:`U(\mathbf{k}) \equiv V(\mathbf{k}) W^\dagger(\mathbf{k})`: .. math:: \tilde{\psi}_{n\mathbf{k}} \;=\; \sum_{m}^{\texttt{band_idxs}} U_{nm}(\mathbf{k}) \, \psi_{m\mathbf{k}} . - Projection can produce at most `len(band_idxs)` independent states. For best stability, use no more trial functions than selected bands (`n_trial <= len(band_idxs)`). References ---------- .. [1] Marzari, N., & Vanderbilt, D. Maximally localized generalized Wannier functions for composite energy bands. Phys. Rev. B 56, 12847 (1997). .. [2] Souza, I., Marzari, N., & Vanderbilt, D. Maximally localized Wannier functions for entangled energy bands. Phys. Rev. B 65, 035109 (2001). """ if tf_list is None: if self.trial_wfs is None: raise ValueError( "Trial wavefunctions must be set before Wannierization." ) else: self.set_trial_wfs(tf_list) twfs = self.trial_wfs if use_tilde: # projecting back onto tilde states if band_idxs is None: # assume we are projecting onto all tilde states band_idxs = list(range(self.tilde_states.nstates)) psi_til = self.tilde_states.states(flatten_spin_axis=True, return_psi=True)[ 1 ] psi_til_til = self._single_shot_project(psi_til, twfs, state_idx=band_idxs) self.set_tilde_states( psi_til_til, is_cell_periodic=False, is_spin_axis_flat=True ) else: # projecting onto Bloch energy eigenstates if band_idxs is None: # assume we are Wannierizing occupied bands n_occ = int(self.bloch_states.nstates / 2) # assuming half filled band_idxs = list(range(0, n_occ)) psi_nk = self.bloch_states.states(flatten_spin_axis=True, return_psi=True)[ 1 ] # shape: (*nks, states, orbs*n_spin]) psi_tilde = self._single_shot_project(psi_nk, twfs, state_idx=band_idxs) self.set_tilde_states( psi_tilde, is_cell_periodic=False, is_spin_axis_flat=True )
def _spread_recip(self, decomp=False): r"""Compute quadratic spreads and their MV97 decomposition on a discrete k-shell. Computes per-band spreads and (optionally) the decomposition :math:`(\Omega_I,\Omega_{\mathrm{D}},\Omega_{\mathrm{OD}})` using discrete overlaps on the first nearest-neighbor k-shell. Parameters ---------- decomp : bool, optional If True, also return the components :math:`\Omega_I`, :math:`\Omega_{\mathrm{D}}`, and :math:`\Omega_{\mathrm{OD}}`. Returns ------- tuple Spread-related quantities. Specifically: If ``decomp=False``: spread_n, r_n, rsq_n If ``decomp=True``: [spread_n, Omega_I, Omega_D, Omega_OD], r_n, rsq_n Notes ----- The implementation currently uses a **single k-shell** (``n_shell=1``) and assumes a **uniform Monkhorst–Pack mesh**. """ M = self.tilde_states.Mmn w_b, k_shell, _ = self.lattice.k_shell_weights(self.mesh.shape_k, n_shell=1) w_b, k_shell = w_b[0], k_shell[0] # Assume only one shell for now n_states = self.tilde_states.nstates nks = self.nks k_axes = tuple(self.mesh.k_axis_indices) Nk = np.prod(nks) diag_M = np.diagonal(M, axis1=-1, axis2=-2) log_diag_M_imag = np.log(diag_M).imag abs_diag_M_sq = abs(diag_M) ** 2 r_n = -(1 / Nk) * w_b * np.sum(log_diag_M_imag, axis=k_axes).T @ k_shell rsq_n = ( (1 / Nk) * w_b * np.sum( (1 - abs_diag_M_sq + log_diag_M_imag**2), axis=k_axes + tuple([-2]) ) ) spread_n = rsq_n - np.array( [np.vdot(r_n[n, :], r_n[n, :]) for n in range(r_n.shape[0])] ) if decomp: Omega_i = w_b * n_states * k_shell.shape[0] - (1 / Nk) * w_b * np.sum( abs(M) ** 2 ) Omega_d = ( (1 / Nk) * w_b * (np.sum((-log_diag_M_imag - k_shell @ r_n.T) ** 2)) ) Omega_od = (1 / Nk) * w_b * (+np.sum(abs(M) ** 2) - np.sum(abs_diag_M_sq)) return [spread_n, Omega_i, Omega_d, Omega_od], r_n, rsq_n else: return spread_n, r_n, rsq_n def _get_omega_til(self, Mmn, wb, k_shell): """Compute :math:`\\widetilde{\\Omega}` from overlap matrices. Parameters ---------- Mmn : np.ndarray Overlap matrices on nearest-neighbor shell. wb : float Shell weight. k_shell : np.ndarray Neighbor displacement vectors in reduced coordinates. Returns ------- float Gauge-dependent spread :math:`\\widetilde{\\Omega}`. """ nks = self.nks Nk = np.prod(nks) k_axes = tuple(self.mesh.k_axis_indices) diag_M = np.diagonal(Mmn, axis1=-1, axis2=-2) log_diag_M_imag = np.log(diag_M).imag abs_diag_M_sq = abs(diag_M) ** 2 r_n = -(1 / Nk) * wb * np.sum(log_diag_M_imag, axis=k_axes).T @ k_shell Omega_tilde = ( (1 / Nk) * wb * ( np.sum((-log_diag_M_imag - k_shell @ r_n.T) ** 2) + np.sum(abs(Mmn) ** 2) - np.sum(abs_diag_M_sq) ) ) return Omega_tilde def _get_omega_d(self, Mmn, wb, k_shell): """Compute diagonal gauge-dependent spread :math:`\\Omega_D`. Parameters ---------- Mmn : np.ndarray Overlap matrices on nearest-neighbor shell. wb : float Shell weight. k_shell : np.ndarray Neighbor displacement vectors in reduced coordinates. Returns ------- float Diagonal spread term. """ nks = self.nks Nk = np.prod(nks) k_axes = tuple(self.mesh.k_axis_indices) diag_M = np.diagonal(Mmn, axis1=-1, axis2=-2) log_diag_M_imag = np.log(diag_M).imag r_n = -(1 / Nk) * wb * np.sum(log_diag_M_imag, axis=k_axes).T @ k_shell Omega_d = (1 / Nk) * wb * (np.sum((-log_diag_M_imag - k_shell @ r_n.T) ** 2)) return Omega_d def _get_omega_od(self, Mmn, wb): """Compute off-diagonal gauge-dependent spread :math:`\\Omega_{OD}`. Parameters ---------- Mmn : np.ndarray Overlap matrices on nearest-neighbor shell. wb : float Shell weight. Returns ------- float Off-diagonal spread term. """ Nk = np.prod(self.nks) diag_M = np.diagonal(Mmn, axis1=-1, axis2=-2) abs_diag_M_sq = abs(diag_M) ** 2 Omega_od = (1 / Nk) * wb * (+np.sum(abs(Mmn) ** 2) - np.sum(abs_diag_M_sq)) return Omega_od def _get_omega_i(self, Mmn, wb, k_shell): """Compute gauge-invariant spread :math:`\\Omega_I`. Parameters ---------- Mmn : np.ndarray Overlap matrices on nearest-neighbor shell. wb : float Shell weight. k_shell : np.ndarray Neighbor displacement vectors in reduced coordinates. Returns ------- float Gauge-invariant spread term. """ Nk = np.prod(self.tilde_states.mesh.shape_k) n_states = self.tilde_states.nstates Omega_i = wb * n_states * k_shell.shape[0] - (1 / Nk) * wb * np.sum( abs(Mmn) ** 2 ) return Omega_i def _get_omega_i_k(self): r"""Calculate the gauge-independent quadratic spread for the Wannier functions. This function computes the gauge-invariant spread density related to :math:`\Omega_I` of the Wannier functions as a function of `k`. This is related to the integrated Cartesian-traced quantum metric. Returns ------- np.ndarray k-resolved contribution to :math:`\Omega_I`. """ P = self.tilde_states.projectors() _, Q_nbr = self.tilde_states._nbr_projectors(return_Q=True) nks = self.nks Nk = np.prod(nks) w_b, _, idx_shell = self.lattice.k_shell_weights(self.mesh.shape_k, n_shell=1) num_nnbrs = idx_shell[0].shape[0] T_kb = np.zeros((*nks, num_nnbrs), dtype=complex) for nbr_idx in range(num_nnbrs): # nearest neighbors T_kb[..., nbr_idx] = np.trace( P[..., :, :] @ Q_nbr[..., nbr_idx, :, :], axis1=-1, axis2=-2 ) return (1 / Nk) * w_b[0] * np.sum(T_kb, axis=-1) ####### Maximally Localized WF ####### def _optimal_subspace( self, n_wfs=None, inner_window="occupied", outer_window="all", iter_num=1000, verbose=True, tol=1e-10, beta=1, tf_speedup=False, ): r"""Iteratively optimize a subspace that minimizes :math:`\Omega_I`. Parameters ---------- n_wfs : int or None, optional Target subspace size. inner_window : str, tuple, list, dict, or None, optional Frozen window specification. outer_window : str, tuple, list, or dict, optional Candidate/disentanglement window specification. iter_num : int, optional Maximum number of iterations. verbose : bool, optional If ``True``, print iteration updates. tol : float, optional Convergence tolerance on :math:`\Delta\Omega_I`. beta : float, optional Linear-mixing factor for projector updates. tf_speedup : bool, optional If ``True``, use TensorFlow eigensolver. Returns ------- np.ndarray Optimized states spanning the selected subspace. Raises ------ ValueError If window specifications are invalid or frozen-state count exceeds requested subspace size at any k point. ImportError If ``tf_speedup=True`` and TensorFlow is unavailable. """ # useful constants nks = self.nks Nk = np.prod(nks) n_orb = self.bloch_states.norb n_states = self.bloch_states.nstates n_occ = int(n_states / 2) # eigenenergies and eigenstates for inner/outer window energies = self.bloch_states.energies u_nk = self.bloch_states.states(flatten_spin_axis=True) # number of states in target manifold if n_wfs is None: n_wfs = self.tilde_states.nstates ########### Setting energy windows ############ #### outer window #### if isinstance(outer_window, str): if outer_window.lower() != "occupied" and outer_window.lower() != "all": raise ValueError( "If outer_window is a string, it must be 'occupied' or 'all'." ) outer_window_type = "bands" if outer_window.lower() == "all": outer_band_idxs = list(range(n_states)) outer_energy_range = [np.min(energies), np.max(energies)] elif outer_window.lower() == "occupied": outer_band_idxs = list(range(n_occ)) outer_band_energies = energies[..., outer_band_idxs] outer_energy_range = [ np.min(outer_band_energies), np.max(outer_band_energies), ] elif isinstance(outer_window, dict): if list(outer_window.keys())[0].lower() not in ["bands", "energy"]: raise ValueError( "If outer_window is a dict, it must have keys 'bands' or 'energy'." ) outer_window_type = list(outer_window.keys())[0].lower() if outer_window_type == "bands": outer_band_idxs = list(outer_window.values())[0] outer_band_energies = energies[..., outer_band_idxs] outer_energy_range = [ np.min(outer_band_energies), np.max(outer_band_energies), ] elif outer_window_type == "energy": outer_energy_range = np.sort(list(outer_window.values())[0]) elif ( isinstance(outer_window, (list, tuple)) and len(outer_window) == 2 and all(isinstance(x, (int, float, np.floating)) for x in outer_window) ): outer_window_type = "energy" outer_energy_range = [float(outer_window[0]), float(outer_window[1])] ######## inner window ######## if inner_window is None: N_inner = 0 inner_window_type = outer_window_type inner_band_idxs = None # make inner energies such that no states are found inside inner_energies = [np.inf, -np.inf] elif isinstance(inner_window, str): if inner_window.lower() != "occupied": raise ValueError("If inner_window is a string, it must be 'occupied'.") inner_window_type = "bands" inner_band_idxs = list(range(n_occ)) inner_band_energies = energies[..., inner_band_idxs] inner_energies = [np.min(inner_band_energies), np.max(inner_band_energies)] elif isinstance(inner_window, dict): if list(inner_window.keys())[0].lower() not in ["bands", "energy"]: raise ValueError( "If inner_window is a dict, it must have keys 'bands' or 'energy'." ) inner_window_type = list(inner_window.keys())[0].lower() if inner_window_type == "bands": inner_band_idxs = list(inner_window.values())[0] inner_band_energies = energies[..., inner_band_idxs] inner_energies = [ np.min(inner_band_energies), np.max(inner_band_energies), ] elif inner_window_type == "energy": inner_energies = np.sort(list(inner_window.values())[0]) elif ( isinstance(inner_window, (list, tuple)) and len(inner_window) == 2 and all(isinstance(x, (int, float, np.floating)) for x in inner_window) ): inner_window_type = "energy" inner_energies = [float(inner_window[0]), float(inner_window[1])] if inner_window_type == outer_window_type == "bands": logger.debug( "Both inner and outer windows specified via band indices. " "Using the faster optimal_subspace_bands method instead." ) # defer to the faster function return self._optimal_subspace_bands( n_wfs=n_wfs, frozen_bands=inner_band_idxs, disentang_bands=outer_band_idxs, iter_num=iter_num, verbose=verbose, tol=tol, beta=beta, tf_speedup=tf_speedup, ) # create array of nans for masking nan = np.empty(u_nk.shape) nan.fill(np.nan) # mask out states outside outer window states_sliced = np.where( np.logical_and( energies[..., np.newaxis] >= outer_energy_range[0], energies[..., np.newaxis] <= outer_energy_range[1], ), u_nk, nan, ) mask_outer = np.isnan(states_sliced) masked_outer_states = np.ma.masked_array(states_sliced, mask=mask_outer) # mask out states outside inner window states_sliced = np.where( np.logical_and( energies[..., np.newaxis] >= inner_energies[0], energies[..., np.newaxis] <= inner_energies[1], ), u_nk, nan, ) mask_inner = np.isnan(states_sliced) masked_inner_states = np.ma.masked_array(states_sliced, mask=mask_inner) # minimization manifold if inner_window is not None: # states in outer manifold and outside inner manifold min_mask = ~np.logical_and(~mask_outer, mask_inner) min_states = np.ma.masked_array(u_nk, mask=min_mask) min_states = np.ma.filled(min_states, fill_value=0) N_inner = (~masked_inner_states.mask).sum(axis=(-1, -2)) // n_orb if np.any(N_inner > n_wfs): mesh = self.mesh bad_kpts = np.where(N_inner > n_wfs) bad_kpts = [mesh.grid[idx] for idx in zip(*bad_kpts)] raise ValueError( f"Number of states in inner window exceeds n_wfs at k-points {bad_kpts}." ) num_keep = n_wfs - N_inner # matrix of integers keep_mask = ( np.arange(min_states.shape[-2]) >= (num_keep[:, :, np.newaxis, np.newaxis]) ) keep_mask = keep_mask.repeat(min_states.shape[-2], axis=-2) keep_mask = np.swapaxes(keep_mask, axis1=-1, axis2=-2) else: min_states = np.ma.filled(masked_outer_states, fill_value=0) # keep all the states from minimization manifold num_keep = np.full(min_states.shape, n_wfs) keep_mask = np.arange(min_states.shape[-2]) >= num_keep keep_mask = np.swapaxes(keep_mask, axis1=-1, axis2=-2) # Assumes only one shell for now w_b, _, _ = self.lattice.k_shell_weights(self.mesh.shape_k, n_shell=1) w_b = w_b[0] # Assume only one shell for now # Projector of initial tilde subspace at each k-point init_states = self.tilde_states P = init_states.projectors(return_Q=False) P_nbr, Q_nbr = init_states._nbr_projectors(return_Q=True) T_kb = np.einsum("...ij, ...kji -> ...k", P, Q_nbr) omega_I_prev = (1 / Nk) * w_b * np.sum(T_kb) logger.info(f"Initial Omega_I: {omega_I_prev.real}") if verbose: print(f"Initial Omega_I: {omega_I_prev.real}") P_min = np.copy(P) # for start of iteration P_nbr_min = np.copy(P_nbr) # for start of iteration Q_nbr_min = np.copy(Q_nbr) # for start of iteration if tf_speedup: try: import tensorflow as tf except ImportError: raise ImportError( "TensorFlow must be installed to use tf_speedup option." ) #### Start of minimization iteration #### for i in range(iter_num): P_avg = np.sum(w_b * P_nbr_min, axis=-3) Z = min_states.conj() @ P_avg @ np.transpose(min_states, axes=(0, 1, 3, 2)) if tf_speedup: Z_tf = tf.convert_to_tensor(Z, dtype=tf.complex64) eigvals, eigvecs = tf.linalg.eigh(Z_tf) # [..., val, idx] eigvals = eigvals.numpy() eigvecs = eigvecs.numpy() else: eigvals, eigvecs = np.linalg.eigh(Z) # [..., val, idx] eigvecs = np.swapaxes(eigvecs, axis1=-1, axis2=-2) # [..., idx, val] # eigvals = 0 correspond to states outside the minimization manifold. Mask these out. zero_mask = eigvals.round(10) == 0 non_zero_eigvals = np.ma.masked_array(eigvals, mask=zero_mask) non_zero_eigvecs = np.ma.masked_array( eigvecs, mask=np.repeat( zero_mask[..., np.newaxis], repeats=eigvals.shape[-1], axis=-1 ), ) # sort eigvals and eigvecs by eigenvalues in descending order excluding eigvals=0 sorted_eigvals_idxs = np.argsort(-non_zero_eigvals, axis=-1) # sorted_eigvals = np.take_along_axis(non_zero_eigvals, sorted_eigvals_idxs, axis=-1) sorted_eigvecs = np.take_along_axis( non_zero_eigvecs, sorted_eigvals_idxs[..., np.newaxis], axis=-2 ) sorted_eigvecs = np.ma.filled(sorted_eigvecs, fill_value=0) states_min = np.einsum("...ji, ...ik->...jk", sorted_eigvecs, min_states) keep_states_ma = np.ma.masked_array(states_min, mask=keep_mask) # need to concatenate with frozen states if inner_window is not None: min_states_ma = np.ma.concatenate( (keep_states_ma, masked_inner_states), axis=-2 ) min_states_sliced = min_states_ma[np.where(~min_states_ma.mask)] min_states_sliced = min_states_sliced.reshape((*nks, n_wfs, n_orb)) states_min = np.array(min_states_sliced) else: min_states_sliced = keep_states_ma[np.where(~keep_states_ma.mask)] min_states_sliced = min_states_sliced.reshape((*nks, n_wfs, n_orb)) states_min = np.array(min_states_sliced) # update projectors min_wfa = WFArray( self.lattice, self.mesh, nstates=states_min.shape[-2], spinful=self.bloch_states.spinful, ) min_wfa.set_states( states_min, is_cell_periodic=True, is_spin_axis_flat=True ) P_new = min_wfa.projectors() P_nbr_new = min_wfa._nbr_projectors(return_Q=False) if beta != 1: # for next iteration P_min = beta * P_new + (1 - beta) * P_min P_nbr_min = beta * P_nbr_new + (1 - beta) * P_nbr_min else: # for next iteration P_min = P_new P_nbr_min = P_nbr_new Q_nbr_min = np.eye(P_nbr_min.shape[-1]) - P_nbr_min T_kb = np.einsum("...ij, ...kji -> ...k", P_min, Q_nbr_min) omega_I_new = (1 / Nk) * w_b * np.sum(T_kb) delta = omega_I_new - omega_I_prev logger.info( f"iter {i:4d} | Ω_I = {omega_I_new.real:12.9e} | ΔΩ = {delta.real:10.5e}" ) if verbose: print( f"iter {i:4d} | Ω_I = {omega_I_new.real:12.9e} | ΔΩ = {delta.real:10.5e}" ) if abs(delta) <= tol: logger.info( f"Converged within tolerance in {i} iterations. Breaking the loop." ) if verbose: print( f"Converged within tolerance in {i} iterations. Breaking the loop." ) break if omega_I_new > omega_I_prev: beta = max(beta - 0.01, 0) logger.warning(f"Warning: Ω_I is increasing. Reducing beta to {beta}.") if verbose: print(f"Warning: Ω_I is increasing. Reducing beta to {beta}.") omega_I_prev = omega_I_new return states_min def _optimal_subspace_bands( self, n_wfs: int | None = None, frozen_bands: list | None = None, disentang_bands: list | str = "occupied", iter_num: int = 1000, tol: float = 1e-10, beta: float = 1, verbose: bool = True, tf_speedup: bool = False, ): r"""Optimize a subspace (band-index windows) to minimize :math:`\Omega_I`. This function utilizes the 'disentanglement' technique to find the subspaces throughout the BZ that minimizes the gauge-independent spread. Parameters ---------- n_wfs : int or None Number of states in the optimal subspace. If ``None``, the number of trial wavefunctions is used. frozen_bands : list of int or None, optional List of band indices defining the 'frozen window', specifying the states totally included within the optimized subspace. Defaults to `None`, in which case no bands are frozen. disentang_bands : list of int or {"occupied"}, optional List of band indices defining 'disentanglement window' where states are borrowed in order to minimize the gauge independent spread. If "occupied", all occupied bands are disentangled. Defaults to "occupied". iter_num : int, optional Maximum number of optimization iterations. Defaults to 100. tol : float, optional Convergence tolerance for the optimization. Defaults to 1e-10. beta : float, optional Mixing parameter for the optimization. If 1, the current step is taken fully. Lower values result in a percentage ``beta`` of the previous step being mixed into the result. Defaults to 1. verbose : bool, optional If True, print detailed information during optimization. tf_speedup : bool, optional If True, uses the ``tensorflow`` package for faster linear algebra operations. Returns ------- np.ndarray The states spanning the optimized subspace that minimizes the gauge-independent spread. Raises ------ ImportError If ``tf_speedup=True`` and TensorFlow is unavailable. """ nks = self.nks Nk = np.prod(nks) n_orb = self.lattice.norb n_occ = int(n_orb / 2) # Assumes only one shell for now w_b, _, _ = self.lattice.k_shell_weights(self.mesh.shape_k, n_shell=1) w_b = w_b[0] # Assume only one shell for now # initial subspace u_nk = self.bloch_states.states(flatten_spin_axis=True) # u_wfs_til = init_states.states(flatten_spin_axis=True) if n_wfs is None: # assume number of states in the subspace is number of tilde states n_wfs = self.tilde_states.nstates if isinstance(disentang_bands, str) and disentang_bands == "occupied": disentang_bands = list(range(n_occ)) # Projector of initial tilde subspace at each k-point if frozen_bands is None: N_inner = 0 init_states = self.tilde_states # manifold from which we borrow states to minimize omega_i comp_states = u_nk.take(disentang_bands, axis=-2) else: N_inner = len(frozen_bands) inner_states = u_nk.take(frozen_bands, axis=-2) P_inner = np.swapaxes(inner_states, -1, -2) @ inner_states.conj() Q_inner = np.eye(P_inner.shape[-1]) - P_inner P_tilde = self.tilde_states.projectors() # chosing initial subspace as highest eigenvalues MinMat = Q_inner @ P_tilde @ Q_inner _, eigvecs = np.linalg.eigh(MinMat) eigvecs = np.swapaxes(eigvecs, -1, -2) init_evecs = eigvecs[..., -(n_wfs - N_inner) :, :] init_states = WFArray( self.lattice, self.mesh, nstates=init_evecs.shape[-2], spinful=self.bloch_states.spinful, ) init_states.set_states( init_evecs, is_cell_periodic=False, is_spin_axis_flat=True ) comp_bands = list(np.setdiff1d(disentang_bands, frozen_bands)) comp_states = u_nk.take(comp_bands, axis=-2) P = init_states.projectors(return_Q=False) P_nbr, Q_nbr = init_states._nbr_projectors(return_Q=True) T_kb = np.einsum("...ij, ...kji -> ...k", P, Q_nbr) omega_I_prev = (1 / Nk) * w_b * np.sum(T_kb) logger.info(f"Initial Omega_I: {omega_I_prev.real}") if verbose: print(f"Initial Omega_I: {omega_I_prev.real}") P_min = np.copy(P) # for start of iteration P_nbr_min = np.copy(P_nbr) # for start of iteration if tf_speedup: try: import tensorflow as tf except ImportError: raise ImportError( "TensorFlow must be installed to use tf_speedup option." ) for i in range(iter_num): # states spanning optimal subspace minimizing gauge invariant spread P_avg = w_b * np.sum(P_nbr_min, axis=-3) Z = comp_states.conj() @ P_avg @ np.swapaxes(comp_states, -1, -2) if tf_speedup: Z_tf = tf.convert_to_tensor(Z, dtype=tf.complex64) _, eigvecs_tf = tf.linalg.eigh(Z_tf) eigvecs = eigvecs_tf.numpy() else: _, eigvecs = np.linalg.eigh(Z) # [val, idx] evecs_keep = eigvecs[..., -(n_wfs - N_inner) :] comp_min = np.swapaxes(evecs_keep, -1, -2) @ comp_states if frozen_bands is not None: states_min = np.concatenate((inner_states, comp_min), axis=-2) else: states_min = comp_min min_wfa = WFArray( self.lattice, self.mesh, nstates=states_min.shape[-2], spinful=self.bloch_states.spinful, ) min_wfa.set_states( states_min, is_cell_periodic=True, is_spin_axis_flat=True ) P_new = min_wfa.projectors() P_nbr_new = min_wfa._nbr_projectors(return_Q=False) if beta != 1: # for next iteration P_min = beta * P_new + (1 - beta) * P_min P_nbr_min = beta * P_nbr_new + (1 - beta) * P_nbr_min else: # for next iteration P_min = P_new P_nbr_min = P_nbr_new Q_nbr_min = np.eye(P_nbr_min.shape[-1]) - P_nbr_min T_kb = np.einsum("...ij, ...kji -> ...k", P_min, Q_nbr_min) omega_I_new = (1 / Nk) * w_b * np.sum(T_kb) delta = omega_I_new - omega_I_prev logger.info( f"iter {i:4d} | Ω_I = {omega_I_new.real:12.9e} | ΔΩ = {delta.real:10.5e}" ) if verbose: print( f"iter {i:4d} | Ω_I = {omega_I_new.real:12.9e} | ΔΩ = {delta.real:10.5e}" ) if abs(delta) <= tol: logger.info( f"Converged within tolerance in {i} iterations. Breaking the loop." ) if verbose: print( f"Converged within tolerance in {i} iterations. Breaking the loop." ) break if omega_I_new > omega_I_prev: beta = max(beta - 0.01, 0) logger.warning( f"Warning: Ω_I is increasing. Reducing beta from {beta + 0.01} to {beta}." ) if verbose: print(f"Warning: Ω_I is increasing. Reducing beta to {beta}.") omega_I_prev = omega_I_new return states_min def _max_loc_unitary( self, alpha=1 / 2, iter_num=100, verbose=False, tol=1e-10, grad_min=1e-3 ): r"""Find unitary rotations that minimize gauge-dependent spread. Parameters ---------- alpha : float, optional Step-size prefactor for gradient updates. iter_num : int Maximum number of iterations. verbose : bool, optional If ``True``, print progress. tol : float, optional Convergence tolerance on spread change. grad_min : float, optional Convergence tolerance on gradient norm. Returns ------- np.ndarray Unitary matrix field ``U(k)`` that rotates tilde states toward minimal gauge-dependent spread. """ M = self.tilde_states.Mmn w_b, k_shell, idx_shell = self.lattice.k_shell_weights( self.mesh.shape_k, n_shell=1 ) # Assumes only one shell for now w_b, k_shell, idx_shell = w_b[0], k_shell[0], idx_shell[0] k_axes = tuple(self.mesh.k_axis_indices) nks = self.nks shape_mesh = self.mesh.shape_axes Nk = np.prod(nks) num_state = self.tilde_states.nstates U = np.zeros( (*shape_mesh, num_state, num_state), dtype=complex ) # unitary transformation U[...] = np.eye(num_state, dtype=complex) # initialize as identity M0 = np.copy(M) # initial overlap matrix M = np.copy(M) # new overlap matrix # initializing omega_tilde_prev = self._get_omega_til(M, w_b, k_shell) grad_mag_prev = 0 for i in range(iter_num): r_n = ( -(1 / Nk) * w_b * np.sum( log_diag_M_imag := np.log(np.diagonal(M, axis1=-1, axis2=-2)).imag, axis=k_axes, ).T
[docs] @ k_shell ) q = log_diag_M_imag + (k_shell @ r_n.T) R = np.multiply( M, np.diagonal(M, axis1=-1, axis2=-2)[..., np.newaxis, :].conj() ) T = np.multiply( np.divide(M, np.diagonal(M, axis1=-1, axis2=-2)[..., np.newaxis, :]), q[..., np.newaxis, :], ) A_R = (R - np.swapaxes(R, axis1=-1, axis2=-2).conj()) / 2 S_T = (T + np.swapaxes(T, axis1=-1, axis2=-2).conj()) / (2j) G = 4 * w_b * np.sum(A_R - S_T, axis=-3) U = np.einsum( "...ij, ...jk -> ...ik", U, mat_exp((alpha / (4 * k_shell.shape[0] * w_b)) * G), ) for idx, idx_vec in enumerate(idx_shell): M[..., idx, :, :] = ( np.swapaxes(U, -1, -2).conj() @ M0[..., idx, :, :] @ np.roll( U, shift=tuple(-idx_vec), axis=tuple(self.mesh.k_axis_indices) ) ) grad_mag = np.linalg.norm(np.sum(G, axis=tuple(self.mesh.k_axis_indices))) omega_tilde_new = self._get_omega_til(M, w_b, k_shell) delta = omega_tilde_new - omega_tilde_prev logger.info( f"iter {i:4d} | Ω_tilde = {omega_tilde_new.real:12.9e} | ΔΩ = {delta.real:12.5e} | ‖∇‖ = {grad_mag:10.5e}" ) if verbose: print( f"iter {i:4d} | Ω_tilde = {omega_tilde_new.real:12.9e} | ΔΩ = {delta.real:12.5e} | ‖∇‖ = {grad_mag:10.5e}" ) # Check for convergence if abs(grad_mag) <= grad_min and abs(delta) <= tol: logger.info( f"Converged within tolerance in {i} iterations. Breaking the loop." ) if verbose: print( f"Converged within tolerance in {i} iterations. Breaking the loop." ) break if grad_mag_prev < grad_mag and i != 0: logger.warning("Warning: Gradient increasing.") if verbose: print("Warning: Gradient increasing.") # Reduce step size to try and stabilize # eps *= 0.9 grad_mag_prev = grad_mag omega_tilde_prev = omega_tilde_new return U def disentangle( self, n_wfs: int | None = None, outer_window: str | tuple | list | dict = "all", frozen_window: str | tuple | list | dict | None = None, max_iter: int = 1000, tol: float = 1e-10, mix: float = 1.0, tf_speedup: bool = False, verbose: bool = True, ): r"""Disentanglement of a subspace that minimizes gauge-independent spread. This procedure implements the Souza–Marzari–Vanderbilt (SMV) disentanglement algorithm [1]_. The goal is to select an ``n_wfs``-dimensional optimal subspace from a larger set of Bloch eigenstates in a specified "``outer window``," such that the gauge-independent part of the Wannier spread :math:`\Omega_I` is minimized. The procedure is iterative, updating the subspace at each k-point until self-consistency is achieved. Parameters ---------- n_wfs : int or None, optional Number of states in the optimal subspace. If ``None``, the number of trial wavefunctions is used. outer_window : str, tuple, list, or dict, optional Defines the "disentanglement window," i.e. the set of candidate states from which the optimal subspace is chosen. States outside this window are ignored. Options: - ``"occupied"``: All states below the Fermi level. - ``"all"``: All available states. - ``(Emin, Emax)``: Energy range in eV. - ``{"bands": [i1, i2, ...]}``: Explicit band indices. - ``{"energy": (Emin, Emax)}``: Energy window. Defaults to ``"all"``. frozen_window : str, tuple, list, dict, or None, optional Defines the "frozen window," i.e. states that must be exactly included in the subspace. This ensures that, for example, the occupied manifold is preserved while disentangling higher states. Options follow the same conventions as ``outer_window``. If ``None``, no states are frozen. Defaults to ``None``. max_iter : int, optional Maximum number of optimization iterations. Defaults to 1000. tol : float, optional Convergence tolerance for the optimization. Defaults to 1e-10. mix : float, optional Mixing parameter for iterative updates. ``mix=1`` uses the new step fully, while smaller values blend the new and old projectors. Defaults to 1. tf_speedup : bool, optional If True, use the ``tensorflow`` package for accelerated linear algebra. Defaults to False. verbose : bool, optional If True, print detailed iteration to the logger. Defaults to True. Notes ----- - The disentanglement algorithm iteratively refines the projectors onto the optimal subspace by solving the eigenvalue problem .. math:: \left[\sum_{\mathbf{b}} w_b \,\hat{\mathcal{P}}_{\mathbf{k}+\mathbf{b}}^{(i)}\right] | u_{m\mathbf{k}}^{(i)} \rangle = \lambda_{m\mathbf{k}}^{(i)} | u_{m\mathbf{k}}^{(i)} \rangle, where :math:`\hat{\mathcal{P}}_{\mathbf{k}+\mathbf{b}}^{(i)}` is the projector from the previous iteration. The states with the largest :math:`n_\text{wfs}` eigenvalues are selected to form the new subspace. - The role of the **outer window** is to provide flexibility: states above or below the frozen region can be borrowed to reduce :math:`\Omega_I`. The **frozen window** ensures that crucial states (e.g. fully occupied bands) are exactly included regardless of the optimization outcome. - After convergence, the ``.tilde_states`` attribute stores the disentangled wavefunctions spanning the optimized subspace. References ---------- .. [1] Souza, I., Marzari, N., & Vanderbilt, D. Maximally localized Wannier functions for entangled energy bands. Phys. Rev. B 65, 035109 (2001). """ # if we haven't done single-shot projection yet (set tilde states) if not hasattr(self.tilde_states, "_u_nk"): # check if we have trial wavefunctions if not hasattr(self, "_trial_wfs"): # we use energy eigenstates tilde states self.set_tilde_states( self.bloch_states.states(flatten_spin_axis=True), is_cell_periodic=True, ) else: # we initialize tilde states with previous trial wavefunctions n_occ = int(self.bloch_states.nstates / 2) # assuming half filled band_idxs = list(range(0, n_occ)) # project onto occ manifold psi_nk = self.bloch_states.states( flatten_spin_axis=True, return_psi=True )[1] self._single_shot_project(psi_nk, self._twfs, state_idx=band_idxs) # Minimizing Omega_I via disentanglement util_min = self._optimal_subspace( n_wfs=n_wfs, inner_window=frozen_window, outer_window=outer_window, iter_num=max_iter, verbose=verbose, beta=mix, tol=tol, tf_speedup=tf_speedup, ) self.set_tilde_states(util_min, is_cell_periodic=True, is_spin_axis_flat=True)
[docs] def maxloc( self, alpha=1 / 2, max_iter=1000, tol=1e-5, grad_min=1e-3, verbose=False ): r"""Unitary transformation to minimize the gauge-dependent spread. This procedure implements the Marzari-Vanderbilt maximal localization algorithm [1]_. Given a (disentangled) subspace (``.tilde_states``), it finds the optimal unitary transformation that minimizes the gauge-dependent part of the Wannier spread :math:`\widetilde{\Omega}`. The algorithm proceeds iteratively, applying gradient-descent updates to the unitary matrices at each k-point until convergence. Parameters ---------- alpha : float, optional Step size for gradient descent. Typical values are between 0 and 1. max_iter : int, optional Maximum number of iterations for the optimization. Default is 1000. tol : float, optional Convergence tolerance for the change in spread. Default is 1e-5. grad_min : float, optional Minimum gradient magnitude for convergence. Default is 1e-3. verbose : bool, optional If True, print progress messages. Default is False. Notes ----- - The gauge-dependent contribution to the total Wannier spread is .. math:: \widetilde{\Omega} = \Omega - \Omega_I, where :math:`\Omega` is the total quadratic spread functional and :math:`\Omega_I` is the gauge-invariant part obtained during the disentanglement step. - The minimization is achieved by unitary rotations of the form .. math:: | u_{m\mathbf{k}}^{\text{new}} \rangle = \sum_{n} U_{nm}(\mathbf{k}) \, | u_{n\mathbf{k}}^{\text{old}} \rangle, with :math:`U(\mathbf{k}) \in U(N)`, where :math:`N` is the dimension of the disentangled subspace at each k-point. - The gradient of :math:`\widetilde{\Omega}` with respect to an infinitesimal anti-Hermitian generator :math:`A(\mathbf{k})` is computed, and the unitary matrices are updated via .. math:: U(\mathbf{k}) \;\to\; \exp[-\epsilon A(\mathbf{k})] \, U(\mathbf{k}), where :math:`\epsilon = \alpha/4 \sum_{\mathbf{b}}w_b` is the step size (given ``alpha``). - Iteration proceeds until the gradient norm falls below ``grad_min`` and the change in spread is smaller than ``tol``, or the maximum number of iterations is reached. References ---------- .. [1] Marzari, N., & Vanderbilt, D. Maximally localized generalized Wannier functions for composite energy bands. Phys. Rev. B 56, 12847 (1997). """ U = self._max_loc_unitary( alpha=alpha, iter_num=max_iter, verbose=verbose, tol=tol, grad_min=grad_min ) u_tilde_wfs = self.tilde_states.states(flatten_spin_axis=True) util_max_loc = np.einsum("...ji, ...jm -> ...im", U, u_tilde_wfs) self.set_tilde_states( util_max_loc, is_cell_periodic=True, is_spin_axis_flat=True )
[docs] def min_spread( self, outer_window="all", inner_window=None, twfs_2=None, n_wfs=None, max_iter=1000, max_iter_dis=1000, alpha=1 / 2, tol_max_loc=1e-5, tol_dis=1e-10, grad_min=1e-3, mix=1, verbose=False, ): r"""Run disentanglement + projection + maximal-localization workflow. This method performs three steps: 1. Calls :meth:`disentangle` to find the optimal subspace that minimizes the gauge-independent spread. 2. Applies a second projection using ``twfs_2`` if provided, or the original trial wavefunctions otherwise, to refine the states within the optimal subspace. This step ensures that the states are well-aligned with the desired chemical character before localization. It uses the :meth:`project` method for this projection. 3. Calls :meth:`maxloc` to find the unitary transformation that minimizes the gauge-dependent spread, resulting in maximally localized Wannier functions. Parameters ---------- outer_window : str, tuple, list, or dict, optional Defines the "disentanglement window," i.e. the set of candidate states from which the optimal subspace is chosen. States outside this window are ignored. Options: - ``"occupied"``: All states below the Fermi level. - ``"all"``: All available states. - ``(Emin, Emax)``: Energy range in eV. - ``{"bands": [i1, i2, ...]}``: Explicit band indices. - ``{"energy": (Emin, Emax)}``: Energy window. Defaults to ``"all"``. inner_window : str, tuple, list, dict, or None, optional Defines the "frozen window," i.e. states that must be exactly included in the subspace. This ensures that, for example, the occupied manifold is preserved while disentangling higher states. Options follow the same conventions as ``outer_window``. If ``None``, no states are frozen. Defaults to ``None``. twfs_2 : list of list of tuple or None, optional A second set of trial wavefunctions for the projection step after disentanglement. If ``None``, the original trial wavefunctions are used. Defaults to ``None``. n_wfs : int or None, optional Number of states in the optimal subspace. If ``None``, the number of trial wavefunctions is used. Defaults to ``None``. max_iter : int, optional Maximum number of iterations for the maximal localization step. Default is 1000. max_iter_dis : int, optional Maximum number of iterations for the disentanglement step. Default is 1000. alpha : float, optional Step size for gradient descent in the maximal localization step. Typical values are between 0 and 1. Default is 1/2. tol_max_loc : float, optional Convergence tolerance for the change in spread in the maximal localization step. Default is 1e-5. tol_dis : float, optional Convergence tolerance for the disentanglement step. Default is 1e-10. grad_min : float, optional Minimum gradient magnitude for convergence in the maximal localization step. Default is 1e-3. mix : float, optional Mixing parameter for iterative updates in the disentanglement step. ``mix=1`` uses the new step fully, while smaller values blend the new and old projectors. Defaults to 1. verbose : bool, optional If True, print detailed iteration information to the logger. Default is False. Notes ----- - This method combines disentanglement and maximal localization to produce maximally localized Wannier functions from a set of Bloch states. It first identifies an optimal subspace, then refines the states via projection, and finally minimizes the gauge-dependent spread. - The resulting Wannier functions are stored in the ``.tilde_states`` attribute after the procedure completes. """ ### Subspace selection ### self.disentangle( outer_window=outer_window, inner_window=inner_window, n_wfs=n_wfs, max_iter=max_iter_dis, tol=tol_dis, mix=mix, verbose=verbose, ) ### Second projection ### # if we need a smaller number of twfs b.c. of subspace selec if twfs_2 is not None: twfs = self.get_trial_wfs(twfs_2) psi_til = self.tilde_states.states(flatten_spin_axis=True, return_psi=True)[ 1 ] psi_til_til = self._single_shot_project( psi_til, twfs, state_idx=list(range(self.tilde_states.nstates)), ) # choose same twfs as in subspace selection else: psi_til = self.tilde_states.states(flatten_spin_axis=True, return_psi=True)[ 1 ] psi_til_til = self._single_shot_project( psi_til, self.trial_wfs, state_idx=list(range(self.tilde_states.nstates)), ) self.set_tilde_states( psi_til_til, is_cell_periodic=False, is_spin_axis_flat=True ) ### Finding optimal gauge with maxloc ### self.maxloc( alpha=alpha, iter_num=max_iter, tol=tol_max_loc, grad_min=grad_min, verbose=verbose, )
[docs] def interp_bands( self, k_nodes, n_interp: int = 20, wan_idxs=None, ret_eigvecs=False ): r"""Wannier interpolate the band structure along a k-path. This method uses the Wannier functions to interpolate the band structure along a specified k-path. It constructs a tight-binding Hamiltonian in the Wannier basis, diagonalizes it, and Fourier transforms back to k-space along the k-path defined by ``k_nodes``. Parameters ---------- k_nodes : array-like Array of k-points defining the path in reciprocal space. n_interp : int, optional Number of interpolated k-points between each pair of nodes in ``k_nodes``. Defaults to 20. wan_idxs : list of int or None, optional Indices of Wannier functions to include in the interpolation. If None, all Wannier functions are used. Defaults to None. ret_eigvecs : bool, optional If True, return the eigenvectors along with the eigenvalues. Defaults to False. Returns ------- np.ndarray or tuple[np.ndarray, np.ndarray] Interpolated eigenvalues, and optionally eigenvectors if ``ret_eigvecs=True``. """ u_tilde = self.tilde_states.states(flatten_spin_axis=False) if wan_idxs is not None: u_tilde = np.take_along_axis(u_tilde, wan_idxs, axis=-2) k_mesh = self.mesh.get_k_points() k_flat = k_mesh.reshape(-1, k_mesh.shape[-1]) H_k = self.bloch_states.model.hamiltonian(k_flat) H_k = H_k.reshape(k_mesh.shape[:-1] + H_k.shape[1:]) if self.bloch_states.spinful: new_shape = H_k.shape[:-4] + ( self.bloch_states.nstates, self.bloch_states.nstates, ) H_k = H_k.reshape(*new_shape) H_rot_k = u_tilde.conj() @ H_k @ np.swapaxes(u_tilde, -1, -2) eigvals, eigvecs = np.linalg.eigh(H_rot_k) eigvecs = np.einsum("...ij, ...ik->...kj", u_tilde, eigvecs) # eigvecs = np.swapaxes(eigvecs, -1, -2) nks = self.nks idx_grid = np.indices(nks, dtype=int) k_idx_arr = idx_grid.reshape(len(nks), -1).T Nk = np.prod([nks]) supercell = list( product( *[ range(-int((nk - nk % 2) / 2), int((nk - nk % 2) / 2) + 1) for nk in nks ] ) ) # Fourier transform to real space # H_R = np.zeros((len(supercell), H_rot_k.shape[-1], H_rot_k.shape[-1]), dtype=complex) # u_R = np.zeros((len(supercell), u_tilde.shape[-2], u_tilde.shape[-1]), dtype=complex) eval_R = np.zeros((len(supercell), eigvals.shape[-1]), dtype=complex) evecs_R = np.zeros( (len(supercell), eigvecs.shape[-2], eigvecs.shape[-1]), dtype=complex ) for idx, r in enumerate(supercell): for k_idx in k_idx_arr: R_vec = np.array([*r]) phase = np.exp(-1j * 2 * np.pi * np.vdot(k_mesh[*k_idx], R_vec)) # H_R[idx, :, :] += H_rot_k[k_idx] * phase / Nk # u_R[idx] += u_tilde[k_idx] * phase / Nk eval_R[idx] += eigvals[*k_idx] * phase / Nk evecs_R[idx] += eigvecs[*k_idx] * phase / Nk # interpolate to arbitrary k k_path, _, _ = self.lattice.k_path(k_nodes, nk=n_interp, report=False) # H_k_interp = np.zeros((k_path.shape[0], H_R.shape[-1], H_R.shape[-1]), dtype=complex) # u_k_interp = np.zeros((k_path.shape[0], u_R.shape[-2], u_R.shape[-1]), dtype=complex) eigvals_k_interp = np.zeros((k_path.shape[0], eval_R.shape[-1]), dtype=complex) eigvecs_k_interp = np.zeros( (k_path.shape[0], evecs_R.shape[-2], evecs_R.shape[-1]), dtype=complex ) for k_idx, k in enumerate(k_path): for idx, r in enumerate(supercell): R_vec = np.array([*r]) phase = np.exp(1j * 2 * np.pi * np.vdot(k, R_vec)) # H_k_interp[k_idx] += H_R[idx] * phase # u_k_interp[k_idx] += u_R[idx] * phase eigvals_k_interp[k_idx] += eval_R[idx] * phase eigvecs_k_interp[k_idx] += evecs_R[idx] * phase # eigvals, eigvecs = np.linalg.eigh(H_k_interp) # eigvecs = np.einsum('...ij, ...ik -> ...kj', u_k_interp, eigvecs) # # normalizing # eigvecs /= np.linalg.norm(eigvecs, axis=-1, keepdims=True) eigvecs_k_interp /= np.linalg.norm(eigvecs_k_interp, axis=-1, keepdims=True) if ret_eigvecs: return eigvals_k_interp.real, eigvecs_k_interp else: return eigvals_k_interp.real
def _get_sc_centers(self): r"""Collect Wannier-center positions across translated supercells. This method computes the positions of the Wannier function centers in the supercell defined by ``self.supercell``. It returns a dictionary containing the x and y coordinates of the Wannier function centers for all sites and home cell sites. Returns ------- dict A dictionary with keys 'centers all' and 'centers home', each containing sub-dictionaries with keys 'xs' and 'ys' for x-coordinates and y-coordinates, respectively. """ lat_vecs = self.lattice.lat_vecs centers = self.centers # Initialize arrays to store positions and weights positions = { "centers all": { "xs": [[] for _ in range(centers.shape[0])], "ys": [[] for _ in range(centers.shape[0])], }, "centers home": {"xs": [], "ys": []}, } for j in range(centers.shape[0]): for tx, ty in self.supercell: center = centers[j] + tx * lat_vecs[0] + ty * lat_vecs[1] positions["centers all"]["xs"][j].append(center[0]) positions["centers all"]["ys"][j].append(center[1]) if tx == ty == 0: positions["centers home"]["xs"].append(center[0]) positions["centers home"]["ys"].append(center[1]) # Convert lists to numpy arrays (batch processing for cleanliness) for key, data in positions.items(): for sub_key in data: positions[key][sub_key] = np.array(data[sub_key]) return positions def _get_sc_weights(self, wan_idx, special_sites=None): r"""Collect Wannier density weights across translated supercells. This method computes the positions and weights of the Wannier functions in the supercell defined by ``self.supercell``. It returns a dictionary containing the x and y coordinates, radial distances from the center, and weights of the Wannier functions for all sites, home cell sites, and optionally special sites. Parameters ---------- wan_idx : int Index of the Wannier function to analyze. special_sites : sequence of int or None, optional List of orbital indices considered as special sites. If provided, the method will also compute positions and weights for these sites. Defaults to None. Returns ------- dict A dictionary with keys 'all', 'home', and optionally 'special', each containing sub-dictionaries with keys 'xs', 'ys', 'r', and 'wt' for x-coordinates, y-coordinates, radial distances, and weights, respectively. """ w0 = self.WFs center = self.centers[wan_idx] orbs = self.lattice.orb_vecs lat_vecs = self.lattice.lat_vecs # Initialize arrays to store positions and weights positions = { "all": {"xs": [], "ys": [], "r": [], "wt": []}, "home": {"xs": [], "ys": [], "r": [], "wt": []}, } for tx, ty in self.supercell: for i, orb in enumerate(orbs): # Extract relevant parameters wf_value = w0[tx, ty, wan_idx, i] wt = np.sum(np.abs(wf_value) ** 2) pos = ( orb[0] * lat_vecs[0] + tx * lat_vecs[0] + orb[1] * lat_vecs[1] + ty * lat_vecs[1] ) rel_pos = pos - center x, y, rad = pos[0], pos[1], np.sqrt(rel_pos[0] ** 2 + rel_pos[1] ** 2) # Store values in 'all' positions["all"]["xs"].append(x) positions["all"]["ys"].append(y) positions["all"]["r"].append(rad) positions["all"]["wt"].append(wt) # Handle special sites if applicable if special_sites is not None and i in special_sites: positions["special"]["xs"].append(x) positions["special"]["ys"].append(y) positions["special"]["r"].append(rad) positions["special"]["wt"].append(wt) if tx == ty == 0: positions["home"]["xs"].append(x) positions["home"]["ys"].append(y) positions["home"]["r"].append(rad) positions["home"]["wt"].append(wt) # Convert lists to numpy arrays (batch processing for cleanliness) for key, data in positions.items(): for sub_key in data: positions[key][sub_key] = np.array(data[sub_key]) return positions
[docs] @copydoc(plot_centers) def plot_centers( self, center_scale=200, section_home_cell=True, color_home_cell=True, translate_centers=False, show=False, legend=False, pmx=4, pmy=4, center_color="r", center_marker="*", lat_home_color="b", lat_color="k", fig=None, ax=None, ): return plot_centers( self, center_scale=center_scale, section_home_cell=section_home_cell, color_home_cell=color_home_cell, translate_centers=translate_centers, show=show, legend=legend, pmx=pmx, pmy=pmy, center_color=center_color, center_marker=center_marker, lat_home_color=lat_home_color, lat_color=lat_color, fig=fig, ax=ax, )
[docs] @copydoc(plot_decay) def plot_decay(self, wan_idx, fig=None, ax=None, show=False): return plot_decay(self, wan_idx=wan_idx, fig=fig, ax=ax, show=show)
[docs] @copydoc(plot_density) def plot_density( self, wan_idx, mark_home_cell=False, mark_center=False, show_lattice=False, dens_size=40, lat_size=2, show=False, fig=None, ax=None, cbar=True, ): return plot_density( self, wan_idx=wan_idx, mark_home_cell=mark_home_cell, mark_center=mark_center, show_lattice=show_lattice, show=show, dens_size=dens_size, lat_size=lat_size, fig=fig, ax=ax, cbar=cbar, )