from .tbmodel import TBModel
from .mesh import Mesh
from .lattice import Lattice
import warnings
import functools
import logging
import copy
import numpy as np
from numpy.typing import ArrayLike
logger = logging.getLogger(__name__)
__all__ = ["WFArray"]
def deprecated(message: str, category=FutureWarning):
"""
Decorator to mark a function as deprecated.
Raises a FutureWarning with the given message when the function is called.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
warnings.warn(
f"{func.__qualname__} is deprecated and will be removed in a future release: {message}",
category=category,
stacklevel=2,
)
return func(*args, **kwargs)
return wrapper
return decorator
[docs]
class WFArray:
r"""Wavefunction container defined on a sampling mesh.
A :class:`WFArray` stores states on a discrete mesh of k-points
and/or adiabatic parameters :math:`\lambda`. Once populated,
it can be queried for Berry connections, Berry curvature,
Chern numbers, and other derived quantities, or passed to :class:`Wannier`
when constructing Wannier functions or obtaining smooth gauges from
the projection method.
The underlying :class:`Mesh` may represent a full Monkhorst-Pack grid,
a lower-dimensional path, or even a mesh that contains only
parameter axes. In every case, :class:`WFArray` tracks the mesh layout,
the stored states, and the necessary phase conventions so downstream
utilities can consume the data consistently.
Wavefunctions stored in a :class:`WFArray` can be the Hamiltonian eigenstates
generated by :meth:`solve_model`, or any other :class:`Mesh`-defined states in the
orbital basis of the underlying :class:`Lattice` that match :attr:`spinful` setting.
Use :meth:`set_states` for bulk assignments, or the ``[...]`` indexer to
update the states at an individual mesh point. Periodic boundary conditions are
enforced automatically on assignment.
Parameters
----------
lattice : :class:`Lattice`
Lattice structure with orbital positions, periodic directions, and lattice vectors.
Use `model.lattice` to keep it consistent with a :class:`TBModel`.
mesh : :class:`Mesh`
Sampling grid in k/parameter space. Build it before constructing the :class:`WFArray`.
Its axes set the topology for enforcing periodic boundary conditions.
spinful : bool, optional
Whether the model includes spin degrees of freedom (defaults to False).
If True, each orbital stores two spin components. This affects the shape
of the stored wavefunction array.
nstates : int, optional
Number of bands per mesh point to store (defaults to ``WFArray.norb * WFArray.nspin``).
See Also
--------
:class:`pythtb.TBModel`
:class:`pythtb.Mesh`
:class:`pythtb.Wannier`
:ref:`formalism`
:ref:`haldane-bp-nb` : For an example of using :class:`WFArray` on a regular grid of points in k-space.
:ref:`cone-nb` : For an example of using :class:`WFArray` on a non-regular grid of points in k-space.
:ref:`three-site-thouless-nb` : For an example of using :class:`Mesh` with an adiabatic dimension.
This example shows how one of the directions of :class:`WFArray` object need not be a k-vector direction,
but can instead be a Hamiltonian parameter :math:`\lambda`. See also discussion after equation 4.1 in
:ref:`formalism`.
:ref:`cubic-slab-hwf-nb` : For an example of using :class:`WFArray` to store hybrid Wannier functions.
Notes
-----
- Wavefunctions are always stored with mesh axes leading, followed by bands, orbital,
and (if present) spin indices. When setting the wavefunctions manually, ensure the input
array matches this convention.
- :class:`WFArray` cooperates with :class:`Wannier` to construct smooth Wannier gauges:
pass the populated array to ``Wannier(wfarray)`` and use
:meth:`Wannier.single_shot_projection`.
- Some features are only defined for regular grids and/or in the energy eigenstate gauge.
Check the documentation of individual methods for details.
Examples
--------
Populate Mesh with a uniform Monkhorst-Pack grid of k-points
>>> mesh = Mesh(['k', 'k'])
>>> mesh.build_grid(shape=(20, 20), gamma_centered=True)
>>> wfa = WFArray(lattice, mesh, spinful=True, nstates=4)
>>> wfa.shape
(20, 20, ...)
The WFArray is initially empty
>>> wfa.filled
False
Solve a :class:`TBModel` on the mesh and store the eigenstates
>>> wfa.solve_model(tb_model)
>>> wfa.energies.shape
(20, 20, ...)
Now we can use downstream functions, such as computing the Berry curvature on the grid
>>> curv = wfa.berry_curvature(non_abelian=False)
We can also store states from a finite model on a parameter sweep (no k-axes)
>>> mesh = Mesh(['l'])
>>> mesh.build_grid(shape=(101,), lambda_start=0.0, lambda_stop=2*np.pi)
If the parameter points in the mesh are adiabatic cycles, ensure the
boundary conditions are set correctly in the Mesh
>>> mesh.loop(axis_idx=0, loop_idx=0)
>>> wfa = WFArray(lattice, mesh)
>>> wfa.set_states(eigenvectors_lambda, is_cell_periodic=False)
>>> np.allclose(wfa[0], wfa[-1])
True # states at the endpoints match due to adiabatic cycle
Even if we set the states manually with the indexer, periodic boundary conditions
are still enforced automatically using the topology of the Mesh
>>> wfa[-1] = eigvecs # shape (nstates, norb[, nspin])
>>> np.allclose(wfa[0], wfa[-1])
True # states at the endpoints still match due to adiabatic cycle
"""
[docs]
@deprecated("Looping is handled automatically by Mesh.")
def impose_loop(self, mesh_dir: int):
r"""
.. versionremoved:: 2.0.0
:meth:`impose_loop` has been removed. Looping is handled automatically by :class:`Mesh`.
"""
raise NotImplementedError("Looping is handled automatically by Mesh.")
[docs]
@deprecated("Periodic boundary conditions are handled automatically by Mesh.")
def impose_pbc(self, mesh_dir: int, k_dir: int):
r"""
.. versionremoved:: 2.0.0
:meth:`impose_pbc` has been removed.
Periodic boundary conditions are handled automatically by :class:`Mesh`.
"""
raise NotImplementedError(
"Periodic boundary conditions are handled automatically by Mesh."
)
[docs]
@deprecated("Use `solve_model` instead.")
def solve_on_grid(self, start_k=None):
r"""
.. versionremoved:: 2.0.0
:meth:`solve_on_grid` has been removed. Use :meth:`solve_model` instead.
"""
raise NotImplementedError("Use `solve_model` instead.")
[docs]
@deprecated("Use `solve_model` instead.")
def solve_on_one_point(self, kpt, mesh_indices):
r"""
.. versionremoved:: 2.0.0
:meth:`solve_on_one_point` has been removed. Use :meth:`solve_model` instead.
"""
raise NotImplementedError("Use `solve_model` instead.")
def __init__(
self, lattice: Lattice, mesh: Mesh, nstates: int = None, spinful: bool = False
):
if not isinstance(lattice, Lattice):
raise TypeError("lattice must be of type pythtb.Lattice")
if not isinstance(mesh, Mesh):
raise TypeError("mesh must be of type pythtb.Mesh")
if not isinstance(spinful, bool):
raise TypeError("Argument spinful must be a boolean.")
if lattice.dim_k != mesh.dim_k:
raise ValueError(
f"Lattice dim_k ({lattice.dim_k}) does not match Mesh dim_k ({mesh.dim_k})"
)
if not mesh.filled:
raise ValueError(
"Mesh points are not initialized. Did you call build_grid/build_custom?"
)
self._lattice = lattice
self._mesh = mesh
axis_sizes = np.array(self.shape_mesh, dtype=int)
loop_axes = [idx for idx, ax in enumerate(mesh.axes) if ax.is_loop]
short_loops = [idx for idx in loop_axes if axis_sizes[idx] < 2]
if short_loops:
raise ValueError(
"Looping mesh axes must have at least two samples "
f"(axes {short_loops} are too short)."
)
if True in (np.array(self.shape_mesh, dtype=int) < 1).tolist():
raise ValueError(
"Dimension of WFArray object in each direction must be at least 1.\n"
"Maybe you need to build the mesh first?"
)
if not isinstance(spinful, bool):
raise TypeError("Argument spinful must be a boolean.")
self._spinful = spinful
if nstates is not None:
if not isinstance(nstates, (int, np.integer)):
raise TypeError("Argument nstates must be an integer.")
self._nstates = nstates
else:
self._nstates = self.norb * self.nspin # Default to total number of bands
# wfs indexed by [k1, k2,..., state, orb, spin]
self._wfs = np.empty(self.shape, dtype=complex)
# energies indexed by [k1, k2,..., state]
self._energies = None
def __getitem__(self, index):
return self._wfs[index]
def __setitem__(self, index, value):
if not isinstance(value, (list, np.ndarray)):
raise TypeError("Value must be a list or numpy array!")
value = np.array(value, dtype=complex)
if self.nspin == 2:
if value.ndim == self.naxes + 2:
if value.shape[-1] != self.norb * 2:
raise ValueError(
"Value shape does not match expected shape for spinful model!"
)
value = value.reshape(*value.shape[:-1], self.norb, 2)
else:
if value.shape != self.shape[len(self.shape_mesh) :]:
raise ValueError("Incompatible shape for wavefunction!")
self._wfs[index] = value
self._sync_boundary_from_index(index)
self._invalidate_caches()
def _check_state_indices(
self, state_idx: int | ArrayLike, return_indices: bool = False
) -> np.ndarray | None:
"""Validate state indices and return as a numpy array."""
# Normalize to numpy array
try:
state_idx = np.atleast_1d(state_idx).astype(int)
except Exception:
raise TypeError("state_idx must be an integer or array-like of integers.")
if state_idx.ndim != 1:
raise ValueError("State indices should be a one-dimensional array.")
if np.any(state_idx < 0) or np.any(state_idx >= self.nstates):
raise IndexError(
"One or more state indices are outside the range of the WFArray."
)
return state_idx if return_indices else None
def _normalize_state_indices(self, state_idx: int | ArrayLike | None) -> np.ndarray:
"""Validate state indices and return as a numpy array.
Differs from _check_state_indices by allowing None input,
which returns all indices.
"""
if state_idx is None:
state_idx = np.arange(self.nstates, dtype=int)
else:
state_idx = self._check_state_indices(state_idx, return_indices=True)
return state_idx
def _invalidate_caches(self):
for attr in ("_P", "_Q", "_P_nbr", "_Q_nbr", "_Mmn"):
if hasattr(self, attr):
delattr(self, attr)
def _sync_boundary_from_index(self, index):
"""Update linked boundary points after assigning into the array."""
if self.naxes == 0:
return
if isinstance(index, np.ndarray):
index = index.tolist()
if self.naxes == 1 and not isinstance(index, (tuple, list)):
coords = (int(index),)
else:
coords = tuple(int(k) for k in index)
mesh_coords = []
for ax_idx, k in enumerate(coords):
size = self.shape_mesh[ax_idx]
mesh_coords.append(k % size)
for ax_idx, coord in enumerate(mesh_coords):
axis = self.mesh.axes[ax_idx]
if not (axis.has_endpoint and axis.is_loop):
continue
axis_len = self.shape_mesh[ax_idx]
if coord not in (0, axis_len - 1):
continue
if axis.winds_bz:
phase, slc_first, slc_last, comps = self._collect_pbc_phase_info(ax_idx)
if phase is None:
continue
from_first = coord == 0
logger.debug(
"Syncing PBC on mesh axis %d (%s) for k-components %s (%s edge).",
ax_idx,
axis,
comps,
"first" if from_first else "last",
)
self._apply_pbc_phase(phase, slc_first, slc_last, from_first=from_first)
else:
slc_first, slc_last = self._edge_slices(ax_idx)
if coord == 0:
logger.debug(
"Syncing loop boundary (first → last) on mesh axis %d (%s).",
ax_idx,
axis,
)
self._copy_edge(slc_first, slc_last)
else:
logger.debug(
"Syncing loop boundary (last → first) on mesh axis %d (%s).",
ax_idx,
axis,
)
self._copy_edge(slc_last, slc_first)
def _canonical_to_mesh_axes(self) -> tuple[int, ...]:
"""Permutation from canonical axis order to the mesh's declared axis order."""
axes = self.mesh.axes
if not axes:
return tuple()
shape_k = self.mesh.shape_k
perm_user_to_canonical: list[int] = []
k_counter = l_counter = 0
for ax in axes:
if ax.type == "k":
perm_user_to_canonical.append(k_counter)
k_counter += 1
else:
perm_user_to_canonical.append(len(shape_k) + l_counter)
l_counter += 1
return tuple(perm_user_to_canonical)
def _mesh_axes_to_canonical(self) -> tuple[int, ...]:
"""Permutation from mesh's declared axis order to canonical axis order."""
perm = self._canonical_to_mesh_axes()
return tuple(np.argsort(perm).tolist())
@property
def model(self) -> TBModel:
"""The :class:`TBModel` associated with the :class:`WFArray`.
Returns
-------
TBModel
The :class:`TBModel` associated with the :class:`WFArray`.
Only defined if the wavefunctions were computed
by :meth:`solve_model`.
Raises
------
ValueError
If no :class:`TBModel` is associated with this :class:`WFArray`.
"""
if not hasattr(self, "_model"):
raise ValueError(
"No TBModel is associated with this WFArray. "
"Did you compute the wavefunctions using solve_model?"
)
return self._model
@property
def lattice(self) -> Lattice:
"""The :class:`Lattice` associated with the :class:`WFArray`.
Returns
-------
Lattice
The :class:`Lattice` associated with the :class:`WFArray`.
"""
return self._lattice
@property
def mesh(self) -> Mesh:
"""The :class:`Mesh` associated with the :class:`WFArray`.
Returns
-------
Mesh
The :class:`Mesh` associated with the :class:`WFArray`.
"""
return self._mesh
@property
def filled(self) -> bool:
"""Whether the wavefunction array has been initialized.
Returns
-------
bool
True if the :attr:`wfs` array is not empty.
"""
# if uninitialzed, wfs will be np.empty
return self._wfs.size > 0
@property
def wfs(self) -> np.ndarray:
"""The stored wavefunctions.
In the case of k-axes, these are the cell-periodic (Bloch) states
without the plane-wave phase factors.
Returns
-------
np.ndarray
The stored wavefunctions. Shape is ``(*shape_mesh, nstates, norb[, nspin])``.
In the case of spinful models, the last axis corresponds to spin.
"""
return self._wfs
@property
def u_nk(self) -> np.ndarray:
r"""The cell-periodic wavefunctions.
These are the :math:`|u_{n\mathbf{k}}\rangle` states without the plane-wave phase factors.
Returns
-------
np.ndarray
The cell-periodic wavefunctions.
Shape is ``(*shape_mesh, nstates, norb[, nspin])``.
Raises
------
ValueError
If the wavefunctions are not initialized or if k-axes are not present in the mesh.
"""
if not self.filled:
raise ValueError("Wavefunctions are not initialized.")
if self.dim_k == 0:
raise ValueError(
"Cell-periodic wavefunctions are not defined for 0D k-space."
)
return getattr(self, "_u_nk", None)
@property
def psi_nk(self) -> np.ndarray:
r"""The Bloch wavefunctions.
These are the :math:`|\psi_{n\mathbf{k}}\rangle` states including the plane-wave
phase factors.
Returns
-------
np.ndarray
The Bloch wavefunctions.
Shape is ``(*shape_mesh, nstates, norb[, nspin])``.
Raises
------
ValueError
If the wavefunctions are not initialized or if k-axes are not present in the mesh
"""
if not self.filled:
raise ValueError("Wavefunctions are not initialized.")
if self.dim_k == 0:
raise ValueError("Bloch wavefunctions are not defined for 0D k-space.")
return getattr(self, "_psi_nk", None)
@property
def Mmn(self) -> np.ndarray:
r"""The overlap matrix of the wavefunctions.
The overlap matrix is defined as
.. math::
M_{mn}^{(\mathbf{b})}(\mathbf{k}) = \langle u_{m,\mathbf{k}} | u_{n,\mathbf{k}+\mathbf{b}} \rangle
where :math:`\mathbf{b}` is a vector connecting nearest neighbor k-points in the mesh. Here, the
neighboring k-points are computed in Cartesian space.
Returns
-------
np.ndarray
The overlap matrix of the wavefunctions. Shape is ``(*shape_mesh, nnbrs, nstates, nstates)``.
Raises
------
ValueError
If the wavefunctions are not initialized or if the mesh is not a regular grid.
Notes
-----
- The overlap matrix is only defined for regular grids in k-space.
- The overlap matrix is computed using the Cartesian metric by default.
To compute the overlap matrix using reduced neighbors,
use :meth:`overlap_matrix` with ``use_k_metric=False``.
"""
if not self.filled:
raise ValueError("Wavefunctions are not initialized.")
if not self.mesh.is_grid:
raise ValueError("Overlap matrix is only defined for regular grids.")
if self.dim_k == 0:
raise ValueError("Overlap matrix is not defined for 0D k-space.")
if not hasattr(self, "_Mmn"):
self._Mmn = self.overlap_matrix(use_k_metric=True)
return self._Mmn
@property
def energies(self) -> np.ndarray:
"""The energies of the :class:`TBModel`.
Returns
-------
np.ndarray
The energies of the :class:`TBModel`. Shape is ``(shape_mesh..., nstates)``.
Only defined after calling :meth:`solve_model`.
Raises
------
ValueError
If the wavefunctions are not initialized
or if the energies have not been computed.
"""
if not self.filled:
raise ValueError("Wavefunctions are not initialized.")
if self._energies is None:
raise ValueError(
"Energies are not initialized. Use `solve_model` to compute them."
)
return self._energies
@property
def hamiltonian(self) -> np.ndarray:
r"""The Hamiltonian defined on the :class:`Mesh`.
Returns
-------
np.ndarray
The Hamiltonian defined on the :class:`Mesh`. Shape is
``(*shape_mesh, norb[, nspin], norb[, nspin])``.
"""
return getattr(self, "_H", None)
@property
def shape(self) -> tuple:
"""The shape of the state array.
Returns
-------
tuple of int
The shape of the state array, including mesh axes, states,
orbitals, and (if present) spin.
"""
wfs_dim = np.array(self.shape_mesh, dtype=int)
wfs_dim = np.append(wfs_dim, self.nstates)
wfs_dim = np.append(wfs_dim, self.norb)
if self.nspin == 2:
wfs_dim = np.append(wfs_dim, self.nspin)
return tuple(wfs_dim)
@property
def nstates(self) -> int:
"""The number of states (or bands) in the state array.
Returns
-------
int
"""
return self._nstates
@property
def nspin(self) -> int:
"""The number of spin components.
Returns
-------
int
2 if the :class:`WFArray` is spinful, 1 otherwise.
"""
return 2 if self.spinful else 1
@property
def spinful(self) -> bool:
"""Whether the :class:`WFArray` includes spin degrees of freedom.
Returns
-------
bool
"""
return self._spinful
@property
def norb(self) -> int:
"""The number of orbitals defined in the :class:`Lattice`.
Returns
-------
int
"""
return self.lattice.norb
@property
def shape_mesh(self) -> tuple:
"""The shape of the axes of :class:`Mesh`.
Returns
-------
tuple of int
The shape of the axes of :class:`Mesh`. Corrsponds to
:attr:`Mesh.shape_axes`.
"""
return self.mesh.shape_axes
@property
def dim_k(self) -> int:
"""The dimension of k-space in the :class:`Mesh`.
Returns
-------
int
"""
return self.lattice.dim_k
@property
def dim_lambda(self) -> int:
"""The dimension of lambda-space in the :class:`Mesh`.
Returns
-------
int
"""
return self.mesh.dim_lambda
@property
def naxes(self) -> int:
"""The number of axes in the :class:`Mesh`.
Returns
-------
int
"""
return self.mesh.naxes
@property
def k_points(self) -> np.ndarray:
"""The k-points in the :class:`Mesh`.
Returns
-------
np.ndarray
The k-points in the :class:`Mesh`.
"""
return self.mesh.get_k_points()
@property
def param_points(self) -> np.ndarray:
"""The parameter points in the :class:`Mesh`.
Returns
-------
np.ndarray
The parameter points in the :class:`Mesh`.
"""
return self.mesh.get_param_points()
[docs]
def empty_like(self, nstates: int = None) -> "WFArray":
r"""Create a new :class:`WFArray` object with the same :class:`Lattice` and :class:`Mesh`.
Parameters
----------
nstates : int, optional
Number of states for the new :class:`WFArray`.
If None, uses the current number of states (default).
.. versionchanged:: 2.0.0
Renamed from ``nsta_arr`` for consistency with initialization.
Returns
-------
WFArray
A new :class:`WFArray` object with the same :class:`Lattice` and :class:`Mesh`.
"""
# make a full copy of the WFArray
wf_new = WFArray(self.lattice, self.mesh, nstates=nstates, spinful=self.spinful)
return wf_new
[docs]
def copy(self) -> "WFArray":
r"""Create a copy of the current :class:`WFArray` object.
.. versionadded:: 2.0.0
Returns
-------
WFArray
A copy of the current :class:`WFArray` object.
"""
return copy.deepcopy(self)
[docs]
def set_states(
self, wfs, is_cell_periodic: bool = True, is_spin_axis_flat: bool = False
):
"""Populate the wavefunction array.
This method sets the wavefunctions stored in the :class:`WFArray`.
The input wavefunctions can be either the cell-periodic (Bloch) states
or the full Bloch states, depending on the value of ``is_cell_periodic``.
.. versionadded:: 2.0.0
Parameters
----------
wfs : np.ndarray
Wavefunctions to populate the mesh with. The shape must match the expected
shape :attr:`shape`.
is_cell_periodic : bool, optional
If True, the wavefunctions are treated as cell-periodic (Bloch states).
Default is True.
is_spin_axis_flat : bool, optional
If True, the spin and orbital indices of ``wfs`` have been
flattened into a single index. Default is False. This option
is only relevant when :attr:`spinful` is True.
Notes
-----
- This function sets the Bloch and cell-periodic eigenstates as class attributes
when ``wfs`` are defined on the a k-mesh. When the model is finite, only the
:attr:`wfs` attribute is set and ``is_cell_periodic`` argument is ignored.
.. warning::
This function should be used carefully to ensure that the wavefunctions
are consistent with the :attr:`mesh` and :attr:`lattice`.
"""
if not isinstance(wfs, np.ndarray):
raise TypeError("wfs must be a numpy ndarray.")
# Check the shape of wfs
if is_spin_axis_flat and self.nspin == 2:
expected_shape = self.shape_mesh + (self.nstates, self.norb * self.nspin)
if not is_spin_axis_flat and self.nspin == 2:
expected_shape = self.shape_mesh + (self.nstates, self.norb, self.nspin)
elif self.nspin == 1:
expected_shape = self.shape_mesh + (self.nstates, self.norb)
if wfs.shape != expected_shape:
raise ValueError(
f"wfs shape {wfs.shape} does not match expected shape: {expected_shape}"
)
wfs = wfs.reshape(self.shape)
# Compute phase factors for Bloch <-> cell-periodic transformation
if self.dim_k > 0:
if is_cell_periodic:
phases = self._get_phases(inverse=False)
psi_nk = wfs * phases
self._u_nk = self._wfs = wfs
self._psi_nk = psi_nk
else:
phases = self._get_phases(inverse=True)
u_nk = wfs * phases
self._u_nk = self._wfs = u_nk
self._psi_nk = wfs
else:
if not is_cell_periodic:
logger.warning(
"Setting non-cell-periodic wavefunctions for 0D k-space."
)
self._wfs = wfs
self._enforce_pbc()
self._invalidate_caches()
[docs]
def remove_states(self, state_idx: int | ArrayLike):
r"""Remove specified states from the :class:`WFArray`.
.. versionadded:: 2.0.0
Parameters
----------
state_idx : int or array-like of int
Indices of the states to remove.
Notes
-----
- This modifies the shape of the :attr:`wfs`, :attr:`energies`,
:attr:`u_nk` and :attr:`psi_nk` arrays.
- The indices in ``state_idx`` refer to the current ordering of states.
After removal, the remaining states are re-indexed accordingly.
Examples
--------
Remove states 0 and 2 from the :class:`WFArray`
>>> wf.remove_states([0, 2])
"""
if self.nspin == 2:
state_ax = -3
elif self.nspin == 1:
state_ax = -2
else:
raise ValueError(
"WFArray object can only handle spinless or spin-1/2 models."
)
state_idx = self._check_state_indices(state_idx, return_indices=True)
n_states = len(state_idx)
self._wfs = np.delete(self._wfs, state_idx, axis=state_ax)
self._nstates -= n_states
self._energies = np.delete(self._energies, state_idx, axis=-1)
if getattr(self, "_u_nk", None) is not None:
self._u_nk = np.delete(self._u_nk, state_idx, axis=state_ax)
if getattr(self, "_psi_nk", None) is not None:
self._psi_nk = np.delete(self._psi_nk, state_idx, axis=state_ax)
[docs]
def choose_states(self, state_idx: int | ArrayLike):
r"""Keep only the specified states in the :class:`WFArray`.
This method modifies the existing states in place to keep only
those specified by ``state_idx``.
Parameters
----------
state_idx : int or array-like of int
Indices of states to keep.
.. versionchanged:: 2.0.0
Renamed from ``subset`` for consistency.
Notes
-----
- This modifies the shape of the :attr:`wfs`, :attr:`energies`,
:attr:`u_nk` and :attr:`psi_nk` arrays.
- The indices in ``state_idx`` refer to the current ordering of states.
After removal, the remaining states are re-indexed accordingly.
Examples
--------
Keep only states 3 and 5 from the :class:`WFArray`
>>> wf.choose_states([3, 5])
"""
state_idx = self._check_state_indices(state_idx, return_indices=True)
remove_indices = np.setdiff1d(np.arange(self.nstates), state_idx)
self.remove_states(remove_indices)
[docs]
def states(
self,
state_idx: ArrayLike | None = None,
flatten_spin_axis: bool = False,
return_psi: bool = False,
) -> np.ndarray:
r"""Return states stored in the *WFArray* object.
The states are returned in the same ordering as stored internally,
with mesh axes leading, followed by band, orbital, and (if present)
spin indices. By default, all states are returned. The user can
specify a subset of states to return using the ``state_idx`` argument.
.. versionadded:: 2.0.0
Parameters
----------
state_idx : int or array-like of int, optional
Index or indices of the states to return. If not provided or None,
all states are returned.
flatten_spin_axis : bool, optional
If True, the spin and orbital indices are flattened into a single index.
Default is False.
return_psi : bool, optional
If True, the function also returns the full Bloch wavefunctions. This should
only be requested when k-axes are present in the mesh and ``dim_k > 0``, otherwise
an error is raised. Default is False.
Returns
-------
u : np.ndarray
The states stored in the *WFArray* object. By default, these are the
cell-periodic states when ``dim_k > 0``. The shape is
``(nk1, nk2,..., nl1, nl2,..., nstate, norb[,nspin])``
If ``flatten_spin_axis=True``, the last two axes are replaced by a single
axis of size ``norb*nspin``.
psi : np.ndarray, optional
Bloch states with the same shape conventions as ``wfs``. These states are
related to the cell-periodic states by plane-wave phase factors. Only
returned if ``return_psi=True``.
See Also
--------
:ref:`formalism`
"""
if return_psi and self.dim_k == 0:
raise ValueError("Bloch states are not defined for 0D k-space.")
u = np.copy(self.wfs)
psi = None if not return_psi else np.copy(self.psi_nk)
state_idx = self._normalize_state_indices(state_idx)
# select requested states
sl = (
(..., state_idx, slice(None), slice(None))
if self.nspin == 2
else (..., state_idx, slice(None))
)
u = u[sl]
if psi is not None:
psi = psi[sl]
if flatten_spin_axis and self.nspin == 2:
u = u.reshape((*u.shape[:-2], -1))
if psi is not None:
psi = psi.reshape((*psi.shape[:-2], -1))
return (u, psi) if return_psi else u
def _nbr_projectors(self, return_Q: bool = False):
if self.dim_k == 0:
raise NotImplementedError(
"Nearest neighbor projectors are not defined for 0D k-space."
)
if not self.mesh.is_grid:
raise NotImplementedError(
"Mesh must be a grid to compute nearest neighbor projectors."
)
# Retrieve cached projectors if available
P = getattr(self, "_P", None)
if P is None:
P = self.projectors(return_Q=False)
# Fast path: cached
if hasattr(self, "_P_nbr"):
if return_Q and hasattr(self, "_Q_nbr"):
return self._P_nbr, self._Q_nbr
return self._P_nbr
# Nearest neighbor shifts
_, nnbr_idx_shell = self.lattice.nn_k_shell(
self.mesh.shape_k, n_shell=1, report=False
)
shifts = nnbr_idx_shell[0]
num_nnbrs = shifts.shape[0]
P_nbr = np.zeros((P.shape[:-2] + (num_nnbrs,) + P.shape[-2:]), dtype=complex)
for idx, idx_vec in enumerate(shifts): # nearest neighbors
u_shifted = self.roll_states_with_pbc(idx_vec, flatten_spin_axis=True)
P = np.matmul(u_shifted.swapaxes(-2, -1), u_shifted.conj())
P_nbr[..., idx, :, :] = P
self._P_nbr = P_nbr
if not return_Q:
return P_nbr
Q_nbr = np.eye(P_nbr.shape[-1]) - P_nbr
self._Q_nbr = Q_nbr
return P_nbr, Q_nbr
[docs]
def projectors(
self, state_idx: int | ArrayLike = None, return_Q: bool = False
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
r"""Returns the band projectors associated with the states in the WFArray.
The band projectors are defined as the outer product of the wavefunctions:
.. math::
P_{n\mathbf{k}} = \lvert u_{n\mathbf{k}}(\mathbf{r})\rangle \langle u_{n\mathbf{k}}(\mathbf{r}) \rvert,
\quad Q_{n\mathbf{k}} = \mathbb{I} - P_{n\mathbf{k}}
.. versionadded:: 2.0.0
Parameters
----------
state_idx : int or array-like of int, optional
Index or indices of the states for which to compute the projectors.
If not provided or None, projectors for all states are computed.
return_Q : bool, optional
If True, the function also returns the orthogonal projector Q.
Returns
-------
P : np.ndarray
The band projectors.
Q : np.ndarray, optional
The orthogonal projectors.
"""
# Check cache
if state_idx is None and hasattr(self, "_P"):
return (self._P, self._Q) if return_Q else self._P
# Compute states
u_nk = self.states(flatten_spin_axis=True)
if state_idx is not None:
u_nk = u_nk[..., state_idx, :]
# Compute projectors
P = np.matmul(u_nk.swapaxes(-2, -1), u_nk.conj())
# Cache full projectors
if state_idx is None:
Q = np.eye(P.shape[-1]) - P
self._P, self._Q = P, Q
else:
Q = None
return (P, Q) if return_Q else P
[docs]
def solve_model(self, model: TBModel, use_tensorflow: bool = False):
r"""Diagonalizes ``model`` on every point of the internal :class:`Mesh`.
The method calls :meth:`TBModel.solve_ham` passing the k-points and
model parameters defined in the :class:`Mesh` and populates the
:class:`WFArray` with the eigenstates and eigenergies of the Hamiltonian.
.. note::
For meshes that include :math:`\lambda`-axes, the axis names are interpreted
as :class:`TBModel` parameter names. The names and values along each
:math:`\lambda`-axis are passed as keyword arguments to :meth:`TBModel.solve_ham`.
These parameter names must match those used in the model definition when using
:meth:`TBModel.set_onsite` and :meth:`TBModel.set_hop`.
.. versionadded:: 2.0.0
Replaces :meth:`solve_on_one_point` and :meth:`solve_on_grid`.
Parameters
----------
model : :class:`TBModel`
The tight-binding model to diagonalize on the mesh. Its
:class:`Lattice` and the same ``spinful`` configuration must match
those of the :class:`WFArray`.
use_tensorflow : bool, optional
If True, uses TensorFlow for diagonalization.
This can be beneficial for large systems where GPU acceleration is available.
This requires TensorFlow to be installed. Default is False.
Notes
-----
- The samples along each :math:`\lambda`-axis are obtained from :func:`Mesh.get_axis_range`
and passed to :meth:`TBModel.solve_ham` as keyword arguments, so it is essential that the
mesh axis names exactly match the symbolic/callable parameter names in the
model.
- Eigenstates stored by ``solve_model`` are chosen to obey a periodic gauge
:math:`\psi_{n,{\bf k+G}}=\psi_{n {\bf k}}`, so the cell-periodic states satisfy
:math:`u_{n,{\bf k_0+G}}=e^{-i{\bf G}\cdot{\bf r}} u_{n {\bf k_0}}`.
See :ref:`formalism` section 4.4 and equation 4.18 for more detail.
- If the mesh includes Brillouin zone boundary points along any k-axis, or a looping
axis is marked as both Brillouin zone winding and closed, ``solve_model`` applies the
periodic-gauge phase to the states along that axis.
Examples
--------
Say we have a parametric model function defined as follows:
>>> model = TBModel(lattice=lat, spinful=False)
>>> model.set_onsite(['param_1'], mode='set')
The model will be 2D in k-space for this example.
We want to vary ``param_1``, and store the energy eigenvalues and eigenstates
on a 3D mesh: 2 dimensions in k-space and 1 dimension for the parameter ``param_1``.
This means we will have a :class:`Mesh` with 2 k-axes and 1 lambda-axis
corresponding to ``param_1``. We can create the mesh as follows:
>>> mesh = Mesh(
... dim_k=2,
... dim_lambda=1,
... axis_types=['k','k','l'],
... axis_names=['k1', 'k2', 'param_1']
... )
Note that we must name the last axis as ``param_1`` to follow the name we set in the model function.
We build the mesh values by using ``build_grid``. We will construct a uniform 2D grid in k-space of shape
``(20, 20)`` and a uniform 1D grid for ``param_1`` with 5 points going from 0 to :math:`2\pi`.
>>> mesh.build_grid(shape=(20, 20, 5), lambda_start=0, lambda_end=2*np.pi)
To initialize the *WFArray*, we specify the same lattice as the model, and the mesh we just created.
If the model were spinful, we would need to set ``spinful=True``.
>>> wfa = WFArray(model.lattice, mesh)
Now, the parameter values are stored in the mesh, and we can diagonalize the model
on the mesh using :meth:`solve_model`:
>>> wfa.solve(model=model)
The eigenvalues and eigenstates are stored in the
``.energies`` and ``.wfs`` attributes, respectively.
They can be accessed as follows:
>>> wfa.energies.shape
(20, 20, 5, 2)
>>> wfa.wfs.shape
(20, 20, 5, 2, 2)
"""
if self.spinful != model.spinful:
raise ValueError("Spinful setting of WFArray does not match the model.")
# lambda-parameter dict
params = {
ax.name: self.mesh.get_axis_range(i, j)
for ax, i, j in zip(
self.mesh.lambda_axes,
self.mesh.lambda_axis_indices,
self.mesh.lambda_component_indices,
)
}
# k-points (flatten k-grid only)
k_flat = self.k_points.reshape(-1, self.dim_k) if self.dim_k else None
eigvals, eigvecs = model.solve_ham(
k_pts=k_flat,
return_eigvecs=True,
flatten_spin_axis=True,
use_tensorflow=use_tensorflow,
**params,
)
n_state_returned = eigvals.shape[-1] if eigvals.ndim else 1
if n_state_returned != self.nstates:
raise ValueError(
"solve_model expected "
f"{self.nstates} bands per mesh point, but TBModel.solve_ham "
f"returned {n_state_returned}. Recreate the WFArray with "
f"nstates={n_state_returned} or adjust the model so the band "
"count matches."
)
# Reshape from TBModel's canonical (k-first, lambda-second) ordering to the user-defined mesh ordering.
shape_k = self.mesh.shape_k
shape_lambda = self.mesh.shape_lambda
canonical_mesh_shape = shape_k + shape_lambda
axes_total = len(self.mesh.axes)
state_shape = self.shape[axes_total:]
eigvals = eigvals.reshape(*canonical_mesh_shape, self.nstates)
eigvecs = eigvecs.reshape(*canonical_mesh_shape, *state_shape)
axis_perm = self._canonical_to_mesh_axes()
if axis_perm:
axes_for_vals = axis_perm + (axes_total,)
eigvals = np.transpose(eigvals, axes_for_vals)
trailing_axes = tuple(range(axes_total, axes_total + len(state_shape)))
eigvecs = np.transpose(eigvecs, axis_perm + trailing_axes)
eigvals = eigvals.reshape(*self.shape_mesh, self.nstates)
eigvecs = eigvecs.reshape(*self.shape)
self.set_states(eigvecs, is_cell_periodic=True, is_spin_axis_flat=False)
self._energies = eigvals
self._model = model
# gaps between adjacent bands
self.gaps = (
(eigvals[..., 1:] - eigvals[..., :-1]).min(axis=tuple(range(self.naxes)))
if self.nstates > 1
else None
)
# Enforce PBCs along winding directions
self._enforce_pbc()
def _get_phases(self, inverse=False):
r"""Compute phase factors for converting between cell-periodic and Bloch wavefunctions.
Parameters
----------
inverse : bool, optional
If True, compute phase factors for converting from Bloch to cell-periodic wavefunctions.
If False, compute phase factors for converting from cell-periodic to Bloch wavefunctions.
Defaults to False.
Returns
-------
phases : np.ndarray
Array of phase factors with shape [nk1, ..., nkd, nl1, ..., nlm, norb, (nspin)].
The last dimension is present only if the model has spin.
"""
lam = -1 if inverse else 1
# k-grid flattened to (Nk, dim_k)
k = self.k_points.reshape(-1, self.dim_k)
# orbital vectors restricted to periodic dirs: (norb, dim_k)
periodic_dirs = np.asarray(self.lattice.periodic_dirs, int)
tau = self.lattice.orb_vecs[:, periodic_dirs]
# phases: exp(+-1 * i 2pi k.tau) with shape (Nk, norb)
phase2d = np.exp(lam * 1j * 2.0 * np.pi * (k @ tau.T))
shape_k = self.mesh.shape_k
axes_total = len(self.mesh.axes)
phases = phase2d.reshape(
*shape_k,
*([1] * self.mesh.nl_axes),
1,
self.norb,
)
axis_perm = self._canonical_to_mesh_axes()
if axis_perm:
phases = np.transpose(phases, axis_perm + (axes_total, axes_total + 1))
if self.nspin == 2:
phases = phases[..., np.newaxis] # spin axis (broadcast)
return phases
def _enforce_pbc(self):
r"""Enforce periodic boundary conditions on all winding loop axes in the mesh.
This routine iterates over all axes in the mesh that are loops and wind
around the Brillouin zone, imposing periodic boundary conditions by
setting the wavefunction at the end of the mesh equal to that at the
beginning multiplied by the appropriate phase factor.
"""
if not self.filled:
return
for idx, ax in enumerate(self.mesh.axes):
if not (ax.has_endpoint and ax.is_loop):
continue
if ax.winds_bz:
# These contain endpoints (k_i = 1 in reduced units)
comps = sorted(
set(ax.endpoint_components) & set(ax.winds_bz_components)
)
phase_total, slc_first, slc_last, comps = self._collect_pbc_phase_info(
idx
)
if phase_total is None:
continue
logger.debug(
f"Imposing PBC in mesh direction {idx} ({ax}) for k-components {comps}"
)
self._apply_pbc_phase(phase_total, slc_first, slc_last)
else:
logger.debug(
f"Imposing loop in mesh direction {idx} ({ax}) without BZ winding."
)
self._impose_loop(idx)
def _collect_pbc_phase_info(self, mesh_axis_idx):
"""Gather combined phase and edge slices for a mesh axis that winds the BZ."""
axis = self.mesh.axes[mesh_axis_idx]
comps = sorted(set(axis.endpoint_components) & set(axis.winds_bz_components))
if not comps:
return None, None, None, tuple()
per_dirs = np.asarray(self.lattice.periodic_dirs, dtype=int)
phase_total = None
slc_first = slc_last = None
for comp in comps:
# NOTE:
# `comp` is Mesh's k-component index (0, ..., dim_k-1).
# _get_pbc_phases expects the real-space index so it grabs correct orbital
# column to dot into k-vector. This is the `periodic_dirs` entry.
# Discrepancy only arises when some lattice directions are non-periodic.
# If periodic axes are [0, 2] and `comp` is 1, then we need to grab
# periodic_dirs[1] = 2 to get correct real-space index.
real_dir = per_dirs[comp]
phase, slc_first, slc_last = self._get_pbc_phases(mesh_axis_idx, real_dir)
# NOTE: multiply phases for multiple components winding BZ
phase_total = phase if phase_total is None else phase_total * phase
return phase_total, slc_first, slc_last, tuple(comps)
@staticmethod
def _edge_slices(ax):
"""Helper function to get slices for the first and last edges of an axis."""
# add one for Python counting and one for ellipses
# Example ax = 2 (2 defines the axis in Python counting)
slc_last = [slice(None)] * (ax + 2) # e.g., [:, :, :, :]
slc_first = [slice(None)] * (ax + 2) # e.g., [:, :, :, :]
# last element along mesh_dir axis
slc_last[ax] = -1 # e.g., [:, :, -1, :]
# first element along mesh_dir axis
slc_first[ax] = 0 # e.g., [:, :, 0, :]
# take all components of remaining axes with ellipses
slc_last[ax + 1] = Ellipsis # e.g., [:, :, -1, ...]
slc_first[ax + 1] = Ellipsis # e.g., [:, :, 0, ...]
return tuple(slc_first), tuple(slc_last)
def _apply_pbc_phase(self, phase, slc_first, slc_last, from_first: bool = True):
"""
Apply the PBC phase between the first and last slice. When ``from_first`` is True the
last slice is overwritten using the first; otherwise the first slice is generated from the last.
"""
phase_conj = np.conjugate(phase)
u_attr = getattr(self, "_u_nk", None)
psi_attr = getattr(self, "_psi_nk", None)
if from_first:
logger.debug(
f"Setting wavefunctions at {slc_last} equal to those at {slc_first} times phase factor."
)
self._wfs[slc_last] = self._wfs[slc_first] * phase
if u_attr is not None:
self._u_nk[slc_last] = self._u_nk[slc_first] * phase
if psi_attr is not None:
self._psi_nk[slc_last] = self._psi_nk[slc_first]
else:
logger.debug(
f"Setting wavefunctions at {slc_first} equal to those at {slc_last} times phase factor."
)
self._wfs[slc_first] = self._wfs[slc_last] * phase_conj
if u_attr is not None:
self._u_nk[slc_first] = self._u_nk[slc_last] * phase_conj
if psi_attr is not None:
self._psi_nk[slc_first] = self._psi_nk[slc_last]
def _copy_edge(self, slc_src, slc_dst):
"""Copy wavefunction data between boundary slices (used for pure loops)."""
self._wfs[slc_dst] = self._wfs[slc_src]
u_attr = getattr(self, "_u_nk", None)
if u_attr is not None:
u_attr[slc_dst] = u_attr[slc_src]
psi_attr = getattr(self, "_psi_nk", None)
if psi_attr is not None:
psi_attr[slc_dst] = psi_attr[slc_src]
def _get_pbc_phases(self, mesh_dir, k_dir):
r"""Compute phase factors for periodic boundary conditions in forward direction.
This routine computes the phase factors needed for imposing periodic
boundary conditions along one direction of the `WFArray`. The phase factors
are given by :math:`e^{-i{\bf G}\cdot{\bf r}}`. In reduced units, this is
:math:`e^{-2\pi i \tau_k}`, where :math:`\tau_k` is the orbital vector
component along the `k_dir` direction corresponding to the reciprocal lattice
vector :math:`{\bf G}`.
Parameters
----------
mesh_dir : int
Direction of the Mesh along which periodic boundary conditions are imposed.
k_dir : int
Component of the k-vector in the Brillouin zone corresponding to `mesh_dir`. This
indexes one of the orbital vectors in the lattice.
Returns
-------
phases : np.ndarray
Array of phase factors with shape [nk1, ..., nkd, norb, (nspin)].
The last dimension is present only if the model has spin.
"""
if k_dir not in self.lattice.periodic_dirs:
raise Exception(
"Periodic boundary condition can be specified only along periodic directions!"
)
if not isinstance(mesh_dir, (int, np.integer)):
raise TypeError("mesh_dir should be an integer!")
if mesh_dir < 0 or mesh_dir >= self.naxes:
raise IndexError("mesh_dir outside the range!")
orb_vecs = self.lattice.orb_vecs
# Compute phase factors from orbital vectors dotted with G parallel to k_dir
phase = np.exp(-2j * np.pi * orb_vecs[:, k_dir])
phase = phase if self.nspin == 1 else phase[:, np.newaxis]
# mesh_dir is the direction of the mesh along which we impose pbc
slc_first, slc_last = self._edge_slices(mesh_dir)
return phase, slc_first, slc_last
def _impose_pbc(self, mesh_dir: int, k_dir: int):
r"""Impose periodic boundary conditions on the WFArray.
This routine sets the cell-periodic Bloch function
at the end of the mesh in direction `k_dir` equal to the first,
multiplied by a phase factor, overwriting the previous value.
Explicitly, this means we set
:math:`u_{n,{\bf k_0+G}}=e^{-i{\bf G}\cdot{\bf r}} u_{n {\bf k_0}}` for the
corresponding reciprocal lattice vector :math:`\mathbf{G} = \mathbf{b}_{\texttt{k_dir}}`,
where :math:`\mathbf{b}_{\texttt{k_dir}}` is the reciprocal lattice basis vector corresponding to the
direction `k_dir`. The state :math:`u_{n{\bf k_0}}` is the state populated in the first element
of the mesh along the `mesh_dir` axis.
Parameters
----------
mesh_dir : int
Direction of `WFArray` along which you wish to impose periodic boundary conditions.
k_dir : int
Corresponding to the periodic k-vector direction
in the Brillouin zone of the underlying *TBModel*. Since
version 1.7.0 this parameter is defined so that it is
specified between 0 and *dim_r-1*.
Notes
-----
This function will impose these periodic boundary conditions along
one direction of the array. We are assuming that the k-point
mesh increases by exactly one reciprocal lattice vector along
this direction.
Examples
--------
Imposes periodic boundary conditions along the mesh_dir=0
direction of the `WFArray` object, assuming that along that
direction the `k_dir=1` component of the k-vector is increased
by one reciprocal lattice vector. This could happen, for
example, if the underlying TBModel is two dimensional but
`WFArray` is a one-dimensional path along :math:`k_y` direction.
>>> wf._impose_pbc(mesh_dir=0, k_dir=1)
"""
if self.dim_k == 0:
raise ValueError(
"Cannot impose periodic boundary conditions in 0D k-space.\n"
"Use `_impose_loop` instead."
)
if k_dir not in self.lattice.periodic_dirs:
raise ValueError(
"Periodic boundary condition can be specified only along periodic directions!"
)
phase, slc_first, slc_last = self._get_pbc_phases(mesh_dir, k_dir)
# Set the last point along mesh_dir axis equal to first
# multiplied by the phase factor
logger.debug(
f"Setting wavefunctions at {slc_last} equal to those at {slc_first} times phase factor."
)
self._wfs[slc_last] = self._wfs[slc_first] * phase
if self.u_nk is not None:
# Set the last point along mesh_dir axis equal to first
# multiplied by the phase factor
self._u_nk[slc_last] = self._u_nk[slc_first] * phase
self._psi_nk[slc_last] = self._psi_nk[slc_first]
def _impose_loop(self, mesh_dir):
r"""Impose a loop condition along a given mesh direction.
This routine can be used to set the
eigenvectors equal (with equal phase), by replacing the last
eigenvector with the first one along the `mesh_dir` direction
(for each band).
Parameters
----------
mesh_dir: int
Direction of `WFArray` along which you wish to
impose periodic boundary conditions.
See Also
--------
:func:`pythtb.WFArray._impose_pbc`
Notes
-----
This routine should not be used if the first and last points
are related by a reciprocal lattice vector; in that case,
:func:`pythtb.WFArray._impose_pbc` should be used instead.
It is assumed that the first and last points along the
`mesh_dir` direction correspond to the same Hamiltonian (this
is **not** checked).
Examples
--------
Suppose the WFArray object is three-dimensional
corresponding to `(kx, ky, lambda)` where `(kx, ky)` are
wavevectors of a 2D insulator and lambda is an
adiabatic parameter that goes around a closed loop.
Then to insure that the states at the ends of the lambda
path are equal (with equal phase) in preparation for
computing Berry phases in lambda for given `(kx, ky)`,
do
>>> wf._impose_loop(mesh_dir = 2)
"""
if not isinstance(mesh_dir, (int, np.integer)):
raise TypeError("mesh_dir must be an integer.")
if mesh_dir < 0 or mesh_dir >= self.naxes:
raise ValueError(
f"mesh_dir must be between 0 and {self.naxes - 1}, got {mesh_dir}."
)
if mesh_dir in self.mesh.k_axes and self.mesh.is_k_torus:
raise ValueError("Cannot impose loop condition on periodic k-space axis.")
slc_first, slc_last = self._edge_slices(mesh_dir)
logger.debug(
f"Setting wavefunctions at {slc_last} equal to those at {slc_first}."
)
self._wfs[slc_last] = self._wfs[slc_first]
if self.dim_k > 0:
if self.u_nk is not None:
self._u_nk[slc_last] = self._u_nk[slc_first]
if self.psi_nk is not None:
self._psi_nk[slc_last] = self._psi_nk[slc_first]
def _unit_shift(self, axis: int, direction: int = 1) -> list[int]:
"""Return an integer shift vector with ±1 along *axis*."""
if axis < 0 or axis >= self.naxes:
raise IndexError(f"axis must be in [0, {self.naxes - 1}]")
if direction not in (-1, 1):
raise ValueError("direction must be ±1.")
v = [0] * self.naxes
v[axis] = direction
return v
@staticmethod
def _bounded_shift(A: np.ndarray, axis: int, sh: int) -> np.ndarray:
"""Shift array A by *sh* along *axis* without wrapping; fill vacated slab with zeros."""
if sh == 0:
return A
sl_all = [slice(None)] * A.ndim
B = np.zeros_like(A)
if sh > 0:
sl_src = sl_all.copy()
sl_dst = sl_all.copy()
sl_src[axis] = slice(0, -sh)
sl_dst[axis] = slice(sh, None)
else: # sh < 0
shn = -sh
sl_src = sl_all.copy()
sl_dst = sl_all.copy()
sl_src[axis] = slice(shn, None)
sl_dst[axis] = slice(0, -shn)
B[tuple(sl_dst)] = A[tuple(sl_src)]
return B
def _boundary_phase_for_shift(self, shift_vec):
"""Compute exp(-i G dot r) mask for a multi-axis integer shift.
The returned array is broadcast to match the stored state tensor shape
(including lambda axes and the state axis). For spinful models, it
is also broadcast over the spin axis.
"""
mesh = self.mesh
nks = np.array(mesh.shape_k, dtype=int)
dim_k = nks.size
if dim_k == 0:
return np.array(1.0, dtype=complex)
k_axes = np.asarray(mesh.k_axis_indices, dtype=int)
# Normalize shift vector and restrict to k-axes
shifts = np.zeros(dim_k, dtype=int)
sv = np.atleast_1d(shift_vec)
# guard: shift_vec may be given in full mesh-axis indexing
for lk, mx in enumerate(k_axes):
sh = int(sv[mx]) if mx < sv.size else 0
# Only keep shifts on axes that wind the BZ; zero out closed axes
if mesh.is_axis_bz_winding(mx):
if mesh.is_axis_closed(mx):
sh = 0
logger.info(f"Axis {mx} is closed; removing shift.")
else:
if sh != 0:
logger.info(f"Axis {mx} is not BZ-winding; removing shift.")
sh = 0
shifts[lk] = sh
# Integer index grid over k-axes: shape (*nks, dim_k)
idx_grid = np.stack(
np.meshgrid(*[np.arange(n) for n in nks], indexing="ij"), axis=-1
) # (*nks, dim_k)
shifted = idx_grid + shifts # (*nks, dim_k)
# Wrap counts per k-axis (handles arbitrary |shift| >= 1)
# floor division is the correct "how many cells crossed" counter
# e.g. n=10: (-1)//10 -> -1; (10)//10 -> 1; (21)//10 -> 2
wraps_k = shifted // nks # (*nks, dim_k), signed wrap count
# Map sampling-axis wraps to k-components via topology mask
# Build M[local_k_axis, comp] = 1 if that sampling axis contributes to comp
M = np.zeros((dim_k, dim_k), dtype=int)
for idx, ax in enumerate(k_axes):
for c in range(dim_k):
if mesh.is_axis_bz_winding(ax, c):
M[idx, c] = 1
# Project wraps to components: G_comp shape (*nks, dim_k)
G_comp = np.einsum("...i, ic -> ...c", wraps_k, M, dtype=int)
# Orbital positions tau restricted to periodic real-space components (norb, dim_k)
per = getattr(self.lattice, "periodic_dirs", None)
if per is None:
if self.dim_k != self.lattice.dim_r:
logger.warning(
"WFArray._boundary_phase_for_shift: lattice.periodic_dirs missing; "
"falling back to first dim_k components."
)
orb = self.lattice.orb_vecs[:, :dim_k]
else:
per = np.asarray(per, dtype=int)
if per.size < dim_k:
raise ValueError(
f"lattice.periodic_dirs lists {per.size} directions; expected ≥ dim_k={dim_k}."
)
orb = self.lattice.orb_vecs[:, per[:dim_k]]
# dot = sum_c G_comp[..., c] * tau[:, c] -> shape (*nks, norb)
dot = np.tensordot(G_comp, orb, axes=([G_comp.ndim - 1], [1]))
phase = np.exp(-2j * np.pi * dot).astype(complex) # (*nks, norb)
# Broadcast to (nk..., nl..., nstate, norb[, nspin]) in one reshape/expand
shape = (*nks, *([1] * self.mesh.nl_axes), 1, self.norb) # band axis
phase = phase.reshape(shape)
if self.nspin == 2:
phase = phase[..., np.newaxis] # spin axis
axis_perm = self._canonical_to_mesh_axes()
if axis_perm:
perm = axis_perm + tuple(range(len(self.mesh.axes), phase.ndim))
phase = np.transpose(phase, perm)
return phase
def _invalidate_boundary_links(self, array: np.ndarray, shift_vec) -> np.ndarray:
"""Stamp NaNs on slabs where a neighbor does not exist for the given shift."""
mesh = self.mesh
ndims = self.naxes
if not isinstance(shift_vec, (tuple, list, np.ndarray)):
shift_vec = (shift_vec,)
for axis, shift in enumerate(shift_vec):
if axis >= ndims or shift == 0:
continue
wraps = mesh.is_axis_looped(axis) or mesh.is_axis_bz_winding(axis)
closed = mesh.is_axis_closed(axis)
if wraps and not closed:
continue
boundary_index = -1 if shift > 0 else 0
slicer = [slice(None)] * array.ndim
slicer[axis] = boundary_index
array[tuple(slicer)] = np.nan + 0j
return array
[docs]
def roll_states_with_pbc(
self,
shift_vec: list[int],
flatten_spin_axis: bool = True,
strip_boundary: bool = False,
):
"""Roll states with periodic boundary conditions.
This method rolls the wavefunction states according to the given shift vector,
applying the appropriate boundary phases to axes that have periodic boundary
conditions.
Parameters
----------
shift_vec : list[int]
List of integer shifts for each axis.
flatten_spin_axis : bool, optional
Whether to flatten the spin axis into the orbital axis, by default True.
strip_boundary : bool, optional
Whether to strip the boundary after rolling, by default False. This
will remove the boundary states along axes with non-periodic boundary
conditions.
Returns
-------
np.ndarray
The rolled wavefunction states with applied boundary conditions.
Examples
--------
>>> rolled_wfa = wfa.roll_states_with_pbc([1, 0])
>>> np.allclose(rolled_wfa[4, 3], rolled_wfa[3, 3])
True
"""
states = self.wfs
mesh = self.mesh
if np.any(abs(np.array(shift_vec, dtype=int)) > 1):
raise ValueError("Only unit shifts (+1, 0, -1) are supported in shift_vec.")
if len(shift_vec) < mesh.nk_axes:
raise ValueError(
"shift_vec must have at least as many elements as k-axes in the mesh."
)
elif len(shift_vec) > mesh.naxes:
raise ValueError(
"shift_vec must have at most as many elements as total axes in the mesh."
)
rolled = states
for ax, sh in enumerate(shift_vec):
if not sh:
continue
wraps = mesh.is_axis_looped(ax) or mesh.is_axis_bz_winding(ax)
closed = mesh.is_axis_closed(ax)
if not closed and wraps:
rolled = np.roll(rolled, shift=-int(sh), axis=ax)
else:
logger.info(
f"Applying bounded shift {sh} to axis {ax} without wrapping."
)
rolled = self._bounded_shift(rolled, axis=ax, sh=-int(sh))
if strip_boundary:
sl = [slice(None)] * rolled.ndim
for ax, sh in enumerate(shift_vec):
loops = mesh.is_axis_looped(ax) or mesh.is_axis_bz_winding(ax)
closed = mesh.is_axis_closed(ax)
if sh and (closed or not loops):
# drop the last index in that direction
sl[ax] = slice(None, -1)
rolled = rolled[tuple(sl)]
phase = self._boundary_phase_for_shift(tuple(shift_vec))
rolled = rolled * phase
if flatten_spin_axis and self.nspin == 2:
rolled = rolled.reshape(*rolled.shape[:-2], self.norb * self.nspin)
return rolled
[docs]
def overlap_matrix(self, use_k_metric: bool = False) -> np.ndarray:
r"""Compute the overlap matrix of the cell periodic eigenstates on nearest neighbor k-shell.
Overlap matrix is of the form
.. math::
M_{m,n}^{\mathbf{b}}(\mathbf{k}, \lambda)
= \langle u_{m, \mathbf{k}, \lambda} | u_{n, \mathbf{k+b}, \lambda} \rangle
where :math:`\mathbf{b}` is a displacement vector connecting nearest neighbor k-points.
.. note::
This method is inteded for use with WFArray objects defined on a uniform grid
of k-points (i.e., Mesh with regular spacing along k-axes).
.. versionadded:: 2.0.0
Parameters
----------
use_k_metric : bool, optional
Whether to use the k-metric for neighbor lookup. If True, the function computes nearest
neighbor k-points in the mesh considering the metric in Cartesian space.
This means that :math:`\mathbf{b}` is not necessarily a unit vector in reduced coordinates,
but rather the vector that connects the
closest k-points in Cartesian space. If False, the function computes nearest neighbor
k-points by shifting the k-points by one step along each k-axis in the mesh.
This means that :math:`\mathbf{b}` is a unit vector in reduced coordinates along each k-axis.
Default is False.
Returns
-------
M : np.ndarray
Overlap matrix with shape ``(*mesh_shape, num_nnbrs, n_states, n_states)``
Notes
-----
- The overlap matrix is computed between cell-periodic parts of the Bloch states.
- When ``use_k_metric`` is True, the leading neighbor index runs over the
*Cartesian metric-derived* k-shell (if any) followed by the :math:`\pm \mu`
unit steps along each :math:`\lambda`-axis.
- When ``use_k_metric`` is False, every :math:`k`- and :math:`\lambda`-axis
contributes both the :math:`+\mu` and :math:`-\mu` neighbors
in *reduced coordinates*, so the third dimension length is
``2 * mesh.naxes``. The ordering is
:math:`(+\kappa_0, -\kappa_0, \cdots, +\lambda_0, -\lambda_0, \cdots )``.
- The overlap matrix is only defined where neighbors exist.
- Boundary behavior follows the ``Mesh`` axis type:
- **Periodic (no endpoint)**
The final point wraps to the first (looped). All entries
are meaningful.
- **Periodic with endpoint included**
The endpoint is explicitly included in the mesh. No wrap is applied;
the final forward link is undefined and is filled with ``NaN``.
Discard the final slice along that axis.
- **Nonperiodic (open directions)`**
No physical neighbor exists at the boundary. Forward links there are
filled with ``NaN`` and should be dropped.
These conventions match those used internally by :meth:`wilson_loop`,
:meth:`berry_phase`, and related routines, which automatically trim
invalid points.
"""
if use_k_metric:
logger.info("Computing overlap matrix using k-metric for neighbor lookup.")
k_axis_indices = self.mesh.k_axis_indices
k_shifts = []
if k_axis_indices:
_, idx_shell = self.lattice.nn_k_shell(
self.mesh.shape_k, n_shell=1, report=False
)
for delta in idx_shell[0]:
shift_vec = [0] * self.naxes
for offset, axis_idx in zip(delta, k_axis_indices):
shift_vec[axis_idx] = int(offset)
k_shifts.append(shift_vec)
lambda_shifts = [
self._unit_shift(axis_idx, direction)
for axis_idx in self.mesh.lambda_axis_indices
for direction in (+1, -1)
]
shift_vectors = k_shifts + lambda_shifts
else:
logger.info(
"Computing overlap matrix without k-metric for neighbor lookup."
)
shift_vectors = [
self._unit_shift(axis_idx, direction)
for axis_idx in (
self.mesh.k_axis_indices + self.mesh.lambda_axis_indices
)
for direction in (+1, -1)
]
# overlap matrix
M = np.zeros(
(*self.shape_mesh, len(shift_vectors), self.nstates, self.nstates),
dtype=complex,
)
u_nk = self.states(flatten_spin_axis=True)
for slot, shift_vec in enumerate(shift_vectors):
rolled = self.roll_states_with_pbc(shift_vec, flatten_spin_axis=True)
overlaps = np.einsum("...mj, ...nj -> ...mn", u_nk.conj(), rolled)
overlaps = self._invalidate_boundary_links(overlaps, shift_vec)
M[..., slot, :, :] = overlaps
return M
[docs]
def links(
self, axis_idx: int | ArrayLike = None, state_idx: int | ArrayLike = None
) -> np.ndarray:
r"""Compute unitary link matrices for specified axis directions.
For each mesh direction :math:`\mu` in ``axis_idx``,
this routine computes the parallel-transport unitary (“link”)
between the states at each point :math:`\boldsymbol{\kappa}` and
its nearest forward neighbor :math:`\boldsymbol{\kappa} + \hat{\mu}`.
First define the overlap matrix:
.. math::
M_{\mu}(\boldsymbol{\kappa})_{mn}
= \langle
u_{m}(\boldsymbol{\kappa})
\mid
u_{n}(\boldsymbol{\kappa} + \hat{\mu})
\rangle.
where :math:`m,n` index the selected states in ``state_idx``.
The **unitary link** is then obtained by taking the unitary factor
of this overlap matrix:
.. math::
U_{\mu}(\boldsymbol{\kappa})
\;\equiv\;
\mathcal{U}\, \big[M_{\mu}(\boldsymbol{\kappa})\big],
Here :math:`\mathcal{U}[\cdot]` denotes the **unitary part**
of the matrix, i.e. the unitary factor in a matrix-factorization
of :math:`M_{\mu}(\boldsymbol{\kappa})`. In practice this is
obtained from the singular-value decomposition (see Notes).
.. versionadded:: 2.0.0
Parameters
----------
axis_idx : int or array_like of int, optional
List of `Mesh` axes along which to compute the links.
If not provided, links will be computed for all directions
in the mesh.
state_idx : int or array_like of int, optional
Index or indices of the states for which to compute the links.
If an integer is provided, only that state will be considered.
If a list is provided, links for all specified states will be computed.
If not provided, links will be computed for all states.
Returns
-------
U_forward : np.ndarray
Array of unitary links. The leading dimension indexes the chosen
``axis_idx`` directions (or all axes if ``axis_idx`` is None).
The trailing two dimensions are the matrix indices in band space.
Shape:
``(ndirs, *mesh_shape, n_states, n_states)``
where ``ndirs = len(axis_idx)`` (or ``WFArray.naxes`` by default).
See Also
--------
:meth:`berry_connection`, :meth:`berry_phase`, :meth:`berry_flux`,
:meth:`wilson_loop`
Notes
-----
- This is the primitive building block for :meth:`berry_connection`,
:meth:`berry_phase`, :meth:`berry_flux`, and :meth:`wilson_loop`
functions.
- The unitary link is construced from the singular value
decomposition of the overlap matrix,
.. math::
M_{\mu} = V_{\mu}\,\Sigma_{\mu}\,W_{\mu}^{\dagger},
as,
.. math::
\mathcal{U} \big[ M_{\mu}(\boldsymbol{\kappa}) \big]
\equiv V_{\mu} W_{\mu}^\dagger
This is equivalent to taking the unitary factor in the
polar decomposition.
.. math::
M_{\mu} = U_{\mu}\,H_{\mu}, \qquad
H_{\mu} = \bigl(M_{\mu}^{\dagger} M_{\mu}\bigr)^{1/2},
\\
\Rightarrow\quad
\mathcal{U}[M_{\mu}] \equiv U_{\mu}
= M_{\mu}\bigl(M_{\mu}^{\dagger} M_{\mu}\bigr)^{-1/2}.
- Boundary behavior follows the ``Mesh`` axis type:
- **Periodic (no endpoint)**
The final point wraps to the first (looped). All entries
are meaningful.
- **Periodic with endpoint included**
The endpoint is explicitly included in the mesh. No wrap is applied;
the final forward link is undefined and is filled with ``NaN``.
Discard the final slice along that axis.
- **Nonperiodic (open directions)**
No physical forward neighbor exists at the boundary. Forward links there
are filled with ``NaN`` to keep the shape consistent. These values are not
physically meaningful and should be dropped.
These conventions match those used internally by :meth:`wilson_loop`,
:meth:`berry_phase`, and related routines, which automatically trim
invalid points.
"""
if axis_idx is None:
axis_idx = np.arange(self.naxes, dtype=int)
else:
axis_idx = np.atleast_1d(axis_idx)
if not np.issubdtype(axis_idx.dtype, np.integer):
raise TypeError("axis_idx must be integer or an integer array.")
if (axis_idx < 0).any() or (axis_idx >= self.naxes).any():
raise IndexError("axis index in axis_idx is out of range.")
# select bands and states once
state_idx = self._normalize_state_indices(state_idx)
wfs = self.states(flatten_spin_axis=True, state_idx=state_idx)
# stack all shifted states along a new leading axis (n_mu, ...)
shifts = [self._unit_shift(mu) for mu in axis_idx]
W = np.stack(
[
np.take(
self.roll_states_with_pbc(s, flatten_spin_axis=True),
state_idx,
axis=-2,
)
for s in shifts
],
axis=0, # (n_mu, ..., nstate, norb)
)
# overlaps O_mu = <u(k)|u(k+dk_mu)> with batched matmul
overlaps = wfs.conj()[None, ...] @ W.swapaxes(
-2, -1
) # (n_mu, ..., nstate, nstate)
# unitary (parallel-transport) factor via polar/SVD: U = V @ W^H
V, _, Wh = np.linalg.svd(overlaps, full_matrices=False) # batched SVD
U_forward = V @ Wh # (n_mu, ..., nstate, nstate)
# invalidate boundary links per axis
for i, s in enumerate(shifts):
U_forward[i] = self._invalidate_boundary_links(U_forward[i], s)
return U_forward
@staticmethod
def _wilson_loop(wfs_loop, wilson_evals: bool = False):
r"""Wilson loop unitary matrix
Computes the Wilson loop unitary matrix and its eigenvalues for multiband Berry phases.
The Wilson loop is a geometric quantity that characterizes the topology of the
band structure. It is defined as the product of the overlap matrices between
neighboring wavefunctions in the loop. Specifically, it is given by
.. math::
U_{\text{Wilson}} = \prod_{n} U_{n}
where :math:`U_{n}` is the unitary part of the overlap matrix between neighboring
wavefunctions in the loop, and the index :math:`n` labels the position in the loop
(see :meth:`links` for more details).
When ``wilson_evals=True``, the function computes the eigenvalues of the Wilson loop
unitary matrix. The eigenvalues are complex numbers of the form
.. math::
\lambda_n = e^{i \phi_n}
where :math:`\phi_n` are the multiband Berry phases associated with each band.
Parameters
----------
wfs_loop : np.ndarray
Has format ``[loop_idx, band, orbital(, spin)]`` and loop has to be one dimensional.
Assumes that first and last loop-point are the same. Therefore if
there are n wavefunctions in total, will calculate phase along n-1
links only!
wilson_evals : bool, optional
If True, then will compute eigenvalues of the Wilson loop unitary and
return the negative phases.
Otherwise just return the Wilson loop unitary matrix. Default is False.
Returns
-------
U_wilson : np.ndarray
Wilson loop unitary matrix of shape ``(band, band)``.
eval_pha : np.ndarray, optional
Multiband Berry phases associated with each band.
Returned only if ``wilson_evals=True``, otherwise not returned.
Notes
------
``wilson_evals`` are to be distinguished from multiband Berry phases, in :meth:`berry_phase`.
The ``berry_evals`` are the phase arguments of ``wilson_evals`` and are always returned between
:math:`-\pi` and :math:`\pi`.
"""
# check that wfs_loop has appropriate shape
if wfs_loop.ndim < 3 or wfs_loop.ndim > 4:
raise ValueError(
"wfs_loop must be a 3D or 4D array with shape [loop_idx, band, orbital(, spin)]"
)
# check if there is a spin axis, then flatten
is_spin = wfs_loop.ndim == 4 and wfs_loop.shape[-1] == 2
if is_spin:
# flatten spin axis
wfs_loop = wfs_loop.reshape(wfs_loop.shape[0], wfs_loop.shape[1], -1, 2)
ovr_mats = wfs_loop[:-1].conj() @ wfs_loop[1:].swapaxes(-2, -1)
V, _, Wh = np.linalg.svd(ovr_mats, full_matrices=False)
U_link = V @ Wh
U_wilson = U_link[0]
for i in range(1, len(U_link)):
U_wilson = U_wilson @ U_link[i]
# calculate phases of all eigenvalues
if wilson_evals:
eigvals = np.linalg.eigvals(U_wilson) # Wilson loop eigenvalues
return U_wilson, eigvals
else:
return U_wilson
@staticmethod
def _berry_loop(wfs_path, berry_evals: bool = False):
r"""Berry phase along a one-dimensional path of wavefunctions.
The Berry phase along a one-dimensional path of wavefunctions
is computed using the Wilson loop unitary matrix.
When ``berry_evals=False``, the Berry phase is computed as the logarithm
of the determinant of the product of the overlap matrices between
neighboring wavefunctions in the path. In otherwords, the Berry phase is
given by the formula:
.. math::
\phi = -\text{Im} \ln \det U_{\rm Wilson}
where :math:`U` is the Wilson loop unitary matrix obtained from
:meth:`wilson_loop`.
When ``berry_evals=True``, the function returns an array of
the individual phases (multiband Berry phases) for each band.
They are computed as
.. math::
\phi_n = -\text{Im} \ln \lambda_n
where :math:`\lambda_n` are the eigenvalues of the Wilson loop
unitary matrix. These multiband Berry phases correspond to the
"maximally localized Wannier centers" or "Wilson loop eigenvalues".
Parameters
----------
wfs_loop : np.ndarray
Wavefunctions in the path, with shape ``(path_idx, band, orbital, spin)``.
berry_evals : bool, optional
Default is `False`. If `True`, will return the argument of the eigenvalues
of the Wilson loop unitary matrix instead of the total Berry phase.
If False, will return the total Berry phase for the loop.
Returns
-------
berry_phase : float
The total Berry phase for the loop.
berry_evals : np.ndarray, optional
If berry_evals is True, returns an array of multiband Berry phases
associated with each band.
Notes
-----
The loop is assumed to be one-dimensional.
The wavefunctions in the loop should be ordered such that the first point
corresponds to the first wavefunction,
the second point to the second wavefunction, and so on, up to the last point,
which corresponds to the last wavefunction.
When the path of wavefunctions is closed, the Berry
phase corresponds to the geometric phase acquired by the wavefunctions
as they are transported around the loop. If the path is not closed, the
Berry phase will depend on the specific path taken.
"""
if wfs_path.ndim < 3 or wfs_path.ndim > 4:
raise ValueError(
"wfs_path must be a 3D or 4D array with shape (path_idx, band, orbital(, spin))"
)
if berry_evals:
U_wilson, eigvals = WFArray._wilson_loop(wfs_path, wilson_evals=berry_evals)
eigvals_phase = -np.angle(eigvals) # Multiband Berry phases
# sort the eigenvalues
eigvals_phase = np.sort(eigvals_phase)
berry_phase = -np.angle(np.linalg.det(U_wilson))
return berry_phase, eigvals_phase
else:
U_wilson = WFArray._wilson_loop(wfs_path, wilson_evals=berry_evals)
berry_phase = -np.angle(np.linalg.det(U_wilson))
return berry_phase
[docs]
def wilson_loop(self, axis_idx: int, state_idx=None, wilson_evals: bool = False):
r"""Wilson loop along a specified mesh axis.
The Wilson loop is defined as the ordered product of *unitary link matrices*
along a closed loop string of points in parameter space. For a direction
:math:`\mu` in the mesh, this routine computes the Wilson loop unitary
matrix along that direction,
.. math::
W_{\mu} = \prod_{n=0}^{N_{\mu}-1}
U_{\mu}\!\bigl(\boldsymbol{\kappa}_n \bigr)
where :math:`\mu` corresponds to ``axis_idx``, and
:math:`U_{\mu}(\boldsymbol{\kappa}_n)` is the **unitary part** of the
overlap matrix between the states at consecutive mesh points
(see :meth:`links`).
.. versionadded:: 2.0.0
Parameters
----------
axis_idx : int
Index of ``Mesh`` axis along which Wilson loop is computed.
state_idx : int, array-like, optional
Optional band index or array of band indices to be included
in the subsequent calculations. If unspecified, all bands are
included.
wilson_evals : bool, optional
If True, then will compute eigenvalues of the Wilson loop and
return them along with the Wilson loop. Default is False.
Returns
-------
U_wilson : np.ndarray
Wilson loop unitary matrix of shape ``(band, band)``.
eigvals : np.ndarray, optional
Unit norm complex eigenvalues of the Wilson loop unitary matrix.
Returned only if ``wilson_evals=True``, otherwise not returned.
See Also
--------
:meth:`berry_phase`
:meth:`links`
Notes
-----
- When ``wilson_evals=True``, the function computes and returns the eigenvalues of the
Wilson loop unitary matrix. The eigenvalues are complex numbers of the form
.. math::
\lambda_n = e^{i \phi_n}
where :math:`\phi_n` are the multiband Berry phases associated with each band.
- ``wilson_evals`` are to be distinguished from multiband Berry phases (as returned
in :meth:`berry_phase` with ``berry_evals=True``). The ``berry_evals`` are the **phase**
arguments of ``wilson_evals``, not the eigenvalues themselves.
- For an array of size ``N`` along ``axis_idx``, the Wilson loop is
formed from the ``N-1`` nearest-neighbor inner products. This gives an
open-path "Wilson line" unless the endpoints correspond to the same physical
Hamiltonian; in which case the first state is appended to the end (if endpoints
are not already identical) to close the loop (with appropriate PBC phase along k-axes).
"""
if (
not isinstance(axis_idx, (int, np.integer))
or axis_idx < 0
or axis_idx >= self.naxes
):
raise ValueError(f"axis_idx must be an integer in [0, {self.naxes - 1}]")
# States (optionally restricted to a subspace)
u = self.states(state_idx=state_idx, flatten_spin_axis=True)
u_expanded = self.states(state_idx=state_idx, flatten_spin_axis=False)
u_loop = u # init wf loop
for ax, comp in self.mesh._get_loop_ax_comp():
# If axis is periodic and open, we need to append
# the first state to the end
if ax == axis_idx and not self.mesh.is_axis_closed(ax, comp):
# If component is along k and wraps bz, apply phase
if self.mesh.is_axis_bz_winding(ax, comp):
logger.debug(
"Applying phase to state at beginning to end of open periodic axis."
)
phase, _, _ = self._get_pbc_phases(ax, comp)
u_first = np.take(u_expanded, 0, axis=axis_idx)
u_last = u_first * phase
# No phase is applied
else:
u_last = np.take(u_expanded, 0, axis=axis_idx)
# flatten spin
if self.nspin == 2:
u_last = u_last.reshape(*u_last.shape[:-2], -1)
logger.debug(
"Appending state at beginning to end of open periodic axis."
)
u_loop = np.concatenate(
[u_loop, np.expand_dims(u_last, axis=axis_idx)], axis=axis_idx
)
# Bring loop axis first for easy slicing over transverse axes
u_loop = np.moveaxis(u_loop, axis_idx, 0)
tail_shape = u_loop.shape[1:-2] # Shape of the tail (transverse) axes
n_sub = u_loop.shape[-2] # Number of subbands
if wilson_evals:
evals = np.empty((*tail_shape, n_sub), dtype=float)
unitar = np.empty((*tail_shape, n_sub, n_sub), dtype=complex)
# Iterate over transverse indices without flattening
it = np.ndindex(*tail_shape) if tail_shape else [()]
for idx in it:
# Take all points along loop axis, and the given transverse indices
# plus all states and orbitals (and spin)
slicer = (slice(None),) + idx + (slice(None), slice(None))
wf_line = u_loop[slicer] # shape: (n_mu or n_mu+1, n_sub, norb*spin)
if wilson_evals:
# val are the individual phases of Wilson loop eigenvalues
U, eval = self._wilson_loop(wf_line, wilson_evals=wilson_evals)
evals[idx] = eval
else:
U = self._wilson_loop(wf_line, wilson_evals=wilson_evals)
# val is the total Berry phase for the loop
unitar[idx] = U
unitar = np.array(unitar)
if wilson_evals:
evals = np.array(evals)
return unitar, evals
return unitar
[docs]
def berry_connection(
self,
axis_idx: int | ArrayLike = None,
state_idx: int | ArrayLike = None,
*,
return_unitaries: bool = False,
cartesian: bool = False,
):
r"""Berry connection from parallel-transport links.
This routine evaluates the (non-Abelian) Berry connection on the
reduced parameter mesh. For each mesh direction :math:`\mu` in
``axis_idx``, the connection is obtained from the parallel-transport link
unitaries :math:`U_{\mu}(\boldsymbol{\kappa})` returned by
:meth:`links`, using the discrete finite-difference approximation:
.. math::
A_{\mu}(\boldsymbol{\kappa})
\;=\;
-\frac{1}{i\,\Delta \kappa_{\mu}}
\log\!\big[\,U_{\mu}(\boldsymbol{\kappa})\,\big],
where :math:`\Delta \kappa_{\mu}` is the step size (in reduced coordinates)
between adjacent points along direction :math:`\mu`. The result is a
matrix-valued, non-Abelian connection defined over the full reduced mesh
:math:`\boldsymbol{\kappa} = (k_1,\dots,k_{d_k};\,\lambda_1,\dots,\lambda_{d_\lambda})`,
with the final two array axes spanning the band subspace specified by
``state_idx``.
.. versionadded:: 2.0.0
Parameters
----------
axis_idx : int or array_like of int or None, optional
Mesh directions :math:`\mu` along which to compute the connection.
If None, all mesh axes are used. Default is None.
state_idx : int or array_like of int or None, optional
Subspace (band indices) to use. If None, use all.
Default is None.
return_unitaries : bool, optional
If True, also return the link unitaries :math:`U_{\mu}`.
Default is False.
cartesian : bool, optional
If True, compute the step size :math:`\Delta k_\mu` in Cartesian space
(using the reciprocal lattice vectors) rather than in reduced coordinates.
Default is False.
Returns
-------
A : ndarray
Non-Abelian connection with shape: ``(n_mu, *mesh_shape, nstate, nstate)``,
where ``n_mu = len(axis_idx)`` (or ``WFArray.naxes`` if ``axis_idx=None``)
and ``nstate = len(state_idx)`` (or ``WFArray.nstates`` if ``state_idx=None``).
U : ndarray, optional
The link unitaries with same shape (returned if ``return_unitaries=True``).
See Also
--------
:meth:`links`, :meth:`berry_phase`, :meth:`wilson_loop`
Notes
-----
- The connection is Hermitian, :math:`A_\mu^\dagger = A_\mu`.
- The logarithm is computed using spectral decomposition of the unitaries
:math:`U = V \, \text{diag}(e^{i\theta_n}) \, V^\dagger`, with principal phases
:math:`\theta_n` in :math:`(-\pi, \pi]`.
- Entries of the connection at invalid boundaries
(where links are NaN; see :meth:`links`) remain NaN.
"""
# Get link unitaries U_mu (n_mu, ..., nstate, nstate), with NaNs at boundaries
U = self.links(
state_idx=state_idx, axis_idx=axis_idx
) # (n_mu, ..., nstate, nstate)
# Build dk per mu
if axis_idx is None:
axis_idx = np.arange(self.naxes, dtype=int)
else:
axis_idx = np.atleast_1d(axis_idx)
# Compute spacings
step_list = []
dim_tot = self.mesh.dim_k + self.mesh.dim_lambda
dim_k = self.mesh.dim_k
for ax in axis_idx:
delta_vec = []
for comp in range(dim_tot):
arr = self.mesh.get_axis_range(ax, comp)
if arr.size < 2:
continue
diff = arr[1] - arr[0]
if not np.isclose(diff, 0.0):
delta_vec.append(diff)
if not delta_vec:
raise ValueError(
f"Could not determine step size along axis {ax} for Berry connection."
)
if cartesian and dim_k:
reduced = delta_vec[:dim_k] # reduced k-step
cart = reduced @ self.lattice.recip_lat_vecs # Cartesian k-step
combined = np.concatenate([cart, delta_vec[dim_k:]])
step_list.append(np.linalg.norm(combined))
else:
# reduced step (first varying component)
nonzero = np.flatnonzero(delta_vec)
step_list.append(delta_vec[nonzero[0]])
dk = np.asarray(step_list, dtype=float)
# Compute A_mu from U_mu
# We'll handle boundaries (NaNs) by masking rows; leave A as NaN there.
A = np.empty_like(U, dtype=complex)
# A = (1/(i dk)) * log(U) via eigen-decomposition
# (unitary U is normal -> unitarily diagonalizable:
# U = V diag(e^{i theta}) V^dagger, log U = V diag(i theta) V^dagger)
for i_mu, dki in enumerate(dk):
Ui = U[i_mu]
Ai = np.empty_like(Ui, dtype=complex)
# flatten batch, do per-matrix eig, then reshape back
batch_shape = Ui.shape[:-2]
nB = Ui.shape[-1]
Ui_flat = Ui.reshape((-1, nB, nB))
Ai_flat = Ai.reshape((-1, nB, nB))
# mask boundaries: where any entry is NaN, fill A with NaN and skip eig
invalid = np.isnan(Ui_flat[..., 0, 0])
for p in range(Ui_flat.shape[0]):
if invalid[p]:
Ai_flat[p, :, :] = np.nan + 0j
continue
w, V = np.linalg.eig(Ui_flat[p])
# principal phases in (-pi, pi]
theta = np.angle(w)
logU = (V * (1j * theta)) @ V.conj().T
Ai_flat[p] = -logU / (1j * dki)
A[i_mu] = Ai_flat.reshape(batch_shape + (nB, nB))
return (A, U) if return_unitaries else A
[docs]
def berry_phase(
self,
axis_idx: int,
state_idx: list[int] = None,
berry_evals: bool = False,
contin: bool = True,
):
r"""Berry phase accumulated along a specified mesh axis.
This routine evaluates the geometric phase accumulated by a set of
states transported along a closed loop in parameter space. The phase
is computed from the Wilson loop unitary matrix. Explicitly,
.. math::
\phi_{\mu} = -\text{Im} \ln \det W_{\mu}
where :math:`W_{\mu}` (obtained from :meth:`wilson_loop`) is the
ordered product of the unitary (overlap) links along the ``axis_idx``
direction :math:`\mu` in the mesh.
If ``berry_evals=True``, the function additionally returns the *individual*
multiband phases ("hybrid Wannier centers") determined from the phases of the
eigenvalues of the Wilson loop unitary matrix,
.. math::
\phi_{\mu}^{(n)} = -\text{Im} \ln \lambda_{\mu}^{(n)}
When summed modulo :math:`2\pi`, the :math:`\phi_{\mu}^{(n)}` reproduce
the total Berry phase.
Parameters
----------
axis_idx : int
Index of the ``Mesh`` axis along which the Berry phase is
computed. For a one-dimensional ``Mesh``, this must be 0.
.. versionchanged:: 2.0.0
Renamed from `dir` to `axis_idx` to avoid conflict
with built-in Python function `dir()`.
state_idx : int or array-like of int, optional
Band index or list of indices specifying which states to
include in the Berry phase calculation. If omitted, all
bands are included.
.. versionchanged:: 2.0.0
Renamed from ``occ``. The band indices are not required to be
occupied bands only. The default behavior is to include all bands,
and the ``"all"`` option has been removed.
contin : bool, optional
Controls branch-continuity of the returned Berry phases. If True (default),
phases are chosen to vary smoothly along the orthogonal mesh directions. The
reference phase (first string) is constrained to lie in :math:`[-\pi, \pi]`.
Subsequent phases are adjusted by adding or subtracting :math:`2 \pi`
as needed to ensure continuity with the previous string. If False, all
phases are constrained to lie in :math:`[-\pi, \pi]`.
berry_evals : bool, optional
If True, also return the Wilson loop eigen-phases
(hybrid Wannier centers). These are branch-fixed following the same
rules as above.
Returns
-------
phase : np.ndarray
Total Berry phase along ``axis_idx``. Scalar when only one axis
is present in the ``Mesh``, otherwise an array with one less
dimension than the original ``Mesh``, corresponding to the
Berry phase for each remaining ``Mesh`` points.
For example, if ``Mesh`` has axes ``[i, j, k]``,
and ``axis_idx=1``, then the returned array has shape
``[i, k]``.
evals : np.ndarray, optional
Only returned if ``berry_evals=True``. Wilson loop eigen-phases
along ``axis_idx``. Follows the same indexing convention as above,
with an additional trailing index labeling the phases :math:`\phi_n`.
See Also
--------
:meth:`wilson_loop`
:ref:`formalism` : Sec. 4.5 for the discretized formula used to compute Berry phase.
:ref:`haldane-bp-nb` : For an example
:ref:`cone-nb` : For an example
:ref:`three-site-thouless-nb` : For an example
Notes
-----
- ``berry_evals`` are to be distinguished from Wilson loop eigenvalues (as returned
in :meth:`wilson_loop` with ``wilson_evals=True``). The ``berry_evals`` are the **phase**
arguments of ``wilson_evals``, not the eigenvalues themselves.
- For a single 1D string, the Berry phase is always returned in
:math:`[-\pi,\pi]`.
- For multidimensional meshes, the phase is computed for each 1D string
obtained by slicing along ``axis_idx``; continuity treatment depends on
``contin``.
- A manifold specified by ``state_idx`` should be isolated from other
states (no degeneracies with states outside the subset). Ensuring this
is the user's responsibility.
- For an array of size ``N`` along ``axis_idx``, the Wilson loop is
formed from the ``N-1`` nearest-neighbor inner products. This gives an
open-path phase unless the endpoints correspond to the same physical
Hamiltonian; in which case the first state is appended to the end
(if endpoints are not already identical) to close the loop
(with appropriate PBC phase along k-axes).
Examples
--------
Compute Berry phases along the second mesh axis for the three lowest
bands. If ``wf`` has axes ``[i, j, k]``, then ``phase[i, k]`` is the
result along the string ``wf[i, :, k]``.
>>> phase = wf.berry_phase(axis_idx=1, state_idx=[0, 1, 2])
"""
if (
not isinstance(axis_idx, (int, np.integer))
or axis_idx < 0
or axis_idx >= self.naxes
):
raise ValueError(f"axis_idx must be an integer in [0, {self.naxes - 1}]")
# States (optionally restricted to a subspace)
u = self.states(state_idx=state_idx, flatten_spin_axis=True)
u_expanded = self.states(state_idx=state_idx, flatten_spin_axis=False)
u_loop = u # init wf loop
for ax, comp in self.mesh._get_loop_ax_comp():
# If axis is periodic and open, we need to append
# the first state to the end
if ax == axis_idx and not self.mesh.is_axis_closed(ax, comp):
# If component is along k and wraps bz, apply phase
if self.mesh.is_axis_bz_winding(ax, comp):
logger.debug(
"Applying phase to state at beginning to end of open periodic axis."
)
real_comp = self.lattice.periodic_dirs[comp]
phase, _, _ = self._get_pbc_phases(ax, real_comp)
u_first = np.take(u_expanded, 0, axis=axis_idx)
u_last = u_first * phase
# No phase is applied
else:
u_last = np.take(u_expanded, 0, axis=axis_idx)
# flatten spin
if self.nspin == 2:
u_last = u_last.reshape(*u_last.shape[:-2], -1)
logger.debug(
"Appending state at beginning to end of open periodic axis."
)
u_loop = np.concatenate(
[u_loop, np.expand_dims(u_last, axis=axis_idx)], axis=axis_idx
)
# Bring loop axis first for easy slicing over transverse axes
u_loop = np.moveaxis(u_loop, axis_idx, 0)
tail_shape = u_loop.shape[1:-2] # Shape of the tail (transverse) axes
n_sub = u_loop.shape[-2] # Number of subbands
if berry_evals:
out = np.empty((*tail_shape, n_sub), dtype=float)
else:
out = np.empty(tail_shape, dtype=float)
# Iterate over transverse indices without flattening
it = np.ndindex(*tail_shape) if tail_shape else [()]
for idx in it:
# Take all points along loop axis, and the given transverse indices
# plus all states and orbitals (and spin)
slicer = (slice(None),) + idx + (slice(None), slice(None))
wf_line = u_loop[slicer] # shape: (n_mu or n_mu+1, n_sub, norb*spin)
if berry_evals:
# val are the individual phases of Berry loop eigenvalues
_, val = self._berry_loop(wf_line, berry_evals=berry_evals)
else:
# val is the total Berry phase for the loop
val = self._berry_loop(wf_line, berry_evals=berry_evals)
out[idx] = val
out = np.array(out)
# Make continuous
if contin:
if len(tail_shape) == 0:
# Make phases continuous for each band
# ret = np.unwrap(ret, axis=0)
pass
elif berry_evals:
# 2D case
if out.ndim == 2:
out = _array_phases_cont(out, out[0])
# 3D case
elif out.ndim == 3:
for i in range(out.shape[1]):
if i == 0:
clos = out[0, 0]
else:
clos = out[0, i - 1]
out[:, i] = _array_phases_cont(out[:, i], clos)
elif self._dim_arr != 1:
raise ValueError("Wrong dimensionality!")
else:
# 2D case
if out.ndim == 1:
out = _one_phase_cont(out, out[0])
# 3D case
elif out.ndim == 2:
for i in range(out.shape[1]):
if i == 0:
clos = out[0, 0]
else:
clos = out[0, i - 1]
out[:, i] = _one_phase_cont(out[:, i], clos)
elif self._dim_arr != 1:
raise ValueError("Wrong dimensionality!")
return out
[docs]
def berry_flux(
self,
plane: ArrayLike | None = None,
state_idx: int | ArrayLike | None = None,
non_abelian: bool = False,
*,
use_tensorflow: bool = False,
):
r"""Berry flux tensor using the Fukui-Hatsugai-Suzuki plaquette method.
The Berry flux tensor :math:`\mathcal{F}_{\mu\nu}(\boldsymbol{\kappa})`
on a reduced mesh point :math:`\boldsymbol{\kappa}` is evaluated by forming the
ordered product of parallel-transport link unitaries around an elementary
plaquette in directions :math:`(\mu,\nu)` [1]_:
.. math::
\mathcal{F}_{\mu\nu}(\boldsymbol{\kappa})
= -\,\operatorname{Im}\,
\ln\!\Big[
U_{\mu}(\boldsymbol{\kappa})\,
U_{\nu}(\boldsymbol{\kappa}+\hat{\mu})\,
U_{\mu}^{\dagger}(\boldsymbol{\kappa}+\hat{\nu})\,
U_{\nu}^{\dagger}(\boldsymbol{\kappa})
\Big],
where :math:`U_{\mu}(\boldsymbol{\kappa})` is the unitary link obtained from
the overlap between states at :math:`\boldsymbol{\kappa}` and its forward
neighbor along direction :math:`\mu` (see :meth:`links`). When a multiband
subspace is supplied, this expression yields the **non-Abelian**
matrix-valued flux; taking the matrix determinant gives the usual **Abelian**
(band-traced) flux.
.. versionremoved:: 2.0.0
The `individual_phases` parameter has been removed.
Parameters
----------
plane : (2,) array_like, optional
Array or tuple of two mesh-axis indices over which the flux
is computed. By default, all index pairs are evaluated.
.. versionchanged:: 2.0.0
Renamed from ``dirs``.
state_idx : int or array_like of int, optional
Indices of the states spanning the subspace.
If None, all states are used.
.. versionchanged:: 2.0.0
Renamed from ``occ``. The band indices are not required to be
occupied bands only. The default behavior is to include all bands,
and the ``"all"`` option has been removed.
non_abelian : bool, optional
If *True*, return the matrix-valued non-Abelian flux.
If *False* (default), return the Abelian (band-traced) flux.
.. versionadded:: 2.0.0
use_tensorflow : bool, optional
If *True*, use TensorFlow for speedup on large calculations.
Default is *False*.
Returns
-------
flux : ndarray
The Berry flux tensor, which is an array of shape
- ``non_abelian=True``, ``plane=None``:
``(naxes, naxes, *mesh_shape, nstates, nstates)``,
- ``non_abelian=True``, ``plane=(mu, nu)``:
``(*mesh_shape, nstates, nstates)``,
- ``non_abelian=False``, ``plane=None``:
``(naxes, naxes, *mesh_shape)``,
- ``non_abelian=False``, ``plane=(mu, nu)``:
``(*mesh_shape,)``,
where ``nstates = len(state_idx)`` (or ``WFArray.nstates`` if
``state_idx=None``).
Notes
-----
- For a given :math:`(\mathbf{k}, \boldsymbol{\lambda})`-point :math:`\boldsymbol{\kappa}`
and pair of mesh directions :math:`(\mu, \nu)`, the plaquette is formed by the points:
.. math::
\begin{pmatrix}
\boldsymbol{\kappa} \\
\boldsymbol{\kappa} + \hat{\mu} \\
\boldsymbol{\kappa} + \hat{\mu} + \hat{\nu} \\
\boldsymbol{\kappa} + \hat{\nu}
\end{pmatrix}
where :math:`\hat{\mu}` and :math:`\hat{\nu}` are the vectors
connecting neighboring points along directions :math:`\mu` and :math:`\nu`
in the reduced mesh.
- Let :math:`U_{\mu}(\boldsymbol{\kappa})` denote the unitary **link matrix**
(unitary part of overlap matrix between states) from
:math:`\boldsymbol{\kappa}` to :math:`\boldsymbol{\kappa} + \hat{\mu}`:
.. math::
U_{\mu}(\boldsymbol{\kappa}) \equiv
\mathcal{U} \,\Big[ M_\mu (\boldsymbol{\kappa}) \Big]
where :math:`\mathcal{U}` denotes the unitary part of (polar decomposition,
see :meth:`links`), and
:math:`M_\mu (\boldsymbol{\kappa})_{mn} = \langle u_{m}(\boldsymbol{\kappa})
\,|\, u_{n}(\boldsymbol{\kappa} + \hat{\mu}) \rangle` is the overlap matrix
of the states in the subspace defined by ``state_idx``.
When ``non_abelian=True``, the (non-Abelian) Berry flux tensor is computed by taking the
imaginary part of the matrix logarithm of the product of the link matrices
around the plaquettes. It is defined as,
.. math::
\mathcal{F}_{\mu\nu}(\boldsymbol{\kappa}) =
-\mathrm{Im} \,\ln \Big[
U_{\mu}(\boldsymbol{\kappa}) \;
U_{\nu}(\boldsymbol{\kappa} + \hat{\mu}) \;
U_{\mu}^\dagger(\boldsymbol{\kappa} + \hat{\nu}) \;
U_{\nu}^\dagger(\boldsymbol{\kappa})
\Big].
This definition holds for multi-band subspaces, where the link
matrices are square and unitary in the occupied-band space.
When ``non_abelian=False``, the (Abelian) Berry flux tensor is computed by taking
the imaginary part of the logarithm of the determinant of the product
of the link matrices around the plaquettes. This is
equivalently the band-trace of the non-Abelian Berry flux tensor.
It is defined as,
.. math::
\mathcal{F}_{\mu\nu}(\boldsymbol{\kappa}) =
-\mathrm{Im} \,\ln \,\det \Big[
U_{\mu}(\boldsymbol{\kappa})
U_{\nu}(\boldsymbol{\kappa} + \hat{\mu})
U_{\mu}^{\dagger}(\boldsymbol{\kappa} + \hat{\nu})
U_{\nu}^{\dagger}(\boldsymbol{\kappa}) \Big].
- This method requires at least a two-dimensional mesh in the
combined adiabatic space (momentum + adiabatic parameters). Thus,
``WFArray.naxes >= 2``.
- The last point along closed or non-periodic axes is trimmed from the flux
array to avoid overcounting, since it is equivalent to the first point.
Examples
--------
Consider the case where we have a mesh with ``naxes`` axes in the combined
adiabatic space (momentum + adiabatic parameters). The Berry flux can be
computed over all 2D planes of the mesh,
>>> flux = wf.berry_flux(
... state_idx = [0, 1, 2]
... ) # shape: (naxes, naxes, *mesh_shape)
or over a specific plane,
>>> flux = wf.berry_flux(
... plane = (0, 1),
... state_idx = [0, 1, 2]
... ) # shape: (*mesh_shape)
We can also compute the non-Abelian Berry flux tensor over a specific plane:
>>> flux = wf.berry_flux(
... plane = (0,1),
... state_idx = [0, 1, 2],
... non_abelian = True
... ) # shape: (*mesh_shape, nstates, nstates)
References
----------
.. [1] \ T. Fukui, Y. Hatsugai, and H. Suzuki, *J. Phys. Soc. Jpn.* **74**, 1674 (2005).
"""
if (self.naxes) < 2:
raise ValueError(
"Berry curvature only defined if number of mesh axes >= 2."
)
# Validate plane
ndims = self.naxes # Total dimensionality of adiabatic space: d
if plane is not None:
if not isinstance(plane, (list, tuple, np.ndarray)):
raise TypeError("plane must be None, a list, tuple, or numpy array.")
if len(plane) != 2:
raise ValueError("plane must contain exactly two directions.")
if any(p < 0 or p >= ndims for p in plane):
raise ValueError(f"Plane indices must be between 0 and {ndims - 1}.")
if plane[0] == plane[1]:
raise ValueError("Plane indices must be different.")
state_idx = self._normalize_state_indices(state_idx)
n_states = len(state_idx) # Number of states considered
flux_shape = list(
self.shape_mesh
) # Number of points in adiabatic mesh: (nk1, nk2, ..., nkd)
# Initialize the Berry flux array
if plane is None:
shape = (
(ndims, ndims, *flux_shape, n_states, n_states)
if non_abelian
else (ndims, ndims, *flux_shape)
)
berry_flux = np.zeros(shape, dtype=complex)
dirs = list(range(ndims))
plane_idxs = ndims
else:
p, q = plane # Unpack plane directions
dirs = [p, q]
plane_idxs = 2
shape = (*flux_shape, n_states, n_states) if non_abelian else (*flux_shape,)
berry_flux = np.zeros(shape, dtype=complex)
# Trim the last point along closed/non-periodic axes to avoid overcounting
for ax in dirs:
if self.mesh.is_axis_closed(ax) or (
not self.mesh.is_axis_looped(ax)
and not self.mesh.is_axis_bz_winding(ax)
):
logger.debug(
f"Axis {ax} is closed or non-periodic. "
"Trimming the last point in the flux array to avoid overcounting."
)
if plane is None:
berry_flux = np.delete(berry_flux, -1, axis=ax + 2)
else:
berry_flux = np.delete(berry_flux, -1, axis=ax)
# U_forward: Unitary part of overlaps <u_{nk} | u_{n, k+delta k_mu}>
U_forward = self.links(state_idx=state_idx, axis_idx=dirs)
# Compute Berry flux for each pair of states
for mu in range(plane_idxs):
for nu in range(mu + 1, plane_idxs):
# NOTE: The order of U_forward follows the order in dirs, so we index accordingly
# e.g., if dirs = [p, q], then mu=0 -> p, mu=1 -> q
U_mu = U_forward[mu]
U_nu = U_forward[nu]
# Shift the links along the mu and nu directions
# NOTE: We index dirs to get the correct ordering
axis_mu = dirs[mu]
axis_nu = dirs[nu]
U_nu_shift_mu = np.roll(U_nu, -1, axis=axis_mu)
U_mu_shift_nu = np.roll(U_mu, -1, axis=axis_nu)
# Wilson loops: W = U_{mu}(k_0) U_{nu}(k_0+delta_mu) U^{-1}_{mu}(k_0+delta_mu+delta_nu) U^{-1}_{nu}(k_0)
if use_tensorflow:
try:
import tensorflow as tf
except ImportError:
raise ImportError(
"TensorFlow is not installed. Please install it or set use_tensorflow=False."
)
U_mu_tf = tf.convert_to_tensor(U_mu)
U_nu_shift_mu_tf = tf.convert_to_tensor(U_nu_shift_mu)
U_mu_shift_nu_tf = tf.convert_to_tensor(U_mu_shift_nu)
U_nu_tf = tf.convert_to_tensor(U_nu)
U_wilson_tf = tf.linalg.matmul(
tf.linalg.matmul(
tf.linalg.matmul(U_mu_tf, U_nu_shift_mu_tf),
tf.linalg.adjoint(U_mu_shift_nu_tf),
),
tf.linalg.adjoint(U_nu_tf),
)
U_wilson = U_wilson_tf.numpy()
else:
U_wilson = (
U_mu
[docs]
@ U_nu_shift_mu
@ U_mu_shift_nu.conj().swapaxes(-1, -2)
@ U_nu.conj().swapaxes(-1, -2)
)
# Trim the last point along closed/non-periodic axes to avoid overcounting
for ax in dirs:
if self.mesh.is_axis_closed(ax) or (
not self.mesh.is_axis_looped(ax)
and not self.mesh.is_axis_bz_winding(ax)
):
logger.debug(
f"Axis {ax} is closed or non-periodic. "
"Trimming the last point in the Wilson loop to avoid overcounting."
)
U_wilson = np.delete(U_wilson, -1, axis=ax)
if non_abelian:
# Non-Abelian lattice field strength: F = -i Log(U_wilson)
# Matrix log using eigen-decompositon
# Eigen-decompose U_wilson = V diag(-phi_j) V^{-1}, phi_j in (-pi, pi]
if use_tensorflow:
try:
import tensorflow as tf
except ImportError:
raise ImportError(
"TensorFlow is not installed. Please install it or set use_tensorflow=False."
)
eigvals, eigvecs = tf.linalg.eig(tf.convert_to_tensor(U_wilson))
eigvals = eigvals.numpy()
eigvecs = eigvecs.numpy()
else:
eigvals, eigvecs = np.linalg.eig(U_wilson)
phi = -np.angle(eigvals)
F_diag = np.einsum("...i, ij -> ...ij", phi, np.eye(phi.shape[-1]))
eigvecs_inv = np.linalg.inv(eigvecs)
phases_plane = eigvecs @ F_diag @ eigvecs_inv
else:
det_U = np.linalg.det(U_wilson)
phases_plane = -np.angle(det_U)
if plane is None:
# Store the Berry flux in a 2D array for each pair of directions
berry_flux[mu, nu] = phases_plane
berry_flux[nu, mu] = -phases_plane
else:
berry_flux = phases_plane
return berry_flux
def berry_curvature(
self,
plane: ArrayLike = None,
state_idx: ArrayLike | int = None,
non_abelian: bool = False,
return_flux: bool = False,
):
r"""Berry curvature tensor using the Fukui-Hatsugai-Suzuki plaquette method.
The Berry curvature tensor :math:`\Omega_{\mu\nu}(\mathbf{k})`
is computed using a discretized formula based on the
Fukui-Hatsugai-Suzuki (FHS) plaquette-based method [1]_. The curvature
is approximated from the Berry flux (computed in :meth:`berry_flux`)
by dividing the flux by the (presumed uniform) area of the plaquette in
parameter space,
.. math::
\Omega_{\mu\nu}(\mathbf{k}) \approx
\frac{\mathcal{F}_{\mu\nu}(\mathbf{k})}{A_{\mu\nu}},
where :math:`A_{\mu\nu}` is the area (in Cartesian units) of the
plaquette in parameter space.
The tensor is either defined is a matrix-valued quantity (non-Abelian case)
or as a scalar quantity obtained by tracing over the band indices
(Abelian case).
.. versionadded:: 2.0.0
Parameters
----------
plane : (2,) array_like, optional
Array or tuple of two indices defining the axes in the
WFArray mesh which the Berry flux is computed over. By default,
all directions are considered, and the full Berry flux tensor is
returned.
state_idx : int or array_like of int, optional
Optional index or array of indices defining the states to be included
in the subsequent calculations, typically the indices of
bands considered occupied. If not specified, or None, all bands are
included.
non_abelian : bool, optional
If *True* then the non-Abelian Berry flux tensor is computed defined
as a matrix-valued quantity.
If *False* (default) then the Berry flux is computed as a scalar quantity
by tracing over the band indices.
return_flux : bool, optional
If *True*, the function returns a tuple containing both the Berry
curvature and the Berry flux tensors. If *False* (default), only the
Berry curvature tensor is returned.
Returns
-------
berry_curv : np.ndarray
Berry curvature tensor with shape depending on input parameters.
Shape is
- ``non_abelian=True``, ``plane=None``:
``(naxes, naxes, *mesh_shape, nstates, nstates)``,
- ``non_abelian=True``, ``plane=(mu, nu)``:
``(*mesh_shape, nstates, nstates)``,
- ``non_abelian=False``, ``plane=None``:
``(naxes, naxes, *mesh_shape)``,
- ``non_abelian=False``, ``plane=(mu, nu)``:
``(*mesh_shape,)``,
berry_flux : np.ndarray, optional
Berry flux tensor with shape depending on input parameters.
Returned only if ``return_flux=True``. Shape is same as that of
``berry_curv``.
See Also
--------
:meth:`berry_flux` : For details and formalism on the Berry flux tensor.
Notes
-----
- The method requires at least a two-dimensional mesh in the
combined adiabatic space (momentum + adiabatic parameters). Thus,
``WFArray.naxes >= 2``.
- The last point along closed or non-periodic axes is trimmed from the curvature
array to avoid overcounting, since it is equivalent to the first point.
References
----------
.. [1] \ T. Fukui, Y. Hatsugai, and H. Suzuki, *J. Phys. Soc. Jpn.* **74**, 1674 (2005).
"""
ndims = self.naxes # Total dimensionality of adiabatic space: d
if plane is None:
dirs = list(range(ndims))
else:
p, q = plane # Unpack plane directions
dirs = [p, q]
berry_flux = self.berry_flux(
plane=plane, state_idx=state_idx, non_abelian=non_abelian
)
berry_curv = np.zeros_like(berry_flux, dtype=complex)
coords = self.mesh.points # (..., dim_k + dim_lam)
dim_total = coords.shape[-1]
dim_k = self.lattice.dim_k
recip = np.asarray(self.lattice.recip_lat_vecs, float) if dim_k else None
# Collect the physical step vectors for each sampling axis
axis_vecs = []
for ax in range(ndims):
delta = np.zeros(dim_total, dtype=float)
for comp in range(dim_total):
arr = self.mesh.get_axis_range(ax, comp)
if arr.size >= 2:
diff = arr[1] - arr[0]
if not np.isclose(diff, 0.0):
delta[comp] = diff
if not np.any(delta):
raise ValueError(
f"Cannot compute Berry curvature: "
f"Mesh axis {ax} has zero length in all parameter directions."
)
if dim_k:
k_cart = delta[:dim_k] @ recip
vec = np.concatenate([k_cart, delta[dim_k:]])
else:
vec = delta[dim_k:]
axis_vecs.append(vec)
# Divide by area elements for the (mu, nu)-plane
if plane is not None: # first two axes absent
mu, nu = p, q
A = np.vstack([axis_vecs[mu], axis_vecs[nu]])
area = np.sqrt(np.linalg.det(A @ A.T))
berry_curv = berry_flux / area
else:
for i, mu in enumerate(dirs):
for j in range(i + 1, len(dirs)):
nu = dirs[j]
A = np.vstack([axis_vecs[mu], axis_vecs[nu]])
area = np.sqrt(np.linalg.det(A @ A.T))
# Divide flux by the area element to get approx curvature
berry_curv[mu, nu] = berry_flux[mu, nu] / area
berry_curv[nu, mu] = berry_flux[nu, mu] / area
return (berry_curv, berry_flux) if return_flux else berry_curv
[docs]
def chern_number(self, plane=(0, 1), state_idx=None):
r"""Computes the Chern number in the specified plane.
The Chern number is computed as the integral of the Berry flux
over the specified plane, divided by :math:`2 \pi`.
.. math::
C = \frac{1}{2\pi} \sum_{\mathbf{k}_{\mu}, \mathbf{k}_{\nu}} F_{\mu\nu}(\mathbf{k}).
The plane :math:`(\mu, \nu)` is specified by `plane`, a tuple of two indices.
.. versionadded:: 2.0.0
Parameters
----------
plane : tuple
A tuple of two indices specifying the plane in which the Chern number is computed.
The indices should be between 0 and the number of mesh dimensions minus 1.
If None, the Chern number is computed for the first two dimensions of the mesh.
state_idx : array-like, optional array
Indices of states to be included in the Chern number calculation.
If None, all states are included. None by default.
Returns
-------
chern : np.ndarray, float
In the two-dimensional case, the result
will be a floating point approximation of the integer Chern number
for that plane. In a higher-dimensional space, the Chern number
is computed for each 2D slice of the higher-dimensional space.
E.g., the shape of the returned array is `(nk3, ..., nkd)` if the plane is
`(0, 1)`, where `(nk3, ..., nkd)` are the sizes of the mesh in the remaining
dimensions.
See Also
--------
:meth:`berry_flux` : For details and formalism on the Berry flux tensor.
Notes
-----
- The Chern number gives an integer value in the
limit of an infinitely dense mesh only when the plane
forms a closed manifold.
- The method requires at least a two-dimensional mesh in the
combined adiabatic space (momentum + adiabatic parameters). Thus,
``WFArray.naxes >= 2``.
- The last point along closed or non-periodic axes is trimmed from the Chern
number array to avoid overcounting, since it is equivalent to the first point.
Examples
--------
Suppose we have a `WFArray` mesh in three-dimensional space
of shape `(nk1, nk2, nk3)`. We can compute the Chern number for the
`(0, 1)` plane as follows:
>>> wfs = WFArray(model, [10, 11, 12])
>>> wfs.solve_on_grid()
>>> chern = wfs.chern_number(plane=(0, 1), state_idx=np.arange(n_occ))
>>> print(chern.shape)
(12,) # shape of the Chern number array
"""
# shape of the Berry flux array: (nk1, nk2, ..., nkd)
berry_flux = self.berry_flux(
state_idx=state_idx, plane=plane, non_abelian=False
)
# shape of chern (if plane is (0,1)): (nk3, ..., nkd)
chern = np.sum(berry_flux, axis=plane) / (2 * np.pi)
return chern
[docs]
def position_matrix(
self, pos_dir: int, mesh_idx: list[int], state_idx: list[int] = None
):
r"""Position matrix for a given k-point and set of states.
Position operator is defined in reduced coordinates.
The returned object :math:`X` is
.. math::
X_{m n {\bf k}}^{\alpha} = \langle u_{m {\bf k}} \vert
r^{\alpha} \vert u_{n {\bf k}} \rangle
Here :math:`r^{\alpha}` is the position operator along direction
:math:`\alpha` that is selected by `pos_dir`.
This routine can be used to compute the position matrix for a
given k-point and set of states (which can be all states, or
a specific subset).
Parameters
----------
pos_dir: int
Direction of the position operator. ``0`` corresponds to the first
non-periodic direction, ``1`` to the second, and so on.
.. versionchanged:: 2.0.0
Renamed from ``dir`` to ``pos_dir`` to avoid conflict with built-in Python function `dir()`.
mesh_idx: array-like of int
Set of integers specifying the :math:`(k, \lambda)`-point of interest in the mesh.
state_idx: array-like, optional
List of states to be included. If not specified, all states are included.
.. versionchanged:: 2.0.0
Renamed from ``occ``. The band indices are not required to be
occupied bands only. The default behavior is to include all bands,
and the ``"all"`` option has been removed.
Returns
-------
pos_mat : np.ndarray
Position operator matrix :math:`X_{m n}` as defined above.
This is a square matrix with size determined by number of bands
given in `evec` input array. First index of `pos_mat` corresponds to
bra vector (:math:`m`) and second index to ket (:math:`n`).
See Also
--------
:func:`pythtb.TBModel.position_matrix`
Notes
-----
The only difference in :func:`pythtb.TBModel.position_matrix` is that,
in addition to specifying ``pos_dir``, one also has to specify ``mesh_idx``
(mesh-point of interest) and ``state_idx``
(list of states to be included, which can optionally be 'all').
"""
# # check if model came from w90
# if not self._assume_position_operator_diagonal:
# _offdiag_approximation_warning_and_stop()
if isinstance(mesh_idx, (list, np.ndarray, tuple)):
mesh_idx = tuple(mesh_idx)
elif isinstance(mesh_idx, (int, np.integer)):
mesh_idx = (mesh_idx,)
else:
raise TypeError(
"mesh_idx must be a list, numpy array, or tuple defining "
"k-point indices of interest."
)
if len(mesh_idx) != self.naxes:
raise ValueError(
f"mesh_idx must have length {self.naxes} corresponding to "
"number of mesh axes."
)
state_idx = self._normalize_state_indices(state_idx)
evec = self.wfs[tuple(mesh_idx)][state_idx]
# make sure specified direction is not periodic!
if pos_dir in self.lattice.periodic_dirs:
raise Exception(
"Can not compute position matrix elements along periodic direction!"
)
# make sure direction is not out of range
if pos_dir < 0 or pos_dir >= self.lattice.dim_r:
raise Exception("Direction out of range!")
# check shape of evec
if not isinstance(evec, np.ndarray):
raise TypeError("evec must be a numpy array.")
# check number of dimensions of evec
if self.nspin == 1:
if evec.ndim != 2:
raise ValueError(
"evec must be a 2D array with shape (band, orbital) for spinless models."
)
elif self.nspin == 2:
if evec.ndim != 3:
raise ValueError(
"evec must be a 3D array with shape (band, orbital, spin) for spinful models."
)
# get coordinates of orbitals along the specified direction
pos_tmp = self.lattice.orb_vecs[:, pos_dir]
# reshape arrays in the case of spinfull calculation
if self.nspin == 2:
# tile along spin direction if needed
pos_use = np.tile(pos_tmp, (2, 1)).transpose().flatten()
evec_use = evec.reshape((evec.shape[0], evec.shape[1] * evec.shape[2]))
else:
pos_use = pos_tmp
evec_use = evec
# position matrix elements
pos_mat = np.zeros((evec_use.shape[0], evec_use.shape[0]), dtype=complex)
# go over all bands
for i in range(evec_use.shape[0]):
for j in range(evec_use.shape[0]):
pos_mat[i, j] = np.dot(evec_use[i].conj(), pos_use * evec_use[j])
# make sure matrix is Hermitian
if not np.allclose(pos_mat, pos_mat.T.conj()):
raise ValueError("Position matrix is not Hermitian.")
return pos_mat
[docs]
def position_expectation(self, pos_dir: int, mesh_idx=None, state_idx=None):
r"""Position expectation value for a given k-point and set of states.
These elements :math:`X_{n n}` can be interpreted as an
average position of n-th Bloch state ``evec[n]`` along
direction ``pos_dir``.
This routine can be used to compute the position expectation value for a
given k-point and set of states (which can be all states, or
a specific subset).
Parameters
----------
pos_dir: int
Direction of the position operator. ``0`` corresponds to the first
non-periodic direction, ``1`` to the second, and so on.
.. versionchanged:: 2.0.0
Renamed from ``dir`` to ``pos_dir`` to avoid conflict with built-in Python function `dir()`.
mesh_idx: array-like of int, optional
Set of integers specifying the :math:`(k, \lambda)`-point of interest in the mesh.
If not specified, position expectation values are computed for all mesh points.
state_idx: array-like, optional
List of states to be included. If not specified, all states are included.
.. versionchanged:: 2.0.0
Renamed from ``occ``. The band indices are not required to be
occupied bands only. The default behavior is to include all bands,
and the ``"all"`` option has been removed.
Returns
-------
pos_exp : np.ndarray
Diagonal elements of the position operator matrix :math:`X`.
Length of this vector is determined by number of bands given in *evec* input
array.
See Also
--------
:func:`pythtb.TBModel.position_expectation`
:ref:`haldane-hwf-nb` : For an example.
position_matrix : For definition of matrix :math:`X`.
Notes
-----
The only difference in :func:`pythtb.TBModel.position_expectation` is that,
in addition to specifying ``pos_dir``, one also has to specify ``mesh_idx``
(mesh-point of interest) and ``state_idx`` (list of states to be included).
Generally speaking these centers are _not_ hybrid Wannier function
centers (which are instead returned by :func:`position_hwf`).
"""
if mesh_idx is None:
pos_exp = np.zeros((*self.shape_mesh, self.nstates), dtype=float)
# loop over all mesh points
for idx in np.ndindex(*self.shape_mesh):
pos_exp_mat = self.position_matrix(
mesh_idx=idx, state_idx=state_idx, pos_dir=pos_dir
).diagonal()
pos_exp[idx] = np.array(np.real(pos_exp_mat), dtype=float)
return pos_exp
else:
pos_exp_mat = self.position_matrix(
mesh_idx=mesh_idx, state_idx=state_idx, pos_dir=pos_dir
).diagonal()
return np.array(np.real(pos_exp_mat), dtype=float)
[docs]
def position_hwf(
self,
pos_dir,
mesh_idx,
state_idx=None,
hwf_evec: bool = False,
basis: str = "wavefunction",
):
r"""Eigenvalues and eigenvectors of the position operator in a given basis.
Parameters
----------
mesh_idx: array-like of int
Set of integers specifying the index of interest in the mesh.
pos_dir: int
Direction along which to compute the position operator.
.. versionchanged:: 2.0.0
Renamed from ``dir`` to ``pos_dir`` to avoid conflict with built-in Python function `dir()`.
state_idx: array-like, optional
List of states to be included. If not specified, all states are included.
.. versionchanged:: 2.0.0
Renamed from ``occ``. The band indices are not required to be
occupied bands only. The default behavior is to include all bands,
and the ``"all"`` option has been removed.
hwf_evec: bool, optional
Default is `False`. If `True`, return the eigenvectors along with eigenvalues
of the position operator.
basis: {"orbital", "wavefunction", "bloch"}, optional
The basis in which to compute the position operator.
Returns
-------
hwfc : np.ndarray
Eigenvalues of the position operator matrix :math:`X`
(also called hybrid Wannier function centers).
Length of this vector equals number of bands given in *evec* input
array. Hybrid Wannier function centers are ordered in ascending order.
Note that in general `n`-th hwfc does not correspond to `n`-th electronic
state `evec`.
hwf : np.ndarray, optional
Eigenvectors of the position operator matrix :math:`X`.
(also called hybrid Wannier functions). These are returned only if
parameter ``hwf_evec=True``.
The shape of this array is ``[h,x]`` or ``[h,x,s]`` depending on value of
``basis`` and ``spinful``.
- If ``basis = "bloch"`` then ``x`` refers to indices of
Bloch states `evec`.
- If ``basis = "orbital"`` then ``x`` (or ``x`` and ``s``)
correspond to orbital index (or orbital and spin index
if ``spinful=True``).
See Also
--------
:ref:`haldane-hwf-nb` : For an example.
position_matrix : For the definition of the matrix :math:`X`.
position_expectation : For the position expectation value.
:func:`pythtb.TBModel.position_hwf`
Notes
-----
Similar to :func:`pythtb.TBModel.position_hwf`, except that
in addition to specifying ``pos_dir``, one also has to specify ``mesh_idx``
(mesh-point of interest) and ``state_idx`` (list of states to be included).
For backwards compatibility the default value of *basis* here is different
from that in :func:`pythtb.TBModel.position_hwf`.
"""
state_idx = self._normalize_state_indices(state_idx)
# get position matrix
pos_mat = self.position_matrix(
mesh_idx=mesh_idx, state_idx=state_idx, pos_dir=pos_dir
)
evec = self.wfs[tuple(mesh_idx)][state_idx]
# diagonalize position matrix
if not hwf_evec:
hwfc = np.linalg.eigvalsh(pos_mat)
return hwfc
else:
hwfc, hwf = np.linalg.eigh(pos_mat)
# transpose so eig[i, :] is eigenvector for eval[i]-th eigenvalue
hwf = hwf.T
# convert to right basis
if basis.lower().strip() in ["wavefunction", "bloch"]:
return hwfc, hwf
elif basis.lower().strip() == "orbital":
if self.nspin == 1:
ret_hwf = np.zeros((hwf.shape[0], self.norb), dtype=complex)
for i in range(ret_hwf.shape[0]):
ret_hwf[i] = np.dot(hwf[i], evec) # project onto orbital basis
hwf = ret_hwf
else:
ret_hwf = np.zeros((hwf.shape[0], self.norb * 2), dtype=complex)
# flatten spin indices
evec_use = evec.reshape([hwf.shape[0], self.norb * 2])
for i in range(ret_hwf.shape[0]):
ret_hwf[i] = np.dot(
hwf[i], evec_use
) # project onto orbital basis
# restore spin indices
hwf = ret_hwf.reshape([hwf.shape[0], self.norb, 2])
return hwfc, hwf
else:
raise ValueError(
"Basis must be either 'wavefunction', 'bloch', or 'orbital'"
)
def _trace_metric(self):
P = self.projectors()
_, Q_nbr = self._nbr_projectors(return_Q=True)
nks = Q_nbr.shape[:-3]
num_nnbrs = Q_nbr.shape[-3]
w_b, _, _ = self.lattice.k_shell_weights(self.mesh.shape_k, n_shell=1)
T_kb = np.zeros((*nks, num_nnbrs), dtype=complex)
for nbr_idx in range(num_nnbrs): # nearest neighbors
T_kb[..., nbr_idx] = np.trace(
P[..., :, :] @ Q_nbr[..., nbr_idx, :, :], axis1=-1, axis2=-2
)
return w_b[0] * np.sum(T_kb, axis=-1)
def _omega_til(self):
Mmn = self._Mmn
w_b, k_shell, _ = self.lattice.k_shell_weights(self.mesh.shape_k, n_shell=1)
w_b = w_b[0]
k_shell = k_shell[0]
nks = Mmn.shape[:-3]
Nk = np.prod(nks)
k_axes = tuple([i for i in range(len(nks))])
diag_M = np.diagonal(Mmn, axis1=-1, axis2=-2)
log_diag_M_imag = np.log(diag_M).imag
abs_diag_M_sq = abs(diag_M) ** 2
r_n = -(1 / Nk) * w_b * np.sum(log_diag_M_imag, axis=k_axes).T @ k_shell
Omega_tilde = (
(1 / Nk)
* w_b
* (
np.sum((-log_diag_M_imag - k_shell @ r_n.T) ** 2)
+ np.sum(abs(Mmn) ** 2)
- np.sum(abs_diag_M_sq)
)
)
return Omega_tilde
def _no_2pi(phi, ref):
"""Shift phase phi by integer multiples of 2π so it is closest to ref."""
while abs(ref - phi) > np.pi:
if ref - phi > np.pi:
phi += 2.0 * np.pi
elif ref - phi < -1.0 * np.pi:
phi -= 2.0 * np.pi
return phi
def _array_phases_cont(arr_pha, clos):
"""Reads in 2d array of phases arr_pha and enforces continuity along the first index,
i.e., no 2π jumps. First row is made as close to clos as possible."""
ret = np.zeros_like(arr_pha)
for i in range(arr_pha.shape[0]):
cmpr = clos if i == 0 else ret[i - 1, :]
avail = list(range(arr_pha.shape[1]))
for j in range(cmpr.shape[0]):
best_k, min_dist = None, 1e10
for k in avail:
cur_dist = np.abs(np.exp(1j * cmpr[j]) - np.exp(1j * arr_pha[i, k]))
if cur_dist <= min_dist:
min_dist = cur_dist
best_k = k
avail.remove(best_k)
ret[i, j] = _no_2pi(arr_pha[i, best_k], cmpr[j])
return ret
def _one_phase_cont(pha, clos):
"""Reads in 1d array of numbers *pha* and makes sure that they are
continuous, i.e., that there are no jumps of 2pi. First number is
made as close to *clos* as possible."""
ret = np.copy(pha)
# go through entire list and "iron out" 2pi jumps
for i in range(len(ret)):
# which number to compare to
if i == 0:
cmpr = clos
else:
cmpr = ret[i - 1]
# make sure there are no 2pi jumps
ret[i] = _no_2pi(ret[i], cmpr)
return ret