Source code for pythtb.wfarray

from .tbmodel import TBModel
from .mesh import Mesh
from .lattice import Lattice
from .utils import deprecated
import logging
import copy
import numpy as np
from numpy.typing import ArrayLike

logger = logging.getLogger(__name__)

__all__ = ["WFArray"]


[docs] class WFArray: r"""Wavefunction container defined on a sampling mesh. A :class:`WFArray` stores states on a discrete mesh of k-points and/or adiabatic parameters :math:`\lambda`. Once populated, it can be queried for Berry connections, Berry curvature, Chern numbers, and other derived quantities, or passed to :class:`Wannier` when constructing Wannier functions or obtaining smooth gauges from the projection method. The underlying :class:`Mesh` may represent a full Monkhorst-Pack grid, a lower-dimensional path, or even a mesh that contains only parameter axes. In every case, :class:`WFArray` tracks the mesh layout, the stored states, and the necessary phase conventions so downstream utilities can consume the data consistently. Wavefunctions stored in a :class:`WFArray` can be the Hamiltonian eigenstates generated by :meth:`solve_model`, or any other :class:`Mesh`-defined states in the orbital basis of the underlying :class:`Lattice` that match :attr:`spinful` setting. Use :meth:`set_states` for bulk assignments, or the ``[...]`` indexer to update the states at an individual mesh point. Periodic boundary conditions are enforced automatically on assignment. Parameters ---------- lattice : :class:`Lattice` Lattice structure with orbital positions, periodic directions, and lattice vectors. Use `model.lattice` to keep it consistent with a :class:`TBModel`. mesh : :class:`Mesh` Sampling grid in k/parameter space. Build it before constructing the :class:`WFArray`. Its axes set the topology for enforcing periodic boundary conditions. spinful : bool, optional Whether the model includes spin degrees of freedom (defaults to False). If True, each orbital stores two spin components. This affects the shape of the stored wavefunction array. nstates : int, optional Number of bands per mesh point to store (defaults to ``WFArray.norb * WFArray.nspin``). See Also -------- :class:`pythtb.TBModel` :class:`pythtb.Mesh` :class:`pythtb.Wannier` :ref:`formalism` :ref:`haldane-bp-nb` : For an example of using :class:`WFArray` on a regular grid of points in k-space. :ref:`cone-nb` : For an example of using :class:`WFArray` on a non-regular grid of points in k-space. :ref:`three-site-thouless-nb` : For an example of using :class:`Mesh` with an adiabatic dimension. This example shows how one of the directions of :class:`WFArray` object need not be a k-vector direction, but can instead be a Hamiltonian parameter :math:`\lambda`. See also discussion after equation 4.1 in :ref:`formalism`. :ref:`cubic-slab-hwf-nb` : For an example of using :class:`WFArray` to store hybrid Wannier functions. Notes ----- - Wavefunctions are always stored with mesh axes leading, followed by bands, orbital, and (if present) spin indices. When setting the wavefunctions manually, ensure the input array matches this convention. - :class:`WFArray` cooperates with :class:`Wannier` to construct smooth Wannier gauges: pass the populated array to ``Wannier(wfarray)`` and use :meth:`Wannier.single_shot_projection`. - Some features are only defined for regular grids and/or in the energy eigenstate gauge. Check the documentation of individual methods for details. Examples -------- Populate Mesh with a uniform Monkhorst-Pack grid of k-points >>> mesh = Mesh(['k', 'k']) >>> mesh.build_grid(shape=(20, 20), gamma_centered=True) >>> wfa = WFArray(lattice, mesh, spinful=True, nstates=4) >>> wfa.shape (20, 20, ...) The WFArray is initially empty >>> wfa.filled False Solve a :class:`TBModel` on the mesh and store the eigenstates >>> wfa.solve_model(tb_model) >>> wfa.energies.shape (20, 20, ...) Now we can use downstream functions, such as computing the Berry curvature on the grid >>> curv = wfa.berry_curvature(non_abelian=False) We can also store states from a finite model on a parameter sweep (no k-axes) >>> mesh = Mesh(['l']) >>> mesh.build_grid(shape=(101,), lambda_start=0.0, lambda_stop=2*np.pi) If the parameter points in the mesh are adiabatic cycles, ensure the boundary conditions are set correctly in the Mesh >>> mesh.loop(axis_idx=0, loop_idx=0) >>> wfa = WFArray(lattice, mesh) >>> wfa.set_states(eigenvectors_lambda, is_cell_periodic=False) >>> np.allclose(wfa[0], wfa[-1]) True # states at the endpoints match due to adiabatic cycle Even if we set the states manually with the indexer, periodic boundary conditions are still enforced automatically using the topology of the Mesh >>> wfa[-1] = eigvecs # shape (nstates, norb[, nspin]) >>> np.allclose(wfa[0], wfa[-1]) True # states at the endpoints still match due to adiabatic cycle """
[docs] @deprecated("Looping is handled automatically by Mesh.") def impose_loop(self, mesh_dir: int): r""" .. versionremoved:: 2.0.0 :meth:`impose_loop` has been removed. Looping is handled automatically by :class:`Mesh`. """ raise NotImplementedError("Looping is handled automatically by Mesh.")
[docs] @deprecated("Periodic boundary conditions are handled automatically by Mesh.") def impose_pbc(self, mesh_dir: int, k_dir: int): r""" .. versionremoved:: 2.0.0 :meth:`impose_pbc` has been removed. Periodic boundary conditions are handled automatically by :class:`Mesh`. """ raise NotImplementedError( "Periodic boundary conditions are handled automatically by Mesh." )
[docs] @deprecated("Use `solve_model` instead.") def solve_on_grid(self, start_k=None): r""" .. versionremoved:: 2.0.0 :meth:`solve_on_grid` has been removed. Use :meth:`solve_model` instead. """ raise NotImplementedError("Use `solve_model` instead.")
[docs] @deprecated("Use `solve_model` instead.") def solve_on_one_point(self, kpt, mesh_indices): r""" .. versionremoved:: 2.0.0 :meth:`solve_on_one_point` has been removed. Use :meth:`solve_model` instead. """ raise NotImplementedError("Use `solve_model` instead.")
def __init__( self, lattice: Lattice, mesh: Mesh, nstates: int = None, spinful: bool = False ): if not isinstance(lattice, Lattice): raise TypeError("lattice must be of type pythtb.Lattice") if not isinstance(mesh, Mesh): raise TypeError("mesh must be of type pythtb.Mesh") if not isinstance(spinful, bool): raise TypeError("Argument spinful must be a boolean.") if lattice.dim_k != mesh.dim_k: raise ValueError( f"Lattice dim_k ({lattice.dim_k}) does not match Mesh dim_k ({mesh.dim_k})" ) if not mesh.filled: raise ValueError( "Mesh points are not initialized. Did you call build_grid/build_custom?" ) self._lattice = lattice self._mesh = mesh axis_sizes = np.array(self.shape_mesh, dtype=int) loop_axes = [idx for idx, ax in enumerate(mesh.axes) if ax.is_loop] short_loops = [idx for idx in loop_axes if axis_sizes[idx] < 2] if short_loops: raise ValueError( "Looping mesh axes must have at least two samples " f"(axes {short_loops} are too short)." ) if True in (np.array(self.shape_mesh, dtype=int) < 1).tolist(): raise ValueError( "Dimension of WFArray object in each direction must be at least 1.\n" "Maybe you need to build the mesh first?" ) if not isinstance(spinful, bool): raise TypeError("Argument spinful must be a boolean.") self._spinful = spinful if nstates is not None: if not isinstance(nstates, (int, np.integer)): raise TypeError("Argument nstates must be an integer.") self._nstates = nstates else: self._nstates = self.norb * self.nspin # Default to total number of bands # wfs indexed by [k1, k2,..., state, orb, spin] self._wfs = np.empty(self.shape, dtype=complex) # energies indexed by [k1, k2,..., state] self._energies = None def __getitem__(self, index): return self._wfs[index] def __setitem__(self, index, value): if not isinstance(value, (list, np.ndarray)): raise TypeError("Value must be a list or numpy array!") value = np.array(value, dtype=complex) if self.nspin == 2: if value.ndim == self.naxes + 2: if value.shape[-1] != self.norb * 2: raise ValueError( "Value shape does not match expected shape for spinful model!" ) value = value.reshape(*value.shape[:-1], self.norb, 2) else: if value.shape != self.shape[len(self.shape_mesh) :]: raise ValueError("Incompatible shape for wavefunction!") self._wfs[index] = value self._sync_boundary_from_index(index) self._invalidate_caches() def _check_state_indices( self, state_idx: int | ArrayLike, return_indices: bool = False ) -> np.ndarray | None: """Validate state indices and return as a numpy array.""" # Normalize to numpy array try: state_idx = np.atleast_1d(state_idx).astype(int) except Exception: raise TypeError("state_idx must be an integer or array-like of integers.") if state_idx.ndim != 1: raise ValueError("State indices should be a one-dimensional array.") if np.any(state_idx < 0) or np.any(state_idx >= self.nstates): raise IndexError( "One or more state indices are outside the range of the WFArray." ) return state_idx if return_indices else None def _normalize_state_indices(self, state_idx: int | ArrayLike | None) -> np.ndarray: """Validate state indices and return as a numpy array. Differs from _check_state_indices by allowing None input, which returns all indices. """ if state_idx is None: state_idx = np.arange(self.nstates, dtype=int) else: state_idx = self._check_state_indices(state_idx, return_indices=True) return state_idx def _invalidate_caches(self): for attr in ("_P", "_Q", "_P_nbr", "_Q_nbr", "_Mmn"): if hasattr(self, attr): delattr(self, attr) def _sync_boundary_from_index(self, index): """Update linked boundary points after assigning into the array.""" if self.naxes == 0: return if isinstance(index, np.ndarray): index = index.tolist() if self.naxes == 1 and not isinstance(index, (tuple, list)): coords = (int(index),) else: coords = tuple(int(k) for k in index) mesh_coords = [] for ax_idx, k in enumerate(coords): size = self.shape_mesh[ax_idx] mesh_coords.append(k % size) for ax_idx, coord in enumerate(mesh_coords): axis = self.mesh.axes[ax_idx] if not (axis.has_endpoint and axis.is_loop): continue axis_len = self.shape_mesh[ax_idx] if coord not in (0, axis_len - 1): continue if axis.winds_bz: phase, slc_first, slc_last, comps = self._collect_pbc_phase_info(ax_idx) if phase is None: continue from_first = coord == 0 logger.debug( "Syncing PBC on mesh axis %d (%s) for k-components %s (%s edge).", ax_idx, axis, comps, "first" if from_first else "last", ) self._apply_pbc_phase(phase, slc_first, slc_last, from_first=from_first) else: slc_first, slc_last = self._edge_slices(ax_idx) if coord == 0: logger.debug( "Syncing loop boundary (first → last) on mesh axis %d (%s).", ax_idx, axis, ) self._copy_edge(slc_first, slc_last) else: logger.debug( "Syncing loop boundary (last → first) on mesh axis %d (%s).", ax_idx, axis, ) self._copy_edge(slc_last, slc_first) def _canonical_to_mesh_axes(self) -> tuple[int, ...]: """Permutation from canonical axis order to the mesh's declared axis order.""" axes = self.mesh.axes if not axes: return tuple() shape_k = self.mesh.shape_k perm_user_to_canonical: list[int] = [] k_counter = l_counter = 0 for ax in axes: if ax.type == "k": perm_user_to_canonical.append(k_counter) k_counter += 1 else: perm_user_to_canonical.append(len(shape_k) + l_counter) l_counter += 1 return tuple(perm_user_to_canonical) def _mesh_axes_to_canonical(self) -> tuple[int, ...]: """Permutation from mesh's declared axis order to canonical axis order.""" perm = self._canonical_to_mesh_axes() return tuple(np.argsort(perm).tolist()) @property def model(self) -> TBModel: """The :class:`TBModel` associated with the :class:`WFArray`. Returns ------- TBModel The :class:`TBModel` associated with the :class:`WFArray`. Only defined if the wavefunctions were computed by :meth:`solve_model`. Raises ------ ValueError If no :class:`TBModel` is associated with this :class:`WFArray`. """ if not hasattr(self, "_model"): raise ValueError( "No TBModel is associated with this WFArray. " "Did you compute the wavefunctions using solve_model?" ) return self._model @property def lattice(self) -> Lattice: """The :class:`Lattice` associated with the :class:`WFArray`. Returns ------- Lattice The :class:`Lattice` associated with the :class:`WFArray`. """ return self._lattice @property def mesh(self) -> Mesh: """The :class:`Mesh` associated with the :class:`WFArray`. Returns ------- Mesh The :class:`Mesh` associated with the :class:`WFArray`. """ return self._mesh @property def filled(self) -> bool: """Whether the wavefunction array has been initialized. Returns ------- bool True if the :attr:`wfs` array is not empty. """ # if uninitialzed, wfs will be np.empty return self._wfs.size > 0 @property def wfs(self) -> np.ndarray: """The stored wavefunctions. In the case of k-axes, these are the cell-periodic (Bloch) states without the plane-wave phase factors. Returns ------- np.ndarray The stored wavefunctions. Shape is ``(*shape_mesh, nstates, norb[, nspin])``. In the case of spinful models, the last axis corresponds to spin. """ return self._wfs @property def u_nk(self) -> np.ndarray: r"""The cell-periodic wavefunctions. These are the :math:`|u_{n\mathbf{k}}\rangle` states without the plane-wave phase factors. Returns ------- np.ndarray The cell-periodic wavefunctions. Shape is ``(*shape_mesh, nstates, norb[, nspin])``. Raises ------ ValueError If the wavefunctions are not initialized or if k-axes are not present in the mesh. """ if not self.filled: raise ValueError("Wavefunctions are not initialized.") if self.dim_k == 0: raise ValueError( "Cell-periodic wavefunctions are not defined for 0D k-space." ) return getattr(self, "_u_nk", None) @property def psi_nk(self) -> np.ndarray: r"""The Bloch wavefunctions. These are the :math:`|\psi_{n\mathbf{k}}\rangle` states including the plane-wave phase factors. Returns ------- np.ndarray The Bloch wavefunctions. Shape is ``(*shape_mesh, nstates, norb[, nspin])``. Raises ------ ValueError If the wavefunctions are not initialized or if k-axes are not present in the mesh """ if not self.filled: raise ValueError("Wavefunctions are not initialized.") if self.dim_k == 0: raise ValueError("Bloch wavefunctions are not defined for 0D k-space.") return getattr(self, "_psi_nk", None) @property def Mmn(self) -> np.ndarray: r"""The overlap matrix of the wavefunctions. The overlap matrix is defined as .. math:: M_{mn}^{(\mathbf{b})}(\mathbf{k}) = \langle u_{m,\mathbf{k}} | u_{n,\mathbf{k}+\mathbf{b}} \rangle where :math:`\mathbf{b}` is a vector connecting nearest neighbor k-points in the mesh. Here, the neighboring k-points are computed in Cartesian space. Returns ------- np.ndarray The overlap matrix of the wavefunctions. Shape is ``(*shape_mesh, nnbrs, nstates, nstates)``. Raises ------ ValueError If the wavefunctions are not initialized or if the mesh is not a regular grid. Notes ----- - The overlap matrix is only defined for regular grids in k-space. - The overlap matrix is computed using the Cartesian metric by default. To compute the overlap matrix using reduced neighbors, use :meth:`overlap_matrix` with ``use_k_metric=False``. """ if not self.filled: raise ValueError("Wavefunctions are not initialized.") if not self.mesh.is_grid: raise ValueError("Overlap matrix is only defined for regular grids.") if self.dim_k == 0: raise ValueError("Overlap matrix is not defined for 0D k-space.") if not hasattr(self, "_Mmn"): self._Mmn = self.overlap_matrix(use_k_metric=True) return self._Mmn @property def energies(self) -> np.ndarray: """The energies of the :class:`TBModel`. Returns ------- np.ndarray The energies of the :class:`TBModel`. Shape is ``(shape_mesh..., nstates)``. Only defined after calling :meth:`solve_model`. Raises ------ ValueError If the wavefunctions are not initialized or if the energies have not been computed. """ if not self.filled: raise ValueError("Wavefunctions are not initialized.") if self._energies is None: raise ValueError( "Energies are not initialized. Use `solve_model` to compute them." ) return self._energies @property def hamiltonian(self) -> np.ndarray: r"""The Hamiltonian defined on the :class:`Mesh`. Returns ------- np.ndarray The Hamiltonian defined on the :class:`Mesh`. Shape is ``(*shape_mesh, norb[, nspin], norb[, nspin])``. """ return getattr(self, "_H", None) @property def shape(self) -> tuple: """The shape of the state array. Returns ------- tuple of int The shape of the state array, including mesh axes, states, orbitals, and (if present) spin. """ wfs_dim = np.array(self.shape_mesh, dtype=int) wfs_dim = np.append(wfs_dim, self.nstates) wfs_dim = np.append(wfs_dim, self.norb) if self.nspin == 2: wfs_dim = np.append(wfs_dim, self.nspin) return tuple(wfs_dim) @property def nstates(self) -> int: """The number of states (or bands) in the state array. Returns ------- int """ return self._nstates @property def nspin(self) -> int: """The number of spin components. Returns ------- int 2 if the :class:`WFArray` is spinful, 1 otherwise. """ return 2 if self.spinful else 1 @property def spinful(self) -> bool: """Whether the :class:`WFArray` includes spin degrees of freedom. Returns ------- bool """ return self._spinful @property def norb(self) -> int: """The number of orbitals defined in the :class:`Lattice`. Returns ------- int """ return self.lattice.norb @property def shape_mesh(self) -> tuple: """The shape of the axes of :class:`Mesh`. Returns ------- tuple of int The shape of the axes of :class:`Mesh`. Corrsponds to :attr:`Mesh.shape_axes`. """ return self.mesh.shape_axes @property def dim_k(self) -> int: """The dimension of k-space in the :class:`Mesh`. Returns ------- int """ return self.lattice.dim_k @property def dim_lambda(self) -> int: """The dimension of lambda-space in the :class:`Mesh`. Returns ------- int """ return self.mesh.dim_lambda @property def naxes(self) -> int: """The number of axes in the :class:`Mesh`. Returns ------- int """ return self.mesh.naxes @property def k_points(self) -> np.ndarray: """The k-points in the :class:`Mesh`. Returns ------- np.ndarray The k-points in the :class:`Mesh`. """ return self.mesh.get_k_points() @property def param_points(self) -> np.ndarray: """The parameter points in the :class:`Mesh`. Returns ------- np.ndarray The parameter points in the :class:`Mesh`. """ return self.mesh.get_param_points()
[docs] def empty_like(self, nstates: int = None) -> "WFArray": r"""Create a new :class:`WFArray` object with the same :class:`Lattice` and :class:`Mesh`. Parameters ---------- nstates : int, optional Number of states for the new :class:`WFArray`. If None, uses the current number of states (default). .. versionchanged:: 2.0.0 Renamed from ``nsta_arr`` for consistency with initialization. Returns ------- WFArray A new :class:`WFArray` object with the same :class:`Lattice` and :class:`Mesh`. """ # make a full copy of the WFArray wf_new = WFArray(self.lattice, self.mesh, nstates=nstates, spinful=self.spinful) return wf_new
[docs] def copy(self) -> "WFArray": r"""Create a copy of the current :class:`WFArray` object. .. versionadded:: 2.0.0 Returns ------- WFArray A copy of the current :class:`WFArray` object. """ return copy.deepcopy(self)
[docs] def set_states( self, wfs, is_cell_periodic: bool = True, is_spin_axis_flat: bool = False ): """Populate the wavefunction array. This method sets the wavefunctions stored in the :class:`WFArray`. The input wavefunctions can be either the cell-periodic (Bloch) states or the full Bloch states, depending on the value of ``is_cell_periodic``. .. versionadded:: 2.0.0 Parameters ---------- wfs : np.ndarray Wavefunctions to populate the mesh with. The shape must match the expected shape :attr:`shape`. is_cell_periodic : bool, optional If True, the wavefunctions are treated as cell-periodic (Bloch states). Default is True. is_spin_axis_flat : bool, optional If True, the spin and orbital indices of ``wfs`` have been flattened into a single index. Default is False. This option is only relevant when :attr:`spinful` is True. Notes ----- - This function sets the Bloch and cell-periodic eigenstates as class attributes when ``wfs`` are defined on the a k-mesh. When the model is finite, only the :attr:`wfs` attribute is set and ``is_cell_periodic`` argument is ignored. .. warning:: This function should be used carefully to ensure that the wavefunctions are consistent with the :attr:`mesh` and :attr:`lattice`. """ if not isinstance(wfs, np.ndarray): raise TypeError("wfs must be a numpy ndarray.") # Check the shape of wfs if is_spin_axis_flat and self.nspin == 2: expected_shape = self.shape_mesh + (self.nstates, self.norb * self.nspin) if not is_spin_axis_flat and self.nspin == 2: expected_shape = self.shape_mesh + (self.nstates, self.norb, self.nspin) elif self.nspin == 1: expected_shape = self.shape_mesh + (self.nstates, self.norb) if wfs.shape != expected_shape: raise ValueError( f"wfs shape {wfs.shape} does not match expected shape: {expected_shape}" ) wfs = wfs.reshape(self.shape) # Compute phase factors for Bloch <-> cell-periodic transformation if self.dim_k > 0: if is_cell_periodic: phases = self._get_phases(inverse=False) psi_nk = wfs * phases self._u_nk = self._wfs = wfs self._psi_nk = psi_nk else: phases = self._get_phases(inverse=True) u_nk = wfs * phases self._u_nk = self._wfs = u_nk self._psi_nk = wfs else: if not is_cell_periodic: logger.warning( "Setting non-cell-periodic wavefunctions for 0D k-space." ) self._wfs = wfs self._enforce_pbc() self._invalidate_caches()
[docs] def remove_states(self, state_idx: int | ArrayLike): r"""Remove specified states from the :class:`WFArray`. .. versionadded:: 2.0.0 Parameters ---------- state_idx : int or array-like of int Indices of the states to remove. Notes ----- - This modifies the shape of the :attr:`wfs`, :attr:`energies`, :attr:`u_nk` and :attr:`psi_nk` arrays. - The indices in ``state_idx`` refer to the current ordering of states. After removal, the remaining states are re-indexed accordingly. Examples -------- Remove states 0 and 2 from the :class:`WFArray` >>> wf.remove_states([0, 2]) """ if self.nspin == 2: state_ax = -3 elif self.nspin == 1: state_ax = -2 else: raise ValueError( "WFArray object can only handle spinless or spin-1/2 models." ) state_idx = self._check_state_indices(state_idx, return_indices=True) n_states = len(state_idx) self._wfs = np.delete(self._wfs, state_idx, axis=state_ax) self._nstates -= n_states self._energies = np.delete(self._energies, state_idx, axis=-1) if getattr(self, "_u_nk", None) is not None: self._u_nk = np.delete(self._u_nk, state_idx, axis=state_ax) if getattr(self, "_psi_nk", None) is not None: self._psi_nk = np.delete(self._psi_nk, state_idx, axis=state_ax)
[docs] def choose_states(self, state_idx: int | ArrayLike): r"""Keep only the specified states in the :class:`WFArray`. This method modifies the existing states in place to keep only those specified by ``state_idx``. Parameters ---------- state_idx : int or array-like of int Indices of states to keep. .. versionchanged:: 2.0.0 Renamed from ``subset`` for consistency. Notes ----- - This modifies the shape of the :attr:`wfs`, :attr:`energies`, :attr:`u_nk` and :attr:`psi_nk` arrays. - The indices in ``state_idx`` refer to the current ordering of states. After removal, the remaining states are re-indexed accordingly. Examples -------- Keep only states 3 and 5 from the :class:`WFArray` >>> wf.choose_states([3, 5]) """ state_idx = self._check_state_indices(state_idx, return_indices=True) remove_indices = np.setdiff1d(np.arange(self.nstates), state_idx) self.remove_states(remove_indices)
[docs] def states( self, state_idx: ArrayLike | None = None, flatten_spin_axis: bool = False, return_psi: bool = False, ) -> np.ndarray: r"""Return states stored in the *WFArray* object. The states are returned in the same ordering as stored internally, with mesh axes leading, followed by band, orbital, and (if present) spin indices. By default, all states are returned. The user can specify a subset of states to return using the ``state_idx`` argument. .. versionadded:: 2.0.0 Parameters ---------- state_idx : int or array-like of int, optional Index or indices of the states to return. If not provided or None, all states are returned. flatten_spin_axis : bool, optional If True, the spin and orbital indices are flattened into a single index. Default is False. return_psi : bool, optional If True, the function also returns the full Bloch wavefunctions. This should only be requested when k-axes are present in the mesh and ``dim_k > 0``, otherwise an error is raised. Default is False. Returns ------- u : np.ndarray The states stored in the *WFArray* object. By default, these are the cell-periodic states when ``dim_k > 0``. The shape is ``(nk1, nk2,..., nl1, nl2,..., nstate, norb[,nspin])`` If ``flatten_spin_axis=True``, the last two axes are replaced by a single axis of size ``norb*nspin``. psi : np.ndarray, optional Bloch states with the same shape conventions as ``wfs``. These states are related to the cell-periodic states by plane-wave phase factors. Only returned if ``return_psi=True``. See Also -------- :ref:`formalism` """ if return_psi and self.dim_k == 0: raise ValueError("Bloch states are not defined for 0D k-space.") u = np.copy(self.wfs) psi = None if not return_psi else np.copy(self.psi_nk) state_idx = self._normalize_state_indices(state_idx) # select requested states sl = ( (..., state_idx, slice(None), slice(None)) if self.nspin == 2 else (..., state_idx, slice(None)) ) u = u[sl] if psi is not None: psi = psi[sl] if flatten_spin_axis and self.nspin == 2: u = u.reshape((*u.shape[:-2], -1)) if psi is not None: psi = psi.reshape((*psi.shape[:-2], -1)) return (u, psi) if return_psi else u
def _nbr_projectors(self, return_Q: bool = False): if self.dim_k == 0: raise NotImplementedError( "Nearest neighbor projectors are not defined for 0D k-space." ) if not self.mesh.is_grid: raise NotImplementedError( "Mesh must be a grid to compute nearest neighbor projectors." ) # Retrieve cached projectors if available P = getattr(self, "_P", None) if P is None: P = self.projectors(return_Q=False) # Fast path: cached if hasattr(self, "_P_nbr"): if return_Q and hasattr(self, "_Q_nbr"): return self._P_nbr, self._Q_nbr return self._P_nbr # Nearest neighbor shifts _, nnbr_idx_shell = self.lattice.nn_k_shell( self.mesh.shape_k, n_shell=1, report=False ) shifts = nnbr_idx_shell[0] num_nnbrs = shifts.shape[0] P_nbr = np.zeros((P.shape[:-2] + (num_nnbrs,) + P.shape[-2:]), dtype=complex) for idx, idx_vec in enumerate(shifts): # nearest neighbors u_shifted = self.roll_states_with_pbc(idx_vec, flatten_spin_axis=True) P = np.matmul(u_shifted.swapaxes(-2, -1), u_shifted.conj()) P_nbr[..., idx, :, :] = P self._P_nbr = P_nbr if not return_Q: return P_nbr Q_nbr = np.eye(P_nbr.shape[-1]) - P_nbr self._Q_nbr = Q_nbr return P_nbr, Q_nbr
[docs] def projectors( self, state_idx: int | ArrayLike = None, return_Q: bool = False ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: r"""Returns the band projectors associated with the states in the WFArray. The band projectors are defined as the outer product of the wavefunctions: .. math:: P_{n\mathbf{k}} = \lvert u_{n\mathbf{k}}(\mathbf{r})\rangle \langle u_{n\mathbf{k}}(\mathbf{r}) \rvert, \quad Q_{n\mathbf{k}} = \mathbb{I} - P_{n\mathbf{k}} .. versionadded:: 2.0.0 Parameters ---------- state_idx : int or array-like of int, optional Index or indices of the states for which to compute the projectors. If not provided or None, projectors for all states are computed. return_Q : bool, optional If True, the function also returns the orthogonal projector Q. Returns ------- P : np.ndarray The band projectors. Q : np.ndarray, optional The orthogonal projectors. """ # Check cache if state_idx is None and hasattr(self, "_P"): return (self._P, self._Q) if return_Q else self._P # Compute states u_nk = self.states(flatten_spin_axis=True) if state_idx is not None: u_nk = u_nk[..., state_idx, :] # Compute projectors P = np.matmul(u_nk.swapaxes(-2, -1), u_nk.conj()) # Cache full projectors if state_idx is None: Q = np.eye(P.shape[-1]) - P self._P, self._Q = P, Q else: Q = None return (P, Q) if return_Q else P
[docs] def solve_model(self, model: TBModel, use_tensorflow: bool = False): r"""Diagonalizes ``model`` on every point of the internal :class:`Mesh`. The method calls :meth:`TBModel.solve_ham` passing the k-points and model parameters defined in the :class:`Mesh` and populates the :class:`WFArray` with the eigenstates and eigenergies of the Hamiltonian. .. note:: For meshes that include :math:`\lambda`-axes, the axis names are interpreted as :class:`TBModel` parameter names. The names and values along each :math:`\lambda`-axis are passed as keyword arguments to :meth:`TBModel.solve_ham`. These parameter names must match those used in the model definition when using :meth:`TBModel.set_onsite` and :meth:`TBModel.set_hop`. .. versionadded:: 2.0.0 Replaces :meth:`solve_on_one_point` and :meth:`solve_on_grid`. Parameters ---------- model : :class:`TBModel` The tight-binding model to diagonalize on the mesh. Its :class:`Lattice` and the same ``spinful`` configuration must match those of the :class:`WFArray`. use_tensorflow : bool, optional If True, uses TensorFlow for diagonalization. This can be beneficial for large systems where GPU acceleration is available. This requires TensorFlow to be installed. Default is False. Notes ----- - The samples along each :math:`\lambda`-axis are obtained from :func:`Mesh.get_axis_range` and passed to :meth:`TBModel.solve_ham` as keyword arguments, so it is essential that the mesh axis names exactly match the symbolic/callable parameter names in the model. - Eigenstates stored by ``solve_model`` are chosen to obey a periodic gauge :math:`\psi_{n,{\bf k+G}}=\psi_{n {\bf k}}`, so the cell-periodic states satisfy :math:`u_{n,{\bf k_0+G}}=e^{-i{\bf G}\cdot{\bf r}} u_{n {\bf k_0}}`. See :ref:`formalism` section 4.4 and equation 4.18 for more detail. - If the mesh includes Brillouin zone boundary points along any k-axis, or a looping axis is marked as both Brillouin zone winding and closed, ``solve_model`` applies the periodic-gauge phase to the states along that axis. Examples -------- Say we have a parametric model function defined as follows: >>> model = TBModel(lattice=lat, spinful=False) >>> model.set_onsite(['param_1'], mode='set') The model will be 2D in k-space for this example. We want to vary ``param_1``, and store the energy eigenvalues and eigenstates on a 3D mesh: 2 dimensions in k-space and 1 dimension for the parameter ``param_1``. This means we will have a :class:`Mesh` with 2 k-axes and 1 lambda-axis corresponding to ``param_1``. We can create the mesh as follows: >>> mesh = Mesh( ... dim_k=2, ... dim_lambda=1, ... axis_types=['k','k','l'], ... axis_names=['k1', 'k2', 'param_1'] ... ) Note that we must name the last axis as ``param_1`` to follow the name we set in the model function. We build the mesh values by using ``build_grid``. We will construct a uniform 2D grid in k-space of shape ``(20, 20)`` and a uniform 1D grid for ``param_1`` with 5 points going from 0 to :math:`2\pi`. >>> mesh.build_grid(shape=(20, 20, 5), lambda_start=0, lambda_end=2*np.pi) To initialize the *WFArray*, we specify the same lattice as the model, and the mesh we just created. If the model were spinful, we would need to set ``spinful=True``. >>> wfa = WFArray(model.lattice, mesh) Now, the parameter values are stored in the mesh, and we can diagonalize the model on the mesh using :meth:`solve_model`: >>> wfa.solve(model=model) The eigenvalues and eigenstates are stored in the ``.energies`` and ``.wfs`` attributes, respectively. They can be accessed as follows: >>> wfa.energies.shape (20, 20, 5, 2) >>> wfa.wfs.shape (20, 20, 5, 2, 2) """ if self.spinful != model.spinful: raise ValueError("Spinful setting of WFArray does not match the model.") # lambda-parameter dict params = { ax.name: self.mesh.get_axis_range(i, j) for ax, i, j in zip( self.mesh.lambda_axes, self.mesh.lambda_axis_indices, self.mesh.lambda_component_indices, ) } # k-points (flatten k-grid only) k_flat = self.k_points.reshape(-1, self.dim_k) if self.dim_k else None eigvals, eigvecs = model.solve_ham( k_pts=k_flat, return_eigvecs=True, flatten_spin_axis=True, use_tensorflow=use_tensorflow, **params, ) n_state_returned = eigvals.shape[-1] if eigvals.ndim else 1 if n_state_returned != self.nstates: raise ValueError( "solve_model expected " f"{self.nstates} bands per mesh point, but TBModel.solve_ham " f"returned {n_state_returned}. Recreate the WFArray with " f"nstates={n_state_returned} or adjust the model so the band " "count matches." ) # Reshape from TBModel's canonical (k-first, lambda-second) ordering to the user-defined mesh ordering. shape_k = self.mesh.shape_k shape_lambda = self.mesh.shape_lambda canonical_mesh_shape = shape_k + shape_lambda axes_total = len(self.mesh.axes) state_shape = self.shape[axes_total:] eigvals = eigvals.reshape(*canonical_mesh_shape, self.nstates) eigvecs = eigvecs.reshape(*canonical_mesh_shape, *state_shape) axis_perm = self._canonical_to_mesh_axes() if axis_perm: axes_for_vals = axis_perm + (axes_total,) eigvals = np.transpose(eigvals, axes_for_vals) trailing_axes = tuple(range(axes_total, axes_total + len(state_shape))) eigvecs = np.transpose(eigvecs, axis_perm + trailing_axes) eigvals = eigvals.reshape(*self.shape_mesh, self.nstates) eigvecs = eigvecs.reshape(*self.shape) self.set_states(eigvecs, is_cell_periodic=True, is_spin_axis_flat=False) self._energies = eigvals self._model = model # gaps between adjacent bands self.gaps = ( (eigvals[..., 1:] - eigvals[..., :-1]).min(axis=tuple(range(self.naxes))) if self.nstates > 1 else None ) # Enforce PBCs along winding directions self._enforce_pbc()
def _get_phases(self, inverse=False): r"""Compute phase factors for converting between cell-periodic and Bloch wavefunctions. Parameters ---------- inverse : bool, optional If True, compute phase factors for converting from Bloch to cell-periodic wavefunctions. If False, compute phase factors for converting from cell-periodic to Bloch wavefunctions. Defaults to False. Returns ------- phases : np.ndarray Array of phase factors with shape [nk1, ..., nkd, nl1, ..., nlm, norb, (nspin)]. The last dimension is present only if the model has spin. """ lam = -1 if inverse else 1 # k-grid flattened to (Nk, dim_k) k = self.k_points.reshape(-1, self.dim_k) # orbital vectors restricted to periodic dirs: (norb, dim_k) periodic_dirs = np.asarray(self.lattice.periodic_dirs, int) tau = self.lattice.orb_vecs[:, periodic_dirs] # phases: exp(+-1 * i 2pi k.tau) with shape (Nk, norb) phase2d = np.exp(lam * 1j * 2.0 * np.pi * (k @ tau.T)) shape_k = self.mesh.shape_k axes_total = len(self.mesh.axes) phases = phase2d.reshape( *shape_k, *([1] * self.mesh.nl_axes), 1, self.norb, ) axis_perm = self._canonical_to_mesh_axes() if axis_perm: phases = np.transpose(phases, axis_perm + (axes_total, axes_total + 1)) if self.nspin == 2: phases = phases[..., np.newaxis] # spin axis (broadcast) return phases def _enforce_pbc(self): r"""Enforce periodic boundary conditions on all winding loop axes in the mesh. This routine iterates over all axes in the mesh that are loops and wind around the Brillouin zone, imposing periodic boundary conditions by setting the wavefunction at the end of the mesh equal to that at the beginning multiplied by the appropriate phase factor. """ if not self.filled: return for idx, ax in enumerate(self.mesh.axes): if not (ax.has_endpoint and ax.is_loop): continue if ax.winds_bz: # These contain endpoints (k_i = 1 in reduced units) comps = sorted( set(ax.endpoint_components) & set(ax.winds_bz_components) ) phase_total, slc_first, slc_last, comps = self._collect_pbc_phase_info( idx ) if phase_total is None: continue logger.debug( f"Imposing PBC in mesh direction {idx} ({ax}) for k-components {comps}" ) self._apply_pbc_phase(phase_total, slc_first, slc_last) else: logger.debug( f"Imposing loop in mesh direction {idx} ({ax}) without BZ winding." ) self._impose_loop(idx) def _collect_pbc_phase_info(self, mesh_axis_idx): """Gather combined phase and edge slices for a mesh axis that winds the BZ.""" axis = self.mesh.axes[mesh_axis_idx] comps = sorted(set(axis.endpoint_components) & set(axis.winds_bz_components)) if not comps: return None, None, None, tuple() per_dirs = np.asarray(self.lattice.periodic_dirs, dtype=int) phase_total = None slc_first = slc_last = None for comp in comps: # NOTE: # `comp` is Mesh's k-component index (0, ..., dim_k-1). # _get_pbc_phases expects the real-space index so it grabs correct orbital # column to dot into k-vector. This is the `periodic_dirs` entry. # Discrepancy only arises when some lattice directions are non-periodic. # If periodic axes are [0, 2] and `comp` is 1, then we need to grab # periodic_dirs[1] = 2 to get correct real-space index. real_dir = per_dirs[comp] phase, slc_first, slc_last = self._get_pbc_phases(mesh_axis_idx, real_dir) # NOTE: multiply phases for multiple components winding BZ phase_total = phase if phase_total is None else phase_total * phase return phase_total, slc_first, slc_last, tuple(comps) @staticmethod def _edge_slices(ax): """Helper function to get slices for the first and last edges of an axis.""" # add one for Python counting and one for ellipses # Example ax = 2 (2 defines the axis in Python counting) slc_last = [slice(None)] * (ax + 2) # e.g., [:, :, :, :] slc_first = [slice(None)] * (ax + 2) # e.g., [:, :, :, :] # last element along mesh_dir axis slc_last[ax] = -1 # e.g., [:, :, -1, :] # first element along mesh_dir axis slc_first[ax] = 0 # e.g., [:, :, 0, :] # take all components of remaining axes with ellipses slc_last[ax + 1] = Ellipsis # e.g., [:, :, -1, ...] slc_first[ax + 1] = Ellipsis # e.g., [:, :, 0, ...] return tuple(slc_first), tuple(slc_last) def _apply_pbc_phase(self, phase, slc_first, slc_last, from_first: bool = True): """ Apply the PBC phase between the first and last slice. When ``from_first`` is True the last slice is overwritten using the first; otherwise the first slice is generated from the last. """ phase_conj = np.conjugate(phase) u_attr = getattr(self, "_u_nk", None) psi_attr = getattr(self, "_psi_nk", None) if from_first: logger.debug( f"Setting wavefunctions at {slc_last} equal to those at {slc_first} times phase factor." ) self._wfs[slc_last] = self._wfs[slc_first] * phase if u_attr is not None: self._u_nk[slc_last] = self._u_nk[slc_first] * phase if psi_attr is not None: self._psi_nk[slc_last] = self._psi_nk[slc_first] else: logger.debug( f"Setting wavefunctions at {slc_first} equal to those at {slc_last} times phase factor." ) self._wfs[slc_first] = self._wfs[slc_last] * phase_conj if u_attr is not None: self._u_nk[slc_first] = self._u_nk[slc_last] * phase_conj if psi_attr is not None: self._psi_nk[slc_first] = self._psi_nk[slc_last] def _copy_edge(self, slc_src, slc_dst): """Copy wavefunction data between boundary slices (used for pure loops).""" self._wfs[slc_dst] = self._wfs[slc_src] u_attr = getattr(self, "_u_nk", None) if u_attr is not None: u_attr[slc_dst] = u_attr[slc_src] psi_attr = getattr(self, "_psi_nk", None) if psi_attr is not None: psi_attr[slc_dst] = psi_attr[slc_src] def _get_pbc_phases(self, mesh_dir, k_dir): r"""Compute phase factors for periodic boundary conditions in forward direction. This routine computes the phase factors needed for imposing periodic boundary conditions along one direction of the `WFArray`. The phase factors are given by :math:`e^{-i{\bf G}\cdot{\bf r}}`. In reduced units, this is :math:`e^{-2\pi i \tau_k}`, where :math:`\tau_k` is the orbital vector component along the `k_dir` direction corresponding to the reciprocal lattice vector :math:`{\bf G}`. Parameters ---------- mesh_dir : int Direction of the Mesh along which periodic boundary conditions are imposed. k_dir : int Component of the k-vector in the Brillouin zone corresponding to `mesh_dir`. This indexes one of the orbital vectors in the lattice. Returns ------- phases : np.ndarray Array of phase factors with shape [nk1, ..., nkd, norb, (nspin)]. The last dimension is present only if the model has spin. """ if k_dir not in self.lattice.periodic_dirs: raise Exception( "Periodic boundary condition can be specified only along periodic directions!" ) if not isinstance(mesh_dir, (int, np.integer)): raise TypeError("mesh_dir should be an integer!") if mesh_dir < 0 or mesh_dir >= self.naxes: raise IndexError("mesh_dir outside the range!") orb_vecs = self.lattice.orb_vecs # Compute phase factors from orbital vectors dotted with G parallel to k_dir phase = np.exp(-2j * np.pi * orb_vecs[:, k_dir]) phase = phase if self.nspin == 1 else phase[:, np.newaxis] # mesh_dir is the direction of the mesh along which we impose pbc slc_first, slc_last = self._edge_slices(mesh_dir) return phase, slc_first, slc_last def _impose_pbc(self, mesh_dir: int, k_dir: int): r"""Impose periodic boundary conditions on the WFArray. This routine sets the cell-periodic Bloch function at the end of the mesh in direction `k_dir` equal to the first, multiplied by a phase factor, overwriting the previous value. Explicitly, this means we set :math:`u_{n,{\bf k_0+G}}=e^{-i{\bf G}\cdot{\bf r}} u_{n {\bf k_0}}` for the corresponding reciprocal lattice vector :math:`\mathbf{G} = \mathbf{b}_{\texttt{k_dir}}`, where :math:`\mathbf{b}_{\texttt{k_dir}}` is the reciprocal lattice basis vector corresponding to the direction `k_dir`. The state :math:`u_{n{\bf k_0}}` is the state populated in the first element of the mesh along the `mesh_dir` axis. Parameters ---------- mesh_dir : int Direction of `WFArray` along which you wish to impose periodic boundary conditions. k_dir : int Corresponding to the periodic k-vector direction in the Brillouin zone of the underlying *TBModel*. Since version 1.7.0 this parameter is defined so that it is specified between 0 and *dim_r-1*. Notes ----- This function will impose these periodic boundary conditions along one direction of the array. We are assuming that the k-point mesh increases by exactly one reciprocal lattice vector along this direction. Examples -------- Imposes periodic boundary conditions along the mesh_dir=0 direction of the `WFArray` object, assuming that along that direction the `k_dir=1` component of the k-vector is increased by one reciprocal lattice vector. This could happen, for example, if the underlying TBModel is two dimensional but `WFArray` is a one-dimensional path along :math:`k_y` direction. >>> wf._impose_pbc(mesh_dir=0, k_dir=1) """ if self.dim_k == 0: raise ValueError( "Cannot impose periodic boundary conditions in 0D k-space.\n" "Use `_impose_loop` instead." ) if k_dir not in self.lattice.periodic_dirs: raise ValueError( "Periodic boundary condition can be specified only along periodic directions!" ) phase, slc_first, slc_last = self._get_pbc_phases(mesh_dir, k_dir) # Set the last point along mesh_dir axis equal to first # multiplied by the phase factor logger.debug( f"Setting wavefunctions at {slc_last} equal to those at {slc_first} times phase factor." ) self._wfs[slc_last] = self._wfs[slc_first] * phase if self.u_nk is not None: # Set the last point along mesh_dir axis equal to first # multiplied by the phase factor self._u_nk[slc_last] = self._u_nk[slc_first] * phase self._psi_nk[slc_last] = self._psi_nk[slc_first] def _impose_loop(self, mesh_dir): r"""Impose a loop condition along a given mesh direction. This routine can be used to set the eigenvectors equal (with equal phase), by replacing the last eigenvector with the first one along the `mesh_dir` direction (for each band). Parameters ---------- mesh_dir: int Direction of `WFArray` along which you wish to impose periodic boundary conditions. See Also -------- :func:`pythtb.WFArray._impose_pbc` Notes ----- This routine should not be used if the first and last points are related by a reciprocal lattice vector; in that case, :func:`pythtb.WFArray._impose_pbc` should be used instead. It is assumed that the first and last points along the `mesh_dir` direction correspond to the same Hamiltonian (this is **not** checked). Examples -------- Suppose the WFArray object is three-dimensional corresponding to `(kx, ky, lambda)` where `(kx, ky)` are wavevectors of a 2D insulator and lambda is an adiabatic parameter that goes around a closed loop. Then to insure that the states at the ends of the lambda path are equal (with equal phase) in preparation for computing Berry phases in lambda for given `(kx, ky)`, do >>> wf._impose_loop(mesh_dir = 2) """ if not isinstance(mesh_dir, (int, np.integer)): raise TypeError("mesh_dir must be an integer.") if mesh_dir < 0 or mesh_dir >= self.naxes: raise ValueError( f"mesh_dir must be between 0 and {self.naxes - 1}, got {mesh_dir}." ) if mesh_dir in self.mesh.k_axes and self.mesh.is_k_torus: raise ValueError("Cannot impose loop condition on periodic k-space axis.") slc_first, slc_last = self._edge_slices(mesh_dir) logger.debug( f"Setting wavefunctions at {slc_last} equal to those at {slc_first}." ) self._wfs[slc_last] = self._wfs[slc_first] if self.dim_k > 0: if self.u_nk is not None: self._u_nk[slc_last] = self._u_nk[slc_first] if self.psi_nk is not None: self._psi_nk[slc_last] = self._psi_nk[slc_first] def _unit_shift(self, axis: int, direction: int = 1) -> list[int]: """Return an integer shift vector with ±1 along *axis*.""" if axis < 0 or axis >= self.naxes: raise IndexError(f"axis must be in [0, {self.naxes - 1}]") if direction not in (-1, 1): raise ValueError("direction must be ±1.") v = [0] * self.naxes v[axis] = direction return v @staticmethod def _bounded_shift(A: np.ndarray, axis: int, sh: int) -> np.ndarray: """Shift array A by *sh* along *axis* without wrapping; fill vacated slab with zeros.""" if sh == 0: return A sl_all = [slice(None)] * A.ndim B = np.zeros_like(A) if sh > 0: sl_src = sl_all.copy() sl_dst = sl_all.copy() sl_src[axis] = slice(0, -sh) sl_dst[axis] = slice(sh, None) else: # sh < 0 shn = -sh sl_src = sl_all.copy() sl_dst = sl_all.copy() sl_src[axis] = slice(shn, None) sl_dst[axis] = slice(0, -shn) B[tuple(sl_dst)] = A[tuple(sl_src)] return B def _boundary_phase_for_shift(self, shift_vec): """Compute exp(-i G dot r) mask for a multi-axis integer shift. The returned array is broadcast to match the stored state tensor shape (including lambda axes and the state axis). For spinful models, it is also broadcast over the spin axis. """ mesh = self.mesh nks = np.array(mesh.shape_k, dtype=int) dim_k = nks.size if dim_k == 0: return np.array(1.0, dtype=complex) k_axes = np.asarray(mesh.k_axis_indices, dtype=int) # Normalize shift vector and restrict to k-axes shifts = np.zeros(dim_k, dtype=int) sv = np.atleast_1d(shift_vec) # guard: shift_vec may be given in full mesh-axis indexing for lk, mx in enumerate(k_axes): sh = int(sv[mx]) if mx < sv.size else 0 # Only keep shifts on axes that wind the BZ; zero out closed axes if mesh.is_axis_bz_winding(mx): if mesh.is_axis_closed(mx): sh = 0 logger.info(f"Axis {mx} is closed; removing shift.") else: if sh != 0: logger.info(f"Axis {mx} is not BZ-winding; removing shift.") sh = 0 shifts[lk] = sh # Integer index grid over k-axes: shape (*nks, dim_k) idx_grid = np.stack( np.meshgrid(*[np.arange(n) for n in nks], indexing="ij"), axis=-1 ) # (*nks, dim_k) shifted = idx_grid + shifts # (*nks, dim_k) # Wrap counts per k-axis (handles arbitrary |shift| >= 1) # floor division is the correct "how many cells crossed" counter # e.g. n=10: (-1)//10 -> -1; (10)//10 -> 1; (21)//10 -> 2 wraps_k = shifted // nks # (*nks, dim_k), signed wrap count # Map sampling-axis wraps to k-components via topology mask # Build M[local_k_axis, comp] = 1 if that sampling axis contributes to comp M = np.zeros((dim_k, dim_k), dtype=int) for idx, ax in enumerate(k_axes): for c in range(dim_k): if mesh.is_axis_bz_winding(ax, c): M[idx, c] = 1 # Project wraps to components: G_comp shape (*nks, dim_k) G_comp = np.einsum("...i, ic -> ...c", wraps_k, M, dtype=int) # Orbital positions tau restricted to periodic real-space components (norb, dim_k) per = getattr(self.lattice, "periodic_dirs", None) if per is None: if self.dim_k != self.lattice.dim_r: logger.warning( "WFArray._boundary_phase_for_shift: lattice.periodic_dirs missing; " "falling back to first dim_k components." ) orb = self.lattice.orb_vecs[:, :dim_k] else: per = np.asarray(per, dtype=int) if per.size < dim_k: raise ValueError( f"lattice.periodic_dirs lists {per.size} directions; expected ≥ dim_k={dim_k}." ) orb = self.lattice.orb_vecs[:, per[:dim_k]] # dot = sum_c G_comp[..., c] * tau[:, c] -> shape (*nks, norb) dot = np.tensordot(G_comp, orb, axes=([G_comp.ndim - 1], [1])) phase = np.exp(-2j * np.pi * dot).astype(complex) # (*nks, norb) # Broadcast to (nk..., nl..., nstate, norb[, nspin]) in one reshape/expand shape = (*nks, *([1] * self.mesh.nl_axes), 1, self.norb) # band axis phase = phase.reshape(shape) if self.nspin == 2: phase = phase[..., np.newaxis] # spin axis axis_perm = self._canonical_to_mesh_axes() if axis_perm: perm = axis_perm + tuple(range(len(self.mesh.axes), phase.ndim)) phase = np.transpose(phase, perm) return phase def _invalidate_boundary_links(self, array: np.ndarray, shift_vec) -> np.ndarray: """Stamp NaNs on slabs where a neighbor does not exist for the given shift.""" mesh = self.mesh ndims = self.naxes if not isinstance(shift_vec, (tuple, list, np.ndarray)): shift_vec = (shift_vec,) for axis, shift in enumerate(shift_vec): if axis >= ndims or shift == 0: continue wraps = mesh.is_axis_looped(axis) or mesh.is_axis_bz_winding(axis) closed = mesh.is_axis_closed(axis) if wraps and not closed: continue boundary_index = -1 if shift > 0 else 0 slicer = [slice(None)] * array.ndim slicer[axis] = boundary_index array[tuple(slicer)] = np.nan + 0j return array
[docs] def roll_states_with_pbc( self, shift_vec: list[int], flatten_spin_axis: bool = True, strip_boundary: bool = False, ): """Roll states with periodic boundary conditions. This method rolls the wavefunction states according to the given shift vector, applying the appropriate boundary phases to axes that have periodic boundary conditions. Parameters ---------- shift_vec : list[int] List of integer shifts for each axis. flatten_spin_axis : bool, optional Whether to flatten the spin axis into the orbital axis, by default True. strip_boundary : bool, optional Whether to strip the boundary after rolling, by default False. This will remove the boundary states along axes with non-periodic boundary conditions. Returns ------- np.ndarray The rolled wavefunction states with applied boundary conditions. Examples -------- >>> rolled_wfa = wfa.roll_states_with_pbc([1, 0]) >>> np.allclose(rolled_wfa[4, 3], rolled_wfa[3, 3]) True """ states = self.wfs mesh = self.mesh if np.any(abs(np.array(shift_vec, dtype=int)) > 1): raise ValueError("Only unit shifts (+1, 0, -1) are supported in shift_vec.") if len(shift_vec) < mesh.nk_axes: raise ValueError( "shift_vec must have at least as many elements as k-axes in the mesh." ) elif len(shift_vec) > mesh.naxes: raise ValueError( "shift_vec must have at most as many elements as total axes in the mesh." ) rolled = states for ax, sh in enumerate(shift_vec): if not sh: continue wraps = mesh.is_axis_looped(ax) or mesh.is_axis_bz_winding(ax) closed = mesh.is_axis_closed(ax) if not closed and wraps: rolled = np.roll(rolled, shift=-int(sh), axis=ax) else: logger.info( f"Applying bounded shift {sh} to axis {ax} without wrapping." ) rolled = self._bounded_shift(rolled, axis=ax, sh=-int(sh)) if strip_boundary: sl = [slice(None)] * rolled.ndim for ax, sh in enumerate(shift_vec): loops = mesh.is_axis_looped(ax) or mesh.is_axis_bz_winding(ax) closed = mesh.is_axis_closed(ax) if sh and (closed or not loops): # drop the last index in that direction sl[ax] = slice(None, -1) rolled = rolled[tuple(sl)] phase = self._boundary_phase_for_shift(tuple(shift_vec)) rolled = rolled * phase if flatten_spin_axis and self.nspin == 2: rolled = rolled.reshape(*rolled.shape[:-2], self.norb * self.nspin) return rolled
[docs] def overlap_matrix(self, use_k_metric: bool = False) -> np.ndarray: r"""Compute the overlap matrix of the cell periodic eigenstates on nearest neighbor k-shell. Overlap matrix is of the form .. math:: M_{m,n}^{\mathbf{b}}(\mathbf{k}, \lambda) = \langle u_{m, \mathbf{k}, \lambda} | u_{n, \mathbf{k+b}, \lambda} \rangle where :math:`\mathbf{b}` is a displacement vector connecting nearest neighbor k-points. .. note:: This method is inteded for use with WFArray objects defined on a uniform grid of k-points (i.e., Mesh with regular spacing along k-axes). .. versionadded:: 2.0.0 Parameters ---------- use_k_metric : bool, optional Whether to use the k-metric for neighbor lookup. If True, the function computes nearest neighbor k-points in the mesh considering the metric in Cartesian space. This means that :math:`\mathbf{b}` is not necessarily a unit vector in reduced coordinates, but rather the vector that connects the closest k-points in Cartesian space. If False, the function computes nearest neighbor k-points by shifting the k-points by one step along each k-axis in the mesh. This means that :math:`\mathbf{b}` is a unit vector in reduced coordinates along each k-axis. Default is False. Returns ------- M : np.ndarray Overlap matrix with shape ``(*mesh_shape, num_nnbrs, n_states, n_states)`` Notes ----- - The overlap matrix is computed between cell-periodic parts of the Bloch states. - When ``use_k_metric`` is True, the leading neighbor index runs over the *Cartesian metric-derived* k-shell (if any) followed by the :math:`\pm \mu` unit steps along each :math:`\lambda`-axis. - When ``use_k_metric`` is False, every :math:`k`- and :math:`\lambda`-axis contributes both the :math:`+\mu` and :math:`-\mu` neighbors in *reduced coordinates*, so the third dimension length is ``2 * mesh.naxes``. The ordering is :math:`(+\kappa_0, -\kappa_0, \cdots, +\lambda_0, -\lambda_0, \cdots )``. - The overlap matrix is only defined where neighbors exist. - Boundary behavior follows the ``Mesh`` axis type: - **Periodic (no endpoint)** The final point wraps to the first (looped). All entries are meaningful. - **Periodic with endpoint included** The endpoint is explicitly included in the mesh. No wrap is applied; the final forward link is undefined and is filled with ``NaN``. Discard the final slice along that axis. - **Nonperiodic (open directions)`** No physical neighbor exists at the boundary. Forward links there are filled with ``NaN`` and should be dropped. These conventions match those used internally by :meth:`wilson_loop`, :meth:`berry_phase`, and related routines, which automatically trim invalid points. """ if use_k_metric: logger.info("Computing overlap matrix using k-metric for neighbor lookup.") k_axis_indices = self.mesh.k_axis_indices k_shifts = [] if k_axis_indices: _, idx_shell = self.lattice.nn_k_shell( self.mesh.shape_k, n_shell=1, report=False ) for delta in idx_shell[0]: shift_vec = [0] * self.naxes for offset, axis_idx in zip(delta, k_axis_indices): shift_vec[axis_idx] = int(offset) k_shifts.append(shift_vec) lambda_shifts = [ self._unit_shift(axis_idx, direction) for axis_idx in self.mesh.lambda_axis_indices for direction in (+1, -1) ] shift_vectors = k_shifts + lambda_shifts else: logger.info( "Computing overlap matrix without k-metric for neighbor lookup." ) shift_vectors = [ self._unit_shift(axis_idx, direction) for axis_idx in ( self.mesh.k_axis_indices + self.mesh.lambda_axis_indices ) for direction in (+1, -1) ] # overlap matrix M = np.zeros( (*self.shape_mesh, len(shift_vectors), self.nstates, self.nstates), dtype=complex, ) u_nk = self.states(flatten_spin_axis=True) for slot, shift_vec in enumerate(shift_vectors): rolled = self.roll_states_with_pbc(shift_vec, flatten_spin_axis=True) overlaps = np.einsum("...mj, ...nj -> ...mn", u_nk.conj(), rolled) overlaps = self._invalidate_boundary_links(overlaps, shift_vec) M[..., slot, :, :] = overlaps return M
@staticmethod def _wilson_loop(wfs_loop, wilson_evals: bool = False): r"""Wilson loop unitary matrix Computes the Wilson loop unitary matrix and its eigenvalues for multiband Berry phases. The Wilson loop is a geometric quantity that characterizes the topology of the band structure. It is defined as the product of the overlap matrices between neighboring wavefunctions in the loop. Specifically, it is given by .. math:: U_{\text{Wilson}} = \prod_{n} U_{n} where :math:`U_{n}` is the unitary part of the overlap matrix between neighboring wavefunctions in the loop, and the index :math:`n` labels the position in the loop (see :meth:`links` for more details). When ``wilson_evals=True``, the function computes the eigenvalues of the Wilson loop unitary matrix. The eigenvalues are complex numbers of the form .. math:: \lambda_n = e^{i \phi_n} where :math:`\phi_n` are the multiband Berry phases associated with each band. Parameters ---------- wfs_loop : np.ndarray Has format ``[loop_idx, band, orbital(, spin)]`` and loop has to be one dimensional. Assumes that first and last loop-point are the same. Therefore if there are n wavefunctions in total, will calculate phase along n-1 links only! wilson_evals : bool, optional If True, then will compute eigenvalues of the Wilson loop unitary and return the negative phases. Otherwise just return the Wilson loop unitary matrix. Default is False. Returns ------- U_wilson : np.ndarray Wilson loop unitary matrix of shape ``(band, band)``. eval_pha : np.ndarray, optional Multiband Berry phases associated with each band. Returned only if ``wilson_evals=True``, otherwise not returned. Notes ------ ``wilson_evals`` are to be distinguished from multiband Berry phases, in :meth:`berry_phase`. The ``berry_evals`` are the phase arguments of ``wilson_evals`` and are always returned between :math:`-\pi` and :math:`\pi`. """ # check that wfs_loop has appropriate shape if wfs_loop.ndim < 3 or wfs_loop.ndim > 4: raise ValueError( "wfs_loop must be a 3D or 4D array with shape [loop_idx, band, orbital(, spin)]" ) # check if there is a spin axis, then flatten is_spin = wfs_loop.ndim == 4 and wfs_loop.shape[-1] == 2 if is_spin: # flatten spin axis wfs_loop = wfs_loop.reshape(wfs_loop.shape[0], wfs_loop.shape[1], -1, 2) ovr_mats = wfs_loop[:-1].conj() @ wfs_loop[1:].swapaxes(-2, -1) V, _, Wh = np.linalg.svd(ovr_mats, full_matrices=False) U_link = V @ Wh U_wilson = U_link[0] for i in range(1, len(U_link)): U_wilson = U_wilson @ U_link[i] # calculate phases of all eigenvalues if wilson_evals: eigvals = np.linalg.eigvals(U_wilson) # Wilson loop eigenvalues return U_wilson, eigvals else: return U_wilson @staticmethod def _berry_loop(wfs_path, berry_evals: bool = False): r"""Berry phase along a one-dimensional path of wavefunctions. The Berry phase along a one-dimensional path of wavefunctions is computed using the Wilson loop unitary matrix. When ``berry_evals=False``, the Berry phase is computed as the logarithm of the determinant of the product of the overlap matrices between neighboring wavefunctions in the path. In otherwords, the Berry phase is given by the formula: .. math:: \phi = -\text{Im} \ln \det U_{\rm Wilson} where :math:`U` is the Wilson loop unitary matrix obtained from :meth:`wilson_loop`. When ``berry_evals=True``, the function returns an array of the individual phases (multiband Berry phases) for each band. They are computed as .. math:: \phi_n = -\text{Im} \ln \lambda_n where :math:`\lambda_n` are the eigenvalues of the Wilson loop unitary matrix. These multiband Berry phases correspond to the "maximally localized Wannier centers" or "Wilson loop eigenvalues". Parameters ---------- wfs_loop : np.ndarray Wavefunctions in the path, with shape ``(path_idx, band, orbital, spin)``. berry_evals : bool, optional Default is `False`. If `True`, will return the argument of the eigenvalues of the Wilson loop unitary matrix instead of the total Berry phase. If False, will return the total Berry phase for the loop. Returns ------- berry_phase : float The total Berry phase for the loop. berry_evals : np.ndarray, optional If berry_evals is True, returns an array of multiband Berry phases associated with each band. Notes ----- The loop is assumed to be one-dimensional. The wavefunctions in the loop should be ordered such that the first point corresponds to the first wavefunction, the second point to the second wavefunction, and so on, up to the last point, which corresponds to the last wavefunction. When the path of wavefunctions is closed, the Berry phase corresponds to the geometric phase acquired by the wavefunctions as they are transported around the loop. If the path is not closed, the Berry phase will depend on the specific path taken. """ if wfs_path.ndim < 3 or wfs_path.ndim > 4: raise ValueError( "wfs_path must be a 3D or 4D array with shape (path_idx, band, orbital(, spin))" ) if berry_evals: U_wilson, eigvals = WFArray._wilson_loop(wfs_path, wilson_evals=berry_evals) eigvals_phase = -np.angle(eigvals) # Multiband Berry phases # sort the eigenvalues eigvals_phase = np.sort(eigvals_phase) berry_phase = -np.angle(np.linalg.det(U_wilson)) return berry_phase, eigvals_phase else: U_wilson = WFArray._wilson_loop(wfs_path, wilson_evals=berry_evals) berry_phase = -np.angle(np.linalg.det(U_wilson)) return berry_phase
[docs] def wilson_loop(self, axis_idx: int, state_idx=None, wilson_evals: bool = False): r"""Wilson loop along a specified mesh axis. The Wilson loop is defined as the ordered product of *unitary link matrices* along a closed loop string of points in parameter space. For a direction :math:`\mu` in the mesh, this routine computes the Wilson loop unitary matrix along that direction, .. math:: W_{\mu} = \prod_{n=0}^{N_{\mu}-1} U_{\mu}\!\bigl(\boldsymbol{\kappa}_n \bigr) where :math:`\mu` corresponds to ``axis_idx``, and :math:`U_{\mu}(\boldsymbol{\kappa}_n)` is the **unitary part** of the overlap matrix between the states at consecutive mesh points (see :meth:`links`). .. versionadded:: 2.0.0 Parameters ---------- axis_idx : int Index of ``Mesh`` axis along which Wilson loop is computed. state_idx : int, array-like, optional Optional band index or array of band indices to be included in the subsequent calculations. If unspecified, all bands are included. wilson_evals : bool, optional If True, then will compute eigenvalues of the Wilson loop and return them along with the Wilson loop. Default is False. Returns ------- U_wilson : np.ndarray Wilson loop unitary matrix of shape ``(band, band)``. eigvals : np.ndarray, optional Unit norm complex eigenvalues of the Wilson loop unitary matrix. Returned only if ``wilson_evals=True``, otherwise not returned. See Also -------- :meth:`berry_phase` :meth:`links` Notes ----- - When ``wilson_evals=True``, the function computes and returns the eigenvalues of the Wilson loop unitary matrix. The eigenvalues are complex numbers of the form .. math:: \lambda_n = e^{i \phi_n} where :math:`\phi_n` are the multiband Berry phases associated with each band. - ``wilson_evals`` are to be distinguished from multiband Berry phases (as returned in :meth:`berry_phase` with ``berry_evals=True``). The ``berry_evals`` are the **phase** arguments of ``wilson_evals``, not the eigenvalues themselves. - For an array of size ``N`` along ``axis_idx``, the Wilson loop is formed from the ``N-1`` nearest-neighbor inner products. This gives an open-path "Wilson line" unless the endpoints correspond to the same physical Hamiltonian; in which case the first state is appended to the end (if endpoints are not already identical) to close the loop (with appropriate PBC phase along k-axes). """ if ( not isinstance(axis_idx, (int, np.integer)) or axis_idx < 0 or axis_idx >= self.naxes ): raise ValueError(f"axis_idx must be an integer in [0, {self.naxes - 1}]") # States (optionally restricted to a subspace) u = self.states(state_idx=state_idx, flatten_spin_axis=True) u_expanded = self.states(state_idx=state_idx, flatten_spin_axis=False) u_loop = u # init wf loop for ax, comp in self.mesh._get_loop_ax_comp(): # If axis is periodic and open, we need to append # the first state to the end if ax == axis_idx and not self.mesh.is_axis_closed(ax, comp): # If component is along k and wraps bz, apply phase if self.mesh.is_axis_bz_winding(ax, comp): logger.debug( "Applying phase to state at beginning to end of open periodic axis." ) phase, _, _ = self._get_pbc_phases(ax, comp) u_first = np.take(u_expanded, 0, axis=axis_idx) u_last = u_first * phase # No phase is applied else: u_last = np.take(u_expanded, 0, axis=axis_idx) # flatten spin if self.nspin == 2: u_last = u_last.reshape(*u_last.shape[:-2], -1) logger.debug( "Appending state at beginning to end of open periodic axis." ) u_loop = np.concatenate( [u_loop, np.expand_dims(u_last, axis=axis_idx)], axis=axis_idx ) # Bring loop axis first for easy slicing over transverse axes u_loop = np.moveaxis(u_loop, axis_idx, 0) tail_shape = u_loop.shape[1:-2] # Shape of the tail (transverse) axes n_sub = u_loop.shape[-2] # Number of subbands if wilson_evals: evals = np.empty((*tail_shape, n_sub), dtype=float) unitar = np.empty((*tail_shape, n_sub, n_sub), dtype=complex) # Iterate over transverse indices without flattening it = np.ndindex(*tail_shape) if tail_shape else [()] for idx in it: # Take all points along loop axis, and the given transverse indices # plus all states and orbitals (and spin) slicer = (slice(None),) + idx + (slice(None), slice(None)) wf_line = u_loop[slicer] # shape: (n_mu or n_mu+1, n_sub, norb*spin) if wilson_evals: # val are the individual phases of Wilson loop eigenvalues U, eval = self._wilson_loop(wf_line, wilson_evals=wilson_evals) evals[idx] = eval else: U = self._wilson_loop(wf_line, wilson_evals=wilson_evals) # val is the total Berry phase for the loop unitar[idx] = U unitar = np.array(unitar) if wilson_evals: evals = np.array(evals) return unitar, evals return unitar
[docs] def berry_connection( self, axis_idx: int | ArrayLike = None, state_idx: int | ArrayLike = None, *, return_unitaries: bool = False, cartesian: bool = False, ): r"""Berry connection from parallel-transport links. This routine evaluates the (non-Abelian) Berry connection on the reduced parameter mesh. For each mesh direction :math:`\mu` in ``axis_idx``, the connection is obtained from the parallel-transport link unitaries :math:`U_{\mu}(\boldsymbol{\kappa})` returned by :meth:`links`, using the discrete finite-difference approximation: .. math:: A_{\mu}(\boldsymbol{\kappa}) \;=\; -\frac{1}{i\,\Delta \kappa_{\mu}} \log\!\big[\,U_{\mu}(\boldsymbol{\kappa})\,\big], where :math:`\Delta \kappa_{\mu}` is the step size (in reduced coordinates) between adjacent points along direction :math:`\mu`. The result is a matrix-valued, non-Abelian connection defined over the full reduced mesh :math:`\boldsymbol{\kappa} = (k_1,\dots,k_{d_k};\,\lambda_1,\dots,\lambda_{d_\lambda})`, with the final two array axes spanning the band subspace specified by ``state_idx``. .. versionadded:: 2.0.0 Parameters ---------- axis_idx : int or array_like of int or None, optional Mesh directions :math:`\mu` along which to compute the connection. If None, all mesh axes are used. Default is None. state_idx : int or array_like of int or None, optional Subspace (band indices) to use. If None, use all. Default is None. return_unitaries : bool, optional If True, also return the link unitaries :math:`U_{\mu}`. Default is False. cartesian : bool, optional If True, compute the step size :math:`\Delta k_\mu` in Cartesian space (using the reciprocal lattice vectors) rather than in reduced coordinates. Default is False. Returns ------- A : ndarray Non-Abelian connection with shape: ``(n_mu, *mesh_shape, nstate, nstate)``, where ``n_mu = len(axis_idx)`` (or ``WFArray.naxes`` if ``axis_idx=None``) and ``nstate = len(state_idx)`` (or ``WFArray.nstates`` if ``state_idx=None``). U : ndarray, optional The link unitaries with same shape (returned if ``return_unitaries=True``). See Also -------- :meth:`links`, :meth:`berry_phase`, :meth:`wilson_loop` Notes ----- - The connection is Hermitian, :math:`A_\mu^\dagger = A_\mu`. - The logarithm is computed using spectral decomposition of the unitaries :math:`U = V \, \text{diag}(e^{i\theta_n}) \, V^\dagger`, with principal phases :math:`\theta_n` in :math:`(-\pi, \pi]`. - Entries of the connection at invalid boundaries (where links are NaN; see :meth:`links`) remain NaN. """ # Get link unitaries U_mu (n_mu, ..., nstate, nstate), with NaNs at boundaries U = self.links( state_idx=state_idx, axis_idx=axis_idx ) # (n_mu, ..., nstate, nstate) # Build dk per mu if axis_idx is None: axis_idx = np.arange(self.naxes, dtype=int) else: axis_idx = np.atleast_1d(axis_idx) # Compute spacings step_list = [] dim_tot = self.mesh.dim_k + self.mesh.dim_lambda dim_k = self.mesh.dim_k for ax in axis_idx: delta_vec = [] for comp in range(dim_tot): arr = self.mesh.get_axis_range(ax, comp) if arr.size < 2: continue diff = arr[1] - arr[0] if not np.isclose(diff, 0.0): delta_vec.append(diff) if not delta_vec: raise ValueError( f"Could not determine step size along axis {ax} for Berry connection." ) if cartesian and dim_k: reduced = delta_vec[:dim_k] # reduced k-step cart = reduced @ self.lattice.recip_lat_vecs # Cartesian k-step combined = np.concatenate([cart, delta_vec[dim_k:]]) step_list.append(np.linalg.norm(combined)) else: # reduced step (first varying component) nonzero = np.flatnonzero(delta_vec) step_list.append(delta_vec[nonzero[0]]) dk = np.asarray(step_list, dtype=float) # Compute A_mu from U_mu # We'll handle boundaries (NaNs) by masking rows; leave A as NaN there. A = np.empty_like(U, dtype=complex) # A = (1/(i dk)) * log(U) via eigen-decomposition # (unitary U is normal -> unitarily diagonalizable: # U = V diag(e^{i theta}) V^dagger, log U = V diag(i theta) V^dagger) for i_mu, dki in enumerate(dk): Ui = U[i_mu] Ai = np.empty_like(Ui, dtype=complex) # flatten batch, do per-matrix eig, then reshape back batch_shape = Ui.shape[:-2] nB = Ui.shape[-1] Ui_flat = Ui.reshape((-1, nB, nB)) Ai_flat = Ai.reshape((-1, nB, nB)) # mask boundaries: where any entry is NaN, fill A with NaN and skip eig invalid = np.isnan(Ui_flat[..., 0, 0]) for p in range(Ui_flat.shape[0]): if invalid[p]: Ai_flat[p, :, :] = np.nan + 0j continue w, V = np.linalg.eig(Ui_flat[p]) # principal phases in (-pi, pi] theta = np.angle(w) logU = (V * (1j * theta)) @ V.conj().T Ai_flat[p] = -logU / (1j * dki) A[i_mu] = Ai_flat.reshape(batch_shape + (nB, nB)) return (A, U) if return_unitaries else A
[docs] def berry_phase( self, axis_idx: int, state_idx: list[int] = None, berry_evals: bool = False, contin: bool = True, ): r"""Berry phase accumulated along a specified mesh axis. This routine evaluates the geometric phase accumulated by a set of states transported along a closed loop in parameter space. The phase is computed from the Wilson loop unitary matrix. Explicitly, .. math:: \phi_{\mu} = -\text{Im} \ln \det W_{\mu} where :math:`W_{\mu}` (obtained from :meth:`wilson_loop`) is the ordered product of the unitary (overlap) links along the ``axis_idx`` direction :math:`\mu` in the mesh. If ``berry_evals=True``, the function additionally returns the *individual* multiband phases ("hybrid Wannier centers") determined from the phases of the eigenvalues of the Wilson loop unitary matrix, .. math:: \phi_{\mu}^{(n)} = -\text{Im} \ln \lambda_{\mu}^{(n)} When summed modulo :math:`2\pi`, the :math:`\phi_{\mu}^{(n)}` reproduce the total Berry phase. Parameters ---------- axis_idx : int Index of the ``Mesh`` axis along which the Berry phase is computed. For a one-dimensional ``Mesh``, this must be 0. .. versionchanged:: 2.0.0 Renamed from `dir` to `axis_idx` to avoid conflict with built-in Python function `dir()`. state_idx : int or array-like of int, optional Band index or list of indices specifying which states to include in the Berry phase calculation. If omitted, all bands are included. .. versionchanged:: 2.0.0 Renamed from ``occ``. The band indices are not required to be occupied bands only. The default behavior is to include all bands, and the ``"all"`` option has been removed. contin : bool, optional Controls branch-continuity of the returned Berry phases. If True (default), phases are chosen to vary smoothly along the orthogonal mesh directions. The reference phase (first string) is constrained to lie in :math:`[-\pi, \pi]`. Subsequent phases are adjusted by adding or subtracting :math:`2 \pi` as needed to ensure continuity with the previous string. If False, all phases are constrained to lie in :math:`[-\pi, \pi]`. berry_evals : bool, optional If True, also return the Wilson loop eigen-phases (hybrid Wannier centers). These are branch-fixed following the same rules as above. Returns ------- phase : np.ndarray Total Berry phase along ``axis_idx``. Scalar when only one axis is present in the ``Mesh``, otherwise an array with one less dimension than the original ``Mesh``, corresponding to the Berry phase for each remaining ``Mesh`` points. For example, if ``Mesh`` has axes ``[i, j, k]``, and ``axis_idx=1``, then the returned array has shape ``[i, k]``. evals : np.ndarray, optional Only returned if ``berry_evals=True``. Wilson loop eigen-phases along ``axis_idx``. Follows the same indexing convention as above, with an additional trailing index labeling the phases :math:`\phi_n`. See Also -------- :meth:`wilson_loop` :ref:`formalism` : Sec. 4.5 for the discretized formula used to compute Berry phase. :ref:`haldane-bp-nb` : For an example :ref:`cone-nb` : For an example :ref:`three-site-thouless-nb` : For an example Notes ----- - ``berry_evals`` are to be distinguished from Wilson loop eigenvalues (as returned in :meth:`wilson_loop` with ``wilson_evals=True``). The ``berry_evals`` are the **phase** arguments of ``wilson_evals``, not the eigenvalues themselves. - For a single 1D string, the Berry phase is always returned in :math:`[-\pi,\pi]`. - For multidimensional meshes, the phase is computed for each 1D string obtained by slicing along ``axis_idx``; continuity treatment depends on ``contin``. - A manifold specified by ``state_idx`` should be isolated from other states (no degeneracies with states outside the subset). Ensuring this is the user's responsibility. - For an array of size ``N`` along ``axis_idx``, the Wilson loop is formed from the ``N-1`` nearest-neighbor inner products. This gives an open-path phase unless the endpoints correspond to the same physical Hamiltonian; in which case the first state is appended to the end (if endpoints are not already identical) to close the loop (with appropriate PBC phase along k-axes). Examples -------- Compute Berry phases along the second mesh axis for the three lowest bands. If ``wf`` has axes ``[i, j, k]``, then ``phase[i, k]`` is the result along the string ``wf[i, :, k]``. >>> phase = wf.berry_phase(axis_idx=1, state_idx=[0, 1, 2]) """ if ( not isinstance(axis_idx, (int, np.integer)) or axis_idx < 0 or axis_idx >= self.naxes ): raise ValueError(f"axis_idx must be an integer in [0, {self.naxes - 1}]") # States (optionally restricted to a subspace) u = self.states(state_idx=state_idx, flatten_spin_axis=True) u_expanded = self.states(state_idx=state_idx, flatten_spin_axis=False) u_loop = u # init wf loop for ax, comp in self.mesh._get_loop_ax_comp(): # If axis is periodic and open, we need to append # the first state to the end if ax == axis_idx and not self.mesh.is_axis_closed(ax, comp): # If component is along k and wraps bz, apply phase if self.mesh.is_axis_bz_winding(ax, comp): logger.debug( "Applying phase to state at beginning to end of open periodic axis." ) real_comp = self.lattice.periodic_dirs[comp] phase, _, _ = self._get_pbc_phases(ax, real_comp) u_first = np.take(u_expanded, 0, axis=axis_idx) u_last = u_first * phase # No phase is applied else: u_last = np.take(u_expanded, 0, axis=axis_idx) # flatten spin if self.nspin == 2: u_last = u_last.reshape(*u_last.shape[:-2], -1) logger.debug( "Appending state at beginning to end of open periodic axis." ) u_loop = np.concatenate( [u_loop, np.expand_dims(u_last, axis=axis_idx)], axis=axis_idx ) # Bring loop axis first for easy slicing over transverse axes u_loop = np.moveaxis(u_loop, axis_idx, 0) tail_shape = u_loop.shape[1:-2] # Shape of the tail (transverse) axes n_sub = u_loop.shape[-2] # Number of subbands if berry_evals: out = np.empty((*tail_shape, n_sub), dtype=float) else: out = np.empty(tail_shape, dtype=float) # Iterate over transverse indices without flattening it = np.ndindex(*tail_shape) if tail_shape else [()] for idx in it: # Take all points along loop axis, and the given transverse indices # plus all states and orbitals (and spin) slicer = (slice(None),) + idx + (slice(None), slice(None)) wf_line = u_loop[slicer] # shape: (n_mu or n_mu+1, n_sub, norb*spin) if berry_evals: # val are the individual phases of Berry loop eigenvalues _, val = self._berry_loop(wf_line, berry_evals=berry_evals) else: # val is the total Berry phase for the loop val = self._berry_loop(wf_line, berry_evals=berry_evals) out[idx] = val out = np.array(out) # Make continuous if contin: if len(tail_shape) == 0: # Make phases continuous for each band # ret = np.unwrap(ret, axis=0) pass elif berry_evals: # 2D case if out.ndim == 2: out = _array_phases_cont(out, out[0]) # 3D case elif out.ndim == 3: for i in range(out.shape[1]): if i == 0: clos = out[0, 0] else: clos = out[0, i - 1] out[:, i] = _array_phases_cont(out[:, i], clos) elif self._dim_arr != 1: raise ValueError("Wrong dimensionality!") else: # 2D case if out.ndim == 1: out = _one_phase_cont(out, out[0]) # 3D case elif out.ndim == 2: for i in range(out.shape[1]): if i == 0: clos = out[0, 0] else: clos = out[0, i - 1] out[:, i] = _one_phase_cont(out[:, i], clos) elif self._dim_arr != 1: raise ValueError("Wrong dimensionality!") return out
[docs] def berry_flux( self, plane: ArrayLike | None = None, state_idx: int | ArrayLike | None = None, non_abelian: bool = False, *, use_tensorflow: bool = False, ): r"""Berry flux tensor using the Fukui-Hatsugai-Suzuki plaquette method. The Berry flux tensor :math:`\mathcal{F}_{\mu\nu}(\boldsymbol{\kappa})` on a reduced mesh point :math:`\boldsymbol{\kappa}` is evaluated by forming the ordered product of parallel-transport link unitaries around an elementary plaquette in directions :math:`(\mu,\nu)` [1]_: .. math:: \mathcal{F}_{\mu\nu}(\boldsymbol{\kappa}) = -\,\operatorname{Im}\, \ln\!\Big[ U_{\mu}(\boldsymbol{\kappa})\, U_{\nu}(\boldsymbol{\kappa}+\hat{\mu})\, U_{\mu}^{\dagger}(\boldsymbol{\kappa}+\hat{\nu})\, U_{\nu}^{\dagger}(\boldsymbol{\kappa}) \Big], where :math:`U_{\mu}(\boldsymbol{\kappa})` is the unitary link obtained from the overlap between states at :math:`\boldsymbol{\kappa}` and its forward neighbor along direction :math:`\mu` (see :meth:`links`). When a multiband subspace is supplied, this expression yields the **non-Abelian** matrix-valued flux; taking the matrix determinant gives the usual **Abelian** (band-traced) flux. .. versionremoved:: 2.0.0 The `individual_phases` parameter has been removed. Parameters ---------- plane : (2,) array_like, optional Array or tuple of two mesh-axis indices over which the flux is computed. By default, all index pairs are evaluated. .. versionchanged:: 2.0.0 Renamed from ``dirs``. state_idx : int or array_like of int, optional Indices of the states spanning the subspace. If None, all states are used. .. versionchanged:: 2.0.0 Renamed from ``occ``. The band indices are not required to be occupied bands only. The default behavior is to include all bands, and the ``"all"`` option has been removed. non_abelian : bool, optional If *True*, return the matrix-valued non-Abelian flux. If *False* (default), return the Abelian (band-traced) flux. .. versionadded:: 2.0.0 use_tensorflow : bool, optional If *True*, use TensorFlow for speedup on large calculations. Default is *False*. Returns ------- flux : ndarray The Berry flux tensor, which is an array of shape - ``non_abelian=True``, ``plane=None``: ``(naxes, naxes, *mesh_shape, nstates, nstates)``, - ``non_abelian=True``, ``plane=(mu, nu)``: ``(*mesh_shape, nstates, nstates)``, - ``non_abelian=False``, ``plane=None``: ``(naxes, naxes, *mesh_shape)``, - ``non_abelian=False``, ``plane=(mu, nu)``: ``(*mesh_shape,)``, where ``nstates = len(state_idx)`` (or ``WFArray.nstates`` if ``state_idx=None``). Notes ----- - For a given :math:`(\mathbf{k}, \boldsymbol{\lambda})`-point :math:`\boldsymbol{\kappa}` and pair of mesh directions :math:`(\mu, \nu)`, the plaquette is formed by the points: .. math:: \begin{pmatrix} \boldsymbol{\kappa} \\ \boldsymbol{\kappa} + \hat{\mu} \\ \boldsymbol{\kappa} + \hat{\mu} + \hat{\nu} \\ \boldsymbol{\kappa} + \hat{\nu} \end{pmatrix} where :math:`\hat{\mu}` and :math:`\hat{\nu}` are the vectors connecting neighboring points along directions :math:`\mu` and :math:`\nu` in the reduced mesh. - Let :math:`U_{\mu}(\boldsymbol{\kappa})` denote the unitary **link matrix** (unitary part of overlap matrix between states) from :math:`\boldsymbol{\kappa}` to :math:`\boldsymbol{\kappa} + \hat{\mu}`: .. math:: U_{\mu}(\boldsymbol{\kappa}) \equiv \mathcal{U} \,\Big[ M_\mu (\boldsymbol{\kappa}) \Big] where :math:`\mathcal{U}` denotes the unitary part of (polar decomposition, see :meth:`links`), and :math:`M_\mu (\boldsymbol{\kappa})_{mn} = \langle u_{m}(\boldsymbol{\kappa}) \,|\, u_{n}(\boldsymbol{\kappa} + \hat{\mu}) \rangle` is the overlap matrix of the states in the subspace defined by ``state_idx``. When ``non_abelian=True``, the (non-Abelian) Berry flux tensor is computed by taking the imaginary part of the matrix logarithm of the product of the link matrices around the plaquettes. It is defined as, .. math:: \mathcal{F}_{\mu\nu}(\boldsymbol{\kappa}) = -\mathrm{Im} \,\ln \Big[ U_{\mu}(\boldsymbol{\kappa}) \; U_{\nu}(\boldsymbol{\kappa} + \hat{\mu}) \; U_{\mu}^\dagger(\boldsymbol{\kappa} + \hat{\nu}) \; U_{\nu}^\dagger(\boldsymbol{\kappa}) \Big]. This definition holds for multi-band subspaces, where the link matrices are square and unitary in the occupied-band space. When ``non_abelian=False``, the (Abelian) Berry flux tensor is computed by taking the imaginary part of the logarithm of the determinant of the product of the link matrices around the plaquettes. This is equivalently the band-trace of the non-Abelian Berry flux tensor. It is defined as, .. math:: \mathcal{F}_{\mu\nu}(\boldsymbol{\kappa}) = -\mathrm{Im} \,\ln \,\det \Big[ U_{\mu}(\boldsymbol{\kappa}) U_{\nu}(\boldsymbol{\kappa} + \hat{\mu}) U_{\mu}^{\dagger}(\boldsymbol{\kappa} + \hat{\nu}) U_{\nu}^{\dagger}(\boldsymbol{\kappa}) \Big]. - This method requires at least a two-dimensional mesh in the combined adiabatic space (momentum + adiabatic parameters). Thus, ``WFArray.naxes >= 2``. - The last point along closed or non-periodic axes is trimmed from the flux array to avoid overcounting, since it is equivalent to the first point. Examples -------- Consider the case where we have a mesh with ``naxes`` axes in the combined adiabatic space (momentum + adiabatic parameters). The Berry flux can be computed over all 2D planes of the mesh, >>> flux = wf.berry_flux( ... state_idx = [0, 1, 2] ... ) # shape: (naxes, naxes, *mesh_shape) or over a specific plane, >>> flux = wf.berry_flux( ... plane = (0, 1), ... state_idx = [0, 1, 2] ... ) # shape: (*mesh_shape) We can also compute the non-Abelian Berry flux tensor over a specific plane: >>> flux = wf.berry_flux( ... plane = (0,1), ... state_idx = [0, 1, 2], ... non_abelian = True ... ) # shape: (*mesh_shape, nstates, nstates) References ---------- .. [1] \ T. Fukui, Y. Hatsugai, and H. Suzuki, *J. Phys. Soc. Jpn.* **74**, 1674 (2005). """ if (self.naxes) < 2: raise ValueError( "Berry curvature only defined if number of mesh axes >= 2." ) # Validate plane ndims = self.naxes # Total dimensionality of adiabatic space: d if plane is not None: if not isinstance(plane, (list, tuple, np.ndarray)): raise TypeError("plane must be None, a list, tuple, or numpy array.") if len(plane) != 2: raise ValueError("plane must contain exactly two directions.") if any(p < 0 or p >= ndims for p in plane): raise ValueError(f"Plane indices must be between 0 and {ndims - 1}.") if plane[0] == plane[1]: raise ValueError("Plane indices must be different.") state_idx = self._normalize_state_indices(state_idx) n_states = len(state_idx) # Number of states considered flux_shape = list( self.shape_mesh ) # Number of points in adiabatic mesh: (nk1, nk2, ..., nkd) # Initialize the Berry flux array if plane is None: shape = ( (ndims, ndims, *flux_shape, n_states, n_states) if non_abelian else (ndims, ndims, *flux_shape) ) berry_flux = np.zeros(shape, dtype=complex) dirs = list(range(ndims)) plane_idxs = ndims else: p, q = plane # Unpack plane directions dirs = [p, q] plane_idxs = 2 shape = (*flux_shape, n_states, n_states) if non_abelian else (*flux_shape,) berry_flux = np.zeros(shape, dtype=complex) # Trim the last point along closed/non-periodic axes to avoid overcounting for ax in dirs: if self.mesh.is_axis_closed(ax) or ( not self.mesh.is_axis_looped(ax) and not self.mesh.is_axis_bz_winding(ax) ): logger.debug( f"Axis {ax} is closed or non-periodic. " "Trimming the last point in the flux array to avoid overcounting." ) if plane is None: berry_flux = np.delete(berry_flux, -1, axis=ax + 2) else: berry_flux = np.delete(berry_flux, -1, axis=ax) # U_forward: Unitary part of overlaps <u_{nk} | u_{n, k+delta k_mu}> U_forward = self.links(state_idx=state_idx, axis_idx=dirs) # Compute Berry flux for each pair of states for mu in range(plane_idxs): for nu in range(mu + 1, plane_idxs): # NOTE: The order of U_forward follows the order in dirs, so we index accordingly # e.g., if dirs = [p, q], then mu=0 -> p, mu=1 -> q U_mu = U_forward[mu] U_nu = U_forward[nu] # Shift the links along the mu and nu directions # NOTE: We index dirs to get the correct ordering axis_mu = dirs[mu] axis_nu = dirs[nu] U_nu_shift_mu = np.roll(U_nu, -1, axis=axis_mu) U_mu_shift_nu = np.roll(U_mu, -1, axis=axis_nu) # Wilson loops: W = U_{mu}(k_0) U_{nu}(k_0+delta_mu) U^{-1}_{mu}(k_0+delta_mu+delta_nu) U^{-1}_{nu}(k_0) if use_tensorflow: try: import tensorflow as tf except ImportError: raise ImportError( "TensorFlow is not installed. Please install it or set use_tensorflow=False." ) U_mu_tf = tf.convert_to_tensor(U_mu) U_nu_shift_mu_tf = tf.convert_to_tensor(U_nu_shift_mu) U_mu_shift_nu_tf = tf.convert_to_tensor(U_mu_shift_nu) U_nu_tf = tf.convert_to_tensor(U_nu) U_wilson_tf = tf.linalg.matmul( tf.linalg.matmul( tf.linalg.matmul(U_mu_tf, U_nu_shift_mu_tf), tf.linalg.adjoint(U_mu_shift_nu_tf), ), tf.linalg.adjoint(U_nu_tf), ) U_wilson = U_wilson_tf.numpy() else: U_wilson = ( U_mu
[docs] @ U_nu_shift_mu @ U_mu_shift_nu.conj().swapaxes(-1, -2) @ U_nu.conj().swapaxes(-1, -2) ) # Trim the last point along closed/non-periodic axes to avoid overcounting for ax in dirs: if self.mesh.is_axis_closed(ax) or ( not self.mesh.is_axis_looped(ax) and not self.mesh.is_axis_bz_winding(ax) ): logger.debug( f"Axis {ax} is closed or non-periodic. " "Trimming the last point in the Wilson loop to avoid overcounting." ) U_wilson = np.delete(U_wilson, -1, axis=ax) if non_abelian: # Non-Abelian lattice field strength: F = -i Log(U_wilson) # Matrix log using eigen-decompositon # Eigen-decompose U_wilson = V diag(-phi_j) V^{-1}, phi_j in (-pi, pi] if use_tensorflow: try: import tensorflow as tf except ImportError: raise ImportError( "TensorFlow is not installed. Please install it or set use_tensorflow=False." ) eigvals, eigvecs = tf.linalg.eig(tf.convert_to_tensor(U_wilson)) eigvals = eigvals.numpy() eigvecs = eigvecs.numpy() else: eigvals, eigvecs = np.linalg.eig(U_wilson) phi = -np.angle(eigvals) F_diag = np.einsum("...i, ij -> ...ij", phi, np.eye(phi.shape[-1])) eigvecs_inv = np.linalg.inv(eigvecs) phases_plane = eigvecs @ F_diag @ eigvecs_inv else: det_U = np.linalg.det(U_wilson) phases_plane = -np.angle(det_U) if plane is None: # Store the Berry flux in a 2D array for each pair of directions berry_flux[mu, nu] = phases_plane berry_flux[nu, mu] = -phases_plane else: berry_flux = phases_plane return berry_flux
def berry_curvature( self, plane: ArrayLike = None, state_idx: ArrayLike | int = None, non_abelian: bool = False, return_flux: bool = False, ): r"""Berry curvature tensor using the Fukui-Hatsugai-Suzuki plaquette method. The Berry curvature tensor :math:`\Omega_{\mu\nu}(\mathbf{k})` is computed using a discretized formula based on the Fukui-Hatsugai-Suzuki (FHS) plaquette-based method [1]_. The curvature is approximated from the Berry flux (computed in :meth:`berry_flux`) by dividing the flux by the (presumed uniform) area of the plaquette in parameter space, .. math:: \Omega_{\mu\nu}(\mathbf{k}) \approx \frac{\mathcal{F}_{\mu\nu}(\mathbf{k})}{A_{\mu\nu}}, where :math:`A_{\mu\nu}` is the area (in Cartesian units) of the plaquette in parameter space. The tensor is either defined is a matrix-valued quantity (non-Abelian case) or as a scalar quantity obtained by tracing over the band indices (Abelian case). .. versionadded:: 2.0.0 Parameters ---------- plane : (2,) array_like, optional Array or tuple of two indices defining the axes in the WFArray mesh which the Berry flux is computed over. By default, all directions are considered, and the full Berry flux tensor is returned. state_idx : int or array_like of int, optional Optional index or array of indices defining the states to be included in the subsequent calculations, typically the indices of bands considered occupied. If not specified, or None, all bands are included. non_abelian : bool, optional If *True* then the non-Abelian Berry flux tensor is computed defined as a matrix-valued quantity. If *False* (default) then the Berry flux is computed as a scalar quantity by tracing over the band indices. return_flux : bool, optional If *True*, the function returns a tuple containing both the Berry curvature and the Berry flux tensors. If *False* (default), only the Berry curvature tensor is returned. Returns ------- berry_curv : np.ndarray Berry curvature tensor with shape depending on input parameters. Shape is - ``non_abelian=True``, ``plane=None``: ``(naxes, naxes, *mesh_shape, nstates, nstates)``, - ``non_abelian=True``, ``plane=(mu, nu)``: ``(*mesh_shape, nstates, nstates)``, - ``non_abelian=False``, ``plane=None``: ``(naxes, naxes, *mesh_shape)``, - ``non_abelian=False``, ``plane=(mu, nu)``: ``(*mesh_shape,)``, berry_flux : np.ndarray, optional Berry flux tensor with shape depending on input parameters. Returned only if ``return_flux=True``. Shape is same as that of ``berry_curv``. See Also -------- :meth:`berry_flux` : For details and formalism on the Berry flux tensor. Notes ----- - The method requires at least a two-dimensional mesh in the combined adiabatic space (momentum + adiabatic parameters). Thus, ``WFArray.naxes >= 2``. - The last point along closed or non-periodic axes is trimmed from the curvature array to avoid overcounting, since it is equivalent to the first point. References ---------- .. [1] \ T. Fukui, Y. Hatsugai, and H. Suzuki, *J. Phys. Soc. Jpn.* **74**, 1674 (2005). """ ndims = self.naxes # Total dimensionality of adiabatic space: d if plane is None: dirs = list(range(ndims)) else: p, q = plane # Unpack plane directions dirs = [p, q] berry_flux = self.berry_flux( plane=plane, state_idx=state_idx, non_abelian=non_abelian ) berry_curv = np.zeros_like(berry_flux, dtype=complex) coords = self.mesh.points # (..., dim_k + dim_lam) dim_total = coords.shape[-1] dim_k = self.lattice.dim_k recip = np.asarray(self.lattice.recip_lat_vecs, float) if dim_k else None # Collect the physical step vectors for each sampling axis axis_vecs = [] for ax in range(ndims): delta = np.zeros(dim_total, dtype=float) for comp in range(dim_total): arr = self.mesh.get_axis_range(ax, comp) if arr.size >= 2: diff = arr[1] - arr[0] if not np.isclose(diff, 0.0): delta[comp] = diff if not np.any(delta): raise ValueError( f"Cannot compute Berry curvature: " f"Mesh axis {ax} has zero length in all parameter directions." ) if dim_k: k_cart = delta[:dim_k] @ recip vec = np.concatenate([k_cart, delta[dim_k:]]) else: vec = delta[dim_k:] axis_vecs.append(vec) # Divide by area elements for the (mu, nu)-plane if plane is not None: # first two axes absent mu, nu = p, q A = np.vstack([axis_vecs[mu], axis_vecs[nu]]) area = np.sqrt(np.linalg.det(A @ A.T)) berry_curv = berry_flux / area else: for i, mu in enumerate(dirs): for j in range(i + 1, len(dirs)): nu = dirs[j] A = np.vstack([axis_vecs[mu], axis_vecs[nu]]) area = np.sqrt(np.linalg.det(A @ A.T)) # Divide flux by the area element to get approx curvature berry_curv[mu, nu] = berry_flux[mu, nu] / area berry_curv[nu, mu] = berry_flux[nu, mu] / area return (berry_curv, berry_flux) if return_flux else berry_curv
[docs] def chern_number(self, plane=(0, 1), state_idx=None): r"""Computes the Chern number in the specified plane. The Chern number is computed as the integral of the Berry flux over the specified plane, divided by :math:`2 \pi`. .. math:: C = \frac{1}{2\pi} \sum_{\mathbf{k}_{\mu}, \mathbf{k}_{\nu}} F_{\mu\nu}(\mathbf{k}). The plane :math:`(\mu, \nu)` is specified by `plane`, a tuple of two indices. .. versionadded:: 2.0.0 Parameters ---------- plane : tuple A tuple of two indices specifying the plane in which the Chern number is computed. The indices should be between 0 and the number of mesh dimensions minus 1. If None, the Chern number is computed for the first two dimensions of the mesh. state_idx : array-like, optional array Indices of states to be included in the Chern number calculation. If None, all states are included. None by default. Returns ------- chern : np.ndarray, float In the two-dimensional case, the result will be a floating point approximation of the integer Chern number for that plane. In a higher-dimensional space, the Chern number is computed for each 2D slice of the higher-dimensional space. E.g., the shape of the returned array is `(nk3, ..., nkd)` if the plane is `(0, 1)`, where `(nk3, ..., nkd)` are the sizes of the mesh in the remaining dimensions. See Also -------- :meth:`berry_flux` : For details and formalism on the Berry flux tensor. Notes ----- - The Chern number gives an integer value in the limit of an infinitely dense mesh only when the plane forms a closed manifold. - The method requires at least a two-dimensional mesh in the combined adiabatic space (momentum + adiabatic parameters). Thus, ``WFArray.naxes >= 2``. - The last point along closed or non-periodic axes is trimmed from the Chern number array to avoid overcounting, since it is equivalent to the first point. Examples -------- Suppose we have a `WFArray` mesh in three-dimensional space of shape `(nk1, nk2, nk3)`. We can compute the Chern number for the `(0, 1)` plane as follows: >>> wfs = WFArray(model, [10, 11, 12]) >>> wfs.solve_on_grid() >>> chern = wfs.chern_number(plane=(0, 1), state_idx=np.arange(n_occ)) >>> print(chern.shape) (12,) # shape of the Chern number array """ # shape of the Berry flux array: (nk1, nk2, ..., nkd) berry_flux = self.berry_flux( state_idx=state_idx, plane=plane, non_abelian=False ) # shape of chern (if plane is (0,1)): (nk3, ..., nkd) chern = np.sum(berry_flux, axis=plane) / (2 * np.pi) return chern
[docs] def position_matrix( self, pos_dir: int, mesh_idx: list[int], state_idx: list[int] = None ): r"""Position matrix for a given k-point and set of states. Position operator is defined in reduced coordinates. The returned object :math:`X` is .. math:: X_{m n {\bf k}}^{\alpha} = \langle u_{m {\bf k}} \vert r^{\alpha} \vert u_{n {\bf k}} \rangle Here :math:`r^{\alpha}` is the position operator along direction :math:`\alpha` that is selected by `pos_dir`. This routine can be used to compute the position matrix for a given k-point and set of states (which can be all states, or a specific subset). Parameters ---------- pos_dir: int Direction of the position operator. ``0`` corresponds to the first non-periodic direction, ``1`` to the second, and so on. .. versionchanged:: 2.0.0 Renamed from ``dir`` to ``pos_dir`` to avoid conflict with built-in Python function `dir()`. mesh_idx: array-like of int Set of integers specifying the :math:`(k, \lambda)`-point of interest in the mesh. state_idx: array-like, optional List of states to be included. If not specified, all states are included. .. versionchanged:: 2.0.0 Renamed from ``occ``. The band indices are not required to be occupied bands only. The default behavior is to include all bands, and the ``"all"`` option has been removed. Returns ------- pos_mat : np.ndarray Position operator matrix :math:`X_{m n}` as defined above. This is a square matrix with size determined by number of bands given in `evec` input array. First index of `pos_mat` corresponds to bra vector (:math:`m`) and second index to ket (:math:`n`). See Also -------- :func:`pythtb.TBModel.position_matrix` Notes ----- The only difference in :func:`pythtb.TBModel.position_matrix` is that, in addition to specifying ``pos_dir``, one also has to specify ``mesh_idx`` (mesh-point of interest) and ``state_idx`` (list of states to be included, which can optionally be 'all'). """ # # check if model came from w90 # if not self._assume_position_operator_diagonal: # _offdiag_approximation_warning_and_stop() if isinstance(mesh_idx, (list, np.ndarray, tuple)): mesh_idx = tuple(mesh_idx) elif isinstance(mesh_idx, (int, np.integer)): mesh_idx = (mesh_idx,) else: raise TypeError( "mesh_idx must be a list, numpy array, or tuple defining " "k-point indices of interest." ) if len(mesh_idx) != self.naxes: raise ValueError( f"mesh_idx must have length {self.naxes} corresponding to " "number of mesh axes." ) state_idx = self._normalize_state_indices(state_idx) evec = self.wfs[tuple(mesh_idx)][state_idx] # make sure specified direction is not periodic! if pos_dir in self.lattice.periodic_dirs: raise Exception( "Can not compute position matrix elements along periodic direction!" ) # make sure direction is not out of range if pos_dir < 0 or pos_dir >= self.lattice.dim_r: raise Exception("Direction out of range!") # check shape of evec if not isinstance(evec, np.ndarray): raise TypeError("evec must be a numpy array.") # check number of dimensions of evec if self.nspin == 1: if evec.ndim != 2: raise ValueError( "evec must be a 2D array with shape (band, orbital) for spinless models." ) elif self.nspin == 2: if evec.ndim != 3: raise ValueError( "evec must be a 3D array with shape (band, orbital, spin) for spinful models." ) # get coordinates of orbitals along the specified direction pos_tmp = self.lattice.orb_vecs[:, pos_dir] # reshape arrays in the case of spinfull calculation if self.nspin == 2: # tile along spin direction if needed pos_use = np.tile(pos_tmp, (2, 1)).transpose().flatten() evec_use = evec.reshape((evec.shape[0], evec.shape[1] * evec.shape[2])) else: pos_use = pos_tmp evec_use = evec # position matrix elements pos_mat = np.zeros((evec_use.shape[0], evec_use.shape[0]), dtype=complex) # go over all bands for i in range(evec_use.shape[0]): for j in range(evec_use.shape[0]): pos_mat[i, j] = np.dot(evec_use[i].conj(), pos_use * evec_use[j]) # make sure matrix is Hermitian if not np.allclose(pos_mat, pos_mat.T.conj()): raise ValueError("Position matrix is not Hermitian.") return pos_mat
[docs] def position_expectation(self, pos_dir: int, mesh_idx=None, state_idx=None): r"""Position expectation value for a given k-point and set of states. These elements :math:`X_{n n}` can be interpreted as an average position of n-th Bloch state ``evec[n]`` along direction ``pos_dir``. This routine can be used to compute the position expectation value for a given k-point and set of states (which can be all states, or a specific subset). Parameters ---------- pos_dir: int Direction of the position operator. ``0`` corresponds to the first non-periodic direction, ``1`` to the second, and so on. .. versionchanged:: 2.0.0 Renamed from ``dir`` to ``pos_dir`` to avoid conflict with built-in Python function `dir()`. mesh_idx: array-like of int, optional Set of integers specifying the :math:`(k, \lambda)`-point of interest in the mesh. If not specified, position expectation values are computed for all mesh points. state_idx: array-like, optional List of states to be included. If not specified, all states are included. .. versionchanged:: 2.0.0 Renamed from ``occ``. The band indices are not required to be occupied bands only. The default behavior is to include all bands, and the ``"all"`` option has been removed. Returns ------- pos_exp : np.ndarray Diagonal elements of the position operator matrix :math:`X`. Length of this vector is determined by number of bands given in *evec* input array. See Also -------- :func:`pythtb.TBModel.position_expectation` :ref:`haldane-hwf-nb` : For an example. position_matrix : For definition of matrix :math:`X`. Notes ----- The only difference in :func:`pythtb.TBModel.position_expectation` is that, in addition to specifying ``pos_dir``, one also has to specify ``mesh_idx`` (mesh-point of interest) and ``state_idx`` (list of states to be included). Generally speaking these centers are _not_ hybrid Wannier function centers (which are instead returned by :func:`position_hwf`). """ if mesh_idx is None: pos_exp = np.zeros((*self.shape_mesh, self.nstates), dtype=float) # loop over all mesh points for idx in np.ndindex(*self.shape_mesh): pos_exp_mat = self.position_matrix( mesh_idx=idx, state_idx=state_idx, pos_dir=pos_dir ).diagonal() pos_exp[idx] = np.array(np.real(pos_exp_mat), dtype=float) return pos_exp else: pos_exp_mat = self.position_matrix( mesh_idx=mesh_idx, state_idx=state_idx, pos_dir=pos_dir ).diagonal() return np.array(np.real(pos_exp_mat), dtype=float)
[docs] def position_hwf( self, pos_dir, mesh_idx, state_idx=None, hwf_evec: bool = False, basis: str = "wavefunction", ): r"""Eigenvalues and eigenvectors of the position operator in a given basis. Parameters ---------- mesh_idx: array-like of int Set of integers specifying the index of interest in the mesh. pos_dir: int Direction along which to compute the position operator. .. versionchanged:: 2.0.0 Renamed from ``dir`` to ``pos_dir`` to avoid conflict with built-in Python function `dir()`. state_idx: array-like, optional List of states to be included. If not specified, all states are included. .. versionchanged:: 2.0.0 Renamed from ``occ``. The band indices are not required to be occupied bands only. The default behavior is to include all bands, and the ``"all"`` option has been removed. hwf_evec: bool, optional Default is `False`. If `True`, return the eigenvectors along with eigenvalues of the position operator. basis: {"orbital", "wavefunction", "bloch"}, optional The basis in which to compute the position operator. Returns ------- hwfc : np.ndarray Eigenvalues of the position operator matrix :math:`X` (also called hybrid Wannier function centers). Length of this vector equals number of bands given in *evec* input array. Hybrid Wannier function centers are ordered in ascending order. Note that in general `n`-th hwfc does not correspond to `n`-th electronic state `evec`. hwf : np.ndarray, optional Eigenvectors of the position operator matrix :math:`X`. (also called hybrid Wannier functions). These are returned only if parameter ``hwf_evec=True``. The shape of this array is ``[h,x]`` or ``[h,x,s]`` depending on value of ``basis`` and ``spinful``. - If ``basis = "bloch"`` then ``x`` refers to indices of Bloch states `evec`. - If ``basis = "orbital"`` then ``x`` (or ``x`` and ``s``) correspond to orbital index (or orbital and spin index if ``spinful=True``). See Also -------- :ref:`haldane-hwf-nb` : For an example. position_matrix : For the definition of the matrix :math:`X`. position_expectation : For the position expectation value. :func:`pythtb.TBModel.position_hwf` Notes ----- Similar to :func:`pythtb.TBModel.position_hwf`, except that in addition to specifying ``pos_dir``, one also has to specify ``mesh_idx`` (mesh-point of interest) and ``state_idx`` (list of states to be included). For backwards compatibility the default value of *basis* here is different from that in :func:`pythtb.TBModel.position_hwf`. """ state_idx = self._normalize_state_indices(state_idx) # get position matrix pos_mat = self.position_matrix( mesh_idx=mesh_idx, state_idx=state_idx, pos_dir=pos_dir ) evec = self.wfs[tuple(mesh_idx)][state_idx] # diagonalize position matrix if not hwf_evec: hwfc = np.linalg.eigvalsh(pos_mat) return hwfc else: hwfc, hwf = np.linalg.eigh(pos_mat) # transpose so eig[i, :] is eigenvector for eval[i]-th eigenvalue hwf = hwf.T # convert to right basis if basis.lower().strip() in ["wavefunction", "bloch"]: return hwfc, hwf elif basis.lower().strip() == "orbital": if self.nspin == 1: ret_hwf = np.zeros((hwf.shape[0], self.norb), dtype=complex) for i in range(ret_hwf.shape[0]): ret_hwf[i] = np.dot(hwf[i], evec) # project onto orbital basis hwf = ret_hwf else: ret_hwf = np.zeros((hwf.shape[0], self.norb * 2), dtype=complex) # flatten spin indices evec_use = evec.reshape([hwf.shape[0], self.norb * 2]) for i in range(ret_hwf.shape[0]): ret_hwf[i] = np.dot( hwf[i], evec_use ) # project onto orbital basis # restore spin indices hwf = ret_hwf.reshape([hwf.shape[0], self.norb, 2]) return hwfc, hwf else: raise ValueError( "Basis must be either 'wavefunction', 'bloch', or 'orbital'" )
def _trace_metric(self): P = self.projectors() _, Q_nbr = self._nbr_projectors(return_Q=True) nks = Q_nbr.shape[:-3] num_nnbrs = Q_nbr.shape[-3] w_b, _, _ = self.lattice.k_shell_weights(self.mesh.shape_k, n_shell=1) T_kb = np.zeros((*nks, num_nnbrs), dtype=complex) for nbr_idx in range(num_nnbrs): # nearest neighbors T_kb[..., nbr_idx] = np.trace( P[..., :, :] @ Q_nbr[..., nbr_idx, :, :], axis1=-1, axis2=-2 ) return w_b[0] * np.sum(T_kb, axis=-1) def _omega_til(self): Mmn = self._Mmn w_b, k_shell, _ = self.lattice.k_shell_weights(self.mesh.shape_k, n_shell=1) w_b = w_b[0] k_shell = k_shell[0] nks = Mmn.shape[:-3] Nk = np.prod(nks) k_axes = tuple([i for i in range(len(nks))]) diag_M = np.diagonal(Mmn, axis1=-1, axis2=-2) log_diag_M_imag = np.log(diag_M).imag abs_diag_M_sq = abs(diag_M) ** 2 r_n = -(1 / Nk) * w_b * np.sum(log_diag_M_imag, axis=k_axes).T @ k_shell Omega_tilde = ( (1 / Nk) * w_b * ( np.sum((-log_diag_M_imag - k_shell @ r_n.T) ** 2) + np.sum(abs(Mmn) ** 2) - np.sum(abs_diag_M_sq) ) ) return Omega_tilde
def _no_2pi(phi, ref): """Shift phase phi by integer multiples of 2π so it is closest to ref.""" while abs(ref - phi) > np.pi: if ref - phi > np.pi: phi += 2.0 * np.pi elif ref - phi < -1.0 * np.pi: phi -= 2.0 * np.pi return phi def _array_phases_cont(arr_pha, clos): """Reads in 2d array of phases arr_pha and enforces continuity along the first index, i.e., no 2π jumps. First row is made as close to clos as possible.""" ret = np.zeros_like(arr_pha) for i in range(arr_pha.shape[0]): cmpr = clos if i == 0 else ret[i - 1, :] avail = list(range(arr_pha.shape[1])) for j in range(cmpr.shape[0]): best_k, min_dist = None, 1e10 for k in avail: cur_dist = np.abs(np.exp(1j * cmpr[j]) - np.exp(1j * arr_pha[i, k])) if cur_dist <= min_dist: min_dist = cur_dist best_k = k avail.remove(best_k) ret[i, j] = _no_2pi(arr_pha[i, best_k], cmpr[j]) return ret def _one_phase_cont(pha, clos): """Reads in 1d array of numbers *pha* and makes sure that they are continuous, i.e., that there are no jumps of 2pi. First number is made as close to *clos* as possible.""" ret = np.copy(pha) # go through entire list and "iron out" 2pi jumps for i in range(len(ret)): # which number to compare to if i == 0: cmpr = clos else: cmpr = ret[i - 1] # make sure there are no 2pi jumps ret[i] = _no_2pi(ret[i], cmpr) return ret