#!/usr/bin/env python
# coding: utf-8

# # `Wannier`: disentanglement with reduced Wannier functions

# In[ ]:


from pythtb import Mesh, Wannier, WFArray
from pythtb.models import haldane
import numpy as np


# ## Haldane model supercell construction
# 
# We begin by constructing the Haldane model in the topological phase, in which case the non-zero Chern number of the occupied bands enforces a topological obstruction to constructing exponentially localized Wannier functions that respect the lattice symmetries.

# In[ ]:


# tight-binding parameters
delta = 1
t1 = 1
t2 = -0.4
prim_model = haldane(delta, t1, t2)

print(f"Chern number: {prim_model.chern_number((0, 1), (20, 20)):0.3f}")


# To circumvent the obstruction, we use the procedure of constructing "reduced Wannier" functions that are localized in a smaller subspace of the original Hilbert space. Since the occupied space is one-dimensional, there exists no subspace. This is why we must construct a supercell, folding the bands back into the first Brillouin zone, to obtain more occupied bands to choose from. Here we will use a 2x2 supercell, which will give us 4 occupied bands to work with.

# In[ ]:


n_super_cell = 2
model = prim_model.make_supercell([[n_super_cell, 0], [0, n_super_cell]])
model.info(show=True, short=False)


# We construct the `WFArray` and diagonalize the model on a _semi-full_ k-mesh. It is important that the mesh not include the endpoints $k_i=1$, which correspond to the boundaries of the Brillouin zone. The Fourier transform requires a well-defined periodicity, which is disrupted by including these points. Therefore, we will use a k-mesh that spans the interior of the Brillouin zone, avoiding the boundaries. This is the default behavior of `Mesh.build_grid`.
# 
# .. note:: 
#     The `Mesh.build_grid` function automatically imposes periodic boundary conditions on the k-mesh. We can also explicitly state that the k-mesh is periodic, with the topology of a torus, we can use the `Mesh.wind_bz` function, specifying the mesh axis and k-component that is wrapped. Here, it is not necessary since the `Mesh.build_grid` function already does this for us. In other cases where we use a custom k-mesh, we may need to use `Mesh.wind_bz` to impose periodic boundary conditions.

# In[ ]:


nks = 20, 20  # number of k points along each dimension
mesh = Mesh(dim_k=2, axis_types=["k", "k"])
mesh.build_grid(shape=nks)
print(mesh)


# Now we pass this mesh to the `WFArray` constructor and solve the mesh.

# In[ ]:


wfa = WFArray(model.lattice, mesh)
wfa.solve_model(model)


# We know that Wannierizing the full set of 4 occupied bands is obstructed by the topology of the band structure. We can try the next best thing and Wannierize a 3-dimensional subspace. To do this, we will choose a set of three trial wavefunctions centered on 3 of the low energy orbitals, where we would expect the localized Wannier functions of the trivial occupied bands to be located

# In[ ]:


n_orb = model.norb  # number of orbitals
n_occ = int(n_orb / 2)  # number of occupied bands (assume half-filling)

low_E_sites = np.arange(
    0, n_orb, 2
)  # low-energy sites defined to be indexed by even numbers
high_E_sites = np.arange(
    1, n_orb, 2
)  # high-energy sites defined to be indexed by odd numbers

omit_site = 6  # omitting one of the low energy sites
sites = list(np.setdiff1d(low_E_sites, [omit_site]))
tf_list = [
    [(orb, 1)] for orb in sites
]  # trial wavefunctions in form of [(orbital index, weight)]

n_tfs = len(tf_list)

print(f"Trial wavefunctions: {tf_list}")
print(f"# of Wannier functions: {n_tfs}")
print(f"# of occupied bands: {n_occ}")
print(f"Wannier fraction: {n_tfs / n_occ}")


# ## Optimal Alignment

# Next, we initialize the `Wannier` object with the `TBModel` and `WFArray` objects. We initialize the Bloch-like states with `project` function which aligns the trial wavefunctions with the target bands specified by `band_idxs`.

# In[ ]:


WF = Wannier(wfa)

WF.project(tf_list, band_idxs=list(range(n_occ)))


# This already gives us a set of Wannier functions that are exponentially localized, showing that this is a trivial subsapce of the obstructed manifold.

# In[ ]:


WF.info()


# ## Disentanglement

# We can make these states even more localized with subspace selectio via the disentanglement procedure. This picks the subspace of the 4-band manfiold that minimizes the gauge-independent spread.

# In[ ]:


frozen_window = None  # frozen window in energy
outer_window = [-4, 0]  # outer window in energy

WF.disentangle(
    n_wfs=3,
    frozen_window=frozen_window,
    outer_window=outer_window,
    verbose=True,
    tf_speedup=True,
    max_iter=500,
    tol=1e-10,
)


# In[ ]:


WF.info()


# ## Maximal localization
# 
# To obtain maximally localized Wannier functions, we follow this with another projection to initialize a smooth gauge, then maximal localization.
# - Note we must pass the flag `tilde=True` to indicate we are projecting the trial wavefunctions onto the tilde states and not the energy eigenstates

# In[ ]:


WF.project(use_tilde=True)


# In[ ]:


WF.info()


# In[ ]:


WF.maxloc(alpha=1 / 2, max_iter=1000, tol=1e-10, grad_min=1e-10, verbose=True)


# In[ ]:


WF.info()


# Now the spreads have been minimized, and the Wannier functions are maximally localized. To help validate that the Wannier functions are indeed exponentially localized, we can plot the decay of each Wannier function's weight away from its center with `plot_decay`. This will plot the absolute value of each Wannier function as a function of distance from its center on a logarithmic scale.

# In[ ]:


fig, ax = WF.plot_decay(0, show=True)


# In[ ]:


fig, ax = WF.plot_density(0, show=True)


# Note that we have effectively broken the primitive translational symmetry of the underlying lattice by choosing a subset of trial wavefunctions on three out of the four low energy sites in the supercell. We can see their positions using `plot_centers`

# In[ ]:


fig, ax = WF.plot_centers(
    color_home_cell=True, center_scale=15, legend=True, pmx=4, pmy=4, show=True
)


# ## Wannier interpolation
# 
# We can view the Wannier interpolated bands by calling `plot_interp_bands`. We specify a set of high-symmetry k-points that defines the one-dimensional path along which the bands are plotted. 

# In[ ]:


k_nodes = [
    [0, 0],
    [2 / 3, 1 / 3],
    [1 / 2, 1 / 2],
    [1 / 3, 2 / 3],
    [0, 0],
    [1 / 2, 1 / 2],
]
k_label = (r"$\Gamma $", r"$K$", r"$M$", r"$K^\prime$", r"$\Gamma $", r"$M$")


# In[ ]:


n_interp = 501
interp_energies = WF.interp_bands(k_nodes, n_interp=n_interp, ret_eigvecs=False)


# In[ ]:


fig, ax = model.plot_bands(
    k_nodes=k_nodes,
    nk=501,
    k_node_labels=k_label,
    proj_orb_idx=high_E_sites,
    cmap="plasma",
)

(k_vec, k_dist, k_node) = model.k_path(k_nodes, nk=n_interp, report=False)
ax.plot(k_dist, interp_energies, ls="--", c="lightgreen", lw=2, zorder=5, alpha=1)

# plot windows
if frozen_window is not None:
    ax.axhline(frozen_window[0], ls="--", c="b", label="frozen window")
    ax.axhline(frozen_window[1], ls="--", c="b")

ax.axhline(outer_window[0], ls=":", c="r", label="disentanglement window")
ax.axhline(outer_window[1], ls=":", c="r")
ax.legend()

